xref: /aosp_15_r20/external/cronet/third_party/protobuf/python/google/protobuf/internal/decoder.py (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Code for decoding protocol buffer primitives.
32
33This code is very similar to encoder.py -- read the docs for that module first.
34
35A "decoder" is a function with the signature:
36  Decode(buffer, pos, end, message, field_dict)
37The arguments are:
38  buffer:     The string containing the encoded message.
39  pos:        The current position in the string.
40  end:        The position in the string where the current message ends.  May be
41              less than len(buffer) if we're reading a sub-message.
42  message:    The message object into which we're parsing.
43  field_dict: message._fields (avoids a hashtable lookup).
44The decoder reads the field and stores it into field_dict, returning the new
45buffer position.  A decoder for a repeated field may proactively decode all of
46the elements of that field, if they appear consecutively.
47
48Note that decoders may throw any of the following:
49  IndexError:  Indicates a truncated message.
50  struct.error:  Unpacking of a fixed-width field failed.
51  message.DecodeError:  Other errors.
52
53Decoders are expected to raise an exception if they are called with pos > end.
54This allows callers to be lax about bounds checking:  it's fineto read past
55"end" as long as you are sure that someone else will notice and throw an
56exception later on.
57
58Something up the call stack is expected to catch IndexError and struct.error
59and convert them to message.DecodeError.
60
61Decoders are constructed using decoder constructors with the signature:
62  MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
63The arguments are:
64  field_number:  The field number of the field we want to decode.
65  is_repeated:   Is the field a repeated field? (bool)
66  is_packed:     Is the field a packed field? (bool)
67  key:           The key to use when looking up the field within field_dict.
68                 (This is actually the FieldDescriptor but nothing in this
69                 file should depend on that.)
70  new_default:   A function which takes a message object as a parameter and
71                 returns a new instance of the default value for this field.
72                 (This is called for repeated fields and sub-messages, when an
73                 instance does not already exist.)
74
75As with encoders, we define a decoder constructor for every type of field.
76Then, for every field of every message class we construct an actual decoder.
77That decoder goes into a dict indexed by tag, so when we decode a message
78we repeatedly read a tag, look up the corresponding decoder, and invoke it.
79"""
80
81__author__ = '[email protected] (Kenton Varda)'
82
83import math
84import struct
85
86from google.protobuf.internal import containers
87from google.protobuf.internal import encoder
88from google.protobuf.internal import wire_format
89from google.protobuf import message
90
91
92# This is not for optimization, but rather to avoid conflicts with local
93# variables named "message".
94_DecodeError = message.DecodeError
95
96
97def _VarintDecoder(mask, result_type):
98  """Return an encoder for a basic varint value (does not include tag).
99
100  Decoded values will be bitwise-anded with the given mask before being
101  returned, e.g. to limit them to 32 bits.  The returned decoder does not
102  take the usual "end" parameter -- the caller is expected to do bounds checking
103  after the fact (often the caller can defer such checking until later).  The
104  decoder returns a (value, new_pos) pair.
105  """
106
107  def DecodeVarint(buffer, pos):
108    result = 0
109    shift = 0
110    while 1:
111      b = buffer[pos]
112      result |= ((b & 0x7f) << shift)
113      pos += 1
114      if not (b & 0x80):
115        result &= mask
116        result = result_type(result)
117        return (result, pos)
118      shift += 7
119      if shift >= 64:
120        raise _DecodeError('Too many bytes when decoding varint.')
121  return DecodeVarint
122
123
124def _SignedVarintDecoder(bits, result_type):
125  """Like _VarintDecoder() but decodes signed values."""
126
127  signbit = 1 << (bits - 1)
128  mask = (1 << bits) - 1
129
130  def DecodeVarint(buffer, pos):
131    result = 0
132    shift = 0
133    while 1:
134      b = buffer[pos]
135      result |= ((b & 0x7f) << shift)
136      pos += 1
137      if not (b & 0x80):
138        result &= mask
139        result = (result ^ signbit) - signbit
140        result = result_type(result)
141        return (result, pos)
142      shift += 7
143      if shift >= 64:
144        raise _DecodeError('Too many bytes when decoding varint.')
145  return DecodeVarint
146
147# All 32-bit and 64-bit values are represented as int.
148_DecodeVarint = _VarintDecoder((1 << 64) - 1, int)
149_DecodeSignedVarint = _SignedVarintDecoder(64, int)
150
151# Use these versions for values which must be limited to 32 bits.
152_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
153_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
154
155
156def ReadTag(buffer, pos):
157  """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
158
159  We return the raw bytes of the tag rather than decoding them.  The raw
160  bytes can then be used to look up the proper decoder.  This effectively allows
161  us to trade some work that would be done in pure-python (decoding a varint)
162  for work that is done in C (searching for a byte string in a hash table).
163  In a low-level language it would be much cheaper to decode the varint and
164  use that, but not in Python.
165
166  Args:
167    buffer: memoryview object of the encoded bytes
168    pos: int of the current position to start from
169
170  Returns:
171    Tuple[bytes, int] of the tag data and new position.
172  """
173  start = pos
174  while buffer[pos] & 0x80:
175    pos += 1
176  pos += 1
177
178  tag_bytes = buffer[start:pos].tobytes()
179  return tag_bytes, pos
180
181
182# --------------------------------------------------------------------
183
184
185def _SimpleDecoder(wire_type, decode_value):
186  """Return a constructor for a decoder for fields of a particular type.
187
188  Args:
189      wire_type:  The field's wire type.
190      decode_value:  A function which decodes an individual value, e.g.
191        _DecodeVarint()
192  """
193
194  def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
195                      clear_if_default=False):
196    if is_packed:
197      local_DecodeVarint = _DecodeVarint
198      def DecodePackedField(buffer, pos, end, message, field_dict):
199        value = field_dict.get(key)
200        if value is None:
201          value = field_dict.setdefault(key, new_default(message))
202        (endpoint, pos) = local_DecodeVarint(buffer, pos)
203        endpoint += pos
204        if endpoint > end:
205          raise _DecodeError('Truncated message.')
206        while pos < endpoint:
207          (element, pos) = decode_value(buffer, pos)
208          value.append(element)
209        if pos > endpoint:
210          del value[-1]   # Discard corrupt value.
211          raise _DecodeError('Packed element was truncated.')
212        return pos
213      return DecodePackedField
214    elif is_repeated:
215      tag_bytes = encoder.TagBytes(field_number, wire_type)
216      tag_len = len(tag_bytes)
217      def DecodeRepeatedField(buffer, pos, end, message, field_dict):
218        value = field_dict.get(key)
219        if value is None:
220          value = field_dict.setdefault(key, new_default(message))
221        while 1:
222          (element, new_pos) = decode_value(buffer, pos)
223          value.append(element)
224          # Predict that the next tag is another copy of the same repeated
225          # field.
226          pos = new_pos + tag_len
227          if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
228            # Prediction failed.  Return.
229            if new_pos > end:
230              raise _DecodeError('Truncated message.')
231            return new_pos
232      return DecodeRepeatedField
233    else:
234      def DecodeField(buffer, pos, end, message, field_dict):
235        (new_value, pos) = decode_value(buffer, pos)
236        if pos > end:
237          raise _DecodeError('Truncated message.')
238        if clear_if_default and not new_value:
239          field_dict.pop(key, None)
240        else:
241          field_dict[key] = new_value
242        return pos
243      return DecodeField
244
245  return SpecificDecoder
246
247
248def _ModifiedDecoder(wire_type, decode_value, modify_value):
249  """Like SimpleDecoder but additionally invokes modify_value on every value
250  before storing it.  Usually modify_value is ZigZagDecode.
251  """
252
253  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
254  # not enough to make a significant difference.
255
256  def InnerDecode(buffer, pos):
257    (result, new_pos) = decode_value(buffer, pos)
258    return (modify_value(result), new_pos)
259  return _SimpleDecoder(wire_type, InnerDecode)
260
261
262def _StructPackDecoder(wire_type, format):
263  """Return a constructor for a decoder for a fixed-width field.
264
265  Args:
266      wire_type:  The field's wire type.
267      format:  The format string to pass to struct.unpack().
268  """
269
270  value_size = struct.calcsize(format)
271  local_unpack = struct.unpack
272
273  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
274  # not enough to make a significant difference.
275
276  # Note that we expect someone up-stack to catch struct.error and convert
277  # it to _DecodeError -- this way we don't have to set up exception-
278  # handling blocks every time we parse one value.
279
280  def InnerDecode(buffer, pos):
281    new_pos = pos + value_size
282    result = local_unpack(format, buffer[pos:new_pos])[0]
283    return (result, new_pos)
284  return _SimpleDecoder(wire_type, InnerDecode)
285
286
287def _FloatDecoder():
288  """Returns a decoder for a float field.
289
290  This code works around a bug in struct.unpack for non-finite 32-bit
291  floating-point values.
292  """
293
294  local_unpack = struct.unpack
295
296  def InnerDecode(buffer, pos):
297    """Decode serialized float to a float and new position.
298
299    Args:
300      buffer: memoryview of the serialized bytes
301      pos: int, position in the memory view to start at.
302
303    Returns:
304      Tuple[float, int] of the deserialized float value and new position
305      in the serialized data.
306    """
307    # We expect a 32-bit value in little-endian byte order.  Bit 1 is the sign
308    # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
309    new_pos = pos + 4
310    float_bytes = buffer[pos:new_pos].tobytes()
311
312    # If this value has all its exponent bits set, then it's non-finite.
313    # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
314    # To avoid that, we parse it specially.
315    if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
316      # If at least one significand bit is set...
317      if float_bytes[0:3] != b'\x00\x00\x80':
318        return (math.nan, new_pos)
319      # If sign bit is set...
320      if float_bytes[3:4] == b'\xFF':
321        return (-math.inf, new_pos)
322      return (math.inf, new_pos)
323
324    # Note that we expect someone up-stack to catch struct.error and convert
325    # it to _DecodeError -- this way we don't have to set up exception-
326    # handling blocks every time we parse one value.
327    result = local_unpack('<f', float_bytes)[0]
328    return (result, new_pos)
329  return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
330
331
332def _DoubleDecoder():
333  """Returns a decoder for a double field.
334
335  This code works around a bug in struct.unpack for not-a-number.
336  """
337
338  local_unpack = struct.unpack
339
340  def InnerDecode(buffer, pos):
341    """Decode serialized double to a double and new position.
342
343    Args:
344      buffer: memoryview of the serialized bytes.
345      pos: int, position in the memory view to start at.
346
347    Returns:
348      Tuple[float, int] of the decoded double value and new position
349      in the serialized data.
350    """
351    # We expect a 64-bit value in little-endian byte order.  Bit 1 is the sign
352    # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
353    new_pos = pos + 8
354    double_bytes = buffer[pos:new_pos].tobytes()
355
356    # If this value has all its exponent bits set and at least one significand
357    # bit set, it's not a number.  In Python 2.4, struct.unpack will treat it
358    # as inf or -inf.  To avoid that, we treat it specially.
359    if ((double_bytes[7:8] in b'\x7F\xFF')
360        and (double_bytes[6:7] >= b'\xF0')
361        and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
362      return (math.nan, new_pos)
363
364    # Note that we expect someone up-stack to catch struct.error and convert
365    # it to _DecodeError -- this way we don't have to set up exception-
366    # handling blocks every time we parse one value.
367    result = local_unpack('<d', double_bytes)[0]
368    return (result, new_pos)
369  return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
370
371
372def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
373                clear_if_default=False):
374  """Returns a decoder for enum field."""
375  enum_type = key.enum_type
376  if is_packed:
377    local_DecodeVarint = _DecodeVarint
378    def DecodePackedField(buffer, pos, end, message, field_dict):
379      """Decode serialized packed enum to its value and a new position.
380
381      Args:
382        buffer: memoryview of the serialized bytes.
383        pos: int, position in the memory view to start at.
384        end: int, end position of serialized data
385        message: Message object to store unknown fields in
386        field_dict: Map[Descriptor, Any] to store decoded values in.
387
388      Returns:
389        int, new position in serialized data.
390      """
391      value = field_dict.get(key)
392      if value is None:
393        value = field_dict.setdefault(key, new_default(message))
394      (endpoint, pos) = local_DecodeVarint(buffer, pos)
395      endpoint += pos
396      if endpoint > end:
397        raise _DecodeError('Truncated message.')
398      while pos < endpoint:
399        value_start_pos = pos
400        (element, pos) = _DecodeSignedVarint32(buffer, pos)
401        # pylint: disable=protected-access
402        if element in enum_type.values_by_number:
403          value.append(element)
404        else:
405          if not message._unknown_fields:
406            message._unknown_fields = []
407          tag_bytes = encoder.TagBytes(field_number,
408                                       wire_format.WIRETYPE_VARINT)
409
410          message._unknown_fields.append(
411              (tag_bytes, buffer[value_start_pos:pos].tobytes()))
412          if message._unknown_field_set is None:
413            message._unknown_field_set = containers.UnknownFieldSet()
414          message._unknown_field_set._add(
415              field_number, wire_format.WIRETYPE_VARINT, element)
416          # pylint: enable=protected-access
417      if pos > endpoint:
418        if element in enum_type.values_by_number:
419          del value[-1]   # Discard corrupt value.
420        else:
421          del message._unknown_fields[-1]
422          # pylint: disable=protected-access
423          del message._unknown_field_set._values[-1]
424          # pylint: enable=protected-access
425        raise _DecodeError('Packed element was truncated.')
426      return pos
427    return DecodePackedField
428  elif is_repeated:
429    tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
430    tag_len = len(tag_bytes)
431    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
432      """Decode serialized repeated enum to its value and a new position.
433
434      Args:
435        buffer: memoryview of the serialized bytes.
436        pos: int, position in the memory view to start at.
437        end: int, end position of serialized data
438        message: Message object to store unknown fields in
439        field_dict: Map[Descriptor, Any] to store decoded values in.
440
441      Returns:
442        int, new position in serialized data.
443      """
444      value = field_dict.get(key)
445      if value is None:
446        value = field_dict.setdefault(key, new_default(message))
447      while 1:
448        (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
449        # pylint: disable=protected-access
450        if element in enum_type.values_by_number:
451          value.append(element)
452        else:
453          if not message._unknown_fields:
454            message._unknown_fields = []
455          message._unknown_fields.append(
456              (tag_bytes, buffer[pos:new_pos].tobytes()))
457          if message._unknown_field_set is None:
458            message._unknown_field_set = containers.UnknownFieldSet()
459          message._unknown_field_set._add(
460              field_number, wire_format.WIRETYPE_VARINT, element)
461        # pylint: enable=protected-access
462        # Predict that the next tag is another copy of the same repeated
463        # field.
464        pos = new_pos + tag_len
465        if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
466          # Prediction failed.  Return.
467          if new_pos > end:
468            raise _DecodeError('Truncated message.')
469          return new_pos
470    return DecodeRepeatedField
471  else:
472    def DecodeField(buffer, pos, end, message, field_dict):
473      """Decode serialized repeated enum to its value and a new position.
474
475      Args:
476        buffer: memoryview of the serialized bytes.
477        pos: int, position in the memory view to start at.
478        end: int, end position of serialized data
479        message: Message object to store unknown fields in
480        field_dict: Map[Descriptor, Any] to store decoded values in.
481
482      Returns:
483        int, new position in serialized data.
484      """
485      value_start_pos = pos
486      (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
487      if pos > end:
488        raise _DecodeError('Truncated message.')
489      if clear_if_default and not enum_value:
490        field_dict.pop(key, None)
491        return pos
492      # pylint: disable=protected-access
493      if enum_value in enum_type.values_by_number:
494        field_dict[key] = enum_value
495      else:
496        if not message._unknown_fields:
497          message._unknown_fields = []
498        tag_bytes = encoder.TagBytes(field_number,
499                                     wire_format.WIRETYPE_VARINT)
500        message._unknown_fields.append(
501            (tag_bytes, buffer[value_start_pos:pos].tobytes()))
502        if message._unknown_field_set is None:
503          message._unknown_field_set = containers.UnknownFieldSet()
504        message._unknown_field_set._add(
505            field_number, wire_format.WIRETYPE_VARINT, enum_value)
506        # pylint: enable=protected-access
507      return pos
508    return DecodeField
509
510
511# --------------------------------------------------------------------
512
513
514Int32Decoder = _SimpleDecoder(
515    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
516
517Int64Decoder = _SimpleDecoder(
518    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
519
520UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
521UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
522
523SInt32Decoder = _ModifiedDecoder(
524    wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
525SInt64Decoder = _ModifiedDecoder(
526    wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
527
528# Note that Python conveniently guarantees that when using the '<' prefix on
529# formats, they will also have the same size across all platforms (as opposed
530# to without the prefix, where their sizes depend on the C compiler's basic
531# type sizes).
532Fixed32Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
533Fixed64Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
534SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
535SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
536FloatDecoder = _FloatDecoder()
537DoubleDecoder = _DoubleDecoder()
538
539BoolDecoder = _ModifiedDecoder(
540    wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
541
542
543def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
544                  clear_if_default=False):
545  """Returns a decoder for a string field."""
546
547  local_DecodeVarint = _DecodeVarint
548
549  def _ConvertToUnicode(memview):
550    """Convert byte to unicode."""
551    byte_str = memview.tobytes()
552    try:
553      value = str(byte_str, 'utf-8')
554    except UnicodeDecodeError as e:
555      # add more information to the error message and re-raise it.
556      e.reason = '%s in field: %s' % (e, key.full_name)
557      raise
558
559    return value
560
561  assert not is_packed
562  if is_repeated:
563    tag_bytes = encoder.TagBytes(field_number,
564                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
565    tag_len = len(tag_bytes)
566    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
567      value = field_dict.get(key)
568      if value is None:
569        value = field_dict.setdefault(key, new_default(message))
570      while 1:
571        (size, pos) = local_DecodeVarint(buffer, pos)
572        new_pos = pos + size
573        if new_pos > end:
574          raise _DecodeError('Truncated string.')
575        value.append(_ConvertToUnicode(buffer[pos:new_pos]))
576        # Predict that the next tag is another copy of the same repeated field.
577        pos = new_pos + tag_len
578        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
579          # Prediction failed.  Return.
580          return new_pos
581    return DecodeRepeatedField
582  else:
583    def DecodeField(buffer, pos, end, message, field_dict):
584      (size, pos) = local_DecodeVarint(buffer, pos)
585      new_pos = pos + size
586      if new_pos > end:
587        raise _DecodeError('Truncated string.')
588      if clear_if_default and not size:
589        field_dict.pop(key, None)
590      else:
591        field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
592      return new_pos
593    return DecodeField
594
595
596def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
597                 clear_if_default=False):
598  """Returns a decoder for a bytes field."""
599
600  local_DecodeVarint = _DecodeVarint
601
602  assert not is_packed
603  if is_repeated:
604    tag_bytes = encoder.TagBytes(field_number,
605                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
606    tag_len = len(tag_bytes)
607    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
608      value = field_dict.get(key)
609      if value is None:
610        value = field_dict.setdefault(key, new_default(message))
611      while 1:
612        (size, pos) = local_DecodeVarint(buffer, pos)
613        new_pos = pos + size
614        if new_pos > end:
615          raise _DecodeError('Truncated string.')
616        value.append(buffer[pos:new_pos].tobytes())
617        # Predict that the next tag is another copy of the same repeated field.
618        pos = new_pos + tag_len
619        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
620          # Prediction failed.  Return.
621          return new_pos
622    return DecodeRepeatedField
623  else:
624    def DecodeField(buffer, pos, end, message, field_dict):
625      (size, pos) = local_DecodeVarint(buffer, pos)
626      new_pos = pos + size
627      if new_pos > end:
628        raise _DecodeError('Truncated string.')
629      if clear_if_default and not size:
630        field_dict.pop(key, None)
631      else:
632        field_dict[key] = buffer[pos:new_pos].tobytes()
633      return new_pos
634    return DecodeField
635
636
637def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
638  """Returns a decoder for a group field."""
639
640  end_tag_bytes = encoder.TagBytes(field_number,
641                                   wire_format.WIRETYPE_END_GROUP)
642  end_tag_len = len(end_tag_bytes)
643
644  assert not is_packed
645  if is_repeated:
646    tag_bytes = encoder.TagBytes(field_number,
647                                 wire_format.WIRETYPE_START_GROUP)
648    tag_len = len(tag_bytes)
649    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
650      value = field_dict.get(key)
651      if value is None:
652        value = field_dict.setdefault(key, new_default(message))
653      while 1:
654        value = field_dict.get(key)
655        if value is None:
656          value = field_dict.setdefault(key, new_default(message))
657        # Read sub-message.
658        pos = value.add()._InternalParse(buffer, pos, end)
659        # Read end tag.
660        new_pos = pos+end_tag_len
661        if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
662          raise _DecodeError('Missing group end tag.')
663        # Predict that the next tag is another copy of the same repeated field.
664        pos = new_pos + tag_len
665        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
666          # Prediction failed.  Return.
667          return new_pos
668    return DecodeRepeatedField
669  else:
670    def DecodeField(buffer, pos, end, message, field_dict):
671      value = field_dict.get(key)
672      if value is None:
673        value = field_dict.setdefault(key, new_default(message))
674      # Read sub-message.
675      pos = value._InternalParse(buffer, pos, end)
676      # Read end tag.
677      new_pos = pos+end_tag_len
678      if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
679        raise _DecodeError('Missing group end tag.')
680      return new_pos
681    return DecodeField
682
683
684def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
685  """Returns a decoder for a message field."""
686
687  local_DecodeVarint = _DecodeVarint
688
689  assert not is_packed
690  if is_repeated:
691    tag_bytes = encoder.TagBytes(field_number,
692                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
693    tag_len = len(tag_bytes)
694    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
695      value = field_dict.get(key)
696      if value is None:
697        value = field_dict.setdefault(key, new_default(message))
698      while 1:
699        # Read length.
700        (size, pos) = local_DecodeVarint(buffer, pos)
701        new_pos = pos + size
702        if new_pos > end:
703          raise _DecodeError('Truncated message.')
704        # Read sub-message.
705        if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
706          # The only reason _InternalParse would return early is if it
707          # encountered an end-group tag.
708          raise _DecodeError('Unexpected end-group tag.')
709        # Predict that the next tag is another copy of the same repeated field.
710        pos = new_pos + tag_len
711        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
712          # Prediction failed.  Return.
713          return new_pos
714    return DecodeRepeatedField
715  else:
716    def DecodeField(buffer, pos, end, message, field_dict):
717      value = field_dict.get(key)
718      if value is None:
719        value = field_dict.setdefault(key, new_default(message))
720      # Read length.
721      (size, pos) = local_DecodeVarint(buffer, pos)
722      new_pos = pos + size
723      if new_pos > end:
724        raise _DecodeError('Truncated message.')
725      # Read sub-message.
726      if value._InternalParse(buffer, pos, new_pos) != new_pos:
727        # The only reason _InternalParse would return early is if it encountered
728        # an end-group tag.
729        raise _DecodeError('Unexpected end-group tag.')
730      return new_pos
731    return DecodeField
732
733
734# --------------------------------------------------------------------
735
736MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
737
738def MessageSetItemDecoder(descriptor):
739  """Returns a decoder for a MessageSet item.
740
741  The parameter is the message Descriptor.
742
743  The message set message looks like this:
744    message MessageSet {
745      repeated group Item = 1 {
746        required int32 type_id = 2;
747        required string message = 3;
748      }
749    }
750  """
751
752  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
753  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
754  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
755
756  local_ReadTag = ReadTag
757  local_DecodeVarint = _DecodeVarint
758  local_SkipField = SkipField
759
760  def DecodeItem(buffer, pos, end, message, field_dict):
761    """Decode serialized message set to its value and new position.
762
763    Args:
764      buffer: memoryview of the serialized bytes.
765      pos: int, position in the memory view to start at.
766      end: int, end position of serialized data
767      message: Message object to store unknown fields in
768      field_dict: Map[Descriptor, Any] to store decoded values in.
769
770    Returns:
771      int, new position in serialized data.
772    """
773    message_set_item_start = pos
774    type_id = -1
775    message_start = -1
776    message_end = -1
777
778    # Technically, type_id and message can appear in any order, so we need
779    # a little loop here.
780    while 1:
781      (tag_bytes, pos) = local_ReadTag(buffer, pos)
782      if tag_bytes == type_id_tag_bytes:
783        (type_id, pos) = local_DecodeVarint(buffer, pos)
784      elif tag_bytes == message_tag_bytes:
785        (size, message_start) = local_DecodeVarint(buffer, pos)
786        pos = message_end = message_start + size
787      elif tag_bytes == item_end_tag_bytes:
788        break
789      else:
790        pos = SkipField(buffer, pos, end, tag_bytes)
791        if pos == -1:
792          raise _DecodeError('Missing group end tag.')
793
794    if pos > end:
795      raise _DecodeError('Truncated message.')
796
797    if type_id == -1:
798      raise _DecodeError('MessageSet item missing type_id.')
799    if message_start == -1:
800      raise _DecodeError('MessageSet item missing message.')
801
802    extension = message.Extensions._FindExtensionByNumber(type_id)
803    # pylint: disable=protected-access
804    if extension is not None:
805      value = field_dict.get(extension)
806      if value is None:
807        message_type = extension.message_type
808        if not hasattr(message_type, '_concrete_class'):
809          # pylint: disable=protected-access
810          message._FACTORY.GetPrototype(message_type)
811        value = field_dict.setdefault(
812            extension, message_type._concrete_class())
813      if value._InternalParse(buffer, message_start,message_end) != message_end:
814        # The only reason _InternalParse would return early is if it encountered
815        # an end-group tag.
816        raise _DecodeError('Unexpected end-group tag.')
817    else:
818      if not message._unknown_fields:
819        message._unknown_fields = []
820      message._unknown_fields.append(
821          (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
822      if message._unknown_field_set is None:
823        message._unknown_field_set = containers.UnknownFieldSet()
824      message._unknown_field_set._add(
825          type_id,
826          wire_format.WIRETYPE_LENGTH_DELIMITED,
827          buffer[message_start:message_end].tobytes())
828      # pylint: enable=protected-access
829
830    return pos
831
832  return DecodeItem
833
834
835def UnknownMessageSetItemDecoder():
836  """Returns a decoder for a Unknown MessageSet item."""
837
838  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
839  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
840  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
841
842  def DecodeUnknownItem(buffer):
843    pos = 0
844    end = len(buffer)
845    message_start = -1
846    message_end = -1
847    while 1:
848      (tag_bytes, pos) = ReadTag(buffer, pos)
849      if tag_bytes == type_id_tag_bytes:
850        (type_id, pos) = _DecodeVarint(buffer, pos)
851      elif tag_bytes == message_tag_bytes:
852        (size, message_start) = _DecodeVarint(buffer, pos)
853        pos = message_end = message_start + size
854      elif tag_bytes == item_end_tag_bytes:
855        break
856      else:
857        pos = SkipField(buffer, pos, end, tag_bytes)
858        if pos == -1:
859          raise _DecodeError('Missing group end tag.')
860
861    if pos > end:
862      raise _DecodeError('Truncated message.')
863
864    if type_id == -1:
865      raise _DecodeError('MessageSet item missing type_id.')
866    if message_start == -1:
867      raise _DecodeError('MessageSet item missing message.')
868
869    return (type_id, buffer[message_start:message_end].tobytes())
870
871  return DecodeUnknownItem
872
873# --------------------------------------------------------------------
874
875def MapDecoder(field_descriptor, new_default, is_message_map):
876  """Returns a decoder for a map field."""
877
878  key = field_descriptor
879  tag_bytes = encoder.TagBytes(field_descriptor.number,
880                               wire_format.WIRETYPE_LENGTH_DELIMITED)
881  tag_len = len(tag_bytes)
882  local_DecodeVarint = _DecodeVarint
883  # Can't read _concrete_class yet; might not be initialized.
884  message_type = field_descriptor.message_type
885
886  def DecodeMap(buffer, pos, end, message, field_dict):
887    submsg = message_type._concrete_class()
888    value = field_dict.get(key)
889    if value is None:
890      value = field_dict.setdefault(key, new_default(message))
891    while 1:
892      # Read length.
893      (size, pos) = local_DecodeVarint(buffer, pos)
894      new_pos = pos + size
895      if new_pos > end:
896        raise _DecodeError('Truncated message.')
897      # Read sub-message.
898      submsg.Clear()
899      if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
900        # The only reason _InternalParse would return early is if it
901        # encountered an end-group tag.
902        raise _DecodeError('Unexpected end-group tag.')
903
904      if is_message_map:
905        value[submsg.key].CopyFrom(submsg.value)
906      else:
907        value[submsg.key] = submsg.value
908
909      # Predict that the next tag is another copy of the same repeated field.
910      pos = new_pos + tag_len
911      if buffer[new_pos:pos] != tag_bytes or new_pos == end:
912        # Prediction failed.  Return.
913        return new_pos
914
915  return DecodeMap
916
917# --------------------------------------------------------------------
918# Optimization is not as heavy here because calls to SkipField() are rare,
919# except for handling end-group tags.
920
921def _SkipVarint(buffer, pos, end):
922  """Skip a varint value.  Returns the new position."""
923  # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
924  # With this code, ord(b'') raises TypeError.  Both are handled in
925  # python_message.py to generate a 'Truncated message' error.
926  while ord(buffer[pos:pos+1].tobytes()) & 0x80:
927    pos += 1
928  pos += 1
929  if pos > end:
930    raise _DecodeError('Truncated message.')
931  return pos
932
933def _SkipFixed64(buffer, pos, end):
934  """Skip a fixed64 value.  Returns the new position."""
935
936  pos += 8
937  if pos > end:
938    raise _DecodeError('Truncated message.')
939  return pos
940
941
942def _DecodeFixed64(buffer, pos):
943  """Decode a fixed64."""
944  new_pos = pos + 8
945  return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
946
947
948def _SkipLengthDelimited(buffer, pos, end):
949  """Skip a length-delimited value.  Returns the new position."""
950
951  (size, pos) = _DecodeVarint(buffer, pos)
952  pos += size
953  if pos > end:
954    raise _DecodeError('Truncated message.')
955  return pos
956
957
958def _SkipGroup(buffer, pos, end):
959  """Skip sub-group.  Returns the new position."""
960
961  while 1:
962    (tag_bytes, pos) = ReadTag(buffer, pos)
963    new_pos = SkipField(buffer, pos, end, tag_bytes)
964    if new_pos == -1:
965      return pos
966    pos = new_pos
967
968
969def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
970  """Decode UnknownFieldSet.  Returns the UnknownFieldSet and new position."""
971
972  unknown_field_set = containers.UnknownFieldSet()
973  while end_pos is None or pos < end_pos:
974    (tag_bytes, pos) = ReadTag(buffer, pos)
975    (tag, _) = _DecodeVarint(tag_bytes, 0)
976    field_number, wire_type = wire_format.UnpackTag(tag)
977    if wire_type == wire_format.WIRETYPE_END_GROUP:
978      break
979    (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
980    # pylint: disable=protected-access
981    unknown_field_set._add(field_number, wire_type, data)
982
983  return (unknown_field_set, pos)
984
985
986def _DecodeUnknownField(buffer, pos, wire_type):
987  """Decode a unknown field.  Returns the UnknownField and new position."""
988
989  if wire_type == wire_format.WIRETYPE_VARINT:
990    (data, pos) = _DecodeVarint(buffer, pos)
991  elif wire_type == wire_format.WIRETYPE_FIXED64:
992    (data, pos) = _DecodeFixed64(buffer, pos)
993  elif wire_type == wire_format.WIRETYPE_FIXED32:
994    (data, pos) = _DecodeFixed32(buffer, pos)
995  elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
996    (size, pos) = _DecodeVarint(buffer, pos)
997    data = buffer[pos:pos+size].tobytes()
998    pos += size
999  elif wire_type == wire_format.WIRETYPE_START_GROUP:
1000    (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
1001  elif wire_type == wire_format.WIRETYPE_END_GROUP:
1002    return (0, -1)
1003  else:
1004    raise _DecodeError('Wrong wire type in tag.')
1005
1006  return (data, pos)
1007
1008
1009def _EndGroup(buffer, pos, end):
1010  """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
1011
1012  return -1
1013
1014
1015def _SkipFixed32(buffer, pos, end):
1016  """Skip a fixed32 value.  Returns the new position."""
1017
1018  pos += 4
1019  if pos > end:
1020    raise _DecodeError('Truncated message.')
1021  return pos
1022
1023
1024def _DecodeFixed32(buffer, pos):
1025  """Decode a fixed32."""
1026
1027  new_pos = pos + 4
1028  return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
1029
1030
1031def _RaiseInvalidWireType(buffer, pos, end):
1032  """Skip function for unknown wire types.  Raises an exception."""
1033
1034  raise _DecodeError('Tag had invalid wire type.')
1035
1036def _FieldSkipper():
1037  """Constructs the SkipField function."""
1038
1039  WIRETYPE_TO_SKIPPER = [
1040      _SkipVarint,
1041      _SkipFixed64,
1042      _SkipLengthDelimited,
1043      _SkipGroup,
1044      _EndGroup,
1045      _SkipFixed32,
1046      _RaiseInvalidWireType,
1047      _RaiseInvalidWireType,
1048      ]
1049
1050  wiretype_mask = wire_format.TAG_TYPE_MASK
1051
1052  def SkipField(buffer, pos, end, tag_bytes):
1053    """Skips a field with the specified tag.
1054
1055    |pos| should point to the byte immediately after the tag.
1056
1057    Returns:
1058        The new position (after the tag value), or -1 if the tag is an end-group
1059        tag (in which case the calling loop should break).
1060    """
1061
1062    # The wire type is always in the first byte since varints are little-endian.
1063    wire_type = ord(tag_bytes[0:1]) & wiretype_mask
1064    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
1065
1066  return SkipField
1067
1068SkipField = _FieldSkipper()
1069