package codec import ( "fmt" "math" "reflect" "sort" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" "github.com/jhump/protoreflect/desc" ) // EncodeZigZag64 does zig-zag encoding to convert the given // signed 64-bit integer into a form that can be expressed // efficiently as a varint, even for negative values. func EncodeZigZag64(v int64) uint64 { return (uint64(v) << 1) ^ uint64(v>>63) } // EncodeZigZag32 does zig-zag encoding to convert the given // signed 32-bit integer into a form that can be expressed // efficiently as a varint, even for negative values. func EncodeZigZag32(v int32) uint64 { return uint64((uint32(v) << 1) ^ uint32((v >> 31))) } func (cb *Buffer) EncodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error { if fd.IsMap() { mp := val.(map[interface{}]interface{}) entryType := fd.GetMessageType() keyType := entryType.FindFieldByNumber(1) valType := entryType.FindFieldByNumber(2) var entryBuffer Buffer if cb.IsDeterministic() { entryBuffer.SetDeterministic(true) keys := make([]interface{}, 0, len(mp)) for k := range mp { keys = append(keys, k) } sort.Sort(sortable(keys)) for _, k := range keys { v := mp[k] entryBuffer.Reset() if err := entryBuffer.encodeFieldElement(keyType, k); err != nil { return err } rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr || !rv.IsNil() { if err := entryBuffer.encodeFieldElement(valType, v); err != nil { return err } } if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { return err } if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil { return err } } } else { for k, v := range mp { entryBuffer.Reset() if err := entryBuffer.encodeFieldElement(keyType, k); err != nil { return err } rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr || !rv.IsNil() { if err := entryBuffer.encodeFieldElement(valType, v); err != nil { return err } } if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { return err } if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil { return err } } } return nil } else if fd.IsRepeated() { sl := val.([]interface{}) wt, err := getWireType(fd.GetType()) if err != nil { return err } if isPacked(fd) && len(sl) > 0 && (wt == proto.WireVarint || wt == proto.WireFixed32 || wt == proto.WireFixed64) { // packed repeated field var packedBuffer Buffer for _, v := range sl { if err := packedBuffer.encodeFieldValue(fd, v); err != nil { return err } } if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { return err } return cb.EncodeRawBytes(packedBuffer.Bytes()) } else { // non-packed repeated field for _, v := range sl { if err := cb.encodeFieldElement(fd, v); err != nil { return err } } return nil } } else { return cb.encodeFieldElement(fd, val) } } func isPacked(fd *desc.FieldDescriptor) bool { opts := fd.AsFieldDescriptorProto().GetOptions() // if set, use that value if opts != nil && opts.Packed != nil { return opts.GetPacked() } // if unset: proto2 defaults to false, proto3 to true return fd.GetFile().IsProto3() } // sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64), // bools, or strings. type sortable []interface{} func (s sortable) Len() int { return len(s) } func (s sortable) Less(i, j int) bool { vi := s[i] vj := s[j] switch reflect.TypeOf(vi).Kind() { case reflect.Int32: return vi.(int32) < vj.(int32) case reflect.Int64: return vi.(int64) < vj.(int64) case reflect.Uint32: return vi.(uint32) < vj.(uint32) case reflect.Uint64: return vi.(uint64) < vj.(uint64) case reflect.String: return vi.(string) < vj.(string) case reflect.Bool: return !vi.(bool) && vj.(bool) default: panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi))) } } func (s sortable) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func (b *Buffer) encodeFieldElement(fd *desc.FieldDescriptor, val interface{}) error { wt, err := getWireType(fd.GetType()) if err != nil { return err } if err := b.EncodeTagAndWireType(fd.GetNumber(), wt); err != nil { return err } if err := b.encodeFieldValue(fd, val); err != nil { return err } if wt == proto.WireStartGroup { return b.EncodeTagAndWireType(fd.GetNumber(), proto.WireEndGroup) } return nil } func (b *Buffer) encodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error { switch fd.GetType() { case descriptor.FieldDescriptorProto_TYPE_BOOL: v := val.(bool) if v { return b.EncodeVarint(1) } return b.EncodeVarint(0) case descriptor.FieldDescriptorProto_TYPE_ENUM, descriptor.FieldDescriptorProto_TYPE_INT32: v := val.(int32) return b.EncodeVarint(uint64(v)) case descriptor.FieldDescriptorProto_TYPE_SFIXED32: v := val.(int32) return b.EncodeFixed32(uint64(v)) case descriptor.FieldDescriptorProto_TYPE_SINT32: v := val.(int32) return b.EncodeVarint(EncodeZigZag32(v)) case descriptor.FieldDescriptorProto_TYPE_UINT32: v := val.(uint32) return b.EncodeVarint(uint64(v)) case descriptor.FieldDescriptorProto_TYPE_FIXED32: v := val.(uint32) return b.EncodeFixed32(uint64(v)) case descriptor.FieldDescriptorProto_TYPE_INT64: v := val.(int64) return b.EncodeVarint(uint64(v)) case descriptor.FieldDescriptorProto_TYPE_SFIXED64: v := val.(int64) return b.EncodeFixed64(uint64(v)) case descriptor.FieldDescriptorProto_TYPE_SINT64: v := val.(int64) return b.EncodeVarint(EncodeZigZag64(v)) case descriptor.FieldDescriptorProto_TYPE_UINT64: v := val.(uint64) return b.EncodeVarint(v) case descriptor.FieldDescriptorProto_TYPE_FIXED64: v := val.(uint64) return b.EncodeFixed64(v) case descriptor.FieldDescriptorProto_TYPE_DOUBLE: v := val.(float64) return b.EncodeFixed64(math.Float64bits(v)) case descriptor.FieldDescriptorProto_TYPE_FLOAT: v := val.(float32) return b.EncodeFixed32(uint64(math.Float32bits(v))) case descriptor.FieldDescriptorProto_TYPE_BYTES: v := val.([]byte) return b.EncodeRawBytes(v) case descriptor.FieldDescriptorProto_TYPE_STRING: v := val.(string) return b.EncodeRawBytes(([]byte)(v)) case descriptor.FieldDescriptorProto_TYPE_MESSAGE: return b.EncodeDelimitedMessage(val.(proto.Message)) case descriptor.FieldDescriptorProto_TYPE_GROUP: // just append the nested message to this buffer return b.EncodeMessage(val.(proto.Message)) // whosoever writeth start-group tag (e.g. caller) is responsible for writing end-group tag default: return fmt.Errorf("unrecognized field type: %v", fd.GetType()) } } func getWireType(t descriptor.FieldDescriptorProto_Type) (int8, error) { switch t { case descriptor.FieldDescriptorProto_TYPE_ENUM, descriptor.FieldDescriptorProto_TYPE_BOOL, descriptor.FieldDescriptorProto_TYPE_INT32, descriptor.FieldDescriptorProto_TYPE_SINT32, descriptor.FieldDescriptorProto_TYPE_UINT32, descriptor.FieldDescriptorProto_TYPE_INT64, descriptor.FieldDescriptorProto_TYPE_SINT64, descriptor.FieldDescriptorProto_TYPE_UINT64: return proto.WireVarint, nil case descriptor.FieldDescriptorProto_TYPE_FIXED32, descriptor.FieldDescriptorProto_TYPE_SFIXED32, descriptor.FieldDescriptorProto_TYPE_FLOAT: return proto.WireFixed32, nil case descriptor.FieldDescriptorProto_TYPE_FIXED64, descriptor.FieldDescriptorProto_TYPE_SFIXED64, descriptor.FieldDescriptorProto_TYPE_DOUBLE: return proto.WireFixed64, nil case descriptor.FieldDescriptorProto_TYPE_BYTES, descriptor.FieldDescriptorProto_TYPE_STRING, descriptor.FieldDescriptorProto_TYPE_MESSAGE: return proto.WireBytes, nil case descriptor.FieldDescriptorProto_TYPE_GROUP: return proto.WireStartGroup, nil default: return 0, ErrBadWireType } }