package exhaustive import ( "bytes" "fmt" "go/ast" "go/printer" "go/token" "go/types" "regexp" "sort" "strconv" "strings" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/ast/inspector" ) func isDefaultCase(c *ast.CaseClause) bool { return c.List == nil // see doc comment on field } func checkSwitchStatements( pass *analysis.Pass, inspect *inspector.Inspector, ) error { comments := make(map[*ast.File]ast.CommentMap) // CommentMap per package file, lazily populated by reference generated := make(map[*ast.File]bool) return checkSwitchStatements_(pass, inspect, comments, generated) } func checkSwitchStatements_( pass *analysis.Pass, inspect *inspector.Inspector, comments map[*ast.File]ast.CommentMap, generated map[*ast.File]bool, ) error { inspect.WithStack([]ast.Node{&ast.SwitchStmt{}}, func(n ast.Node, push bool, stack []ast.Node) bool { if !push { return true } file := stack[0].(*ast.File) // Determine if file is a generated file, based on https://golang.org/s/generatedcode. // If generated, don't check this file. var isGenerated bool if gen, ok := generated[file]; ok { isGenerated = gen } else { isGenerated = isGeneratedFile(file) generated[file] = isGenerated } if isGenerated && !fCheckGeneratedFiles { // don't check return true } sw := n.(*ast.SwitchStmt) if sw.Tag == nil { return true } t := pass.TypesInfo.Types[sw.Tag] if !t.IsValue() { return true } tagType, ok := t.Type.(*types.Named) if !ok { return true } tagPkg := tagType.Obj().Pkg() if tagPkg == nil { // Doc comment: nil for labels and objects in the Universe scope. // This happens for the `error` type, for example. // Continuing would mean that ImportPackageFact panics. return true } var enums enumsFact if !pass.ImportPackageFact(tagPkg, &enums) { // Can't do anything further. return true } em, isEnum := enums.Enums[tagType.Obj().Name()] if !isEnum { // Tag's type is not a known enum. return true } // Get comment map. var allComments ast.CommentMap if cm, ok := comments[file]; ok { allComments = cm } else { allComments = ast.NewCommentMap(pass.Fset, file, file.Comments) comments[file] = allComments } specificComments := allComments.Filter(sw) for _, group := range specificComments.Comments() { if containsIgnoreDirective(group.List) { return true // skip checking due to ignore directive } } samePkg := tagPkg == pass.Pkg checkUnexported := samePkg hitlist := hitlistFromEnumMembers(em, tagPkg, checkUnexported, fIgnorePattern.Get().(*regexp.Regexp)) if len(hitlist) == 0 { return true } var defaultCase *ast.CaseClause for _, stmt := range sw.Body.List { caseCl := stmt.(*ast.CaseClause) if isDefaultCase(caseCl) { defaultCase = caseCl continue // nothing more to do if it's the default case } for _, e := range caseCl.List { e = astutil.Unparen(e) if samePkg { ident, ok := e.(*ast.Ident) if !ok { continue } updateHitlist(hitlist, em, ident.Name) } else { selExpr, ok := e.(*ast.SelectorExpr) if !ok { continue } // ensure X is package identifier ident, ok := selExpr.X.(*ast.Ident) if !ok { continue } if !isPackageNameIdentifier(pass, ident) { continue } updateHitlist(hitlist, em, selExpr.Sel.Name) } } } defaultSuffices := fDefaultSignifiesExhaustive && defaultCase != nil shouldReport := len(hitlist) > 0 && !defaultSuffices if shouldReport { reportSwitch(pass, sw, defaultCase, samePkg, tagType, em, hitlist, file) } return true }) return nil } func updateHitlist(hitlist map[string]struct{}, em *enumMembers, foundName string) { constVal, ok := em.NameToValue[foundName] if !ok { // only delete the name alone from hitlist delete(hitlist, foundName) return } // delete all of the same-valued names from hitlist namesToDelete := em.ValueToNames[constVal] for _, n := range namesToDelete { delete(hitlist, n) } } func isPackageNameIdentifier(pass *analysis.Pass, ident *ast.Ident) bool { obj := pass.TypesInfo.ObjectOf(ident) if obj == nil { return false } _, ok := obj.(*types.PkgName) return ok } func hitlistFromEnumMembers(em *enumMembers, enumPkg *types.Package, checkUnexported bool, ignorePattern *regexp.Regexp) map[string]struct{} { hitlist := make(map[string]struct{}) for _, name := range em.OrderedNames { if name == "_" { // blank identifier is often used to skip entries in iota lists continue } if ignorePattern != nil && ignorePattern.MatchString(enumPkg.Path()+"."+name) { continue } if !ast.IsExported(name) && !checkUnexported { continue } hitlist[name] = struct{}{} } return hitlist } func determineMissingOutput(missingMembers map[string]struct{}, em *enumMembers) []string { constValMembers := make(map[string][]string) // value -> names var otherMembers []string // non-constant value names for m := range missingMembers { if constVal, ok := em.NameToValue[m]; ok { constValMembers[constVal] = append(constValMembers[constVal], m) } else { otherMembers = append(otherMembers, m) } } missingOutput := make([]string, 0, len(constValMembers)+len(otherMembers)) for _, names := range constValMembers { sort.Strings(names) missingOutput = append(missingOutput, strings.Join(names, "|")) } missingOutput = append(missingOutput, otherMembers...) sort.Strings(missingOutput) return missingOutput } func reportSwitch( pass *analysis.Pass, sw *ast.SwitchStmt, defaultCase *ast.CaseClause, samePkg bool, enumType *types.Named, em *enumMembers, missingMembers map[string]struct{}, f *ast.File, ) { missingOutput := determineMissingOutput(missingMembers, em) var fixes []analysis.SuggestedFix if fix, ok := computeFix(pass, pass.Fset, f, sw, defaultCase, enumType, samePkg, missingMembers); ok { fixes = append(fixes, fix) } pass.Report(analysis.Diagnostic{ Pos: sw.Pos(), End: sw.End(), Message: fmt.Sprintf("missing cases in switch of type %s: %s", enumTypeName(enumType, samePkg), strings.Join(missingOutput, ", ")), SuggestedFixes: fixes, }) } func computeFix(pass *analysis.Pass, fset *token.FileSet, f *ast.File, sw *ast.SwitchStmt, defaultCase *ast.CaseClause, enumType *types.Named, samePkg bool, missingMembers map[string]struct{}) (analysis.SuggestedFix, bool) { // Function and method calls may be mutative, so we don't want to reuse the // call expression in the about-to-be-inserted case clause body. So we just // don't suggest a fix in such situations. // // However, we need to make an exception for type conversions, which are // also call expressions in the AST. // // We'll need to lookup type information for this, and can't rely solely // on the AST. if containsFuncCall(pass, sw.Tag) { return analysis.SuggestedFix{}, false } textEdits := []analysis.TextEdit{missingCasesTextEdit(fset, f, samePkg, sw, defaultCase, enumType, missingMembers)} // need to add "fmt" import if "fmt" import doesn't already exist if !hasImportWithPath(fset, f, `"fmt"`) { textEdits = append(textEdits, fmtImportTextEdit(fset, f)) } missing := make([]string, 0, len(missingMembers)) for m := range missingMembers { missing = append(missing, m) } sort.Strings(missing) return analysis.SuggestedFix{ Message: fmt.Sprintf("add case clause for: %s", strings.Join(missing, ", ")), TextEdits: textEdits, }, true } func containsFuncCall(pass *analysis.Pass, e ast.Expr) bool { e = astutil.Unparen(e) c, ok := e.(*ast.CallExpr) if !ok { return false } if _, isFunc := pass.TypesInfo.TypeOf(c.Fun).Underlying().(*types.Signature); isFunc { return true } for _, a := range c.Args { if containsFuncCall(pass, a) { return true } } return false } func firstImportDecl(fset *token.FileSet, f *ast.File) *ast.GenDecl { for _, decl := range f.Decls { genDecl, ok := decl.(*ast.GenDecl) if ok && genDecl.Tok == token.IMPORT { // first IMPORT GenDecl return genDecl } } return nil } // copies an GenDecl in a manner such that appending to the returned GenDecl's Specs field // doesn't mutate the original GenDecl func copyGenDecl(im *ast.GenDecl) *ast.GenDecl { imCopy := *im imCopy.Specs = make([]ast.Spec, len(im.Specs)) for i := range im.Specs { imCopy.Specs[i] = im.Specs[i] } return &imCopy } func hasImportWithPath(fset *token.FileSet, f *ast.File, pathLiteral string) bool { igroups := astutil.Imports(fset, f) for _, igroup := range igroups { for _, importSpec := range igroup { if importSpec.Path.Value == pathLiteral { return true } } } return false } func fmtImportTextEdit(fset *token.FileSet, f *ast.File) analysis.TextEdit { firstDecl := firstImportDecl(fset, f) if firstDecl == nil { // file has no import declarations // insert "fmt" import spec after package statement return analysis.TextEdit{ Pos: f.Name.End() + 1, // end of package name + 1 End: f.Name.End() + 1, NewText: []byte(`import ( "fmt" )`), } } // copy because we'll be mutating its Specs field firstDeclCopy := copyGenDecl(firstDecl) // find insertion index for "fmt" import spec var i int for ; i < len(firstDeclCopy.Specs); i++ { im := firstDeclCopy.Specs[i].(*ast.ImportSpec) if v, _ := strconv.Unquote(im.Path.Value); v > "fmt" { break } } // insert "fmt" import spec at the index fmtSpec := &ast.ImportSpec{ Path: &ast.BasicLit{ // NOTE: Pos field doesn't seem to be required for our // purposes here. Kind: token.STRING, Value: `"fmt"`, }, } s := firstDeclCopy.Specs // local var for easier comprehension of next line s = append(s[:i], append([]ast.Spec{fmtSpec}, s[i:]...)...) firstDeclCopy.Specs = s // create the text edit var buf bytes.Buffer printer.Fprint(&buf, fset, firstDeclCopy) return analysis.TextEdit{ Pos: firstDecl.Pos(), End: firstDecl.End(), NewText: buf.Bytes(), } } func missingCasesTextEdit(fset *token.FileSet, f *ast.File, samePkg bool, sw *ast.SwitchStmt, defaultCase *ast.CaseClause, enumType *types.Named, missingMembers map[string]struct{}) analysis.TextEdit { // ... Construct insertion text for case clause and its body ... var tag bytes.Buffer printer.Fprint(&tag, fset, sw.Tag) // If possible and if necessary, determine the package identifier based on // the AST of other `case` clauses. var pkgIdent *ast.Ident if !samePkg { for _, stmt := range sw.Body.List { caseCl := stmt.(*ast.CaseClause) if len(caseCl.List) != 0 { // guard against default case if sel, ok := caseCl.List[0].(*ast.SelectorExpr); ok { pkgIdent = sel.X.(*ast.Ident) break } } } } missing := make([]string, 0, len(missingMembers)) for m := range missingMembers { if !samePkg { if pkgIdent != nil { // we were able to determine package identifier missing = append(missing, pkgIdent.Name+"."+m) } else { // use the package name (may not be correct always) // // TODO: May need to also add import if the package isn't imported // elsewhere. This (ie, a switch with zero case clauses) should // happen rarely, so don't implement this for now. missing = append(missing, enumType.Obj().Pkg().Name()+"."+m) } } else { missing = append(missing, m) } } sort.Strings(missing) insert := `case ` + strings.Join(missing, ", ") + `: panic(fmt.Sprintf("unhandled value: %v",` + tag.String() + `))` // ... Create the text edit ... pos := sw.Body.Rbrace - 1 // put it as last case if defaultCase != nil { pos = defaultCase.Case - 2 // put it before the default case (why -2?) } return analysis.TextEdit{ Pos: pos, End: pos, NewText: []byte(insert), } }