1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Code for decoding protocol buffer primitives. 32 33This code is very similar to encoder.py -- read the docs for that module first. 34 35A "decoder" is a function with the signature: 36 Decode(buffer, pos, end, message, field_dict) 37The arguments are: 38 buffer: The string containing the encoded message. 39 pos: The current position in the string. 40 end: The position in the string where the current message ends. May be 41 less than len(buffer) if we're reading a sub-message. 42 message: The message object into which we're parsing. 43 field_dict: message._fields (avoids a hashtable lookup). 44The decoder reads the field and stores it into field_dict, returning the new 45buffer position. A decoder for a repeated field may proactively decode all of 46the elements of that field, if they appear consecutively. 47 48Note that decoders may throw any of the following: 49 IndexError: Indicates a truncated message. 50 struct.error: Unpacking of a fixed-width field failed. 51 message.DecodeError: Other errors. 52 53Decoders are expected to raise an exception if they are called with pos > end. 54This allows callers to be lax about bounds checking: it's fineto read past 55"end" as long as you are sure that someone else will notice and throw an 56exception later on. 57 58Something up the call stack is expected to catch IndexError and struct.error 59and convert them to message.DecodeError. 60 61Decoders are constructed using decoder constructors with the signature: 62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default) 63The arguments are: 64 field_number: The field number of the field we want to decode. 65 is_repeated: Is the field a repeated field? (bool) 66 is_packed: Is the field a packed field? (bool) 67 key: The key to use when looking up the field within field_dict. 68 (This is actually the FieldDescriptor but nothing in this 69 file should depend on that.) 70 new_default: A function which takes a message object as a parameter and 71 returns a new instance of the default value for this field. 72 (This is called for repeated fields and sub-messages, when an 73 instance does not already exist.) 74 75As with encoders, we define a decoder constructor for every type of field. 76Then, for every field of every message class we construct an actual decoder. 77That decoder goes into a dict indexed by tag, so when we decode a message 78we repeatedly read a tag, look up the corresponding decoder, and invoke it. 79""" 80 81__author__ = '[email protected] (Kenton Varda)' 82 83import math 84import struct 85 86from google.protobuf.internal import containers 87from google.protobuf.internal import encoder 88from google.protobuf.internal import wire_format 89from google.protobuf import message 90 91 92# This is not for optimization, but rather to avoid conflicts with local 93# variables named "message". 94_DecodeError = message.DecodeError 95 96 97def _VarintDecoder(mask, result_type): 98 """Return an encoder for a basic varint value (does not include tag). 99 100 Decoded values will be bitwise-anded with the given mask before being 101 returned, e.g. to limit them to 32 bits. The returned decoder does not 102 take the usual "end" parameter -- the caller is expected to do bounds checking 103 after the fact (often the caller can defer such checking until later). The 104 decoder returns a (value, new_pos) pair. 105 """ 106 107 def DecodeVarint(buffer, pos): 108 result = 0 109 shift = 0 110 while 1: 111 b = buffer[pos] 112 result |= ((b & 0x7f) << shift) 113 pos += 1 114 if not (b & 0x80): 115 result &= mask 116 result = result_type(result) 117 return (result, pos) 118 shift += 7 119 if shift >= 64: 120 raise _DecodeError('Too many bytes when decoding varint.') 121 return DecodeVarint 122 123 124def _SignedVarintDecoder(bits, result_type): 125 """Like _VarintDecoder() but decodes signed values.""" 126 127 signbit = 1 << (bits - 1) 128 mask = (1 << bits) - 1 129 130 def DecodeVarint(buffer, pos): 131 result = 0 132 shift = 0 133 while 1: 134 b = buffer[pos] 135 result |= ((b & 0x7f) << shift) 136 pos += 1 137 if not (b & 0x80): 138 result &= mask 139 result = (result ^ signbit) - signbit 140 result = result_type(result) 141 return (result, pos) 142 shift += 7 143 if shift >= 64: 144 raise _DecodeError('Too many bytes when decoding varint.') 145 return DecodeVarint 146 147# All 32-bit and 64-bit values are represented as int. 148_DecodeVarint = _VarintDecoder((1 << 64) - 1, int) 149_DecodeSignedVarint = _SignedVarintDecoder(64, int) 150 151# Use these versions for values which must be limited to 32 bits. 152_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) 153_DecodeSignedVarint32 = _SignedVarintDecoder(32, int) 154 155 156def ReadTag(buffer, pos): 157 """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. 158 159 We return the raw bytes of the tag rather than decoding them. The raw 160 bytes can then be used to look up the proper decoder. This effectively allows 161 us to trade some work that would be done in pure-python (decoding a varint) 162 for work that is done in C (searching for a byte string in a hash table). 163 In a low-level language it would be much cheaper to decode the varint and 164 use that, but not in Python. 165 166 Args: 167 buffer: memoryview object of the encoded bytes 168 pos: int of the current position to start from 169 170 Returns: 171 Tuple[bytes, int] of the tag data and new position. 172 """ 173 start = pos 174 while buffer[pos] & 0x80: 175 pos += 1 176 pos += 1 177 178 tag_bytes = buffer[start:pos].tobytes() 179 return tag_bytes, pos 180 181 182# -------------------------------------------------------------------- 183 184 185def _SimpleDecoder(wire_type, decode_value): 186 """Return a constructor for a decoder for fields of a particular type. 187 188 Args: 189 wire_type: The field's wire type. 190 decode_value: A function which decodes an individual value, e.g. 191 _DecodeVarint() 192 """ 193 194 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default, 195 clear_if_default=False): 196 if is_packed: 197 local_DecodeVarint = _DecodeVarint 198 def DecodePackedField(buffer, pos, end, message, field_dict): 199 value = field_dict.get(key) 200 if value is None: 201 value = field_dict.setdefault(key, new_default(message)) 202 (endpoint, pos) = local_DecodeVarint(buffer, pos) 203 endpoint += pos 204 if endpoint > end: 205 raise _DecodeError('Truncated message.') 206 while pos < endpoint: 207 (element, pos) = decode_value(buffer, pos) 208 value.append(element) 209 if pos > endpoint: 210 del value[-1] # Discard corrupt value. 211 raise _DecodeError('Packed element was truncated.') 212 return pos 213 return DecodePackedField 214 elif is_repeated: 215 tag_bytes = encoder.TagBytes(field_number, wire_type) 216 tag_len = len(tag_bytes) 217 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 218 value = field_dict.get(key) 219 if value is None: 220 value = field_dict.setdefault(key, new_default(message)) 221 while 1: 222 (element, new_pos) = decode_value(buffer, pos) 223 value.append(element) 224 # Predict that the next tag is another copy of the same repeated 225 # field. 226 pos = new_pos + tag_len 227 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 228 # Prediction failed. Return. 229 if new_pos > end: 230 raise _DecodeError('Truncated message.') 231 return new_pos 232 return DecodeRepeatedField 233 else: 234 def DecodeField(buffer, pos, end, message, field_dict): 235 (new_value, pos) = decode_value(buffer, pos) 236 if pos > end: 237 raise _DecodeError('Truncated message.') 238 if clear_if_default and not new_value: 239 field_dict.pop(key, None) 240 else: 241 field_dict[key] = new_value 242 return pos 243 return DecodeField 244 245 return SpecificDecoder 246 247 248def _ModifiedDecoder(wire_type, decode_value, modify_value): 249 """Like SimpleDecoder but additionally invokes modify_value on every value 250 before storing it. Usually modify_value is ZigZagDecode. 251 """ 252 253 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 254 # not enough to make a significant difference. 255 256 def InnerDecode(buffer, pos): 257 (result, new_pos) = decode_value(buffer, pos) 258 return (modify_value(result), new_pos) 259 return _SimpleDecoder(wire_type, InnerDecode) 260 261 262def _StructPackDecoder(wire_type, format): 263 """Return a constructor for a decoder for a fixed-width field. 264 265 Args: 266 wire_type: The field's wire type. 267 format: The format string to pass to struct.unpack(). 268 """ 269 270 value_size = struct.calcsize(format) 271 local_unpack = struct.unpack 272 273 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but 274 # not enough to make a significant difference. 275 276 # Note that we expect someone up-stack to catch struct.error and convert 277 # it to _DecodeError -- this way we don't have to set up exception- 278 # handling blocks every time we parse one value. 279 280 def InnerDecode(buffer, pos): 281 new_pos = pos + value_size 282 result = local_unpack(format, buffer[pos:new_pos])[0] 283 return (result, new_pos) 284 return _SimpleDecoder(wire_type, InnerDecode) 285 286 287def _FloatDecoder(): 288 """Returns a decoder for a float field. 289 290 This code works around a bug in struct.unpack for non-finite 32-bit 291 floating-point values. 292 """ 293 294 local_unpack = struct.unpack 295 296 def InnerDecode(buffer, pos): 297 """Decode serialized float to a float and new position. 298 299 Args: 300 buffer: memoryview of the serialized bytes 301 pos: int, position in the memory view to start at. 302 303 Returns: 304 Tuple[float, int] of the deserialized float value and new position 305 in the serialized data. 306 """ 307 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign 308 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. 309 new_pos = pos + 4 310 float_bytes = buffer[pos:new_pos].tobytes() 311 312 # If this value has all its exponent bits set, then it's non-finite. 313 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. 314 # To avoid that, we parse it specially. 315 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'): 316 # If at least one significand bit is set... 317 if float_bytes[0:3] != b'\x00\x00\x80': 318 return (math.nan, new_pos) 319 # If sign bit is set... 320 if float_bytes[3:4] == b'\xFF': 321 return (-math.inf, new_pos) 322 return (math.inf, new_pos) 323 324 # Note that we expect someone up-stack to catch struct.error and convert 325 # it to _DecodeError -- this way we don't have to set up exception- 326 # handling blocks every time we parse one value. 327 result = local_unpack('<f', float_bytes)[0] 328 return (result, new_pos) 329 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) 330 331 332def _DoubleDecoder(): 333 """Returns a decoder for a double field. 334 335 This code works around a bug in struct.unpack for not-a-number. 336 """ 337 338 local_unpack = struct.unpack 339 340 def InnerDecode(buffer, pos): 341 """Decode serialized double to a double and new position. 342 343 Args: 344 buffer: memoryview of the serialized bytes. 345 pos: int, position in the memory view to start at. 346 347 Returns: 348 Tuple[float, int] of the decoded double value and new position 349 in the serialized data. 350 """ 351 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign 352 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. 353 new_pos = pos + 8 354 double_bytes = buffer[pos:new_pos].tobytes() 355 356 # If this value has all its exponent bits set and at least one significand 357 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 358 # as inf or -inf. To avoid that, we treat it specially. 359 if ((double_bytes[7:8] in b'\x7F\xFF') 360 and (double_bytes[6:7] >= b'\xF0') 361 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): 362 return (math.nan, new_pos) 363 364 # Note that we expect someone up-stack to catch struct.error and convert 365 # it to _DecodeError -- this way we don't have to set up exception- 366 # handling blocks every time we parse one value. 367 result = local_unpack('<d', double_bytes)[0] 368 return (result, new_pos) 369 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) 370 371 372def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, 373 clear_if_default=False): 374 """Returns a decoder for enum field.""" 375 enum_type = key.enum_type 376 if is_packed: 377 local_DecodeVarint = _DecodeVarint 378 def DecodePackedField(buffer, pos, end, message, field_dict): 379 """Decode serialized packed enum to its value and a new position. 380 381 Args: 382 buffer: memoryview of the serialized bytes. 383 pos: int, position in the memory view to start at. 384 end: int, end position of serialized data 385 message: Message object to store unknown fields in 386 field_dict: Map[Descriptor, Any] to store decoded values in. 387 388 Returns: 389 int, new position in serialized data. 390 """ 391 value = field_dict.get(key) 392 if value is None: 393 value = field_dict.setdefault(key, new_default(message)) 394 (endpoint, pos) = local_DecodeVarint(buffer, pos) 395 endpoint += pos 396 if endpoint > end: 397 raise _DecodeError('Truncated message.') 398 while pos < endpoint: 399 value_start_pos = pos 400 (element, pos) = _DecodeSignedVarint32(buffer, pos) 401 # pylint: disable=protected-access 402 if element in enum_type.values_by_number: 403 value.append(element) 404 else: 405 if not message._unknown_fields: 406 message._unknown_fields = [] 407 tag_bytes = encoder.TagBytes(field_number, 408 wire_format.WIRETYPE_VARINT) 409 410 message._unknown_fields.append( 411 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 412 if message._unknown_field_set is None: 413 message._unknown_field_set = containers.UnknownFieldSet() 414 message._unknown_field_set._add( 415 field_number, wire_format.WIRETYPE_VARINT, element) 416 # pylint: enable=protected-access 417 if pos > endpoint: 418 if element in enum_type.values_by_number: 419 del value[-1] # Discard corrupt value. 420 else: 421 del message._unknown_fields[-1] 422 # pylint: disable=protected-access 423 del message._unknown_field_set._values[-1] 424 # pylint: enable=protected-access 425 raise _DecodeError('Packed element was truncated.') 426 return pos 427 return DecodePackedField 428 elif is_repeated: 429 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) 430 tag_len = len(tag_bytes) 431 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 432 """Decode serialized repeated enum to its value and a new position. 433 434 Args: 435 buffer: memoryview of the serialized bytes. 436 pos: int, position in the memory view to start at. 437 end: int, end position of serialized data 438 message: Message object to store unknown fields in 439 field_dict: Map[Descriptor, Any] to store decoded values in. 440 441 Returns: 442 int, new position in serialized data. 443 """ 444 value = field_dict.get(key) 445 if value is None: 446 value = field_dict.setdefault(key, new_default(message)) 447 while 1: 448 (element, new_pos) = _DecodeSignedVarint32(buffer, pos) 449 # pylint: disable=protected-access 450 if element in enum_type.values_by_number: 451 value.append(element) 452 else: 453 if not message._unknown_fields: 454 message._unknown_fields = [] 455 message._unknown_fields.append( 456 (tag_bytes, buffer[pos:new_pos].tobytes())) 457 if message._unknown_field_set is None: 458 message._unknown_field_set = containers.UnknownFieldSet() 459 message._unknown_field_set._add( 460 field_number, wire_format.WIRETYPE_VARINT, element) 461 # pylint: enable=protected-access 462 # Predict that the next tag is another copy of the same repeated 463 # field. 464 pos = new_pos + tag_len 465 if buffer[new_pos:pos] != tag_bytes or new_pos >= end: 466 # Prediction failed. Return. 467 if new_pos > end: 468 raise _DecodeError('Truncated message.') 469 return new_pos 470 return DecodeRepeatedField 471 else: 472 def DecodeField(buffer, pos, end, message, field_dict): 473 """Decode serialized repeated enum to its value and a new position. 474 475 Args: 476 buffer: memoryview of the serialized bytes. 477 pos: int, position in the memory view to start at. 478 end: int, end position of serialized data 479 message: Message object to store unknown fields in 480 field_dict: Map[Descriptor, Any] to store decoded values in. 481 482 Returns: 483 int, new position in serialized data. 484 """ 485 value_start_pos = pos 486 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) 487 if pos > end: 488 raise _DecodeError('Truncated message.') 489 if clear_if_default and not enum_value: 490 field_dict.pop(key, None) 491 return pos 492 # pylint: disable=protected-access 493 if enum_value in enum_type.values_by_number: 494 field_dict[key] = enum_value 495 else: 496 if not message._unknown_fields: 497 message._unknown_fields = [] 498 tag_bytes = encoder.TagBytes(field_number, 499 wire_format.WIRETYPE_VARINT) 500 message._unknown_fields.append( 501 (tag_bytes, buffer[value_start_pos:pos].tobytes())) 502 if message._unknown_field_set is None: 503 message._unknown_field_set = containers.UnknownFieldSet() 504 message._unknown_field_set._add( 505 field_number, wire_format.WIRETYPE_VARINT, enum_value) 506 # pylint: enable=protected-access 507 return pos 508 return DecodeField 509 510 511# -------------------------------------------------------------------- 512 513 514Int32Decoder = _SimpleDecoder( 515 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) 516 517Int64Decoder = _SimpleDecoder( 518 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) 519 520UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) 521UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) 522 523SInt32Decoder = _ModifiedDecoder( 524 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) 525SInt64Decoder = _ModifiedDecoder( 526 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) 527 528# Note that Python conveniently guarantees that when using the '<' prefix on 529# formats, they will also have the same size across all platforms (as opposed 530# to without the prefix, where their sizes depend on the C compiler's basic 531# type sizes). 532Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') 533Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') 534SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') 535SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') 536FloatDecoder = _FloatDecoder() 537DoubleDecoder = _DoubleDecoder() 538 539BoolDecoder = _ModifiedDecoder( 540 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) 541 542 543def StringDecoder(field_number, is_repeated, is_packed, key, new_default, 544 clear_if_default=False): 545 """Returns a decoder for a string field.""" 546 547 local_DecodeVarint = _DecodeVarint 548 549 def _ConvertToUnicode(memview): 550 """Convert byte to unicode.""" 551 byte_str = memview.tobytes() 552 try: 553 value = str(byte_str, 'utf-8') 554 except UnicodeDecodeError as e: 555 # add more information to the error message and re-raise it. 556 e.reason = '%s in field: %s' % (e, key.full_name) 557 raise 558 559 return value 560 561 assert not is_packed 562 if is_repeated: 563 tag_bytes = encoder.TagBytes(field_number, 564 wire_format.WIRETYPE_LENGTH_DELIMITED) 565 tag_len = len(tag_bytes) 566 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 567 value = field_dict.get(key) 568 if value is None: 569 value = field_dict.setdefault(key, new_default(message)) 570 while 1: 571 (size, pos) = local_DecodeVarint(buffer, pos) 572 new_pos = pos + size 573 if new_pos > end: 574 raise _DecodeError('Truncated string.') 575 value.append(_ConvertToUnicode(buffer[pos:new_pos])) 576 # Predict that the next tag is another copy of the same repeated field. 577 pos = new_pos + tag_len 578 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 579 # Prediction failed. Return. 580 return new_pos 581 return DecodeRepeatedField 582 else: 583 def DecodeField(buffer, pos, end, message, field_dict): 584 (size, pos) = local_DecodeVarint(buffer, pos) 585 new_pos = pos + size 586 if new_pos > end: 587 raise _DecodeError('Truncated string.') 588 if clear_if_default and not size: 589 field_dict.pop(key, None) 590 else: 591 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) 592 return new_pos 593 return DecodeField 594 595 596def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, 597 clear_if_default=False): 598 """Returns a decoder for a bytes field.""" 599 600 local_DecodeVarint = _DecodeVarint 601 602 assert not is_packed 603 if is_repeated: 604 tag_bytes = encoder.TagBytes(field_number, 605 wire_format.WIRETYPE_LENGTH_DELIMITED) 606 tag_len = len(tag_bytes) 607 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 608 value = field_dict.get(key) 609 if value is None: 610 value = field_dict.setdefault(key, new_default(message)) 611 while 1: 612 (size, pos) = local_DecodeVarint(buffer, pos) 613 new_pos = pos + size 614 if new_pos > end: 615 raise _DecodeError('Truncated string.') 616 value.append(buffer[pos:new_pos].tobytes()) 617 # Predict that the next tag is another copy of the same repeated field. 618 pos = new_pos + tag_len 619 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 620 # Prediction failed. Return. 621 return new_pos 622 return DecodeRepeatedField 623 else: 624 def DecodeField(buffer, pos, end, message, field_dict): 625 (size, pos) = local_DecodeVarint(buffer, pos) 626 new_pos = pos + size 627 if new_pos > end: 628 raise _DecodeError('Truncated string.') 629 if clear_if_default and not size: 630 field_dict.pop(key, None) 631 else: 632 field_dict[key] = buffer[pos:new_pos].tobytes() 633 return new_pos 634 return DecodeField 635 636 637def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): 638 """Returns a decoder for a group field.""" 639 640 end_tag_bytes = encoder.TagBytes(field_number, 641 wire_format.WIRETYPE_END_GROUP) 642 end_tag_len = len(end_tag_bytes) 643 644 assert not is_packed 645 if is_repeated: 646 tag_bytes = encoder.TagBytes(field_number, 647 wire_format.WIRETYPE_START_GROUP) 648 tag_len = len(tag_bytes) 649 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 650 value = field_dict.get(key) 651 if value is None: 652 value = field_dict.setdefault(key, new_default(message)) 653 while 1: 654 value = field_dict.get(key) 655 if value is None: 656 value = field_dict.setdefault(key, new_default(message)) 657 # Read sub-message. 658 pos = value.add()._InternalParse(buffer, pos, end) 659 # Read end tag. 660 new_pos = pos+end_tag_len 661 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 662 raise _DecodeError('Missing group end tag.') 663 # Predict that the next tag is another copy of the same repeated field. 664 pos = new_pos + tag_len 665 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 666 # Prediction failed. Return. 667 return new_pos 668 return DecodeRepeatedField 669 else: 670 def DecodeField(buffer, pos, end, message, field_dict): 671 value = field_dict.get(key) 672 if value is None: 673 value = field_dict.setdefault(key, new_default(message)) 674 # Read sub-message. 675 pos = value._InternalParse(buffer, pos, end) 676 # Read end tag. 677 new_pos = pos+end_tag_len 678 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: 679 raise _DecodeError('Missing group end tag.') 680 return new_pos 681 return DecodeField 682 683 684def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): 685 """Returns a decoder for a message field.""" 686 687 local_DecodeVarint = _DecodeVarint 688 689 assert not is_packed 690 if is_repeated: 691 tag_bytes = encoder.TagBytes(field_number, 692 wire_format.WIRETYPE_LENGTH_DELIMITED) 693 tag_len = len(tag_bytes) 694 def DecodeRepeatedField(buffer, pos, end, message, field_dict): 695 value = field_dict.get(key) 696 if value is None: 697 value = field_dict.setdefault(key, new_default(message)) 698 while 1: 699 # Read length. 700 (size, pos) = local_DecodeVarint(buffer, pos) 701 new_pos = pos + size 702 if new_pos > end: 703 raise _DecodeError('Truncated message.') 704 # Read sub-message. 705 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 706 # The only reason _InternalParse would return early is if it 707 # encountered an end-group tag. 708 raise _DecodeError('Unexpected end-group tag.') 709 # Predict that the next tag is another copy of the same repeated field. 710 pos = new_pos + tag_len 711 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 712 # Prediction failed. Return. 713 return new_pos 714 return DecodeRepeatedField 715 else: 716 def DecodeField(buffer, pos, end, message, field_dict): 717 value = field_dict.get(key) 718 if value is None: 719 value = field_dict.setdefault(key, new_default(message)) 720 # Read length. 721 (size, pos) = local_DecodeVarint(buffer, pos) 722 new_pos = pos + size 723 if new_pos > end: 724 raise _DecodeError('Truncated message.') 725 # Read sub-message. 726 if value._InternalParse(buffer, pos, new_pos) != new_pos: 727 # The only reason _InternalParse would return early is if it encountered 728 # an end-group tag. 729 raise _DecodeError('Unexpected end-group tag.') 730 return new_pos 731 return DecodeField 732 733 734# -------------------------------------------------------------------- 735 736MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) 737 738def MessageSetItemDecoder(descriptor): 739 """Returns a decoder for a MessageSet item. 740 741 The parameter is the message Descriptor. 742 743 The message set message looks like this: 744 message MessageSet { 745 repeated group Item = 1 { 746 required int32 type_id = 2; 747 required string message = 3; 748 } 749 } 750 """ 751 752 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 753 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 754 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 755 756 local_ReadTag = ReadTag 757 local_DecodeVarint = _DecodeVarint 758 local_SkipField = SkipField 759 760 def DecodeItem(buffer, pos, end, message, field_dict): 761 """Decode serialized message set to its value and new position. 762 763 Args: 764 buffer: memoryview of the serialized bytes. 765 pos: int, position in the memory view to start at. 766 end: int, end position of serialized data 767 message: Message object to store unknown fields in 768 field_dict: Map[Descriptor, Any] to store decoded values in. 769 770 Returns: 771 int, new position in serialized data. 772 """ 773 message_set_item_start = pos 774 type_id = -1 775 message_start = -1 776 message_end = -1 777 778 # Technically, type_id and message can appear in any order, so we need 779 # a little loop here. 780 while 1: 781 (tag_bytes, pos) = local_ReadTag(buffer, pos) 782 if tag_bytes == type_id_tag_bytes: 783 (type_id, pos) = local_DecodeVarint(buffer, pos) 784 elif tag_bytes == message_tag_bytes: 785 (size, message_start) = local_DecodeVarint(buffer, pos) 786 pos = message_end = message_start + size 787 elif tag_bytes == item_end_tag_bytes: 788 break 789 else: 790 pos = SkipField(buffer, pos, end, tag_bytes) 791 if pos == -1: 792 raise _DecodeError('Missing group end tag.') 793 794 if pos > end: 795 raise _DecodeError('Truncated message.') 796 797 if type_id == -1: 798 raise _DecodeError('MessageSet item missing type_id.') 799 if message_start == -1: 800 raise _DecodeError('MessageSet item missing message.') 801 802 extension = message.Extensions._FindExtensionByNumber(type_id) 803 # pylint: disable=protected-access 804 if extension is not None: 805 value = field_dict.get(extension) 806 if value is None: 807 message_type = extension.message_type 808 if not hasattr(message_type, '_concrete_class'): 809 # pylint: disable=protected-access 810 message._FACTORY.GetPrototype(message_type) 811 value = field_dict.setdefault( 812 extension, message_type._concrete_class()) 813 if value._InternalParse(buffer, message_start,message_end) != message_end: 814 # The only reason _InternalParse would return early is if it encountered 815 # an end-group tag. 816 raise _DecodeError('Unexpected end-group tag.') 817 else: 818 if not message._unknown_fields: 819 message._unknown_fields = [] 820 message._unknown_fields.append( 821 (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) 822 if message._unknown_field_set is None: 823 message._unknown_field_set = containers.UnknownFieldSet() 824 message._unknown_field_set._add( 825 type_id, 826 wire_format.WIRETYPE_LENGTH_DELIMITED, 827 buffer[message_start:message_end].tobytes()) 828 # pylint: enable=protected-access 829 830 return pos 831 832 return DecodeItem 833 834 835def UnknownMessageSetItemDecoder(): 836 """Returns a decoder for a Unknown MessageSet item.""" 837 838 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) 839 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) 840 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) 841 842 def DecodeUnknownItem(buffer): 843 pos = 0 844 end = len(buffer) 845 message_start = -1 846 message_end = -1 847 while 1: 848 (tag_bytes, pos) = ReadTag(buffer, pos) 849 if tag_bytes == type_id_tag_bytes: 850 (type_id, pos) = _DecodeVarint(buffer, pos) 851 elif tag_bytes == message_tag_bytes: 852 (size, message_start) = _DecodeVarint(buffer, pos) 853 pos = message_end = message_start + size 854 elif tag_bytes == item_end_tag_bytes: 855 break 856 else: 857 pos = SkipField(buffer, pos, end, tag_bytes) 858 if pos == -1: 859 raise _DecodeError('Missing group end tag.') 860 861 if pos > end: 862 raise _DecodeError('Truncated message.') 863 864 if type_id == -1: 865 raise _DecodeError('MessageSet item missing type_id.') 866 if message_start == -1: 867 raise _DecodeError('MessageSet item missing message.') 868 869 return (type_id, buffer[message_start:message_end].tobytes()) 870 871 return DecodeUnknownItem 872 873# -------------------------------------------------------------------- 874 875def MapDecoder(field_descriptor, new_default, is_message_map): 876 """Returns a decoder for a map field.""" 877 878 key = field_descriptor 879 tag_bytes = encoder.TagBytes(field_descriptor.number, 880 wire_format.WIRETYPE_LENGTH_DELIMITED) 881 tag_len = len(tag_bytes) 882 local_DecodeVarint = _DecodeVarint 883 # Can't read _concrete_class yet; might not be initialized. 884 message_type = field_descriptor.message_type 885 886 def DecodeMap(buffer, pos, end, message, field_dict): 887 submsg = message_type._concrete_class() 888 value = field_dict.get(key) 889 if value is None: 890 value = field_dict.setdefault(key, new_default(message)) 891 while 1: 892 # Read length. 893 (size, pos) = local_DecodeVarint(buffer, pos) 894 new_pos = pos + size 895 if new_pos > end: 896 raise _DecodeError('Truncated message.') 897 # Read sub-message. 898 submsg.Clear() 899 if submsg._InternalParse(buffer, pos, new_pos) != new_pos: 900 # The only reason _InternalParse would return early is if it 901 # encountered an end-group tag. 902 raise _DecodeError('Unexpected end-group tag.') 903 904 if is_message_map: 905 value[submsg.key].CopyFrom(submsg.value) 906 else: 907 value[submsg.key] = submsg.value 908 909 # Predict that the next tag is another copy of the same repeated field. 910 pos = new_pos + tag_len 911 if buffer[new_pos:pos] != tag_bytes or new_pos == end: 912 # Prediction failed. Return. 913 return new_pos 914 915 return DecodeMap 916 917# -------------------------------------------------------------------- 918# Optimization is not as heavy here because calls to SkipField() are rare, 919# except for handling end-group tags. 920 921def _SkipVarint(buffer, pos, end): 922 """Skip a varint value. Returns the new position.""" 923 # Previously ord(buffer[pos]) raised IndexError when pos is out of range. 924 # With this code, ord(b'') raises TypeError. Both are handled in 925 # python_message.py to generate a 'Truncated message' error. 926 while ord(buffer[pos:pos+1].tobytes()) & 0x80: 927 pos += 1 928 pos += 1 929 if pos > end: 930 raise _DecodeError('Truncated message.') 931 return pos 932 933def _SkipFixed64(buffer, pos, end): 934 """Skip a fixed64 value. Returns the new position.""" 935 936 pos += 8 937 if pos > end: 938 raise _DecodeError('Truncated message.') 939 return pos 940 941 942def _DecodeFixed64(buffer, pos): 943 """Decode a fixed64.""" 944 new_pos = pos + 8 945 return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) 946 947 948def _SkipLengthDelimited(buffer, pos, end): 949 """Skip a length-delimited value. Returns the new position.""" 950 951 (size, pos) = _DecodeVarint(buffer, pos) 952 pos += size 953 if pos > end: 954 raise _DecodeError('Truncated message.') 955 return pos 956 957 958def _SkipGroup(buffer, pos, end): 959 """Skip sub-group. Returns the new position.""" 960 961 while 1: 962 (tag_bytes, pos) = ReadTag(buffer, pos) 963 new_pos = SkipField(buffer, pos, end, tag_bytes) 964 if new_pos == -1: 965 return pos 966 pos = new_pos 967 968 969def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): 970 """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" 971 972 unknown_field_set = containers.UnknownFieldSet() 973 while end_pos is None or pos < end_pos: 974 (tag_bytes, pos) = ReadTag(buffer, pos) 975 (tag, _) = _DecodeVarint(tag_bytes, 0) 976 field_number, wire_type = wire_format.UnpackTag(tag) 977 if wire_type == wire_format.WIRETYPE_END_GROUP: 978 break 979 (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) 980 # pylint: disable=protected-access 981 unknown_field_set._add(field_number, wire_type, data) 982 983 return (unknown_field_set, pos) 984 985 986def _DecodeUnknownField(buffer, pos, wire_type): 987 """Decode a unknown field. Returns the UnknownField and new position.""" 988 989 if wire_type == wire_format.WIRETYPE_VARINT: 990 (data, pos) = _DecodeVarint(buffer, pos) 991 elif wire_type == wire_format.WIRETYPE_FIXED64: 992 (data, pos) = _DecodeFixed64(buffer, pos) 993 elif wire_type == wire_format.WIRETYPE_FIXED32: 994 (data, pos) = _DecodeFixed32(buffer, pos) 995 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 996 (size, pos) = _DecodeVarint(buffer, pos) 997 data = buffer[pos:pos+size].tobytes() 998 pos += size 999 elif wire_type == wire_format.WIRETYPE_START_GROUP: 1000 (data, pos) = _DecodeUnknownFieldSet(buffer, pos) 1001 elif wire_type == wire_format.WIRETYPE_END_GROUP: 1002 return (0, -1) 1003 else: 1004 raise _DecodeError('Wrong wire type in tag.') 1005 1006 return (data, pos) 1007 1008 1009def _EndGroup(buffer, pos, end): 1010 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" 1011 1012 return -1 1013 1014 1015def _SkipFixed32(buffer, pos, end): 1016 """Skip a fixed32 value. Returns the new position.""" 1017 1018 pos += 4 1019 if pos > end: 1020 raise _DecodeError('Truncated message.') 1021 return pos 1022 1023 1024def _DecodeFixed32(buffer, pos): 1025 """Decode a fixed32.""" 1026 1027 new_pos = pos + 4 1028 return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) 1029 1030 1031def _RaiseInvalidWireType(buffer, pos, end): 1032 """Skip function for unknown wire types. Raises an exception.""" 1033 1034 raise _DecodeError('Tag had invalid wire type.') 1035 1036def _FieldSkipper(): 1037 """Constructs the SkipField function.""" 1038 1039 WIRETYPE_TO_SKIPPER = [ 1040 _SkipVarint, 1041 _SkipFixed64, 1042 _SkipLengthDelimited, 1043 _SkipGroup, 1044 _EndGroup, 1045 _SkipFixed32, 1046 _RaiseInvalidWireType, 1047 _RaiseInvalidWireType, 1048 ] 1049 1050 wiretype_mask = wire_format.TAG_TYPE_MASK 1051 1052 def SkipField(buffer, pos, end, tag_bytes): 1053 """Skips a field with the specified tag. 1054 1055 |pos| should point to the byte immediately after the tag. 1056 1057 Returns: 1058 The new position (after the tag value), or -1 if the tag is an end-group 1059 tag (in which case the calling loop should break). 1060 """ 1061 1062 # The wire type is always in the first byte since varints are little-endian. 1063 wire_type = ord(tag_bytes[0:1]) & wiretype_mask 1064 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) 1065 1066 return SkipField 1067 1068SkipField = _FieldSkipper() 1069