workgroups/vendor/github.com/ryanrolds/sqlclosecheck/pkg/analyzer/analyzer.go
Marvin Preuss 1d4ae27878
All checks were successful
continuous-integration/drone/push Build is passing
ci: drone yaml with reusable anchors
2021-09-24 17:34:17 +02:00

312 lines
6.6 KiB
Go

package analyzer
import (
"go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/buildssa"
"golang.org/x/tools/go/ssa"
)
const (
rowsName = "Rows"
stmtName = "Stmt"
closeMethod = "Close"
)
var (
sqlPackages = []string{
"database/sql",
"github.com/jmoiron/sqlx",
}
)
func NewAnalyzer() *analysis.Analyzer {
return &analysis.Analyzer{
Name: "sqlclosecheck",
Doc: "Checks that sql.Rows and sql.Stmt are closed.",
Run: run,
Requires: []*analysis.Analyzer{
buildssa.Analyzer,
},
}
}
func run(pass *analysis.Pass) (interface{}, error) {
pssa := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
// Build list of types we are looking for
targetTypes := getTargetTypes(pssa, sqlPackages)
// If non of the types are found, skip
if len(targetTypes) == 0 {
return nil, nil
}
funcs := pssa.SrcFuncs
for _, f := range funcs {
for _, b := range f.Blocks {
for i := range b.Instrs {
// Check if instruction is call that returns a target type
targetValues := getTargetTypesValues(b, i, targetTypes)
if len(targetValues) == 0 {
continue
}
// log.Printf("%s", f.Name())
// For each found target check if they are closed and deferred
for _, targetValue := range targetValues {
refs := (*targetValue.value).Referrers()
isClosed := checkClosed(refs, targetTypes)
if !isClosed {
pass.Reportf((targetValue.instr).Pos(), "Rows/Stmt was not closed")
}
checkDeferred(pass, refs, targetTypes, false)
}
}
}
}
return nil, nil
}
func getTargetTypes(pssa *buildssa.SSA, targetPackages []string) []*types.Pointer {
targets := []*types.Pointer{}
for _, sqlPkg := range targetPackages {
pkg := pssa.Pkg.Prog.ImportedPackage(sqlPkg)
if pkg == nil {
// the SQL package being checked isn't imported
return targets
}
rowsType := getTypePointerFromName(pkg, rowsName)
if rowsType != nil {
targets = append(targets, rowsType)
}
stmtType := getTypePointerFromName(pkg, stmtName)
if stmtType != nil {
targets = append(targets, stmtType)
}
}
return targets
}
func getTypePointerFromName(pkg *ssa.Package, name string) *types.Pointer {
pkgType := pkg.Type(name)
if pkgType == nil {
// this package does not use Rows/Stmt
return nil
}
obj := pkgType.Object()
named, ok := obj.Type().(*types.Named)
if !ok {
return nil
}
return types.NewPointer(named)
}
type targetValue struct {
value *ssa.Value
instr ssa.Instruction
}
func getTargetTypesValues(b *ssa.BasicBlock, i int, targetTypes []*types.Pointer) []targetValue {
targetValues := []targetValue{}
instr := b.Instrs[i]
call, ok := instr.(*ssa.Call)
if !ok {
return targetValues
}
signature := call.Call.Signature()
results := signature.Results()
for i := 0; i < results.Len(); i++ {
v := results.At(i)
varType := v.Type()
for _, targetType := range targetTypes {
if !types.Identical(varType, targetType) {
continue
}
for _, cRef := range *call.Referrers() {
switch instr := cRef.(type) {
case *ssa.Call:
if len(instr.Call.Args) >= 1 && types.Identical(instr.Call.Args[0].Type(), targetType) {
targetValues = append(targetValues, targetValue{
value: &instr.Call.Args[0],
instr: call,
})
}
case ssa.Value:
if types.Identical(instr.Type(), targetType) {
targetValues = append(targetValues, targetValue{
value: &instr,
instr: call,
})
}
}
}
}
}
return targetValues
}
func checkClosed(refs *[]ssa.Instruction, targetTypes []*types.Pointer) bool {
numInstrs := len(*refs)
for idx, ref := range *refs {
// log.Printf("%T - %s", ref, ref)
action := getAction(ref, targetTypes)
switch action {
case "closed":
return true
case "passed":
// Passed and not used after
if numInstrs == idx+1 {
return true
}
case "returned":
return true
case "handled":
return true
default:
// log.Printf(action)
}
}
return false
}
func getAction(instr ssa.Instruction, targetTypes []*types.Pointer) string {
switch instr := instr.(type) {
case *ssa.Defer:
if instr.Call.Value == nil {
return "unvalued defer"
}
name := instr.Call.Value.Name()
if name == closeMethod {
return "closed"
}
case *ssa.Call:
if instr.Call.Value == nil {
return "unvalued call"
}
isTarget := false
receiver := instr.Call.StaticCallee().Signature.Recv()
if receiver != nil {
isTarget = isTargetType(receiver.Type(), targetTypes)
}
name := instr.Call.Value.Name()
if isTarget && name == closeMethod {
return "closed"
}
if !isTarget {
return "passed"
}
case *ssa.Phi:
return "passed"
case *ssa.MakeInterface:
return "passed"
case *ssa.Store:
if len(*instr.Addr.Referrers()) == 0 {
return "noop"
}
for _, aRef := range *instr.Addr.Referrers() {
if c, ok := aRef.(*ssa.MakeClosure); ok {
f := c.Fn.(*ssa.Function)
for _, b := range f.Blocks {
if checkClosed(&b.Instrs, targetTypes) {
return "handled"
}
}
}
}
case *ssa.UnOp:
instrType := instr.Type()
for _, targetType := range targetTypes {
if types.Identical(instrType, targetType) {
if checkClosed(instr.Referrers(), targetTypes) {
return "handled"
}
}
}
case *ssa.FieldAddr:
if checkClosed(instr.Referrers(), targetTypes) {
return "handled"
}
case *ssa.Return:
return "returned"
default:
// log.Printf("%s", instr)
}
return "unhandled"
}
func checkDeferred(pass *analysis.Pass, instrs *[]ssa.Instruction, targetTypes []*types.Pointer, inDefer bool) {
for _, instr := range *instrs {
switch instr := instr.(type) {
case *ssa.Defer:
if instr.Call.Value != nil && instr.Call.Value.Name() == closeMethod {
return
}
case *ssa.Call:
if instr.Call.Value != nil && instr.Call.Value.Name() == closeMethod {
if !inDefer {
pass.Reportf(instr.Pos(), "Close should use defer")
}
return
}
case *ssa.Store:
if len(*instr.Addr.Referrers()) == 0 {
return
}
for _, aRef := range *instr.Addr.Referrers() {
if c, ok := aRef.(*ssa.MakeClosure); ok {
f := c.Fn.(*ssa.Function)
for _, b := range f.Blocks {
checkDeferred(pass, &b.Instrs, targetTypes, true)
}
}
}
case *ssa.UnOp:
instrType := instr.Type()
for _, targetType := range targetTypes {
if types.Identical(instrType, targetType) {
checkDeferred(pass, instr.Referrers(), targetTypes, inDefer)
}
}
case *ssa.FieldAddr:
checkDeferred(pass, instr.Referrers(), targetTypes, inDefer)
}
}
}
func isTargetType(t types.Type, targetTypes []*types.Pointer) bool {
for _, targetType := range targetTypes {
if types.Identical(t, targetType) {
return true
}
}
return false
}