// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package proto import ( "errors" "fmt" "reflect" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/runtime/protoiface" "google.golang.org/protobuf/runtime/protoimpl" ) type ( // ExtensionDesc represents an extension descriptor and // is used to interact with an extension field in a message. // // Variables of this type are generated in code by protoc-gen-go. ExtensionDesc = protoimpl.ExtensionInfo // ExtensionRange represents a range of message extensions. // Used in code generated by protoc-gen-go. ExtensionRange = protoiface.ExtensionRangeV1 // Deprecated: Do not use; this is an internal type. Extension = protoimpl.ExtensionFieldV1 // Deprecated: Do not use; this is an internal type. XXX_InternalExtensions = protoimpl.ExtensionFields ) // ErrMissingExtension reports whether the extension was not present. var ErrMissingExtension = errors.New("proto: missing extension") var errNotExtendable = errors.New("proto: not an extendable proto.Message") // HasExtension reports whether the extension field is present in m // either as an explicitly populated field or as an unknown field. func HasExtension(m Message, xt *ExtensionDesc) (has bool) { mr := MessageReflect(m) if mr == nil || !mr.IsValid() { return false } // Check whether any populated known field matches the field number. xtd := xt.TypeDescriptor() if isValidExtension(mr.Descriptor(), xtd) { has = mr.Has(xtd) } else { mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { has = int32(fd.Number()) == xt.Field return !has }) } // Check whether any unknown field matches the field number. for b := mr.GetUnknown(); !has && len(b) > 0; { num, _, n := protowire.ConsumeField(b) has = int32(num) == xt.Field b = b[n:] } return has } // ClearExtension removes the extension field from m // either as an explicitly populated field or as an unknown field. func ClearExtension(m Message, xt *ExtensionDesc) { mr := MessageReflect(m) if mr == nil || !mr.IsValid() { return } xtd := xt.TypeDescriptor() if isValidExtension(mr.Descriptor(), xtd) { mr.Clear(xtd) } else { mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { if int32(fd.Number()) == xt.Field { mr.Clear(fd) return false } return true }) } clearUnknown(mr, fieldNum(xt.Field)) } // ClearAllExtensions clears all extensions from m. // This includes populated fields and unknown fields in the extension range. func ClearAllExtensions(m Message) { mr := MessageReflect(m) if mr == nil || !mr.IsValid() { return } mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { if fd.IsExtension() { mr.Clear(fd) } return true }) clearUnknown(mr, mr.Descriptor().ExtensionRanges()) } // GetExtension retrieves a proto2 extended field from m. // // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil), // then GetExtension parses the encoded field and returns a Go value of the specified type. // If the field is not present, then the default value is returned (if one is specified), // otherwise ErrMissingExtension is reported. // // If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil), // then GetExtension returns the raw encoded bytes for the extension field. func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) { mr := MessageReflect(m) if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 { return nil, errNotExtendable } // Retrieve the unknown fields for this extension field. var bo protoreflect.RawFields for bi := mr.GetUnknown(); len(bi) > 0; { num, _, n := protowire.ConsumeField(bi) if int32(num) == xt.Field { bo = append(bo, bi[:n]...) } bi = bi[n:] } // For type incomplete descriptors, only retrieve the unknown fields. if xt.ExtensionType == nil { return []byte(bo), nil } // If the extension field only exists as unknown fields, unmarshal it. // This is rarely done since proto.Unmarshal eagerly unmarshals extensions. xtd := xt.TypeDescriptor() if !isValidExtension(mr.Descriptor(), xtd) { return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m) } if !mr.Has(xtd) && len(bo) > 0 { m2 := mr.New() if err := (proto.UnmarshalOptions{ Resolver: extensionResolver{xt}, }.Unmarshal(bo, m2.Interface())); err != nil { return nil, err } if m2.Has(xtd) { mr.Set(xtd, m2.Get(xtd)) clearUnknown(mr, fieldNum(xt.Field)) } } // Check whether the message has the extension field set or a default. var pv protoreflect.Value switch { case mr.Has(xtd): pv = mr.Get(xtd) case xtd.HasDefault(): pv = xtd.Default() default: return nil, ErrMissingExtension } v := xt.InterfaceOf(pv) rv := reflect.ValueOf(v) if isScalarKind(rv.Kind()) { rv2 := reflect.New(rv.Type()) rv2.Elem().Set(rv) v = rv2.Interface() } return v, nil } // extensionResolver is a custom extension resolver that stores a single // extension type that takes precedence over the global registry. type extensionResolver struct{ xt protoreflect.ExtensionType } func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field { return r.xt, nil } return protoregistry.GlobalTypes.FindExtensionByName(field) } func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field { return r.xt, nil } return protoregistry.GlobalTypes.FindExtensionByNumber(message, field) } // GetExtensions returns a list of the extensions values present in m, // corresponding with the provided list of extension descriptors, xts. // If an extension is missing in m, the corresponding value is nil. func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) { mr := MessageReflect(m) if mr == nil || !mr.IsValid() { return nil, errNotExtendable } vs := make([]interface{}, len(xts)) for i, xt := range xts { v, err := GetExtension(m, xt) if err != nil { if err == ErrMissingExtension { continue } return vs, err } vs[i] = v } return vs, nil } // SetExtension sets an extension field in m to the provided value. func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error { mr := MessageReflect(m) if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 { return errNotExtendable } rv := reflect.ValueOf(v) if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) { return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType) } if rv.Kind() == reflect.Ptr { if rv.IsNil() { return fmt.Errorf("proto: SetExtension called with nil value of type %T", v) } if isScalarKind(rv.Elem().Kind()) { v = rv.Elem().Interface() } } xtd := xt.TypeDescriptor() if !isValidExtension(mr.Descriptor(), xtd) { return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m) } mr.Set(xtd, xt.ValueOf(v)) clearUnknown(mr, fieldNum(xt.Field)) return nil } // SetRawExtension inserts b into the unknown fields of m. // // Deprecated: Use Message.ProtoReflect.SetUnknown instead. func SetRawExtension(m Message, fnum int32, b []byte) { mr := MessageReflect(m) if mr == nil || !mr.IsValid() { return } // Verify that the raw field is valid. for b0 := b; len(b0) > 0; { num, _, n := protowire.ConsumeField(b0) if int32(num) != fnum { panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum)) } b0 = b0[n:] } ClearExtension(m, &ExtensionDesc{Field: fnum}) mr.SetUnknown(append(mr.GetUnknown(), b...)) } // ExtensionDescs returns a list of extension descriptors found in m, // containing descriptors for both populated extension fields in m and // also unknown fields of m that are in the extension range. // For the later case, an type incomplete descriptor is provided where only // the ExtensionDesc.Field field is populated. // The order of the extension descriptors is undefined. func ExtensionDescs(m Message) ([]*ExtensionDesc, error) { mr := MessageReflect(m) if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 { return nil, errNotExtendable } // Collect a set of known extension descriptors. extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc) mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { if fd.IsExtension() { xt := fd.(protoreflect.ExtensionTypeDescriptor) if xd, ok := xt.Type().(*ExtensionDesc); ok { extDescs[fd.Number()] = xd } } return true }) // Collect a set of unknown extension descriptors. extRanges := mr.Descriptor().ExtensionRanges() for b := mr.GetUnknown(); len(b) > 0; { num, _, n := protowire.ConsumeField(b) if extRanges.Has(num) && extDescs[num] == nil { extDescs[num] = nil } b = b[n:] } // Transpose the set of descriptors into a list. var xts []*ExtensionDesc for num, xt := range extDescs { if xt == nil { xt = &ExtensionDesc{Field: int32(num)} } xts = append(xts, xt) } return xts, nil } // isValidExtension reports whether xtd is a valid extension descriptor for md. func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool { return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number()) } // isScalarKind reports whether k is a protobuf scalar kind (except bytes). // This function exists for historical reasons since the representation of // scalars differs between v1 and v2, where v1 uses *T and v2 uses T. func isScalarKind(k reflect.Kind) bool { switch k { case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: return true default: return false } } // clearUnknown removes unknown fields from m where remover.Has reports true. func clearUnknown(m protoreflect.Message, remover interface { Has(protoreflect.FieldNumber) bool }) { var bo protoreflect.RawFields for bi := m.GetUnknown(); len(bi) > 0; { num, _, n := protowire.ConsumeField(bi) if !remover.Has(num) { bo = append(bo, bi[:n]...) } bi = bi[n:] } if bi := m.GetUnknown(); len(bi) != len(bo) { m.SetUnknown(bo) } } type fieldNum protoreflect.FieldNumber func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool { return protoreflect.FieldNumber(n1) == n2 }