1# Copyright 2020 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Base implementation of reflection servicer.""" 15 16from google.protobuf import descriptor_pb2 17from google.protobuf import descriptor_pool 18import grpc 19from grpc_reflection.v1alpha import reflection_pb2 as _reflection_pb2 20from grpc_reflection.v1alpha import reflection_pb2_grpc as _reflection_pb2_grpc 21 22_POOL = descriptor_pool.Default() 23 24 25def _not_found_error(): 26 return _reflection_pb2.ServerReflectionResponse( 27 error_response=_reflection_pb2.ErrorResponse( 28 error_code=grpc.StatusCode.NOT_FOUND.value[0], 29 error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), 30 ) 31 ) 32 33 34def _collect_transitive_dependencies(descriptor, seen_files): 35 seen_files.update({descriptor.name: descriptor}) 36 for dependency in descriptor.dependencies: 37 if not dependency.name in seen_files: 38 # descriptors cannot have circular dependencies 39 _collect_transitive_dependencies(dependency, seen_files) 40 41 42def _file_descriptor_response(descriptor): 43 # collect all dependencies 44 descriptors = {} 45 _collect_transitive_dependencies(descriptor, descriptors) 46 47 # serialize all descriptors 48 serialized_proto_list = [] 49 for d_key in descriptors: 50 proto = descriptor_pb2.FileDescriptorProto() 51 descriptors[d_key].CopyToProto(proto) 52 serialized_proto_list.append(proto.SerializeToString()) 53 54 return _reflection_pb2.ServerReflectionResponse( 55 file_descriptor_response=_reflection_pb2.FileDescriptorResponse( 56 file_descriptor_proto=(serialized_proto_list) 57 ), 58 ) 59 60 61class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer): 62 """Base class for reflection servicer.""" 63 64 def __init__(self, service_names, pool=None): 65 """Constructor. 66 67 Args: 68 service_names: Iterable of fully-qualified service names available. 69 pool: An optional DescriptorPool instance. 70 """ 71 self._service_names = tuple(sorted(service_names)) 72 self._pool = _POOL if pool is None else pool 73 74 def _file_by_filename(self, filename): 75 try: 76 descriptor = self._pool.FindFileByName(filename) 77 except KeyError: 78 return _not_found_error() 79 else: 80 return _file_descriptor_response(descriptor) 81 82 def _file_containing_symbol(self, fully_qualified_name): 83 try: 84 descriptor = self._pool.FindFileContainingSymbol( 85 fully_qualified_name 86 ) 87 except KeyError: 88 return _not_found_error() 89 else: 90 return _file_descriptor_response(descriptor) 91 92 def _file_containing_extension(self, containing_type, extension_number): 93 try: 94 message_descriptor = self._pool.FindMessageTypeByName( 95 containing_type 96 ) 97 extension_descriptor = self._pool.FindExtensionByNumber( 98 message_descriptor, extension_number 99 ) 100 descriptor = self._pool.FindFileContainingSymbol( 101 extension_descriptor.full_name 102 ) 103 except KeyError: 104 return _not_found_error() 105 else: 106 return _file_descriptor_response(descriptor) 107 108 def _all_extension_numbers_of_type(self, containing_type): 109 try: 110 message_descriptor = self._pool.FindMessageTypeByName( 111 containing_type 112 ) 113 extension_numbers = tuple( 114 sorted( 115 extension.number 116 for extension in self._pool.FindAllExtensions( 117 message_descriptor 118 ) 119 ) 120 ) 121 except KeyError: 122 return _not_found_error() 123 else: 124 return _reflection_pb2.ServerReflectionResponse( 125 all_extension_numbers_response=_reflection_pb2.ExtensionNumberResponse( 126 base_type_name=message_descriptor.full_name, 127 extension_number=extension_numbers, 128 ) 129 ) 130 131 def _list_services(self): 132 return _reflection_pb2.ServerReflectionResponse( 133 list_services_response=_reflection_pb2.ListServiceResponse( 134 service=[ 135 _reflection_pb2.ServiceResponse(name=service_name) 136 for service_name in self._service_names 137 ] 138 ) 139 ) 140 141 142__all__ = ["BaseReflectionServicer"] 143