319 lines
10 KiB
Go
319 lines
10 KiB
Go
package codec
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/golang/protobuf/protoc-gen-go/descriptor"
|
|
|
|
"github.com/jhump/protoreflect/desc"
|
|
)
|
|
|
|
var varintTypes = map[descriptor.FieldDescriptorProto_Type]bool{}
|
|
var fixed32Types = map[descriptor.FieldDescriptorProto_Type]bool{}
|
|
var fixed64Types = map[descriptor.FieldDescriptorProto_Type]bool{}
|
|
|
|
func init() {
|
|
varintTypes[descriptor.FieldDescriptorProto_TYPE_BOOL] = true
|
|
varintTypes[descriptor.FieldDescriptorProto_TYPE_INT32] = true
|
|
varintTypes[descriptor.FieldDescriptorProto_TYPE_INT64] = true
|
|
varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT32] = true
|
|
varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT64] = true
|
|
varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT32] = true
|
|
varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT64] = true
|
|
varintTypes[descriptor.FieldDescriptorProto_TYPE_ENUM] = true
|
|
|
|
fixed32Types[descriptor.FieldDescriptorProto_TYPE_FIXED32] = true
|
|
fixed32Types[descriptor.FieldDescriptorProto_TYPE_SFIXED32] = true
|
|
fixed32Types[descriptor.FieldDescriptorProto_TYPE_FLOAT] = true
|
|
|
|
fixed64Types[descriptor.FieldDescriptorProto_TYPE_FIXED64] = true
|
|
fixed64Types[descriptor.FieldDescriptorProto_TYPE_SFIXED64] = true
|
|
fixed64Types[descriptor.FieldDescriptorProto_TYPE_DOUBLE] = true
|
|
}
|
|
|
|
// ErrWireTypeEndGroup is returned from DecodeFieldValue if the tag and wire-type
|
|
// it reads indicates an end-group marker.
|
|
var ErrWireTypeEndGroup = errors.New("unexpected wire type: end group")
|
|
|
|
// MessageFactory is used to instantiate messages when DecodeFieldValue needs to
|
|
// decode a message value.
|
|
//
|
|
// Also see MessageFactory in "github.com/jhump/protoreflect/dynamic", which
|
|
// implements this interface.
|
|
type MessageFactory interface {
|
|
NewMessage(md *desc.MessageDescriptor) proto.Message
|
|
}
|
|
|
|
// UnknownField represents a field that was parsed from the binary wire
|
|
// format for a message, but was not a recognized field number. Enough
|
|
// information is preserved so that re-serializing the message won't lose
|
|
// any of the unrecognized data.
|
|
type UnknownField struct {
|
|
// The tag number for the unrecognized field.
|
|
Tag int32
|
|
|
|
// Encoding indicates how the unknown field was encoded on the wire. If it
|
|
// is proto.WireBytes or proto.WireGroupStart then Contents will be set to
|
|
// the raw bytes. If it is proto.WireTypeFixed32 then the data is in the least
|
|
// significant 32 bits of Value. Otherwise, the data is in all 64 bits of
|
|
// Value.
|
|
Encoding int8
|
|
Contents []byte
|
|
Value uint64
|
|
}
|
|
|
|
// DecodeZigZag32 decodes a signed 32-bit integer from the given
|
|
// zig-zag encoded value.
|
|
func DecodeZigZag32(v uint64) int32 {
|
|
return int32((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31))
|
|
}
|
|
|
|
// DecodeZigZag64 decodes a signed 64-bit integer from the given
|
|
// zig-zag encoded value.
|
|
func DecodeZigZag64(v uint64) int64 {
|
|
return int64((v >> 1) ^ uint64((int64(v&1)<<63)>>63))
|
|
}
|
|
|
|
// DecodeFieldValue will read a field value from the buffer and return its
|
|
// value and the corresponding field descriptor. The given function is used
|
|
// to lookup a field descriptor by tag number. The given factory is used to
|
|
// instantiate a message if the field value is (or contains) a message value.
|
|
//
|
|
// On error, the field descriptor and value are typically nil. However, if the
|
|
// error returned is ErrWireTypeEndGroup, the returned value will indicate any
|
|
// tag number encoded in the end-group marker.
|
|
//
|
|
// If the field descriptor returned is nil, that means that the given function
|
|
// returned nil. This is expected to happen for unrecognized tag numbers. In
|
|
// that case, no error is returned, and the value will be an UnknownField.
|
|
func (cb *Buffer) DecodeFieldValue(fieldFinder func(int32) *desc.FieldDescriptor, fact MessageFactory) (*desc.FieldDescriptor, interface{}, error) {
|
|
if cb.EOF() {
|
|
return nil, nil, io.EOF
|
|
}
|
|
tagNumber, wireType, err := cb.DecodeTagAndWireType()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if wireType == proto.WireEndGroup {
|
|
return nil, tagNumber, ErrWireTypeEndGroup
|
|
}
|
|
fd := fieldFinder(tagNumber)
|
|
if fd == nil {
|
|
val, err := cb.decodeUnknownField(tagNumber, wireType)
|
|
return nil, val, err
|
|
}
|
|
val, err := cb.decodeKnownField(fd, wireType, fact)
|
|
return fd, val, err
|
|
}
|
|
|
|
// DecodeScalarField extracts a properly-typed value from v. The returned value's
|
|
// type depends on the given field descriptor type. It will be the same type as
|
|
// generated structs use for the field descriptor's type. Enum types will return
|
|
// an int32. If the given field type uses length-delimited encoding (nested
|
|
// messages, bytes, and strings), an error is returned.
|
|
func DecodeScalarField(fd *desc.FieldDescriptor, v uint64) (interface{}, error) {
|
|
switch fd.GetType() {
|
|
case descriptor.FieldDescriptorProto_TYPE_BOOL:
|
|
return v != 0, nil
|
|
case descriptor.FieldDescriptorProto_TYPE_UINT32,
|
|
descriptor.FieldDescriptorProto_TYPE_FIXED32:
|
|
if v > math.MaxUint32 {
|
|
return nil, ErrOverflow
|
|
}
|
|
return uint32(v), nil
|
|
|
|
case descriptor.FieldDescriptorProto_TYPE_INT32,
|
|
descriptor.FieldDescriptorProto_TYPE_ENUM:
|
|
s := int64(v)
|
|
if s > math.MaxInt32 || s < math.MinInt32 {
|
|
return nil, ErrOverflow
|
|
}
|
|
return int32(s), nil
|
|
|
|
case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
|
|
if v > math.MaxUint32 {
|
|
return nil, ErrOverflow
|
|
}
|
|
return int32(v), nil
|
|
|
|
case descriptor.FieldDescriptorProto_TYPE_SINT32:
|
|
if v > math.MaxUint32 {
|
|
return nil, ErrOverflow
|
|
}
|
|
return DecodeZigZag32(v), nil
|
|
|
|
case descriptor.FieldDescriptorProto_TYPE_UINT64,
|
|
descriptor.FieldDescriptorProto_TYPE_FIXED64:
|
|
return v, nil
|
|
|
|
case descriptor.FieldDescriptorProto_TYPE_INT64,
|
|
descriptor.FieldDescriptorProto_TYPE_SFIXED64:
|
|
return int64(v), nil
|
|
|
|
case descriptor.FieldDescriptorProto_TYPE_SINT64:
|
|
return DecodeZigZag64(v), nil
|
|
|
|
case descriptor.FieldDescriptorProto_TYPE_FLOAT:
|
|
if v > math.MaxUint32 {
|
|
return nil, ErrOverflow
|
|
}
|
|
return math.Float32frombits(uint32(v)), nil
|
|
|
|
case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
|
|
return math.Float64frombits(v), nil
|
|
|
|
default:
|
|
// bytes, string, message, and group cannot be represented as a simple numeric value
|
|
return nil, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName())
|
|
}
|
|
}
|
|
|
|
// DecodeLengthDelimitedField extracts a properly-typed value from bytes. The
|
|
// returned value's type will usually be []byte, string, or, for nested messages,
|
|
// the type returned from the given message factory. However, since repeated
|
|
// scalar fields can be length-delimited, when they used packed encoding, it can
|
|
// also return an []interface{}, where each element is a scalar value. Furthermore,
|
|
// it could return a scalar type, not in a slice, if the given field descriptor is
|
|
// not repeated. This is to support cases where a field is changed from optional
|
|
// to repeated. New code may emit a packed repeated representation, but old code
|
|
// still expects a single scalar value. In this case, if the actual data in bytes
|
|
// contains multiple values, only the last value is returned.
|
|
func DecodeLengthDelimitedField(fd *desc.FieldDescriptor, bytes []byte, mf MessageFactory) (interface{}, error) {
|
|
switch {
|
|
case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES:
|
|
return bytes, nil
|
|
|
|
case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_STRING:
|
|
return string(bytes), nil
|
|
|
|
case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE ||
|
|
fd.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP:
|
|
msg := mf.NewMessage(fd.GetMessageType())
|
|
err := proto.Unmarshal(bytes, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
} else {
|
|
return msg, nil
|
|
}
|
|
|
|
default:
|
|
// even if the field is not repeated or not packed, we still parse it as such for
|
|
// backwards compatibility (e.g. message we are de-serializing could have been both
|
|
// repeated and packed at the time of serialization)
|
|
packedBuf := NewBuffer(bytes)
|
|
var slice []interface{}
|
|
var val interface{}
|
|
for !packedBuf.EOF() {
|
|
var v uint64
|
|
var err error
|
|
if varintTypes[fd.GetType()] {
|
|
v, err = packedBuf.DecodeVarint()
|
|
} else if fixed32Types[fd.GetType()] {
|
|
v, err = packedBuf.DecodeFixed32()
|
|
} else if fixed64Types[fd.GetType()] {
|
|
v, err = packedBuf.DecodeFixed64()
|
|
} else {
|
|
return nil, fmt.Errorf("bad input; cannot parse length-delimited wire type for field %s", fd.GetFullyQualifiedName())
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
val, err = DecodeScalarField(fd, v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if fd.IsRepeated() {
|
|
slice = append(slice, val)
|
|
}
|
|
}
|
|
if fd.IsRepeated() {
|
|
return slice, nil
|
|
} else {
|
|
// if not a repeated field, last value wins
|
|
return val, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *Buffer) decodeKnownField(fd *desc.FieldDescriptor, encoding int8, fact MessageFactory) (interface{}, error) {
|
|
var val interface{}
|
|
var err error
|
|
switch encoding {
|
|
case proto.WireFixed32:
|
|
var num uint64
|
|
num, err = b.DecodeFixed32()
|
|
if err == nil {
|
|
val, err = DecodeScalarField(fd, num)
|
|
}
|
|
case proto.WireFixed64:
|
|
var num uint64
|
|
num, err = b.DecodeFixed64()
|
|
if err == nil {
|
|
val, err = DecodeScalarField(fd, num)
|
|
}
|
|
case proto.WireVarint:
|
|
var num uint64
|
|
num, err = b.DecodeVarint()
|
|
if err == nil {
|
|
val, err = DecodeScalarField(fd, num)
|
|
}
|
|
|
|
case proto.WireBytes:
|
|
alloc := fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES
|
|
var raw []byte
|
|
raw, err = b.DecodeRawBytes(alloc)
|
|
if err == nil {
|
|
val, err = DecodeLengthDelimitedField(fd, raw, fact)
|
|
}
|
|
|
|
case proto.WireStartGroup:
|
|
if fd.GetMessageType() == nil {
|
|
return nil, fmt.Errorf("cannot parse field %s from group-encoded wire type", fd.GetFullyQualifiedName())
|
|
}
|
|
msg := fact.NewMessage(fd.GetMessageType())
|
|
var data []byte
|
|
data, err = b.ReadGroup(false)
|
|
if err == nil {
|
|
err = proto.Unmarshal(data, msg)
|
|
if err == nil {
|
|
val = msg
|
|
}
|
|
}
|
|
|
|
default:
|
|
return nil, ErrBadWireType
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return val, nil
|
|
}
|
|
|
|
func (b *Buffer) decodeUnknownField(tagNumber int32, encoding int8) (interface{}, error) {
|
|
u := UnknownField{Tag: tagNumber, Encoding: encoding}
|
|
var err error
|
|
switch encoding {
|
|
case proto.WireFixed32:
|
|
u.Value, err = b.DecodeFixed32()
|
|
case proto.WireFixed64:
|
|
u.Value, err = b.DecodeFixed64()
|
|
case proto.WireVarint:
|
|
u.Value, err = b.DecodeVarint()
|
|
case proto.WireBytes:
|
|
u.Contents, err = b.DecodeRawBytes(true)
|
|
case proto.WireStartGroup:
|
|
u.Contents, err = b.ReadGroup(true)
|
|
default:
|
|
err = ErrBadWireType
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return u, nil
|
|
}
|