logginghandler/vendor/github.com/nishanths/exhaustive/switch.go
Marvin Preuss d095180eb4
All checks were successful
continuous-integration/drone/push Build is passing
build: uses go modules for tool handling
2022-01-14 13:51:56 +01:00

445 lines
12 KiB
Go

package exhaustive
import (
"bytes"
"fmt"
"go/ast"
"go/printer"
"go/token"
"go/types"
"regexp"
"sort"
"strconv"
"strings"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/ast/inspector"
)
func isDefaultCase(c *ast.CaseClause) bool {
return c.List == nil // see doc comment on field
}
func checkSwitchStatements(
pass *analysis.Pass,
inspect *inspector.Inspector,
) error {
comments := make(map[*ast.File]ast.CommentMap) // CommentMap per package file, lazily populated by reference
generated := make(map[*ast.File]bool)
return checkSwitchStatements_(pass, inspect, comments, generated)
}
func checkSwitchStatements_(
pass *analysis.Pass,
inspect *inspector.Inspector,
comments map[*ast.File]ast.CommentMap,
generated map[*ast.File]bool,
) error {
inspect.WithStack([]ast.Node{&ast.SwitchStmt{}}, func(n ast.Node, push bool, stack []ast.Node) bool {
if !push {
return true
}
file := stack[0].(*ast.File)
// Determine if file is a generated file, based on https://golang.org/s/generatedcode.
// If generated, don't check this file.
var isGenerated bool
if gen, ok := generated[file]; ok {
isGenerated = gen
} else {
isGenerated = isGeneratedFile(file)
generated[file] = isGenerated
}
if isGenerated && !fCheckGeneratedFiles {
// don't check
return true
}
sw := n.(*ast.SwitchStmt)
if sw.Tag == nil {
return true
}
t := pass.TypesInfo.Types[sw.Tag]
if !t.IsValue() {
return true
}
tagType, ok := t.Type.(*types.Named)
if !ok {
return true
}
tagPkg := tagType.Obj().Pkg()
if tagPkg == nil {
// Doc comment: nil for labels and objects in the Universe scope.
// This happens for the `error` type, for example.
// Continuing would mean that ImportPackageFact panics.
return true
}
var enums enumsFact
if !pass.ImportPackageFact(tagPkg, &enums) {
// Can't do anything further.
return true
}
em, isEnum := enums.Enums[tagType.Obj().Name()]
if !isEnum {
// Tag's type is not a known enum.
return true
}
// Get comment map.
var allComments ast.CommentMap
if cm, ok := comments[file]; ok {
allComments = cm
} else {
allComments = ast.NewCommentMap(pass.Fset, file, file.Comments)
comments[file] = allComments
}
specificComments := allComments.Filter(sw)
for _, group := range specificComments.Comments() {
if containsIgnoreDirective(group.List) {
return true // skip checking due to ignore directive
}
}
samePkg := tagPkg == pass.Pkg
checkUnexported := samePkg
hitlist := hitlistFromEnumMembers(em, tagPkg, checkUnexported, fIgnorePattern.Get().(*regexp.Regexp))
if len(hitlist) == 0 {
return true
}
var defaultCase *ast.CaseClause
for _, stmt := range sw.Body.List {
caseCl := stmt.(*ast.CaseClause)
if isDefaultCase(caseCl) {
defaultCase = caseCl
continue // nothing more to do if it's the default case
}
for _, e := range caseCl.List {
e = astutil.Unparen(e)
if samePkg {
ident, ok := e.(*ast.Ident)
if !ok {
continue
}
updateHitlist(hitlist, em, ident.Name)
} else {
selExpr, ok := e.(*ast.SelectorExpr)
if !ok {
continue
}
// ensure X is package identifier
ident, ok := selExpr.X.(*ast.Ident)
if !ok {
continue
}
if !isPackageNameIdentifier(pass, ident) {
continue
}
updateHitlist(hitlist, em, selExpr.Sel.Name)
}
}
}
defaultSuffices := fDefaultSignifiesExhaustive && defaultCase != nil
shouldReport := len(hitlist) > 0 && !defaultSuffices
if shouldReport {
reportSwitch(pass, sw, defaultCase, samePkg, tagType, em, hitlist, file)
}
return true
})
return nil
}
func updateHitlist(hitlist map[string]struct{}, em *enumMembers, foundName string) {
constVal, ok := em.NameToValue[foundName]
if !ok {
// only delete the name alone from hitlist
delete(hitlist, foundName)
return
}
// delete all of the same-valued names from hitlist
namesToDelete := em.ValueToNames[constVal]
for _, n := range namesToDelete {
delete(hitlist, n)
}
}
func isPackageNameIdentifier(pass *analysis.Pass, ident *ast.Ident) bool {
obj := pass.TypesInfo.ObjectOf(ident)
if obj == nil {
return false
}
_, ok := obj.(*types.PkgName)
return ok
}
func hitlistFromEnumMembers(em *enumMembers, enumPkg *types.Package, checkUnexported bool, ignorePattern *regexp.Regexp) map[string]struct{} {
hitlist := make(map[string]struct{})
for _, name := range em.OrderedNames {
if name == "_" {
// blank identifier is often used to skip entries in iota lists
continue
}
if ignorePattern != nil && ignorePattern.MatchString(enumPkg.Path()+"."+name) {
continue
}
if !ast.IsExported(name) && !checkUnexported {
continue
}
hitlist[name] = struct{}{}
}
return hitlist
}
func determineMissingOutput(missingMembers map[string]struct{}, em *enumMembers) []string {
constValMembers := make(map[string][]string) // value -> names
var otherMembers []string // non-constant value names
for m := range missingMembers {
if constVal, ok := em.NameToValue[m]; ok {
constValMembers[constVal] = append(constValMembers[constVal], m)
} else {
otherMembers = append(otherMembers, m)
}
}
missingOutput := make([]string, 0, len(constValMembers)+len(otherMembers))
for _, names := range constValMembers {
sort.Strings(names)
missingOutput = append(missingOutput, strings.Join(names, "|"))
}
missingOutput = append(missingOutput, otherMembers...)
sort.Strings(missingOutput)
return missingOutput
}
func reportSwitch(
pass *analysis.Pass,
sw *ast.SwitchStmt,
defaultCase *ast.CaseClause,
samePkg bool,
enumType *types.Named,
em *enumMembers,
missingMembers map[string]struct{},
f *ast.File,
) {
missingOutput := determineMissingOutput(missingMembers, em)
var fixes []analysis.SuggestedFix
if fix, ok := computeFix(pass, pass.Fset, f, sw, defaultCase, enumType, samePkg, missingMembers); ok {
fixes = append(fixes, fix)
}
pass.Report(analysis.Diagnostic{
Pos: sw.Pos(),
End: sw.End(),
Message: fmt.Sprintf("missing cases in switch of type %s: %s", enumTypeName(enumType, samePkg), strings.Join(missingOutput, ", ")),
SuggestedFixes: fixes,
})
}
func computeFix(pass *analysis.Pass, fset *token.FileSet, f *ast.File, sw *ast.SwitchStmt, defaultCase *ast.CaseClause, enumType *types.Named, samePkg bool, missingMembers map[string]struct{}) (analysis.SuggestedFix, bool) {
// Function and method calls may be mutative, so we don't want to reuse the
// call expression in the about-to-be-inserted case clause body. So we just
// don't suggest a fix in such situations.
//
// However, we need to make an exception for type conversions, which are
// also call expressions in the AST.
//
// We'll need to lookup type information for this, and can't rely solely
// on the AST.
if containsFuncCall(pass, sw.Tag) {
return analysis.SuggestedFix{}, false
}
textEdits := []analysis.TextEdit{missingCasesTextEdit(fset, f, samePkg, sw, defaultCase, enumType, missingMembers)}
// need to add "fmt" import if "fmt" import doesn't already exist
if !hasImportWithPath(fset, f, `"fmt"`) {
textEdits = append(textEdits, fmtImportTextEdit(fset, f))
}
missing := make([]string, 0, len(missingMembers))
for m := range missingMembers {
missing = append(missing, m)
}
sort.Strings(missing)
return analysis.SuggestedFix{
Message: fmt.Sprintf("add case clause for: %s", strings.Join(missing, ", ")),
TextEdits: textEdits,
}, true
}
func containsFuncCall(pass *analysis.Pass, e ast.Expr) bool {
e = astutil.Unparen(e)
c, ok := e.(*ast.CallExpr)
if !ok {
return false
}
if _, isFunc := pass.TypesInfo.TypeOf(c.Fun).Underlying().(*types.Signature); isFunc {
return true
}
for _, a := range c.Args {
if containsFuncCall(pass, a) {
return true
}
}
return false
}
func firstImportDecl(fset *token.FileSet, f *ast.File) *ast.GenDecl {
for _, decl := range f.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if ok && genDecl.Tok == token.IMPORT {
// first IMPORT GenDecl
return genDecl
}
}
return nil
}
// copies an GenDecl in a manner such that appending to the returned GenDecl's Specs field
// doesn't mutate the original GenDecl
func copyGenDecl(im *ast.GenDecl) *ast.GenDecl {
imCopy := *im
imCopy.Specs = make([]ast.Spec, len(im.Specs))
for i := range im.Specs {
imCopy.Specs[i] = im.Specs[i]
}
return &imCopy
}
func hasImportWithPath(fset *token.FileSet, f *ast.File, pathLiteral string) bool {
igroups := astutil.Imports(fset, f)
for _, igroup := range igroups {
for _, importSpec := range igroup {
if importSpec.Path.Value == pathLiteral {
return true
}
}
}
return false
}
func fmtImportTextEdit(fset *token.FileSet, f *ast.File) analysis.TextEdit {
firstDecl := firstImportDecl(fset, f)
if firstDecl == nil {
// file has no import declarations
// insert "fmt" import spec after package statement
return analysis.TextEdit{
Pos: f.Name.End() + 1, // end of package name + 1
End: f.Name.End() + 1,
NewText: []byte(`import (
"fmt"
)`),
}
}
// copy because we'll be mutating its Specs field
firstDeclCopy := copyGenDecl(firstDecl)
// find insertion index for "fmt" import spec
var i int
for ; i < len(firstDeclCopy.Specs); i++ {
im := firstDeclCopy.Specs[i].(*ast.ImportSpec)
if v, _ := strconv.Unquote(im.Path.Value); v > "fmt" {
break
}
}
// insert "fmt" import spec at the index
fmtSpec := &ast.ImportSpec{
Path: &ast.BasicLit{
// NOTE: Pos field doesn't seem to be required for our
// purposes here.
Kind: token.STRING,
Value: `"fmt"`,
},
}
s := firstDeclCopy.Specs // local var for easier comprehension of next line
s = append(s[:i], append([]ast.Spec{fmtSpec}, s[i:]...)...)
firstDeclCopy.Specs = s
// create the text edit
var buf bytes.Buffer
printer.Fprint(&buf, fset, firstDeclCopy)
return analysis.TextEdit{
Pos: firstDecl.Pos(),
End: firstDecl.End(),
NewText: buf.Bytes(),
}
}
func missingCasesTextEdit(fset *token.FileSet, f *ast.File, samePkg bool, sw *ast.SwitchStmt, defaultCase *ast.CaseClause, enumType *types.Named, missingMembers map[string]struct{}) analysis.TextEdit {
// ... Construct insertion text for case clause and its body ...
var tag bytes.Buffer
printer.Fprint(&tag, fset, sw.Tag)
// If possible and if necessary, determine the package identifier based on
// the AST of other `case` clauses.
var pkgIdent *ast.Ident
if !samePkg {
for _, stmt := range sw.Body.List {
caseCl := stmt.(*ast.CaseClause)
if len(caseCl.List) != 0 { // guard against default case
if sel, ok := caseCl.List[0].(*ast.SelectorExpr); ok {
pkgIdent = sel.X.(*ast.Ident)
break
}
}
}
}
missing := make([]string, 0, len(missingMembers))
for m := range missingMembers {
if !samePkg {
if pkgIdent != nil {
// we were able to determine package identifier
missing = append(missing, pkgIdent.Name+"."+m)
} else {
// use the package name (may not be correct always)
//
// TODO: May need to also add import if the package isn't imported
// elsewhere. This (ie, a switch with zero case clauses) should
// happen rarely, so don't implement this for now.
missing = append(missing, enumType.Obj().Pkg().Name()+"."+m)
}
} else {
missing = append(missing, m)
}
}
sort.Strings(missing)
insert := `case ` + strings.Join(missing, ", ") + `:
panic(fmt.Sprintf("unhandled value: %v",` + tag.String() + `))`
// ... Create the text edit ...
pos := sw.Body.Rbrace - 1 // put it as last case
if defaultCase != nil {
pos = defaultCase.Case - 2 // put it before the default case (why -2?)
}
return analysis.TextEdit{
Pos: pos,
End: pos,
NewText: []byte(insert),
}
}