wg-quicker/vendor/github.com/charithe/durationcheck/durationcheck.go

192 lines
4.5 KiB
Go

package durationcheck
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/token"
"go/types"
"log"
"os"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
)
var Analyzer = &analysis.Analyzer{
Name: "durationcheck",
Doc: "check for two durations multiplied together",
Run: run,
Requires: []*analysis.Analyzer{inspect.Analyzer},
}
func run(pass *analysis.Pass) (interface{}, error) {
// if the package does not import time, it can be skipped from analysis
if !hasImport(pass.Pkg, "time") {
return nil, nil
}
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
nodeTypes := []ast.Node{
(*ast.BinaryExpr)(nil),
}
inspect.Preorder(nodeTypes, check(pass))
return nil, nil
}
func hasImport(pkg *types.Package, importPath string) bool {
for _, imp := range pkg.Imports() {
if imp.Path() == importPath {
return true
}
}
return false
}
// check contains the logic for checking that time.Duration is used correctly in the code being analysed
func check(pass *analysis.Pass) func(ast.Node) {
return func(node ast.Node) {
expr := node.(*ast.BinaryExpr)
// we are only interested in multiplication
if expr.Op != token.MUL {
return
}
// get the types of the two operands
x, xOK := pass.TypesInfo.Types[expr.X]
y, yOK := pass.TypesInfo.Types[expr.Y]
if !xOK || !yOK {
return
}
if isDuration(x.Type) && isDuration(y.Type) {
// check that both sides are acceptable expressions
if isUnacceptableExpr(pass, expr.X) && isUnacceptableExpr(pass, expr.Y) {
pass.Reportf(expr.Pos(), "Multiplication of durations: `%s`", formatNode(expr))
}
}
}
}
func isDuration(x types.Type) bool {
return x.String() == "time.Duration" || x.String() == "*time.Duration"
}
// isUnacceptableExpr returns true if the argument is not an acceptable time.Duration expression
func isUnacceptableExpr(pass *analysis.Pass, expr ast.Expr) bool {
switch e := expr.(type) {
case *ast.BasicLit:
return false
case *ast.Ident:
return !isAcceptableNestedExpr(pass, e)
case *ast.CallExpr:
return !isAcceptableCast(pass, e)
case *ast.BinaryExpr:
return !isAcceptableNestedExpr(pass, e)
case *ast.UnaryExpr:
return !isAcceptableNestedExpr(pass, e)
case *ast.SelectorExpr:
return !isAcceptableNestedExpr(pass, e)
case *ast.StarExpr:
return !isAcceptableNestedExpr(pass, e)
case *ast.ParenExpr:
return !isAcceptableNestedExpr(pass, e)
case *ast.IndexExpr:
return !isAcceptableNestedExpr(pass, e)
default:
return true
}
}
// isAcceptableCast returns true if the argument is an acceptable expression cast to time.Duration
func isAcceptableCast(pass *analysis.Pass, e *ast.CallExpr) bool {
// check that there's a single argument
if len(e.Args) != 1 {
return false
}
// check that the argument is acceptable
if !isAcceptableNestedExpr(pass, e.Args[0]) {
return false
}
// check for time.Duration cast
selector, ok := e.Fun.(*ast.SelectorExpr)
if !ok {
return false
}
return isDurationCast(selector)
}
func isDurationCast(selector *ast.SelectorExpr) bool {
pkg, ok := selector.X.(*ast.Ident)
if !ok {
return false
}
if pkg.Name != "time" {
return false
}
return selector.Sel.Name == "Duration"
}
func isAcceptableNestedExpr(pass *analysis.Pass, n ast.Expr) bool {
switch e := n.(type) {
case *ast.BasicLit:
return true
case *ast.BinaryExpr:
return isAcceptableNestedExpr(pass, e.X) && isAcceptableNestedExpr(pass, e.Y)
case *ast.UnaryExpr:
return isAcceptableNestedExpr(pass, e.X)
case *ast.Ident:
return isAcceptableIdent(pass, e)
case *ast.CallExpr:
if isAcceptableCast(pass, e) {
return true
}
t := pass.TypesInfo.TypeOf(e)
return !isDuration(t)
case *ast.SelectorExpr:
return isAcceptableNestedExpr(pass, e.X) && isAcceptableIdent(pass, e.Sel)
case *ast.StarExpr:
return isAcceptableNestedExpr(pass, e.X)
case *ast.ParenExpr:
return isAcceptableNestedExpr(pass, e.X)
case *ast.IndexExpr:
t := pass.TypesInfo.TypeOf(e)
return !isDuration(t)
default:
return false
}
}
func isAcceptableIdent(pass *analysis.Pass, ident *ast.Ident) bool {
obj := pass.TypesInfo.ObjectOf(ident)
return !isDuration(obj.Type())
}
func formatNode(node ast.Node) string {
buf := new(bytes.Buffer)
if err := format.Node(buf, token.NewFileSet(), node); err != nil {
log.Printf("Error formatting expression: %v", err)
return ""
}
return buf.String()
}
func printAST(msg string, node ast.Node) {
fmt.Printf(">>> %s:\n%s\n\n\n", msg, formatNode(node))
ast.Fprint(os.Stdout, nil, node, nil)
fmt.Println("--------------")
}