1*1c12ee1eSDan Willemsen// Copyright 2019 The Go Authors. All rights reserved. 2*1c12ee1eSDan Willemsen// Use of this source code is governed by a BSD-style 3*1c12ee1eSDan Willemsen// license that can be found in the LICENSE file. 4*1c12ee1eSDan Willemsen 5*1c12ee1eSDan Willemsenpackage impl 6*1c12ee1eSDan Willemsen 7*1c12ee1eSDan Willemsenimport ( 8*1c12ee1eSDan Willemsen "math/bits" 9*1c12ee1eSDan Willemsen 10*1c12ee1eSDan Willemsen "google.golang.org/protobuf/encoding/protowire" 11*1c12ee1eSDan Willemsen "google.golang.org/protobuf/internal/errors" 12*1c12ee1eSDan Willemsen "google.golang.org/protobuf/internal/flags" 13*1c12ee1eSDan Willemsen "google.golang.org/protobuf/proto" 14*1c12ee1eSDan Willemsen "google.golang.org/protobuf/reflect/protoreflect" 15*1c12ee1eSDan Willemsen "google.golang.org/protobuf/reflect/protoregistry" 16*1c12ee1eSDan Willemsen "google.golang.org/protobuf/runtime/protoiface" 17*1c12ee1eSDan Willemsen) 18*1c12ee1eSDan Willemsen 19*1c12ee1eSDan Willemsenvar errDecode = errors.New("cannot parse invalid wire-format data") 20*1c12ee1eSDan Willemsenvar errRecursionDepth = errors.New("exceeded maximum recursion depth") 21*1c12ee1eSDan Willemsen 22*1c12ee1eSDan Willemsentype unmarshalOptions struct { 23*1c12ee1eSDan Willemsen flags protoiface.UnmarshalInputFlags 24*1c12ee1eSDan Willemsen resolver interface { 25*1c12ee1eSDan Willemsen FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) 26*1c12ee1eSDan Willemsen FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) 27*1c12ee1eSDan Willemsen } 28*1c12ee1eSDan Willemsen depth int 29*1c12ee1eSDan Willemsen} 30*1c12ee1eSDan Willemsen 31*1c12ee1eSDan Willemsenfunc (o unmarshalOptions) Options() proto.UnmarshalOptions { 32*1c12ee1eSDan Willemsen return proto.UnmarshalOptions{ 33*1c12ee1eSDan Willemsen Merge: true, 34*1c12ee1eSDan Willemsen AllowPartial: true, 35*1c12ee1eSDan Willemsen DiscardUnknown: o.DiscardUnknown(), 36*1c12ee1eSDan Willemsen Resolver: o.resolver, 37*1c12ee1eSDan Willemsen } 38*1c12ee1eSDan Willemsen} 39*1c12ee1eSDan Willemsen 40*1c12ee1eSDan Willemsenfunc (o unmarshalOptions) DiscardUnknown() bool { 41*1c12ee1eSDan Willemsen return o.flags&protoiface.UnmarshalDiscardUnknown != 0 42*1c12ee1eSDan Willemsen} 43*1c12ee1eSDan Willemsen 44*1c12ee1eSDan Willemsenfunc (o unmarshalOptions) IsDefault() bool { 45*1c12ee1eSDan Willemsen return o.flags == 0 && o.resolver == protoregistry.GlobalTypes 46*1c12ee1eSDan Willemsen} 47*1c12ee1eSDan Willemsen 48*1c12ee1eSDan Willemsenvar lazyUnmarshalOptions = unmarshalOptions{ 49*1c12ee1eSDan Willemsen resolver: protoregistry.GlobalTypes, 50*1c12ee1eSDan Willemsen depth: protowire.DefaultRecursionLimit, 51*1c12ee1eSDan Willemsen} 52*1c12ee1eSDan Willemsen 53*1c12ee1eSDan Willemsentype unmarshalOutput struct { 54*1c12ee1eSDan Willemsen n int // number of bytes consumed 55*1c12ee1eSDan Willemsen initialized bool 56*1c12ee1eSDan Willemsen} 57*1c12ee1eSDan Willemsen 58*1c12ee1eSDan Willemsen// unmarshal is protoreflect.Methods.Unmarshal. 59*1c12ee1eSDan Willemsenfunc (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) { 60*1c12ee1eSDan Willemsen var p pointer 61*1c12ee1eSDan Willemsen if ms, ok := in.Message.(*messageState); ok { 62*1c12ee1eSDan Willemsen p = ms.pointer() 63*1c12ee1eSDan Willemsen } else { 64*1c12ee1eSDan Willemsen p = in.Message.(*messageReflectWrapper).pointer() 65*1c12ee1eSDan Willemsen } 66*1c12ee1eSDan Willemsen out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{ 67*1c12ee1eSDan Willemsen flags: in.Flags, 68*1c12ee1eSDan Willemsen resolver: in.Resolver, 69*1c12ee1eSDan Willemsen depth: in.Depth, 70*1c12ee1eSDan Willemsen }) 71*1c12ee1eSDan Willemsen var flags protoiface.UnmarshalOutputFlags 72*1c12ee1eSDan Willemsen if out.initialized { 73*1c12ee1eSDan Willemsen flags |= protoiface.UnmarshalInitialized 74*1c12ee1eSDan Willemsen } 75*1c12ee1eSDan Willemsen return protoiface.UnmarshalOutput{ 76*1c12ee1eSDan Willemsen Flags: flags, 77*1c12ee1eSDan Willemsen }, err 78*1c12ee1eSDan Willemsen} 79*1c12ee1eSDan Willemsen 80*1c12ee1eSDan Willemsen// errUnknown is returned during unmarshaling to indicate a parse error that 81*1c12ee1eSDan Willemsen// should result in a field being placed in the unknown fields section (for example, 82*1c12ee1eSDan Willemsen// when the wire type doesn't match) as opposed to the entire unmarshal operation 83*1c12ee1eSDan Willemsen// failing (for example, when a field extends past the available input). 84*1c12ee1eSDan Willemsen// 85*1c12ee1eSDan Willemsen// This is a sentinel error which should never be visible to the user. 86*1c12ee1eSDan Willemsenvar errUnknown = errors.New("unknown") 87*1c12ee1eSDan Willemsen 88*1c12ee1eSDan Willemsenfunc (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { 89*1c12ee1eSDan Willemsen mi.init() 90*1c12ee1eSDan Willemsen opts.depth-- 91*1c12ee1eSDan Willemsen if opts.depth < 0 { 92*1c12ee1eSDan Willemsen return out, errRecursionDepth 93*1c12ee1eSDan Willemsen } 94*1c12ee1eSDan Willemsen if flags.ProtoLegacy && mi.isMessageSet { 95*1c12ee1eSDan Willemsen return unmarshalMessageSet(mi, b, p, opts) 96*1c12ee1eSDan Willemsen } 97*1c12ee1eSDan Willemsen initialized := true 98*1c12ee1eSDan Willemsen var requiredMask uint64 99*1c12ee1eSDan Willemsen var exts *map[int32]ExtensionField 100*1c12ee1eSDan Willemsen start := len(b) 101*1c12ee1eSDan Willemsen for len(b) > 0 { 102*1c12ee1eSDan Willemsen // Parse the tag (field number and wire type). 103*1c12ee1eSDan Willemsen var tag uint64 104*1c12ee1eSDan Willemsen if b[0] < 0x80 { 105*1c12ee1eSDan Willemsen tag = uint64(b[0]) 106*1c12ee1eSDan Willemsen b = b[1:] 107*1c12ee1eSDan Willemsen } else if len(b) >= 2 && b[1] < 128 { 108*1c12ee1eSDan Willemsen tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 109*1c12ee1eSDan Willemsen b = b[2:] 110*1c12ee1eSDan Willemsen } else { 111*1c12ee1eSDan Willemsen var n int 112*1c12ee1eSDan Willemsen tag, n = protowire.ConsumeVarint(b) 113*1c12ee1eSDan Willemsen if n < 0 { 114*1c12ee1eSDan Willemsen return out, errDecode 115*1c12ee1eSDan Willemsen } 116*1c12ee1eSDan Willemsen b = b[n:] 117*1c12ee1eSDan Willemsen } 118*1c12ee1eSDan Willemsen var num protowire.Number 119*1c12ee1eSDan Willemsen if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { 120*1c12ee1eSDan Willemsen return out, errDecode 121*1c12ee1eSDan Willemsen } else { 122*1c12ee1eSDan Willemsen num = protowire.Number(n) 123*1c12ee1eSDan Willemsen } 124*1c12ee1eSDan Willemsen wtyp := protowire.Type(tag & 7) 125*1c12ee1eSDan Willemsen 126*1c12ee1eSDan Willemsen if wtyp == protowire.EndGroupType { 127*1c12ee1eSDan Willemsen if num != groupTag { 128*1c12ee1eSDan Willemsen return out, errDecode 129*1c12ee1eSDan Willemsen } 130*1c12ee1eSDan Willemsen groupTag = 0 131*1c12ee1eSDan Willemsen break 132*1c12ee1eSDan Willemsen } 133*1c12ee1eSDan Willemsen 134*1c12ee1eSDan Willemsen var f *coderFieldInfo 135*1c12ee1eSDan Willemsen if int(num) < len(mi.denseCoderFields) { 136*1c12ee1eSDan Willemsen f = mi.denseCoderFields[num] 137*1c12ee1eSDan Willemsen } else { 138*1c12ee1eSDan Willemsen f = mi.coderFields[num] 139*1c12ee1eSDan Willemsen } 140*1c12ee1eSDan Willemsen var n int 141*1c12ee1eSDan Willemsen err := errUnknown 142*1c12ee1eSDan Willemsen switch { 143*1c12ee1eSDan Willemsen case f != nil: 144*1c12ee1eSDan Willemsen if f.funcs.unmarshal == nil { 145*1c12ee1eSDan Willemsen break 146*1c12ee1eSDan Willemsen } 147*1c12ee1eSDan Willemsen var o unmarshalOutput 148*1c12ee1eSDan Willemsen o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) 149*1c12ee1eSDan Willemsen n = o.n 150*1c12ee1eSDan Willemsen if err != nil { 151*1c12ee1eSDan Willemsen break 152*1c12ee1eSDan Willemsen } 153*1c12ee1eSDan Willemsen requiredMask |= f.validation.requiredBit 154*1c12ee1eSDan Willemsen if f.funcs.isInit != nil && !o.initialized { 155*1c12ee1eSDan Willemsen initialized = false 156*1c12ee1eSDan Willemsen } 157*1c12ee1eSDan Willemsen default: 158*1c12ee1eSDan Willemsen // Possible extension. 159*1c12ee1eSDan Willemsen if exts == nil && mi.extensionOffset.IsValid() { 160*1c12ee1eSDan Willemsen exts = p.Apply(mi.extensionOffset).Extensions() 161*1c12ee1eSDan Willemsen if *exts == nil { 162*1c12ee1eSDan Willemsen *exts = make(map[int32]ExtensionField) 163*1c12ee1eSDan Willemsen } 164*1c12ee1eSDan Willemsen } 165*1c12ee1eSDan Willemsen if exts == nil { 166*1c12ee1eSDan Willemsen break 167*1c12ee1eSDan Willemsen } 168*1c12ee1eSDan Willemsen var o unmarshalOutput 169*1c12ee1eSDan Willemsen o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) 170*1c12ee1eSDan Willemsen if err != nil { 171*1c12ee1eSDan Willemsen break 172*1c12ee1eSDan Willemsen } 173*1c12ee1eSDan Willemsen n = o.n 174*1c12ee1eSDan Willemsen if !o.initialized { 175*1c12ee1eSDan Willemsen initialized = false 176*1c12ee1eSDan Willemsen } 177*1c12ee1eSDan Willemsen } 178*1c12ee1eSDan Willemsen if err != nil { 179*1c12ee1eSDan Willemsen if err != errUnknown { 180*1c12ee1eSDan Willemsen return out, err 181*1c12ee1eSDan Willemsen } 182*1c12ee1eSDan Willemsen n = protowire.ConsumeFieldValue(num, wtyp, b) 183*1c12ee1eSDan Willemsen if n < 0 { 184*1c12ee1eSDan Willemsen return out, errDecode 185*1c12ee1eSDan Willemsen } 186*1c12ee1eSDan Willemsen if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { 187*1c12ee1eSDan Willemsen u := mi.mutableUnknownBytes(p) 188*1c12ee1eSDan Willemsen *u = protowire.AppendTag(*u, num, wtyp) 189*1c12ee1eSDan Willemsen *u = append(*u, b[:n]...) 190*1c12ee1eSDan Willemsen } 191*1c12ee1eSDan Willemsen } 192*1c12ee1eSDan Willemsen b = b[n:] 193*1c12ee1eSDan Willemsen } 194*1c12ee1eSDan Willemsen if groupTag != 0 { 195*1c12ee1eSDan Willemsen return out, errDecode 196*1c12ee1eSDan Willemsen } 197*1c12ee1eSDan Willemsen if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { 198*1c12ee1eSDan Willemsen initialized = false 199*1c12ee1eSDan Willemsen } 200*1c12ee1eSDan Willemsen if initialized { 201*1c12ee1eSDan Willemsen out.initialized = true 202*1c12ee1eSDan Willemsen } 203*1c12ee1eSDan Willemsen out.n = start - len(b) 204*1c12ee1eSDan Willemsen return out, nil 205*1c12ee1eSDan Willemsen} 206*1c12ee1eSDan Willemsen 207*1c12ee1eSDan Willemsenfunc (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) { 208*1c12ee1eSDan Willemsen x := exts[int32(num)] 209*1c12ee1eSDan Willemsen xt := x.Type() 210*1c12ee1eSDan Willemsen if xt == nil { 211*1c12ee1eSDan Willemsen var err error 212*1c12ee1eSDan Willemsen xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num) 213*1c12ee1eSDan Willemsen if err != nil { 214*1c12ee1eSDan Willemsen if err == protoregistry.NotFound { 215*1c12ee1eSDan Willemsen return out, errUnknown 216*1c12ee1eSDan Willemsen } 217*1c12ee1eSDan Willemsen return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err) 218*1c12ee1eSDan Willemsen } 219*1c12ee1eSDan Willemsen } 220*1c12ee1eSDan Willemsen xi := getExtensionFieldInfo(xt) 221*1c12ee1eSDan Willemsen if xi.funcs.unmarshal == nil { 222*1c12ee1eSDan Willemsen return out, errUnknown 223*1c12ee1eSDan Willemsen } 224*1c12ee1eSDan Willemsen if flags.LazyUnmarshalExtensions { 225*1c12ee1eSDan Willemsen if opts.IsDefault() && x.canLazy(xt) { 226*1c12ee1eSDan Willemsen out, valid := skipExtension(b, xi, num, wtyp, opts) 227*1c12ee1eSDan Willemsen switch valid { 228*1c12ee1eSDan Willemsen case ValidationValid: 229*1c12ee1eSDan Willemsen if out.initialized { 230*1c12ee1eSDan Willemsen x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n]) 231*1c12ee1eSDan Willemsen exts[int32(num)] = x 232*1c12ee1eSDan Willemsen return out, nil 233*1c12ee1eSDan Willemsen } 234*1c12ee1eSDan Willemsen case ValidationInvalid: 235*1c12ee1eSDan Willemsen return out, errDecode 236*1c12ee1eSDan Willemsen case ValidationUnknown: 237*1c12ee1eSDan Willemsen } 238*1c12ee1eSDan Willemsen } 239*1c12ee1eSDan Willemsen } 240*1c12ee1eSDan Willemsen ival := x.Value() 241*1c12ee1eSDan Willemsen if !ival.IsValid() && xi.unmarshalNeedsValue { 242*1c12ee1eSDan Willemsen // Create a new message, list, or map value to fill in. 243*1c12ee1eSDan Willemsen // For enums, create a prototype value to let the unmarshal func know the 244*1c12ee1eSDan Willemsen // concrete type. 245*1c12ee1eSDan Willemsen ival = xt.New() 246*1c12ee1eSDan Willemsen } 247*1c12ee1eSDan Willemsen v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts) 248*1c12ee1eSDan Willemsen if err != nil { 249*1c12ee1eSDan Willemsen return out, err 250*1c12ee1eSDan Willemsen } 251*1c12ee1eSDan Willemsen if xi.funcs.isInit == nil { 252*1c12ee1eSDan Willemsen out.initialized = true 253*1c12ee1eSDan Willemsen } 254*1c12ee1eSDan Willemsen x.Set(xt, v) 255*1c12ee1eSDan Willemsen exts[int32(num)] = x 256*1c12ee1eSDan Willemsen return out, nil 257*1c12ee1eSDan Willemsen} 258*1c12ee1eSDan Willemsen 259*1c12ee1eSDan Willemsenfunc skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { 260*1c12ee1eSDan Willemsen if xi.validation.mi == nil { 261*1c12ee1eSDan Willemsen return out, ValidationUnknown 262*1c12ee1eSDan Willemsen } 263*1c12ee1eSDan Willemsen xi.validation.mi.init() 264*1c12ee1eSDan Willemsen switch xi.validation.typ { 265*1c12ee1eSDan Willemsen case validationTypeMessage: 266*1c12ee1eSDan Willemsen if wtyp != protowire.BytesType { 267*1c12ee1eSDan Willemsen return out, ValidationUnknown 268*1c12ee1eSDan Willemsen } 269*1c12ee1eSDan Willemsen v, n := protowire.ConsumeBytes(b) 270*1c12ee1eSDan Willemsen if n < 0 { 271*1c12ee1eSDan Willemsen return out, ValidationUnknown 272*1c12ee1eSDan Willemsen } 273*1c12ee1eSDan Willemsen out, st := xi.validation.mi.validate(v, 0, opts) 274*1c12ee1eSDan Willemsen out.n = n 275*1c12ee1eSDan Willemsen return out, st 276*1c12ee1eSDan Willemsen case validationTypeGroup: 277*1c12ee1eSDan Willemsen if wtyp != protowire.StartGroupType { 278*1c12ee1eSDan Willemsen return out, ValidationUnknown 279*1c12ee1eSDan Willemsen } 280*1c12ee1eSDan Willemsen out, st := xi.validation.mi.validate(b, num, opts) 281*1c12ee1eSDan Willemsen return out, st 282*1c12ee1eSDan Willemsen default: 283*1c12ee1eSDan Willemsen return out, ValidationUnknown 284*1c12ee1eSDan Willemsen } 285*1c12ee1eSDan Willemsen} 286