Marvin Steadfast
104e72bd86
All checks were successful
continuous-integration/drone/push Build is passing
106 lines
2.9 KiB
Go
106 lines
2.9 KiB
Go
package pkg
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/vektra/mockery/v2/pkg/config"
|
|
"github.com/vektra/mockery/v2/pkg/logging"
|
|
)
|
|
|
|
type Cleanup func() error
|
|
|
|
type OutputStreamProvider interface {
|
|
GetWriter(context.Context, *Interface) (io.Writer, error, Cleanup)
|
|
}
|
|
|
|
type StdoutStreamProvider struct {
|
|
}
|
|
|
|
func (*StdoutStreamProvider) GetWriter(ctx context.Context, iface *Interface) (io.Writer, error, Cleanup) {
|
|
return os.Stdout, nil, func() error { return nil }
|
|
}
|
|
|
|
type FileOutputStreamProvider struct {
|
|
Config config.Config
|
|
BaseDir string
|
|
InPackage bool
|
|
TestOnly bool
|
|
Case string
|
|
KeepTree bool
|
|
KeepTreeOriginalDirectory string
|
|
FileName string
|
|
}
|
|
|
|
func (p *FileOutputStreamProvider) GetWriter(ctx context.Context, iface *Interface) (io.Writer, error, Cleanup) {
|
|
log := zerolog.Ctx(ctx).With().Str(logging.LogKeyInterface, iface.Name).Logger()
|
|
ctx = log.WithContext(ctx)
|
|
|
|
var path string
|
|
|
|
caseName := iface.Name
|
|
if p.Case == "underscore" || p.Case == "snake" {
|
|
caseName = p.underscoreCaseName(caseName)
|
|
}
|
|
|
|
if p.KeepTree {
|
|
absOriginalDir, err := filepath.Abs(p.KeepTreeOriginalDirectory)
|
|
if err != nil {
|
|
return nil, err, func() error { return nil }
|
|
}
|
|
relativePath := strings.TrimPrefix(
|
|
filepath.Join(filepath.Dir(iface.FileName), p.filename(caseName)),
|
|
absOriginalDir)
|
|
path = filepath.Join(p.BaseDir, relativePath)
|
|
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
|
return nil, err, func() error { return nil }
|
|
}
|
|
} else if p.InPackage {
|
|
path = filepath.Join(filepath.Dir(iface.FileName), p.filename(caseName))
|
|
} else {
|
|
path = filepath.Join(p.BaseDir, p.filename(caseName))
|
|
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
|
return nil, err, func() error { return nil }
|
|
}
|
|
}
|
|
|
|
log = log.With().Str(logging.LogKeyPath, path).Logger()
|
|
ctx = log.WithContext(ctx)
|
|
|
|
log.Debug().Msgf("creating writer to file")
|
|
f, err := os.Create(path)
|
|
if err != nil {
|
|
return nil, err, func() error { return nil }
|
|
}
|
|
|
|
return f, nil, func() error {
|
|
return f.Close()
|
|
}
|
|
}
|
|
|
|
func (p *FileOutputStreamProvider) filename(name string) string {
|
|
if p.FileName != "" {
|
|
return p.FileName
|
|
} else if p.InPackage && p.TestOnly {
|
|
return "mock_" + name + "_test.go"
|
|
} else if p.InPackage && !p.KeepTree {
|
|
return "mock_" + name + ".go"
|
|
} else if p.TestOnly {
|
|
return name + "_test.go"
|
|
}
|
|
return name + ".go"
|
|
}
|
|
|
|
// shamelessly taken from http://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-camel-caseo
|
|
func (*FileOutputStreamProvider) underscoreCaseName(caseName string) string {
|
|
rxp1 := regexp.MustCompile("(.)([A-Z][a-z]+)")
|
|
s1 := rxp1.ReplaceAllString(caseName, "${1}_${2}")
|
|
rxp2 := regexp.MustCompile("([a-z0-9])([A-Z])")
|
|
return strings.ToLower(rxp2.ReplaceAllString(s1, "${1}_${2}"))
|
|
}
|