xref: /aosp_15_r20/external/cronet/third_party/protobuf/python/google/protobuf/internal/python_message.py (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = '[email protected] (Will Robinson)'
52
53from io import BytesIO
54import struct
55import sys
56import weakref
57
58# We use "as" to avoid name collisions with variables.
59from google.protobuf.internal import api_implementation
60from google.protobuf.internal import containers
61from google.protobuf.internal import decoder
62from google.protobuf.internal import encoder
63from google.protobuf.internal import enum_type_wrapper
64from google.protobuf.internal import extension_dict
65from google.protobuf.internal import message_listener as message_listener_mod
66from google.protobuf.internal import type_checkers
67from google.protobuf.internal import well_known_types
68from google.protobuf.internal import wire_format
69from google.protobuf import descriptor as descriptor_mod
70from google.protobuf import message as message_mod
71from google.protobuf import text_format
72
73_FieldDescriptor = descriptor_mod.FieldDescriptor
74_AnyFullTypeName = 'google.protobuf.Any'
75_ExtensionDict = extension_dict._ExtensionDict
76
77class GeneratedProtocolMessageType(type):
78
79  """Metaclass for protocol message classes created at runtime from Descriptors.
80
81  We add implementations for all methods described in the Message class.  We
82  also create properties to allow getting/setting all fields in the protocol
83  message.  Finally, we create slots to prevent users from accidentally
84  "setting" nonexistent fields in the protocol message, which then wouldn't get
85  serialized / deserialized properly.
86
87  The protocol compiler currently uses this metaclass to create protocol
88  message classes at runtime.  Clients can also manually create their own
89  classes at runtime, as in this example:
90
91  mydescriptor = Descriptor(.....)
92  factory = symbol_database.Default()
93  factory.pool.AddDescriptor(mydescriptor)
94  MyProtoClass = factory.GetPrototype(mydescriptor)
95  myproto_instance = MyProtoClass()
96  myproto.foo_field = 23
97  ...
98  """
99
100  # Must be consistent with the protocol-compiler code in
101  # proto2/compiler/internal/generator.*.
102  _DESCRIPTOR_KEY = 'DESCRIPTOR'
103
104  def __new__(cls, name, bases, dictionary):
105    """Custom allocation for runtime-generated class types.
106
107    We override __new__ because this is apparently the only place
108    where we can meaningfully set __slots__ on the class we're creating(?).
109    (The interplay between metaclasses and slots is not very well-documented).
110
111    Args:
112      name: Name of the class (ignored, but required by the
113        metaclass protocol).
114      bases: Base classes of the class we're constructing.
115        (Should be message.Message).  We ignore this field, but
116        it's required by the metaclass protocol
117      dictionary: The class dictionary of the class we're
118        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
119        a Descriptor object describing this protocol message
120        type.
121
122    Returns:
123      Newly-allocated class.
124
125    Raises:
126      RuntimeError: Generated code only work with python cpp extension.
127    """
128    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
129
130    if isinstance(descriptor, str):
131      raise RuntimeError('The generated code only work with python cpp '
132                         'extension, but it is using pure python runtime.')
133
134    # If a concrete class already exists for this descriptor, don't try to
135    # create another.  Doing so will break any messages that already exist with
136    # the existing class.
137    #
138    # The C++ implementation appears to have its own internal `PyMessageFactory`
139    # to achieve similar results.
140    #
141    # This most commonly happens in `text_format.py` when using descriptors from
142    # a custom pool; it calls symbol_database.Global().getPrototype() on a
143    # descriptor which already has an existing concrete class.
144    new_class = getattr(descriptor, '_concrete_class', None)
145    if new_class:
146      return new_class
147
148    if descriptor.full_name in well_known_types.WKTBASES:
149      bases += (well_known_types.WKTBASES[descriptor.full_name],)
150    _AddClassAttributesForNestedExtensions(descriptor, dictionary)
151    _AddSlots(descriptor, dictionary)
152
153    superclass = super(GeneratedProtocolMessageType, cls)
154    new_class = superclass.__new__(cls, name, bases, dictionary)
155    return new_class
156
157  def __init__(cls, name, bases, dictionary):
158    """Here we perform the majority of our work on the class.
159    We add enum getters, an __init__ method, implementations
160    of all Message methods, and properties for all fields
161    in the protocol type.
162
163    Args:
164      name: Name of the class (ignored, but required by the
165        metaclass protocol).
166      bases: Base classes of the class we're constructing.
167        (Should be message.Message).  We ignore this field, but
168        it's required by the metaclass protocol
169      dictionary: The class dictionary of the class we're
170        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
171        a Descriptor object describing this protocol message
172        type.
173    """
174    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
175
176    # If this is an _existing_ class looked up via `_concrete_class` in the
177    # __new__ method above, then we don't need to re-initialize anything.
178    existing_class = getattr(descriptor, '_concrete_class', None)
179    if existing_class:
180      assert existing_class is cls, (
181          'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
182          % (descriptor.full_name))
183      return
184
185    cls._decoders_by_tag = {}
186    if (descriptor.has_options and
187        descriptor.GetOptions().message_set_wire_format):
188      cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
189          decoder.MessageSetItemDecoder(descriptor), None)
190
191    # Attach stuff to each FieldDescriptor for quick lookup later on.
192    for field in descriptor.fields:
193      _AttachFieldHelpers(cls, field)
194
195    descriptor._concrete_class = cls  # pylint: disable=protected-access
196    _AddEnumValues(descriptor, cls)
197    _AddInitMethod(descriptor, cls)
198    _AddPropertiesForFields(descriptor, cls)
199    _AddPropertiesForExtensions(descriptor, cls)
200    _AddStaticMethods(cls)
201    _AddMessageMethods(descriptor, cls)
202    _AddPrivateHelperMethods(descriptor, cls)
203
204    superclass = super(GeneratedProtocolMessageType, cls)
205    superclass.__init__(name, bases, dictionary)
206
207
208# Stateless helpers for GeneratedProtocolMessageType below.
209# Outside clients should not access these directly.
210#
211# I opted not to make any of these methods on the metaclass, to make it more
212# clear that I'm not really using any state there and to keep clients from
213# thinking that they have direct access to these construction helpers.
214
215
216def _PropertyName(proto_field_name):
217  """Returns the name of the public property attribute which
218  clients can use to get and (in some cases) set the value
219  of a protocol message field.
220
221  Args:
222    proto_field_name: The protocol message field name, exactly
223      as it appears (or would appear) in a .proto file.
224  """
225  # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
226  # nnorwitz makes my day by writing:
227  # """
228  # FYI.  See the keyword module in the stdlib. This could be as simple as:
229  #
230  # if keyword.iskeyword(proto_field_name):
231  #   return proto_field_name + "_"
232  # return proto_field_name
233  # """
234  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
235  #   getattr() and setattr() to reflectively manipulate field values.  If we
236  #   rename the properties, then every such user has to also make sure to apply
237  #   the same transformation.  Note that currently if you name a field "yield",
238  #   you can still access it just fine using getattr/setattr -- it's not even
239  #   that cumbersome to do so.
240  # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
241  #   position.
242  return proto_field_name
243
244
245def _AddSlots(message_descriptor, dictionary):
246  """Adds a __slots__ entry to dictionary, containing the names of all valid
247  attributes for this message type.
248
249  Args:
250    message_descriptor: A Descriptor instance describing this message type.
251    dictionary: Class dictionary to which we'll add a '__slots__' entry.
252  """
253  dictionary['__slots__'] = ['_cached_byte_size',
254                             '_cached_byte_size_dirty',
255                             '_fields',
256                             '_unknown_fields',
257                             '_unknown_field_set',
258                             '_is_present_in_parent',
259                             '_listener',
260                             '_listener_for_children',
261                             '__weakref__',
262                             '_oneofs']
263
264
265def _IsMessageSetExtension(field):
266  return (field.is_extension and
267          field.containing_type.has_options and
268          field.containing_type.GetOptions().message_set_wire_format and
269          field.type == _FieldDescriptor.TYPE_MESSAGE and
270          field.label == _FieldDescriptor.LABEL_OPTIONAL)
271
272
273def _IsMapField(field):
274  return (field.type == _FieldDescriptor.TYPE_MESSAGE and
275          field.message_type.has_options and
276          field.message_type.GetOptions().map_entry)
277
278
279def _IsMessageMapField(field):
280  value_type = field.message_type.fields_by_name['value']
281  return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
282
283
284def _AttachFieldHelpers(cls, field_descriptor):
285  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
286  is_packable = (is_repeated and
287                 wire_format.IsTypePackable(field_descriptor.type))
288  is_proto3 = field_descriptor.containing_type.syntax == 'proto3'
289  if not is_packable:
290    is_packed = False
291  elif field_descriptor.containing_type.syntax == 'proto2':
292    is_packed = (field_descriptor.has_options and
293                field_descriptor.GetOptions().packed)
294  else:
295    has_packed_false = (field_descriptor.has_options and
296                        field_descriptor.GetOptions().HasField('packed') and
297                        field_descriptor.GetOptions().packed == False)
298    is_packed = not has_packed_false
299  is_map_entry = _IsMapField(field_descriptor)
300
301  if is_map_entry:
302    field_encoder = encoder.MapEncoder(field_descriptor)
303    sizer = encoder.MapSizer(field_descriptor,
304                             _IsMessageMapField(field_descriptor))
305  elif _IsMessageSetExtension(field_descriptor):
306    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
307    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
308  else:
309    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
310        field_descriptor.number, is_repeated, is_packed)
311    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
312        field_descriptor.number, is_repeated, is_packed)
313
314  field_descriptor._encoder = field_encoder
315  field_descriptor._sizer = sizer
316  field_descriptor._default_constructor = _DefaultValueConstructorForField(
317      field_descriptor)
318
319  def AddDecoder(wiretype, is_packed):
320    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
321    decode_type = field_descriptor.type
322    if (decode_type == _FieldDescriptor.TYPE_ENUM and
323        type_checkers.SupportsOpenEnums(field_descriptor)):
324      decode_type = _FieldDescriptor.TYPE_INT32
325
326    oneof_descriptor = None
327    clear_if_default = False
328    if field_descriptor.containing_oneof is not None:
329      oneof_descriptor = field_descriptor
330    elif (is_proto3 and not is_repeated and
331          field_descriptor.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
332      clear_if_default = True
333
334    if is_map_entry:
335      is_message_map = _IsMessageMapField(field_descriptor)
336
337      field_decoder = decoder.MapDecoder(
338          field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
339          is_message_map)
340    elif decode_type == _FieldDescriptor.TYPE_STRING:
341      field_decoder = decoder.StringDecoder(
342          field_descriptor.number, is_repeated, is_packed,
343          field_descriptor, field_descriptor._default_constructor,
344          clear_if_default)
345    elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
346      field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
347          field_descriptor.number, is_repeated, is_packed,
348          field_descriptor, field_descriptor._default_constructor)
349    else:
350      field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
351          field_descriptor.number, is_repeated, is_packed,
352          # pylint: disable=protected-access
353          field_descriptor, field_descriptor._default_constructor,
354          clear_if_default)
355
356    cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
357
358  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
359             False)
360
361  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
362    # To support wire compatibility of adding packed = true, add a decoder for
363    # packed values regardless of the field's options.
364    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
365
366
367def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
368  extensions = descriptor.extensions_by_name
369  for extension_name, extension_field in extensions.items():
370    assert extension_name not in dictionary
371    dictionary[extension_name] = extension_field
372
373
374def _AddEnumValues(descriptor, cls):
375  """Sets class-level attributes for all enum fields defined in this message.
376
377  Also exporting a class-level object that can name enum values.
378
379  Args:
380    descriptor: Descriptor object for this message type.
381    cls: Class we're constructing for this message type.
382  """
383  for enum_type in descriptor.enum_types:
384    setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
385    for enum_value in enum_type.values:
386      setattr(cls, enum_value.name, enum_value.number)
387
388
389def _GetInitializeDefaultForMap(field):
390  if field.label != _FieldDescriptor.LABEL_REPEATED:
391    raise ValueError('map_entry set on non-repeated field %s' % (
392        field.name))
393  fields_by_name = field.message_type.fields_by_name
394  key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
395
396  value_field = fields_by_name['value']
397  if _IsMessageMapField(field):
398    def MakeMessageMapDefault(message):
399      return containers.MessageMap(
400          message._listener_for_children, value_field.message_type, key_checker,
401          field.message_type)
402    return MakeMessageMapDefault
403  else:
404    value_checker = type_checkers.GetTypeChecker(value_field)
405    def MakePrimitiveMapDefault(message):
406      return containers.ScalarMap(
407          message._listener_for_children, key_checker, value_checker,
408          field.message_type)
409    return MakePrimitiveMapDefault
410
411def _DefaultValueConstructorForField(field):
412  """Returns a function which returns a default value for a field.
413
414  Args:
415    field: FieldDescriptor object for this field.
416
417  The returned function has one argument:
418    message: Message instance containing this field, or a weakref proxy
419      of same.
420
421  That function in turn returns a default value for this field.  The default
422    value may refer back to |message| via a weak reference.
423  """
424
425  if _IsMapField(field):
426    return _GetInitializeDefaultForMap(field)
427
428  if field.label == _FieldDescriptor.LABEL_REPEATED:
429    if field.has_default_value and field.default_value != []:
430      raise ValueError('Repeated field default value not empty list: %s' % (
431          field.default_value))
432    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
433      # We can't look at _concrete_class yet since it might not have
434      # been set.  (Depends on order in which we initialize the classes).
435      message_type = field.message_type
436      def MakeRepeatedMessageDefault(message):
437        return containers.RepeatedCompositeFieldContainer(
438            message._listener_for_children, field.message_type)
439      return MakeRepeatedMessageDefault
440    else:
441      type_checker = type_checkers.GetTypeChecker(field)
442      def MakeRepeatedScalarDefault(message):
443        return containers.RepeatedScalarFieldContainer(
444            message._listener_for_children, type_checker)
445      return MakeRepeatedScalarDefault
446
447  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
448    # _concrete_class may not yet be initialized.
449    message_type = field.message_type
450    def MakeSubMessageDefault(message):
451      assert getattr(message_type, '_concrete_class', None), (
452          'Uninitialized concrete class found for field %r (message type %r)'
453          % (field.full_name, message_type.full_name))
454      result = message_type._concrete_class()
455      result._SetListener(
456          _OneofListener(message, field)
457          if field.containing_oneof is not None
458          else message._listener_for_children)
459      return result
460    return MakeSubMessageDefault
461
462  def MakeScalarDefault(message):
463    # TODO(protobuf-team): This may be broken since there may not be
464    # default_value.  Combine with has_default_value somehow.
465    return field.default_value
466  return MakeScalarDefault
467
468
469def _ReraiseTypeErrorWithFieldName(message_name, field_name):
470  """Re-raise the currently-handled TypeError with the field name added."""
471  exc = sys.exc_info()[1]
472  if len(exc.args) == 1 and type(exc) is TypeError:
473    # simple TypeError; add field name to exception message
474    exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
475
476  # re-raise possibly-amended exception with original traceback:
477  raise exc.with_traceback(sys.exc_info()[2])
478
479
480def _AddInitMethod(message_descriptor, cls):
481  """Adds an __init__ method to cls."""
482
483  def _GetIntegerEnumValue(enum_type, value):
484    """Convert a string or integer enum value to an integer.
485
486    If the value is a string, it is converted to the enum value in
487    enum_type with the same name.  If the value is not a string, it's
488    returned as-is.  (No conversion or bounds-checking is done.)
489    """
490    if isinstance(value, str):
491      try:
492        return enum_type.values_by_name[value].number
493      except KeyError:
494        raise ValueError('Enum type %s: unknown label "%s"' % (
495            enum_type.full_name, value))
496    return value
497
498  def init(self, **kwargs):
499    self._cached_byte_size = 0
500    self._cached_byte_size_dirty = len(kwargs) > 0
501    self._fields = {}
502    # Contains a mapping from oneof field descriptors to the descriptor
503    # of the currently set field in that oneof field.
504    self._oneofs = {}
505
506    # _unknown_fields is () when empty for efficiency, and will be turned into
507    # a list if fields are added.
508    self._unknown_fields = ()
509    # _unknown_field_set is None when empty for efficiency, and will be
510    # turned into UnknownFieldSet struct if fields are added.
511    self._unknown_field_set = None      # pylint: disable=protected-access
512    self._is_present_in_parent = False
513    self._listener = message_listener_mod.NullMessageListener()
514    self._listener_for_children = _Listener(self)
515    for field_name, field_value in kwargs.items():
516      field = _GetFieldByName(message_descriptor, field_name)
517      if field is None:
518        raise TypeError('%s() got an unexpected keyword argument "%s"' %
519                        (message_descriptor.name, field_name))
520      if field_value is None:
521        # field=None is the same as no field at all.
522        continue
523      if field.label == _FieldDescriptor.LABEL_REPEATED:
524        copy = field._default_constructor(self)
525        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
526          if _IsMapField(field):
527            if _IsMessageMapField(field):
528              for key in field_value:
529                copy[key].MergeFrom(field_value[key])
530            else:
531              copy.update(field_value)
532          else:
533            for val in field_value:
534              if isinstance(val, dict):
535                copy.add(**val)
536              else:
537                copy.add().MergeFrom(val)
538        else:  # Scalar
539          if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
540            field_value = [_GetIntegerEnumValue(field.enum_type, val)
541                           for val in field_value]
542          copy.extend(field_value)
543        self._fields[field] = copy
544      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
545        copy = field._default_constructor(self)
546        new_val = field_value
547        if isinstance(field_value, dict):
548          new_val = field.message_type._concrete_class(**field_value)
549        try:
550          copy.MergeFrom(new_val)
551        except TypeError:
552          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
553        self._fields[field] = copy
554      else:
555        if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
556          field_value = _GetIntegerEnumValue(field.enum_type, field_value)
557        try:
558          setattr(self, field_name, field_value)
559        except TypeError:
560          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
561
562  init.__module__ = None
563  init.__doc__ = None
564  cls.__init__ = init
565
566
567def _GetFieldByName(message_descriptor, field_name):
568  """Returns a field descriptor by field name.
569
570  Args:
571    message_descriptor: A Descriptor describing all fields in message.
572    field_name: The name of the field to retrieve.
573  Returns:
574    The field descriptor associated with the field name.
575  """
576  try:
577    return message_descriptor.fields_by_name[field_name]
578  except KeyError:
579    raise ValueError('Protocol message %s has no "%s" field.' %
580                     (message_descriptor.name, field_name))
581
582
583def _AddPropertiesForFields(descriptor, cls):
584  """Adds properties for all fields in this protocol message type."""
585  for field in descriptor.fields:
586    _AddPropertiesForField(field, cls)
587
588  if descriptor.is_extendable:
589    # _ExtensionDict is just an adaptor with no state so we allocate a new one
590    # every time it is accessed.
591    cls.Extensions = property(lambda self: _ExtensionDict(self))
592
593
594def _AddPropertiesForField(field, cls):
595  """Adds a public property for a protocol message field.
596  Clients can use this property to get and (in the case
597  of non-repeated scalar fields) directly set the value
598  of a protocol message field.
599
600  Args:
601    field: A FieldDescriptor for this field.
602    cls: The class we're constructing.
603  """
604  # Catch it if we add other types that we should
605  # handle specially here.
606  assert _FieldDescriptor.MAX_CPPTYPE == 10
607
608  constant_name = field.name.upper() + '_FIELD_NUMBER'
609  setattr(cls, constant_name, field.number)
610
611  if field.label == _FieldDescriptor.LABEL_REPEATED:
612    _AddPropertiesForRepeatedField(field, cls)
613  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
614    _AddPropertiesForNonRepeatedCompositeField(field, cls)
615  else:
616    _AddPropertiesForNonRepeatedScalarField(field, cls)
617
618
619class _FieldProperty(property):
620  __slots__ = ('DESCRIPTOR',)
621
622  def __init__(self, descriptor, getter, setter, doc):
623    property.__init__(self, getter, setter, doc=doc)
624    self.DESCRIPTOR = descriptor
625
626
627def _AddPropertiesForRepeatedField(field, cls):
628  """Adds a public property for a "repeated" protocol message field.  Clients
629  can use this property to get the value of the field, which will be either a
630  RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
631  below).
632
633  Note that when clients add values to these containers, we perform
634  type-checking in the case of repeated scalar fields, and we also set any
635  necessary "has" bits as a side-effect.
636
637  Args:
638    field: A FieldDescriptor for this field.
639    cls: The class we're constructing.
640  """
641  proto_field_name = field.name
642  property_name = _PropertyName(proto_field_name)
643
644  def getter(self):
645    field_value = self._fields.get(field)
646    if field_value is None:
647      # Construct a new object to represent this field.
648      field_value = field._default_constructor(self)
649
650      # Atomically check if another thread has preempted us and, if not, swap
651      # in the new object we just created.  If someone has preempted us, we
652      # take that object and discard ours.
653      # WARNING:  We are relying on setdefault() being atomic.  This is true
654      #   in CPython but we haven't investigated others.  This warning appears
655      #   in several other locations in this file.
656      field_value = self._fields.setdefault(field, field_value)
657    return field_value
658  getter.__module__ = None
659  getter.__doc__ = 'Getter for %s.' % proto_field_name
660
661  # We define a setter just so we can throw an exception with a more
662  # helpful error message.
663  def setter(self, new_value):
664    raise AttributeError('Assignment not allowed to repeated field '
665                         '"%s" in protocol message object.' % proto_field_name)
666
667  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
668  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
669
670
671def _AddPropertiesForNonRepeatedScalarField(field, cls):
672  """Adds a public property for a nonrepeated, scalar protocol message field.
673  Clients can use this property to get and directly set the value of the field.
674  Note that when the client sets the value of a field by using this property,
675  all necessary "has" bits are set as a side-effect, and we also perform
676  type-checking.
677
678  Args:
679    field: A FieldDescriptor for this field.
680    cls: The class we're constructing.
681  """
682  proto_field_name = field.name
683  property_name = _PropertyName(proto_field_name)
684  type_checker = type_checkers.GetTypeChecker(field)
685  default_value = field.default_value
686  is_proto3 = field.containing_type.syntax == 'proto3'
687
688  def getter(self):
689    # TODO(protobuf-team): This may be broken since there may not be
690    # default_value.  Combine with has_default_value somehow.
691    return self._fields.get(field, default_value)
692  getter.__module__ = None
693  getter.__doc__ = 'Getter for %s.' % proto_field_name
694
695  clear_when_set_to_default = is_proto3 and not field.containing_oneof
696
697  def field_setter(self, new_value):
698    # pylint: disable=protected-access
699    # Testing the value for truthiness captures all of the proto3 defaults
700    # (0, 0.0, enum 0, and False).
701    try:
702      new_value = type_checker.CheckValue(new_value)
703    except TypeError as e:
704      raise TypeError(
705          'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
706    if clear_when_set_to_default and not new_value:
707      self._fields.pop(field, None)
708    else:
709      self._fields[field] = new_value
710    # Check _cached_byte_size_dirty inline to improve performance, since scalar
711    # setters are called frequently.
712    if not self._cached_byte_size_dirty:
713      self._Modified()
714
715  if field.containing_oneof:
716    def setter(self, new_value):
717      field_setter(self, new_value)
718      self._UpdateOneofState(field)
719  else:
720    setter = field_setter
721
722  setter.__module__ = None
723  setter.__doc__ = 'Setter for %s.' % proto_field_name
724
725  # Add a property to encapsulate the getter/setter.
726  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
727  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
728
729
730def _AddPropertiesForNonRepeatedCompositeField(field, cls):
731  """Adds a public property for a nonrepeated, composite protocol message field.
732  A composite field is a "group" or "message" field.
733
734  Clients can use this property to get the value of the field, but cannot
735  assign to the property directly.
736
737  Args:
738    field: A FieldDescriptor for this field.
739    cls: The class we're constructing.
740  """
741  # TODO(robinson): Remove duplication with similar method
742  # for non-repeated scalars.
743  proto_field_name = field.name
744  property_name = _PropertyName(proto_field_name)
745
746  def getter(self):
747    field_value = self._fields.get(field)
748    if field_value is None:
749      # Construct a new object to represent this field.
750      field_value = field._default_constructor(self)
751
752      # Atomically check if another thread has preempted us and, if not, swap
753      # in the new object we just created.  If someone has preempted us, we
754      # take that object and discard ours.
755      # WARNING:  We are relying on setdefault() being atomic.  This is true
756      #   in CPython but we haven't investigated others.  This warning appears
757      #   in several other locations in this file.
758      field_value = self._fields.setdefault(field, field_value)
759    return field_value
760  getter.__module__ = None
761  getter.__doc__ = 'Getter for %s.' % proto_field_name
762
763  # We define a setter just so we can throw an exception with a more
764  # helpful error message.
765  def setter(self, new_value):
766    raise AttributeError('Assignment not allowed to composite field '
767                         '"%s" in protocol message object.' % proto_field_name)
768
769  # Add a property to encapsulate the getter.
770  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
771  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
772
773
774def _AddPropertiesForExtensions(descriptor, cls):
775  """Adds properties for all fields in this protocol message type."""
776  extensions = descriptor.extensions_by_name
777  for extension_name, extension_field in extensions.items():
778    constant_name = extension_name.upper() + '_FIELD_NUMBER'
779    setattr(cls, constant_name, extension_field.number)
780
781  # TODO(amauryfa): Migrate all users of these attributes to functions like
782  #   pool.FindExtensionByNumber(descriptor).
783  if descriptor.file is not None:
784    # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
785    pool = descriptor.file.pool
786    cls._extensions_by_number = pool._extensions_by_number[descriptor]
787    cls._extensions_by_name = pool._extensions_by_name[descriptor]
788
789def _AddStaticMethods(cls):
790  # TODO(robinson): This probably needs to be thread-safe(?)
791  def RegisterExtension(extension_handle):
792    extension_handle.containing_type = cls.DESCRIPTOR
793    # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
794    # pylint: disable=protected-access
795    cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(extension_handle)
796    _AttachFieldHelpers(cls, extension_handle)
797  cls.RegisterExtension = staticmethod(RegisterExtension)
798
799  def FromString(s):
800    message = cls()
801    message.MergeFromString(s)
802    return message
803  cls.FromString = staticmethod(FromString)
804
805
806def _IsPresent(item):
807  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
808  value should be included in the list returned by ListFields()."""
809
810  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
811    return bool(item[1])
812  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
813    return item[1]._is_present_in_parent
814  else:
815    return True
816
817
818def _AddListFieldsMethod(message_descriptor, cls):
819  """Helper for _AddMessageMethods()."""
820
821  def ListFields(self):
822    all_fields = [item for item in self._fields.items() if _IsPresent(item)]
823    all_fields.sort(key = lambda item: item[0].number)
824    return all_fields
825
826  cls.ListFields = ListFields
827
828_PROTO3_ERROR_TEMPLATE = \
829  ('Protocol message %s has no non-repeated submessage field "%s" '
830   'nor marked as optional')
831_PROTO2_ERROR_TEMPLATE = 'Protocol message %s has no non-repeated field "%s"'
832
833def _AddHasFieldMethod(message_descriptor, cls):
834  """Helper for _AddMessageMethods()."""
835
836  is_proto3 = (message_descriptor.syntax == "proto3")
837  error_msg = _PROTO3_ERROR_TEMPLATE if is_proto3 else _PROTO2_ERROR_TEMPLATE
838
839  hassable_fields = {}
840  for field in message_descriptor.fields:
841    if field.label == _FieldDescriptor.LABEL_REPEATED:
842      continue
843    # For proto3, only submessages and fields inside a oneof have presence.
844    if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
845        not field.containing_oneof):
846      continue
847    hassable_fields[field.name] = field
848
849  # Has methods are supported for oneof descriptors.
850  for oneof in message_descriptor.oneofs:
851    hassable_fields[oneof.name] = oneof
852
853  def HasField(self, field_name):
854    try:
855      field = hassable_fields[field_name]
856    except KeyError:
857      raise ValueError(error_msg % (message_descriptor.full_name, field_name))
858
859    if isinstance(field, descriptor_mod.OneofDescriptor):
860      try:
861        return HasField(self, self._oneofs[field].name)
862      except KeyError:
863        return False
864    else:
865      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
866        value = self._fields.get(field)
867        return value is not None and value._is_present_in_parent
868      else:
869        return field in self._fields
870
871  cls.HasField = HasField
872
873
874def _AddClearFieldMethod(message_descriptor, cls):
875  """Helper for _AddMessageMethods()."""
876  def ClearField(self, field_name):
877    try:
878      field = message_descriptor.fields_by_name[field_name]
879    except KeyError:
880      try:
881        field = message_descriptor.oneofs_by_name[field_name]
882        if field in self._oneofs:
883          field = self._oneofs[field]
884        else:
885          return
886      except KeyError:
887        raise ValueError('Protocol message %s has no "%s" field.' %
888                         (message_descriptor.name, field_name))
889
890    if field in self._fields:
891      # To match the C++ implementation, we need to invalidate iterators
892      # for map fields when ClearField() happens.
893      if hasattr(self._fields[field], 'InvalidateIterators'):
894        self._fields[field].InvalidateIterators()
895
896      # Note:  If the field is a sub-message, its listener will still point
897      #   at us.  That's fine, because the worst than can happen is that it
898      #   will call _Modified() and invalidate our byte size.  Big deal.
899      del self._fields[field]
900
901      if self._oneofs.get(field.containing_oneof, None) is field:
902        del self._oneofs[field.containing_oneof]
903
904    # Always call _Modified() -- even if nothing was changed, this is
905    # a mutating method, and thus calling it should cause the field to become
906    # present in the parent message.
907    self._Modified()
908
909  cls.ClearField = ClearField
910
911
912def _AddClearExtensionMethod(cls):
913  """Helper for _AddMessageMethods()."""
914  def ClearExtension(self, extension_handle):
915    extension_dict._VerifyExtensionHandle(self, extension_handle)
916
917    # Similar to ClearField(), above.
918    if extension_handle in self._fields:
919      del self._fields[extension_handle]
920    self._Modified()
921  cls.ClearExtension = ClearExtension
922
923
924def _AddHasExtensionMethod(cls):
925  """Helper for _AddMessageMethods()."""
926  def HasExtension(self, extension_handle):
927    extension_dict._VerifyExtensionHandle(self, extension_handle)
928    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
929      raise KeyError('"%s" is repeated.' % extension_handle.full_name)
930
931    if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
932      value = self._fields.get(extension_handle)
933      return value is not None and value._is_present_in_parent
934    else:
935      return extension_handle in self._fields
936  cls.HasExtension = HasExtension
937
938def _InternalUnpackAny(msg):
939  """Unpacks Any message and returns the unpacked message.
940
941  This internal method is different from public Any Unpack method which takes
942  the target message as argument. _InternalUnpackAny method does not have
943  target message type and need to find the message type in descriptor pool.
944
945  Args:
946    msg: An Any message to be unpacked.
947
948  Returns:
949    The unpacked message.
950  """
951  # TODO(amauryfa): Don't use the factory of generated messages.
952  # To make Any work with custom factories, use the message factory of the
953  # parent message.
954  # pylint: disable=g-import-not-at-top
955  from google.protobuf import symbol_database
956  factory = symbol_database.Default()
957
958  type_url = msg.type_url
959
960  if not type_url:
961    return None
962
963  # TODO(haberman): For now we just strip the hostname.  Better logic will be
964  # required.
965  type_name = type_url.split('/')[-1]
966  descriptor = factory.pool.FindMessageTypeByName(type_name)
967
968  if descriptor is None:
969    return None
970
971  message_class = factory.GetPrototype(descriptor)
972  message = message_class()
973
974  message.ParseFromString(msg.value)
975  return message
976
977
978def _AddEqualsMethod(message_descriptor, cls):
979  """Helper for _AddMessageMethods()."""
980  def __eq__(self, other):
981    if (not isinstance(other, message_mod.Message) or
982        other.DESCRIPTOR != self.DESCRIPTOR):
983      return False
984
985    if self is other:
986      return True
987
988    if self.DESCRIPTOR.full_name == _AnyFullTypeName:
989      any_a = _InternalUnpackAny(self)
990      any_b = _InternalUnpackAny(other)
991      if any_a and any_b:
992        return any_a == any_b
993
994    if not self.ListFields() == other.ListFields():
995      return False
996
997    # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions,
998    # then use it for the comparison.
999    unknown_fields = list(self._unknown_fields)
1000    unknown_fields.sort()
1001    other_unknown_fields = list(other._unknown_fields)
1002    other_unknown_fields.sort()
1003    return unknown_fields == other_unknown_fields
1004
1005  cls.__eq__ = __eq__
1006
1007
1008def _AddStrMethod(message_descriptor, cls):
1009  """Helper for _AddMessageMethods()."""
1010  def __str__(self):
1011    return text_format.MessageToString(self)
1012  cls.__str__ = __str__
1013
1014
1015def _AddReprMethod(message_descriptor, cls):
1016  """Helper for _AddMessageMethods()."""
1017  def __repr__(self):
1018    return text_format.MessageToString(self)
1019  cls.__repr__ = __repr__
1020
1021
1022def _AddUnicodeMethod(unused_message_descriptor, cls):
1023  """Helper for _AddMessageMethods()."""
1024
1025  def __unicode__(self):
1026    return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1027  cls.__unicode__ = __unicode__
1028
1029
1030def _BytesForNonRepeatedElement(value, field_number, field_type):
1031  """Returns the number of bytes needed to serialize a non-repeated element.
1032  The returned byte count includes space for tag information and any
1033  other additional space associated with serializing value.
1034
1035  Args:
1036    value: Value we're serializing.
1037    field_number: Field number of this value.  (Since the field number
1038      is stored as part of a varint-encoded tag, this has an impact
1039      on the total bytes required to serialize the value).
1040    field_type: The type of the field.  One of the TYPE_* constants
1041      within FieldDescriptor.
1042  """
1043  try:
1044    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1045    return fn(field_number, value)
1046  except KeyError:
1047    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1048
1049
1050def _AddByteSizeMethod(message_descriptor, cls):
1051  """Helper for _AddMessageMethods()."""
1052
1053  def ByteSize(self):
1054    if not self._cached_byte_size_dirty:
1055      return self._cached_byte_size
1056
1057    size = 0
1058    descriptor = self.DESCRIPTOR
1059    if descriptor.GetOptions().map_entry:
1060      # Fields of map entry should always be serialized.
1061      size = descriptor.fields_by_name['key']._sizer(self.key)
1062      size += descriptor.fields_by_name['value']._sizer(self.value)
1063    else:
1064      for field_descriptor, field_value in self.ListFields():
1065        size += field_descriptor._sizer(field_value)
1066      for tag_bytes, value_bytes in self._unknown_fields:
1067        size += len(tag_bytes) + len(value_bytes)
1068
1069    self._cached_byte_size = size
1070    self._cached_byte_size_dirty = False
1071    self._listener_for_children.dirty = False
1072    return size
1073
1074  cls.ByteSize = ByteSize
1075
1076
1077def _AddSerializeToStringMethod(message_descriptor, cls):
1078  """Helper for _AddMessageMethods()."""
1079
1080  def SerializeToString(self, **kwargs):
1081    # Check if the message has all of its required fields set.
1082    if not self.IsInitialized():
1083      raise message_mod.EncodeError(
1084          'Message %s is missing required fields: %s' % (
1085          self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1086    return self.SerializePartialToString(**kwargs)
1087  cls.SerializeToString = SerializeToString
1088
1089
1090def _AddSerializePartialToStringMethod(message_descriptor, cls):
1091  """Helper for _AddMessageMethods()."""
1092
1093  def SerializePartialToString(self, **kwargs):
1094    out = BytesIO()
1095    self._InternalSerialize(out.write, **kwargs)
1096    return out.getvalue()
1097  cls.SerializePartialToString = SerializePartialToString
1098
1099  def InternalSerialize(self, write_bytes, deterministic=None):
1100    if deterministic is None:
1101      deterministic = (
1102          api_implementation.IsPythonDefaultSerializationDeterministic())
1103    else:
1104      deterministic = bool(deterministic)
1105
1106    descriptor = self.DESCRIPTOR
1107    if descriptor.GetOptions().map_entry:
1108      # Fields of map entry should always be serialized.
1109      descriptor.fields_by_name['key']._encoder(
1110          write_bytes, self.key, deterministic)
1111      descriptor.fields_by_name['value']._encoder(
1112          write_bytes, self.value, deterministic)
1113    else:
1114      for field_descriptor, field_value in self.ListFields():
1115        field_descriptor._encoder(write_bytes, field_value, deterministic)
1116      for tag_bytes, value_bytes in self._unknown_fields:
1117        write_bytes(tag_bytes)
1118        write_bytes(value_bytes)
1119  cls._InternalSerialize = InternalSerialize
1120
1121
1122def _AddMergeFromStringMethod(message_descriptor, cls):
1123  """Helper for _AddMessageMethods()."""
1124  def MergeFromString(self, serialized):
1125    serialized = memoryview(serialized)
1126    length = len(serialized)
1127    try:
1128      if self._InternalParse(serialized, 0, length) != length:
1129        # The only reason _InternalParse would return early is if it
1130        # encountered an end-group tag.
1131        raise message_mod.DecodeError('Unexpected end-group tag.')
1132    except (IndexError, TypeError):
1133      # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1134      raise message_mod.DecodeError('Truncated message.')
1135    except struct.error as e:
1136      raise message_mod.DecodeError(e)
1137    return length   # Return this for legacy reasons.
1138  cls.MergeFromString = MergeFromString
1139
1140  local_ReadTag = decoder.ReadTag
1141  local_SkipField = decoder.SkipField
1142  decoders_by_tag = cls._decoders_by_tag
1143
1144  def InternalParse(self, buffer, pos, end):
1145    """Create a message from serialized bytes.
1146
1147    Args:
1148      self: Message, instance of the proto message object.
1149      buffer: memoryview of the serialized data.
1150      pos: int, position to start in the serialized data.
1151      end: int, end position of the serialized data.
1152
1153    Returns:
1154      Message object.
1155    """
1156    # Guard against internal misuse, since this function is called internally
1157    # quite extensively, and its easy to accidentally pass bytes.
1158    assert isinstance(buffer, memoryview)
1159    self._Modified()
1160    field_dict = self._fields
1161    # pylint: disable=protected-access
1162    unknown_field_set = self._unknown_field_set
1163    while pos != end:
1164      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1165      field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
1166      if field_decoder is None:
1167        if not self._unknown_fields:   # pylint: disable=protected-access
1168          self._unknown_fields = []    # pylint: disable=protected-access
1169        if unknown_field_set is None:
1170          # pylint: disable=protected-access
1171          self._unknown_field_set = containers.UnknownFieldSet()
1172          # pylint: disable=protected-access
1173          unknown_field_set = self._unknown_field_set
1174        # pylint: disable=protected-access
1175        (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
1176        field_number, wire_type = wire_format.UnpackTag(tag)
1177        if field_number == 0:
1178          raise message_mod.DecodeError('Field number 0 is illegal.')
1179        # TODO(jieluo): remove old_pos.
1180        old_pos = new_pos
1181        (data, new_pos) = decoder._DecodeUnknownField(
1182            buffer, new_pos, wire_type)  # pylint: disable=protected-access
1183        if new_pos == -1:
1184          return pos
1185        # pylint: disable=protected-access
1186        unknown_field_set._add(field_number, wire_type, data)
1187        # TODO(jieluo): remove _unknown_fields.
1188        new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
1189        if new_pos == -1:
1190          return pos
1191        self._unknown_fields.append(
1192            (tag_bytes, buffer[old_pos:new_pos].tobytes()))
1193        pos = new_pos
1194      else:
1195        pos = field_decoder(buffer, new_pos, end, self, field_dict)
1196        if field_desc:
1197          self._UpdateOneofState(field_desc)
1198    return pos
1199  cls._InternalParse = InternalParse
1200
1201
1202def _AddIsInitializedMethod(message_descriptor, cls):
1203  """Adds the IsInitialized and FindInitializationError methods to the
1204  protocol message class."""
1205
1206  required_fields = [field for field in message_descriptor.fields
1207                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
1208
1209  def IsInitialized(self, errors=None):
1210    """Checks if all required fields of a message are set.
1211
1212    Args:
1213      errors:  A list which, if provided, will be populated with the field
1214               paths of all missing required fields.
1215
1216    Returns:
1217      True iff the specified message has all required fields set.
1218    """
1219
1220    # Performance is critical so we avoid HasField() and ListFields().
1221
1222    for field in required_fields:
1223      if (field not in self._fields or
1224          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1225           not self._fields[field]._is_present_in_parent)):
1226        if errors is not None:
1227          errors.extend(self.FindInitializationErrors())
1228        return False
1229
1230    for field, value in list(self._fields.items()):  # dict can change size!
1231      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1232        if field.label == _FieldDescriptor.LABEL_REPEATED:
1233          if (field.message_type.has_options and
1234              field.message_type.GetOptions().map_entry):
1235            continue
1236          for element in value:
1237            if not element.IsInitialized():
1238              if errors is not None:
1239                errors.extend(self.FindInitializationErrors())
1240              return False
1241        elif value._is_present_in_parent and not value.IsInitialized():
1242          if errors is not None:
1243            errors.extend(self.FindInitializationErrors())
1244          return False
1245
1246    return True
1247
1248  cls.IsInitialized = IsInitialized
1249
1250  def FindInitializationErrors(self):
1251    """Finds required fields which are not initialized.
1252
1253    Returns:
1254      A list of strings.  Each string is a path to an uninitialized field from
1255      the top-level message, e.g. "foo.bar[5].baz".
1256    """
1257
1258    errors = []  # simplify things
1259
1260    for field in required_fields:
1261      if not self.HasField(field.name):
1262        errors.append(field.name)
1263
1264    for field, value in self.ListFields():
1265      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1266        if field.is_extension:
1267          name = '(%s)' % field.full_name
1268        else:
1269          name = field.name
1270
1271        if _IsMapField(field):
1272          if _IsMessageMapField(field):
1273            for key in value:
1274              element = value[key]
1275              prefix = '%s[%s].' % (name, key)
1276              sub_errors = element.FindInitializationErrors()
1277              errors += [prefix + error for error in sub_errors]
1278          else:
1279            # ScalarMaps can't have any initialization errors.
1280            pass
1281        elif field.label == _FieldDescriptor.LABEL_REPEATED:
1282          for i in range(len(value)):
1283            element = value[i]
1284            prefix = '%s[%d].' % (name, i)
1285            sub_errors = element.FindInitializationErrors()
1286            errors += [prefix + error for error in sub_errors]
1287        else:
1288          prefix = name + '.'
1289          sub_errors = value.FindInitializationErrors()
1290          errors += [prefix + error for error in sub_errors]
1291
1292    return errors
1293
1294  cls.FindInitializationErrors = FindInitializationErrors
1295
1296
1297def _FullyQualifiedClassName(klass):
1298  module = klass.__module__
1299  name = getattr(klass, '__qualname__', klass.__name__)
1300  if module in (None, 'builtins', '__builtin__'):
1301    return name
1302  return module + '.' + name
1303
1304
1305def _AddMergeFromMethod(cls):
1306  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1307  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1308
1309  def MergeFrom(self, msg):
1310    if not isinstance(msg, cls):
1311      raise TypeError(
1312          'Parameter to MergeFrom() must be instance of same class: '
1313          'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1314                                   _FullyQualifiedClassName(msg.__class__)))
1315
1316    assert msg is not self
1317    self._Modified()
1318
1319    fields = self._fields
1320
1321    for field, value in msg._fields.items():
1322      if field.label == LABEL_REPEATED:
1323        field_value = fields.get(field)
1324        if field_value is None:
1325          # Construct a new object to represent this field.
1326          field_value = field._default_constructor(self)
1327          fields[field] = field_value
1328        field_value.MergeFrom(value)
1329      elif field.cpp_type == CPPTYPE_MESSAGE:
1330        if value._is_present_in_parent:
1331          field_value = fields.get(field)
1332          if field_value is None:
1333            # Construct a new object to represent this field.
1334            field_value = field._default_constructor(self)
1335            fields[field] = field_value
1336          field_value.MergeFrom(value)
1337      else:
1338        self._fields[field] = value
1339        if field.containing_oneof:
1340          self._UpdateOneofState(field)
1341
1342    if msg._unknown_fields:
1343      if not self._unknown_fields:
1344        self._unknown_fields = []
1345      self._unknown_fields.extend(msg._unknown_fields)
1346      # pylint: disable=protected-access
1347      if self._unknown_field_set is None:
1348        self._unknown_field_set = containers.UnknownFieldSet()
1349      self._unknown_field_set._extend(msg._unknown_field_set)
1350
1351  cls.MergeFrom = MergeFrom
1352
1353
1354def _AddWhichOneofMethod(message_descriptor, cls):
1355  def WhichOneof(self, oneof_name):
1356    """Returns the name of the currently set field inside a oneof, or None."""
1357    try:
1358      field = message_descriptor.oneofs_by_name[oneof_name]
1359    except KeyError:
1360      raise ValueError(
1361          'Protocol message has no oneof "%s" field.' % oneof_name)
1362
1363    nested_field = self._oneofs.get(field, None)
1364    if nested_field is not None and self.HasField(nested_field.name):
1365      return nested_field.name
1366    else:
1367      return None
1368
1369  cls.WhichOneof = WhichOneof
1370
1371
1372def _Clear(self):
1373  # Clear fields.
1374  self._fields = {}
1375  self._unknown_fields = ()
1376  # pylint: disable=protected-access
1377  if self._unknown_field_set is not None:
1378    self._unknown_field_set._clear()
1379    self._unknown_field_set = None
1380
1381  self._oneofs = {}
1382  self._Modified()
1383
1384
1385def _UnknownFields(self):
1386  if self._unknown_field_set is None:  # pylint: disable=protected-access
1387    # pylint: disable=protected-access
1388    self._unknown_field_set = containers.UnknownFieldSet()
1389  return self._unknown_field_set    # pylint: disable=protected-access
1390
1391
1392def _DiscardUnknownFields(self):
1393  self._unknown_fields = []
1394  self._unknown_field_set = None      # pylint: disable=protected-access
1395  for field, value in self.ListFields():
1396    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1397      if _IsMapField(field):
1398        if _IsMessageMapField(field):
1399          for key in value:
1400            value[key].DiscardUnknownFields()
1401      elif field.label == _FieldDescriptor.LABEL_REPEATED:
1402        for sub_message in value:
1403          sub_message.DiscardUnknownFields()
1404      else:
1405        value.DiscardUnknownFields()
1406
1407
1408def _SetListener(self, listener):
1409  if listener is None:
1410    self._listener = message_listener_mod.NullMessageListener()
1411  else:
1412    self._listener = listener
1413
1414
1415def _AddMessageMethods(message_descriptor, cls):
1416  """Adds implementations of all Message methods to cls."""
1417  _AddListFieldsMethod(message_descriptor, cls)
1418  _AddHasFieldMethod(message_descriptor, cls)
1419  _AddClearFieldMethod(message_descriptor, cls)
1420  if message_descriptor.is_extendable:
1421    _AddClearExtensionMethod(cls)
1422    _AddHasExtensionMethod(cls)
1423  _AddEqualsMethod(message_descriptor, cls)
1424  _AddStrMethod(message_descriptor, cls)
1425  _AddReprMethod(message_descriptor, cls)
1426  _AddUnicodeMethod(message_descriptor, cls)
1427  _AddByteSizeMethod(message_descriptor, cls)
1428  _AddSerializeToStringMethod(message_descriptor, cls)
1429  _AddSerializePartialToStringMethod(message_descriptor, cls)
1430  _AddMergeFromStringMethod(message_descriptor, cls)
1431  _AddIsInitializedMethod(message_descriptor, cls)
1432  _AddMergeFromMethod(cls)
1433  _AddWhichOneofMethod(message_descriptor, cls)
1434  # Adds methods which do not depend on cls.
1435  cls.Clear = _Clear
1436  cls.UnknownFields = _UnknownFields
1437  cls.DiscardUnknownFields = _DiscardUnknownFields
1438  cls._SetListener = _SetListener
1439
1440
1441def _AddPrivateHelperMethods(message_descriptor, cls):
1442  """Adds implementation of private helper methods to cls."""
1443
1444  def Modified(self):
1445    """Sets the _cached_byte_size_dirty bit to true,
1446    and propagates this to our listener iff this was a state change.
1447    """
1448
1449    # Note:  Some callers check _cached_byte_size_dirty before calling
1450    #   _Modified() as an extra optimization.  So, if this method is ever
1451    #   changed such that it does stuff even when _cached_byte_size_dirty is
1452    #   already true, the callers need to be updated.
1453    if not self._cached_byte_size_dirty:
1454      self._cached_byte_size_dirty = True
1455      self._listener_for_children.dirty = True
1456      self._is_present_in_parent = True
1457      self._listener.Modified()
1458
1459  def _UpdateOneofState(self, field):
1460    """Sets field as the active field in its containing oneof.
1461
1462    Will also delete currently active field in the oneof, if it is different
1463    from the argument. Does not mark the message as modified.
1464    """
1465    other_field = self._oneofs.setdefault(field.containing_oneof, field)
1466    if other_field is not field:
1467      del self._fields[other_field]
1468      self._oneofs[field.containing_oneof] = field
1469
1470  cls._Modified = Modified
1471  cls.SetInParent = Modified
1472  cls._UpdateOneofState = _UpdateOneofState
1473
1474
1475class _Listener(object):
1476
1477  """MessageListener implementation that a parent message registers with its
1478  child message.
1479
1480  In order to support semantics like:
1481
1482    foo.bar.baz.moo = 23
1483    assert foo.HasField('bar')
1484
1485  ...child objects must have back references to their parents.
1486  This helper class is at the heart of this support.
1487  """
1488
1489  def __init__(self, parent_message):
1490    """Args:
1491      parent_message: The message whose _Modified() method we should call when
1492        we receive Modified() messages.
1493    """
1494    # This listener establishes a back reference from a child (contained) object
1495    # to its parent (containing) object.  We make this a weak reference to avoid
1496    # creating cyclic garbage when the client finishes with the 'parent' object
1497    # in the tree.
1498    if isinstance(parent_message, weakref.ProxyType):
1499      self._parent_message_weakref = parent_message
1500    else:
1501      self._parent_message_weakref = weakref.proxy(parent_message)
1502
1503    # As an optimization, we also indicate directly on the listener whether
1504    # or not the parent message is dirty.  This way we can avoid traversing
1505    # up the tree in the common case.
1506    self.dirty = False
1507
1508  def Modified(self):
1509    if self.dirty:
1510      return
1511    try:
1512      # Propagate the signal to our parents iff this is the first field set.
1513      self._parent_message_weakref._Modified()
1514    except ReferenceError:
1515      # We can get here if a client has kept a reference to a child object,
1516      # and is now setting a field on it, but the child's parent has been
1517      # garbage-collected.  This is not an error.
1518      pass
1519
1520
1521class _OneofListener(_Listener):
1522  """Special listener implementation for setting composite oneof fields."""
1523
1524  def __init__(self, parent_message, field):
1525    """Args:
1526      parent_message: The message whose _Modified() method we should call when
1527        we receive Modified() messages.
1528      field: The descriptor of the field being set in the parent message.
1529    """
1530    super(_OneofListener, self).__init__(parent_message)
1531    self._field = field
1532
1533  def Modified(self):
1534    """Also updates the state of the containing oneof in the parent message."""
1535    try:
1536      self._parent_message_weakref._UpdateOneofState(self._field)
1537      super(_OneofListener, self).Modified()
1538    except ReferenceError:
1539      pass
1540