protoc-gen-go-value-scan/main.go
2022-12-13 13:12:42 +00:00

115 lines
2.6 KiB
Go

// https://bongnv.com/blog/2021-05-08-create-codes-generator-with-protobuf/
// https://rotemtam.com/2021/03/22/creating-a-protoc-plugin-to-gen-go-code/
package main
import (
"flag"
"fmt"
"strconv"
"strings"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/compiler/protogen"
)
var (
useProtoNames *bool
emitUnpopulated *bool
excludeMsgs *string
)
func main() {
var flags flag.FlagSet
useProtoNames = flags.Bool("use_proto_names", true, "default: true")
emitUnpopulated = flags.Bool("emit_unpopulated", true, "default: true")
excludeMsgs = flags.String("exclude", "", "message to ignore. delimited with `,`")
protogen.Options{ParamFunc: flags.Set}.Run(func(plugin *protogen.Plugin) error {
for _, file := range plugin.Files {
if !file.Generate {
continue
}
if err := generateFile(plugin, file); err != nil {
return fmt.Errorf("failed to generate file: %w", err)
}
}
return nil
})
}
func generateFile(p *protogen.Plugin, f *protogen.File) error {
if len(f.Messages) == 0 {
log.Info().Msg("no messages")
return nil
}
filename := f.GeneratedFilenamePrefix + ".value_scan.pb.go"
g := p.NewGeneratedFile(filename, f.GoImportPath)
g.P(`// Code generated by protoc-gen-go-value-scan. DO NOT EDIT.`)
g.P()
g.P(`package `, f.GoPackageName)
g.P()
g.P(`import (`)
g.P(`"database/sql/driver"`)
g.P(`"fmt"`)
g.P()
g.P(`protojson "google.golang.org/protobuf/encoding/protojson"`)
g.P(`)`)
g.P()
for _, msg := range f.Messages {
// Check if message should be excluded.
if exclude(*excludeMsgs, msg.GoIdent) {
continue
}
g.P("func (x ", msg.GoIdent.GoName, ") Value() (driver.Value, error) {")
g.P("f := &x")
g.P()
g.P("opts := protojson.MarshalOptions{}")
g.P("opts.UseProtoNames = ", strconv.FormatBool(*useProtoNames))
g.P("opts.EmitUnpopulated = ", strconv.FormatBool(*emitUnpopulated))
g.P()
g.P("return opts.Marshal(f)")
g.P(`}`)
g.P()
g.P(`func (x *`, msg.GoIdent.GoName, `) Scan(value interface{}) error {`)
g.P(`switch v := value.(type) {`)
g.P(`case []byte:`)
g.P(` return protojson.Unmarshal(v, x)`)
g.P(`case string:`)
g.P(` return protojson.Unmarshal([]byte(v), x)`)
g.P(`default:`)
g.P(` return fmt.Errorf("type assertion failed")`)
g.P(`}`)
g.P(`}`)
g.P()
}
return nil
}
func identifier(id protogen.GoIdent) string {
return fmt.Sprintf("%s.%s", strings.Trim(id.GoImportPath.String(), `"`), id.GoName)
}
func exclude(e string, i protogen.GoIdent) bool {
if e == "" {
return false
}
excs := strings.Split(e, ",")
for _, ex := range excs {
if ex == identifier(i) {
return true
}
}
return false
}