xref: /aosp_15_r20/external/cronet/third_party/protobuf/python/google/protobuf/descriptor_pool.py (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Provides DescriptorPool to use as a container for proto2 descriptors.
32
33The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
34a collection of protocol buffer descriptors for use when dynamically creating
35message types at runtime.
36
37For most applications protocol buffers should be used via modules generated by
38the protocol buffer compiler tool. This should only be used when the type of
39protocol buffers used in an application or library cannot be predetermined.
40
41Below is a straightforward example on how to use this class::
42
43  pool = DescriptorPool()
44  file_descriptor_protos = [ ... ]
45  for file_descriptor_proto in file_descriptor_protos:
46    pool.Add(file_descriptor_proto)
47  my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
48
49The message descriptor can be used in conjunction with the message_factory
50module in order to create a protocol buffer class that can be encoded and
51decoded.
52
53If you want to get a Python class for the specified proto, use the
54helper functions inside google.protobuf.message_factory
55directly instead of this class.
56"""
57
58__author__ = '[email protected] (Matt Toia)'
59
60import collections
61import warnings
62
63from google.protobuf import descriptor
64from google.protobuf import descriptor_database
65from google.protobuf import text_encoding
66
67
68_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS  # pylint: disable=protected-access
69
70
71def _Deprecated(func):
72  """Mark functions as deprecated."""
73
74  def NewFunc(*args, **kwargs):
75    warnings.warn(
76        'Call to deprecated function %s(). Note: Do add unlinked descriptors '
77        'to descriptor_pool is wrong. Use Add() or AddSerializedFile() '
78        'instead.' % func.__name__,
79        category=DeprecationWarning)
80    return func(*args, **kwargs)
81  NewFunc.__name__ = func.__name__
82  NewFunc.__doc__ = func.__doc__
83  NewFunc.__dict__.update(func.__dict__)
84  return NewFunc
85
86
87def _NormalizeFullyQualifiedName(name):
88  """Remove leading period from fully-qualified type name.
89
90  Due to b/13860351 in descriptor_database.py, types in the root namespace are
91  generated with a leading period. This function removes that prefix.
92
93  Args:
94    name (str): The fully-qualified symbol name.
95
96  Returns:
97    str: The normalized fully-qualified symbol name.
98  """
99  return name.lstrip('.')
100
101
102def _OptionsOrNone(descriptor_proto):
103  """Returns the value of the field `options`, or None if it is not set."""
104  if descriptor_proto.HasField('options'):
105    return descriptor_proto.options
106  else:
107    return None
108
109
110def _IsMessageSetExtension(field):
111  return (field.is_extension and
112          field.containing_type.has_options and
113          field.containing_type.GetOptions().message_set_wire_format and
114          field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
115          field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
116
117
118class DescriptorPool(object):
119  """A collection of protobufs dynamically constructed by descriptor protos."""
120
121  if _USE_C_DESCRIPTORS:
122
123    def __new__(cls, descriptor_db=None):
124      # pylint: disable=protected-access
125      return descriptor._message.DescriptorPool(descriptor_db)
126
127  def __init__(self, descriptor_db=None):
128    """Initializes a Pool of proto buffs.
129
130    The descriptor_db argument to the constructor is provided to allow
131    specialized file descriptor proto lookup code to be triggered on demand. An
132    example would be an implementation which will read and compile a file
133    specified in a call to FindFileByName() and not require the call to Add()
134    at all. Results from this database will be cached internally here as well.
135
136    Args:
137      descriptor_db: A secondary source of file descriptors.
138    """
139
140    self._internal_db = descriptor_database.DescriptorDatabase()
141    self._descriptor_db = descriptor_db
142    self._descriptors = {}
143    self._enum_descriptors = {}
144    self._service_descriptors = {}
145    self._file_descriptors = {}
146    self._toplevel_extensions = {}
147    # TODO(jieluo): Remove _file_desc_by_toplevel_extension after
148    # maybe year 2020 for compatibility issue (with 3.4.1 only).
149    self._file_desc_by_toplevel_extension = {}
150    self._top_enum_values = {}
151    # We store extensions in two two-level mappings: The first key is the
152    # descriptor of the message being extended, the second key is the extension
153    # full name or its tag number.
154    self._extensions_by_name = collections.defaultdict(dict)
155    self._extensions_by_number = collections.defaultdict(dict)
156
157  def _CheckConflictRegister(self, desc, desc_name, file_name):
158    """Check if the descriptor name conflicts with another of the same name.
159
160    Args:
161      desc: Descriptor of a message, enum, service, extension or enum value.
162      desc_name (str): the full name of desc.
163      file_name (str): The file name of descriptor.
164    """
165    for register, descriptor_type in [
166        (self._descriptors, descriptor.Descriptor),
167        (self._enum_descriptors, descriptor.EnumDescriptor),
168        (self._service_descriptors, descriptor.ServiceDescriptor),
169        (self._toplevel_extensions, descriptor.FieldDescriptor),
170        (self._top_enum_values, descriptor.EnumValueDescriptor)]:
171      if desc_name in register:
172        old_desc = register[desc_name]
173        if isinstance(old_desc, descriptor.EnumValueDescriptor):
174          old_file = old_desc.type.file.name
175        else:
176          old_file = old_desc.file.name
177
178        if not isinstance(desc, descriptor_type) or (
179            old_file != file_name):
180          error_msg = ('Conflict register for file "' + file_name +
181                       '": ' + desc_name +
182                       ' is already defined in file "' +
183                       old_file + '". Please fix the conflict by adding '
184                       'package name on the proto file, or use different '
185                       'name for the duplication.')
186          if isinstance(desc, descriptor.EnumValueDescriptor):
187            error_msg += ('\nNote: enum values appear as '
188                          'siblings of the enum type instead of '
189                          'children of it.')
190
191          raise TypeError(error_msg)
192
193        return
194
195  def Add(self, file_desc_proto):
196    """Adds the FileDescriptorProto and its types to this pool.
197
198    Args:
199      file_desc_proto (FileDescriptorProto): The file descriptor to add.
200    """
201
202    self._internal_db.Add(file_desc_proto)
203
204  def AddSerializedFile(self, serialized_file_desc_proto):
205    """Adds the FileDescriptorProto and its types to this pool.
206
207    Args:
208      serialized_file_desc_proto (bytes): A bytes string, serialization of the
209        :class:`FileDescriptorProto` to add.
210
211    Returns:
212      FileDescriptor: Descriptor for the added file.
213    """
214
215    # pylint: disable=g-import-not-at-top
216    from google.protobuf import descriptor_pb2
217    file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
218        serialized_file_desc_proto)
219    file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto)
220    file_desc.serialized_pb = serialized_file_desc_proto
221    return file_desc
222
223  # Add Descriptor to descriptor pool is dreprecated. Please use Add()
224  # or AddSerializedFile() to add a FileDescriptorProto instead.
225  @_Deprecated
226  def AddDescriptor(self, desc):
227    self._AddDescriptor(desc)
228
229  # Never call this method. It is for internal usage only.
230  def _AddDescriptor(self, desc):
231    """Adds a Descriptor to the pool, non-recursively.
232
233    If the Descriptor contains nested messages or enums, the caller must
234    explicitly register them. This method also registers the FileDescriptor
235    associated with the message.
236
237    Args:
238      desc: A Descriptor.
239    """
240    if not isinstance(desc, descriptor.Descriptor):
241      raise TypeError('Expected instance of descriptor.Descriptor.')
242
243    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
244
245    self._descriptors[desc.full_name] = desc
246    self._AddFileDescriptor(desc.file)
247
248  # Add EnumDescriptor to descriptor pool is dreprecated. Please use Add()
249  # or AddSerializedFile() to add a FileDescriptorProto instead.
250  @_Deprecated
251  def AddEnumDescriptor(self, enum_desc):
252    self._AddEnumDescriptor(enum_desc)
253
254  # Never call this method. It is for internal usage only.
255  def _AddEnumDescriptor(self, enum_desc):
256    """Adds an EnumDescriptor to the pool.
257
258    This method also registers the FileDescriptor associated with the enum.
259
260    Args:
261      enum_desc: An EnumDescriptor.
262    """
263
264    if not isinstance(enum_desc, descriptor.EnumDescriptor):
265      raise TypeError('Expected instance of descriptor.EnumDescriptor.')
266
267    file_name = enum_desc.file.name
268    self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
269    self._enum_descriptors[enum_desc.full_name] = enum_desc
270
271    # Top enum values need to be indexed.
272    # Count the number of dots to see whether the enum is toplevel or nested
273    # in a message. We cannot use enum_desc.containing_type at this stage.
274    if enum_desc.file.package:
275      top_level = (enum_desc.full_name.count('.')
276                   - enum_desc.file.package.count('.') == 1)
277    else:
278      top_level = enum_desc.full_name.count('.') == 0
279    if top_level:
280      file_name = enum_desc.file.name
281      package = enum_desc.file.package
282      for enum_value in enum_desc.values:
283        full_name = _NormalizeFullyQualifiedName(
284            '.'.join((package, enum_value.name)))
285        self._CheckConflictRegister(enum_value, full_name, file_name)
286        self._top_enum_values[full_name] = enum_value
287    self._AddFileDescriptor(enum_desc.file)
288
289  # Add ServiceDescriptor to descriptor pool is dreprecated. Please use Add()
290  # or AddSerializedFile() to add a FileDescriptorProto instead.
291  @_Deprecated
292  def AddServiceDescriptor(self, service_desc):
293    self._AddServiceDescriptor(service_desc)
294
295  # Never call this method. It is for internal usage only.
296  def _AddServiceDescriptor(self, service_desc):
297    """Adds a ServiceDescriptor to the pool.
298
299    Args:
300      service_desc: A ServiceDescriptor.
301    """
302
303    if not isinstance(service_desc, descriptor.ServiceDescriptor):
304      raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
305
306    self._CheckConflictRegister(service_desc, service_desc.full_name,
307                                service_desc.file.name)
308    self._service_descriptors[service_desc.full_name] = service_desc
309
310  # Add ExtensionDescriptor to descriptor pool is dreprecated. Please use Add()
311  # or AddSerializedFile() to add a FileDescriptorProto instead.
312  @_Deprecated
313  def AddExtensionDescriptor(self, extension):
314    self._AddExtensionDescriptor(extension)
315
316  # Never call this method. It is for internal usage only.
317  def _AddExtensionDescriptor(self, extension):
318    """Adds a FieldDescriptor describing an extension to the pool.
319
320    Args:
321      extension: A FieldDescriptor.
322
323    Raises:
324      AssertionError: when another extension with the same number extends the
325        same message.
326      TypeError: when the specified extension is not a
327        descriptor.FieldDescriptor.
328    """
329    if not (isinstance(extension, descriptor.FieldDescriptor) and
330            extension.is_extension):
331      raise TypeError('Expected an extension descriptor.')
332
333    if extension.extension_scope is None:
334      self._toplevel_extensions[extension.full_name] = extension
335
336    try:
337      existing_desc = self._extensions_by_number[
338          extension.containing_type][extension.number]
339    except KeyError:
340      pass
341    else:
342      if extension is not existing_desc:
343        raise AssertionError(
344            'Extensions "%s" and "%s" both try to extend message type "%s" '
345            'with field number %d.' %
346            (extension.full_name, existing_desc.full_name,
347             extension.containing_type.full_name, extension.number))
348
349    self._extensions_by_number[extension.containing_type][
350        extension.number] = extension
351    self._extensions_by_name[extension.containing_type][
352        extension.full_name] = extension
353
354    # Also register MessageSet extensions with the type name.
355    if _IsMessageSetExtension(extension):
356      self._extensions_by_name[extension.containing_type][
357          extension.message_type.full_name] = extension
358
359  @_Deprecated
360  def AddFileDescriptor(self, file_desc):
361    self._InternalAddFileDescriptor(file_desc)
362
363  # Never call this method. It is for internal usage only.
364  def _InternalAddFileDescriptor(self, file_desc):
365    """Adds a FileDescriptor to the pool, non-recursively.
366
367    If the FileDescriptor contains messages or enums, the caller must explicitly
368    register them.
369
370    Args:
371      file_desc: A FileDescriptor.
372    """
373
374    self._AddFileDescriptor(file_desc)
375    # TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
376    # FieldDescriptor.file is added in code gen. Remove this solution after
377    # maybe 2020 for compatibility reason (with 3.4.1 only).
378    for extension in file_desc.extensions_by_name.values():
379      self._file_desc_by_toplevel_extension[
380          extension.full_name] = file_desc
381
382  def _AddFileDescriptor(self, file_desc):
383    """Adds a FileDescriptor to the pool, non-recursively.
384
385    If the FileDescriptor contains messages or enums, the caller must explicitly
386    register them.
387
388    Args:
389      file_desc: A FileDescriptor.
390    """
391
392    if not isinstance(file_desc, descriptor.FileDescriptor):
393      raise TypeError('Expected instance of descriptor.FileDescriptor.')
394    self._file_descriptors[file_desc.name] = file_desc
395
396  def FindFileByName(self, file_name):
397    """Gets a FileDescriptor by file name.
398
399    Args:
400      file_name (str): The path to the file to get a descriptor for.
401
402    Returns:
403      FileDescriptor: The descriptor for the named file.
404
405    Raises:
406      KeyError: if the file cannot be found in the pool.
407    """
408
409    try:
410      return self._file_descriptors[file_name]
411    except KeyError:
412      pass
413
414    try:
415      file_proto = self._internal_db.FindFileByName(file_name)
416    except KeyError as error:
417      if self._descriptor_db:
418        file_proto = self._descriptor_db.FindFileByName(file_name)
419      else:
420        raise error
421    if not file_proto:
422      raise KeyError('Cannot find a file named %s' % file_name)
423    return self._ConvertFileProtoToFileDescriptor(file_proto)
424
425  def FindFileContainingSymbol(self, symbol):
426    """Gets the FileDescriptor for the file containing the specified symbol.
427
428    Args:
429      symbol (str): The name of the symbol to search for.
430
431    Returns:
432      FileDescriptor: Descriptor for the file that contains the specified
433      symbol.
434
435    Raises:
436      KeyError: if the file cannot be found in the pool.
437    """
438
439    symbol = _NormalizeFullyQualifiedName(symbol)
440    try:
441      return self._InternalFindFileContainingSymbol(symbol)
442    except KeyError:
443      pass
444
445    try:
446      # Try fallback database. Build and find again if possible.
447      self._FindFileContainingSymbolInDb(symbol)
448      return self._InternalFindFileContainingSymbol(symbol)
449    except KeyError:
450      raise KeyError('Cannot find a file containing %s' % symbol)
451
452  def _InternalFindFileContainingSymbol(self, symbol):
453    """Gets the already built FileDescriptor containing the specified symbol.
454
455    Args:
456      symbol (str): The name of the symbol to search for.
457
458    Returns:
459      FileDescriptor: Descriptor for the file that contains the specified
460      symbol.
461
462    Raises:
463      KeyError: if the file cannot be found in the pool.
464    """
465    try:
466      return self._descriptors[symbol].file
467    except KeyError:
468      pass
469
470    try:
471      return self._enum_descriptors[symbol].file
472    except KeyError:
473      pass
474
475    try:
476      return self._service_descriptors[symbol].file
477    except KeyError:
478      pass
479
480    try:
481      return self._top_enum_values[symbol].type.file
482    except KeyError:
483      pass
484
485    try:
486      return self._file_desc_by_toplevel_extension[symbol]
487    except KeyError:
488      pass
489
490    # Try fields, enum values and nested extensions inside a message.
491    top_name, _, sub_name = symbol.rpartition('.')
492    try:
493      message = self.FindMessageTypeByName(top_name)
494      assert (sub_name in message.extensions_by_name or
495              sub_name in message.fields_by_name or
496              sub_name in message.enum_values_by_name)
497      return message.file
498    except (KeyError, AssertionError):
499      raise KeyError('Cannot find a file containing %s' % symbol)
500
501  def FindMessageTypeByName(self, full_name):
502    """Loads the named descriptor from the pool.
503
504    Args:
505      full_name (str): The full name of the descriptor to load.
506
507    Returns:
508      Descriptor: The descriptor for the named type.
509
510    Raises:
511      KeyError: if the message cannot be found in the pool.
512    """
513
514    full_name = _NormalizeFullyQualifiedName(full_name)
515    if full_name not in self._descriptors:
516      self._FindFileContainingSymbolInDb(full_name)
517    return self._descriptors[full_name]
518
519  def FindEnumTypeByName(self, full_name):
520    """Loads the named enum descriptor from the pool.
521
522    Args:
523      full_name (str): The full name of the enum descriptor to load.
524
525    Returns:
526      EnumDescriptor: The enum descriptor for the named type.
527
528    Raises:
529      KeyError: if the enum cannot be found in the pool.
530    """
531
532    full_name = _NormalizeFullyQualifiedName(full_name)
533    if full_name not in self._enum_descriptors:
534      self._FindFileContainingSymbolInDb(full_name)
535    return self._enum_descriptors[full_name]
536
537  def FindFieldByName(self, full_name):
538    """Loads the named field descriptor from the pool.
539
540    Args:
541      full_name (str): The full name of the field descriptor to load.
542
543    Returns:
544      FieldDescriptor: The field descriptor for the named field.
545
546    Raises:
547      KeyError: if the field cannot be found in the pool.
548    """
549    full_name = _NormalizeFullyQualifiedName(full_name)
550    message_name, _, field_name = full_name.rpartition('.')
551    message_descriptor = self.FindMessageTypeByName(message_name)
552    return message_descriptor.fields_by_name[field_name]
553
554  def FindOneofByName(self, full_name):
555    """Loads the named oneof descriptor from the pool.
556
557    Args:
558      full_name (str): The full name of the oneof descriptor to load.
559
560    Returns:
561      OneofDescriptor: The oneof descriptor for the named oneof.
562
563    Raises:
564      KeyError: if the oneof cannot be found in the pool.
565    """
566    full_name = _NormalizeFullyQualifiedName(full_name)
567    message_name, _, oneof_name = full_name.rpartition('.')
568    message_descriptor = self.FindMessageTypeByName(message_name)
569    return message_descriptor.oneofs_by_name[oneof_name]
570
571  def FindExtensionByName(self, full_name):
572    """Loads the named extension descriptor from the pool.
573
574    Args:
575      full_name (str): The full name of the extension descriptor to load.
576
577    Returns:
578      FieldDescriptor: The field descriptor for the named extension.
579
580    Raises:
581      KeyError: if the extension cannot be found in the pool.
582    """
583    full_name = _NormalizeFullyQualifiedName(full_name)
584    try:
585      # The proto compiler does not give any link between the FileDescriptor
586      # and top-level extensions unless the FileDescriptorProto is added to
587      # the DescriptorDatabase, but this can impact memory usage.
588      # So we registered these extensions by name explicitly.
589      return self._toplevel_extensions[full_name]
590    except KeyError:
591      pass
592    message_name, _, extension_name = full_name.rpartition('.')
593    try:
594      # Most extensions are nested inside a message.
595      scope = self.FindMessageTypeByName(message_name)
596    except KeyError:
597      # Some extensions are defined at file scope.
598      scope = self._FindFileContainingSymbolInDb(full_name)
599    return scope.extensions_by_name[extension_name]
600
601  def FindExtensionByNumber(self, message_descriptor, number):
602    """Gets the extension of the specified message with the specified number.
603
604    Extensions have to be registered to this pool by calling :func:`Add` or
605    :func:`AddExtensionDescriptor`.
606
607    Args:
608      message_descriptor (Descriptor): descriptor of the extended message.
609      number (int): Number of the extension field.
610
611    Returns:
612      FieldDescriptor: The descriptor for the extension.
613
614    Raises:
615      KeyError: when no extension with the given number is known for the
616        specified message.
617    """
618    try:
619      return self._extensions_by_number[message_descriptor][number]
620    except KeyError:
621      self._TryLoadExtensionFromDB(message_descriptor, number)
622      return self._extensions_by_number[message_descriptor][number]
623
624  def FindAllExtensions(self, message_descriptor):
625    """Gets all the known extensions of a given message.
626
627    Extensions have to be registered to this pool by build related
628    :func:`Add` or :func:`AddExtensionDescriptor`.
629
630    Args:
631      message_descriptor (Descriptor): Descriptor of the extended message.
632
633    Returns:
634      list[FieldDescriptor]: Field descriptors describing the extensions.
635    """
636    # Fallback to descriptor db if FindAllExtensionNumbers is provided.
637    if self._descriptor_db and hasattr(
638        self._descriptor_db, 'FindAllExtensionNumbers'):
639      full_name = message_descriptor.full_name
640      all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
641      for number in all_numbers:
642        if number in self._extensions_by_number[message_descriptor]:
643          continue
644        self._TryLoadExtensionFromDB(message_descriptor, number)
645
646    return list(self._extensions_by_number[message_descriptor].values())
647
648  def _TryLoadExtensionFromDB(self, message_descriptor, number):
649    """Try to Load extensions from descriptor db.
650
651    Args:
652      message_descriptor: descriptor of the extended message.
653      number: the extension number that needs to be loaded.
654    """
655    if not self._descriptor_db:
656      return
657    # Only supported when FindFileContainingExtension is provided.
658    if not hasattr(
659        self._descriptor_db, 'FindFileContainingExtension'):
660      return
661
662    full_name = message_descriptor.full_name
663    file_proto = self._descriptor_db.FindFileContainingExtension(
664        full_name, number)
665
666    if file_proto is None:
667      return
668
669    try:
670      self._ConvertFileProtoToFileDescriptor(file_proto)
671    except:
672      warn_msg = ('Unable to load proto file %s for extension number %d.' %
673                  (file_proto.name, number))
674      warnings.warn(warn_msg, RuntimeWarning)
675
676  def FindServiceByName(self, full_name):
677    """Loads the named service descriptor from the pool.
678
679    Args:
680      full_name (str): The full name of the service descriptor to load.
681
682    Returns:
683      ServiceDescriptor: The service descriptor for the named service.
684
685    Raises:
686      KeyError: if the service cannot be found in the pool.
687    """
688    full_name = _NormalizeFullyQualifiedName(full_name)
689    if full_name not in self._service_descriptors:
690      self._FindFileContainingSymbolInDb(full_name)
691    return self._service_descriptors[full_name]
692
693  def FindMethodByName(self, full_name):
694    """Loads the named service method descriptor from the pool.
695
696    Args:
697      full_name (str): The full name of the method descriptor to load.
698
699    Returns:
700      MethodDescriptor: The method descriptor for the service method.
701
702    Raises:
703      KeyError: if the method cannot be found in the pool.
704    """
705    full_name = _NormalizeFullyQualifiedName(full_name)
706    service_name, _, method_name = full_name.rpartition('.')
707    service_descriptor = self.FindServiceByName(service_name)
708    return service_descriptor.methods_by_name[method_name]
709
710  def _FindFileContainingSymbolInDb(self, symbol):
711    """Finds the file in descriptor DB containing the specified symbol.
712
713    Args:
714      symbol (str): The name of the symbol to search for.
715
716    Returns:
717      FileDescriptor: The file that contains the specified symbol.
718
719    Raises:
720      KeyError: if the file cannot be found in the descriptor database.
721    """
722    try:
723      file_proto = self._internal_db.FindFileContainingSymbol(symbol)
724    except KeyError as error:
725      if self._descriptor_db:
726        file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
727      else:
728        raise error
729    if not file_proto:
730      raise KeyError('Cannot find a file containing %s' % symbol)
731    return self._ConvertFileProtoToFileDescriptor(file_proto)
732
733  def _ConvertFileProtoToFileDescriptor(self, file_proto):
734    """Creates a FileDescriptor from a proto or returns a cached copy.
735
736    This method also has the side effect of loading all the symbols found in
737    the file into the appropriate dictionaries in the pool.
738
739    Args:
740      file_proto: The proto to convert.
741
742    Returns:
743      A FileDescriptor matching the passed in proto.
744    """
745    if file_proto.name not in self._file_descriptors:
746      built_deps = list(self._GetDeps(file_proto.dependency))
747      direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
748      public_deps = [direct_deps[i] for i in file_proto.public_dependency]
749
750      file_descriptor = descriptor.FileDescriptor(
751          pool=self,
752          name=file_proto.name,
753          package=file_proto.package,
754          syntax=file_proto.syntax,
755          options=_OptionsOrNone(file_proto),
756          serialized_pb=file_proto.SerializeToString(),
757          dependencies=direct_deps,
758          public_dependencies=public_deps,
759          # pylint: disable=protected-access
760          create_key=descriptor._internal_create_key)
761      scope = {}
762
763      # This loop extracts all the message and enum types from all the
764      # dependencies of the file_proto. This is necessary to create the
765      # scope of available message types when defining the passed in
766      # file proto.
767      for dependency in built_deps:
768        scope.update(self._ExtractSymbols(
769            dependency.message_types_by_name.values()))
770        scope.update((_PrefixWithDot(enum.full_name), enum)
771                     for enum in dependency.enum_types_by_name.values())
772
773      for message_type in file_proto.message_type:
774        message_desc = self._ConvertMessageDescriptor(
775            message_type, file_proto.package, file_descriptor, scope,
776            file_proto.syntax)
777        file_descriptor.message_types_by_name[message_desc.name] = (
778            message_desc)
779
780      for enum_type in file_proto.enum_type:
781        file_descriptor.enum_types_by_name[enum_type.name] = (
782            self._ConvertEnumDescriptor(enum_type, file_proto.package,
783                                        file_descriptor, None, scope, True))
784
785      for index, extension_proto in enumerate(file_proto.extension):
786        extension_desc = self._MakeFieldDescriptor(
787            extension_proto, file_proto.package, index, file_descriptor,
788            is_extension=True)
789        extension_desc.containing_type = self._GetTypeFromScope(
790            file_descriptor.package, extension_proto.extendee, scope)
791        self._SetFieldType(extension_proto, extension_desc,
792                           file_descriptor.package, scope)
793        file_descriptor.extensions_by_name[extension_desc.name] = (
794            extension_desc)
795        self._file_desc_by_toplevel_extension[extension_desc.full_name] = (
796            file_descriptor)
797
798      for desc_proto in file_proto.message_type:
799        self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
800
801      if file_proto.package:
802        desc_proto_prefix = _PrefixWithDot(file_proto.package)
803      else:
804        desc_proto_prefix = ''
805
806      for desc_proto in file_proto.message_type:
807        desc = self._GetTypeFromScope(
808            desc_proto_prefix, desc_proto.name, scope)
809        file_descriptor.message_types_by_name[desc_proto.name] = desc
810
811      for index, service_proto in enumerate(file_proto.service):
812        file_descriptor.services_by_name[service_proto.name] = (
813            self._MakeServiceDescriptor(service_proto, index, scope,
814                                        file_proto.package, file_descriptor))
815
816      self._file_descriptors[file_proto.name] = file_descriptor
817
818    # Add extensions to the pool
819    file_desc = self._file_descriptors[file_proto.name]
820    for extension in file_desc.extensions_by_name.values():
821      self._AddExtensionDescriptor(extension)
822    for message_type in file_desc.message_types_by_name.values():
823      for extension in message_type.extensions:
824        self._AddExtensionDescriptor(extension)
825
826    return file_desc
827
828  def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
829                                scope=None, syntax=None):
830    """Adds the proto to the pool in the specified package.
831
832    Args:
833      desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
834      package: The package the proto should be located in.
835      file_desc: The file containing this message.
836      scope: Dict mapping short and full symbols to message and enum types.
837      syntax: string indicating syntax of the file ("proto2" or "proto3")
838
839    Returns:
840      The added descriptor.
841    """
842
843    if package:
844      desc_name = '.'.join((package, desc_proto.name))
845    else:
846      desc_name = desc_proto.name
847
848    if file_desc is None:
849      file_name = None
850    else:
851      file_name = file_desc.name
852
853    if scope is None:
854      scope = {}
855
856    nested = [
857        self._ConvertMessageDescriptor(
858            nested, desc_name, file_desc, scope, syntax)
859        for nested in desc_proto.nested_type]
860    enums = [
861        self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
862                                    scope, False)
863        for enum in desc_proto.enum_type]
864    fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
865              for index, field in enumerate(desc_proto.field)]
866    extensions = [
867        self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
868                                  is_extension=True)
869        for index, extension in enumerate(desc_proto.extension)]
870    oneofs = [
871        # pylint: disable=g-complex-comprehension
872        descriptor.OneofDescriptor(
873            desc.name,
874            '.'.join((desc_name, desc.name)),
875            index,
876            None,
877            [],
878            _OptionsOrNone(desc),
879            # pylint: disable=protected-access
880            create_key=descriptor._internal_create_key)
881        for index, desc in enumerate(desc_proto.oneof_decl)
882    ]
883    extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
884    if extension_ranges:
885      is_extendable = True
886    else:
887      is_extendable = False
888    desc = descriptor.Descriptor(
889        name=desc_proto.name,
890        full_name=desc_name,
891        filename=file_name,
892        containing_type=None,
893        fields=fields,
894        oneofs=oneofs,
895        nested_types=nested,
896        enum_types=enums,
897        extensions=extensions,
898        options=_OptionsOrNone(desc_proto),
899        is_extendable=is_extendable,
900        extension_ranges=extension_ranges,
901        file=file_desc,
902        serialized_start=None,
903        serialized_end=None,
904        syntax=syntax,
905        # pylint: disable=protected-access
906        create_key=descriptor._internal_create_key)
907    for nested in desc.nested_types:
908      nested.containing_type = desc
909    for enum in desc.enum_types:
910      enum.containing_type = desc
911    for field_index, field_desc in enumerate(desc_proto.field):
912      if field_desc.HasField('oneof_index'):
913        oneof_index = field_desc.oneof_index
914        oneofs[oneof_index].fields.append(fields[field_index])
915        fields[field_index].containing_oneof = oneofs[oneof_index]
916
917    scope[_PrefixWithDot(desc_name)] = desc
918    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
919    self._descriptors[desc_name] = desc
920    return desc
921
922  def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
923                             containing_type=None, scope=None, top_level=False):
924    """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
925
926    Args:
927      enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
928      package: Optional package name for the new message EnumDescriptor.
929      file_desc: The file containing the enum descriptor.
930      containing_type: The type containing this enum.
931      scope: Scope containing available types.
932      top_level: If True, the enum is a top level symbol. If False, the enum
933          is defined inside a message.
934
935    Returns:
936      The added descriptor
937    """
938
939    if package:
940      enum_name = '.'.join((package, enum_proto.name))
941    else:
942      enum_name = enum_proto.name
943
944    if file_desc is None:
945      file_name = None
946    else:
947      file_name = file_desc.name
948
949    values = [self._MakeEnumValueDescriptor(value, index)
950              for index, value in enumerate(enum_proto.value)]
951    desc = descriptor.EnumDescriptor(name=enum_proto.name,
952                                     full_name=enum_name,
953                                     filename=file_name,
954                                     file=file_desc,
955                                     values=values,
956                                     containing_type=containing_type,
957                                     options=_OptionsOrNone(enum_proto),
958                                     # pylint: disable=protected-access
959                                     create_key=descriptor._internal_create_key)
960    scope['.%s' % enum_name] = desc
961    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
962    self._enum_descriptors[enum_name] = desc
963
964    # Add top level enum values.
965    if top_level:
966      for value in values:
967        full_name = _NormalizeFullyQualifiedName(
968            '.'.join((package, value.name)))
969        self._CheckConflictRegister(value, full_name, file_name)
970        self._top_enum_values[full_name] = value
971
972    return desc
973
974  def _MakeFieldDescriptor(self, field_proto, message_name, index,
975                           file_desc, is_extension=False):
976    """Creates a field descriptor from a FieldDescriptorProto.
977
978    For message and enum type fields, this method will do a look up
979    in the pool for the appropriate descriptor for that type. If it
980    is unavailable, it will fall back to the _source function to
981    create it. If this type is still unavailable, construction will
982    fail.
983
984    Args:
985      field_proto: The proto describing the field.
986      message_name: The name of the containing message.
987      index: Index of the field
988      file_desc: The file containing the field descriptor.
989      is_extension: Indication that this field is for an extension.
990
991    Returns:
992      An initialized FieldDescriptor object
993    """
994
995    if message_name:
996      full_name = '.'.join((message_name, field_proto.name))
997    else:
998      full_name = field_proto.name
999
1000    if field_proto.json_name:
1001      json_name = field_proto.json_name
1002    else:
1003      json_name = None
1004
1005    return descriptor.FieldDescriptor(
1006        name=field_proto.name,
1007        full_name=full_name,
1008        index=index,
1009        number=field_proto.number,
1010        type=field_proto.type,
1011        cpp_type=None,
1012        message_type=None,
1013        enum_type=None,
1014        containing_type=None,
1015        label=field_proto.label,
1016        has_default_value=False,
1017        default_value=None,
1018        is_extension=is_extension,
1019        extension_scope=None,
1020        options=_OptionsOrNone(field_proto),
1021        json_name=json_name,
1022        file=file_desc,
1023        # pylint: disable=protected-access
1024        create_key=descriptor._internal_create_key)
1025
1026  def _SetAllFieldTypes(self, package, desc_proto, scope):
1027    """Sets all the descriptor's fields's types.
1028
1029    This method also sets the containing types on any extensions.
1030
1031    Args:
1032      package: The current package of desc_proto.
1033      desc_proto: The message descriptor to update.
1034      scope: Enclosing scope of available types.
1035    """
1036
1037    package = _PrefixWithDot(package)
1038
1039    main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
1040
1041    if package == '.':
1042      nested_package = _PrefixWithDot(desc_proto.name)
1043    else:
1044      nested_package = '.'.join([package, desc_proto.name])
1045
1046    for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
1047      self._SetFieldType(field_proto, field_desc, nested_package, scope)
1048
1049    for extension_proto, extension_desc in (
1050        zip(desc_proto.extension, main_desc.extensions)):
1051      extension_desc.containing_type = self._GetTypeFromScope(
1052          nested_package, extension_proto.extendee, scope)
1053      self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
1054
1055    for nested_type in desc_proto.nested_type:
1056      self._SetAllFieldTypes(nested_package, nested_type, scope)
1057
1058  def _SetFieldType(self, field_proto, field_desc, package, scope):
1059    """Sets the field's type, cpp_type, message_type and enum_type.
1060
1061    Args:
1062      field_proto: Data about the field in proto format.
1063      field_desc: The descriptor to modify.
1064      package: The package the field's container is in.
1065      scope: Enclosing scope of available types.
1066    """
1067    if field_proto.type_name:
1068      desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
1069    else:
1070      desc = None
1071
1072    if not field_proto.HasField('type'):
1073      if isinstance(desc, descriptor.Descriptor):
1074        field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
1075      else:
1076        field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
1077
1078    field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
1079        field_proto.type)
1080
1081    if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
1082        or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
1083      field_desc.message_type = desc
1084
1085    if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1086      field_desc.enum_type = desc
1087
1088    if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1089      field_desc.has_default_value = False
1090      field_desc.default_value = []
1091    elif field_proto.HasField('default_value'):
1092      field_desc.has_default_value = True
1093      if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1094          field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1095        field_desc.default_value = float(field_proto.default_value)
1096      elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1097        field_desc.default_value = field_proto.default_value
1098      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1099        field_desc.default_value = field_proto.default_value.lower() == 'true'
1100      elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1101        field_desc.default_value = field_desc.enum_type.values_by_name[
1102            field_proto.default_value].number
1103      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1104        field_desc.default_value = text_encoding.CUnescape(
1105            field_proto.default_value)
1106      elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1107        field_desc.default_value = None
1108      else:
1109        # All other types are of the "int" type.
1110        field_desc.default_value = int(field_proto.default_value)
1111    else:
1112      field_desc.has_default_value = False
1113      if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1114          field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1115        field_desc.default_value = 0.0
1116      elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1117        field_desc.default_value = u''
1118      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1119        field_desc.default_value = False
1120      elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1121        field_desc.default_value = field_desc.enum_type.values[0].number
1122      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1123        field_desc.default_value = b''
1124      elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1125        field_desc.default_value = None
1126      elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP:
1127        field_desc.default_value = None
1128      else:
1129        # All other types are of the "int" type.
1130        field_desc.default_value = 0
1131
1132    field_desc.type = field_proto.type
1133
1134  def _MakeEnumValueDescriptor(self, value_proto, index):
1135    """Creates a enum value descriptor object from a enum value proto.
1136
1137    Args:
1138      value_proto: The proto describing the enum value.
1139      index: The index of the enum value.
1140
1141    Returns:
1142      An initialized EnumValueDescriptor object.
1143    """
1144
1145    return descriptor.EnumValueDescriptor(
1146        name=value_proto.name,
1147        index=index,
1148        number=value_proto.number,
1149        options=_OptionsOrNone(value_proto),
1150        type=None,
1151        # pylint: disable=protected-access
1152        create_key=descriptor._internal_create_key)
1153
1154  def _MakeServiceDescriptor(self, service_proto, service_index, scope,
1155                             package, file_desc):
1156    """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
1157
1158    Args:
1159      service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
1160      service_index: The index of the service in the File.
1161      scope: Dict mapping short and full symbols to message and enum types.
1162      package: Optional package name for the new message EnumDescriptor.
1163      file_desc: The file containing the service descriptor.
1164
1165    Returns:
1166      The added descriptor.
1167    """
1168
1169    if package:
1170      service_name = '.'.join((package, service_proto.name))
1171    else:
1172      service_name = service_proto.name
1173
1174    methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
1175                                          scope, index)
1176               for index, method_proto in enumerate(service_proto.method)]
1177    desc = descriptor.ServiceDescriptor(
1178        name=service_proto.name,
1179        full_name=service_name,
1180        index=service_index,
1181        methods=methods,
1182        options=_OptionsOrNone(service_proto),
1183        file=file_desc,
1184        # pylint: disable=protected-access
1185        create_key=descriptor._internal_create_key)
1186    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1187    self._service_descriptors[service_name] = desc
1188    return desc
1189
1190  def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
1191                            index):
1192    """Creates a method descriptor from a MethodDescriptorProto.
1193
1194    Args:
1195      method_proto: The proto describing the method.
1196      service_name: The name of the containing service.
1197      package: Optional package name to look up for types.
1198      scope: Scope containing available types.
1199      index: Index of the method in the service.
1200
1201    Returns:
1202      An initialized MethodDescriptor object.
1203    """
1204    full_name = '.'.join((service_name, method_proto.name))
1205    input_type = self._GetTypeFromScope(
1206        package, method_proto.input_type, scope)
1207    output_type = self._GetTypeFromScope(
1208        package, method_proto.output_type, scope)
1209    return descriptor.MethodDescriptor(
1210        name=method_proto.name,
1211        full_name=full_name,
1212        index=index,
1213        containing_service=None,
1214        input_type=input_type,
1215        output_type=output_type,
1216        client_streaming=method_proto.client_streaming,
1217        server_streaming=method_proto.server_streaming,
1218        options=_OptionsOrNone(method_proto),
1219        # pylint: disable=protected-access
1220        create_key=descriptor._internal_create_key)
1221
1222  def _ExtractSymbols(self, descriptors):
1223    """Pulls out all the symbols from descriptor protos.
1224
1225    Args:
1226      descriptors: The messages to extract descriptors from.
1227    Yields:
1228      A two element tuple of the type name and descriptor object.
1229    """
1230
1231    for desc in descriptors:
1232      yield (_PrefixWithDot(desc.full_name), desc)
1233      for symbol in self._ExtractSymbols(desc.nested_types):
1234        yield symbol
1235      for enum in desc.enum_types:
1236        yield (_PrefixWithDot(enum.full_name), enum)
1237
1238  def _GetDeps(self, dependencies, visited=None):
1239    """Recursively finds dependencies for file protos.
1240
1241    Args:
1242      dependencies: The names of the files being depended on.
1243      visited: The names of files already found.
1244
1245    Yields:
1246      Each direct and indirect dependency.
1247    """
1248
1249    visited = visited or set()
1250    for dependency in dependencies:
1251      if dependency not in visited:
1252        visited.add(dependency)
1253        dep_desc = self.FindFileByName(dependency)
1254        yield dep_desc
1255        public_files = [d.name for d in dep_desc.public_dependencies]
1256        yield from self._GetDeps(public_files, visited)
1257
1258  def _GetTypeFromScope(self, package, type_name, scope):
1259    """Finds a given type name in the current scope.
1260
1261    Args:
1262      package: The package the proto should be located in.
1263      type_name: The name of the type to be found in the scope.
1264      scope: Dict mapping short and full symbols to message and enum types.
1265
1266    Returns:
1267      The descriptor for the requested type.
1268    """
1269    if type_name not in scope:
1270      components = _PrefixWithDot(package).split('.')
1271      while components:
1272        possible_match = '.'.join(components + [type_name])
1273        if possible_match in scope:
1274          type_name = possible_match
1275          break
1276        else:
1277          components.pop(-1)
1278    return scope[type_name]
1279
1280
1281def _PrefixWithDot(name):
1282  return name if name.startswith('.') else '.%s' % name
1283
1284
1285if _USE_C_DESCRIPTORS:
1286  # TODO(amauryfa): This pool could be constructed from Python code, when we
1287  # support a flag like 'use_cpp_generated_pool=True'.
1288  # pylint: disable=protected-access
1289  _DEFAULT = descriptor._message.default_pool
1290else:
1291  _DEFAULT = DescriptorPool()
1292
1293
1294def Default():
1295  return _DEFAULT
1296