// https://bongnv.com/blog/2021-05-08-create-codes-generator-with-protobuf/ // https://rotemtam.com/2021/03/22/creating-a-protoc-plugin-to-gen-go-code/ package main import ( "flag" "fmt" "strconv" "strings" "github.com/rs/zerolog/log" "google.golang.org/protobuf/compiler/protogen" ) var ( useProtoNames *bool emitUnpopulated *bool excludeMsgs *string ) func main() { var flags flag.FlagSet useProtoNames = flags.Bool("use_proto_names", true, "default: true") emitUnpopulated = flags.Bool("emit_unpopulated", true, "default: true") excludeMsgs = flags.String("exclude", "", "message to ignore. delimited with `,`") protogen.Options{ParamFunc: flags.Set}.Run(func(plugin *protogen.Plugin) error { for _, file := range plugin.Files { if !file.Generate { continue } if err := generateFile(plugin, file); err != nil { return fmt.Errorf("failed to generate file: %w", err) } } return nil }) } func generateFile(p *protogen.Plugin, f *protogen.File) error { if len(f.Messages) == 0 { log.Info().Msg("no messages") return nil } filename := f.GeneratedFilenamePrefix + ".value_scan.pb.go" g := p.NewGeneratedFile(filename, f.GoImportPath) g.P(`// Code generated by protoc-gen-go-value-scan. DO NOT EDIT.`) g.P() g.P(`package `, f.GoPackageName) g.P() g.P(`import (`) g.P(`"database/sql/driver"`) g.P(`"fmt"`) g.P() g.P(`protojson "google.golang.org/protobuf/encoding/protojson"`) g.P(`)`) g.P() for _, msg := range f.Messages { // Check if message should be excluded. if exclude(*excludeMsgs, msg.GoIdent) { continue } g.P("func (x ", msg.GoIdent.GoName, ") Value() (driver.Value, error) {") g.P("f := &x") g.P() g.P("opts := protojson.MarshalOptions{}") g.P("opts.UseProtoNames = ", strconv.FormatBool(*useProtoNames)) g.P("opts.EmitUnpopulated = ", strconv.FormatBool(*emitUnpopulated)) g.P() g.P("return opts.Marshal(f)") g.P(`}`) g.P() g.P(`func (x *`, msg.GoIdent.GoName, `) Scan(value interface{}) error {`) g.P(`switch v := value.(type) {`) g.P(`case []byte:`) g.P(` return protojson.Unmarshal(v, x)`) g.P(`case string:`) g.P(` return protojson.Unmarshal([]byte(v), x)`) g.P(`default:`) g.P(` return fmt.Errorf("type assertion failed")`) g.P(`}`) g.P(`}`) g.P() } return nil } func identifier(id protogen.GoIdent) string { return fmt.Sprintf("%s.%s", strings.Trim(id.GoImportPath.String(), `"`), id.GoName) } func exclude(e string, i protogen.GoIdent) bool { if e == "" { return false } excs := strings.Split(e, ",") for _, ex := range excs { if ex == identifier(i) { return true } } return false }