1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Provides DescriptorPool to use as a container for proto2 descriptors. 32 33The DescriptorPool is used in conjection with a DescriptorDatabase to maintain 34a collection of protocol buffer descriptors for use when dynamically creating 35message types at runtime. 36 37For most applications protocol buffers should be used via modules generated by 38the protocol buffer compiler tool. This should only be used when the type of 39protocol buffers used in an application or library cannot be predetermined. 40 41Below is a straightforward example on how to use this class:: 42 43 pool = DescriptorPool() 44 file_descriptor_protos = [ ... ] 45 for file_descriptor_proto in file_descriptor_protos: 46 pool.Add(file_descriptor_proto) 47 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType') 48 49The message descriptor can be used in conjunction with the message_factory 50module in order to create a protocol buffer class that can be encoded and 51decoded. 52 53If you want to get a Python class for the specified proto, use the 54helper functions inside google.protobuf.message_factory 55directly instead of this class. 56""" 57 58__author__ = '[email protected] (Matt Toia)' 59 60import collections 61import warnings 62 63from google.protobuf import descriptor 64from google.protobuf import descriptor_database 65from google.protobuf import text_encoding 66 67 68_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access 69 70 71def _Deprecated(func): 72 """Mark functions as deprecated.""" 73 74 def NewFunc(*args, **kwargs): 75 warnings.warn( 76 'Call to deprecated function %s(). Note: Do add unlinked descriptors ' 77 'to descriptor_pool is wrong. Use Add() or AddSerializedFile() ' 78 'instead.' % func.__name__, 79 category=DeprecationWarning) 80 return func(*args, **kwargs) 81 NewFunc.__name__ = func.__name__ 82 NewFunc.__doc__ = func.__doc__ 83 NewFunc.__dict__.update(func.__dict__) 84 return NewFunc 85 86 87def _NormalizeFullyQualifiedName(name): 88 """Remove leading period from fully-qualified type name. 89 90 Due to b/13860351 in descriptor_database.py, types in the root namespace are 91 generated with a leading period. This function removes that prefix. 92 93 Args: 94 name (str): The fully-qualified symbol name. 95 96 Returns: 97 str: The normalized fully-qualified symbol name. 98 """ 99 return name.lstrip('.') 100 101 102def _OptionsOrNone(descriptor_proto): 103 """Returns the value of the field `options`, or None if it is not set.""" 104 if descriptor_proto.HasField('options'): 105 return descriptor_proto.options 106 else: 107 return None 108 109 110def _IsMessageSetExtension(field): 111 return (field.is_extension and 112 field.containing_type.has_options and 113 field.containing_type.GetOptions().message_set_wire_format and 114 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and 115 field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL) 116 117 118class DescriptorPool(object): 119 """A collection of protobufs dynamically constructed by descriptor protos.""" 120 121 if _USE_C_DESCRIPTORS: 122 123 def __new__(cls, descriptor_db=None): 124 # pylint: disable=protected-access 125 return descriptor._message.DescriptorPool(descriptor_db) 126 127 def __init__(self, descriptor_db=None): 128 """Initializes a Pool of proto buffs. 129 130 The descriptor_db argument to the constructor is provided to allow 131 specialized file descriptor proto lookup code to be triggered on demand. An 132 example would be an implementation which will read and compile a file 133 specified in a call to FindFileByName() and not require the call to Add() 134 at all. Results from this database will be cached internally here as well. 135 136 Args: 137 descriptor_db: A secondary source of file descriptors. 138 """ 139 140 self._internal_db = descriptor_database.DescriptorDatabase() 141 self._descriptor_db = descriptor_db 142 self._descriptors = {} 143 self._enum_descriptors = {} 144 self._service_descriptors = {} 145 self._file_descriptors = {} 146 self._toplevel_extensions = {} 147 # TODO(jieluo): Remove _file_desc_by_toplevel_extension after 148 # maybe year 2020 for compatibility issue (with 3.4.1 only). 149 self._file_desc_by_toplevel_extension = {} 150 self._top_enum_values = {} 151 # We store extensions in two two-level mappings: The first key is the 152 # descriptor of the message being extended, the second key is the extension 153 # full name or its tag number. 154 self._extensions_by_name = collections.defaultdict(dict) 155 self._extensions_by_number = collections.defaultdict(dict) 156 157 def _CheckConflictRegister(self, desc, desc_name, file_name): 158 """Check if the descriptor name conflicts with another of the same name. 159 160 Args: 161 desc: Descriptor of a message, enum, service, extension or enum value. 162 desc_name (str): the full name of desc. 163 file_name (str): The file name of descriptor. 164 """ 165 for register, descriptor_type in [ 166 (self._descriptors, descriptor.Descriptor), 167 (self._enum_descriptors, descriptor.EnumDescriptor), 168 (self._service_descriptors, descriptor.ServiceDescriptor), 169 (self._toplevel_extensions, descriptor.FieldDescriptor), 170 (self._top_enum_values, descriptor.EnumValueDescriptor)]: 171 if desc_name in register: 172 old_desc = register[desc_name] 173 if isinstance(old_desc, descriptor.EnumValueDescriptor): 174 old_file = old_desc.type.file.name 175 else: 176 old_file = old_desc.file.name 177 178 if not isinstance(desc, descriptor_type) or ( 179 old_file != file_name): 180 error_msg = ('Conflict register for file "' + file_name + 181 '": ' + desc_name + 182 ' is already defined in file "' + 183 old_file + '". Please fix the conflict by adding ' 184 'package name on the proto file, or use different ' 185 'name for the duplication.') 186 if isinstance(desc, descriptor.EnumValueDescriptor): 187 error_msg += ('\nNote: enum values appear as ' 188 'siblings of the enum type instead of ' 189 'children of it.') 190 191 raise TypeError(error_msg) 192 193 return 194 195 def Add(self, file_desc_proto): 196 """Adds the FileDescriptorProto and its types to this pool. 197 198 Args: 199 file_desc_proto (FileDescriptorProto): The file descriptor to add. 200 """ 201 202 self._internal_db.Add(file_desc_proto) 203 204 def AddSerializedFile(self, serialized_file_desc_proto): 205 """Adds the FileDescriptorProto and its types to this pool. 206 207 Args: 208 serialized_file_desc_proto (bytes): A bytes string, serialization of the 209 :class:`FileDescriptorProto` to add. 210 211 Returns: 212 FileDescriptor: Descriptor for the added file. 213 """ 214 215 # pylint: disable=g-import-not-at-top 216 from google.protobuf import descriptor_pb2 217 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( 218 serialized_file_desc_proto) 219 file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto) 220 file_desc.serialized_pb = serialized_file_desc_proto 221 return file_desc 222 223 # Add Descriptor to descriptor pool is dreprecated. Please use Add() 224 # or AddSerializedFile() to add a FileDescriptorProto instead. 225 @_Deprecated 226 def AddDescriptor(self, desc): 227 self._AddDescriptor(desc) 228 229 # Never call this method. It is for internal usage only. 230 def _AddDescriptor(self, desc): 231 """Adds a Descriptor to the pool, non-recursively. 232 233 If the Descriptor contains nested messages or enums, the caller must 234 explicitly register them. This method also registers the FileDescriptor 235 associated with the message. 236 237 Args: 238 desc: A Descriptor. 239 """ 240 if not isinstance(desc, descriptor.Descriptor): 241 raise TypeError('Expected instance of descriptor.Descriptor.') 242 243 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 244 245 self._descriptors[desc.full_name] = desc 246 self._AddFileDescriptor(desc.file) 247 248 # Add EnumDescriptor to descriptor pool is dreprecated. Please use Add() 249 # or AddSerializedFile() to add a FileDescriptorProto instead. 250 @_Deprecated 251 def AddEnumDescriptor(self, enum_desc): 252 self._AddEnumDescriptor(enum_desc) 253 254 # Never call this method. It is for internal usage only. 255 def _AddEnumDescriptor(self, enum_desc): 256 """Adds an EnumDescriptor to the pool. 257 258 This method also registers the FileDescriptor associated with the enum. 259 260 Args: 261 enum_desc: An EnumDescriptor. 262 """ 263 264 if not isinstance(enum_desc, descriptor.EnumDescriptor): 265 raise TypeError('Expected instance of descriptor.EnumDescriptor.') 266 267 file_name = enum_desc.file.name 268 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name) 269 self._enum_descriptors[enum_desc.full_name] = enum_desc 270 271 # Top enum values need to be indexed. 272 # Count the number of dots to see whether the enum is toplevel or nested 273 # in a message. We cannot use enum_desc.containing_type at this stage. 274 if enum_desc.file.package: 275 top_level = (enum_desc.full_name.count('.') 276 - enum_desc.file.package.count('.') == 1) 277 else: 278 top_level = enum_desc.full_name.count('.') == 0 279 if top_level: 280 file_name = enum_desc.file.name 281 package = enum_desc.file.package 282 for enum_value in enum_desc.values: 283 full_name = _NormalizeFullyQualifiedName( 284 '.'.join((package, enum_value.name))) 285 self._CheckConflictRegister(enum_value, full_name, file_name) 286 self._top_enum_values[full_name] = enum_value 287 self._AddFileDescriptor(enum_desc.file) 288 289 # Add ServiceDescriptor to descriptor pool is dreprecated. Please use Add() 290 # or AddSerializedFile() to add a FileDescriptorProto instead. 291 @_Deprecated 292 def AddServiceDescriptor(self, service_desc): 293 self._AddServiceDescriptor(service_desc) 294 295 # Never call this method. It is for internal usage only. 296 def _AddServiceDescriptor(self, service_desc): 297 """Adds a ServiceDescriptor to the pool. 298 299 Args: 300 service_desc: A ServiceDescriptor. 301 """ 302 303 if not isinstance(service_desc, descriptor.ServiceDescriptor): 304 raise TypeError('Expected instance of descriptor.ServiceDescriptor.') 305 306 self._CheckConflictRegister(service_desc, service_desc.full_name, 307 service_desc.file.name) 308 self._service_descriptors[service_desc.full_name] = service_desc 309 310 # Add ExtensionDescriptor to descriptor pool is dreprecated. Please use Add() 311 # or AddSerializedFile() to add a FileDescriptorProto instead. 312 @_Deprecated 313 def AddExtensionDescriptor(self, extension): 314 self._AddExtensionDescriptor(extension) 315 316 # Never call this method. It is for internal usage only. 317 def _AddExtensionDescriptor(self, extension): 318 """Adds a FieldDescriptor describing an extension to the pool. 319 320 Args: 321 extension: A FieldDescriptor. 322 323 Raises: 324 AssertionError: when another extension with the same number extends the 325 same message. 326 TypeError: when the specified extension is not a 327 descriptor.FieldDescriptor. 328 """ 329 if not (isinstance(extension, descriptor.FieldDescriptor) and 330 extension.is_extension): 331 raise TypeError('Expected an extension descriptor.') 332 333 if extension.extension_scope is None: 334 self._toplevel_extensions[extension.full_name] = extension 335 336 try: 337 existing_desc = self._extensions_by_number[ 338 extension.containing_type][extension.number] 339 except KeyError: 340 pass 341 else: 342 if extension is not existing_desc: 343 raise AssertionError( 344 'Extensions "%s" and "%s" both try to extend message type "%s" ' 345 'with field number %d.' % 346 (extension.full_name, existing_desc.full_name, 347 extension.containing_type.full_name, extension.number)) 348 349 self._extensions_by_number[extension.containing_type][ 350 extension.number] = extension 351 self._extensions_by_name[extension.containing_type][ 352 extension.full_name] = extension 353 354 # Also register MessageSet extensions with the type name. 355 if _IsMessageSetExtension(extension): 356 self._extensions_by_name[extension.containing_type][ 357 extension.message_type.full_name] = extension 358 359 @_Deprecated 360 def AddFileDescriptor(self, file_desc): 361 self._InternalAddFileDescriptor(file_desc) 362 363 # Never call this method. It is for internal usage only. 364 def _InternalAddFileDescriptor(self, file_desc): 365 """Adds a FileDescriptor to the pool, non-recursively. 366 367 If the FileDescriptor contains messages or enums, the caller must explicitly 368 register them. 369 370 Args: 371 file_desc: A FileDescriptor. 372 """ 373 374 self._AddFileDescriptor(file_desc) 375 # TODO(jieluo): This is a temporary solution for FieldDescriptor.file. 376 # FieldDescriptor.file is added in code gen. Remove this solution after 377 # maybe 2020 for compatibility reason (with 3.4.1 only). 378 for extension in file_desc.extensions_by_name.values(): 379 self._file_desc_by_toplevel_extension[ 380 extension.full_name] = file_desc 381 382 def _AddFileDescriptor(self, file_desc): 383 """Adds a FileDescriptor to the pool, non-recursively. 384 385 If the FileDescriptor contains messages or enums, the caller must explicitly 386 register them. 387 388 Args: 389 file_desc: A FileDescriptor. 390 """ 391 392 if not isinstance(file_desc, descriptor.FileDescriptor): 393 raise TypeError('Expected instance of descriptor.FileDescriptor.') 394 self._file_descriptors[file_desc.name] = file_desc 395 396 def FindFileByName(self, file_name): 397 """Gets a FileDescriptor by file name. 398 399 Args: 400 file_name (str): The path to the file to get a descriptor for. 401 402 Returns: 403 FileDescriptor: The descriptor for the named file. 404 405 Raises: 406 KeyError: if the file cannot be found in the pool. 407 """ 408 409 try: 410 return self._file_descriptors[file_name] 411 except KeyError: 412 pass 413 414 try: 415 file_proto = self._internal_db.FindFileByName(file_name) 416 except KeyError as error: 417 if self._descriptor_db: 418 file_proto = self._descriptor_db.FindFileByName(file_name) 419 else: 420 raise error 421 if not file_proto: 422 raise KeyError('Cannot find a file named %s' % file_name) 423 return self._ConvertFileProtoToFileDescriptor(file_proto) 424 425 def FindFileContainingSymbol(self, symbol): 426 """Gets the FileDescriptor for the file containing the specified symbol. 427 428 Args: 429 symbol (str): The name of the symbol to search for. 430 431 Returns: 432 FileDescriptor: Descriptor for the file that contains the specified 433 symbol. 434 435 Raises: 436 KeyError: if the file cannot be found in the pool. 437 """ 438 439 symbol = _NormalizeFullyQualifiedName(symbol) 440 try: 441 return self._InternalFindFileContainingSymbol(symbol) 442 except KeyError: 443 pass 444 445 try: 446 # Try fallback database. Build and find again if possible. 447 self._FindFileContainingSymbolInDb(symbol) 448 return self._InternalFindFileContainingSymbol(symbol) 449 except KeyError: 450 raise KeyError('Cannot find a file containing %s' % symbol) 451 452 def _InternalFindFileContainingSymbol(self, symbol): 453 """Gets the already built FileDescriptor containing the specified symbol. 454 455 Args: 456 symbol (str): The name of the symbol to search for. 457 458 Returns: 459 FileDescriptor: Descriptor for the file that contains the specified 460 symbol. 461 462 Raises: 463 KeyError: if the file cannot be found in the pool. 464 """ 465 try: 466 return self._descriptors[symbol].file 467 except KeyError: 468 pass 469 470 try: 471 return self._enum_descriptors[symbol].file 472 except KeyError: 473 pass 474 475 try: 476 return self._service_descriptors[symbol].file 477 except KeyError: 478 pass 479 480 try: 481 return self._top_enum_values[symbol].type.file 482 except KeyError: 483 pass 484 485 try: 486 return self._file_desc_by_toplevel_extension[symbol] 487 except KeyError: 488 pass 489 490 # Try fields, enum values and nested extensions inside a message. 491 top_name, _, sub_name = symbol.rpartition('.') 492 try: 493 message = self.FindMessageTypeByName(top_name) 494 assert (sub_name in message.extensions_by_name or 495 sub_name in message.fields_by_name or 496 sub_name in message.enum_values_by_name) 497 return message.file 498 except (KeyError, AssertionError): 499 raise KeyError('Cannot find a file containing %s' % symbol) 500 501 def FindMessageTypeByName(self, full_name): 502 """Loads the named descriptor from the pool. 503 504 Args: 505 full_name (str): The full name of the descriptor to load. 506 507 Returns: 508 Descriptor: The descriptor for the named type. 509 510 Raises: 511 KeyError: if the message cannot be found in the pool. 512 """ 513 514 full_name = _NormalizeFullyQualifiedName(full_name) 515 if full_name not in self._descriptors: 516 self._FindFileContainingSymbolInDb(full_name) 517 return self._descriptors[full_name] 518 519 def FindEnumTypeByName(self, full_name): 520 """Loads the named enum descriptor from the pool. 521 522 Args: 523 full_name (str): The full name of the enum descriptor to load. 524 525 Returns: 526 EnumDescriptor: The enum descriptor for the named type. 527 528 Raises: 529 KeyError: if the enum cannot be found in the pool. 530 """ 531 532 full_name = _NormalizeFullyQualifiedName(full_name) 533 if full_name not in self._enum_descriptors: 534 self._FindFileContainingSymbolInDb(full_name) 535 return self._enum_descriptors[full_name] 536 537 def FindFieldByName(self, full_name): 538 """Loads the named field descriptor from the pool. 539 540 Args: 541 full_name (str): The full name of the field descriptor to load. 542 543 Returns: 544 FieldDescriptor: The field descriptor for the named field. 545 546 Raises: 547 KeyError: if the field cannot be found in the pool. 548 """ 549 full_name = _NormalizeFullyQualifiedName(full_name) 550 message_name, _, field_name = full_name.rpartition('.') 551 message_descriptor = self.FindMessageTypeByName(message_name) 552 return message_descriptor.fields_by_name[field_name] 553 554 def FindOneofByName(self, full_name): 555 """Loads the named oneof descriptor from the pool. 556 557 Args: 558 full_name (str): The full name of the oneof descriptor to load. 559 560 Returns: 561 OneofDescriptor: The oneof descriptor for the named oneof. 562 563 Raises: 564 KeyError: if the oneof cannot be found in the pool. 565 """ 566 full_name = _NormalizeFullyQualifiedName(full_name) 567 message_name, _, oneof_name = full_name.rpartition('.') 568 message_descriptor = self.FindMessageTypeByName(message_name) 569 return message_descriptor.oneofs_by_name[oneof_name] 570 571 def FindExtensionByName(self, full_name): 572 """Loads the named extension descriptor from the pool. 573 574 Args: 575 full_name (str): The full name of the extension descriptor to load. 576 577 Returns: 578 FieldDescriptor: The field descriptor for the named extension. 579 580 Raises: 581 KeyError: if the extension cannot be found in the pool. 582 """ 583 full_name = _NormalizeFullyQualifiedName(full_name) 584 try: 585 # The proto compiler does not give any link between the FileDescriptor 586 # and top-level extensions unless the FileDescriptorProto is added to 587 # the DescriptorDatabase, but this can impact memory usage. 588 # So we registered these extensions by name explicitly. 589 return self._toplevel_extensions[full_name] 590 except KeyError: 591 pass 592 message_name, _, extension_name = full_name.rpartition('.') 593 try: 594 # Most extensions are nested inside a message. 595 scope = self.FindMessageTypeByName(message_name) 596 except KeyError: 597 # Some extensions are defined at file scope. 598 scope = self._FindFileContainingSymbolInDb(full_name) 599 return scope.extensions_by_name[extension_name] 600 601 def FindExtensionByNumber(self, message_descriptor, number): 602 """Gets the extension of the specified message with the specified number. 603 604 Extensions have to be registered to this pool by calling :func:`Add` or 605 :func:`AddExtensionDescriptor`. 606 607 Args: 608 message_descriptor (Descriptor): descriptor of the extended message. 609 number (int): Number of the extension field. 610 611 Returns: 612 FieldDescriptor: The descriptor for the extension. 613 614 Raises: 615 KeyError: when no extension with the given number is known for the 616 specified message. 617 """ 618 try: 619 return self._extensions_by_number[message_descriptor][number] 620 except KeyError: 621 self._TryLoadExtensionFromDB(message_descriptor, number) 622 return self._extensions_by_number[message_descriptor][number] 623 624 def FindAllExtensions(self, message_descriptor): 625 """Gets all the known extensions of a given message. 626 627 Extensions have to be registered to this pool by build related 628 :func:`Add` or :func:`AddExtensionDescriptor`. 629 630 Args: 631 message_descriptor (Descriptor): Descriptor of the extended message. 632 633 Returns: 634 list[FieldDescriptor]: Field descriptors describing the extensions. 635 """ 636 # Fallback to descriptor db if FindAllExtensionNumbers is provided. 637 if self._descriptor_db and hasattr( 638 self._descriptor_db, 'FindAllExtensionNumbers'): 639 full_name = message_descriptor.full_name 640 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name) 641 for number in all_numbers: 642 if number in self._extensions_by_number[message_descriptor]: 643 continue 644 self._TryLoadExtensionFromDB(message_descriptor, number) 645 646 return list(self._extensions_by_number[message_descriptor].values()) 647 648 def _TryLoadExtensionFromDB(self, message_descriptor, number): 649 """Try to Load extensions from descriptor db. 650 651 Args: 652 message_descriptor: descriptor of the extended message. 653 number: the extension number that needs to be loaded. 654 """ 655 if not self._descriptor_db: 656 return 657 # Only supported when FindFileContainingExtension is provided. 658 if not hasattr( 659 self._descriptor_db, 'FindFileContainingExtension'): 660 return 661 662 full_name = message_descriptor.full_name 663 file_proto = self._descriptor_db.FindFileContainingExtension( 664 full_name, number) 665 666 if file_proto is None: 667 return 668 669 try: 670 self._ConvertFileProtoToFileDescriptor(file_proto) 671 except: 672 warn_msg = ('Unable to load proto file %s for extension number %d.' % 673 (file_proto.name, number)) 674 warnings.warn(warn_msg, RuntimeWarning) 675 676 def FindServiceByName(self, full_name): 677 """Loads the named service descriptor from the pool. 678 679 Args: 680 full_name (str): The full name of the service descriptor to load. 681 682 Returns: 683 ServiceDescriptor: The service descriptor for the named service. 684 685 Raises: 686 KeyError: if the service cannot be found in the pool. 687 """ 688 full_name = _NormalizeFullyQualifiedName(full_name) 689 if full_name not in self._service_descriptors: 690 self._FindFileContainingSymbolInDb(full_name) 691 return self._service_descriptors[full_name] 692 693 def FindMethodByName(self, full_name): 694 """Loads the named service method descriptor from the pool. 695 696 Args: 697 full_name (str): The full name of the method descriptor to load. 698 699 Returns: 700 MethodDescriptor: The method descriptor for the service method. 701 702 Raises: 703 KeyError: if the method cannot be found in the pool. 704 """ 705 full_name = _NormalizeFullyQualifiedName(full_name) 706 service_name, _, method_name = full_name.rpartition('.') 707 service_descriptor = self.FindServiceByName(service_name) 708 return service_descriptor.methods_by_name[method_name] 709 710 def _FindFileContainingSymbolInDb(self, symbol): 711 """Finds the file in descriptor DB containing the specified symbol. 712 713 Args: 714 symbol (str): The name of the symbol to search for. 715 716 Returns: 717 FileDescriptor: The file that contains the specified symbol. 718 719 Raises: 720 KeyError: if the file cannot be found in the descriptor database. 721 """ 722 try: 723 file_proto = self._internal_db.FindFileContainingSymbol(symbol) 724 except KeyError as error: 725 if self._descriptor_db: 726 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) 727 else: 728 raise error 729 if not file_proto: 730 raise KeyError('Cannot find a file containing %s' % symbol) 731 return self._ConvertFileProtoToFileDescriptor(file_proto) 732 733 def _ConvertFileProtoToFileDescriptor(self, file_proto): 734 """Creates a FileDescriptor from a proto or returns a cached copy. 735 736 This method also has the side effect of loading all the symbols found in 737 the file into the appropriate dictionaries in the pool. 738 739 Args: 740 file_proto: The proto to convert. 741 742 Returns: 743 A FileDescriptor matching the passed in proto. 744 """ 745 if file_proto.name not in self._file_descriptors: 746 built_deps = list(self._GetDeps(file_proto.dependency)) 747 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] 748 public_deps = [direct_deps[i] for i in file_proto.public_dependency] 749 750 file_descriptor = descriptor.FileDescriptor( 751 pool=self, 752 name=file_proto.name, 753 package=file_proto.package, 754 syntax=file_proto.syntax, 755 options=_OptionsOrNone(file_proto), 756 serialized_pb=file_proto.SerializeToString(), 757 dependencies=direct_deps, 758 public_dependencies=public_deps, 759 # pylint: disable=protected-access 760 create_key=descriptor._internal_create_key) 761 scope = {} 762 763 # This loop extracts all the message and enum types from all the 764 # dependencies of the file_proto. This is necessary to create the 765 # scope of available message types when defining the passed in 766 # file proto. 767 for dependency in built_deps: 768 scope.update(self._ExtractSymbols( 769 dependency.message_types_by_name.values())) 770 scope.update((_PrefixWithDot(enum.full_name), enum) 771 for enum in dependency.enum_types_by_name.values()) 772 773 for message_type in file_proto.message_type: 774 message_desc = self._ConvertMessageDescriptor( 775 message_type, file_proto.package, file_descriptor, scope, 776 file_proto.syntax) 777 file_descriptor.message_types_by_name[message_desc.name] = ( 778 message_desc) 779 780 for enum_type in file_proto.enum_type: 781 file_descriptor.enum_types_by_name[enum_type.name] = ( 782 self._ConvertEnumDescriptor(enum_type, file_proto.package, 783 file_descriptor, None, scope, True)) 784 785 for index, extension_proto in enumerate(file_proto.extension): 786 extension_desc = self._MakeFieldDescriptor( 787 extension_proto, file_proto.package, index, file_descriptor, 788 is_extension=True) 789 extension_desc.containing_type = self._GetTypeFromScope( 790 file_descriptor.package, extension_proto.extendee, scope) 791 self._SetFieldType(extension_proto, extension_desc, 792 file_descriptor.package, scope) 793 file_descriptor.extensions_by_name[extension_desc.name] = ( 794 extension_desc) 795 self._file_desc_by_toplevel_extension[extension_desc.full_name] = ( 796 file_descriptor) 797 798 for desc_proto in file_proto.message_type: 799 self._SetAllFieldTypes(file_proto.package, desc_proto, scope) 800 801 if file_proto.package: 802 desc_proto_prefix = _PrefixWithDot(file_proto.package) 803 else: 804 desc_proto_prefix = '' 805 806 for desc_proto in file_proto.message_type: 807 desc = self._GetTypeFromScope( 808 desc_proto_prefix, desc_proto.name, scope) 809 file_descriptor.message_types_by_name[desc_proto.name] = desc 810 811 for index, service_proto in enumerate(file_proto.service): 812 file_descriptor.services_by_name[service_proto.name] = ( 813 self._MakeServiceDescriptor(service_proto, index, scope, 814 file_proto.package, file_descriptor)) 815 816 self._file_descriptors[file_proto.name] = file_descriptor 817 818 # Add extensions to the pool 819 file_desc = self._file_descriptors[file_proto.name] 820 for extension in file_desc.extensions_by_name.values(): 821 self._AddExtensionDescriptor(extension) 822 for message_type in file_desc.message_types_by_name.values(): 823 for extension in message_type.extensions: 824 self._AddExtensionDescriptor(extension) 825 826 return file_desc 827 828 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, 829 scope=None, syntax=None): 830 """Adds the proto to the pool in the specified package. 831 832 Args: 833 desc_proto: The descriptor_pb2.DescriptorProto protobuf message. 834 package: The package the proto should be located in. 835 file_desc: The file containing this message. 836 scope: Dict mapping short and full symbols to message and enum types. 837 syntax: string indicating syntax of the file ("proto2" or "proto3") 838 839 Returns: 840 The added descriptor. 841 """ 842 843 if package: 844 desc_name = '.'.join((package, desc_proto.name)) 845 else: 846 desc_name = desc_proto.name 847 848 if file_desc is None: 849 file_name = None 850 else: 851 file_name = file_desc.name 852 853 if scope is None: 854 scope = {} 855 856 nested = [ 857 self._ConvertMessageDescriptor( 858 nested, desc_name, file_desc, scope, syntax) 859 for nested in desc_proto.nested_type] 860 enums = [ 861 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, 862 scope, False) 863 for enum in desc_proto.enum_type] 864 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc) 865 for index, field in enumerate(desc_proto.field)] 866 extensions = [ 867 self._MakeFieldDescriptor(extension, desc_name, index, file_desc, 868 is_extension=True) 869 for index, extension in enumerate(desc_proto.extension)] 870 oneofs = [ 871 # pylint: disable=g-complex-comprehension 872 descriptor.OneofDescriptor( 873 desc.name, 874 '.'.join((desc_name, desc.name)), 875 index, 876 None, 877 [], 878 _OptionsOrNone(desc), 879 # pylint: disable=protected-access 880 create_key=descriptor._internal_create_key) 881 for index, desc in enumerate(desc_proto.oneof_decl) 882 ] 883 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] 884 if extension_ranges: 885 is_extendable = True 886 else: 887 is_extendable = False 888 desc = descriptor.Descriptor( 889 name=desc_proto.name, 890 full_name=desc_name, 891 filename=file_name, 892 containing_type=None, 893 fields=fields, 894 oneofs=oneofs, 895 nested_types=nested, 896 enum_types=enums, 897 extensions=extensions, 898 options=_OptionsOrNone(desc_proto), 899 is_extendable=is_extendable, 900 extension_ranges=extension_ranges, 901 file=file_desc, 902 serialized_start=None, 903 serialized_end=None, 904 syntax=syntax, 905 # pylint: disable=protected-access 906 create_key=descriptor._internal_create_key) 907 for nested in desc.nested_types: 908 nested.containing_type = desc 909 for enum in desc.enum_types: 910 enum.containing_type = desc 911 for field_index, field_desc in enumerate(desc_proto.field): 912 if field_desc.HasField('oneof_index'): 913 oneof_index = field_desc.oneof_index 914 oneofs[oneof_index].fields.append(fields[field_index]) 915 fields[field_index].containing_oneof = oneofs[oneof_index] 916 917 scope[_PrefixWithDot(desc_name)] = desc 918 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 919 self._descriptors[desc_name] = desc 920 return desc 921 922 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None, 923 containing_type=None, scope=None, top_level=False): 924 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf. 925 926 Args: 927 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message. 928 package: Optional package name for the new message EnumDescriptor. 929 file_desc: The file containing the enum descriptor. 930 containing_type: The type containing this enum. 931 scope: Scope containing available types. 932 top_level: If True, the enum is a top level symbol. If False, the enum 933 is defined inside a message. 934 935 Returns: 936 The added descriptor 937 """ 938 939 if package: 940 enum_name = '.'.join((package, enum_proto.name)) 941 else: 942 enum_name = enum_proto.name 943 944 if file_desc is None: 945 file_name = None 946 else: 947 file_name = file_desc.name 948 949 values = [self._MakeEnumValueDescriptor(value, index) 950 for index, value in enumerate(enum_proto.value)] 951 desc = descriptor.EnumDescriptor(name=enum_proto.name, 952 full_name=enum_name, 953 filename=file_name, 954 file=file_desc, 955 values=values, 956 containing_type=containing_type, 957 options=_OptionsOrNone(enum_proto), 958 # pylint: disable=protected-access 959 create_key=descriptor._internal_create_key) 960 scope['.%s' % enum_name] = desc 961 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 962 self._enum_descriptors[enum_name] = desc 963 964 # Add top level enum values. 965 if top_level: 966 for value in values: 967 full_name = _NormalizeFullyQualifiedName( 968 '.'.join((package, value.name))) 969 self._CheckConflictRegister(value, full_name, file_name) 970 self._top_enum_values[full_name] = value 971 972 return desc 973 974 def _MakeFieldDescriptor(self, field_proto, message_name, index, 975 file_desc, is_extension=False): 976 """Creates a field descriptor from a FieldDescriptorProto. 977 978 For message and enum type fields, this method will do a look up 979 in the pool for the appropriate descriptor for that type. If it 980 is unavailable, it will fall back to the _source function to 981 create it. If this type is still unavailable, construction will 982 fail. 983 984 Args: 985 field_proto: The proto describing the field. 986 message_name: The name of the containing message. 987 index: Index of the field 988 file_desc: The file containing the field descriptor. 989 is_extension: Indication that this field is for an extension. 990 991 Returns: 992 An initialized FieldDescriptor object 993 """ 994 995 if message_name: 996 full_name = '.'.join((message_name, field_proto.name)) 997 else: 998 full_name = field_proto.name 999 1000 if field_proto.json_name: 1001 json_name = field_proto.json_name 1002 else: 1003 json_name = None 1004 1005 return descriptor.FieldDescriptor( 1006 name=field_proto.name, 1007 full_name=full_name, 1008 index=index, 1009 number=field_proto.number, 1010 type=field_proto.type, 1011 cpp_type=None, 1012 message_type=None, 1013 enum_type=None, 1014 containing_type=None, 1015 label=field_proto.label, 1016 has_default_value=False, 1017 default_value=None, 1018 is_extension=is_extension, 1019 extension_scope=None, 1020 options=_OptionsOrNone(field_proto), 1021 json_name=json_name, 1022 file=file_desc, 1023 # pylint: disable=protected-access 1024 create_key=descriptor._internal_create_key) 1025 1026 def _SetAllFieldTypes(self, package, desc_proto, scope): 1027 """Sets all the descriptor's fields's types. 1028 1029 This method also sets the containing types on any extensions. 1030 1031 Args: 1032 package: The current package of desc_proto. 1033 desc_proto: The message descriptor to update. 1034 scope: Enclosing scope of available types. 1035 """ 1036 1037 package = _PrefixWithDot(package) 1038 1039 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope) 1040 1041 if package == '.': 1042 nested_package = _PrefixWithDot(desc_proto.name) 1043 else: 1044 nested_package = '.'.join([package, desc_proto.name]) 1045 1046 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): 1047 self._SetFieldType(field_proto, field_desc, nested_package, scope) 1048 1049 for extension_proto, extension_desc in ( 1050 zip(desc_proto.extension, main_desc.extensions)): 1051 extension_desc.containing_type = self._GetTypeFromScope( 1052 nested_package, extension_proto.extendee, scope) 1053 self._SetFieldType(extension_proto, extension_desc, nested_package, scope) 1054 1055 for nested_type in desc_proto.nested_type: 1056 self._SetAllFieldTypes(nested_package, nested_type, scope) 1057 1058 def _SetFieldType(self, field_proto, field_desc, package, scope): 1059 """Sets the field's type, cpp_type, message_type and enum_type. 1060 1061 Args: 1062 field_proto: Data about the field in proto format. 1063 field_desc: The descriptor to modify. 1064 package: The package the field's container is in. 1065 scope: Enclosing scope of available types. 1066 """ 1067 if field_proto.type_name: 1068 desc = self._GetTypeFromScope(package, field_proto.type_name, scope) 1069 else: 1070 desc = None 1071 1072 if not field_proto.HasField('type'): 1073 if isinstance(desc, descriptor.Descriptor): 1074 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE 1075 else: 1076 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM 1077 1078 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( 1079 field_proto.type) 1080 1081 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE 1082 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): 1083 field_desc.message_type = desc 1084 1085 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1086 field_desc.enum_type = desc 1087 1088 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: 1089 field_desc.has_default_value = False 1090 field_desc.default_value = [] 1091 elif field_proto.HasField('default_value'): 1092 field_desc.has_default_value = True 1093 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or 1094 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): 1095 field_desc.default_value = float(field_proto.default_value) 1096 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: 1097 field_desc.default_value = field_proto.default_value 1098 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: 1099 field_desc.default_value = field_proto.default_value.lower() == 'true' 1100 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1101 field_desc.default_value = field_desc.enum_type.values_by_name[ 1102 field_proto.default_value].number 1103 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: 1104 field_desc.default_value = text_encoding.CUnescape( 1105 field_proto.default_value) 1106 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: 1107 field_desc.default_value = None 1108 else: 1109 # All other types are of the "int" type. 1110 field_desc.default_value = int(field_proto.default_value) 1111 else: 1112 field_desc.has_default_value = False 1113 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or 1114 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): 1115 field_desc.default_value = 0.0 1116 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: 1117 field_desc.default_value = u'' 1118 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: 1119 field_desc.default_value = False 1120 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1121 field_desc.default_value = field_desc.enum_type.values[0].number 1122 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: 1123 field_desc.default_value = b'' 1124 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: 1125 field_desc.default_value = None 1126 elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP: 1127 field_desc.default_value = None 1128 else: 1129 # All other types are of the "int" type. 1130 field_desc.default_value = 0 1131 1132 field_desc.type = field_proto.type 1133 1134 def _MakeEnumValueDescriptor(self, value_proto, index): 1135 """Creates a enum value descriptor object from a enum value proto. 1136 1137 Args: 1138 value_proto: The proto describing the enum value. 1139 index: The index of the enum value. 1140 1141 Returns: 1142 An initialized EnumValueDescriptor object. 1143 """ 1144 1145 return descriptor.EnumValueDescriptor( 1146 name=value_proto.name, 1147 index=index, 1148 number=value_proto.number, 1149 options=_OptionsOrNone(value_proto), 1150 type=None, 1151 # pylint: disable=protected-access 1152 create_key=descriptor._internal_create_key) 1153 1154 def _MakeServiceDescriptor(self, service_proto, service_index, scope, 1155 package, file_desc): 1156 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto. 1157 1158 Args: 1159 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message. 1160 service_index: The index of the service in the File. 1161 scope: Dict mapping short and full symbols to message and enum types. 1162 package: Optional package name for the new message EnumDescriptor. 1163 file_desc: The file containing the service descriptor. 1164 1165 Returns: 1166 The added descriptor. 1167 """ 1168 1169 if package: 1170 service_name = '.'.join((package, service_proto.name)) 1171 else: 1172 service_name = service_proto.name 1173 1174 methods = [self._MakeMethodDescriptor(method_proto, service_name, package, 1175 scope, index) 1176 for index, method_proto in enumerate(service_proto.method)] 1177 desc = descriptor.ServiceDescriptor( 1178 name=service_proto.name, 1179 full_name=service_name, 1180 index=service_index, 1181 methods=methods, 1182 options=_OptionsOrNone(service_proto), 1183 file=file_desc, 1184 # pylint: disable=protected-access 1185 create_key=descriptor._internal_create_key) 1186 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 1187 self._service_descriptors[service_name] = desc 1188 return desc 1189 1190 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope, 1191 index): 1192 """Creates a method descriptor from a MethodDescriptorProto. 1193 1194 Args: 1195 method_proto: The proto describing the method. 1196 service_name: The name of the containing service. 1197 package: Optional package name to look up for types. 1198 scope: Scope containing available types. 1199 index: Index of the method in the service. 1200 1201 Returns: 1202 An initialized MethodDescriptor object. 1203 """ 1204 full_name = '.'.join((service_name, method_proto.name)) 1205 input_type = self._GetTypeFromScope( 1206 package, method_proto.input_type, scope) 1207 output_type = self._GetTypeFromScope( 1208 package, method_proto.output_type, scope) 1209 return descriptor.MethodDescriptor( 1210 name=method_proto.name, 1211 full_name=full_name, 1212 index=index, 1213 containing_service=None, 1214 input_type=input_type, 1215 output_type=output_type, 1216 client_streaming=method_proto.client_streaming, 1217 server_streaming=method_proto.server_streaming, 1218 options=_OptionsOrNone(method_proto), 1219 # pylint: disable=protected-access 1220 create_key=descriptor._internal_create_key) 1221 1222 def _ExtractSymbols(self, descriptors): 1223 """Pulls out all the symbols from descriptor protos. 1224 1225 Args: 1226 descriptors: The messages to extract descriptors from. 1227 Yields: 1228 A two element tuple of the type name and descriptor object. 1229 """ 1230 1231 for desc in descriptors: 1232 yield (_PrefixWithDot(desc.full_name), desc) 1233 for symbol in self._ExtractSymbols(desc.nested_types): 1234 yield symbol 1235 for enum in desc.enum_types: 1236 yield (_PrefixWithDot(enum.full_name), enum) 1237 1238 def _GetDeps(self, dependencies, visited=None): 1239 """Recursively finds dependencies for file protos. 1240 1241 Args: 1242 dependencies: The names of the files being depended on. 1243 visited: The names of files already found. 1244 1245 Yields: 1246 Each direct and indirect dependency. 1247 """ 1248 1249 visited = visited or set() 1250 for dependency in dependencies: 1251 if dependency not in visited: 1252 visited.add(dependency) 1253 dep_desc = self.FindFileByName(dependency) 1254 yield dep_desc 1255 public_files = [d.name for d in dep_desc.public_dependencies] 1256 yield from self._GetDeps(public_files, visited) 1257 1258 def _GetTypeFromScope(self, package, type_name, scope): 1259 """Finds a given type name in the current scope. 1260 1261 Args: 1262 package: The package the proto should be located in. 1263 type_name: The name of the type to be found in the scope. 1264 scope: Dict mapping short and full symbols to message and enum types. 1265 1266 Returns: 1267 The descriptor for the requested type. 1268 """ 1269 if type_name not in scope: 1270 components = _PrefixWithDot(package).split('.') 1271 while components: 1272 possible_match = '.'.join(components + [type_name]) 1273 if possible_match in scope: 1274 type_name = possible_match 1275 break 1276 else: 1277 components.pop(-1) 1278 return scope[type_name] 1279 1280 1281def _PrefixWithDot(name): 1282 return name if name.startswith('.') else '.%s' % name 1283 1284 1285if _USE_C_DESCRIPTORS: 1286 # TODO(amauryfa): This pool could be constructed from Python code, when we 1287 # support a flag like 'use_cpp_generated_pool=True'. 1288 # pylint: disable=protected-access 1289 _DEFAULT = descriptor._message.default_pool 1290else: 1291 _DEFAULT = DescriptorPool() 1292 1293 1294def Default(): 1295 return _DEFAULT 1296