Marvin Preuss
1d4ae27878
All checks were successful
continuous-integration/drone/push Build is passing
312 lines
6.6 KiB
Go
312 lines
6.6 KiB
Go
package analyzer
|
|
|
|
import (
|
|
"go/types"
|
|
|
|
"golang.org/x/tools/go/analysis"
|
|
"golang.org/x/tools/go/analysis/passes/buildssa"
|
|
"golang.org/x/tools/go/ssa"
|
|
)
|
|
|
|
const (
|
|
rowsName = "Rows"
|
|
stmtName = "Stmt"
|
|
closeMethod = "Close"
|
|
)
|
|
|
|
var (
|
|
sqlPackages = []string{
|
|
"database/sql",
|
|
"github.com/jmoiron/sqlx",
|
|
}
|
|
)
|
|
|
|
func NewAnalyzer() *analysis.Analyzer {
|
|
return &analysis.Analyzer{
|
|
Name: "sqlclosecheck",
|
|
Doc: "Checks that sql.Rows and sql.Stmt are closed.",
|
|
Run: run,
|
|
Requires: []*analysis.Analyzer{
|
|
buildssa.Analyzer,
|
|
},
|
|
}
|
|
}
|
|
|
|
func run(pass *analysis.Pass) (interface{}, error) {
|
|
pssa := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
|
|
|
|
// Build list of types we are looking for
|
|
targetTypes := getTargetTypes(pssa, sqlPackages)
|
|
|
|
// If non of the types are found, skip
|
|
if len(targetTypes) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
funcs := pssa.SrcFuncs
|
|
for _, f := range funcs {
|
|
for _, b := range f.Blocks {
|
|
for i := range b.Instrs {
|
|
// Check if instruction is call that returns a target type
|
|
targetValues := getTargetTypesValues(b, i, targetTypes)
|
|
if len(targetValues) == 0 {
|
|
continue
|
|
}
|
|
|
|
// log.Printf("%s", f.Name())
|
|
|
|
// For each found target check if they are closed and deferred
|
|
for _, targetValue := range targetValues {
|
|
refs := (*targetValue.value).Referrers()
|
|
isClosed := checkClosed(refs, targetTypes)
|
|
if !isClosed {
|
|
pass.Reportf((targetValue.instr).Pos(), "Rows/Stmt was not closed")
|
|
}
|
|
|
|
checkDeferred(pass, refs, targetTypes, false)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []*types.Pointer {
|
|
targets := []*types.Pointer{}
|
|
|
|
for _, sqlPkg := range targetPackages {
|
|
pkg := pssa.Pkg.Prog.ImportedPackage(sqlPkg)
|
|
if pkg == nil {
|
|
// the SQL package being checked isn't imported
|
|
return targets
|
|
}
|
|
|
|
rowsType := getTypePointerFromName(pkg, rowsName)
|
|
if rowsType != nil {
|
|
targets = append(targets, rowsType)
|
|
}
|
|
|
|
stmtType := getTypePointerFromName(pkg, stmtName)
|
|
if stmtType != nil {
|
|
targets = append(targets, stmtType)
|
|
}
|
|
}
|
|
|
|
return targets
|
|
}
|
|
|
|
func getTypePointerFromName(pkg *ssa.Package, name string) *types.Pointer {
|
|
pkgType := pkg.Type(name)
|
|
if pkgType == nil {
|
|
// this package does not use Rows/Stmt
|
|
return nil
|
|
}
|
|
|
|
obj := pkgType.Object()
|
|
named, ok := obj.Type().(*types.Named)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
return types.NewPointer(named)
|
|
}
|
|
|
|
type targetValue struct {
|
|
value *ssa.Value
|
|
instr ssa.Instruction
|
|
}
|
|
|
|
func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []*types.Pointer) []targetValue {
|
|
targetValues := []targetValue{}
|
|
|
|
instr := b.Instrs[i]
|
|
call, ok := instr.(*ssa.Call)
|
|
if !ok {
|
|
return targetValues
|
|
}
|
|
|
|
signature := call.Call.Signature()
|
|
results := signature.Results()
|
|
for i := 0; i < results.Len(); i++ {
|
|
v := results.At(i)
|
|
varType := v.Type()
|
|
|
|
for _, targetType := range targetTypes {
|
|
if !types.Identical(varType, targetType) {
|
|
continue
|
|
}
|
|
|
|
for _, cRef := range *call.Referrers() {
|
|
switch instr := cRef.(type) {
|
|
case *ssa.Call:
|
|
if len(instr.Call.Args) >= 1 && types.Identical(instr.Call.Args[0].Type(), targetType) {
|
|
targetValues = append(targetValues, targetValue{
|
|
value: &instr.Call.Args[0],
|
|
instr: call,
|
|
})
|
|
}
|
|
case ssa.Value:
|
|
if types.Identical(instr.Type(), targetType) {
|
|
targetValues = append(targetValues, targetValue{
|
|
value: &instr,
|
|
instr: call,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return targetValues
|
|
}
|
|
|
|
func checkClosed(refs *[]ssa.Instruction, targetTypes []*types.Pointer) bool {
|
|
numInstrs := len(*refs)
|
|
for idx, ref := range *refs {
|
|
// log.Printf("%T - %s", ref, ref)
|
|
|
|
action := getAction(ref, targetTypes)
|
|
switch action {
|
|
case "closed":
|
|
return true
|
|
case "passed":
|
|
// Passed and not used after
|
|
if numInstrs == idx+1 {
|
|
return true
|
|
}
|
|
case "returned":
|
|
return true
|
|
case "handled":
|
|
return true
|
|
default:
|
|
// log.Printf(action)
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func getAction(instr ssa.Instruction, targetTypes []*types.Pointer) string {
|
|
switch instr := instr.(type) {
|
|
case *ssa.Defer:
|
|
if instr.Call.Value == nil {
|
|
return "unvalued defer"
|
|
}
|
|
|
|
name := instr.Call.Value.Name()
|
|
if name == closeMethod {
|
|
return "closed"
|
|
}
|
|
case *ssa.Call:
|
|
if instr.Call.Value == nil {
|
|
return "unvalued call"
|
|
}
|
|
|
|
isTarget := false
|
|
receiver := instr.Call.StaticCallee().Signature.Recv()
|
|
if receiver != nil {
|
|
isTarget = isTargetType(receiver.Type(), targetTypes)
|
|
}
|
|
|
|
name := instr.Call.Value.Name()
|
|
if isTarget && name == closeMethod {
|
|
return "closed"
|
|
}
|
|
|
|
if !isTarget {
|
|
return "passed"
|
|
}
|
|
case *ssa.Phi:
|
|
return "passed"
|
|
case *ssa.MakeInterface:
|
|
return "passed"
|
|
case *ssa.Store:
|
|
if len(*instr.Addr.Referrers()) == 0 {
|
|
return "noop"
|
|
}
|
|
|
|
for _, aRef := range *instr.Addr.Referrers() {
|
|
if c, ok := aRef.(*ssa.MakeClosure); ok {
|
|
f := c.Fn.(*ssa.Function)
|
|
for _, b := range f.Blocks {
|
|
if checkClosed(&b.Instrs, targetTypes) {
|
|
return "handled"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case *ssa.UnOp:
|
|
instrType := instr.Type()
|
|
for _, targetType := range targetTypes {
|
|
if types.Identical(instrType, targetType) {
|
|
if checkClosed(instr.Referrers(), targetTypes) {
|
|
return "handled"
|
|
}
|
|
}
|
|
}
|
|
case *ssa.FieldAddr:
|
|
if checkClosed(instr.Referrers(), targetTypes) {
|
|
return "handled"
|
|
}
|
|
case *ssa.Return:
|
|
return "returned"
|
|
default:
|
|
// log.Printf("%s", instr)
|
|
}
|
|
|
|
return "unhandled"
|
|
}
|
|
|
|
func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes []*types.Pointer, inDefer bool) {
|
|
for _, instr := range *instrs {
|
|
switch instr := instr.(type) {
|
|
case *ssa.Defer:
|
|
if instr.Call.Value != nil && instr.Call.Value.Name() == closeMethod {
|
|
return
|
|
}
|
|
case *ssa.Call:
|
|
if instr.Call.Value != nil && instr.Call.Value.Name() == closeMethod {
|
|
if !inDefer {
|
|
pass.Reportf(instr.Pos(), "Close should use defer")
|
|
}
|
|
|
|
return
|
|
}
|
|
case *ssa.Store:
|
|
if len(*instr.Addr.Referrers()) == 0 {
|
|
return
|
|
}
|
|
|
|
for _, aRef := range *instr.Addr.Referrers() {
|
|
if c, ok := aRef.(*ssa.MakeClosure); ok {
|
|
f := c.Fn.(*ssa.Function)
|
|
|
|
for _, b := range f.Blocks {
|
|
checkDeferred(pass, &b.Instrs, targetTypes, true)
|
|
}
|
|
}
|
|
}
|
|
case *ssa.UnOp:
|
|
instrType := instr.Type()
|
|
for _, targetType := range targetTypes {
|
|
if types.Identical(instrType, targetType) {
|
|
checkDeferred(pass, instr.Referrers(), targetTypes, inDefer)
|
|
}
|
|
}
|
|
case *ssa.FieldAddr:
|
|
checkDeferred(pass, instr.Referrers(), targetTypes, inDefer)
|
|
}
|
|
}
|
|
}
|
|
|
|
func isTargetType(t types.Type, targetTypes []*types.Pointer) bool {
|
|
for _, targetType := range targetTypes {
|
|
if types.Identical(t, targetType) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|