xref: /aosp_15_r20/external/golang-protobuf/internal/impl/validate.go (revision 1c12ee1efe575feb122dbf939ff15148a3b3e8f2)
1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8	"fmt"
9	"math"
10	"math/bits"
11	"reflect"
12	"unicode/utf8"
13
14	"google.golang.org/protobuf/encoding/protowire"
15	"google.golang.org/protobuf/internal/encoding/messageset"
16	"google.golang.org/protobuf/internal/flags"
17	"google.golang.org/protobuf/internal/genid"
18	"google.golang.org/protobuf/internal/strs"
19	"google.golang.org/protobuf/reflect/protoreflect"
20	"google.golang.org/protobuf/reflect/protoregistry"
21	"google.golang.org/protobuf/runtime/protoiface"
22)
23
24// ValidationStatus is the result of validating the wire-format encoding of a message.
25type ValidationStatus int
26
27const (
28	// ValidationUnknown indicates that unmarshaling the message might succeed or fail.
29	// The validator was unable to render a judgement.
30	//
31	// The only causes of this status are an aberrant message type appearing somewhere
32	// in the message or a failure in the extension resolver.
33	ValidationUnknown ValidationStatus = iota + 1
34
35	// ValidationInvalid indicates that unmarshaling the message will fail.
36	ValidationInvalid
37
38	// ValidationValid indicates that unmarshaling the message will succeed.
39	ValidationValid
40)
41
42func (v ValidationStatus) String() string {
43	switch v {
44	case ValidationUnknown:
45		return "ValidationUnknown"
46	case ValidationInvalid:
47		return "ValidationInvalid"
48	case ValidationValid:
49		return "ValidationValid"
50	default:
51		return fmt.Sprintf("ValidationStatus(%d)", int(v))
52	}
53}
54
55// Validate determines whether the contents of the buffer are a valid wire encoding
56// of the message type.
57//
58// This function is exposed for testing.
59func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
60	mi, ok := mt.(*MessageInfo)
61	if !ok {
62		return out, ValidationUnknown
63	}
64	if in.Resolver == nil {
65		in.Resolver = protoregistry.GlobalTypes
66	}
67	o, st := mi.validate(in.Buf, 0, unmarshalOptions{
68		flags:    in.Flags,
69		resolver: in.Resolver,
70	})
71	if o.initialized {
72		out.Flags |= protoiface.UnmarshalInitialized
73	}
74	return out, st
75}
76
77type validationInfo struct {
78	mi               *MessageInfo
79	typ              validationType
80	keyType, valType validationType
81
82	// For non-required fields, requiredBit is 0.
83	//
84	// For required fields, requiredBit's nth bit is set, where n is a
85	// unique index in the range [0, MessageInfo.numRequiredFields).
86	//
87	// If there are more than 64 required fields, requiredBit is 0.
88	requiredBit uint64
89}
90
91type validationType uint8
92
93const (
94	validationTypeOther validationType = iota
95	validationTypeMessage
96	validationTypeGroup
97	validationTypeMap
98	validationTypeRepeatedVarint
99	validationTypeRepeatedFixed32
100	validationTypeRepeatedFixed64
101	validationTypeVarint
102	validationTypeFixed32
103	validationTypeFixed64
104	validationTypeBytes
105	validationTypeUTF8String
106	validationTypeMessageSetItem
107)
108
109func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
110	var vi validationInfo
111	switch {
112	case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
113		switch fd.Kind() {
114		case protoreflect.MessageKind:
115			vi.typ = validationTypeMessage
116			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
117				vi.mi = getMessageInfo(ot.Field(0).Type)
118			}
119		case protoreflect.GroupKind:
120			vi.typ = validationTypeGroup
121			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
122				vi.mi = getMessageInfo(ot.Field(0).Type)
123			}
124		case protoreflect.StringKind:
125			if strs.EnforceUTF8(fd) {
126				vi.typ = validationTypeUTF8String
127			}
128		}
129	default:
130		vi = newValidationInfo(fd, ft)
131	}
132	if fd.Cardinality() == protoreflect.Required {
133		// Avoid overflow. The required field check is done with a 64-bit mask, with
134		// any message containing more than 64 required fields always reported as
135		// potentially uninitialized, so it is not important to get a precise count
136		// of the required fields past 64.
137		if mi.numRequiredFields < math.MaxUint8 {
138			mi.numRequiredFields++
139			vi.requiredBit = 1 << (mi.numRequiredFields - 1)
140		}
141	}
142	return vi
143}
144
145func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
146	var vi validationInfo
147	switch {
148	case fd.IsList():
149		switch fd.Kind() {
150		case protoreflect.MessageKind:
151			vi.typ = validationTypeMessage
152			if ft.Kind() == reflect.Slice {
153				vi.mi = getMessageInfo(ft.Elem())
154			}
155		case protoreflect.GroupKind:
156			vi.typ = validationTypeGroup
157			if ft.Kind() == reflect.Slice {
158				vi.mi = getMessageInfo(ft.Elem())
159			}
160		case protoreflect.StringKind:
161			vi.typ = validationTypeBytes
162			if strs.EnforceUTF8(fd) {
163				vi.typ = validationTypeUTF8String
164			}
165		default:
166			switch wireTypes[fd.Kind()] {
167			case protowire.VarintType:
168				vi.typ = validationTypeRepeatedVarint
169			case protowire.Fixed32Type:
170				vi.typ = validationTypeRepeatedFixed32
171			case protowire.Fixed64Type:
172				vi.typ = validationTypeRepeatedFixed64
173			}
174		}
175	case fd.IsMap():
176		vi.typ = validationTypeMap
177		switch fd.MapKey().Kind() {
178		case protoreflect.StringKind:
179			if strs.EnforceUTF8(fd) {
180				vi.keyType = validationTypeUTF8String
181			}
182		}
183		switch fd.MapValue().Kind() {
184		case protoreflect.MessageKind:
185			vi.valType = validationTypeMessage
186			if ft.Kind() == reflect.Map {
187				vi.mi = getMessageInfo(ft.Elem())
188			}
189		case protoreflect.StringKind:
190			if strs.EnforceUTF8(fd) {
191				vi.valType = validationTypeUTF8String
192			}
193		}
194	default:
195		switch fd.Kind() {
196		case protoreflect.MessageKind:
197			vi.typ = validationTypeMessage
198			if !fd.IsWeak() {
199				vi.mi = getMessageInfo(ft)
200			}
201		case protoreflect.GroupKind:
202			vi.typ = validationTypeGroup
203			vi.mi = getMessageInfo(ft)
204		case protoreflect.StringKind:
205			vi.typ = validationTypeBytes
206			if strs.EnforceUTF8(fd) {
207				vi.typ = validationTypeUTF8String
208			}
209		default:
210			switch wireTypes[fd.Kind()] {
211			case protowire.VarintType:
212				vi.typ = validationTypeVarint
213			case protowire.Fixed32Type:
214				vi.typ = validationTypeFixed32
215			case protowire.Fixed64Type:
216				vi.typ = validationTypeFixed64
217			case protowire.BytesType:
218				vi.typ = validationTypeBytes
219			}
220		}
221	}
222	return vi
223}
224
225func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
226	mi.init()
227	type validationState struct {
228		typ              validationType
229		keyType, valType validationType
230		endGroup         protowire.Number
231		mi               *MessageInfo
232		tail             []byte
233		requiredMask     uint64
234	}
235
236	// Pre-allocate some slots to avoid repeated slice reallocation.
237	states := make([]validationState, 0, 16)
238	states = append(states, validationState{
239		typ: validationTypeMessage,
240		mi:  mi,
241	})
242	if groupTag > 0 {
243		states[0].typ = validationTypeGroup
244		states[0].endGroup = groupTag
245	}
246	initialized := true
247	start := len(b)
248State:
249	for len(states) > 0 {
250		st := &states[len(states)-1]
251		for len(b) > 0 {
252			// Parse the tag (field number and wire type).
253			var tag uint64
254			if b[0] < 0x80 {
255				tag = uint64(b[0])
256				b = b[1:]
257			} else if len(b) >= 2 && b[1] < 128 {
258				tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
259				b = b[2:]
260			} else {
261				var n int
262				tag, n = protowire.ConsumeVarint(b)
263				if n < 0 {
264					return out, ValidationInvalid
265				}
266				b = b[n:]
267			}
268			var num protowire.Number
269			if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
270				return out, ValidationInvalid
271			} else {
272				num = protowire.Number(n)
273			}
274			wtyp := protowire.Type(tag & 7)
275
276			if wtyp == protowire.EndGroupType {
277				if st.endGroup == num {
278					goto PopState
279				}
280				return out, ValidationInvalid
281			}
282			var vi validationInfo
283			switch {
284			case st.typ == validationTypeMap:
285				switch num {
286				case genid.MapEntry_Key_field_number:
287					vi.typ = st.keyType
288				case genid.MapEntry_Value_field_number:
289					vi.typ = st.valType
290					vi.mi = st.mi
291					vi.requiredBit = 1
292				}
293			case flags.ProtoLegacy && st.mi.isMessageSet:
294				switch num {
295				case messageset.FieldItem:
296					vi.typ = validationTypeMessageSetItem
297				}
298			default:
299				var f *coderFieldInfo
300				if int(num) < len(st.mi.denseCoderFields) {
301					f = st.mi.denseCoderFields[num]
302				} else {
303					f = st.mi.coderFields[num]
304				}
305				if f != nil {
306					vi = f.validation
307					if vi.typ == validationTypeMessage && vi.mi == nil {
308						// Probable weak field.
309						//
310						// TODO: Consider storing the results of this lookup somewhere
311						// rather than recomputing it on every validation.
312						fd := st.mi.Desc.Fields().ByNumber(num)
313						if fd == nil || !fd.IsWeak() {
314							break
315						}
316						messageName := fd.Message().FullName()
317						messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName)
318						switch err {
319						case nil:
320							vi.mi, _ = messageType.(*MessageInfo)
321						case protoregistry.NotFound:
322							vi.typ = validationTypeBytes
323						default:
324							return out, ValidationUnknown
325						}
326					}
327					break
328				}
329				// Possible extension field.
330				//
331				// TODO: We should return ValidationUnknown when:
332				//   1. The resolver is not frozen. (More extensions may be added to it.)
333				//   2. The resolver returns preg.NotFound.
334				// In this case, a type added to the resolver in the future could cause
335				// unmarshaling to begin failing. Supporting this requires some way to
336				// determine if the resolver is frozen.
337				xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
338				if err != nil && err != protoregistry.NotFound {
339					return out, ValidationUnknown
340				}
341				if err == nil {
342					vi = getExtensionFieldInfo(xt).validation
343				}
344			}
345			if vi.requiredBit != 0 {
346				// Check that the field has a compatible wire type.
347				// We only need to consider non-repeated field types,
348				// since repeated fields (and maps) can never be required.
349				ok := false
350				switch vi.typ {
351				case validationTypeVarint:
352					ok = wtyp == protowire.VarintType
353				case validationTypeFixed32:
354					ok = wtyp == protowire.Fixed32Type
355				case validationTypeFixed64:
356					ok = wtyp == protowire.Fixed64Type
357				case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
358					ok = wtyp == protowire.BytesType
359				case validationTypeGroup:
360					ok = wtyp == protowire.StartGroupType
361				}
362				if ok {
363					st.requiredMask |= vi.requiredBit
364				}
365			}
366
367			switch wtyp {
368			case protowire.VarintType:
369				if len(b) >= 10 {
370					switch {
371					case b[0] < 0x80:
372						b = b[1:]
373					case b[1] < 0x80:
374						b = b[2:]
375					case b[2] < 0x80:
376						b = b[3:]
377					case b[3] < 0x80:
378						b = b[4:]
379					case b[4] < 0x80:
380						b = b[5:]
381					case b[5] < 0x80:
382						b = b[6:]
383					case b[6] < 0x80:
384						b = b[7:]
385					case b[7] < 0x80:
386						b = b[8:]
387					case b[8] < 0x80:
388						b = b[9:]
389					case b[9] < 0x80 && b[9] < 2:
390						b = b[10:]
391					default:
392						return out, ValidationInvalid
393					}
394				} else {
395					switch {
396					case len(b) > 0 && b[0] < 0x80:
397						b = b[1:]
398					case len(b) > 1 && b[1] < 0x80:
399						b = b[2:]
400					case len(b) > 2 && b[2] < 0x80:
401						b = b[3:]
402					case len(b) > 3 && b[3] < 0x80:
403						b = b[4:]
404					case len(b) > 4 && b[4] < 0x80:
405						b = b[5:]
406					case len(b) > 5 && b[5] < 0x80:
407						b = b[6:]
408					case len(b) > 6 && b[6] < 0x80:
409						b = b[7:]
410					case len(b) > 7 && b[7] < 0x80:
411						b = b[8:]
412					case len(b) > 8 && b[8] < 0x80:
413						b = b[9:]
414					case len(b) > 9 && b[9] < 2:
415						b = b[10:]
416					default:
417						return out, ValidationInvalid
418					}
419				}
420				continue State
421			case protowire.BytesType:
422				var size uint64
423				if len(b) >= 1 && b[0] < 0x80 {
424					size = uint64(b[0])
425					b = b[1:]
426				} else if len(b) >= 2 && b[1] < 128 {
427					size = uint64(b[0]&0x7f) + uint64(b[1])<<7
428					b = b[2:]
429				} else {
430					var n int
431					size, n = protowire.ConsumeVarint(b)
432					if n < 0 {
433						return out, ValidationInvalid
434					}
435					b = b[n:]
436				}
437				if size > uint64(len(b)) {
438					return out, ValidationInvalid
439				}
440				v := b[:size]
441				b = b[size:]
442				switch vi.typ {
443				case validationTypeMessage:
444					if vi.mi == nil {
445						return out, ValidationUnknown
446					}
447					vi.mi.init()
448					fallthrough
449				case validationTypeMap:
450					if vi.mi != nil {
451						vi.mi.init()
452					}
453					states = append(states, validationState{
454						typ:     vi.typ,
455						keyType: vi.keyType,
456						valType: vi.valType,
457						mi:      vi.mi,
458						tail:    b,
459					})
460					b = v
461					continue State
462				case validationTypeRepeatedVarint:
463					// Packed field.
464					for len(v) > 0 {
465						_, n := protowire.ConsumeVarint(v)
466						if n < 0 {
467							return out, ValidationInvalid
468						}
469						v = v[n:]
470					}
471				case validationTypeRepeatedFixed32:
472					// Packed field.
473					if len(v)%4 != 0 {
474						return out, ValidationInvalid
475					}
476				case validationTypeRepeatedFixed64:
477					// Packed field.
478					if len(v)%8 != 0 {
479						return out, ValidationInvalid
480					}
481				case validationTypeUTF8String:
482					if !utf8.Valid(v) {
483						return out, ValidationInvalid
484					}
485				}
486			case protowire.Fixed32Type:
487				if len(b) < 4 {
488					return out, ValidationInvalid
489				}
490				b = b[4:]
491			case protowire.Fixed64Type:
492				if len(b) < 8 {
493					return out, ValidationInvalid
494				}
495				b = b[8:]
496			case protowire.StartGroupType:
497				switch {
498				case vi.typ == validationTypeGroup:
499					if vi.mi == nil {
500						return out, ValidationUnknown
501					}
502					vi.mi.init()
503					states = append(states, validationState{
504						typ:      validationTypeGroup,
505						mi:       vi.mi,
506						endGroup: num,
507					})
508					continue State
509				case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
510					typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
511					if err != nil {
512						return out, ValidationInvalid
513					}
514					xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
515					switch {
516					case err == protoregistry.NotFound:
517						b = b[n:]
518					case err != nil:
519						return out, ValidationUnknown
520					default:
521						xvi := getExtensionFieldInfo(xt).validation
522						if xvi.mi != nil {
523							xvi.mi.init()
524						}
525						states = append(states, validationState{
526							typ:  xvi.typ,
527							mi:   xvi.mi,
528							tail: b[n:],
529						})
530						b = v
531						continue State
532					}
533				default:
534					n := protowire.ConsumeFieldValue(num, wtyp, b)
535					if n < 0 {
536						return out, ValidationInvalid
537					}
538					b = b[n:]
539				}
540			default:
541				return out, ValidationInvalid
542			}
543		}
544		if st.endGroup != 0 {
545			return out, ValidationInvalid
546		}
547		if len(b) != 0 {
548			return out, ValidationInvalid
549		}
550		b = st.tail
551	PopState:
552		numRequiredFields := 0
553		switch st.typ {
554		case validationTypeMessage, validationTypeGroup:
555			numRequiredFields = int(st.mi.numRequiredFields)
556		case validationTypeMap:
557			// If this is a map field with a message value that contains
558			// required fields, require that the value be present.
559			if st.mi != nil && st.mi.numRequiredFields > 0 {
560				numRequiredFields = 1
561			}
562		}
563		// If there are more than 64 required fields, this check will
564		// always fail and we will report that the message is potentially
565		// uninitialized.
566		if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
567			initialized = false
568		}
569		states = states[:len(states)-1]
570	}
571	out.n = start - len(b)
572	if initialized {
573		out.initialized = true
574	}
575	return out, ValidationValid
576}
577