xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_reflection/grpc_reflection/v1alpha/_base.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Base implementation of reflection servicer."""
15
16from google.protobuf import descriptor_pb2
17from google.protobuf import descriptor_pool
18import grpc
19from grpc_reflection.v1alpha import reflection_pb2 as _reflection_pb2
20from grpc_reflection.v1alpha import reflection_pb2_grpc as _reflection_pb2_grpc
21
22_POOL = descriptor_pool.Default()
23
24
25def _not_found_error():
26    return _reflection_pb2.ServerReflectionResponse(
27        error_response=_reflection_pb2.ErrorResponse(
28            error_code=grpc.StatusCode.NOT_FOUND.value[0],
29            error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
30        )
31    )
32
33
34def _collect_transitive_dependencies(descriptor, seen_files):
35    seen_files.update({descriptor.name: descriptor})
36    for dependency in descriptor.dependencies:
37        if not dependency.name in seen_files:
38            # descriptors cannot have circular dependencies
39            _collect_transitive_dependencies(dependency, seen_files)
40
41
42def _file_descriptor_response(descriptor):
43    # collect all dependencies
44    descriptors = {}
45    _collect_transitive_dependencies(descriptor, descriptors)
46
47    # serialize all descriptors
48    serialized_proto_list = []
49    for d_key in descriptors:
50        proto = descriptor_pb2.FileDescriptorProto()
51        descriptors[d_key].CopyToProto(proto)
52        serialized_proto_list.append(proto.SerializeToString())
53
54    return _reflection_pb2.ServerReflectionResponse(
55        file_descriptor_response=_reflection_pb2.FileDescriptorResponse(
56            file_descriptor_proto=(serialized_proto_list)
57        ),
58    )
59
60
61class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer):
62    """Base class for reflection servicer."""
63
64    def __init__(self, service_names, pool=None):
65        """Constructor.
66
67        Args:
68            service_names: Iterable of fully-qualified service names available.
69            pool: An optional DescriptorPool instance.
70        """
71        self._service_names = tuple(sorted(service_names))
72        self._pool = _POOL if pool is None else pool
73
74    def _file_by_filename(self, filename):
75        try:
76            descriptor = self._pool.FindFileByName(filename)
77        except KeyError:
78            return _not_found_error()
79        else:
80            return _file_descriptor_response(descriptor)
81
82    def _file_containing_symbol(self, fully_qualified_name):
83        try:
84            descriptor = self._pool.FindFileContainingSymbol(
85                fully_qualified_name
86            )
87        except KeyError:
88            return _not_found_error()
89        else:
90            return _file_descriptor_response(descriptor)
91
92    def _file_containing_extension(self, containing_type, extension_number):
93        try:
94            message_descriptor = self._pool.FindMessageTypeByName(
95                containing_type
96            )
97            extension_descriptor = self._pool.FindExtensionByNumber(
98                message_descriptor, extension_number
99            )
100            descriptor = self._pool.FindFileContainingSymbol(
101                extension_descriptor.full_name
102            )
103        except KeyError:
104            return _not_found_error()
105        else:
106            return _file_descriptor_response(descriptor)
107
108    def _all_extension_numbers_of_type(self, containing_type):
109        try:
110            message_descriptor = self._pool.FindMessageTypeByName(
111                containing_type
112            )
113            extension_numbers = tuple(
114                sorted(
115                    extension.number
116                    for extension in self._pool.FindAllExtensions(
117                        message_descriptor
118                    )
119                )
120            )
121        except KeyError:
122            return _not_found_error()
123        else:
124            return _reflection_pb2.ServerReflectionResponse(
125                all_extension_numbers_response=_reflection_pb2.ExtensionNumberResponse(
126                    base_type_name=message_descriptor.full_name,
127                    extension_number=extension_numbers,
128                )
129            )
130
131    def _list_services(self):
132        return _reflection_pb2.ServerReflectionResponse(
133            list_services_response=_reflection_pb2.ListServiceResponse(
134                service=[
135                    _reflection_pb2.ServiceResponse(name=service_name)
136                    for service_name in self._service_names
137                ]
138            )
139        )
140
141
142__all__ = ["BaseReflectionServicer"]
143