1 //
2 //
3 // Copyright 2016 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include "src/cpp/ext/proto_server_reflection.h"
20
21 #include <unordered_set>
22 #include <vector>
23
24 #include <grpcpp/grpcpp.h>
25 #include <grpcpp/support/interceptor.h>
26
27 // IWYU pragma: no_include "google/protobuf/descriptor.h"
28 // IWYU pragma: no_include <google/protobuf/descriptor.h>
29
30 using grpc::reflection::v1alpha::ErrorResponse;
31 using grpc::reflection::v1alpha::ExtensionNumberResponse;
32 using grpc::reflection::v1alpha::ExtensionRequest;
33 using grpc::reflection::v1alpha::ListServiceResponse;
34 using grpc::reflection::v1alpha::ServerReflectionRequest;
35 using grpc::reflection::v1alpha::ServerReflectionResponse;
36 using grpc::reflection::v1alpha::ServiceResponse;
37
38 namespace grpc {
39
ProtoServerReflection()40 ProtoServerReflection::ProtoServerReflection()
41 : descriptor_pool_(protobuf::DescriptorPool::generated_pool()) {}
42
SetServiceList(const std::vector<std::string> * services)43 void ProtoServerReflection::SetServiceList(
44 const std::vector<std::string>* services) {
45 services_ = services;
46 }
47
ServerReflectionInfo(ServerContext * context,ServerReaderWriter<ServerReflectionResponse,ServerReflectionRequest> * stream)48 Status ProtoServerReflection::ServerReflectionInfo(
49 ServerContext* context,
50 ServerReaderWriter<ServerReflectionResponse, ServerReflectionRequest>*
51 stream) {
52 ServerReflectionRequest request;
53 ServerReflectionResponse response;
54 Status status;
55 while (stream->Read(&request)) {
56 switch (request.message_request_case()) {
57 case ServerReflectionRequest::MessageRequestCase::kFileByFilename:
58 status = GetFileByName(context, request.file_by_filename(), &response);
59 break;
60 case ServerReflectionRequest::MessageRequestCase::kFileContainingSymbol:
61 status = GetFileContainingSymbol(
62 context, request.file_containing_symbol(), &response);
63 break;
64 case ServerReflectionRequest::MessageRequestCase::
65 kFileContainingExtension:
66 status = GetFileContainingExtension(
67 context, &request.file_containing_extension(), &response);
68 break;
69 case ServerReflectionRequest::MessageRequestCase::
70 kAllExtensionNumbersOfType:
71 status = GetAllExtensionNumbers(
72 context, request.all_extension_numbers_of_type(),
73 response.mutable_all_extension_numbers_response());
74 break;
75 case ServerReflectionRequest::MessageRequestCase::kListServices:
76 status =
77 ListService(context, response.mutable_list_services_response());
78 break;
79 default:
80 status = Status(StatusCode::UNIMPLEMENTED, "");
81 }
82
83 if (!status.ok()) {
84 FillErrorResponse(status, response.mutable_error_response());
85 }
86 response.set_valid_host(request.host());
87 response.set_allocated_original_request(
88 new ServerReflectionRequest(request));
89 stream->Write(response);
90 }
91
92 return Status::OK;
93 }
94
FillErrorResponse(const Status & status,ErrorResponse * error_response)95 void ProtoServerReflection::FillErrorResponse(const Status& status,
96 ErrorResponse* error_response) {
97 error_response->set_error_code(status.error_code());
98 error_response->set_error_message(status.error_message());
99 }
100
ListService(ServerContext *,ListServiceResponse * response)101 Status ProtoServerReflection::ListService(ServerContext* /*context*/,
102 ListServiceResponse* response) {
103 if (services_ == nullptr) {
104 return Status(StatusCode::NOT_FOUND, "Services not found.");
105 }
106 for (const auto& value : *services_) {
107 ServiceResponse* service_response = response->add_service();
108 service_response->set_name(value);
109 }
110 return Status::OK;
111 }
112
GetFileByName(ServerContext *,const std::string & file_name,ServerReflectionResponse * response)113 Status ProtoServerReflection::GetFileByName(
114 ServerContext* /*context*/, const std::string& file_name,
115 ServerReflectionResponse* response) {
116 if (descriptor_pool_ == nullptr) {
117 return Status::CANCELLED;
118 }
119
120 const protobuf::FileDescriptor* file_desc =
121 descriptor_pool_->FindFileByName(file_name);
122 if (file_desc == nullptr) {
123 return Status(StatusCode::NOT_FOUND, "File not found.");
124 }
125 std::unordered_set<std::string> seen_files;
126 FillFileDescriptorResponse(file_desc, response, &seen_files);
127 return Status::OK;
128 }
129
GetFileContainingSymbol(ServerContext *,const std::string & symbol,ServerReflectionResponse * response)130 Status ProtoServerReflection::GetFileContainingSymbol(
131 ServerContext* /*context*/, const std::string& symbol,
132 ServerReflectionResponse* response) {
133 if (descriptor_pool_ == nullptr) {
134 return Status::CANCELLED;
135 }
136
137 const protobuf::FileDescriptor* file_desc =
138 descriptor_pool_->FindFileContainingSymbol(symbol);
139 if (file_desc == nullptr) {
140 return Status(StatusCode::NOT_FOUND, "Symbol not found.");
141 }
142 std::unordered_set<std::string> seen_files;
143 FillFileDescriptorResponse(file_desc, response, &seen_files);
144 return Status::OK;
145 }
146
GetFileContainingExtension(ServerContext *,const ExtensionRequest * request,ServerReflectionResponse * response)147 Status ProtoServerReflection::GetFileContainingExtension(
148 ServerContext* /*context*/, const ExtensionRequest* request,
149 ServerReflectionResponse* response) {
150 if (descriptor_pool_ == nullptr) {
151 return Status::CANCELLED;
152 }
153
154 const protobuf::Descriptor* desc =
155 descriptor_pool_->FindMessageTypeByName(request->containing_type());
156 if (desc == nullptr) {
157 return Status(StatusCode::NOT_FOUND, "Type not found.");
158 }
159
160 const protobuf::FieldDescriptor* field_desc =
161 descriptor_pool_->FindExtensionByNumber(desc,
162 request->extension_number());
163 if (field_desc == nullptr) {
164 return Status(StatusCode::NOT_FOUND, "Extension not found.");
165 }
166 std::unordered_set<std::string> seen_files;
167 FillFileDescriptorResponse(field_desc->file(), response, &seen_files);
168 return Status::OK;
169 }
170
GetAllExtensionNumbers(ServerContext *,const std::string & type,ExtensionNumberResponse * response)171 Status ProtoServerReflection::GetAllExtensionNumbers(
172 ServerContext* /*context*/, const std::string& type,
173 ExtensionNumberResponse* response) {
174 if (descriptor_pool_ == nullptr) {
175 return Status::CANCELLED;
176 }
177
178 const protobuf::Descriptor* desc =
179 descriptor_pool_->FindMessageTypeByName(type);
180 if (desc == nullptr) {
181 return Status(StatusCode::NOT_FOUND, "Type not found.");
182 }
183
184 std::vector<const protobuf::FieldDescriptor*> extensions;
185 descriptor_pool_->FindAllExtensions(desc, &extensions);
186 for (const auto& value : extensions) {
187 response->add_extension_number(value->number());
188 }
189 response->set_base_type_name(type);
190 return Status::OK;
191 }
192
FillFileDescriptorResponse(const protobuf::FileDescriptor * file_desc,ServerReflectionResponse * response,std::unordered_set<std::string> * seen_files)193 void ProtoServerReflection::FillFileDescriptorResponse(
194 const protobuf::FileDescriptor* file_desc,
195 ServerReflectionResponse* response,
196 std::unordered_set<std::string>* seen_files) {
197 if (seen_files->find(file_desc->name()) != seen_files->end()) {
198 return;
199 }
200 seen_files->insert(file_desc->name());
201
202 protobuf::FileDescriptorProto file_desc_proto;
203 std::string data;
204 file_desc->CopyTo(&file_desc_proto);
205 file_desc_proto.SerializeToString(&data);
206 response->mutable_file_descriptor_response()->add_file_descriptor_proto(data);
207
208 for (int i = 0; i < file_desc->dependency_count(); ++i) {
209 FillFileDescriptorResponse(file_desc->dependency(i), response, seen_files);
210 }
211 }
212
213 } // namespace grpc
214