wg-quicker/vendor/github.com/sylvia7788/contextcheck/contextcheck.go

508 lines
9.9 KiB
Go

package contextcheck
import (
"go/ast"
"go/token"
"go/types"
"strconv"
"strings"
"sync"
"github.com/gostaticanalysis/analysisutil"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/buildssa"
"golang.org/x/tools/go/ssa"
)
func NewAnalyzer() *analysis.Analyzer {
return &analysis.Analyzer{
Name: "contextcheck",
Doc: "check the function whether use a non-inherited context",
Run: NewRun(),
Requires: []*analysis.Analyzer{
buildssa.Analyzer,
},
}
}
const (
ctxPkg = "context"
ctxName = "Context"
)
const (
CtxIn int = 1 << iota // ctx in function's param
CtxOut // ctx in function's results
CtxInField // ctx in function's field param
CtxInOut = CtxIn | CtxOut
)
var (
checkedMap = make(map[string]bool)
checkedMapLock sync.RWMutex
)
type runner struct {
pass *analysis.Pass
ctxTyp *types.Named
ctxPTyp *types.Pointer
cmpPath string
skipFile map[*ast.File]bool
}
func NewRun() func(pass *analysis.Pass) (interface{}, error) {
return func(pass *analysis.Pass) (interface{}, error) {
r := new(runner)
r.run(pass)
return nil, nil
}
}
func (r *runner) run(pass *analysis.Pass) {
r.pass = pass
r.cmpPath = strings.Split(pass.Pkg.Path(), "/")[0]
pssa := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
funcs := pssa.SrcFuncs
name := pass.Pkg.Path()
_ = name
pkg := pssa.Pkg.Prog.ImportedPackage(ctxPkg)
if pkg == nil {
return
}
ctxType := pkg.Type(ctxName)
if ctxType == nil {
return
}
if resNamed, ok := ctxType.Object().Type().(*types.Named); !ok {
return
} else {
r.ctxTyp = resNamed
r.ctxPTyp = types.NewPointer(resNamed)
}
r.skipFile = make(map[*ast.File]bool)
for _, f := range funcs {
// skip checked function
key := f.RelString(nil)
_, ok := getValue(key)
if ok {
continue
}
if !r.checkIsEntry(f, f.Pos()) {
continue
}
r.checkFuncWithCtx(f)
setValue(key, true)
}
}
func (r *runner) noImportedContext(f *ssa.Function) (ret bool) {
if !f.Pos().IsValid() {
return false
}
file := analysisutil.File(r.pass, f.Pos())
if file == nil {
return false
}
if skip, has := r.skipFile[file]; has {
return skip
}
defer func() {
r.skipFile[file] = ret
}()
for _, impt := range file.Imports {
path, err := strconv.Unquote(impt.Path.Value)
if err != nil {
continue
}
path = analysisutil.RemoveVendor(path)
if path == ctxPkg {
return false
}
}
return true
}
func (r *runner) checkIsEntry(f *ssa.Function, pos token.Pos) (ret bool) {
if r.noImportedContext(f) {
return false
}
// check params
tuple := f.Signature.Params()
for i := 0; i < tuple.Len(); i++ {
if r.isCtxType(tuple.At(i).Type()) {
ret = true
break
}
}
// check freevars
for _, param := range f.FreeVars {
if r.isCtxType(param.Type()) {
ret = true
break
}
}
// check results
tuple = f.Signature.Results()
for i := 0; i < tuple.Len(); i++ {
// skip the function which generate ctx
if r.isCtxType(tuple.At(i).Type()) {
ret = false
break
}
}
return
}
func (r *runner) collectCtxRef(f *ssa.Function) (refMap map[ssa.Instruction]bool, ok bool) {
ok = true
refMap = make(map[ssa.Instruction]bool)
checkedRefMap := make(map[ssa.Value]bool)
storeInstrs := make(map[*ssa.Store]bool)
phiInstrs := make(map[*ssa.Phi]bool)
var checkRefs func(val ssa.Value, fromAddr bool)
var checkInstr func(instr ssa.Instruction, fromAddr bool)
checkRefs = func(val ssa.Value, fromAddr bool) {
if val == nil || val.Referrers() == nil {
return
}
if checkedRefMap[val] {
return
}
checkedRefMap[val] = true
for _, instr := range *val.Referrers() {
checkInstr(instr, fromAddr)
}
}
checkInstr = func(instr ssa.Instruction, fromAddr bool) {
switch i := instr.(type) {
case ssa.CallInstruction:
refMap[i] = true
tp := r.getCallInstrCtxType(i)
if tp&CtxOut != 0 {
// collect referrers of the results
checkRefs(i.Value(), false)
return
}
case *ssa.Store:
if fromAddr {
// collect all store to judge whether it's right value is valid
storeInstrs[i] = true
} else {
checkRefs(i.Addr, true)
}
case *ssa.UnOp:
checkRefs(i, false)
case *ssa.MakeClosure:
for _, param := range i.Bindings {
if r.isCtxType(param.Type()) {
refMap[i] = true
break
}
}
case *ssa.Extract:
// only care about ctx
if r.isCtxType(i.Type()) {
checkRefs(i, false)
}
case *ssa.Phi:
phiInstrs[i] = true
checkRefs(i, false)
case *ssa.TypeAssert:
// ctx.(*bm.Context)
}
}
for _, param := range f.Params {
if r.isCtxType(param.Type()) {
checkRefs(param, false)
}
}
for _, param := range f.FreeVars {
if r.isCtxType(param.Type()) {
checkRefs(param, false)
}
}
for instr := range storeInstrs {
if !checkedRefMap[instr.Val] {
r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` instead")
ok = false
}
}
for instr := range phiInstrs {
for _, v := range instr.Edges {
if !checkedRefMap[v] {
r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` instead")
ok = false
}
}
}
return
}
func (r *runner) buildPkg(f *ssa.Function) {
if f.Blocks != nil {
return
}
// only build the pkg which is in the same repo
if r.checkIsSameRepo(f.Pkg.Pkg.Path()) {
f.Pkg.Build()
}
}
func (r *runner) checkIsSameRepo(s string) bool {
return strings.HasPrefix(s, r.cmpPath+"/")
}
func (r *runner) checkFuncWithCtx(f *ssa.Function) {
refMap, ok := r.collectCtxRef(f)
if !ok {
return
}
for _, b := range f.Blocks {
for _, instr := range b.Instrs {
tp, ok := r.getCtxType(instr)
if !ok {
continue
}
// checked in collectCtxRef, skipped
if tp&CtxOut != 0 {
continue
}
if tp&CtxIn != 0 {
if !refMap[instr] {
r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` instead")
}
}
ff := r.getFunction(instr)
if ff == nil {
continue
}
key := ff.RelString(nil)
valid, ok := getValue(key)
if ok {
if !valid {
r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name())
}
continue
}
// check is thunk or bound
if strings.HasSuffix(key, "$thunk") || strings.HasSuffix(key, "$bound") {
continue
}
// if ff has no ctx, start deep traversal check
if !r.checkIsEntry(ff, instr.Pos()) {
r.buildPkg(ff)
checkingMap := make(map[string]bool)
checkingMap[key] = true
valid := r.checkFuncWithoutCtx(ff, checkingMap)
setValue(key, valid)
if !valid {
r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name())
}
}
}
}
}
func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]bool) (ret bool) {
ret = true
for _, b := range f.Blocks {
for _, instr := range b.Instrs {
tp, ok := r.getCtxType(instr)
if !ok {
continue
}
if tp&CtxOut != 0 {
continue
}
// it is considered illegal as long as ctx is in the input and not in *struct X
if tp&CtxIn != 0 {
if tp&CtxInField == 0 {
ret = false
}
continue
}
ff := r.getFunction(instr)
if ff == nil {
continue
}
key := ff.RelString(nil)
valid, ok := getValue(key)
if ok {
if !valid {
ret = false
r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name())
}
continue
}
// check is thunk or bound
if strings.HasSuffix(key, "$thunk") || strings.HasSuffix(key, "$bound") {
continue
}
if !r.checkIsEntry(ff, instr.Pos()) {
// handler ring call
if checkingMap[key] {
continue
}
checkingMap[key] = true
r.buildPkg(ff)
valid := r.checkFuncWithoutCtx(ff, checkingMap)
setValue(key, valid)
if !valid {
ret = false
r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name())
}
}
}
}
return ret
}
func (r *runner) getCtxType(instr ssa.Instruction) (tp int, ok bool) {
switch i := instr.(type) {
case ssa.CallInstruction:
tp = r.getCallInstrCtxType(i)
ok = true
case *ssa.MakeClosure:
tp = r.getMakeClosureCtxType(i)
ok = true
}
return
}
func (r *runner) getCallInstrCtxType(c ssa.CallInstruction) (tp int) {
// check params
for _, v := range c.Common().Args {
if r.isCtxType(v.Type()) {
if vv, ok := v.(*ssa.UnOp); ok {
if _, ok := vv.X.(*ssa.FieldAddr); ok {
tp |= CtxInField
}
}
tp |= CtxIn
break
}
}
// check results
if v := c.Value(); v != nil {
if r.isCtxType(v.Type()) {
tp |= CtxOut
} else {
tuple, ok := v.Type().(*types.Tuple)
if !ok {
return
}
for i := 0; i < tuple.Len(); i++ {
if r.isCtxType(tuple.At(i).Type()) {
tp |= CtxOut
break
}
}
}
}
return
}
func (r *runner) getMakeClosureCtxType(c *ssa.MakeClosure) (tp int) {
for _, v := range c.Bindings {
if r.isCtxType(v.Type()) {
if vv, ok := v.(*ssa.UnOp); ok {
if _, ok := vv.X.(*ssa.FieldAddr); ok {
tp |= CtxInField
}
}
tp |= CtxIn
break
}
}
return
}
func (r *runner) getFunction(instr ssa.Instruction) (f *ssa.Function) {
switch i := instr.(type) {
case ssa.CallInstruction:
if i.Common().IsInvoke() {
return
}
switch c := i.Common().Value.(type) {
case *ssa.Function:
f = c
case *ssa.MakeClosure:
// captured in the outer layer
case *ssa.Builtin, *ssa.UnOp, *ssa.Lookup, *ssa.Phi:
// skipped
case *ssa.Extract, *ssa.Call:
// function is a result of a call, skipped
case *ssa.Parameter:
// function is a param, skipped
}
case *ssa.MakeClosure:
f = i.Fn.(*ssa.Function)
}
return
}
func (r *runner) isCtxType(tp types.Type) bool {
return types.Identical(tp, r.ctxTyp) || types.Identical(tp, r.ctxPTyp)
}
func getValue(key string) (valid, ok bool) {
checkedMapLock.RLock()
valid, ok = checkedMap[key]
checkedMapLock.RUnlock()
return
}
func setValue(key string, valid bool) {
checkedMapLock.Lock()
checkedMap[key] = valid
checkedMapLock.Unlock()
}