diff --git a/main.go b/main.go index e0f3d73..b675453 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "flag" "fmt" "strconv" + "strings" "github.com/rs/zerolog/log" "google.golang.org/protobuf/compiler/protogen" @@ -14,6 +15,7 @@ import ( var ( useProtoNames *bool emitUnpopulated *bool + excludeMsgs *string ) func main() { @@ -21,6 +23,7 @@ func main() { useProtoNames = flags.Bool("use_proto_names", true, "default: true") emitUnpopulated = flags.Bool("emit_unpopulated", true, "default: true") + excludeMsgs = flags.String("exclude", "", "message to ignore. delimited with `,`") protogen.Options{ParamFunc: flags.Set}.Run(func(plugin *protogen.Plugin) error { for _, file := range plugin.Files { @@ -60,6 +63,11 @@ func generateFile(p *protogen.Plugin, f *protogen.File) error { g.P() for _, msg := range f.Messages { + // Check if message should be excluded. + if exclude(*excludeMsgs, msg.GoIdent) { + continue + } + g.P("func (x ", msg.GoIdent.GoName, ") Value() (driver.Value, error) {") g.P("f := &x") g.P() @@ -80,7 +88,27 @@ func generateFile(p *protogen.Plugin, f *protogen.File) error { g.P(` return fmt.Errorf("type assertion failed")`) g.P(`}`) g.P(`}`) + g.P() } return nil } + +func identifier(id protogen.GoIdent) string { + return fmt.Sprintf("%s.%s", strings.Trim(id.GoImportPath.String(), `"`), id.GoName) +} + +func exclude(e string, i protogen.GoIdent) bool { + if e == "" { + return false + } + + excs := strings.Split(e, ",") + for _, ex := range excs { + if ex == identifier(i) { + return true + } + } + + return false +}