1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31# This code is meant to work on Python 2.4 and above only. 32# 33# TODO(robinson): Helpers for verbose, common checks like seeing if a 34# descriptor's cpp_type is CPPTYPE_MESSAGE. 35 36"""Contains a metaclass and helper functions used to create 37protocol message classes from Descriptor objects at runtime. 38 39Recall that a metaclass is the "type" of a class. 40(A class is to a metaclass what an instance is to a class.) 41 42In this case, we use the GeneratedProtocolMessageType metaclass 43to inject all the useful functionality into the classes 44output by the protocol compiler at compile-time. 45 46The upshot of all this is that the real implementation 47details for ALL pure-Python protocol buffers are *here in 48this file*. 49""" 50 51__author__ = '[email protected] (Will Robinson)' 52 53from io import BytesIO 54import struct 55import sys 56import weakref 57 58# We use "as" to avoid name collisions with variables. 59from google.protobuf.internal import api_implementation 60from google.protobuf.internal import containers 61from google.protobuf.internal import decoder 62from google.protobuf.internal import encoder 63from google.protobuf.internal import enum_type_wrapper 64from google.protobuf.internal import extension_dict 65from google.protobuf.internal import message_listener as message_listener_mod 66from google.protobuf.internal import type_checkers 67from google.protobuf.internal import well_known_types 68from google.protobuf.internal import wire_format 69from google.protobuf import descriptor as descriptor_mod 70from google.protobuf import message as message_mod 71from google.protobuf import text_format 72 73_FieldDescriptor = descriptor_mod.FieldDescriptor 74_AnyFullTypeName = 'google.protobuf.Any' 75_ExtensionDict = extension_dict._ExtensionDict 76 77class GeneratedProtocolMessageType(type): 78 79 """Metaclass for protocol message classes created at runtime from Descriptors. 80 81 We add implementations for all methods described in the Message class. We 82 also create properties to allow getting/setting all fields in the protocol 83 message. Finally, we create slots to prevent users from accidentally 84 "setting" nonexistent fields in the protocol message, which then wouldn't get 85 serialized / deserialized properly. 86 87 The protocol compiler currently uses this metaclass to create protocol 88 message classes at runtime. Clients can also manually create their own 89 classes at runtime, as in this example: 90 91 mydescriptor = Descriptor(.....) 92 factory = symbol_database.Default() 93 factory.pool.AddDescriptor(mydescriptor) 94 MyProtoClass = factory.GetPrototype(mydescriptor) 95 myproto_instance = MyProtoClass() 96 myproto.foo_field = 23 97 ... 98 """ 99 100 # Must be consistent with the protocol-compiler code in 101 # proto2/compiler/internal/generator.*. 102 _DESCRIPTOR_KEY = 'DESCRIPTOR' 103 104 def __new__(cls, name, bases, dictionary): 105 """Custom allocation for runtime-generated class types. 106 107 We override __new__ because this is apparently the only place 108 where we can meaningfully set __slots__ on the class we're creating(?). 109 (The interplay between metaclasses and slots is not very well-documented). 110 111 Args: 112 name: Name of the class (ignored, but required by the 113 metaclass protocol). 114 bases: Base classes of the class we're constructing. 115 (Should be message.Message). We ignore this field, but 116 it's required by the metaclass protocol 117 dictionary: The class dictionary of the class we're 118 constructing. dictionary[_DESCRIPTOR_KEY] must contain 119 a Descriptor object describing this protocol message 120 type. 121 122 Returns: 123 Newly-allocated class. 124 125 Raises: 126 RuntimeError: Generated code only work with python cpp extension. 127 """ 128 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 129 130 if isinstance(descriptor, str): 131 raise RuntimeError('The generated code only work with python cpp ' 132 'extension, but it is using pure python runtime.') 133 134 # If a concrete class already exists for this descriptor, don't try to 135 # create another. Doing so will break any messages that already exist with 136 # the existing class. 137 # 138 # The C++ implementation appears to have its own internal `PyMessageFactory` 139 # to achieve similar results. 140 # 141 # This most commonly happens in `text_format.py` when using descriptors from 142 # a custom pool; it calls symbol_database.Global().getPrototype() on a 143 # descriptor which already has an existing concrete class. 144 new_class = getattr(descriptor, '_concrete_class', None) 145 if new_class: 146 return new_class 147 148 if descriptor.full_name in well_known_types.WKTBASES: 149 bases += (well_known_types.WKTBASES[descriptor.full_name],) 150 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 151 _AddSlots(descriptor, dictionary) 152 153 superclass = super(GeneratedProtocolMessageType, cls) 154 new_class = superclass.__new__(cls, name, bases, dictionary) 155 return new_class 156 157 def __init__(cls, name, bases, dictionary): 158 """Here we perform the majority of our work on the class. 159 We add enum getters, an __init__ method, implementations 160 of all Message methods, and properties for all fields 161 in the protocol type. 162 163 Args: 164 name: Name of the class (ignored, but required by the 165 metaclass protocol). 166 bases: Base classes of the class we're constructing. 167 (Should be message.Message). We ignore this field, but 168 it's required by the metaclass protocol 169 dictionary: The class dictionary of the class we're 170 constructing. dictionary[_DESCRIPTOR_KEY] must contain 171 a Descriptor object describing this protocol message 172 type. 173 """ 174 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 175 176 # If this is an _existing_ class looked up via `_concrete_class` in the 177 # __new__ method above, then we don't need to re-initialize anything. 178 existing_class = getattr(descriptor, '_concrete_class', None) 179 if existing_class: 180 assert existing_class is cls, ( 181 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' 182 % (descriptor.full_name)) 183 return 184 185 cls._decoders_by_tag = {} 186 if (descriptor.has_options and 187 descriptor.GetOptions().message_set_wire_format): 188 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 189 decoder.MessageSetItemDecoder(descriptor), None) 190 191 # Attach stuff to each FieldDescriptor for quick lookup later on. 192 for field in descriptor.fields: 193 _AttachFieldHelpers(cls, field) 194 195 descriptor._concrete_class = cls # pylint: disable=protected-access 196 _AddEnumValues(descriptor, cls) 197 _AddInitMethod(descriptor, cls) 198 _AddPropertiesForFields(descriptor, cls) 199 _AddPropertiesForExtensions(descriptor, cls) 200 _AddStaticMethods(cls) 201 _AddMessageMethods(descriptor, cls) 202 _AddPrivateHelperMethods(descriptor, cls) 203 204 superclass = super(GeneratedProtocolMessageType, cls) 205 superclass.__init__(name, bases, dictionary) 206 207 208# Stateless helpers for GeneratedProtocolMessageType below. 209# Outside clients should not access these directly. 210# 211# I opted not to make any of these methods on the metaclass, to make it more 212# clear that I'm not really using any state there and to keep clients from 213# thinking that they have direct access to these construction helpers. 214 215 216def _PropertyName(proto_field_name): 217 """Returns the name of the public property attribute which 218 clients can use to get and (in some cases) set the value 219 of a protocol message field. 220 221 Args: 222 proto_field_name: The protocol message field name, exactly 223 as it appears (or would appear) in a .proto file. 224 """ 225 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. 226 # nnorwitz makes my day by writing: 227 # """ 228 # FYI. See the keyword module in the stdlib. This could be as simple as: 229 # 230 # if keyword.iskeyword(proto_field_name): 231 # return proto_field_name + "_" 232 # return proto_field_name 233 # """ 234 # Kenton says: The above is a BAD IDEA. People rely on being able to use 235 # getattr() and setattr() to reflectively manipulate field values. If we 236 # rename the properties, then every such user has to also make sure to apply 237 # the same transformation. Note that currently if you name a field "yield", 238 # you can still access it just fine using getattr/setattr -- it's not even 239 # that cumbersome to do so. 240 # TODO(kenton): Remove this method entirely if/when everyone agrees with my 241 # position. 242 return proto_field_name 243 244 245def _AddSlots(message_descriptor, dictionary): 246 """Adds a __slots__ entry to dictionary, containing the names of all valid 247 attributes for this message type. 248 249 Args: 250 message_descriptor: A Descriptor instance describing this message type. 251 dictionary: Class dictionary to which we'll add a '__slots__' entry. 252 """ 253 dictionary['__slots__'] = ['_cached_byte_size', 254 '_cached_byte_size_dirty', 255 '_fields', 256 '_unknown_fields', 257 '_unknown_field_set', 258 '_is_present_in_parent', 259 '_listener', 260 '_listener_for_children', 261 '__weakref__', 262 '_oneofs'] 263 264 265def _IsMessageSetExtension(field): 266 return (field.is_extension and 267 field.containing_type.has_options and 268 field.containing_type.GetOptions().message_set_wire_format and 269 field.type == _FieldDescriptor.TYPE_MESSAGE and 270 field.label == _FieldDescriptor.LABEL_OPTIONAL) 271 272 273def _IsMapField(field): 274 return (field.type == _FieldDescriptor.TYPE_MESSAGE and 275 field.message_type.has_options and 276 field.message_type.GetOptions().map_entry) 277 278 279def _IsMessageMapField(field): 280 value_type = field.message_type.fields_by_name['value'] 281 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE 282 283 284def _AttachFieldHelpers(cls, field_descriptor): 285 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 286 is_packable = (is_repeated and 287 wire_format.IsTypePackable(field_descriptor.type)) 288 is_proto3 = field_descriptor.containing_type.syntax == 'proto3' 289 if not is_packable: 290 is_packed = False 291 elif field_descriptor.containing_type.syntax == 'proto2': 292 is_packed = (field_descriptor.has_options and 293 field_descriptor.GetOptions().packed) 294 else: 295 has_packed_false = (field_descriptor.has_options and 296 field_descriptor.GetOptions().HasField('packed') and 297 field_descriptor.GetOptions().packed == False) 298 is_packed = not has_packed_false 299 is_map_entry = _IsMapField(field_descriptor) 300 301 if is_map_entry: 302 field_encoder = encoder.MapEncoder(field_descriptor) 303 sizer = encoder.MapSizer(field_descriptor, 304 _IsMessageMapField(field_descriptor)) 305 elif _IsMessageSetExtension(field_descriptor): 306 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 307 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 308 else: 309 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 310 field_descriptor.number, is_repeated, is_packed) 311 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 312 field_descriptor.number, is_repeated, is_packed) 313 314 field_descriptor._encoder = field_encoder 315 field_descriptor._sizer = sizer 316 field_descriptor._default_constructor = _DefaultValueConstructorForField( 317 field_descriptor) 318 319 def AddDecoder(wiretype, is_packed): 320 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 321 decode_type = field_descriptor.type 322 if (decode_type == _FieldDescriptor.TYPE_ENUM and 323 type_checkers.SupportsOpenEnums(field_descriptor)): 324 decode_type = _FieldDescriptor.TYPE_INT32 325 326 oneof_descriptor = None 327 clear_if_default = False 328 if field_descriptor.containing_oneof is not None: 329 oneof_descriptor = field_descriptor 330 elif (is_proto3 and not is_repeated and 331 field_descriptor.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): 332 clear_if_default = True 333 334 if is_map_entry: 335 is_message_map = _IsMessageMapField(field_descriptor) 336 337 field_decoder = decoder.MapDecoder( 338 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), 339 is_message_map) 340 elif decode_type == _FieldDescriptor.TYPE_STRING: 341 field_decoder = decoder.StringDecoder( 342 field_descriptor.number, is_repeated, is_packed, 343 field_descriptor, field_descriptor._default_constructor, 344 clear_if_default) 345 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 346 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 347 field_descriptor.number, is_repeated, is_packed, 348 field_descriptor, field_descriptor._default_constructor) 349 else: 350 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 351 field_descriptor.number, is_repeated, is_packed, 352 # pylint: disable=protected-access 353 field_descriptor, field_descriptor._default_constructor, 354 clear_if_default) 355 356 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) 357 358 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 359 False) 360 361 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 362 # To support wire compatibility of adding packed = true, add a decoder for 363 # packed values regardless of the field's options. 364 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 365 366 367def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 368 extensions = descriptor.extensions_by_name 369 for extension_name, extension_field in extensions.items(): 370 assert extension_name not in dictionary 371 dictionary[extension_name] = extension_field 372 373 374def _AddEnumValues(descriptor, cls): 375 """Sets class-level attributes for all enum fields defined in this message. 376 377 Also exporting a class-level object that can name enum values. 378 379 Args: 380 descriptor: Descriptor object for this message type. 381 cls: Class we're constructing for this message type. 382 """ 383 for enum_type in descriptor.enum_types: 384 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 385 for enum_value in enum_type.values: 386 setattr(cls, enum_value.name, enum_value.number) 387 388 389def _GetInitializeDefaultForMap(field): 390 if field.label != _FieldDescriptor.LABEL_REPEATED: 391 raise ValueError('map_entry set on non-repeated field %s' % ( 392 field.name)) 393 fields_by_name = field.message_type.fields_by_name 394 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) 395 396 value_field = fields_by_name['value'] 397 if _IsMessageMapField(field): 398 def MakeMessageMapDefault(message): 399 return containers.MessageMap( 400 message._listener_for_children, value_field.message_type, key_checker, 401 field.message_type) 402 return MakeMessageMapDefault 403 else: 404 value_checker = type_checkers.GetTypeChecker(value_field) 405 def MakePrimitiveMapDefault(message): 406 return containers.ScalarMap( 407 message._listener_for_children, key_checker, value_checker, 408 field.message_type) 409 return MakePrimitiveMapDefault 410 411def _DefaultValueConstructorForField(field): 412 """Returns a function which returns a default value for a field. 413 414 Args: 415 field: FieldDescriptor object for this field. 416 417 The returned function has one argument: 418 message: Message instance containing this field, or a weakref proxy 419 of same. 420 421 That function in turn returns a default value for this field. The default 422 value may refer back to |message| via a weak reference. 423 """ 424 425 if _IsMapField(field): 426 return _GetInitializeDefaultForMap(field) 427 428 if field.label == _FieldDescriptor.LABEL_REPEATED: 429 if field.has_default_value and field.default_value != []: 430 raise ValueError('Repeated field default value not empty list: %s' % ( 431 field.default_value)) 432 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 433 # We can't look at _concrete_class yet since it might not have 434 # been set. (Depends on order in which we initialize the classes). 435 message_type = field.message_type 436 def MakeRepeatedMessageDefault(message): 437 return containers.RepeatedCompositeFieldContainer( 438 message._listener_for_children, field.message_type) 439 return MakeRepeatedMessageDefault 440 else: 441 type_checker = type_checkers.GetTypeChecker(field) 442 def MakeRepeatedScalarDefault(message): 443 return containers.RepeatedScalarFieldContainer( 444 message._listener_for_children, type_checker) 445 return MakeRepeatedScalarDefault 446 447 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 448 # _concrete_class may not yet be initialized. 449 message_type = field.message_type 450 def MakeSubMessageDefault(message): 451 assert getattr(message_type, '_concrete_class', None), ( 452 'Uninitialized concrete class found for field %r (message type %r)' 453 % (field.full_name, message_type.full_name)) 454 result = message_type._concrete_class() 455 result._SetListener( 456 _OneofListener(message, field) 457 if field.containing_oneof is not None 458 else message._listener_for_children) 459 return result 460 return MakeSubMessageDefault 461 462 def MakeScalarDefault(message): 463 # TODO(protobuf-team): This may be broken since there may not be 464 # default_value. Combine with has_default_value somehow. 465 return field.default_value 466 return MakeScalarDefault 467 468 469def _ReraiseTypeErrorWithFieldName(message_name, field_name): 470 """Re-raise the currently-handled TypeError with the field name added.""" 471 exc = sys.exc_info()[1] 472 if len(exc.args) == 1 and type(exc) is TypeError: 473 # simple TypeError; add field name to exception message 474 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) 475 476 # re-raise possibly-amended exception with original traceback: 477 raise exc.with_traceback(sys.exc_info()[2]) 478 479 480def _AddInitMethod(message_descriptor, cls): 481 """Adds an __init__ method to cls.""" 482 483 def _GetIntegerEnumValue(enum_type, value): 484 """Convert a string or integer enum value to an integer. 485 486 If the value is a string, it is converted to the enum value in 487 enum_type with the same name. If the value is not a string, it's 488 returned as-is. (No conversion or bounds-checking is done.) 489 """ 490 if isinstance(value, str): 491 try: 492 return enum_type.values_by_name[value].number 493 except KeyError: 494 raise ValueError('Enum type %s: unknown label "%s"' % ( 495 enum_type.full_name, value)) 496 return value 497 498 def init(self, **kwargs): 499 self._cached_byte_size = 0 500 self._cached_byte_size_dirty = len(kwargs) > 0 501 self._fields = {} 502 # Contains a mapping from oneof field descriptors to the descriptor 503 # of the currently set field in that oneof field. 504 self._oneofs = {} 505 506 # _unknown_fields is () when empty for efficiency, and will be turned into 507 # a list if fields are added. 508 self._unknown_fields = () 509 # _unknown_field_set is None when empty for efficiency, and will be 510 # turned into UnknownFieldSet struct if fields are added. 511 self._unknown_field_set = None # pylint: disable=protected-access 512 self._is_present_in_parent = False 513 self._listener = message_listener_mod.NullMessageListener() 514 self._listener_for_children = _Listener(self) 515 for field_name, field_value in kwargs.items(): 516 field = _GetFieldByName(message_descriptor, field_name) 517 if field is None: 518 raise TypeError('%s() got an unexpected keyword argument "%s"' % 519 (message_descriptor.name, field_name)) 520 if field_value is None: 521 # field=None is the same as no field at all. 522 continue 523 if field.label == _FieldDescriptor.LABEL_REPEATED: 524 copy = field._default_constructor(self) 525 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 526 if _IsMapField(field): 527 if _IsMessageMapField(field): 528 for key in field_value: 529 copy[key].MergeFrom(field_value[key]) 530 else: 531 copy.update(field_value) 532 else: 533 for val in field_value: 534 if isinstance(val, dict): 535 copy.add(**val) 536 else: 537 copy.add().MergeFrom(val) 538 else: # Scalar 539 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 540 field_value = [_GetIntegerEnumValue(field.enum_type, val) 541 for val in field_value] 542 copy.extend(field_value) 543 self._fields[field] = copy 544 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 545 copy = field._default_constructor(self) 546 new_val = field_value 547 if isinstance(field_value, dict): 548 new_val = field.message_type._concrete_class(**field_value) 549 try: 550 copy.MergeFrom(new_val) 551 except TypeError: 552 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 553 self._fields[field] = copy 554 else: 555 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 556 field_value = _GetIntegerEnumValue(field.enum_type, field_value) 557 try: 558 setattr(self, field_name, field_value) 559 except TypeError: 560 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 561 562 init.__module__ = None 563 init.__doc__ = None 564 cls.__init__ = init 565 566 567def _GetFieldByName(message_descriptor, field_name): 568 """Returns a field descriptor by field name. 569 570 Args: 571 message_descriptor: A Descriptor describing all fields in message. 572 field_name: The name of the field to retrieve. 573 Returns: 574 The field descriptor associated with the field name. 575 """ 576 try: 577 return message_descriptor.fields_by_name[field_name] 578 except KeyError: 579 raise ValueError('Protocol message %s has no "%s" field.' % 580 (message_descriptor.name, field_name)) 581 582 583def _AddPropertiesForFields(descriptor, cls): 584 """Adds properties for all fields in this protocol message type.""" 585 for field in descriptor.fields: 586 _AddPropertiesForField(field, cls) 587 588 if descriptor.is_extendable: 589 # _ExtensionDict is just an adaptor with no state so we allocate a new one 590 # every time it is accessed. 591 cls.Extensions = property(lambda self: _ExtensionDict(self)) 592 593 594def _AddPropertiesForField(field, cls): 595 """Adds a public property for a protocol message field. 596 Clients can use this property to get and (in the case 597 of non-repeated scalar fields) directly set the value 598 of a protocol message field. 599 600 Args: 601 field: A FieldDescriptor for this field. 602 cls: The class we're constructing. 603 """ 604 # Catch it if we add other types that we should 605 # handle specially here. 606 assert _FieldDescriptor.MAX_CPPTYPE == 10 607 608 constant_name = field.name.upper() + '_FIELD_NUMBER' 609 setattr(cls, constant_name, field.number) 610 611 if field.label == _FieldDescriptor.LABEL_REPEATED: 612 _AddPropertiesForRepeatedField(field, cls) 613 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 614 _AddPropertiesForNonRepeatedCompositeField(field, cls) 615 else: 616 _AddPropertiesForNonRepeatedScalarField(field, cls) 617 618 619class _FieldProperty(property): 620 __slots__ = ('DESCRIPTOR',) 621 622 def __init__(self, descriptor, getter, setter, doc): 623 property.__init__(self, getter, setter, doc=doc) 624 self.DESCRIPTOR = descriptor 625 626 627def _AddPropertiesForRepeatedField(field, cls): 628 """Adds a public property for a "repeated" protocol message field. Clients 629 can use this property to get the value of the field, which will be either a 630 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see 631 below). 632 633 Note that when clients add values to these containers, we perform 634 type-checking in the case of repeated scalar fields, and we also set any 635 necessary "has" bits as a side-effect. 636 637 Args: 638 field: A FieldDescriptor for this field. 639 cls: The class we're constructing. 640 """ 641 proto_field_name = field.name 642 property_name = _PropertyName(proto_field_name) 643 644 def getter(self): 645 field_value = self._fields.get(field) 646 if field_value is None: 647 # Construct a new object to represent this field. 648 field_value = field._default_constructor(self) 649 650 # Atomically check if another thread has preempted us and, if not, swap 651 # in the new object we just created. If someone has preempted us, we 652 # take that object and discard ours. 653 # WARNING: We are relying on setdefault() being atomic. This is true 654 # in CPython but we haven't investigated others. This warning appears 655 # in several other locations in this file. 656 field_value = self._fields.setdefault(field, field_value) 657 return field_value 658 getter.__module__ = None 659 getter.__doc__ = 'Getter for %s.' % proto_field_name 660 661 # We define a setter just so we can throw an exception with a more 662 # helpful error message. 663 def setter(self, new_value): 664 raise AttributeError('Assignment not allowed to repeated field ' 665 '"%s" in protocol message object.' % proto_field_name) 666 667 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 668 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 669 670 671def _AddPropertiesForNonRepeatedScalarField(field, cls): 672 """Adds a public property for a nonrepeated, scalar protocol message field. 673 Clients can use this property to get and directly set the value of the field. 674 Note that when the client sets the value of a field by using this property, 675 all necessary "has" bits are set as a side-effect, and we also perform 676 type-checking. 677 678 Args: 679 field: A FieldDescriptor for this field. 680 cls: The class we're constructing. 681 """ 682 proto_field_name = field.name 683 property_name = _PropertyName(proto_field_name) 684 type_checker = type_checkers.GetTypeChecker(field) 685 default_value = field.default_value 686 is_proto3 = field.containing_type.syntax == 'proto3' 687 688 def getter(self): 689 # TODO(protobuf-team): This may be broken since there may not be 690 # default_value. Combine with has_default_value somehow. 691 return self._fields.get(field, default_value) 692 getter.__module__ = None 693 getter.__doc__ = 'Getter for %s.' % proto_field_name 694 695 clear_when_set_to_default = is_proto3 and not field.containing_oneof 696 697 def field_setter(self, new_value): 698 # pylint: disable=protected-access 699 # Testing the value for truthiness captures all of the proto3 defaults 700 # (0, 0.0, enum 0, and False). 701 try: 702 new_value = type_checker.CheckValue(new_value) 703 except TypeError as e: 704 raise TypeError( 705 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) 706 if clear_when_set_to_default and not new_value: 707 self._fields.pop(field, None) 708 else: 709 self._fields[field] = new_value 710 # Check _cached_byte_size_dirty inline to improve performance, since scalar 711 # setters are called frequently. 712 if not self._cached_byte_size_dirty: 713 self._Modified() 714 715 if field.containing_oneof: 716 def setter(self, new_value): 717 field_setter(self, new_value) 718 self._UpdateOneofState(field) 719 else: 720 setter = field_setter 721 722 setter.__module__ = None 723 setter.__doc__ = 'Setter for %s.' % proto_field_name 724 725 # Add a property to encapsulate the getter/setter. 726 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 727 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 728 729 730def _AddPropertiesForNonRepeatedCompositeField(field, cls): 731 """Adds a public property for a nonrepeated, composite protocol message field. 732 A composite field is a "group" or "message" field. 733 734 Clients can use this property to get the value of the field, but cannot 735 assign to the property directly. 736 737 Args: 738 field: A FieldDescriptor for this field. 739 cls: The class we're constructing. 740 """ 741 # TODO(robinson): Remove duplication with similar method 742 # for non-repeated scalars. 743 proto_field_name = field.name 744 property_name = _PropertyName(proto_field_name) 745 746 def getter(self): 747 field_value = self._fields.get(field) 748 if field_value is None: 749 # Construct a new object to represent this field. 750 field_value = field._default_constructor(self) 751 752 # Atomically check if another thread has preempted us and, if not, swap 753 # in the new object we just created. If someone has preempted us, we 754 # take that object and discard ours. 755 # WARNING: We are relying on setdefault() being atomic. This is true 756 # in CPython but we haven't investigated others. This warning appears 757 # in several other locations in this file. 758 field_value = self._fields.setdefault(field, field_value) 759 return field_value 760 getter.__module__ = None 761 getter.__doc__ = 'Getter for %s.' % proto_field_name 762 763 # We define a setter just so we can throw an exception with a more 764 # helpful error message. 765 def setter(self, new_value): 766 raise AttributeError('Assignment not allowed to composite field ' 767 '"%s" in protocol message object.' % proto_field_name) 768 769 # Add a property to encapsulate the getter. 770 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 771 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 772 773 774def _AddPropertiesForExtensions(descriptor, cls): 775 """Adds properties for all fields in this protocol message type.""" 776 extensions = descriptor.extensions_by_name 777 for extension_name, extension_field in extensions.items(): 778 constant_name = extension_name.upper() + '_FIELD_NUMBER' 779 setattr(cls, constant_name, extension_field.number) 780 781 # TODO(amauryfa): Migrate all users of these attributes to functions like 782 # pool.FindExtensionByNumber(descriptor). 783 if descriptor.file is not None: 784 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. 785 pool = descriptor.file.pool 786 cls._extensions_by_number = pool._extensions_by_number[descriptor] 787 cls._extensions_by_name = pool._extensions_by_name[descriptor] 788 789def _AddStaticMethods(cls): 790 # TODO(robinson): This probably needs to be thread-safe(?) 791 def RegisterExtension(extension_handle): 792 extension_handle.containing_type = cls.DESCRIPTOR 793 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. 794 # pylint: disable=protected-access 795 cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(extension_handle) 796 _AttachFieldHelpers(cls, extension_handle) 797 cls.RegisterExtension = staticmethod(RegisterExtension) 798 799 def FromString(s): 800 message = cls() 801 message.MergeFromString(s) 802 return message 803 cls.FromString = staticmethod(FromString) 804 805 806def _IsPresent(item): 807 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 808 value should be included in the list returned by ListFields().""" 809 810 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 811 return bool(item[1]) 812 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 813 return item[1]._is_present_in_parent 814 else: 815 return True 816 817 818def _AddListFieldsMethod(message_descriptor, cls): 819 """Helper for _AddMessageMethods().""" 820 821 def ListFields(self): 822 all_fields = [item for item in self._fields.items() if _IsPresent(item)] 823 all_fields.sort(key = lambda item: item[0].number) 824 return all_fields 825 826 cls.ListFields = ListFields 827 828_PROTO3_ERROR_TEMPLATE = \ 829 ('Protocol message %s has no non-repeated submessage field "%s" ' 830 'nor marked as optional') 831_PROTO2_ERROR_TEMPLATE = 'Protocol message %s has no non-repeated field "%s"' 832 833def _AddHasFieldMethod(message_descriptor, cls): 834 """Helper for _AddMessageMethods().""" 835 836 is_proto3 = (message_descriptor.syntax == "proto3") 837 error_msg = _PROTO3_ERROR_TEMPLATE if is_proto3 else _PROTO2_ERROR_TEMPLATE 838 839 hassable_fields = {} 840 for field in message_descriptor.fields: 841 if field.label == _FieldDescriptor.LABEL_REPEATED: 842 continue 843 # For proto3, only submessages and fields inside a oneof have presence. 844 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and 845 not field.containing_oneof): 846 continue 847 hassable_fields[field.name] = field 848 849 # Has methods are supported for oneof descriptors. 850 for oneof in message_descriptor.oneofs: 851 hassable_fields[oneof.name] = oneof 852 853 def HasField(self, field_name): 854 try: 855 field = hassable_fields[field_name] 856 except KeyError: 857 raise ValueError(error_msg % (message_descriptor.full_name, field_name)) 858 859 if isinstance(field, descriptor_mod.OneofDescriptor): 860 try: 861 return HasField(self, self._oneofs[field].name) 862 except KeyError: 863 return False 864 else: 865 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 866 value = self._fields.get(field) 867 return value is not None and value._is_present_in_parent 868 else: 869 return field in self._fields 870 871 cls.HasField = HasField 872 873 874def _AddClearFieldMethod(message_descriptor, cls): 875 """Helper for _AddMessageMethods().""" 876 def ClearField(self, field_name): 877 try: 878 field = message_descriptor.fields_by_name[field_name] 879 except KeyError: 880 try: 881 field = message_descriptor.oneofs_by_name[field_name] 882 if field in self._oneofs: 883 field = self._oneofs[field] 884 else: 885 return 886 except KeyError: 887 raise ValueError('Protocol message %s has no "%s" field.' % 888 (message_descriptor.name, field_name)) 889 890 if field in self._fields: 891 # To match the C++ implementation, we need to invalidate iterators 892 # for map fields when ClearField() happens. 893 if hasattr(self._fields[field], 'InvalidateIterators'): 894 self._fields[field].InvalidateIterators() 895 896 # Note: If the field is a sub-message, its listener will still point 897 # at us. That's fine, because the worst than can happen is that it 898 # will call _Modified() and invalidate our byte size. Big deal. 899 del self._fields[field] 900 901 if self._oneofs.get(field.containing_oneof, None) is field: 902 del self._oneofs[field.containing_oneof] 903 904 # Always call _Modified() -- even if nothing was changed, this is 905 # a mutating method, and thus calling it should cause the field to become 906 # present in the parent message. 907 self._Modified() 908 909 cls.ClearField = ClearField 910 911 912def _AddClearExtensionMethod(cls): 913 """Helper for _AddMessageMethods().""" 914 def ClearExtension(self, extension_handle): 915 extension_dict._VerifyExtensionHandle(self, extension_handle) 916 917 # Similar to ClearField(), above. 918 if extension_handle in self._fields: 919 del self._fields[extension_handle] 920 self._Modified() 921 cls.ClearExtension = ClearExtension 922 923 924def _AddHasExtensionMethod(cls): 925 """Helper for _AddMessageMethods().""" 926 def HasExtension(self, extension_handle): 927 extension_dict._VerifyExtensionHandle(self, extension_handle) 928 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 929 raise KeyError('"%s" is repeated.' % extension_handle.full_name) 930 931 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 932 value = self._fields.get(extension_handle) 933 return value is not None and value._is_present_in_parent 934 else: 935 return extension_handle in self._fields 936 cls.HasExtension = HasExtension 937 938def _InternalUnpackAny(msg): 939 """Unpacks Any message and returns the unpacked message. 940 941 This internal method is different from public Any Unpack method which takes 942 the target message as argument. _InternalUnpackAny method does not have 943 target message type and need to find the message type in descriptor pool. 944 945 Args: 946 msg: An Any message to be unpacked. 947 948 Returns: 949 The unpacked message. 950 """ 951 # TODO(amauryfa): Don't use the factory of generated messages. 952 # To make Any work with custom factories, use the message factory of the 953 # parent message. 954 # pylint: disable=g-import-not-at-top 955 from google.protobuf import symbol_database 956 factory = symbol_database.Default() 957 958 type_url = msg.type_url 959 960 if not type_url: 961 return None 962 963 # TODO(haberman): For now we just strip the hostname. Better logic will be 964 # required. 965 type_name = type_url.split('/')[-1] 966 descriptor = factory.pool.FindMessageTypeByName(type_name) 967 968 if descriptor is None: 969 return None 970 971 message_class = factory.GetPrototype(descriptor) 972 message = message_class() 973 974 message.ParseFromString(msg.value) 975 return message 976 977 978def _AddEqualsMethod(message_descriptor, cls): 979 """Helper for _AddMessageMethods().""" 980 def __eq__(self, other): 981 if (not isinstance(other, message_mod.Message) or 982 other.DESCRIPTOR != self.DESCRIPTOR): 983 return False 984 985 if self is other: 986 return True 987 988 if self.DESCRIPTOR.full_name == _AnyFullTypeName: 989 any_a = _InternalUnpackAny(self) 990 any_b = _InternalUnpackAny(other) 991 if any_a and any_b: 992 return any_a == any_b 993 994 if not self.ListFields() == other.ListFields(): 995 return False 996 997 # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions, 998 # then use it for the comparison. 999 unknown_fields = list(self._unknown_fields) 1000 unknown_fields.sort() 1001 other_unknown_fields = list(other._unknown_fields) 1002 other_unknown_fields.sort() 1003 return unknown_fields == other_unknown_fields 1004 1005 cls.__eq__ = __eq__ 1006 1007 1008def _AddStrMethod(message_descriptor, cls): 1009 """Helper for _AddMessageMethods().""" 1010 def __str__(self): 1011 return text_format.MessageToString(self) 1012 cls.__str__ = __str__ 1013 1014 1015def _AddReprMethod(message_descriptor, cls): 1016 """Helper for _AddMessageMethods().""" 1017 def __repr__(self): 1018 return text_format.MessageToString(self) 1019 cls.__repr__ = __repr__ 1020 1021 1022def _AddUnicodeMethod(unused_message_descriptor, cls): 1023 """Helper for _AddMessageMethods().""" 1024 1025 def __unicode__(self): 1026 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 1027 cls.__unicode__ = __unicode__ 1028 1029 1030def _BytesForNonRepeatedElement(value, field_number, field_type): 1031 """Returns the number of bytes needed to serialize a non-repeated element. 1032 The returned byte count includes space for tag information and any 1033 other additional space associated with serializing value. 1034 1035 Args: 1036 value: Value we're serializing. 1037 field_number: Field number of this value. (Since the field number 1038 is stored as part of a varint-encoded tag, this has an impact 1039 on the total bytes required to serialize the value). 1040 field_type: The type of the field. One of the TYPE_* constants 1041 within FieldDescriptor. 1042 """ 1043 try: 1044 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 1045 return fn(field_number, value) 1046 except KeyError: 1047 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 1048 1049 1050def _AddByteSizeMethod(message_descriptor, cls): 1051 """Helper for _AddMessageMethods().""" 1052 1053 def ByteSize(self): 1054 if not self._cached_byte_size_dirty: 1055 return self._cached_byte_size 1056 1057 size = 0 1058 descriptor = self.DESCRIPTOR 1059 if descriptor.GetOptions().map_entry: 1060 # Fields of map entry should always be serialized. 1061 size = descriptor.fields_by_name['key']._sizer(self.key) 1062 size += descriptor.fields_by_name['value']._sizer(self.value) 1063 else: 1064 for field_descriptor, field_value in self.ListFields(): 1065 size += field_descriptor._sizer(field_value) 1066 for tag_bytes, value_bytes in self._unknown_fields: 1067 size += len(tag_bytes) + len(value_bytes) 1068 1069 self._cached_byte_size = size 1070 self._cached_byte_size_dirty = False 1071 self._listener_for_children.dirty = False 1072 return size 1073 1074 cls.ByteSize = ByteSize 1075 1076 1077def _AddSerializeToStringMethod(message_descriptor, cls): 1078 """Helper for _AddMessageMethods().""" 1079 1080 def SerializeToString(self, **kwargs): 1081 # Check if the message has all of its required fields set. 1082 if not self.IsInitialized(): 1083 raise message_mod.EncodeError( 1084 'Message %s is missing required fields: %s' % ( 1085 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 1086 return self.SerializePartialToString(**kwargs) 1087 cls.SerializeToString = SerializeToString 1088 1089 1090def _AddSerializePartialToStringMethod(message_descriptor, cls): 1091 """Helper for _AddMessageMethods().""" 1092 1093 def SerializePartialToString(self, **kwargs): 1094 out = BytesIO() 1095 self._InternalSerialize(out.write, **kwargs) 1096 return out.getvalue() 1097 cls.SerializePartialToString = SerializePartialToString 1098 1099 def InternalSerialize(self, write_bytes, deterministic=None): 1100 if deterministic is None: 1101 deterministic = ( 1102 api_implementation.IsPythonDefaultSerializationDeterministic()) 1103 else: 1104 deterministic = bool(deterministic) 1105 1106 descriptor = self.DESCRIPTOR 1107 if descriptor.GetOptions().map_entry: 1108 # Fields of map entry should always be serialized. 1109 descriptor.fields_by_name['key']._encoder( 1110 write_bytes, self.key, deterministic) 1111 descriptor.fields_by_name['value']._encoder( 1112 write_bytes, self.value, deterministic) 1113 else: 1114 for field_descriptor, field_value in self.ListFields(): 1115 field_descriptor._encoder(write_bytes, field_value, deterministic) 1116 for tag_bytes, value_bytes in self._unknown_fields: 1117 write_bytes(tag_bytes) 1118 write_bytes(value_bytes) 1119 cls._InternalSerialize = InternalSerialize 1120 1121 1122def _AddMergeFromStringMethod(message_descriptor, cls): 1123 """Helper for _AddMessageMethods().""" 1124 def MergeFromString(self, serialized): 1125 serialized = memoryview(serialized) 1126 length = len(serialized) 1127 try: 1128 if self._InternalParse(serialized, 0, length) != length: 1129 # The only reason _InternalParse would return early is if it 1130 # encountered an end-group tag. 1131 raise message_mod.DecodeError('Unexpected end-group tag.') 1132 except (IndexError, TypeError): 1133 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 1134 raise message_mod.DecodeError('Truncated message.') 1135 except struct.error as e: 1136 raise message_mod.DecodeError(e) 1137 return length # Return this for legacy reasons. 1138 cls.MergeFromString = MergeFromString 1139 1140 local_ReadTag = decoder.ReadTag 1141 local_SkipField = decoder.SkipField 1142 decoders_by_tag = cls._decoders_by_tag 1143 1144 def InternalParse(self, buffer, pos, end): 1145 """Create a message from serialized bytes. 1146 1147 Args: 1148 self: Message, instance of the proto message object. 1149 buffer: memoryview of the serialized data. 1150 pos: int, position to start in the serialized data. 1151 end: int, end position of the serialized data. 1152 1153 Returns: 1154 Message object. 1155 """ 1156 # Guard against internal misuse, since this function is called internally 1157 # quite extensively, and its easy to accidentally pass bytes. 1158 assert isinstance(buffer, memoryview) 1159 self._Modified() 1160 field_dict = self._fields 1161 # pylint: disable=protected-access 1162 unknown_field_set = self._unknown_field_set 1163 while pos != end: 1164 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 1165 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) 1166 if field_decoder is None: 1167 if not self._unknown_fields: # pylint: disable=protected-access 1168 self._unknown_fields = [] # pylint: disable=protected-access 1169 if unknown_field_set is None: 1170 # pylint: disable=protected-access 1171 self._unknown_field_set = containers.UnknownFieldSet() 1172 # pylint: disable=protected-access 1173 unknown_field_set = self._unknown_field_set 1174 # pylint: disable=protected-access 1175 (tag, _) = decoder._DecodeVarint(tag_bytes, 0) 1176 field_number, wire_type = wire_format.UnpackTag(tag) 1177 if field_number == 0: 1178 raise message_mod.DecodeError('Field number 0 is illegal.') 1179 # TODO(jieluo): remove old_pos. 1180 old_pos = new_pos 1181 (data, new_pos) = decoder._DecodeUnknownField( 1182 buffer, new_pos, wire_type) # pylint: disable=protected-access 1183 if new_pos == -1: 1184 return pos 1185 # pylint: disable=protected-access 1186 unknown_field_set._add(field_number, wire_type, data) 1187 # TODO(jieluo): remove _unknown_fields. 1188 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) 1189 if new_pos == -1: 1190 return pos 1191 self._unknown_fields.append( 1192 (tag_bytes, buffer[old_pos:new_pos].tobytes())) 1193 pos = new_pos 1194 else: 1195 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1196 if field_desc: 1197 self._UpdateOneofState(field_desc) 1198 return pos 1199 cls._InternalParse = InternalParse 1200 1201 1202def _AddIsInitializedMethod(message_descriptor, cls): 1203 """Adds the IsInitialized and FindInitializationError methods to the 1204 protocol message class.""" 1205 1206 required_fields = [field for field in message_descriptor.fields 1207 if field.label == _FieldDescriptor.LABEL_REQUIRED] 1208 1209 def IsInitialized(self, errors=None): 1210 """Checks if all required fields of a message are set. 1211 1212 Args: 1213 errors: A list which, if provided, will be populated with the field 1214 paths of all missing required fields. 1215 1216 Returns: 1217 True iff the specified message has all required fields set. 1218 """ 1219 1220 # Performance is critical so we avoid HasField() and ListFields(). 1221 1222 for field in required_fields: 1223 if (field not in self._fields or 1224 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 1225 not self._fields[field]._is_present_in_parent)): 1226 if errors is not None: 1227 errors.extend(self.FindInitializationErrors()) 1228 return False 1229 1230 for field, value in list(self._fields.items()): # dict can change size! 1231 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1232 if field.label == _FieldDescriptor.LABEL_REPEATED: 1233 if (field.message_type.has_options and 1234 field.message_type.GetOptions().map_entry): 1235 continue 1236 for element in value: 1237 if not element.IsInitialized(): 1238 if errors is not None: 1239 errors.extend(self.FindInitializationErrors()) 1240 return False 1241 elif value._is_present_in_parent and not value.IsInitialized(): 1242 if errors is not None: 1243 errors.extend(self.FindInitializationErrors()) 1244 return False 1245 1246 return True 1247 1248 cls.IsInitialized = IsInitialized 1249 1250 def FindInitializationErrors(self): 1251 """Finds required fields which are not initialized. 1252 1253 Returns: 1254 A list of strings. Each string is a path to an uninitialized field from 1255 the top-level message, e.g. "foo.bar[5].baz". 1256 """ 1257 1258 errors = [] # simplify things 1259 1260 for field in required_fields: 1261 if not self.HasField(field.name): 1262 errors.append(field.name) 1263 1264 for field, value in self.ListFields(): 1265 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1266 if field.is_extension: 1267 name = '(%s)' % field.full_name 1268 else: 1269 name = field.name 1270 1271 if _IsMapField(field): 1272 if _IsMessageMapField(field): 1273 for key in value: 1274 element = value[key] 1275 prefix = '%s[%s].' % (name, key) 1276 sub_errors = element.FindInitializationErrors() 1277 errors += [prefix + error for error in sub_errors] 1278 else: 1279 # ScalarMaps can't have any initialization errors. 1280 pass 1281 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1282 for i in range(len(value)): 1283 element = value[i] 1284 prefix = '%s[%d].' % (name, i) 1285 sub_errors = element.FindInitializationErrors() 1286 errors += [prefix + error for error in sub_errors] 1287 else: 1288 prefix = name + '.' 1289 sub_errors = value.FindInitializationErrors() 1290 errors += [prefix + error for error in sub_errors] 1291 1292 return errors 1293 1294 cls.FindInitializationErrors = FindInitializationErrors 1295 1296 1297def _FullyQualifiedClassName(klass): 1298 module = klass.__module__ 1299 name = getattr(klass, '__qualname__', klass.__name__) 1300 if module in (None, 'builtins', '__builtin__'): 1301 return name 1302 return module + '.' + name 1303 1304 1305def _AddMergeFromMethod(cls): 1306 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 1307 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 1308 1309 def MergeFrom(self, msg): 1310 if not isinstance(msg, cls): 1311 raise TypeError( 1312 'Parameter to MergeFrom() must be instance of same class: ' 1313 'expected %s got %s.' % (_FullyQualifiedClassName(cls), 1314 _FullyQualifiedClassName(msg.__class__))) 1315 1316 assert msg is not self 1317 self._Modified() 1318 1319 fields = self._fields 1320 1321 for field, value in msg._fields.items(): 1322 if field.label == LABEL_REPEATED: 1323 field_value = fields.get(field) 1324 if field_value is None: 1325 # Construct a new object to represent this field. 1326 field_value = field._default_constructor(self) 1327 fields[field] = field_value 1328 field_value.MergeFrom(value) 1329 elif field.cpp_type == CPPTYPE_MESSAGE: 1330 if value._is_present_in_parent: 1331 field_value = fields.get(field) 1332 if field_value is None: 1333 # Construct a new object to represent this field. 1334 field_value = field._default_constructor(self) 1335 fields[field] = field_value 1336 field_value.MergeFrom(value) 1337 else: 1338 self._fields[field] = value 1339 if field.containing_oneof: 1340 self._UpdateOneofState(field) 1341 1342 if msg._unknown_fields: 1343 if not self._unknown_fields: 1344 self._unknown_fields = [] 1345 self._unknown_fields.extend(msg._unknown_fields) 1346 # pylint: disable=protected-access 1347 if self._unknown_field_set is None: 1348 self._unknown_field_set = containers.UnknownFieldSet() 1349 self._unknown_field_set._extend(msg._unknown_field_set) 1350 1351 cls.MergeFrom = MergeFrom 1352 1353 1354def _AddWhichOneofMethod(message_descriptor, cls): 1355 def WhichOneof(self, oneof_name): 1356 """Returns the name of the currently set field inside a oneof, or None.""" 1357 try: 1358 field = message_descriptor.oneofs_by_name[oneof_name] 1359 except KeyError: 1360 raise ValueError( 1361 'Protocol message has no oneof "%s" field.' % oneof_name) 1362 1363 nested_field = self._oneofs.get(field, None) 1364 if nested_field is not None and self.HasField(nested_field.name): 1365 return nested_field.name 1366 else: 1367 return None 1368 1369 cls.WhichOneof = WhichOneof 1370 1371 1372def _Clear(self): 1373 # Clear fields. 1374 self._fields = {} 1375 self._unknown_fields = () 1376 # pylint: disable=protected-access 1377 if self._unknown_field_set is not None: 1378 self._unknown_field_set._clear() 1379 self._unknown_field_set = None 1380 1381 self._oneofs = {} 1382 self._Modified() 1383 1384 1385def _UnknownFields(self): 1386 if self._unknown_field_set is None: # pylint: disable=protected-access 1387 # pylint: disable=protected-access 1388 self._unknown_field_set = containers.UnknownFieldSet() 1389 return self._unknown_field_set # pylint: disable=protected-access 1390 1391 1392def _DiscardUnknownFields(self): 1393 self._unknown_fields = [] 1394 self._unknown_field_set = None # pylint: disable=protected-access 1395 for field, value in self.ListFields(): 1396 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1397 if _IsMapField(field): 1398 if _IsMessageMapField(field): 1399 for key in value: 1400 value[key].DiscardUnknownFields() 1401 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1402 for sub_message in value: 1403 sub_message.DiscardUnknownFields() 1404 else: 1405 value.DiscardUnknownFields() 1406 1407 1408def _SetListener(self, listener): 1409 if listener is None: 1410 self._listener = message_listener_mod.NullMessageListener() 1411 else: 1412 self._listener = listener 1413 1414 1415def _AddMessageMethods(message_descriptor, cls): 1416 """Adds implementations of all Message methods to cls.""" 1417 _AddListFieldsMethod(message_descriptor, cls) 1418 _AddHasFieldMethod(message_descriptor, cls) 1419 _AddClearFieldMethod(message_descriptor, cls) 1420 if message_descriptor.is_extendable: 1421 _AddClearExtensionMethod(cls) 1422 _AddHasExtensionMethod(cls) 1423 _AddEqualsMethod(message_descriptor, cls) 1424 _AddStrMethod(message_descriptor, cls) 1425 _AddReprMethod(message_descriptor, cls) 1426 _AddUnicodeMethod(message_descriptor, cls) 1427 _AddByteSizeMethod(message_descriptor, cls) 1428 _AddSerializeToStringMethod(message_descriptor, cls) 1429 _AddSerializePartialToStringMethod(message_descriptor, cls) 1430 _AddMergeFromStringMethod(message_descriptor, cls) 1431 _AddIsInitializedMethod(message_descriptor, cls) 1432 _AddMergeFromMethod(cls) 1433 _AddWhichOneofMethod(message_descriptor, cls) 1434 # Adds methods which do not depend on cls. 1435 cls.Clear = _Clear 1436 cls.UnknownFields = _UnknownFields 1437 cls.DiscardUnknownFields = _DiscardUnknownFields 1438 cls._SetListener = _SetListener 1439 1440 1441def _AddPrivateHelperMethods(message_descriptor, cls): 1442 """Adds implementation of private helper methods to cls.""" 1443 1444 def Modified(self): 1445 """Sets the _cached_byte_size_dirty bit to true, 1446 and propagates this to our listener iff this was a state change. 1447 """ 1448 1449 # Note: Some callers check _cached_byte_size_dirty before calling 1450 # _Modified() as an extra optimization. So, if this method is ever 1451 # changed such that it does stuff even when _cached_byte_size_dirty is 1452 # already true, the callers need to be updated. 1453 if not self._cached_byte_size_dirty: 1454 self._cached_byte_size_dirty = True 1455 self._listener_for_children.dirty = True 1456 self._is_present_in_parent = True 1457 self._listener.Modified() 1458 1459 def _UpdateOneofState(self, field): 1460 """Sets field as the active field in its containing oneof. 1461 1462 Will also delete currently active field in the oneof, if it is different 1463 from the argument. Does not mark the message as modified. 1464 """ 1465 other_field = self._oneofs.setdefault(field.containing_oneof, field) 1466 if other_field is not field: 1467 del self._fields[other_field] 1468 self._oneofs[field.containing_oneof] = field 1469 1470 cls._Modified = Modified 1471 cls.SetInParent = Modified 1472 cls._UpdateOneofState = _UpdateOneofState 1473 1474 1475class _Listener(object): 1476 1477 """MessageListener implementation that a parent message registers with its 1478 child message. 1479 1480 In order to support semantics like: 1481 1482 foo.bar.baz.moo = 23 1483 assert foo.HasField('bar') 1484 1485 ...child objects must have back references to their parents. 1486 This helper class is at the heart of this support. 1487 """ 1488 1489 def __init__(self, parent_message): 1490 """Args: 1491 parent_message: The message whose _Modified() method we should call when 1492 we receive Modified() messages. 1493 """ 1494 # This listener establishes a back reference from a child (contained) object 1495 # to its parent (containing) object. We make this a weak reference to avoid 1496 # creating cyclic garbage when the client finishes with the 'parent' object 1497 # in the tree. 1498 if isinstance(parent_message, weakref.ProxyType): 1499 self._parent_message_weakref = parent_message 1500 else: 1501 self._parent_message_weakref = weakref.proxy(parent_message) 1502 1503 # As an optimization, we also indicate directly on the listener whether 1504 # or not the parent message is dirty. This way we can avoid traversing 1505 # up the tree in the common case. 1506 self.dirty = False 1507 1508 def Modified(self): 1509 if self.dirty: 1510 return 1511 try: 1512 # Propagate the signal to our parents iff this is the first field set. 1513 self._parent_message_weakref._Modified() 1514 except ReferenceError: 1515 # We can get here if a client has kept a reference to a child object, 1516 # and is now setting a field on it, but the child's parent has been 1517 # garbage-collected. This is not an error. 1518 pass 1519 1520 1521class _OneofListener(_Listener): 1522 """Special listener implementation for setting composite oneof fields.""" 1523 1524 def __init__(self, parent_message, field): 1525 """Args: 1526 parent_message: The message whose _Modified() method we should call when 1527 we receive Modified() messages. 1528 field: The descriptor of the field being set in the parent message. 1529 """ 1530 super(_OneofListener, self).__init__(parent_message) 1531 self._field = field 1532 1533 def Modified(self): 1534 """Also updates the state of the containing oneof in the parent message.""" 1535 try: 1536 self._parent_message_weakref._UpdateOneofState(self._field) 1537 super(_OneofListener, self).Modified() 1538 except ReferenceError: 1539 pass 1540