Marvin Steadfast
104e72bd86
All checks were successful
continuous-integration/drone/push Build is passing
651 lines
17 KiB
Go
651 lines
17 KiB
Go
package pkg
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/build"
|
|
"go/types"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/vektra/mockery/v2/pkg/config"
|
|
"github.com/vektra/mockery/v2/pkg/logging"
|
|
"golang.org/x/tools/imports"
|
|
)
|
|
|
|
var invalidIdentifierChar = regexp.MustCompile("[^[:digit:][:alpha:]_]")
|
|
|
|
// Generator is responsible for generating the string containing
|
|
// imports and the mock struct that will later be written out as file.
|
|
type Generator struct {
|
|
config.Config
|
|
buf bytes.Buffer
|
|
|
|
iface *Interface
|
|
pkg string
|
|
|
|
localizationCache map[string]string
|
|
packagePathToName map[string]string
|
|
nameToPackagePath map[string]string
|
|
|
|
packageRoots []string
|
|
}
|
|
|
|
// NewGenerator builds a Generator.
|
|
func NewGenerator(ctx context.Context, c config.Config, iface *Interface, pkg string) *Generator {
|
|
|
|
var roots []string
|
|
|
|
for _, root := range filepath.SplitList(build.Default.GOPATH) {
|
|
roots = append(roots, filepath.Join(root, "src"))
|
|
}
|
|
|
|
g := &Generator{
|
|
Config: c,
|
|
iface: iface,
|
|
pkg: pkg,
|
|
localizationCache: make(map[string]string),
|
|
packagePathToName: make(map[string]string),
|
|
nameToPackagePath: make(map[string]string),
|
|
packageRoots: roots,
|
|
}
|
|
|
|
g.addPackageImportWithName(ctx, "github.com/stretchr/testify/mock", "mock")
|
|
return g
|
|
}
|
|
|
|
func (g *Generator) populateImports(ctx context.Context) {
|
|
log := zerolog.Ctx(ctx)
|
|
|
|
log.Debug().Msgf("populating imports")
|
|
|
|
for _, method := range g.iface.Methods() {
|
|
ftype := method.Signature
|
|
g.addImportsFromTuple(ctx, ftype.Params())
|
|
g.addImportsFromTuple(ctx, ftype.Results())
|
|
g.renderType(ctx, g.iface.NamedType)
|
|
}
|
|
}
|
|
|
|
func (g *Generator) addImportsFromTuple(ctx context.Context, list *types.Tuple) {
|
|
for i := 0; i < list.Len(); i++ {
|
|
// We use renderType here because we need to recursively
|
|
// resolve any types to make sure that all named types that
|
|
// will appear in the interface file are known
|
|
g.renderType(ctx, list.At(i).Type())
|
|
}
|
|
}
|
|
|
|
func (g *Generator) addPackageImport(ctx context.Context, pkg *types.Package) string {
|
|
return g.addPackageImportWithName(ctx, pkg.Path(), pkg.Name())
|
|
}
|
|
|
|
func (g *Generator) addPackageImportWithName(ctx context.Context, path, name string) string {
|
|
path = g.getLocalizedPath(ctx, path)
|
|
if existingName, pathExists := g.packagePathToName[path]; pathExists {
|
|
return existingName
|
|
}
|
|
|
|
nonConflictingName := g.getNonConflictingName(path, name)
|
|
g.packagePathToName[path] = nonConflictingName
|
|
g.nameToPackagePath[nonConflictingName] = path
|
|
return nonConflictingName
|
|
}
|
|
|
|
func (g *Generator) getNonConflictingName(path, name string) string {
|
|
if !g.importNameExists(name) {
|
|
return name
|
|
}
|
|
|
|
// The path will always contain '/' because it is enforced in getLocalizedPath
|
|
// regardless of OS.
|
|
directories := strings.Split(path, "/")
|
|
|
|
cleanedDirectories := make([]string, 0, len(directories))
|
|
for _, directory := range directories {
|
|
cleaned := invalidIdentifierChar.ReplaceAllString(directory, "_")
|
|
cleanedDirectories = append(cleanedDirectories, cleaned)
|
|
}
|
|
numDirectories := len(cleanedDirectories)
|
|
var prospectiveName string
|
|
for i := 1; i <= numDirectories; i++ {
|
|
prospectiveName = strings.Join(cleanedDirectories[numDirectories-i:], "")
|
|
if !g.importNameExists(prospectiveName) {
|
|
return prospectiveName
|
|
}
|
|
}
|
|
// Try adding numbers to the given name
|
|
i := 2
|
|
for {
|
|
prospectiveName = fmt.Sprintf("%v%d", name, i)
|
|
if !g.importNameExists(prospectiveName) {
|
|
return prospectiveName
|
|
}
|
|
i++
|
|
}
|
|
}
|
|
|
|
func (g *Generator) importNameExists(name string) bool {
|
|
_, nameExists := g.nameToPackagePath[name]
|
|
return nameExists
|
|
}
|
|
|
|
func calculateImport(ctx context.Context, set []string, path string) string {
|
|
log := zerolog.Ctx(ctx).With().Str(logging.LogKeyPath, path).Logger()
|
|
ctx = log.WithContext(ctx)
|
|
|
|
for _, root := range set {
|
|
if strings.HasPrefix(path, root) {
|
|
packagePath, err := filepath.Rel(root, path)
|
|
if err == nil {
|
|
return packagePath
|
|
}
|
|
log.Err(err).Msgf("Unable to localize path")
|
|
}
|
|
}
|
|
return path
|
|
}
|
|
|
|
// TODO(@IvanMalison): Is there not a better way to get the actual
|
|
// import path of a package?
|
|
func (g *Generator) getLocalizedPath(ctx context.Context, path string) string {
|
|
log := zerolog.Ctx(ctx).With().Str(logging.LogKeyPath, path).Logger()
|
|
ctx = log.WithContext(ctx)
|
|
|
|
if strings.HasSuffix(path, ".go") {
|
|
path, _ = filepath.Split(path)
|
|
}
|
|
if localized, ok := g.localizationCache[path]; ok {
|
|
return localized
|
|
}
|
|
directories := strings.Split(path, string(filepath.Separator))
|
|
numDirectories := len(directories)
|
|
vendorIndex := -1
|
|
for i := 1; i <= numDirectories; i++ {
|
|
dir := directories[numDirectories-i]
|
|
if dir == "vendor" {
|
|
vendorIndex = numDirectories - i
|
|
break
|
|
}
|
|
}
|
|
|
|
toReturn := path
|
|
if vendorIndex >= 0 {
|
|
toReturn = filepath.Join(directories[vendorIndex+1:]...)
|
|
} else if filepath.IsAbs(path) {
|
|
toReturn = calculateImport(ctx, g.packageRoots, path)
|
|
}
|
|
|
|
// Enforce '/' slashes for import paths in every OS.
|
|
toReturn = filepath.ToSlash(toReturn)
|
|
|
|
g.localizationCache[path] = toReturn
|
|
return toReturn
|
|
}
|
|
|
|
func (g *Generator) mockName() string {
|
|
if g.StructName != "" {
|
|
return g.StructName
|
|
}
|
|
|
|
if !g.KeepTree && g.InPackage {
|
|
if g.Exported || ast.IsExported(g.iface.Name) {
|
|
return "Mock" + g.iface.Name
|
|
}
|
|
|
|
first := true
|
|
return "mock" + strings.Map(func(r rune) rune {
|
|
if first {
|
|
first = false
|
|
return unicode.ToUpper(r)
|
|
}
|
|
return r
|
|
}, g.iface.Name)
|
|
}
|
|
|
|
return g.iface.Name
|
|
}
|
|
|
|
func (g *Generator) sortedImportNames() (importNames []string) {
|
|
for name := range g.nameToPackagePath {
|
|
importNames = append(importNames, name)
|
|
}
|
|
sort.Strings(importNames)
|
|
return
|
|
}
|
|
|
|
func (g *Generator) generateImports(ctx context.Context) {
|
|
log := zerolog.Ctx(ctx)
|
|
|
|
log.Debug().Msgf("generating imports")
|
|
log.Debug().Msgf("%v", g.nameToPackagePath)
|
|
|
|
pkgPath := g.nameToPackagePath[g.iface.Pkg.Name()]
|
|
// Sort by import name so that we get a deterministic order
|
|
for _, name := range g.sortedImportNames() {
|
|
logImport := log.With().Str(logging.LogKeyImport, g.nameToPackagePath[name]).Logger()
|
|
logImport.Debug().Msgf("found import")
|
|
|
|
path := g.nameToPackagePath[name]
|
|
if !g.KeepTree && g.InPackage && path == pkgPath {
|
|
logImport.Debug().Msgf("import (%s) equals interface's package path (%s), skipping", path, pkgPath)
|
|
continue
|
|
}
|
|
g.printf("import %s \"%s\"\n", name, path)
|
|
}
|
|
}
|
|
|
|
// GeneratePrologue generates the prologue of the mock.
|
|
func (g *Generator) GeneratePrologue(ctx context.Context, pkg string) {
|
|
g.populateImports(ctx)
|
|
if g.InPackage {
|
|
g.printf("package %s\n\n", g.iface.Pkg.Name())
|
|
} else {
|
|
g.printf("package %v\n\n", pkg)
|
|
}
|
|
|
|
g.generateImports(ctx)
|
|
g.printf("\n")
|
|
}
|
|
|
|
// GeneratePrologueNote adds a note after the prologue to the output
|
|
// string.
|
|
func (g *Generator) GeneratePrologueNote(note string) {
|
|
prologue := "// Code generated by mockery"
|
|
if !g.Config.DisableVersionString {
|
|
prologue += fmt.Sprintf(" %s", config.GetSemverInfo())
|
|
}
|
|
prologue += ". DO NOT EDIT.\n"
|
|
|
|
g.printf(prologue)
|
|
if note != "" {
|
|
g.printf("\n")
|
|
for _, n := range strings.Split(note, "\\n") {
|
|
g.printf("// %s\n", n)
|
|
}
|
|
}
|
|
g.printf("\n")
|
|
}
|
|
|
|
// GenerateBoilerplate adds a boilerplate text. It should be called
|
|
// before any other generator methods to ensure the text is on top.
|
|
func (g *Generator) GenerateBoilerplate(boilerplate string) {
|
|
if boilerplate != "" {
|
|
g.printf("%s\n", boilerplate)
|
|
}
|
|
}
|
|
|
|
// ErrNotInterface is returned when the given type is not an interface
|
|
// type.
|
|
var ErrNotInterface = errors.New("expression not an interface")
|
|
|
|
func (g *Generator) printf(s string, vals ...interface{}) {
|
|
fmt.Fprintf(&g.buf, s, vals...)
|
|
}
|
|
|
|
type namer interface {
|
|
Name() string
|
|
}
|
|
|
|
func (g *Generator) renderType(ctx context.Context, typ types.Type) string {
|
|
switch t := typ.(type) {
|
|
case *types.Named:
|
|
o := t.Obj()
|
|
if o.Pkg() == nil || o.Pkg().Name() == "main" || (!g.KeepTree && g.InPackage && o.Pkg() == g.iface.Pkg) {
|
|
return o.Name()
|
|
}
|
|
return g.addPackageImport(ctx, o.Pkg()) + "." + o.Name()
|
|
case *types.Basic:
|
|
return t.Name()
|
|
case *types.Pointer:
|
|
return "*" + g.renderType(ctx, t.Elem())
|
|
case *types.Slice:
|
|
return "[]" + g.renderType(ctx, t.Elem())
|
|
case *types.Array:
|
|
return fmt.Sprintf("[%d]%s", t.Len(), g.renderType(ctx, t.Elem()))
|
|
case *types.Signature:
|
|
switch t.Results().Len() {
|
|
case 0:
|
|
return fmt.Sprintf(
|
|
"func(%s)",
|
|
g.renderTypeTuple(ctx, t.Params()),
|
|
)
|
|
case 1:
|
|
return fmt.Sprintf(
|
|
"func(%s) %s",
|
|
g.renderTypeTuple(ctx, t.Params()),
|
|
g.renderType(ctx, t.Results().At(0).Type()),
|
|
)
|
|
default:
|
|
return fmt.Sprintf(
|
|
"func(%s)(%s)",
|
|
g.renderTypeTuple(ctx, t.Params()),
|
|
g.renderTypeTuple(ctx, t.Results()),
|
|
)
|
|
}
|
|
case *types.Map:
|
|
kt := g.renderType(ctx, t.Key())
|
|
vt := g.renderType(ctx, t.Elem())
|
|
|
|
return fmt.Sprintf("map[%s]%s", kt, vt)
|
|
case *types.Chan:
|
|
switch t.Dir() {
|
|
case types.SendRecv:
|
|
return "chan " + g.renderType(ctx, t.Elem())
|
|
case types.RecvOnly:
|
|
return "<-chan " + g.renderType(ctx, t.Elem())
|
|
default:
|
|
return "chan<- " + g.renderType(ctx, t.Elem())
|
|
}
|
|
case *types.Struct:
|
|
var fields []string
|
|
|
|
for i := 0; i < t.NumFields(); i++ {
|
|
f := t.Field(i)
|
|
|
|
if f.Anonymous() {
|
|
fields = append(fields, g.renderType(ctx, f.Type()))
|
|
} else {
|
|
fields = append(fields, fmt.Sprintf("%s %s", f.Name(), g.renderType(ctx, f.Type())))
|
|
}
|
|
}
|
|
|
|
return fmt.Sprintf("struct{%s}", strings.Join(fields, ";"))
|
|
case *types.Interface:
|
|
if t.NumMethods() != 0 {
|
|
panic("Unable to mock inline interfaces with methods")
|
|
}
|
|
|
|
return "interface{}"
|
|
case namer:
|
|
return t.Name()
|
|
default:
|
|
panic(fmt.Sprintf("un-namable type: %#v (%T)", t, t))
|
|
}
|
|
}
|
|
|
|
func (g *Generator) renderTypeTuple(ctx context.Context, tup *types.Tuple) string {
|
|
var parts []string
|
|
|
|
for i := 0; i < tup.Len(); i++ {
|
|
v := tup.At(i)
|
|
|
|
parts = append(parts, g.renderType(ctx, v.Type()))
|
|
}
|
|
|
|
return strings.Join(parts, " , ")
|
|
}
|
|
|
|
func isNillable(typ types.Type) bool {
|
|
switch t := typ.(type) {
|
|
case *types.Pointer, *types.Array, *types.Map, *types.Interface, *types.Signature, *types.Chan, *types.Slice:
|
|
return true
|
|
case *types.Named:
|
|
return isNillable(t.Underlying())
|
|
}
|
|
return false
|
|
}
|
|
|
|
type paramList struct {
|
|
Names []string
|
|
Types []string
|
|
Params []string
|
|
Nilable []bool
|
|
Variadic bool
|
|
}
|
|
|
|
func (g *Generator) genList(ctx context.Context, list *types.Tuple, variadic bool) *paramList {
|
|
var params paramList
|
|
|
|
if list == nil {
|
|
return ¶ms
|
|
}
|
|
|
|
for i := 0; i < list.Len(); i++ {
|
|
v := list.At(i)
|
|
|
|
ts := g.renderType(ctx, v.Type())
|
|
|
|
if variadic && i == list.Len()-1 {
|
|
t := v.Type()
|
|
switch t := t.(type) {
|
|
case *types.Slice:
|
|
params.Variadic = true
|
|
ts = "..." + g.renderType(ctx, t.Elem())
|
|
default:
|
|
panic("bad variadic type!")
|
|
}
|
|
}
|
|
|
|
pname := v.Name()
|
|
|
|
if g.nameCollides(pname) || pname == "" {
|
|
pname = fmt.Sprintf("_a%d", i)
|
|
}
|
|
|
|
params.Names = append(params.Names, pname)
|
|
params.Types = append(params.Types, ts)
|
|
|
|
params.Params = append(params.Params, fmt.Sprintf("%s %s", pname, ts))
|
|
params.Nilable = append(params.Nilable, isNillable(v.Type()))
|
|
}
|
|
|
|
return ¶ms
|
|
}
|
|
|
|
func (g *Generator) nameCollides(pname string) bool {
|
|
if pname == g.pkg {
|
|
return true
|
|
}
|
|
return g.importNameExists(pname)
|
|
}
|
|
|
|
// ErrNotSetup is returned when the generator is not configured.
|
|
var ErrNotSetup = errors.New("not setup")
|
|
|
|
// Generate builds a string that constitutes a valid go source file
|
|
// containing the mock of the relevant interface.
|
|
func (g *Generator) Generate(ctx context.Context) error {
|
|
g.populateImports(ctx)
|
|
if g.iface == nil {
|
|
return ErrNotSetup
|
|
}
|
|
|
|
g.printf(
|
|
"// %s is an autogenerated mock type for the %s type\n", g.mockName(),
|
|
g.iface.Name,
|
|
)
|
|
|
|
g.printf(
|
|
"type %s struct {\n\tmock.Mock\n}\n\n", g.mockName(),
|
|
)
|
|
|
|
for _, method := range g.iface.Methods() {
|
|
|
|
ftype := method.Signature
|
|
fname := method.Name
|
|
|
|
params := g.genList(ctx, ftype.Params(), ftype.Variadic())
|
|
returns := g.genList(ctx, ftype.Results(), false)
|
|
|
|
if len(params.Names) == 0 {
|
|
g.printf("// %s provides a mock function with given fields:\n", fname)
|
|
} else {
|
|
g.printf(
|
|
"// %s provides a mock function with given fields: %s\n", fname,
|
|
strings.Join(params.Names, ", "),
|
|
)
|
|
}
|
|
g.printf(
|
|
"func (_m *%s) %s(%s) ", g.mockName(), fname,
|
|
strings.Join(params.Params, ", "),
|
|
)
|
|
|
|
switch len(returns.Types) {
|
|
case 0:
|
|
g.printf("{\n")
|
|
case 1:
|
|
g.printf("%s {\n", returns.Types[0])
|
|
default:
|
|
g.printf("(%s) {\n", strings.Join(returns.Types, ", "))
|
|
}
|
|
|
|
formattedParamNames := ""
|
|
setOfParamNames := make(map[string]struct{}, len(params.Names))
|
|
for i, name := range params.Names {
|
|
if i > 0 {
|
|
formattedParamNames += ", "
|
|
}
|
|
|
|
paramType := params.Types[i]
|
|
// for variable args, move the ... to the end.
|
|
if strings.Index(paramType, "...") == 0 {
|
|
name += "..."
|
|
}
|
|
formattedParamNames += name
|
|
|
|
setOfParamNames[name] = struct{}{}
|
|
}
|
|
|
|
called := g.generateCalled(params, formattedParamNames) // _m.Called invocation string
|
|
|
|
if len(returns.Types) > 0 {
|
|
retVariable := resolveCollision(setOfParamNames, "ret")
|
|
g.printf("\t%s := %s\n\n", retVariable, called)
|
|
|
|
ret := make([]string, len(returns.Types))
|
|
|
|
for idx, typ := range returns.Types {
|
|
g.printf("\tvar r%d %s\n", idx, typ)
|
|
g.printf("\tif rf, ok := %s.Get(%d).(func(%s) %s); ok {\n",
|
|
retVariable, idx, strings.Join(params.Types, ", "), typ)
|
|
g.printf("\t\tr%d = rf(%s)\n", idx, formattedParamNames)
|
|
g.printf("\t} else {\n")
|
|
if typ == "error" {
|
|
g.printf("\t\tr%d = %s.Error(%d)\n", idx, retVariable, idx)
|
|
} else if returns.Nilable[idx] {
|
|
g.printf("\t\tif %s.Get(%d) != nil {\n", retVariable, idx)
|
|
g.printf("\t\t\tr%d = %s.Get(%d).(%s)\n", idx, retVariable, idx, typ)
|
|
g.printf("\t\t}\n")
|
|
} else {
|
|
g.printf("\t\tr%d = %s.Get(%d).(%s)\n", idx, retVariable, idx, typ)
|
|
}
|
|
g.printf("\t}\n\n")
|
|
|
|
ret[idx] = fmt.Sprintf("r%d", idx)
|
|
}
|
|
|
|
g.printf("\treturn %s\n", strings.Join(ret, ", "))
|
|
} else {
|
|
g.printf("\t%s\n", called)
|
|
}
|
|
|
|
g.printf("}\n")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// generateCalled returns the Mock.Called invocation string and, if necessary, prints the
|
|
// steps to prepare its argument list.
|
|
//
|
|
// It is separate from Generate to avoid cyclomatic complexity through early return statements.
|
|
func (g *Generator) generateCalled(list *paramList, formattedParamNames string) string {
|
|
namesLen := len(list.Names)
|
|
if namesLen == 0 {
|
|
return "_m.Called()"
|
|
}
|
|
|
|
if !list.Variadic {
|
|
return "_m.Called(" + formattedParamNames + ")"
|
|
}
|
|
|
|
if !g.UnrollVariadic {
|
|
return "_m.Called(" + strings.Join(list.Names, ", ") + ")"
|
|
}
|
|
|
|
var variadicArgsName string
|
|
variadicName := list.Names[namesLen-1]
|
|
|
|
// list.Types[] will contain a leading '...'. Strip this from the string to
|
|
// do easier comparison.
|
|
strippedIfaceType := strings.Trim(list.Types[namesLen-1], "...")
|
|
variadicIface := strippedIfaceType == "interface{}"
|
|
|
|
if variadicIface {
|
|
// Variadic is already of the interface{} type, so we don't need special handling.
|
|
variadicArgsName = variadicName
|
|
} else {
|
|
// Define _va to avoid "cannot use t (type T) as type []interface {} in append" error
|
|
// whenever the variadic type is non-interface{}.
|
|
g.printf("\t_va := make([]interface{}, len(%s))\n", variadicName)
|
|
g.printf("\tfor _i := range %s {\n\t\t_va[_i] = %s[_i]\n\t}\n", variadicName, variadicName)
|
|
variadicArgsName = "_va"
|
|
}
|
|
|
|
// _ca will hold all arguments we'll mirror into Called, one argument per distinct value
|
|
// passed to the method.
|
|
//
|
|
// For example, if the second argument is variadic and consists of three values,
|
|
// a total of 4 arguments will be passed to Called. The alternative is to
|
|
// pass a total of 2 arguments where the second is a slice with those 3 values from
|
|
// the variadic argument. But the alternative is less accessible because it requires
|
|
// building a []interface{} before calling Mock methods like On and AssertCalled for
|
|
// the variadic argument, and creates incompatibility issues with the diff algorithm
|
|
// in github.com/stretchr/testify/mock.
|
|
//
|
|
// This mirroring will allow argument lists for methods like On and AssertCalled to
|
|
// always resemble the expected calls they describe and retain compatibility.
|
|
//
|
|
// It's okay for us to use the interface{} type, regardless of the actual types, because
|
|
// Called receives only interface{} anyway.
|
|
g.printf("\tvar _ca []interface{}\n")
|
|
|
|
if namesLen > 1 {
|
|
nonVariadicParamNames := formattedParamNames[0:strings.LastIndex(formattedParamNames, ",")]
|
|
g.printf("\t_ca = append(_ca, %s)\n", nonVariadicParamNames)
|
|
}
|
|
g.printf("\t_ca = append(_ca, %s...)\n", variadicArgsName)
|
|
|
|
return "_m.Called(_ca...)"
|
|
}
|
|
|
|
func (g *Generator) Write(w io.Writer) error {
|
|
opt := &imports.Options{Comments: true}
|
|
theBytes := g.buf.Bytes()
|
|
|
|
res, err := imports.Process("mock.go", theBytes, opt)
|
|
if err != nil {
|
|
line := "--------------------------------------------------------------------------------------------"
|
|
fmt.Fprintf(os.Stderr, "Between the lines is the file (mock.go) mockery generated in-memory but detected as invalid:\n%s\n%s\n%s\n", line, g.buf.String(), line)
|
|
return err
|
|
}
|
|
|
|
w.Write(res)
|
|
return nil
|
|
}
|
|
|
|
func resolveCollision(names map[string]struct{}, variable string) string {
|
|
ret := variable
|
|
|
|
for i := len(names); true; i++ {
|
|
_, ok := names[ret]
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
ret = fmt.Sprintf("%s_%d", variable, i)
|
|
}
|
|
|
|
return ret
|
|
}
|