1 //
2 //
3 // Copyright 2016 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include "src/cpp/ext/proto_server_reflection.h"
20 
21 #include <unordered_set>
22 #include <vector>
23 
24 #include <grpcpp/grpcpp.h>
25 #include <grpcpp/support/interceptor.h>
26 
27 // IWYU pragma: no_include "google/protobuf/descriptor.h"
28 // IWYU pragma: no_include <google/protobuf/descriptor.h>
29 
30 using grpc::reflection::v1alpha::ErrorResponse;
31 using grpc::reflection::v1alpha::ExtensionNumberResponse;
32 using grpc::reflection::v1alpha::ExtensionRequest;
33 using grpc::reflection::v1alpha::ListServiceResponse;
34 using grpc::reflection::v1alpha::ServerReflectionRequest;
35 using grpc::reflection::v1alpha::ServerReflectionResponse;
36 using grpc::reflection::v1alpha::ServiceResponse;
37 
38 namespace grpc {
39 
ProtoServerReflection()40 ProtoServerReflection::ProtoServerReflection()
41     : descriptor_pool_(protobuf::DescriptorPool::generated_pool()) {}
42 
SetServiceList(const std::vector<std::string> * services)43 void ProtoServerReflection::SetServiceList(
44     const std::vector<std::string>* services) {
45   services_ = services;
46 }
47 
ServerReflectionInfo(ServerContext * context,ServerReaderWriter<ServerReflectionResponse,ServerReflectionRequest> * stream)48 Status ProtoServerReflection::ServerReflectionInfo(
49     ServerContext* context,
50     ServerReaderWriter<ServerReflectionResponse, ServerReflectionRequest>*
51         stream) {
52   ServerReflectionRequest request;
53   ServerReflectionResponse response;
54   Status status;
55   while (stream->Read(&request)) {
56     switch (request.message_request_case()) {
57       case ServerReflectionRequest::MessageRequestCase::kFileByFilename:
58         status = GetFileByName(context, request.file_by_filename(), &response);
59         break;
60       case ServerReflectionRequest::MessageRequestCase::kFileContainingSymbol:
61         status = GetFileContainingSymbol(
62             context, request.file_containing_symbol(), &response);
63         break;
64       case ServerReflectionRequest::MessageRequestCase::
65           kFileContainingExtension:
66         status = GetFileContainingExtension(
67             context, &request.file_containing_extension(), &response);
68         break;
69       case ServerReflectionRequest::MessageRequestCase::
70           kAllExtensionNumbersOfType:
71         status = GetAllExtensionNumbers(
72             context, request.all_extension_numbers_of_type(),
73             response.mutable_all_extension_numbers_response());
74         break;
75       case ServerReflectionRequest::MessageRequestCase::kListServices:
76         status =
77             ListService(context, response.mutable_list_services_response());
78         break;
79       default:
80         status = Status(StatusCode::UNIMPLEMENTED, "");
81     }
82 
83     if (!status.ok()) {
84       FillErrorResponse(status, response.mutable_error_response());
85     }
86     response.set_valid_host(request.host());
87     response.set_allocated_original_request(
88         new ServerReflectionRequest(request));
89     stream->Write(response);
90   }
91 
92   return Status::OK;
93 }
94 
FillErrorResponse(const Status & status,ErrorResponse * error_response)95 void ProtoServerReflection::FillErrorResponse(const Status& status,
96                                               ErrorResponse* error_response) {
97   error_response->set_error_code(status.error_code());
98   error_response->set_error_message(status.error_message());
99 }
100 
ListService(ServerContext *,ListServiceResponse * response)101 Status ProtoServerReflection::ListService(ServerContext* /*context*/,
102                                           ListServiceResponse* response) {
103   if (services_ == nullptr) {
104     return Status(StatusCode::NOT_FOUND, "Services not found.");
105   }
106   for (const auto& value : *services_) {
107     ServiceResponse* service_response = response->add_service();
108     service_response->set_name(value);
109   }
110   return Status::OK;
111 }
112 
GetFileByName(ServerContext *,const std::string & file_name,ServerReflectionResponse * response)113 Status ProtoServerReflection::GetFileByName(
114     ServerContext* /*context*/, const std::string& file_name,
115     ServerReflectionResponse* response) {
116   if (descriptor_pool_ == nullptr) {
117     return Status::CANCELLED;
118   }
119 
120   const protobuf::FileDescriptor* file_desc =
121       descriptor_pool_->FindFileByName(file_name);
122   if (file_desc == nullptr) {
123     return Status(StatusCode::NOT_FOUND, "File not found.");
124   }
125   std::unordered_set<std::string> seen_files;
126   FillFileDescriptorResponse(file_desc, response, &seen_files);
127   return Status::OK;
128 }
129 
GetFileContainingSymbol(ServerContext *,const std::string & symbol,ServerReflectionResponse * response)130 Status ProtoServerReflection::GetFileContainingSymbol(
131     ServerContext* /*context*/, const std::string& symbol,
132     ServerReflectionResponse* response) {
133   if (descriptor_pool_ == nullptr) {
134     return Status::CANCELLED;
135   }
136 
137   const protobuf::FileDescriptor* file_desc =
138       descriptor_pool_->FindFileContainingSymbol(symbol);
139   if (file_desc == nullptr) {
140     return Status(StatusCode::NOT_FOUND, "Symbol not found.");
141   }
142   std::unordered_set<std::string> seen_files;
143   FillFileDescriptorResponse(file_desc, response, &seen_files);
144   return Status::OK;
145 }
146 
GetFileContainingExtension(ServerContext *,const ExtensionRequest * request,ServerReflectionResponse * response)147 Status ProtoServerReflection::GetFileContainingExtension(
148     ServerContext* /*context*/, const ExtensionRequest* request,
149     ServerReflectionResponse* response) {
150   if (descriptor_pool_ == nullptr) {
151     return Status::CANCELLED;
152   }
153 
154   const protobuf::Descriptor* desc =
155       descriptor_pool_->FindMessageTypeByName(request->containing_type());
156   if (desc == nullptr) {
157     return Status(StatusCode::NOT_FOUND, "Type not found.");
158   }
159 
160   const protobuf::FieldDescriptor* field_desc =
161       descriptor_pool_->FindExtensionByNumber(desc,
162                                               request->extension_number());
163   if (field_desc == nullptr) {
164     return Status(StatusCode::NOT_FOUND, "Extension not found.");
165   }
166   std::unordered_set<std::string> seen_files;
167   FillFileDescriptorResponse(field_desc->file(), response, &seen_files);
168   return Status::OK;
169 }
170 
GetAllExtensionNumbers(ServerContext *,const std::string & type,ExtensionNumberResponse * response)171 Status ProtoServerReflection::GetAllExtensionNumbers(
172     ServerContext* /*context*/, const std::string& type,
173     ExtensionNumberResponse* response) {
174   if (descriptor_pool_ == nullptr) {
175     return Status::CANCELLED;
176   }
177 
178   const protobuf::Descriptor* desc =
179       descriptor_pool_->FindMessageTypeByName(type);
180   if (desc == nullptr) {
181     return Status(StatusCode::NOT_FOUND, "Type not found.");
182   }
183 
184   std::vector<const protobuf::FieldDescriptor*> extensions;
185   descriptor_pool_->FindAllExtensions(desc, &extensions);
186   for (const auto& value : extensions) {
187     response->add_extension_number(value->number());
188   }
189   response->set_base_type_name(type);
190   return Status::OK;
191 }
192 
FillFileDescriptorResponse(const protobuf::FileDescriptor * file_desc,ServerReflectionResponse * response,std::unordered_set<std::string> * seen_files)193 void ProtoServerReflection::FillFileDescriptorResponse(
194     const protobuf::FileDescriptor* file_desc,
195     ServerReflectionResponse* response,
196     std::unordered_set<std::string>* seen_files) {
197   if (seen_files->find(file_desc->name()) != seen_files->end()) {
198     return;
199   }
200   seen_files->insert(file_desc->name());
201 
202   protobuf::FileDescriptorProto file_desc_proto;
203   std::string data;
204   file_desc->CopyTo(&file_desc_proto);
205   file_desc_proto.SerializeToString(&data);
206   response->mutable_file_descriptor_response()->add_file_descriptor_proto(data);
207 
208   for (int i = 0; i < file_desc->dependency_count(); ++i) {
209     FillFileDescriptorResponse(file_desc->dependency(i), response, seen_files);
210   }
211 }
212 
213 }  // namespace grpc
214