xref: /aosp_15_r20/external/golang-protobuf/proto/decode.go (revision 1c12ee1efe575feb122dbf939ff15148a3b3e8f2)
1*1c12ee1eSDan Willemsen// Copyright 2018 The Go Authors. All rights reserved.
2*1c12ee1eSDan Willemsen// Use of this source code is governed by a BSD-style
3*1c12ee1eSDan Willemsen// license that can be found in the LICENSE file.
4*1c12ee1eSDan Willemsen
5*1c12ee1eSDan Willemsenpackage proto
6*1c12ee1eSDan Willemsen
7*1c12ee1eSDan Willemsenimport (
8*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/encoding/protowire"
9*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/internal/encoding/messageset"
10*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/internal/errors"
11*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/internal/flags"
12*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/internal/genid"
13*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/internal/pragma"
14*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/reflect/protoreflect"
15*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/reflect/protoregistry"
16*1c12ee1eSDan Willemsen	"google.golang.org/protobuf/runtime/protoiface"
17*1c12ee1eSDan Willemsen)
18*1c12ee1eSDan Willemsen
19*1c12ee1eSDan Willemsen// UnmarshalOptions configures the unmarshaler.
20*1c12ee1eSDan Willemsen//
21*1c12ee1eSDan Willemsen// Example usage:
22*1c12ee1eSDan Willemsen//
23*1c12ee1eSDan Willemsen//	err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
24*1c12ee1eSDan Willemsentype UnmarshalOptions struct {
25*1c12ee1eSDan Willemsen	pragma.NoUnkeyedLiterals
26*1c12ee1eSDan Willemsen
27*1c12ee1eSDan Willemsen	// Merge merges the input into the destination message.
28*1c12ee1eSDan Willemsen	// The default behavior is to always reset the message before unmarshaling,
29*1c12ee1eSDan Willemsen	// unless Merge is specified.
30*1c12ee1eSDan Willemsen	Merge bool
31*1c12ee1eSDan Willemsen
32*1c12ee1eSDan Willemsen	// AllowPartial accepts input for messages that will result in missing
33*1c12ee1eSDan Willemsen	// required fields. If AllowPartial is false (the default), Unmarshal will
34*1c12ee1eSDan Willemsen	// return an error if there are any missing required fields.
35*1c12ee1eSDan Willemsen	AllowPartial bool
36*1c12ee1eSDan Willemsen
37*1c12ee1eSDan Willemsen	// If DiscardUnknown is set, unknown fields are ignored.
38*1c12ee1eSDan Willemsen	DiscardUnknown bool
39*1c12ee1eSDan Willemsen
40*1c12ee1eSDan Willemsen	// Resolver is used for looking up types when unmarshaling extension fields.
41*1c12ee1eSDan Willemsen	// If nil, this defaults to using protoregistry.GlobalTypes.
42*1c12ee1eSDan Willemsen	Resolver interface {
43*1c12ee1eSDan Willemsen		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
44*1c12ee1eSDan Willemsen		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
45*1c12ee1eSDan Willemsen	}
46*1c12ee1eSDan Willemsen
47*1c12ee1eSDan Willemsen	// RecursionLimit limits how deeply messages may be nested.
48*1c12ee1eSDan Willemsen	// If zero, a default limit is applied.
49*1c12ee1eSDan Willemsen	RecursionLimit int
50*1c12ee1eSDan Willemsen}
51*1c12ee1eSDan Willemsen
52*1c12ee1eSDan Willemsen// Unmarshal parses the wire-format message in b and places the result in m.
53*1c12ee1eSDan Willemsen// The provided message must be mutable (e.g., a non-nil pointer to a message).
54*1c12ee1eSDan Willemsenfunc Unmarshal(b []byte, m Message) error {
55*1c12ee1eSDan Willemsen	_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
56*1c12ee1eSDan Willemsen	return err
57*1c12ee1eSDan Willemsen}
58*1c12ee1eSDan Willemsen
59*1c12ee1eSDan Willemsen// Unmarshal parses the wire-format message in b and places the result in m.
60*1c12ee1eSDan Willemsen// The provided message must be mutable (e.g., a non-nil pointer to a message).
61*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
62*1c12ee1eSDan Willemsen	if o.RecursionLimit == 0 {
63*1c12ee1eSDan Willemsen		o.RecursionLimit = protowire.DefaultRecursionLimit
64*1c12ee1eSDan Willemsen	}
65*1c12ee1eSDan Willemsen	_, err := o.unmarshal(b, m.ProtoReflect())
66*1c12ee1eSDan Willemsen	return err
67*1c12ee1eSDan Willemsen}
68*1c12ee1eSDan Willemsen
69*1c12ee1eSDan Willemsen// UnmarshalState parses a wire-format message and places the result in m.
70*1c12ee1eSDan Willemsen//
71*1c12ee1eSDan Willemsen// This method permits fine-grained control over the unmarshaler.
72*1c12ee1eSDan Willemsen// Most users should use Unmarshal instead.
73*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
74*1c12ee1eSDan Willemsen	if o.RecursionLimit == 0 {
75*1c12ee1eSDan Willemsen		o.RecursionLimit = protowire.DefaultRecursionLimit
76*1c12ee1eSDan Willemsen	}
77*1c12ee1eSDan Willemsen	return o.unmarshal(in.Buf, in.Message)
78*1c12ee1eSDan Willemsen}
79*1c12ee1eSDan Willemsen
80*1c12ee1eSDan Willemsen// unmarshal is a centralized function that all unmarshal operations go through.
81*1c12ee1eSDan Willemsen// For profiling purposes, avoid changing the name of this function or
82*1c12ee1eSDan Willemsen// introducing other code paths for unmarshal that do not go through this.
83*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
84*1c12ee1eSDan Willemsen	if o.Resolver == nil {
85*1c12ee1eSDan Willemsen		o.Resolver = protoregistry.GlobalTypes
86*1c12ee1eSDan Willemsen	}
87*1c12ee1eSDan Willemsen	if !o.Merge {
88*1c12ee1eSDan Willemsen		Reset(m.Interface())
89*1c12ee1eSDan Willemsen	}
90*1c12ee1eSDan Willemsen	allowPartial := o.AllowPartial
91*1c12ee1eSDan Willemsen	o.Merge = true
92*1c12ee1eSDan Willemsen	o.AllowPartial = true
93*1c12ee1eSDan Willemsen	methods := protoMethods(m)
94*1c12ee1eSDan Willemsen	if methods != nil && methods.Unmarshal != nil &&
95*1c12ee1eSDan Willemsen		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
96*1c12ee1eSDan Willemsen		in := protoiface.UnmarshalInput{
97*1c12ee1eSDan Willemsen			Message:  m,
98*1c12ee1eSDan Willemsen			Buf:      b,
99*1c12ee1eSDan Willemsen			Resolver: o.Resolver,
100*1c12ee1eSDan Willemsen			Depth:    o.RecursionLimit,
101*1c12ee1eSDan Willemsen		}
102*1c12ee1eSDan Willemsen		if o.DiscardUnknown {
103*1c12ee1eSDan Willemsen			in.Flags |= protoiface.UnmarshalDiscardUnknown
104*1c12ee1eSDan Willemsen		}
105*1c12ee1eSDan Willemsen		out, err = methods.Unmarshal(in)
106*1c12ee1eSDan Willemsen	} else {
107*1c12ee1eSDan Willemsen		o.RecursionLimit--
108*1c12ee1eSDan Willemsen		if o.RecursionLimit < 0 {
109*1c12ee1eSDan Willemsen			return out, errors.New("exceeded max recursion depth")
110*1c12ee1eSDan Willemsen		}
111*1c12ee1eSDan Willemsen		err = o.unmarshalMessageSlow(b, m)
112*1c12ee1eSDan Willemsen	}
113*1c12ee1eSDan Willemsen	if err != nil {
114*1c12ee1eSDan Willemsen		return out, err
115*1c12ee1eSDan Willemsen	}
116*1c12ee1eSDan Willemsen	if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
117*1c12ee1eSDan Willemsen		return out, nil
118*1c12ee1eSDan Willemsen	}
119*1c12ee1eSDan Willemsen	return out, checkInitialized(m)
120*1c12ee1eSDan Willemsen}
121*1c12ee1eSDan Willemsen
122*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
123*1c12ee1eSDan Willemsen	_, err := o.unmarshal(b, m)
124*1c12ee1eSDan Willemsen	return err
125*1c12ee1eSDan Willemsen}
126*1c12ee1eSDan Willemsen
127*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
128*1c12ee1eSDan Willemsen	md := m.Descriptor()
129*1c12ee1eSDan Willemsen	if messageset.IsMessageSet(md) {
130*1c12ee1eSDan Willemsen		return o.unmarshalMessageSet(b, m)
131*1c12ee1eSDan Willemsen	}
132*1c12ee1eSDan Willemsen	fields := md.Fields()
133*1c12ee1eSDan Willemsen	for len(b) > 0 {
134*1c12ee1eSDan Willemsen		// Parse the tag (field number and wire type).
135*1c12ee1eSDan Willemsen		num, wtyp, tagLen := protowire.ConsumeTag(b)
136*1c12ee1eSDan Willemsen		if tagLen < 0 {
137*1c12ee1eSDan Willemsen			return errDecode
138*1c12ee1eSDan Willemsen		}
139*1c12ee1eSDan Willemsen		if num > protowire.MaxValidNumber {
140*1c12ee1eSDan Willemsen			return errDecode
141*1c12ee1eSDan Willemsen		}
142*1c12ee1eSDan Willemsen
143*1c12ee1eSDan Willemsen		// Find the field descriptor for this field number.
144*1c12ee1eSDan Willemsen		fd := fields.ByNumber(num)
145*1c12ee1eSDan Willemsen		if fd == nil && md.ExtensionRanges().Has(num) {
146*1c12ee1eSDan Willemsen			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
147*1c12ee1eSDan Willemsen			if err != nil && err != protoregistry.NotFound {
148*1c12ee1eSDan Willemsen				return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
149*1c12ee1eSDan Willemsen			}
150*1c12ee1eSDan Willemsen			if extType != nil {
151*1c12ee1eSDan Willemsen				fd = extType.TypeDescriptor()
152*1c12ee1eSDan Willemsen			}
153*1c12ee1eSDan Willemsen		}
154*1c12ee1eSDan Willemsen		var err error
155*1c12ee1eSDan Willemsen		if fd == nil {
156*1c12ee1eSDan Willemsen			err = errUnknown
157*1c12ee1eSDan Willemsen		} else if flags.ProtoLegacy {
158*1c12ee1eSDan Willemsen			if fd.IsWeak() && fd.Message().IsPlaceholder() {
159*1c12ee1eSDan Willemsen				err = errUnknown // weak referent is not linked in
160*1c12ee1eSDan Willemsen			}
161*1c12ee1eSDan Willemsen		}
162*1c12ee1eSDan Willemsen
163*1c12ee1eSDan Willemsen		// Parse the field value.
164*1c12ee1eSDan Willemsen		var valLen int
165*1c12ee1eSDan Willemsen		switch {
166*1c12ee1eSDan Willemsen		case err != nil:
167*1c12ee1eSDan Willemsen		case fd.IsList():
168*1c12ee1eSDan Willemsen			valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
169*1c12ee1eSDan Willemsen		case fd.IsMap():
170*1c12ee1eSDan Willemsen			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
171*1c12ee1eSDan Willemsen		default:
172*1c12ee1eSDan Willemsen			valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
173*1c12ee1eSDan Willemsen		}
174*1c12ee1eSDan Willemsen		if err != nil {
175*1c12ee1eSDan Willemsen			if err != errUnknown {
176*1c12ee1eSDan Willemsen				return err
177*1c12ee1eSDan Willemsen			}
178*1c12ee1eSDan Willemsen			valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
179*1c12ee1eSDan Willemsen			if valLen < 0 {
180*1c12ee1eSDan Willemsen				return errDecode
181*1c12ee1eSDan Willemsen			}
182*1c12ee1eSDan Willemsen			if !o.DiscardUnknown {
183*1c12ee1eSDan Willemsen				m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
184*1c12ee1eSDan Willemsen			}
185*1c12ee1eSDan Willemsen		}
186*1c12ee1eSDan Willemsen		b = b[tagLen+valLen:]
187*1c12ee1eSDan Willemsen	}
188*1c12ee1eSDan Willemsen	return nil
189*1c12ee1eSDan Willemsen}
190*1c12ee1eSDan Willemsen
191*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
192*1c12ee1eSDan Willemsen	v, n, err := o.unmarshalScalar(b, wtyp, fd)
193*1c12ee1eSDan Willemsen	if err != nil {
194*1c12ee1eSDan Willemsen		return 0, err
195*1c12ee1eSDan Willemsen	}
196*1c12ee1eSDan Willemsen	switch fd.Kind() {
197*1c12ee1eSDan Willemsen	case protoreflect.GroupKind, protoreflect.MessageKind:
198*1c12ee1eSDan Willemsen		m2 := m.Mutable(fd).Message()
199*1c12ee1eSDan Willemsen		if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
200*1c12ee1eSDan Willemsen			return n, err
201*1c12ee1eSDan Willemsen		}
202*1c12ee1eSDan Willemsen	default:
203*1c12ee1eSDan Willemsen		// Non-message scalars replace the previous value.
204*1c12ee1eSDan Willemsen		m.Set(fd, v)
205*1c12ee1eSDan Willemsen	}
206*1c12ee1eSDan Willemsen	return n, nil
207*1c12ee1eSDan Willemsen}
208*1c12ee1eSDan Willemsen
209*1c12ee1eSDan Willemsenfunc (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
210*1c12ee1eSDan Willemsen	if wtyp != protowire.BytesType {
211*1c12ee1eSDan Willemsen		return 0, errUnknown
212*1c12ee1eSDan Willemsen	}
213*1c12ee1eSDan Willemsen	b, n = protowire.ConsumeBytes(b)
214*1c12ee1eSDan Willemsen	if n < 0 {
215*1c12ee1eSDan Willemsen		return 0, errDecode
216*1c12ee1eSDan Willemsen	}
217*1c12ee1eSDan Willemsen	var (
218*1c12ee1eSDan Willemsen		keyField = fd.MapKey()
219*1c12ee1eSDan Willemsen		valField = fd.MapValue()
220*1c12ee1eSDan Willemsen		key      protoreflect.Value
221*1c12ee1eSDan Willemsen		val      protoreflect.Value
222*1c12ee1eSDan Willemsen		haveKey  bool
223*1c12ee1eSDan Willemsen		haveVal  bool
224*1c12ee1eSDan Willemsen	)
225*1c12ee1eSDan Willemsen	switch valField.Kind() {
226*1c12ee1eSDan Willemsen	case protoreflect.GroupKind, protoreflect.MessageKind:
227*1c12ee1eSDan Willemsen		val = mapv.NewValue()
228*1c12ee1eSDan Willemsen	}
229*1c12ee1eSDan Willemsen	// Map entries are represented as a two-element message with fields
230*1c12ee1eSDan Willemsen	// containing the key and value.
231*1c12ee1eSDan Willemsen	for len(b) > 0 {
232*1c12ee1eSDan Willemsen		num, wtyp, n := protowire.ConsumeTag(b)
233*1c12ee1eSDan Willemsen		if n < 0 {
234*1c12ee1eSDan Willemsen			return 0, errDecode
235*1c12ee1eSDan Willemsen		}
236*1c12ee1eSDan Willemsen		if num > protowire.MaxValidNumber {
237*1c12ee1eSDan Willemsen			return 0, errDecode
238*1c12ee1eSDan Willemsen		}
239*1c12ee1eSDan Willemsen		b = b[n:]
240*1c12ee1eSDan Willemsen		err = errUnknown
241*1c12ee1eSDan Willemsen		switch num {
242*1c12ee1eSDan Willemsen		case genid.MapEntry_Key_field_number:
243*1c12ee1eSDan Willemsen			key, n, err = o.unmarshalScalar(b, wtyp, keyField)
244*1c12ee1eSDan Willemsen			if err != nil {
245*1c12ee1eSDan Willemsen				break
246*1c12ee1eSDan Willemsen			}
247*1c12ee1eSDan Willemsen			haveKey = true
248*1c12ee1eSDan Willemsen		case genid.MapEntry_Value_field_number:
249*1c12ee1eSDan Willemsen			var v protoreflect.Value
250*1c12ee1eSDan Willemsen			v, n, err = o.unmarshalScalar(b, wtyp, valField)
251*1c12ee1eSDan Willemsen			if err != nil {
252*1c12ee1eSDan Willemsen				break
253*1c12ee1eSDan Willemsen			}
254*1c12ee1eSDan Willemsen			switch valField.Kind() {
255*1c12ee1eSDan Willemsen			case protoreflect.GroupKind, protoreflect.MessageKind:
256*1c12ee1eSDan Willemsen				if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
257*1c12ee1eSDan Willemsen					return 0, err
258*1c12ee1eSDan Willemsen				}
259*1c12ee1eSDan Willemsen			default:
260*1c12ee1eSDan Willemsen				val = v
261*1c12ee1eSDan Willemsen			}
262*1c12ee1eSDan Willemsen			haveVal = true
263*1c12ee1eSDan Willemsen		}
264*1c12ee1eSDan Willemsen		if err == errUnknown {
265*1c12ee1eSDan Willemsen			n = protowire.ConsumeFieldValue(num, wtyp, b)
266*1c12ee1eSDan Willemsen			if n < 0 {
267*1c12ee1eSDan Willemsen				return 0, errDecode
268*1c12ee1eSDan Willemsen			}
269*1c12ee1eSDan Willemsen		} else if err != nil {
270*1c12ee1eSDan Willemsen			return 0, err
271*1c12ee1eSDan Willemsen		}
272*1c12ee1eSDan Willemsen		b = b[n:]
273*1c12ee1eSDan Willemsen	}
274*1c12ee1eSDan Willemsen	// Every map entry should have entries for key and value, but this is not strictly required.
275*1c12ee1eSDan Willemsen	if !haveKey {
276*1c12ee1eSDan Willemsen		key = keyField.Default()
277*1c12ee1eSDan Willemsen	}
278*1c12ee1eSDan Willemsen	if !haveVal {
279*1c12ee1eSDan Willemsen		switch valField.Kind() {
280*1c12ee1eSDan Willemsen		case protoreflect.GroupKind, protoreflect.MessageKind:
281*1c12ee1eSDan Willemsen		default:
282*1c12ee1eSDan Willemsen			val = valField.Default()
283*1c12ee1eSDan Willemsen		}
284*1c12ee1eSDan Willemsen	}
285*1c12ee1eSDan Willemsen	mapv.Set(key.MapKey(), val)
286*1c12ee1eSDan Willemsen	return n, nil
287*1c12ee1eSDan Willemsen}
288*1c12ee1eSDan Willemsen
289*1c12ee1eSDan Willemsen// errUnknown is used internally to indicate fields which should be added
290*1c12ee1eSDan Willemsen// to the unknown field set of a message. It is never returned from an exported
291*1c12ee1eSDan Willemsen// function.
292*1c12ee1eSDan Willemsenvar errUnknown = errors.New("BUG: internal error (unknown)")
293*1c12ee1eSDan Willemsen
294*1c12ee1eSDan Willemsenvar errDecode = errors.New("cannot parse invalid wire-format data")
295