schnutibox/vendor/github.com/jhump/protoreflect/desc/protoparse/linker.go
Marvin Steadfast ae35d9ab41
Some checks failed
continuous-integration/drone/push Build is failing
uses buf for compiling proto files and implements the server service
2021-05-05 11:14:17 +02:00

746 lines
23 KiB
Go

package protoparse
import (
"bytes"
"fmt"
"sort"
"strings"
"github.com/golang/protobuf/proto"
dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/desc/internal"
"github.com/jhump/protoreflect/desc/protoparse/ast"
)
type linker struct {
files map[string]*parseResult
filenames []string
errs *errorHandler
descriptorPool map[*dpb.FileDescriptorProto]map[string]proto.Message
extensions map[string]map[int32]string
}
func newLinker(files *parseResults, errs *errorHandler) *linker {
return &linker{files: files.resultsByFilename, filenames: files.filenames, errs: errs}
}
func (l *linker) linkFiles() (map[string]*desc.FileDescriptor, error) {
// First, we put all symbols into a single pool, which lets us ensure there
// are no duplicate symbols and will also let us resolve and revise all type
// references in next step.
if err := l.createDescriptorPool(); err != nil {
return nil, err
}
// After we've populated the pool, we can now try to resolve all type
// references. All references must be checked for correct type, any fields
// with enum types must be corrected (since we parse them as if they are
// message references since we don't actually know message or enum until
// link time), and references will be re-written to be fully-qualified
// references (e.g. start with a dot ".").
if err := l.resolveReferences(); err != nil {
return nil, err
}
if err := l.errs.getError(); err != nil {
// we won't be able to create real descriptors if we've encountered
// errors up to this point, so bail at this point
return nil, err
}
// Now we've validated the descriptors, so we can link them into rich
// descriptors. This is a little redundant since that step does similar
// checking of symbols. But, without breaking encapsulation (e.g. exporting
// a lot of fields from desc package that are currently unexported) or
// merging this into the same package, we can't really prevent it.
linked, err := l.createdLinkedDescriptors()
if err != nil {
return nil, err
}
// Now that we have linked descriptors, we can interpret any uninterpreted
// options that remain.
for _, r := range l.files {
fd := linked[r.fd.GetName()]
if err := interpretFileOptions(r, richFileDescriptorish{FileDescriptor: fd}); err != nil {
return nil, err
}
// we should now have any message_set_wire_format options parsed
// and can do further validation on tag ranges
if err := checkExtensionTagsInFile(fd, r); err != nil {
return nil, err
}
}
// When Parser calls linkFiles, it does not check errs again, and it expects that linkFiles
// will return all errors it should process. If the ErrorReporter handles all errors itself
// and always returns nil, we should get ErrInvalidSource here, and need to propagate this
if err := l.errs.getError(); err != nil {
return nil, err
}
return linked, nil
}
func (l *linker) createDescriptorPool() error {
l.descriptorPool = map[*dpb.FileDescriptorProto]map[string]proto.Message{}
for _, filename := range l.filenames {
r := l.files[filename]
fd := r.fd
pool := map[string]proto.Message{}
l.descriptorPool[fd] = pool
prefix := fd.GetPackage()
if prefix != "" {
prefix += "."
}
for _, md := range fd.MessageType {
if err := addMessageToPool(r, pool, l.errs, prefix, md); err != nil {
return err
}
}
for _, fld := range fd.Extension {
if err := addFieldToPool(r, pool, l.errs, prefix, fld); err != nil {
return err
}
}
for _, ed := range fd.EnumType {
if err := addEnumToPool(r, pool, l.errs, prefix, ed); err != nil {
return err
}
}
for _, sd := range fd.Service {
if err := addServiceToPool(r, pool, l.errs, prefix, sd); err != nil {
return err
}
}
}
// try putting everything into a single pool, to ensure there are no duplicates
// across files (e.g. same symbol, but declared in two different files)
type entry struct {
file string
msg proto.Message
}
pool := map[string]entry{}
for _, filename := range l.filenames {
f := l.files[filename].fd
p := l.descriptorPool[f]
keys := make([]string, 0, len(p))
for k := range p {
keys = append(keys, k)
}
sort.Strings(keys) // for deterministic error reporting
for _, k := range keys {
v := p[k]
if e, ok := pool[k]; ok {
desc1 := e.msg
file1 := e.file
desc2 := v
file2 := f.GetName()
if file2 < file1 {
file1, file2 = file2, file1
desc1, desc2 = desc2, desc1
}
node := l.files[file2].nodes[desc2]
if err := l.errs.handleErrorWithPos(node.Start(), "duplicate symbol %s: already defined as %s in %q", k, descriptorType(desc1), file1); err != nil {
return err
}
}
pool[k] = entry{file: f.GetName(), msg: v}
}
}
return nil
}
func addMessageToPool(r *parseResult, pool map[string]proto.Message, errs *errorHandler, prefix string, md *dpb.DescriptorProto) error {
fqn := prefix + md.GetName()
if err := addToPool(r, pool, errs, fqn, md); err != nil {
return err
}
prefix = fqn + "."
for _, fld := range md.Field {
if err := addFieldToPool(r, pool, errs, prefix, fld); err != nil {
return err
}
}
for _, fld := range md.Extension {
if err := addFieldToPool(r, pool, errs, prefix, fld); err != nil {
return err
}
}
for _, nmd := range md.NestedType {
if err := addMessageToPool(r, pool, errs, prefix, nmd); err != nil {
return err
}
}
for _, ed := range md.EnumType {
if err := addEnumToPool(r, pool, errs, prefix, ed); err != nil {
return err
}
}
return nil
}
func addFieldToPool(r *parseResult, pool map[string]proto.Message, errs *errorHandler, prefix string, fld *dpb.FieldDescriptorProto) error {
fqn := prefix + fld.GetName()
return addToPool(r, pool, errs, fqn, fld)
}
func addEnumToPool(r *parseResult, pool map[string]proto.Message, errs *errorHandler, prefix string, ed *dpb.EnumDescriptorProto) error {
fqn := prefix + ed.GetName()
if err := addToPool(r, pool, errs, fqn, ed); err != nil {
return err
}
for _, evd := range ed.Value {
vfqn := fqn + "." + evd.GetName()
if err := addToPool(r, pool, errs, vfqn, evd); err != nil {
return err
}
}
return nil
}
func addServiceToPool(r *parseResult, pool map[string]proto.Message, errs *errorHandler, prefix string, sd *dpb.ServiceDescriptorProto) error {
fqn := prefix + sd.GetName()
if err := addToPool(r, pool, errs, fqn, sd); err != nil {
return err
}
for _, mtd := range sd.Method {
mfqn := fqn + "." + mtd.GetName()
if err := addToPool(r, pool, errs, mfqn, mtd); err != nil {
return err
}
}
return nil
}
func addToPool(r *parseResult, pool map[string]proto.Message, errs *errorHandler, fqn string, dsc proto.Message) error {
if d, ok := pool[fqn]; ok {
node := r.nodes[dsc]
if err := errs.handleErrorWithPos(node.Start(), "duplicate symbol %s: already defined as %s", fqn, descriptorType(d)); err != nil {
return err
}
}
pool[fqn] = dsc
return nil
}
func descriptorType(m proto.Message) string {
switch m := m.(type) {
case *dpb.DescriptorProto:
return "message"
case *dpb.DescriptorProto_ExtensionRange:
return "extension range"
case *dpb.FieldDescriptorProto:
if m.GetExtendee() == "" {
return "field"
} else {
return "extension"
}
case *dpb.EnumDescriptorProto:
return "enum"
case *dpb.EnumValueDescriptorProto:
return "enum value"
case *dpb.ServiceDescriptorProto:
return "service"
case *dpb.MethodDescriptorProto:
return "method"
case *dpb.FileDescriptorProto:
return "file"
default:
// shouldn't be possible
return fmt.Sprintf("%T", m)
}
}
func (l *linker) resolveReferences() error {
l.extensions = map[string]map[int32]string{}
for _, filename := range l.filenames {
r := l.files[filename]
fd := r.fd
prefix := fd.GetPackage()
scopes := []scope{fileScope(fd, l)}
if prefix != "" {
prefix += "."
}
if fd.Options != nil {
if err := l.resolveOptions(r, fd, "file", fd.GetName(), proto.MessageName(fd.Options), fd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
for _, md := range fd.MessageType {
if err := l.resolveMessageTypes(r, fd, prefix, md, scopes); err != nil {
return err
}
}
for _, fld := range fd.Extension {
if err := l.resolveFieldTypes(r, fd, prefix, fld, scopes); err != nil {
return err
}
}
for _, ed := range fd.EnumType {
if err := l.resolveEnumTypes(r, fd, prefix, ed, scopes); err != nil {
return err
}
}
for _, sd := range fd.Service {
if err := l.resolveServiceTypes(r, fd, prefix, sd, scopes); err != nil {
return err
}
}
}
return nil
}
func (l *linker) resolveEnumTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, ed *dpb.EnumDescriptorProto, scopes []scope) error {
enumFqn := prefix + ed.GetName()
if ed.Options != nil {
if err := l.resolveOptions(r, fd, "enum", enumFqn, proto.MessageName(ed.Options), ed.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
for _, evd := range ed.Value {
if evd.Options != nil {
evFqn := enumFqn + "." + evd.GetName()
if err := l.resolveOptions(r, fd, "enum value", evFqn, proto.MessageName(evd.Options), evd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
}
return nil
}
func (l *linker) resolveMessageTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, md *dpb.DescriptorProto, scopes []scope) error {
fqn := prefix + md.GetName()
scope := messageScope(fqn, isProto3(fd), l.descriptorPool[fd])
scopes = append(scopes, scope)
prefix = fqn + "."
if md.Options != nil {
if err := l.resolveOptions(r, fd, "message", fqn, proto.MessageName(md.Options), md.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
for _, nmd := range md.NestedType {
if err := l.resolveMessageTypes(r, fd, prefix, nmd, scopes); err != nil {
return err
}
}
for _, ned := range md.EnumType {
if err := l.resolveEnumTypes(r, fd, prefix, ned, scopes); err != nil {
return err
}
}
for _, fld := range md.Field {
if err := l.resolveFieldTypes(r, fd, prefix, fld, scopes); err != nil {
return err
}
}
for _, fld := range md.Extension {
if err := l.resolveFieldTypes(r, fd, prefix, fld, scopes); err != nil {
return err
}
}
for _, er := range md.ExtensionRange {
if er.Options != nil {
erName := fmt.Sprintf("%s:%d-%d", fqn, er.GetStart(), er.GetEnd()-1)
if err := l.resolveOptions(r, fd, "extension range", erName, proto.MessageName(er.Options), er.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
}
return nil
}
func (l *linker) resolveFieldTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto, scopes []scope) error {
thisName := prefix + fld.GetName()
scope := fmt.Sprintf("field %s", thisName)
node := r.getFieldNode(fld)
elemType := "field"
if fld.GetExtendee() != "" {
elemType = "extension"
fqn, dsc, _ := l.resolve(fd, fld.GetExtendee(), isMessage, scopes)
if dsc == nil {
return l.errs.handleErrorWithPos(node.FieldExtendee().Start(), "unknown extendee type %s", fld.GetExtendee())
}
extd, ok := dsc.(*dpb.DescriptorProto)
if !ok {
otherType := descriptorType(dsc)
return l.errs.handleErrorWithPos(node.FieldExtendee().Start(), "extendee is invalid: %s is a %s, not a message", fqn, otherType)
}
fld.Extendee = proto.String("." + fqn)
// make sure the tag number is in range
found := false
tag := fld.GetNumber()
for _, rng := range extd.ExtensionRange {
if tag >= rng.GetStart() && tag < rng.GetEnd() {
found = true
break
}
}
if !found {
if err := l.errs.handleErrorWithPos(node.FieldTag().Start(), "%s: tag %d is not in valid range for extended type %s", scope, tag, fqn); err != nil {
return err
}
} else {
// make sure tag is not a duplicate
usedExtTags := l.extensions[fqn]
if usedExtTags == nil {
usedExtTags = map[int32]string{}
l.extensions[fqn] = usedExtTags
}
if other := usedExtTags[fld.GetNumber()]; other != "" {
if err := l.errs.handleErrorWithPos(node.FieldTag().Start(), "%s: duplicate extension: %s and %s are both using tag %d", scope, other, thisName, fld.GetNumber()); err != nil {
return err
}
} else {
usedExtTags[fld.GetNumber()] = thisName
}
}
}
if fld.Options != nil {
if err := l.resolveOptions(r, fd, elemType, thisName, proto.MessageName(fld.Options), fld.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
if fld.GetTypeName() == "" {
// scalar type; no further resolution required
return nil
}
fqn, dsc, proto3 := l.resolve(fd, fld.GetTypeName(), isType, scopes)
if dsc == nil {
return l.errs.handleErrorWithPos(node.FieldType().Start(), "%s: unknown type %s", scope, fld.GetTypeName())
}
switch dsc := dsc.(type) {
case *dpb.DescriptorProto:
fld.TypeName = proto.String("." + fqn)
// if type was tentatively unset, we now know it's actually a message
if fld.Type == nil {
fld.Type = dpb.FieldDescriptorProto_TYPE_MESSAGE.Enum()
}
case *dpb.EnumDescriptorProto:
if fld.GetExtendee() == "" && isProto3(fd) && !proto3 {
// fields in a proto3 message cannot refer to proto2 enums
return l.errs.handleErrorWithPos(node.FieldType().Start(), "%s: cannot use proto2 enum %s in a proto3 message", scope, fld.GetTypeName())
}
fld.TypeName = proto.String("." + fqn)
// the type was tentatively unset, but now we know it's actually an enum
fld.Type = dpb.FieldDescriptorProto_TYPE_ENUM.Enum()
default:
otherType := descriptorType(dsc)
return l.errs.handleErrorWithPos(node.FieldType().Start(), "%s: invalid type: %s is a %s, not a message or enum", scope, fqn, otherType)
}
return nil
}
func (l *linker) resolveServiceTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, sd *dpb.ServiceDescriptorProto, scopes []scope) error {
thisName := prefix + sd.GetName()
if sd.Options != nil {
if err := l.resolveOptions(r, fd, "service", thisName, proto.MessageName(sd.Options), sd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
for _, mtd := range sd.Method {
if mtd.Options != nil {
if err := l.resolveOptions(r, fd, "method", thisName+"."+mtd.GetName(), proto.MessageName(mtd.Options), mtd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
scope := fmt.Sprintf("method %s.%s", thisName, mtd.GetName())
node := r.getMethodNode(mtd)
fqn, dsc, _ := l.resolve(fd, mtd.GetInputType(), isMessage, scopes)
if dsc == nil {
if err := l.errs.handleErrorWithPos(node.GetInputType().Start(), "%s: unknown request type %s", scope, mtd.GetInputType()); err != nil {
return err
}
} else if _, ok := dsc.(*dpb.DescriptorProto); !ok {
otherType := descriptorType(dsc)
if err := l.errs.handleErrorWithPos(node.GetInputType().Start(), "%s: invalid request type: %s is a %s, not a message", scope, fqn, otherType); err != nil {
return err
}
} else {
mtd.InputType = proto.String("." + fqn)
}
fqn, dsc, _ = l.resolve(fd, mtd.GetOutputType(), isMessage, scopes)
if dsc == nil {
if err := l.errs.handleErrorWithPos(node.GetOutputType().Start(), "%s: unknown response type %s", scope, mtd.GetOutputType()); err != nil {
return err
}
} else if _, ok := dsc.(*dpb.DescriptorProto); !ok {
otherType := descriptorType(dsc)
if err := l.errs.handleErrorWithPos(node.GetOutputType().Start(), "%s: invalid response type: %s is a %s, not a message", scope, fqn, otherType); err != nil {
return err
}
} else {
mtd.OutputType = proto.String("." + fqn)
}
}
return nil
}
func (l *linker) resolveOptions(r *parseResult, fd *dpb.FileDescriptorProto, elemType, elemName, optType string, opts []*dpb.UninterpretedOption, scopes []scope) error {
var scope string
if elemType != "file" {
scope = fmt.Sprintf("%s %s: ", elemType, elemName)
}
opts:
for _, opt := range opts {
for _, nm := range opt.Name {
if nm.GetIsExtension() {
node := r.getOptionNamePartNode(nm)
fqn, dsc, _ := l.resolve(fd, nm.GetNamePart(), isField, scopes)
if dsc == nil {
if err := l.errs.handleErrorWithPos(node.Start(), "%sunknown extension %s", scope, nm.GetNamePart()); err != nil {
return err
}
continue opts
}
if ext, ok := dsc.(*dpb.FieldDescriptorProto); !ok {
otherType := descriptorType(dsc)
if err := l.errs.handleErrorWithPos(node.Start(), "%sinvalid extension: %s is a %s, not an extension", scope, nm.GetNamePart(), otherType); err != nil {
return err
}
continue opts
} else if ext.GetExtendee() == "" {
if err := l.errs.handleErrorWithPos(node.Start(), "%sinvalid extension: %s is a field but not an extension", scope, nm.GetNamePart()); err != nil {
return err
}
continue opts
}
nm.NamePart = proto.String("." + fqn)
}
}
}
return nil
}
func (l *linker) resolve(fd *dpb.FileDescriptorProto, name string, allowed func(proto.Message) bool, scopes []scope) (fqn string, element proto.Message, proto3 bool) {
if strings.HasPrefix(name, ".") {
// already fully-qualified
d, proto3 := l.findSymbol(fd, name[1:], false, map[*dpb.FileDescriptorProto]struct{}{})
if d != nil {
return name[1:], d, proto3
}
} else {
// unqualified, so we look in the enclosing (last) scope first and move
// towards outermost (first) scope, trying to resolve the symbol
var bestGuess proto.Message
var bestGuessFqn string
var bestGuessProto3 bool
for i := len(scopes) - 1; i >= 0; i-- {
fqn, d, proto3 := scopes[i](name)
if d != nil {
if allowed(d) {
return fqn, d, proto3
} else if bestGuess == nil {
bestGuess = d
bestGuessFqn = fqn
bestGuessProto3 = proto3
}
}
}
// we return best guess, even though it was not an allowed kind of
// descriptor, so caller can print a better error message (e.g.
// indicating that the name was found but that it's the wrong type)
return bestGuessFqn, bestGuess, bestGuessProto3
}
return "", nil, false
}
func isField(m proto.Message) bool {
_, ok := m.(*dpb.FieldDescriptorProto)
return ok
}
func isMessage(m proto.Message) bool {
_, ok := m.(*dpb.DescriptorProto)
return ok
}
func isType(m proto.Message) bool {
switch m.(type) {
case *dpb.DescriptorProto, *dpb.EnumDescriptorProto:
return true
}
return false
}
// scope represents a lexical scope in a proto file in which messages and enums
// can be declared.
type scope func(symbol string) (fqn string, element proto.Message, proto3 bool)
func fileScope(fd *dpb.FileDescriptorProto, l *linker) scope {
// we search symbols in this file, but also symbols in other files that have
// the same package as this file or a "parent" package (in protobuf,
// packages are a hierarchy like C++ namespaces)
prefixes := internal.CreatePrefixList(fd.GetPackage())
return func(name string) (string, proto.Message, bool) {
for _, prefix := range prefixes {
var n string
if prefix == "" {
n = name
} else {
n = prefix + "." + name
}
d, proto3 := l.findSymbol(fd, n, false, map[*dpb.FileDescriptorProto]struct{}{})
if d != nil {
return n, d, proto3
}
}
return "", nil, false
}
}
func messageScope(messageName string, proto3 bool, filePool map[string]proto.Message) scope {
return func(name string) (string, proto.Message, bool) {
n := messageName + "." + name
if d, ok := filePool[n]; ok {
return n, d, proto3
}
return "", nil, false
}
}
func (l *linker) findSymbol(fd *dpb.FileDescriptorProto, name string, public bool, checked map[*dpb.FileDescriptorProto]struct{}) (element proto.Message, proto3 bool) {
if _, ok := checked[fd]; ok {
// already checked this one
return nil, false
}
checked[fd] = struct{}{}
d := l.descriptorPool[fd][name]
if d != nil {
return d, isProto3(fd)
}
// When public = false, we are searching only directly imported symbols. But we
// also need to search transitive public imports due to semantics of public imports.
if public {
for _, depIndex := range fd.PublicDependency {
dep := fd.Dependency[depIndex]
depres := l.files[dep]
if depres == nil {
// we'll catch this error later
continue
}
if d, proto3 := l.findSymbol(depres.fd, name, true, checked); d != nil {
return d, proto3
}
}
} else {
for _, dep := range fd.Dependency {
depres := l.files[dep]
if depres == nil {
// we'll catch this error later
continue
}
if d, proto3 := l.findSymbol(depres.fd, name, true, checked); d != nil {
return d, proto3
}
}
}
return nil, false
}
func isProto3(fd *dpb.FileDescriptorProto) bool {
return fd.GetSyntax() == "proto3"
}
func (l *linker) createdLinkedDescriptors() (map[string]*desc.FileDescriptor, error) {
names := make([]string, 0, len(l.files))
for name := range l.files {
names = append(names, name)
}
sort.Strings(names)
linked := map[string]*desc.FileDescriptor{}
for _, name := range names {
if _, err := l.linkFile(name, nil, nil, linked); err != nil {
return nil, err
}
}
return linked, nil
}
func (l *linker) linkFile(name string, rootImportLoc *SourcePos, seen []string, linked map[string]*desc.FileDescriptor) (*desc.FileDescriptor, error) {
// check for import cycle
for _, s := range seen {
if name == s {
var msg bytes.Buffer
first := true
for _, s := range seen {
if first {
first = false
} else {
msg.WriteString(" -> ")
}
_, _ = fmt.Fprintf(&msg, "%q", s)
}
_, _ = fmt.Fprintf(&msg, " -> %q", name)
return nil, ErrorWithSourcePos{
Underlying: fmt.Errorf("cycle found in imports: %s", msg.String()),
Pos: rootImportLoc,
}
}
}
seen = append(seen, name)
if lfd, ok := linked[name]; ok {
// already linked
return lfd, nil
}
r := l.files[name]
if r == nil {
importer := seen[len(seen)-2] // len-1 is *this* file, before that is the one that imported it
return nil, fmt.Errorf("no descriptor found for %q, imported by %q", name, importer)
}
var deps []*desc.FileDescriptor
if rootImportLoc == nil {
// try to find a source location for this "root" import
decl := r.getFileNode(r.fd)
fnode, ok := decl.(*ast.FileNode)
if ok {
for _, decl := range fnode.Decls {
if dep, ok := decl.(*ast.ImportNode); ok {
ldep, err := l.linkFile(dep.Name.AsString(), dep.Name.Start(), seen, linked)
if err != nil {
return nil, err
}
deps = append(deps, ldep)
}
}
} else {
// no AST? just use the descriptor
for _, dep := range r.fd.Dependency {
ldep, err := l.linkFile(dep, decl.Start(), seen, linked)
if err != nil {
return nil, err
}
deps = append(deps, ldep)
}
}
} else {
// we can just use the descriptor since we don't need source location
// (we'll just attribute any import cycles found to the "root" import)
for _, dep := range r.fd.Dependency {
ldep, err := l.linkFile(dep, rootImportLoc, seen, linked)
if err != nil {
return nil, err
}
deps = append(deps, ldep)
}
}
lfd, err := desc.CreateFileDescriptor(r.fd, deps...)
if err != nil {
return nil, fmt.Errorf("error linking %q: %s", name, err)
}
linked[name] = lfd
return lfd, nil
}