1# Copyright 2022 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Reference implementation for reflection client in gRPC Python.
15
16For usage instructions, see the Python Reflection documentation at
17``doc/python/server_reflection.md``.
18"""
19
20import logging
21from typing import Any, Dict, Iterable, List, Set
22
23from google.protobuf.descriptor_database import DescriptorDatabase
24from google.protobuf.descriptor_pb2 import FileDescriptorProto
25import grpc
26from grpc_reflection.v1alpha.reflection_pb2 import ExtensionNumberResponse
27from grpc_reflection.v1alpha.reflection_pb2 import ExtensionRequest
28from grpc_reflection.v1alpha.reflection_pb2 import FileDescriptorResponse
29from grpc_reflection.v1alpha.reflection_pb2 import ListServiceResponse
30from grpc_reflection.v1alpha.reflection_pb2 import ServerReflectionRequest
31from grpc_reflection.v1alpha.reflection_pb2 import ServerReflectionResponse
32from grpc_reflection.v1alpha.reflection_pb2 import ServiceResponse
33from grpc_reflection.v1alpha.reflection_pb2_grpc import ServerReflectionStub
34
35
36class ProtoReflectionDescriptorDatabase(DescriptorDatabase):
37    """
38    A container and interface for receiving descriptors from a server's
39    Reflection service.
40
41    ProtoReflectionDescriptorDatabase takes a channel to a server with
42    Reflection service, and provides an interface to retrieve the Reflection
43    information. It implements the DescriptorDatabase interface.
44
45    It is typically used to feed a DescriptorPool instance.
46    """
47
48    # Implementation based on C++ version found here (version tag 1.39.1):
49    #   grpc/test/cpp/util/proto_reflection_descriptor_database.cc
50    # while implementing the Python interface given here:
51    #   https://googleapis.dev/python/protobuf/3.17.0/google/protobuf/descriptor_database.html
52
53    def __init__(self, channel: grpc.Channel):
54        DescriptorDatabase.__init__(self)
55        self._logger = logging.getLogger(__name__)
56        self._stub = ServerReflectionStub(channel)
57        self._known_files: Set[str] = set()
58        self._cached_extension_numbers: Dict[str, List[int]] = dict()
59
60    def get_services(self) -> Iterable[str]:
61        """
62        Get list of full names of the registered services.
63
64        Returns:
65            A list of strings corresponding to the names of the services.
66        """
67
68        request = ServerReflectionRequest(list_services="")
69        response = self._do_one_request(request, key="")
70        list_services: ListServiceResponse = response.list_services_response
71        services: List[ServiceResponse] = list_services.service
72        return [service.name for service in services]
73
74    def FindFileByName(self, name: str) -> FileDescriptorProto:
75        """
76        Find a file descriptor by file name.
77
78        This function implements a DescriptorDatabase interface, and is
79        typically not called directly; prefer using a DescriptorPool instead.
80
81        Args:
82            name: The name of the file. Typically this is a relative path ending in ".proto".
83
84        Returns:
85            A FileDescriptorProto for the file.
86
87        Raises:
88            KeyError: the file was not found.
89        """
90
91        try:
92            return super().FindFileByName(name)
93        except KeyError:
94            pass
95        assert name not in self._known_files
96        request = ServerReflectionRequest(file_by_filename=name)
97        response = self._do_one_request(request, key=name)
98        self._add_file_from_response(response.file_descriptor_response)
99        return super().FindFileByName(name)
100
101    def FindFileContainingSymbol(self, symbol: str) -> FileDescriptorProto:
102        """
103        Find the file containing the symbol, and return its file descriptor.
104
105        The symbol should be a fully qualified name including the file
106        descriptor's package and any containing messages. Some examples:
107
108            * "some.package.name.Message"
109            * "some.package.name.Message.NestedEnum"
110            * "some.package.name.Message.some_field"
111
112        This function implements a DescriptorDatabase interface, and is
113        typically not called directly; prefer using a DescriptorPool instead.
114
115        Args:
116            symbol: The fully-qualified name of the symbol.
117
118        Returns:
119            FileDescriptorProto for the file containing the symbol.
120
121        Raises:
122            KeyError: the symbol was not found.
123        """
124
125        try:
126            return super().FindFileContainingSymbol(symbol)
127        except KeyError:
128            pass
129        # Query the server
130        request = ServerReflectionRequest(file_containing_symbol=symbol)
131        response = self._do_one_request(request, key=symbol)
132        self._add_file_from_response(response.file_descriptor_response)
133        return super().FindFileContainingSymbol(symbol)
134
135    def FindAllExtensionNumbers(self, extendee_name: str) -> Iterable[int]:
136        """
137        Find the field numbers used by all known extensions of `extendee_name`.
138
139        This function implements a DescriptorDatabase interface, and is
140        typically not called directly; prefer using a DescriptorPool instead.
141
142        Args:
143            extendee_name: fully-qualified name of the extended message type.
144
145        Returns:
146            A list of field numbers used by all known extensions.
147
148        Raises:
149            KeyError: The message type `extendee_name` was not found.
150        """
151
152        if extendee_name in self._cached_extension_numbers:
153            return self._cached_extension_numbers[extendee_name]
154        request = ServerReflectionRequest(
155            all_extension_numbers_of_type=extendee_name
156        )
157        response = self._do_one_request(request, key=extendee_name)
158        all_extension_numbers: ExtensionNumberResponse = (
159            response.all_extension_numbers_response
160        )
161        numbers = list(all_extension_numbers.extension_number)
162        self._cached_extension_numbers[extendee_name] = numbers
163        return numbers
164
165    def FindFileContainingExtension(
166        self, extendee_name: str, extension_number: int
167    ) -> FileDescriptorProto:
168        """
169        Find the file which defines an extension for the given message type
170        and field number.
171
172        This function implements a DescriptorDatabase interface, and is
173        typically not called directly; prefer using a DescriptorPool instead.
174
175        Args:
176            extendee_name: fully-qualified name of the extended message type.
177            extension_number: the number of the extension field.
178
179        Returns:
180            FileDescriptorProto for the file containing the extension.
181
182        Raises:
183            KeyError: The message or the extension number were not found.
184        """
185
186        try:
187            return super().FindFileContainingExtension(
188                extendee_name, extension_number
189            )
190        except KeyError:
191            pass
192        request = ServerReflectionRequest(
193            file_containing_extension=ExtensionRequest(
194                containing_type=extendee_name, extension_number=extension_number
195            )
196        )
197        response = self._do_one_request(
198            request, key=(extendee_name, extension_number)
199        )
200        file_desc = response.file_descriptor_response
201        self._add_file_from_response(file_desc)
202        return super().FindFileContainingExtension(
203            extendee_name, extension_number
204        )
205
206    def _do_one_request(
207        self, request: ServerReflectionRequest, key: Any
208    ) -> ServerReflectionResponse:
209        response = self._stub.ServerReflectionInfo(iter([request]))
210        res = next(response)
211        if res.WhichOneof("message_response") == "error_response":
212            # Only NOT_FOUND errors are expected at this layer
213            error_code = res.error_response.error_code
214            assert (
215                error_code == grpc.StatusCode.NOT_FOUND.value[0]
216            ), "unexpected error response: " + repr(res.error_response)
217            raise KeyError(key)
218        return res
219
220    def _add_file_from_response(
221        self, file_descriptor: FileDescriptorResponse
222    ) -> None:
223        protos: List[bytes] = file_descriptor.file_descriptor_proto
224        for proto in protos:
225            desc = FileDescriptorProto()
226            desc.ParseFromString(proto)
227            if desc.name not in self._known_files:
228                self._logger.info(
229                    "Loading descriptors from file: %s", desc.name
230                )
231                self._known_files.add(desc.name)
232                self.Add(desc)
233