xref: /aosp_15_r20/external/golang-protobuf/internal/impl/decode.go (revision 1c12ee1efe575feb122dbf939ff15148a3b3e8f2)
1*1c12ee1eSDan Willemsen// Copyright 2019 The Go Authors. All rights reserved.
2*1c12ee1eSDan Willemsen// Use of this source code is governed by a BSD-style
3*1c12ee1eSDan Willemsen// license that can be found in the LICENSE file.
4*1c12ee1eSDan Willemsen
5*1c12ee1eSDan Willemsenpackage impl
6*1c12ee1eSDan Willemsen
7*1c12ee1eSDan Willemsenimport (
8*1c12ee1eSDan Willemsen	"math/bits"
9*1c12ee1eSDan Willemsen
10*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/encoding/protowire"
11*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/internal/errors"
12*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/internal/flags"
13*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/proto"
14*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/reflect/protoreflect"
15*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/reflect/protoregistry"
16*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/runtime/protoiface"
17*1c12ee1eSDan Willemsen)
18*1c12ee1eSDan Willemsen
19*1c12ee1eSDan Willemsenvar errDecode = errors.New("cannot parse invalid wire-format data")
20*1c12ee1eSDan Willemsenvar errRecursionDepth = errors.New("exceeded maximum recursion depth")
21*1c12ee1eSDan Willemsen
22*1c12ee1eSDan Willemsentype unmarshalOptions struct {
23*1c12ee1eSDan Willemsen	flags    protoiface.UnmarshalInputFlags
24*1c12ee1eSDan Willemsen	resolver interface {
25*1c12ee1eSDan Willemsen		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
26*1c12ee1eSDan Willemsen		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
27*1c12ee1eSDan Willemsen	}
28*1c12ee1eSDan Willemsen	depth int
29*1c12ee1eSDan Willemsen}
30*1c12ee1eSDan Willemsen
31*1c12ee1eSDan Willemsenfunc (o unmarshalOptions) Options() proto.UnmarshalOptions {
32*1c12ee1eSDan Willemsen	return proto.UnmarshalOptions{
33*1c12ee1eSDan Willemsen		Merge:          true,
34*1c12ee1eSDan Willemsen		AllowPartial:   true,
35*1c12ee1eSDan Willemsen		DiscardUnknown: o.DiscardUnknown(),
36*1c12ee1eSDan Willemsen		Resolver:       o.resolver,
37*1c12ee1eSDan Willemsen	}
38*1c12ee1eSDan Willemsen}
39*1c12ee1eSDan Willemsen
40*1c12ee1eSDan Willemsenfunc (o unmarshalOptions) DiscardUnknown() bool {
41*1c12ee1eSDan Willemsen	return o.flags&protoiface.UnmarshalDiscardUnknown != 0
42*1c12ee1eSDan Willemsen}
43*1c12ee1eSDan Willemsen
44*1c12ee1eSDan Willemsenfunc (o unmarshalOptions) IsDefault() bool {
45*1c12ee1eSDan Willemsen	return o.flags == 0 && o.resolver == protoregistry.GlobalTypes
46*1c12ee1eSDan Willemsen}
47*1c12ee1eSDan Willemsen
48*1c12ee1eSDan Willemsenvar lazyUnmarshalOptions = unmarshalOptions{
49*1c12ee1eSDan Willemsen	resolver: protoregistry.GlobalTypes,
50*1c12ee1eSDan Willemsen	depth:    protowire.DefaultRecursionLimit,
51*1c12ee1eSDan Willemsen}
52*1c12ee1eSDan Willemsen
53*1c12ee1eSDan Willemsentype unmarshalOutput struct {
54*1c12ee1eSDan Willemsen	n           int // number of bytes consumed
55*1c12ee1eSDan Willemsen	initialized bool
56*1c12ee1eSDan Willemsen}
57*1c12ee1eSDan Willemsen
58*1c12ee1eSDan Willemsen// unmarshal is protoreflect.Methods.Unmarshal.
59*1c12ee1eSDan Willemsenfunc (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
60*1c12ee1eSDan Willemsen	var p pointer
61*1c12ee1eSDan Willemsen	if ms, ok := in.Message.(*messageState); ok {
62*1c12ee1eSDan Willemsen		p = ms.pointer()
63*1c12ee1eSDan Willemsen	} else {
64*1c12ee1eSDan Willemsen		p = in.Message.(*messageReflectWrapper).pointer()
65*1c12ee1eSDan Willemsen	}
66*1c12ee1eSDan Willemsen	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
67*1c12ee1eSDan Willemsen		flags:    in.Flags,
68*1c12ee1eSDan Willemsen		resolver: in.Resolver,
69*1c12ee1eSDan Willemsen		depth:    in.Depth,
70*1c12ee1eSDan Willemsen	})
71*1c12ee1eSDan Willemsen	var flags protoiface.UnmarshalOutputFlags
72*1c12ee1eSDan Willemsen	if out.initialized {
73*1c12ee1eSDan Willemsen		flags |= protoiface.UnmarshalInitialized
74*1c12ee1eSDan Willemsen	}
75*1c12ee1eSDan Willemsen	return protoiface.UnmarshalOutput{
76*1c12ee1eSDan Willemsen		Flags: flags,
77*1c12ee1eSDan Willemsen	}, err
78*1c12ee1eSDan Willemsen}
79*1c12ee1eSDan Willemsen
80*1c12ee1eSDan Willemsen// errUnknown is returned during unmarshaling to indicate a parse error that
81*1c12ee1eSDan Willemsen// should result in a field being placed in the unknown fields section (for example,
82*1c12ee1eSDan Willemsen// when the wire type doesn't match) as opposed to the entire unmarshal operation
83*1c12ee1eSDan Willemsen// failing (for example, when a field extends past the available input).
84*1c12ee1eSDan Willemsen//
85*1c12ee1eSDan Willemsen// This is a sentinel error which should never be visible to the user.
86*1c12ee1eSDan Willemsenvar errUnknown = errors.New("unknown")
87*1c12ee1eSDan Willemsen
88*1c12ee1eSDan Willemsenfunc (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
89*1c12ee1eSDan Willemsen	mi.init()
90*1c12ee1eSDan Willemsen	opts.depth--
91*1c12ee1eSDan Willemsen	if opts.depth < 0 {
92*1c12ee1eSDan Willemsen		return out, errRecursionDepth
93*1c12ee1eSDan Willemsen	}
94*1c12ee1eSDan Willemsen	if flags.ProtoLegacy && mi.isMessageSet {
95*1c12ee1eSDan Willemsen		return unmarshalMessageSet(mi, b, p, opts)
96*1c12ee1eSDan Willemsen	}
97*1c12ee1eSDan Willemsen	initialized := true
98*1c12ee1eSDan Willemsen	var requiredMask uint64
99*1c12ee1eSDan Willemsen	var exts *map[int32]ExtensionField
100*1c12ee1eSDan Willemsen	start := len(b)
101*1c12ee1eSDan Willemsen	for len(b) > 0 {
102*1c12ee1eSDan Willemsen		// Parse the tag (field number and wire type).
103*1c12ee1eSDan Willemsen		var tag uint64
104*1c12ee1eSDan Willemsen		if b[0] < 0x80 {
105*1c12ee1eSDan Willemsen			tag = uint64(b[0])
106*1c12ee1eSDan Willemsen			b = b[1:]
107*1c12ee1eSDan Willemsen		} else if len(b) >= 2 && b[1] < 128 {
108*1c12ee1eSDan Willemsen			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
109*1c12ee1eSDan Willemsen			b = b[2:]
110*1c12ee1eSDan Willemsen		} else {
111*1c12ee1eSDan Willemsen			var n int
112*1c12ee1eSDan Willemsen			tag, n = protowire.ConsumeVarint(b)
113*1c12ee1eSDan Willemsen			if n < 0 {
114*1c12ee1eSDan Willemsen				return out, errDecode
115*1c12ee1eSDan Willemsen			}
116*1c12ee1eSDan Willemsen			b = b[n:]
117*1c12ee1eSDan Willemsen		}
118*1c12ee1eSDan Willemsen		var num protowire.Number
119*1c12ee1eSDan Willemsen		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
120*1c12ee1eSDan Willemsen			return out, errDecode
121*1c12ee1eSDan Willemsen		} else {
122*1c12ee1eSDan Willemsen			num = protowire.Number(n)
123*1c12ee1eSDan Willemsen		}
124*1c12ee1eSDan Willemsen		wtyp := protowire.Type(tag & 7)
125*1c12ee1eSDan Willemsen
126*1c12ee1eSDan Willemsen		if wtyp == protowire.EndGroupType {
127*1c12ee1eSDan Willemsen			if num != groupTag {
128*1c12ee1eSDan Willemsen				return out, errDecode
129*1c12ee1eSDan Willemsen			}
130*1c12ee1eSDan Willemsen			groupTag = 0
131*1c12ee1eSDan Willemsen			break
132*1c12ee1eSDan Willemsen		}
133*1c12ee1eSDan Willemsen
134*1c12ee1eSDan Willemsen		var f *coderFieldInfo
135*1c12ee1eSDan Willemsen		if int(num) < len(mi.denseCoderFields) {
136*1c12ee1eSDan Willemsen			f = mi.denseCoderFields[num]
137*1c12ee1eSDan Willemsen		} else {
138*1c12ee1eSDan Willemsen			f = mi.coderFields[num]
139*1c12ee1eSDan Willemsen		}
140*1c12ee1eSDan Willemsen		var n int
141*1c12ee1eSDan Willemsen		err := errUnknown
142*1c12ee1eSDan Willemsen		switch {
143*1c12ee1eSDan Willemsen		case f != nil:
144*1c12ee1eSDan Willemsen			if f.funcs.unmarshal == nil {
145*1c12ee1eSDan Willemsen				break
146*1c12ee1eSDan Willemsen			}
147*1c12ee1eSDan Willemsen			var o unmarshalOutput
148*1c12ee1eSDan Willemsen			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
149*1c12ee1eSDan Willemsen			n = o.n
150*1c12ee1eSDan Willemsen			if err != nil {
151*1c12ee1eSDan Willemsen				break
152*1c12ee1eSDan Willemsen			}
153*1c12ee1eSDan Willemsen			requiredMask |= f.validation.requiredBit
154*1c12ee1eSDan Willemsen			if f.funcs.isInit != nil && !o.initialized {
155*1c12ee1eSDan Willemsen				initialized = false
156*1c12ee1eSDan Willemsen			}
157*1c12ee1eSDan Willemsen		default:
158*1c12ee1eSDan Willemsen			// Possible extension.
159*1c12ee1eSDan Willemsen			if exts == nil && mi.extensionOffset.IsValid() {
160*1c12ee1eSDan Willemsen				exts = p.Apply(mi.extensionOffset).Extensions()
161*1c12ee1eSDan Willemsen				if *exts == nil {
162*1c12ee1eSDan Willemsen					*exts = make(map[int32]ExtensionField)
163*1c12ee1eSDan Willemsen				}
164*1c12ee1eSDan Willemsen			}
165*1c12ee1eSDan Willemsen			if exts == nil {
166*1c12ee1eSDan Willemsen				break
167*1c12ee1eSDan Willemsen			}
168*1c12ee1eSDan Willemsen			var o unmarshalOutput
169*1c12ee1eSDan Willemsen			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
170*1c12ee1eSDan Willemsen			if err != nil {
171*1c12ee1eSDan Willemsen				break
172*1c12ee1eSDan Willemsen			}
173*1c12ee1eSDan Willemsen			n = o.n
174*1c12ee1eSDan Willemsen			if !o.initialized {
175*1c12ee1eSDan Willemsen				initialized = false
176*1c12ee1eSDan Willemsen			}
177*1c12ee1eSDan Willemsen		}
178*1c12ee1eSDan Willemsen		if err != nil {
179*1c12ee1eSDan Willemsen			if err != errUnknown {
180*1c12ee1eSDan Willemsen				return out, err
181*1c12ee1eSDan Willemsen			}
182*1c12ee1eSDan Willemsen			n = protowire.ConsumeFieldValue(num, wtyp, b)
183*1c12ee1eSDan Willemsen			if n < 0 {
184*1c12ee1eSDan Willemsen				return out, errDecode
185*1c12ee1eSDan Willemsen			}
186*1c12ee1eSDan Willemsen			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
187*1c12ee1eSDan Willemsen				u := mi.mutableUnknownBytes(p)
188*1c12ee1eSDan Willemsen				*u = protowire.AppendTag(*u, num, wtyp)
189*1c12ee1eSDan Willemsen				*u = append(*u, b[:n]...)
190*1c12ee1eSDan Willemsen			}
191*1c12ee1eSDan Willemsen		}
192*1c12ee1eSDan Willemsen		b = b[n:]
193*1c12ee1eSDan Willemsen	}
194*1c12ee1eSDan Willemsen	if groupTag != 0 {
195*1c12ee1eSDan Willemsen		return out, errDecode
196*1c12ee1eSDan Willemsen	}
197*1c12ee1eSDan Willemsen	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
198*1c12ee1eSDan Willemsen		initialized = false
199*1c12ee1eSDan Willemsen	}
200*1c12ee1eSDan Willemsen	if initialized {
201*1c12ee1eSDan Willemsen		out.initialized = true
202*1c12ee1eSDan Willemsen	}
203*1c12ee1eSDan Willemsen	out.n = start - len(b)
204*1c12ee1eSDan Willemsen	return out, nil
205*1c12ee1eSDan Willemsen}
206*1c12ee1eSDan Willemsen
207*1c12ee1eSDan Willemsenfunc (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
208*1c12ee1eSDan Willemsen	x := exts[int32(num)]
209*1c12ee1eSDan Willemsen	xt := x.Type()
210*1c12ee1eSDan Willemsen	if xt == nil {
211*1c12ee1eSDan Willemsen		var err error
212*1c12ee1eSDan Willemsen		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
213*1c12ee1eSDan Willemsen		if err != nil {
214*1c12ee1eSDan Willemsen			if err == protoregistry.NotFound {
215*1c12ee1eSDan Willemsen				return out, errUnknown
216*1c12ee1eSDan Willemsen			}
217*1c12ee1eSDan Willemsen			return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
218*1c12ee1eSDan Willemsen		}
219*1c12ee1eSDan Willemsen	}
220*1c12ee1eSDan Willemsen	xi := getExtensionFieldInfo(xt)
221*1c12ee1eSDan Willemsen	if xi.funcs.unmarshal == nil {
222*1c12ee1eSDan Willemsen		return out, errUnknown
223*1c12ee1eSDan Willemsen	}
224*1c12ee1eSDan Willemsen	if flags.LazyUnmarshalExtensions {
225*1c12ee1eSDan Willemsen		if opts.IsDefault() && x.canLazy(xt) {
226*1c12ee1eSDan Willemsen			out, valid := skipExtension(b, xi, num, wtyp, opts)
227*1c12ee1eSDan Willemsen			switch valid {
228*1c12ee1eSDan Willemsen			case ValidationValid:
229*1c12ee1eSDan Willemsen				if out.initialized {
230*1c12ee1eSDan Willemsen					x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
231*1c12ee1eSDan Willemsen					exts[int32(num)] = x
232*1c12ee1eSDan Willemsen					return out, nil
233*1c12ee1eSDan Willemsen				}
234*1c12ee1eSDan Willemsen			case ValidationInvalid:
235*1c12ee1eSDan Willemsen				return out, errDecode
236*1c12ee1eSDan Willemsen			case ValidationUnknown:
237*1c12ee1eSDan Willemsen			}
238*1c12ee1eSDan Willemsen		}
239*1c12ee1eSDan Willemsen	}
240*1c12ee1eSDan Willemsen	ival := x.Value()
241*1c12ee1eSDan Willemsen	if !ival.IsValid() && xi.unmarshalNeedsValue {
242*1c12ee1eSDan Willemsen		// Create a new message, list, or map value to fill in.
243*1c12ee1eSDan Willemsen		// For enums, create a prototype value to let the unmarshal func know the
244*1c12ee1eSDan Willemsen		// concrete type.
245*1c12ee1eSDan Willemsen		ival = xt.New()
246*1c12ee1eSDan Willemsen	}
247*1c12ee1eSDan Willemsen	v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
248*1c12ee1eSDan Willemsen	if err != nil {
249*1c12ee1eSDan Willemsen		return out, err
250*1c12ee1eSDan Willemsen	}
251*1c12ee1eSDan Willemsen	if xi.funcs.isInit == nil {
252*1c12ee1eSDan Willemsen		out.initialized = true
253*1c12ee1eSDan Willemsen	}
254*1c12ee1eSDan Willemsen	x.Set(xt, v)
255*1c12ee1eSDan Willemsen	exts[int32(num)] = x
256*1c12ee1eSDan Willemsen	return out, nil
257*1c12ee1eSDan Willemsen}
258*1c12ee1eSDan Willemsen
259*1c12ee1eSDan Willemsenfunc skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
260*1c12ee1eSDan Willemsen	if xi.validation.mi == nil {
261*1c12ee1eSDan Willemsen		return out, ValidationUnknown
262*1c12ee1eSDan Willemsen	}
263*1c12ee1eSDan Willemsen	xi.validation.mi.init()
264*1c12ee1eSDan Willemsen	switch xi.validation.typ {
265*1c12ee1eSDan Willemsen	case validationTypeMessage:
266*1c12ee1eSDan Willemsen		if wtyp != protowire.BytesType {
267*1c12ee1eSDan Willemsen			return out, ValidationUnknown
268*1c12ee1eSDan Willemsen		}
269*1c12ee1eSDan Willemsen		v, n := protowire.ConsumeBytes(b)
270*1c12ee1eSDan Willemsen		if n < 0 {
271*1c12ee1eSDan Willemsen			return out, ValidationUnknown
272*1c12ee1eSDan Willemsen		}
273*1c12ee1eSDan Willemsen		out, st := xi.validation.mi.validate(v, 0, opts)
274*1c12ee1eSDan Willemsen		out.n = n
275*1c12ee1eSDan Willemsen		return out, st
276*1c12ee1eSDan Willemsen	case validationTypeGroup:
277*1c12ee1eSDan Willemsen		if wtyp != protowire.StartGroupType {
278*1c12ee1eSDan Willemsen			return out, ValidationUnknown
279*1c12ee1eSDan Willemsen		}
280*1c12ee1eSDan Willemsen		out, st := xi.validation.mi.validate(b, num, opts)
281*1c12ee1eSDan Willemsen		return out, st
282*1c12ee1eSDan Willemsen	default:
283*1c12ee1eSDan Willemsen		return out, ValidationUnknown
284*1c12ee1eSDan Willemsen	}
285*1c12ee1eSDan Willemsen}
286