1*1c12ee1eSDan Willemsen// Copyright 2018 The Go Authors. All rights reserved. 2*1c12ee1eSDan Willemsen// Use of this source code is governed by a BSD-style 3*1c12ee1eSDan Willemsen// license that can be found in the LICENSE file. 4*1c12ee1eSDan Willemsen 5*1c12ee1eSDan Willemsenpackage proto 6*1c12ee1eSDan Willemsen 7*1c12ee1eSDan Willemsenimport ( 8*1c12ee1eSDan Willemsen "google.golang.org/protobuf/encoding/protowire" 9*1c12ee1eSDan Willemsen "google.golang.org/protobuf/internal/encoding/messageset" 10*1c12ee1eSDan Willemsen "google.golang.org/protobuf/internal/errors" 11*1c12ee1eSDan Willemsen "google.golang.org/protobuf/internal/flags" 12*1c12ee1eSDan Willemsen "google.golang.org/protobuf/internal/genid" 13*1c12ee1eSDan Willemsen "google.golang.org/protobuf/internal/pragma" 14*1c12ee1eSDan Willemsen "google.golang.org/protobuf/reflect/protoreflect" 15*1c12ee1eSDan Willemsen "google.golang.org/protobuf/reflect/protoregistry" 16*1c12ee1eSDan Willemsen "google.golang.org/protobuf/runtime/protoiface" 17*1c12ee1eSDan Willemsen) 18*1c12ee1eSDan Willemsen 19*1c12ee1eSDan Willemsen// UnmarshalOptions configures the unmarshaler. 20*1c12ee1eSDan Willemsen// 21*1c12ee1eSDan Willemsen// Example usage: 22*1c12ee1eSDan Willemsen// 23*1c12ee1eSDan Willemsen// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m) 24*1c12ee1eSDan Willemsentype UnmarshalOptions struct { 25*1c12ee1eSDan Willemsen pragma.NoUnkeyedLiterals 26*1c12ee1eSDan Willemsen 27*1c12ee1eSDan Willemsen // Merge merges the input into the destination message. 28*1c12ee1eSDan Willemsen // The default behavior is to always reset the message before unmarshaling, 29*1c12ee1eSDan Willemsen // unless Merge is specified. 30*1c12ee1eSDan Willemsen Merge bool 31*1c12ee1eSDan Willemsen 32*1c12ee1eSDan Willemsen // AllowPartial accepts input for messages that will result in missing 33*1c12ee1eSDan Willemsen // required fields. If AllowPartial is false (the default), Unmarshal will 34*1c12ee1eSDan Willemsen // return an error if there are any missing required fields. 35*1c12ee1eSDan Willemsen AllowPartial bool 36*1c12ee1eSDan Willemsen 37*1c12ee1eSDan Willemsen // If DiscardUnknown is set, unknown fields are ignored. 38*1c12ee1eSDan Willemsen DiscardUnknown bool 39*1c12ee1eSDan Willemsen 40*1c12ee1eSDan Willemsen // Resolver is used for looking up types when unmarshaling extension fields. 41*1c12ee1eSDan Willemsen // If nil, this defaults to using protoregistry.GlobalTypes. 42*1c12ee1eSDan Willemsen Resolver interface { 43*1c12ee1eSDan Willemsen FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) 44*1c12ee1eSDan Willemsen FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) 45*1c12ee1eSDan Willemsen } 46*1c12ee1eSDan Willemsen 47*1c12ee1eSDan Willemsen // RecursionLimit limits how deeply messages may be nested. 48*1c12ee1eSDan Willemsen // If zero, a default limit is applied. 49*1c12ee1eSDan Willemsen RecursionLimit int 50*1c12ee1eSDan Willemsen} 51*1c12ee1eSDan Willemsen 52*1c12ee1eSDan Willemsen// Unmarshal parses the wire-format message in b and places the result in m. 53*1c12ee1eSDan Willemsen// The provided message must be mutable (e.g., a non-nil pointer to a message). 54*1c12ee1eSDan Willemsenfunc Unmarshal(b []byte, m Message) error { 55*1c12ee1eSDan Willemsen _, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect()) 56*1c12ee1eSDan Willemsen return err 57*1c12ee1eSDan Willemsen} 58*1c12ee1eSDan Willemsen 59*1c12ee1eSDan Willemsen// Unmarshal parses the wire-format message in b and places the result in m. 60*1c12ee1eSDan Willemsen// The provided message must be mutable (e.g., a non-nil pointer to a message). 61*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) Unmarshal(b []byte, m Message) error { 62*1c12ee1eSDan Willemsen if o.RecursionLimit == 0 { 63*1c12ee1eSDan Willemsen o.RecursionLimit = protowire.DefaultRecursionLimit 64*1c12ee1eSDan Willemsen } 65*1c12ee1eSDan Willemsen _, err := o.unmarshal(b, m.ProtoReflect()) 66*1c12ee1eSDan Willemsen return err 67*1c12ee1eSDan Willemsen} 68*1c12ee1eSDan Willemsen 69*1c12ee1eSDan Willemsen// UnmarshalState parses a wire-format message and places the result in m. 70*1c12ee1eSDan Willemsen// 71*1c12ee1eSDan Willemsen// This method permits fine-grained control over the unmarshaler. 72*1c12ee1eSDan Willemsen// Most users should use Unmarshal instead. 73*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) { 74*1c12ee1eSDan Willemsen if o.RecursionLimit == 0 { 75*1c12ee1eSDan Willemsen o.RecursionLimit = protowire.DefaultRecursionLimit 76*1c12ee1eSDan Willemsen } 77*1c12ee1eSDan Willemsen return o.unmarshal(in.Buf, in.Message) 78*1c12ee1eSDan Willemsen} 79*1c12ee1eSDan Willemsen 80*1c12ee1eSDan Willemsen// unmarshal is a centralized function that all unmarshal operations go through. 81*1c12ee1eSDan Willemsen// For profiling purposes, avoid changing the name of this function or 82*1c12ee1eSDan Willemsen// introducing other code paths for unmarshal that do not go through this. 83*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) { 84*1c12ee1eSDan Willemsen if o.Resolver == nil { 85*1c12ee1eSDan Willemsen o.Resolver = protoregistry.GlobalTypes 86*1c12ee1eSDan Willemsen } 87*1c12ee1eSDan Willemsen if !o.Merge { 88*1c12ee1eSDan Willemsen Reset(m.Interface()) 89*1c12ee1eSDan Willemsen } 90*1c12ee1eSDan Willemsen allowPartial := o.AllowPartial 91*1c12ee1eSDan Willemsen o.Merge = true 92*1c12ee1eSDan Willemsen o.AllowPartial = true 93*1c12ee1eSDan Willemsen methods := protoMethods(m) 94*1c12ee1eSDan Willemsen if methods != nil && methods.Unmarshal != nil && 95*1c12ee1eSDan Willemsen !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) { 96*1c12ee1eSDan Willemsen in := protoiface.UnmarshalInput{ 97*1c12ee1eSDan Willemsen Message: m, 98*1c12ee1eSDan Willemsen Buf: b, 99*1c12ee1eSDan Willemsen Resolver: o.Resolver, 100*1c12ee1eSDan Willemsen Depth: o.RecursionLimit, 101*1c12ee1eSDan Willemsen } 102*1c12ee1eSDan Willemsen if o.DiscardUnknown { 103*1c12ee1eSDan Willemsen in.Flags |= protoiface.UnmarshalDiscardUnknown 104*1c12ee1eSDan Willemsen } 105*1c12ee1eSDan Willemsen out, err = methods.Unmarshal(in) 106*1c12ee1eSDan Willemsen } else { 107*1c12ee1eSDan Willemsen o.RecursionLimit-- 108*1c12ee1eSDan Willemsen if o.RecursionLimit < 0 { 109*1c12ee1eSDan Willemsen return out, errors.New("exceeded max recursion depth") 110*1c12ee1eSDan Willemsen } 111*1c12ee1eSDan Willemsen err = o.unmarshalMessageSlow(b, m) 112*1c12ee1eSDan Willemsen } 113*1c12ee1eSDan Willemsen if err != nil { 114*1c12ee1eSDan Willemsen return out, err 115*1c12ee1eSDan Willemsen } 116*1c12ee1eSDan Willemsen if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) { 117*1c12ee1eSDan Willemsen return out, nil 118*1c12ee1eSDan Willemsen } 119*1c12ee1eSDan Willemsen return out, checkInitialized(m) 120*1c12ee1eSDan Willemsen} 121*1c12ee1eSDan Willemsen 122*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error { 123*1c12ee1eSDan Willemsen _, err := o.unmarshal(b, m) 124*1c12ee1eSDan Willemsen return err 125*1c12ee1eSDan Willemsen} 126*1c12ee1eSDan Willemsen 127*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error { 128*1c12ee1eSDan Willemsen md := m.Descriptor() 129*1c12ee1eSDan Willemsen if messageset.IsMessageSet(md) { 130*1c12ee1eSDan Willemsen return o.unmarshalMessageSet(b, m) 131*1c12ee1eSDan Willemsen } 132*1c12ee1eSDan Willemsen fields := md.Fields() 133*1c12ee1eSDan Willemsen for len(b) > 0 { 134*1c12ee1eSDan Willemsen // Parse the tag (field number and wire type). 135*1c12ee1eSDan Willemsen num, wtyp, tagLen := protowire.ConsumeTag(b) 136*1c12ee1eSDan Willemsen if tagLen < 0 { 137*1c12ee1eSDan Willemsen return errDecode 138*1c12ee1eSDan Willemsen } 139*1c12ee1eSDan Willemsen if num > protowire.MaxValidNumber { 140*1c12ee1eSDan Willemsen return errDecode 141*1c12ee1eSDan Willemsen } 142*1c12ee1eSDan Willemsen 143*1c12ee1eSDan Willemsen // Find the field descriptor for this field number. 144*1c12ee1eSDan Willemsen fd := fields.ByNumber(num) 145*1c12ee1eSDan Willemsen if fd == nil && md.ExtensionRanges().Has(num) { 146*1c12ee1eSDan Willemsen extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num) 147*1c12ee1eSDan Willemsen if err != nil && err != protoregistry.NotFound { 148*1c12ee1eSDan Willemsen return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err) 149*1c12ee1eSDan Willemsen } 150*1c12ee1eSDan Willemsen if extType != nil { 151*1c12ee1eSDan Willemsen fd = extType.TypeDescriptor() 152*1c12ee1eSDan Willemsen } 153*1c12ee1eSDan Willemsen } 154*1c12ee1eSDan Willemsen var err error 155*1c12ee1eSDan Willemsen if fd == nil { 156*1c12ee1eSDan Willemsen err = errUnknown 157*1c12ee1eSDan Willemsen } else if flags.ProtoLegacy { 158*1c12ee1eSDan Willemsen if fd.IsWeak() && fd.Message().IsPlaceholder() { 159*1c12ee1eSDan Willemsen err = errUnknown // weak referent is not linked in 160*1c12ee1eSDan Willemsen } 161*1c12ee1eSDan Willemsen } 162*1c12ee1eSDan Willemsen 163*1c12ee1eSDan Willemsen // Parse the field value. 164*1c12ee1eSDan Willemsen var valLen int 165*1c12ee1eSDan Willemsen switch { 166*1c12ee1eSDan Willemsen case err != nil: 167*1c12ee1eSDan Willemsen case fd.IsList(): 168*1c12ee1eSDan Willemsen valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd) 169*1c12ee1eSDan Willemsen case fd.IsMap(): 170*1c12ee1eSDan Willemsen valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd) 171*1c12ee1eSDan Willemsen default: 172*1c12ee1eSDan Willemsen valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd) 173*1c12ee1eSDan Willemsen } 174*1c12ee1eSDan Willemsen if err != nil { 175*1c12ee1eSDan Willemsen if err != errUnknown { 176*1c12ee1eSDan Willemsen return err 177*1c12ee1eSDan Willemsen } 178*1c12ee1eSDan Willemsen valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:]) 179*1c12ee1eSDan Willemsen if valLen < 0 { 180*1c12ee1eSDan Willemsen return errDecode 181*1c12ee1eSDan Willemsen } 182*1c12ee1eSDan Willemsen if !o.DiscardUnknown { 183*1c12ee1eSDan Willemsen m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...)) 184*1c12ee1eSDan Willemsen } 185*1c12ee1eSDan Willemsen } 186*1c12ee1eSDan Willemsen b = b[tagLen+valLen:] 187*1c12ee1eSDan Willemsen } 188*1c12ee1eSDan Willemsen return nil 189*1c12ee1eSDan Willemsen} 190*1c12ee1eSDan Willemsen 191*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) { 192*1c12ee1eSDan Willemsen v, n, err := o.unmarshalScalar(b, wtyp, fd) 193*1c12ee1eSDan Willemsen if err != nil { 194*1c12ee1eSDan Willemsen return 0, err 195*1c12ee1eSDan Willemsen } 196*1c12ee1eSDan Willemsen switch fd.Kind() { 197*1c12ee1eSDan Willemsen case protoreflect.GroupKind, protoreflect.MessageKind: 198*1c12ee1eSDan Willemsen m2 := m.Mutable(fd).Message() 199*1c12ee1eSDan Willemsen if err := o.unmarshalMessage(v.Bytes(), m2); err != nil { 200*1c12ee1eSDan Willemsen return n, err 201*1c12ee1eSDan Willemsen } 202*1c12ee1eSDan Willemsen default: 203*1c12ee1eSDan Willemsen // Non-message scalars replace the previous value. 204*1c12ee1eSDan Willemsen m.Set(fd, v) 205*1c12ee1eSDan Willemsen } 206*1c12ee1eSDan Willemsen return n, nil 207*1c12ee1eSDan Willemsen} 208*1c12ee1eSDan Willemsen 209*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) { 210*1c12ee1eSDan Willemsen if wtyp != protowire.BytesType { 211*1c12ee1eSDan Willemsen return 0, errUnknown 212*1c12ee1eSDan Willemsen } 213*1c12ee1eSDan Willemsen b, n = protowire.ConsumeBytes(b) 214*1c12ee1eSDan Willemsen if n < 0 { 215*1c12ee1eSDan Willemsen return 0, errDecode 216*1c12ee1eSDan Willemsen } 217*1c12ee1eSDan Willemsen var ( 218*1c12ee1eSDan Willemsen keyField = fd.MapKey() 219*1c12ee1eSDan Willemsen valField = fd.MapValue() 220*1c12ee1eSDan Willemsen key protoreflect.Value 221*1c12ee1eSDan Willemsen val protoreflect.Value 222*1c12ee1eSDan Willemsen haveKey bool 223*1c12ee1eSDan Willemsen haveVal bool 224*1c12ee1eSDan Willemsen ) 225*1c12ee1eSDan Willemsen switch valField.Kind() { 226*1c12ee1eSDan Willemsen case protoreflect.GroupKind, protoreflect.MessageKind: 227*1c12ee1eSDan Willemsen val = mapv.NewValue() 228*1c12ee1eSDan Willemsen } 229*1c12ee1eSDan Willemsen // Map entries are represented as a two-element message with fields 230*1c12ee1eSDan Willemsen // containing the key and value. 231*1c12ee1eSDan Willemsen for len(b) > 0 { 232*1c12ee1eSDan Willemsen num, wtyp, n := protowire.ConsumeTag(b) 233*1c12ee1eSDan Willemsen if n < 0 { 234*1c12ee1eSDan Willemsen return 0, errDecode 235*1c12ee1eSDan Willemsen } 236*1c12ee1eSDan Willemsen if num > protowire.MaxValidNumber { 237*1c12ee1eSDan Willemsen return 0, errDecode 238*1c12ee1eSDan Willemsen } 239*1c12ee1eSDan Willemsen b = b[n:] 240*1c12ee1eSDan Willemsen err = errUnknown 241*1c12ee1eSDan Willemsen switch num { 242*1c12ee1eSDan Willemsen case genid.MapEntry_Key_field_number: 243*1c12ee1eSDan Willemsen key, n, err = o.unmarshalScalar(b, wtyp, keyField) 244*1c12ee1eSDan Willemsen if err != nil { 245*1c12ee1eSDan Willemsen break 246*1c12ee1eSDan Willemsen } 247*1c12ee1eSDan Willemsen haveKey = true 248*1c12ee1eSDan Willemsen case genid.MapEntry_Value_field_number: 249*1c12ee1eSDan Willemsen var v protoreflect.Value 250*1c12ee1eSDan Willemsen v, n, err = o.unmarshalScalar(b, wtyp, valField) 251*1c12ee1eSDan Willemsen if err != nil { 252*1c12ee1eSDan Willemsen break 253*1c12ee1eSDan Willemsen } 254*1c12ee1eSDan Willemsen switch valField.Kind() { 255*1c12ee1eSDan Willemsen case protoreflect.GroupKind, protoreflect.MessageKind: 256*1c12ee1eSDan Willemsen if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil { 257*1c12ee1eSDan Willemsen return 0, err 258*1c12ee1eSDan Willemsen } 259*1c12ee1eSDan Willemsen default: 260*1c12ee1eSDan Willemsen val = v 261*1c12ee1eSDan Willemsen } 262*1c12ee1eSDan Willemsen haveVal = true 263*1c12ee1eSDan Willemsen } 264*1c12ee1eSDan Willemsen if err == errUnknown { 265*1c12ee1eSDan Willemsen n = protowire.ConsumeFieldValue(num, wtyp, b) 266*1c12ee1eSDan Willemsen if n < 0 { 267*1c12ee1eSDan Willemsen return 0, errDecode 268*1c12ee1eSDan Willemsen } 269*1c12ee1eSDan Willemsen } else if err != nil { 270*1c12ee1eSDan Willemsen return 0, err 271*1c12ee1eSDan Willemsen } 272*1c12ee1eSDan Willemsen b = b[n:] 273*1c12ee1eSDan Willemsen } 274*1c12ee1eSDan Willemsen // Every map entry should have entries for key and value, but this is not strictly required. 275*1c12ee1eSDan Willemsen if !haveKey { 276*1c12ee1eSDan Willemsen key = keyField.Default() 277*1c12ee1eSDan Willemsen } 278*1c12ee1eSDan Willemsen if !haveVal { 279*1c12ee1eSDan Willemsen switch valField.Kind() { 280*1c12ee1eSDan Willemsen case protoreflect.GroupKind, protoreflect.MessageKind: 281*1c12ee1eSDan Willemsen default: 282*1c12ee1eSDan Willemsen val = valField.Default() 283*1c12ee1eSDan Willemsen } 284*1c12ee1eSDan Willemsen } 285*1c12ee1eSDan Willemsen mapv.Set(key.MapKey(), val) 286*1c12ee1eSDan Willemsen return n, nil 287*1c12ee1eSDan Willemsen} 288*1c12ee1eSDan Willemsen 289*1c12ee1eSDan Willemsen// errUnknown is used internally to indicate fields which should be added 290*1c12ee1eSDan Willemsen// to the unknown field set of a message. It is never returned from an exported 291*1c12ee1eSDan Willemsen// function. 292*1c12ee1eSDan Willemsenvar errUnknown = errors.New("BUG: internal error (unknown)") 293*1c12ee1eSDan Willemsen 294*1c12ee1eSDan Willemsenvar errDecode = errors.New("cannot parse invalid wire-format data") 295