242 lines
7.2 KiB
Go
242 lines
7.2 KiB
Go
|
package dynamic
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/golang/protobuf/proto"
|
||
|
|
||
|
"github.com/jhump/protoreflect/desc"
|
||
|
)
|
||
|
|
||
|
// ExtensionRegistry is a registry of known extension fields. This is used to parse
|
||
|
// extension fields encountered when de-serializing a dynamic message.
|
||
|
type ExtensionRegistry struct {
|
||
|
includeDefault bool
|
||
|
mu sync.RWMutex
|
||
|
exts map[string]map[int32]*desc.FieldDescriptor
|
||
|
}
|
||
|
|
||
|
// NewExtensionRegistryWithDefaults is a registry that includes all "default" extensions,
|
||
|
// which are those that are statically linked into the current program (e.g. registered by
|
||
|
// protoc-generated code via proto.RegisterExtension). Extensions explicitly added to the
|
||
|
// registry will override any default extensions that are for the same extendee and have the
|
||
|
// same tag number and/or name.
|
||
|
func NewExtensionRegistryWithDefaults() *ExtensionRegistry {
|
||
|
return &ExtensionRegistry{includeDefault: true}
|
||
|
}
|
||
|
|
||
|
// AddExtensionDesc adds the given extensions to the registry.
|
||
|
func (r *ExtensionRegistry) AddExtensionDesc(exts ...*proto.ExtensionDesc) error {
|
||
|
flds := make([]*desc.FieldDescriptor, len(exts))
|
||
|
for i, ext := range exts {
|
||
|
fd, err := desc.LoadFieldDescriptorForExtension(ext)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
flds[i] = fd
|
||
|
}
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
if r.exts == nil {
|
||
|
r.exts = map[string]map[int32]*desc.FieldDescriptor{}
|
||
|
}
|
||
|
for _, fd := range flds {
|
||
|
r.putExtensionLocked(fd)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// AddExtension adds the given extensions to the registry. The given extensions
|
||
|
// will overwrite any previously added extensions that are for the same extendee
|
||
|
// message and same extension tag number.
|
||
|
func (r *ExtensionRegistry) AddExtension(exts ...*desc.FieldDescriptor) error {
|
||
|
for _, ext := range exts {
|
||
|
if !ext.IsExtension() {
|
||
|
return fmt.Errorf("given field is not an extension: %s", ext.GetFullyQualifiedName())
|
||
|
}
|
||
|
}
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
if r.exts == nil {
|
||
|
r.exts = map[string]map[int32]*desc.FieldDescriptor{}
|
||
|
}
|
||
|
for _, ext := range exts {
|
||
|
r.putExtensionLocked(ext)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// AddExtensionsFromFile adds to the registry all extension fields defined in the given file descriptor.
|
||
|
func (r *ExtensionRegistry) AddExtensionsFromFile(fd *desc.FileDescriptor) {
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
r.addExtensionsFromFileLocked(fd, false, nil)
|
||
|
}
|
||
|
|
||
|
// AddExtensionsFromFileRecursively adds to the registry all extension fields defined in the give file
|
||
|
// descriptor and also recursively adds all extensions defined in that file's dependencies. This adds
|
||
|
// extensions from the entire transitive closure for the given file.
|
||
|
func (r *ExtensionRegistry) AddExtensionsFromFileRecursively(fd *desc.FileDescriptor) {
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
already := map[*desc.FileDescriptor]struct{}{}
|
||
|
r.addExtensionsFromFileLocked(fd, true, already)
|
||
|
}
|
||
|
|
||
|
func (r *ExtensionRegistry) addExtensionsFromFileLocked(fd *desc.FileDescriptor, recursive bool, alreadySeen map[*desc.FileDescriptor]struct{}) {
|
||
|
if _, ok := alreadySeen[fd]; ok {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if r.exts == nil {
|
||
|
r.exts = map[string]map[int32]*desc.FieldDescriptor{}
|
||
|
}
|
||
|
for _, ext := range fd.GetExtensions() {
|
||
|
r.putExtensionLocked(ext)
|
||
|
}
|
||
|
for _, msg := range fd.GetMessageTypes() {
|
||
|
r.addExtensionsFromMessageLocked(msg)
|
||
|
}
|
||
|
|
||
|
if recursive {
|
||
|
alreadySeen[fd] = struct{}{}
|
||
|
for _, dep := range fd.GetDependencies() {
|
||
|
r.addExtensionsFromFileLocked(dep, recursive, alreadySeen)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *ExtensionRegistry) addExtensionsFromMessageLocked(md *desc.MessageDescriptor) {
|
||
|
for _, ext := range md.GetNestedExtensions() {
|
||
|
r.putExtensionLocked(ext)
|
||
|
}
|
||
|
for _, msg := range md.GetNestedMessageTypes() {
|
||
|
r.addExtensionsFromMessageLocked(msg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *ExtensionRegistry) putExtensionLocked(fd *desc.FieldDescriptor) {
|
||
|
msgName := fd.GetOwner().GetFullyQualifiedName()
|
||
|
m := r.exts[msgName]
|
||
|
if m == nil {
|
||
|
m = map[int32]*desc.FieldDescriptor{}
|
||
|
r.exts[msgName] = m
|
||
|
}
|
||
|
m[fd.GetNumber()] = fd
|
||
|
}
|
||
|
|
||
|
// FindExtension queries for the extension field with the given extendee name (must be a fully-qualified
|
||
|
// message name) and tag number. If no extension is known, nil is returned.
|
||
|
func (r *ExtensionRegistry) FindExtension(messageName string, tagNumber int32) *desc.FieldDescriptor {
|
||
|
if r == nil {
|
||
|
return nil
|
||
|
}
|
||
|
r.mu.RLock()
|
||
|
defer r.mu.RUnlock()
|
||
|
fd := r.exts[messageName][tagNumber]
|
||
|
if fd == nil && r.includeDefault {
|
||
|
ext := getDefaultExtensions(messageName)[tagNumber]
|
||
|
if ext != nil {
|
||
|
fd, _ = desc.LoadFieldDescriptorForExtension(ext)
|
||
|
}
|
||
|
}
|
||
|
return fd
|
||
|
}
|
||
|
|
||
|
// FindExtensionByName queries for the extension field with the given extendee name (must be a fully-qualified
|
||
|
// message name) and field name (must also be a fully-qualified extension name). If no extension is known, nil
|
||
|
// is returned.
|
||
|
func (r *ExtensionRegistry) FindExtensionByName(messageName string, fieldName string) *desc.FieldDescriptor {
|
||
|
if r == nil {
|
||
|
return nil
|
||
|
}
|
||
|
r.mu.RLock()
|
||
|
defer r.mu.RUnlock()
|
||
|
for _, fd := range r.exts[messageName] {
|
||
|
if fd.GetFullyQualifiedName() == fieldName {
|
||
|
return fd
|
||
|
}
|
||
|
}
|
||
|
if r.includeDefault {
|
||
|
for _, ext := range getDefaultExtensions(messageName) {
|
||
|
fd, _ := desc.LoadFieldDescriptorForExtension(ext)
|
||
|
if fd.GetFullyQualifiedName() == fieldName {
|
||
|
return fd
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// FindExtensionByJSONName queries for the extension field with the given extendee name (must be a fully-qualified
|
||
|
// message name) and JSON field name (must also be a fully-qualified name). If no extension is known, nil is returned.
|
||
|
// The fully-qualified JSON name is the same as the extension's normal fully-qualified name except that the last
|
||
|
// component uses the field's JSON name (if present).
|
||
|
func (r *ExtensionRegistry) FindExtensionByJSONName(messageName string, fieldName string) *desc.FieldDescriptor {
|
||
|
if r == nil {
|
||
|
return nil
|
||
|
}
|
||
|
r.mu.RLock()
|
||
|
defer r.mu.RUnlock()
|
||
|
for _, fd := range r.exts[messageName] {
|
||
|
if fd.GetFullyQualifiedJSONName() == fieldName {
|
||
|
return fd
|
||
|
}
|
||
|
}
|
||
|
if r.includeDefault {
|
||
|
for _, ext := range getDefaultExtensions(messageName) {
|
||
|
fd, _ := desc.LoadFieldDescriptorForExtension(ext)
|
||
|
if fd.GetFullyQualifiedJSONName() == fieldName {
|
||
|
return fd
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func getDefaultExtensions(messageName string) map[int32]*proto.ExtensionDesc {
|
||
|
t := proto.MessageType(messageName)
|
||
|
if t != nil {
|
||
|
msg := reflect.Zero(t).Interface().(proto.Message)
|
||
|
return proto.RegisteredExtensions(msg)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// AllExtensionsForType returns all known extension fields for the given extendee name (must be a
|
||
|
// fully-qualified message name).
|
||
|
func (r *ExtensionRegistry) AllExtensionsForType(messageName string) []*desc.FieldDescriptor {
|
||
|
if r == nil {
|
||
|
return []*desc.FieldDescriptor(nil)
|
||
|
}
|
||
|
r.mu.RLock()
|
||
|
defer r.mu.RUnlock()
|
||
|
flds := r.exts[messageName]
|
||
|
var ret []*desc.FieldDescriptor
|
||
|
if r.includeDefault {
|
||
|
exts := getDefaultExtensions(messageName)
|
||
|
if len(exts) > 0 || len(flds) > 0 {
|
||
|
ret = make([]*desc.FieldDescriptor, 0, len(exts)+len(flds))
|
||
|
}
|
||
|
for tag, ext := range exts {
|
||
|
if _, ok := flds[tag]; ok {
|
||
|
// skip default extension and use the one explicitly registered instead
|
||
|
continue
|
||
|
}
|
||
|
fd, _ := desc.LoadFieldDescriptorForExtension(ext)
|
||
|
if fd != nil {
|
||
|
ret = append(ret, fd)
|
||
|
}
|
||
|
}
|
||
|
} else if len(flds) > 0 {
|
||
|
ret = make([]*desc.FieldDescriptor, 0, len(flds))
|
||
|
}
|
||
|
|
||
|
for _, ext := range flds {
|
||
|
ret = append(ret, ext)
|
||
|
}
|
||
|
return ret
|
||
|
}
|