332 lines
6.5 KiB
Go
332 lines
6.5 KiB
Go
|
package rowserr
|
||
|
|
||
|
import (
|
||
|
"go/ast"
|
||
|
"go/types"
|
||
|
"strconv"
|
||
|
|
||
|
"github.com/gostaticanalysis/analysisutil"
|
||
|
"golang.org/x/tools/go/analysis"
|
||
|
"golang.org/x/tools/go/analysis/passes/buildssa"
|
||
|
"golang.org/x/tools/go/ssa"
|
||
|
)
|
||
|
|
||
|
func NewAnalyzer(sqlPkgs ...string) *analysis.Analyzer {
|
||
|
return &analysis.Analyzer{
|
||
|
Name: "rowserrcheck",
|
||
|
Doc: Doc,
|
||
|
Run: NewRun(sqlPkgs...),
|
||
|
Requires: []*analysis.Analyzer{
|
||
|
buildssa.Analyzer,
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
Doc = "rowserrcheck checks whether Rows.Err is checked"
|
||
|
errMethod = "Err"
|
||
|
rowsName = "Rows"
|
||
|
)
|
||
|
|
||
|
type runner struct {
|
||
|
pass *analysis.Pass
|
||
|
rowsTyp *types.Pointer
|
||
|
rowsObj types.Object
|
||
|
skipFile map[*ast.File]bool
|
||
|
sqlPkgs []string
|
||
|
}
|
||
|
|
||
|
func NewRun(pkgs ...string) func(pass *analysis.Pass) (interface{}, error) {
|
||
|
return func(pass *analysis.Pass) (interface{}, error) {
|
||
|
sqlPkgs := append(pkgs, "database/sql")
|
||
|
for _, pkg := range sqlPkgs {
|
||
|
r := new(runner)
|
||
|
r.sqlPkgs = sqlPkgs
|
||
|
r.run(pass, pkg)
|
||
|
}
|
||
|
return nil, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// run executes an analysis for the pass. The receiver is passed
|
||
|
// by value because this func is called in parallel for different passes.
|
||
|
func (r runner) run(pass *analysis.Pass, pkgPath string) {
|
||
|
r.pass = pass
|
||
|
pssa := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
|
||
|
funcs := pssa.SrcFuncs
|
||
|
|
||
|
pkg := pssa.Pkg.Prog.ImportedPackage(pkgPath)
|
||
|
if pkg == nil {
|
||
|
// skip
|
||
|
return
|
||
|
}
|
||
|
|
||
|
rowsType := pkg.Type(rowsName)
|
||
|
if rowsType == nil {
|
||
|
// skip checking
|
||
|
return
|
||
|
}
|
||
|
|
||
|
r.rowsObj = rowsType.Object()
|
||
|
if r.rowsObj == nil {
|
||
|
// skip checking
|
||
|
return
|
||
|
}
|
||
|
|
||
|
resNamed, ok := r.rowsObj.Type().(*types.Named)
|
||
|
if !ok {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
r.rowsTyp = types.NewPointer(resNamed)
|
||
|
r.skipFile = map[*ast.File]bool{}
|
||
|
|
||
|
for _, f := range funcs {
|
||
|
if r.noImportedDBSQL(f) {
|
||
|
// skip this
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// skip if the function is just referenced
|
||
|
var isRefFunc bool
|
||
|
|
||
|
for i := 0; i < f.Signature.Results().Len(); i++ {
|
||
|
if types.Identical(f.Signature.Results().At(i).Type(), r.rowsTyp) {
|
||
|
isRefFunc = true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if isRefFunc {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
for _, b := range f.Blocks {
|
||
|
for i := range b.Instrs {
|
||
|
if r.errCallMissing(b, i) {
|
||
|
pass.Reportf(b.Instrs[i].Pos(), "rows.Err must be checked")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *runner) errCallMissing(b *ssa.BasicBlock, i int) (ret bool) {
|
||
|
call, ok := r.getCallReturnsRow(b.Instrs[i])
|
||
|
if !ok {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
for _, cRef := range *call.Referrers() {
|
||
|
val, ok := r.getRowsVal(cRef)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
if len(*val.Referrers()) == 0 {
|
||
|
continue
|
||
|
}
|
||
|
resRefs := *val.Referrers()
|
||
|
var errCalled func(resRef ssa.Instruction) bool
|
||
|
errCalled = func(resRef ssa.Instruction) bool {
|
||
|
switch resRef := resRef.(type) {
|
||
|
case *ssa.Phi:
|
||
|
for _, rf := range *resRef.Referrers() {
|
||
|
if errCalled(rf) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
case *ssa.Store: // Call in Closure function
|
||
|
for _, aref := range *resRef.Addr.Referrers() {
|
||
|
switch c := aref.(type) {
|
||
|
case *ssa.MakeClosure:
|
||
|
f := c.Fn.(*ssa.Function)
|
||
|
if r.noImportedDBSQL(f) {
|
||
|
// skip this
|
||
|
continue
|
||
|
}
|
||
|
called := r.isClosureCalled(c)
|
||
|
if r.calledInFunc(f, called) {
|
||
|
return true
|
||
|
}
|
||
|
case *ssa.UnOp:
|
||
|
for _, rf := range *c.Referrers() {
|
||
|
if errCalled(rf) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
case *ssa.Call: // Indirect function call
|
||
|
if r.isErrCall(resRef) {
|
||
|
return true
|
||
|
}
|
||
|
if f, ok := resRef.Call.Value.(*ssa.Function); ok {
|
||
|
for _, b := range f.Blocks {
|
||
|
for i := range b.Instrs {
|
||
|
if !r.errCallMissing(b, i) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
case *ssa.FieldAddr:
|
||
|
for _, bRef := range *resRef.Referrers() {
|
||
|
bOp, ok := r.getBodyOp(bRef)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
for _, ccall := range *bOp.Referrers() {
|
||
|
if r.isErrCall(ccall) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
for _, resRef := range resRefs {
|
||
|
if errCalled(resRef) {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (r *runner) getCallReturnsRow(instr ssa.Instruction) (*ssa.Call, bool) {
|
||
|
call, ok := instr.(*ssa.Call)
|
||
|
if !ok {
|
||
|
return nil, false
|
||
|
}
|
||
|
|
||
|
res := call.Call.Signature().Results()
|
||
|
flag := false
|
||
|
|
||
|
for i := 0; i < res.Len(); i++ {
|
||
|
flag = flag || types.Identical(res.At(i).Type(), r.rowsTyp)
|
||
|
}
|
||
|
|
||
|
if !flag {
|
||
|
return nil, false
|
||
|
}
|
||
|
|
||
|
return call, true
|
||
|
}
|
||
|
|
||
|
func (r *runner) getRowsVal(instr ssa.Instruction) (ssa.Value, bool) {
|
||
|
switch instr := instr.(type) {
|
||
|
case *ssa.Call:
|
||
|
if len(instr.Call.Args) == 1 && types.Identical(instr.Call.Args[0].Type(), r.rowsTyp) {
|
||
|
return instr.Call.Args[0], true
|
||
|
}
|
||
|
case ssa.Value:
|
||
|
if types.Identical(instr.Type(), r.rowsTyp) {
|
||
|
return instr, true
|
||
|
}
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
return nil, false
|
||
|
}
|
||
|
|
||
|
func (r *runner) getBodyOp(instr ssa.Instruction) (*ssa.UnOp, bool) {
|
||
|
op, ok := instr.(*ssa.UnOp)
|
||
|
if !ok {
|
||
|
return nil, false
|
||
|
}
|
||
|
// fix: try to check type
|
||
|
// if op.Type() != r.rowsObj.Type() {
|
||
|
// return nil, false
|
||
|
// }
|
||
|
return op, true
|
||
|
}
|
||
|
|
||
|
func (r *runner) isErrCall(ccall ssa.Instruction) bool {
|
||
|
switch ccall := ccall.(type) {
|
||
|
case *ssa.Defer:
|
||
|
if ccall.Call.Value != nil && ccall.Call.Value.Name() == errMethod {
|
||
|
return true
|
||
|
}
|
||
|
case *ssa.Call:
|
||
|
if ccall.Call.Value != nil && ccall.Call.Value.Name() == errMethod {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (r *runner) isClosureCalled(c *ssa.MakeClosure) bool {
|
||
|
for _, ref := range *c.Referrers() {
|
||
|
switch ref.(type) {
|
||
|
case *ssa.Call, *ssa.Defer:
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (r *runner) noImportedDBSQL(f *ssa.Function) (ret bool) {
|
||
|
obj := f.Object()
|
||
|
if obj == nil {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
file := analysisutil.File(r.pass, obj.Pos())
|
||
|
if file == nil {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
if skip, has := r.skipFile[file]; has {
|
||
|
return skip
|
||
|
}
|
||
|
defer func() {
|
||
|
r.skipFile[file] = ret
|
||
|
}()
|
||
|
|
||
|
for _, impt := range file.Imports {
|
||
|
path, err := strconv.Unquote(impt.Path.Value)
|
||
|
if err != nil {
|
||
|
continue
|
||
|
}
|
||
|
path = analysisutil.RemoveVendor(path)
|
||
|
for _, pkg := range r.sqlPkgs {
|
||
|
if pkg == path {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (r *runner) calledInFunc(f *ssa.Function, called bool) bool {
|
||
|
for _, b := range f.Blocks {
|
||
|
for i, instr := range b.Instrs {
|
||
|
switch instr := instr.(type) {
|
||
|
case *ssa.UnOp:
|
||
|
for _, ref := range *instr.Referrers() {
|
||
|
if v, ok := ref.(ssa.Value); ok {
|
||
|
if vCall, ok := v.(*ssa.Call); ok {
|
||
|
if vCall.Call.Value != nil && vCall.Call.Value.Name() == errMethod {
|
||
|
if called {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
default:
|
||
|
if r.errCallMissing(b, i) || !called {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|