wg-quicker/vendor/github.com/vektra/mockery/v2/pkg/generator.go
Marvin Steadfast 104e72bd86
All checks were successful
continuous-integration/drone/push Build is passing
adds ability to kill a process of a userland wireguard-go
2021-04-16 14:21:36 +02:00

651 lines
17 KiB
Go

package pkg
import (
"bytes"
"context"
"errors"
"fmt"
"go/ast"
"go/build"
"go/types"
"io"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"unicode"
"github.com/rs/zerolog"
"github.com/vektra/mockery/v2/pkg/config"
"github.com/vektra/mockery/v2/pkg/logging"
"golang.org/x/tools/imports"
)
var invalidIdentifierChar = regexp.MustCompile("[^[:digit:][:alpha:]_]")
// Generator is responsible for generating the string containing
// imports and the mock struct that will later be written out as file.
type Generator struct {
config.Config
buf bytes.Buffer
iface *Interface
pkg string
localizationCache map[string]string
packagePathToName map[string]string
nameToPackagePath map[string]string
packageRoots []string
}
// NewGenerator builds a Generator.
func NewGenerator(ctx context.Context, c config.Config, iface *Interface, pkg string) *Generator {
var roots []string
for _, root := range filepath.SplitList(build.Default.GOPATH) {
roots = append(roots, filepath.Join(root, "src"))
}
g := &Generator{
Config: c,
iface: iface,
pkg: pkg,
localizationCache: make(map[string]string),
packagePathToName: make(map[string]string),
nameToPackagePath: make(map[string]string),
packageRoots: roots,
}
g.addPackageImportWithName(ctx, "github.com/stretchr/testify/mock", "mock")
return g
}
func (g *Generator) populateImports(ctx context.Context) {
log := zerolog.Ctx(ctx)
log.Debug().Msgf("populating imports")
for _, method := range g.iface.Methods() {
ftype := method.Signature
g.addImportsFromTuple(ctx, ftype.Params())
g.addImportsFromTuple(ctx, ftype.Results())
g.renderType(ctx, g.iface.NamedType)
}
}
func (g *Generator) addImportsFromTuple(ctx context.Context, list *types.Tuple) {
for i := 0; i < list.Len(); i++ {
// We use renderType here because we need to recursively
// resolve any types to make sure that all named types that
// will appear in the interface file are known
g.renderType(ctx, list.At(i).Type())
}
}
func (g *Generator) addPackageImport(ctx context.Context, pkg *types.Package) string {
return g.addPackageImportWithName(ctx, pkg.Path(), pkg.Name())
}
func (g *Generator) addPackageImportWithName(ctx context.Context, path, name string) string {
path = g.getLocalizedPath(ctx, path)
if existingName, pathExists := g.packagePathToName[path]; pathExists {
return existingName
}
nonConflictingName := g.getNonConflictingName(path, name)
g.packagePathToName[path] = nonConflictingName
g.nameToPackagePath[nonConflictingName] = path
return nonConflictingName
}
func (g *Generator) getNonConflictingName(path, name string) string {
if !g.importNameExists(name) {
return name
}
// The path will always contain '/' because it is enforced in getLocalizedPath
// regardless of OS.
directories := strings.Split(path, "/")
cleanedDirectories := make([]string, 0, len(directories))
for _, directory := range directories {
cleaned := invalidIdentifierChar.ReplaceAllString(directory, "_")
cleanedDirectories = append(cleanedDirectories, cleaned)
}
numDirectories := len(cleanedDirectories)
var prospectiveName string
for i := 1; i <= numDirectories; i++ {
prospectiveName = strings.Join(cleanedDirectories[numDirectories-i:], "")
if !g.importNameExists(prospectiveName) {
return prospectiveName
}
}
// Try adding numbers to the given name
i := 2
for {
prospectiveName = fmt.Sprintf("%v%d", name, i)
if !g.importNameExists(prospectiveName) {
return prospectiveName
}
i++
}
}
func (g *Generator) importNameExists(name string) bool {
_, nameExists := g.nameToPackagePath[name]
return nameExists
}
func calculateImport(ctx context.Context, set []string, path string) string {
log := zerolog.Ctx(ctx).With().Str(logging.LogKeyPath, path).Logger()
ctx = log.WithContext(ctx)
for _, root := range set {
if strings.HasPrefix(path, root) {
packagePath, err := filepath.Rel(root, path)
if err == nil {
return packagePath
}
log.Err(err).Msgf("Unable to localize path")
}
}
return path
}
// TODO(@IvanMalison): Is there not a better way to get the actual
// import path of a package?
func (g *Generator) getLocalizedPath(ctx context.Context, path string) string {
log := zerolog.Ctx(ctx).With().Str(logging.LogKeyPath, path).Logger()
ctx = log.WithContext(ctx)
if strings.HasSuffix(path, ".go") {
path, _ = filepath.Split(path)
}
if localized, ok := g.localizationCache[path]; ok {
return localized
}
directories := strings.Split(path, string(filepath.Separator))
numDirectories := len(directories)
vendorIndex := -1
for i := 1; i <= numDirectories; i++ {
dir := directories[numDirectories-i]
if dir == "vendor" {
vendorIndex = numDirectories - i
break
}
}
toReturn := path
if vendorIndex >= 0 {
toReturn = filepath.Join(directories[vendorIndex+1:]...)
} else if filepath.IsAbs(path) {
toReturn = calculateImport(ctx, g.packageRoots, path)
}
// Enforce '/' slashes for import paths in every OS.
toReturn = filepath.ToSlash(toReturn)
g.localizationCache[path] = toReturn
return toReturn
}
func (g *Generator) mockName() string {
if g.StructName != "" {
return g.StructName
}
if !g.KeepTree && g.InPackage {
if g.Exported || ast.IsExported(g.iface.Name) {
return "Mock" + g.iface.Name
}
first := true
return "mock" + strings.Map(func(r rune) rune {
if first {
first = false
return unicode.ToUpper(r)
}
return r
}, g.iface.Name)
}
return g.iface.Name
}
func (g *Generator) sortedImportNames() (importNames []string) {
for name := range g.nameToPackagePath {
importNames = append(importNames, name)
}
sort.Strings(importNames)
return
}
func (g *Generator) generateImports(ctx context.Context) {
log := zerolog.Ctx(ctx)
log.Debug().Msgf("generating imports")
log.Debug().Msgf("%v", g.nameToPackagePath)
pkgPath := g.nameToPackagePath[g.iface.Pkg.Name()]
// Sort by import name so that we get a deterministic order
for _, name := range g.sortedImportNames() {
logImport := log.With().Str(logging.LogKeyImport, g.nameToPackagePath[name]).Logger()
logImport.Debug().Msgf("found import")
path := g.nameToPackagePath[name]
if !g.KeepTree && g.InPackage && path == pkgPath {
logImport.Debug().Msgf("import (%s) equals interface's package path (%s), skipping", path, pkgPath)
continue
}
g.printf("import %s \"%s\"\n", name, path)
}
}
// GeneratePrologue generates the prologue of the mock.
func (g *Generator) GeneratePrologue(ctx context.Context, pkg string) {
g.populateImports(ctx)
if g.InPackage {
g.printf("package %s\n\n", g.iface.Pkg.Name())
} else {
g.printf("package %v\n\n", pkg)
}
g.generateImports(ctx)
g.printf("\n")
}
// GeneratePrologueNote adds a note after the prologue to the output
// string.
func (g *Generator) GeneratePrologueNote(note string) {
prologue := "// Code generated by mockery"
if !g.Config.DisableVersionString {
prologue += fmt.Sprintf(" %s", config.GetSemverInfo())
}
prologue += ". DO NOT EDIT.\n"
g.printf(prologue)
if note != "" {
g.printf("\n")
for _, n := range strings.Split(note, "\\n") {
g.printf("// %s\n", n)
}
}
g.printf("\n")
}
// GenerateBoilerplate adds a boilerplate text. It should be called
// before any other generator methods to ensure the text is on top.
func (g *Generator) GenerateBoilerplate(boilerplate string) {
if boilerplate != "" {
g.printf("%s\n", boilerplate)
}
}
// ErrNotInterface is returned when the given type is not an interface
// type.
var ErrNotInterface = errors.New("expression not an interface")
func (g *Generator) printf(s string, vals ...interface{}) {
fmt.Fprintf(&g.buf, s, vals...)
}
type namer interface {
Name() string
}
func (g *Generator) renderType(ctx context.Context, typ types.Type) string {
switch t := typ.(type) {
case *types.Named:
o := t.Obj()
if o.Pkg() == nil || o.Pkg().Name() == "main" || (!g.KeepTree && g.InPackage && o.Pkg() == g.iface.Pkg) {
return o.Name()
}
return g.addPackageImport(ctx, o.Pkg()) + "." + o.Name()
case *types.Basic:
return t.Name()
case *types.Pointer:
return "*" + g.renderType(ctx, t.Elem())
case *types.Slice:
return "[]" + g.renderType(ctx, t.Elem())
case *types.Array:
return fmt.Sprintf("[%d]%s", t.Len(), g.renderType(ctx, t.Elem()))
case *types.Signature:
switch t.Results().Len() {
case 0:
return fmt.Sprintf(
"func(%s)",
g.renderTypeTuple(ctx, t.Params()),
)
case 1:
return fmt.Sprintf(
"func(%s) %s",
g.renderTypeTuple(ctx, t.Params()),
g.renderType(ctx, t.Results().At(0).Type()),
)
default:
return fmt.Sprintf(
"func(%s)(%s)",
g.renderTypeTuple(ctx, t.Params()),
g.renderTypeTuple(ctx, t.Results()),
)
}
case *types.Map:
kt := g.renderType(ctx, t.Key())
vt := g.renderType(ctx, t.Elem())
return fmt.Sprintf("map[%s]%s", kt, vt)
case *types.Chan:
switch t.Dir() {
case types.SendRecv:
return "chan " + g.renderType(ctx, t.Elem())
case types.RecvOnly:
return "<-chan " + g.renderType(ctx, t.Elem())
default:
return "chan<- " + g.renderType(ctx, t.Elem())
}
case *types.Struct:
var fields []string
for i := 0; i < t.NumFields(); i++ {
f := t.Field(i)
if f.Anonymous() {
fields = append(fields, g.renderType(ctx, f.Type()))
} else {
fields = append(fields, fmt.Sprintf("%s %s", f.Name(), g.renderType(ctx, f.Type())))
}
}
return fmt.Sprintf("struct{%s}", strings.Join(fields, ";"))
case *types.Interface:
if t.NumMethods() != 0 {
panic("Unable to mock inline interfaces with methods")
}
return "interface{}"
case namer:
return t.Name()
default:
panic(fmt.Sprintf("un-namable type: %#v (%T)", t, t))
}
}
func (g *Generator) renderTypeTuple(ctx context.Context, tup *types.Tuple) string {
var parts []string
for i := 0; i < tup.Len(); i++ {
v := tup.At(i)
parts = append(parts, g.renderType(ctx, v.Type()))
}
return strings.Join(parts, " , ")
}
func isNillable(typ types.Type) bool {
switch t := typ.(type) {
case *types.Pointer, *types.Array, *types.Map, *types.Interface, *types.Signature, *types.Chan, *types.Slice:
return true
case *types.Named:
return isNillable(t.Underlying())
}
return false
}
type paramList struct {
Names []string
Types []string
Params []string
Nilable []bool
Variadic bool
}
func (g *Generator) genList(ctx context.Context, list *types.Tuple, variadic bool) *paramList {
var params paramList
if list == nil {
return &params
}
for i := 0; i < list.Len(); i++ {
v := list.At(i)
ts := g.renderType(ctx, v.Type())
if variadic && i == list.Len()-1 {
t := v.Type()
switch t := t.(type) {
case *types.Slice:
params.Variadic = true
ts = "..." + g.renderType(ctx, t.Elem())
default:
panic("bad variadic type!")
}
}
pname := v.Name()
if g.nameCollides(pname) || pname == "" {
pname = fmt.Sprintf("_a%d", i)
}
params.Names = append(params.Names, pname)
params.Types = append(params.Types, ts)
params.Params = append(params.Params, fmt.Sprintf("%s %s", pname, ts))
params.Nilable = append(params.Nilable, isNillable(v.Type()))
}
return &params
}
func (g *Generator) nameCollides(pname string) bool {
if pname == g.pkg {
return true
}
return g.importNameExists(pname)
}
// ErrNotSetup is returned when the generator is not configured.
var ErrNotSetup = errors.New("not setup")
// Generate builds a string that constitutes a valid go source file
// containing the mock of the relevant interface.
func (g *Generator) Generate(ctx context.Context) error {
g.populateImports(ctx)
if g.iface == nil {
return ErrNotSetup
}
g.printf(
"// %s is an autogenerated mock type for the %s type\n", g.mockName(),
g.iface.Name,
)
g.printf(
"type %s struct {\n\tmock.Mock\n}\n\n", g.mockName(),
)
for _, method := range g.iface.Methods() {
ftype := method.Signature
fname := method.Name
params := g.genList(ctx, ftype.Params(), ftype.Variadic())
returns := g.genList(ctx, ftype.Results(), false)
if len(params.Names) == 0 {
g.printf("// %s provides a mock function with given fields:\n", fname)
} else {
g.printf(
"// %s provides a mock function with given fields: %s\n", fname,
strings.Join(params.Names, ", "),
)
}
g.printf(
"func (_m *%s) %s(%s) ", g.mockName(), fname,
strings.Join(params.Params, ", "),
)
switch len(returns.Types) {
case 0:
g.printf("{\n")
case 1:
g.printf("%s {\n", returns.Types[0])
default:
g.printf("(%s) {\n", strings.Join(returns.Types, ", "))
}
formattedParamNames := ""
setOfParamNames := make(map[string]struct{}, len(params.Names))
for i, name := range params.Names {
if i > 0 {
formattedParamNames += ", "
}
paramType := params.Types[i]
// for variable args, move the ... to the end.
if strings.Index(paramType, "...") == 0 {
name += "..."
}
formattedParamNames += name
setOfParamNames[name] = struct{}{}
}
called := g.generateCalled(params, formattedParamNames) // _m.Called invocation string
if len(returns.Types) > 0 {
retVariable := resolveCollision(setOfParamNames, "ret")
g.printf("\t%s := %s\n\n", retVariable, called)
ret := make([]string, len(returns.Types))
for idx, typ := range returns.Types {
g.printf("\tvar r%d %s\n", idx, typ)
g.printf("\tif rf, ok := %s.Get(%d).(func(%s) %s); ok {\n",
retVariable, idx, strings.Join(params.Types, ", "), typ)
g.printf("\t\tr%d = rf(%s)\n", idx, formattedParamNames)
g.printf("\t} else {\n")
if typ == "error" {
g.printf("\t\tr%d = %s.Error(%d)\n", idx, retVariable, idx)
} else if returns.Nilable[idx] {
g.printf("\t\tif %s.Get(%d) != nil {\n", retVariable, idx)
g.printf("\t\t\tr%d = %s.Get(%d).(%s)\n", idx, retVariable, idx, typ)
g.printf("\t\t}\n")
} else {
g.printf("\t\tr%d = %s.Get(%d).(%s)\n", idx, retVariable, idx, typ)
}
g.printf("\t}\n\n")
ret[idx] = fmt.Sprintf("r%d", idx)
}
g.printf("\treturn %s\n", strings.Join(ret, ", "))
} else {
g.printf("\t%s\n", called)
}
g.printf("}\n")
}
return nil
}
// generateCalled returns the Mock.Called invocation string and, if necessary, prints the
// steps to prepare its argument list.
//
// It is separate from Generate to avoid cyclomatic complexity through early return statements.
func (g *Generator) generateCalled(list *paramList, formattedParamNames string) string {
namesLen := len(list.Names)
if namesLen == 0 {
return "_m.Called()"
}
if !list.Variadic {
return "_m.Called(" + formattedParamNames + ")"
}
if !g.UnrollVariadic {
return "_m.Called(" + strings.Join(list.Names, ", ") + ")"
}
var variadicArgsName string
variadicName := list.Names[namesLen-1]
// list.Types[] will contain a leading '...'. Strip this from the string to
// do easier comparison.
strippedIfaceType := strings.Trim(list.Types[namesLen-1], "...")
variadicIface := strippedIfaceType == "interface{}"
if variadicIface {
// Variadic is already of the interface{} type, so we don't need special handling.
variadicArgsName = variadicName
} else {
// Define _va to avoid "cannot use t (type T) as type []interface {} in append" error
// whenever the variadic type is non-interface{}.
g.printf("\t_va := make([]interface{}, len(%s))\n", variadicName)
g.printf("\tfor _i := range %s {\n\t\t_va[_i] = %s[_i]\n\t}\n", variadicName, variadicName)
variadicArgsName = "_va"
}
// _ca will hold all arguments we'll mirror into Called, one argument per distinct value
// passed to the method.
//
// For example, if the second argument is variadic and consists of three values,
// a total of 4 arguments will be passed to Called. The alternative is to
// pass a total of 2 arguments where the second is a slice with those 3 values from
// the variadic argument. But the alternative is less accessible because it requires
// building a []interface{} before calling Mock methods like On and AssertCalled for
// the variadic argument, and creates incompatibility issues with the diff algorithm
// in github.com/stretchr/testify/mock.
//
// This mirroring will allow argument lists for methods like On and AssertCalled to
// always resemble the expected calls they describe and retain compatibility.
//
// It's okay for us to use the interface{} type, regardless of the actual types, because
// Called receives only interface{} anyway.
g.printf("\tvar _ca []interface{}\n")
if namesLen > 1 {
nonVariadicParamNames := formattedParamNames[0:strings.LastIndex(formattedParamNames, ",")]
g.printf("\t_ca = append(_ca, %s)\n", nonVariadicParamNames)
}
g.printf("\t_ca = append(_ca, %s...)\n", variadicArgsName)
return "_m.Called(_ca...)"
}
func (g *Generator) Write(w io.Writer) error {
opt := &imports.Options{Comments: true}
theBytes := g.buf.Bytes()
res, err := imports.Process("mock.go", theBytes, opt)
if err != nil {
line := "--------------------------------------------------------------------------------------------"
fmt.Fprintf(os.Stderr, "Between the lines is the file (mock.go) mockery generated in-memory but detected as invalid:\n%s\n%s\n%s\n", line, g.buf.String(), line)
return err
}
w.Write(res)
return nil
}
func resolveCollision(names map[string]struct{}, variable string) string {
ret := variable
for i := len(names); true; i++ {
_, ok := names[ret]
if !ok {
break
}
ret = fmt.Sprintf("%s_%d", variable, i)
}
return ret
}