1 //
2 //
3 // Copyright 2016 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include "test/cpp/util/proto_file_parser.h"
20
21 #include <algorithm>
22 #include <iostream>
23 #include <sstream>
24 #include <unordered_set>
25
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_split.h"
28
29 #include <grpcpp/support/config.h>
30
31 namespace grpc {
32 namespace testing {
33 namespace {
34
35 // Match the user input method string to the full_name from method descriptor.
MethodNameMatch(const std::string & full_name,const std::string & input)36 bool MethodNameMatch(const std::string& full_name, const std::string& input) {
37 std::string clean_input = input;
38 std::replace(clean_input.begin(), clean_input.end(), '/', '.');
39 if (clean_input.size() > full_name.size()) {
40 return false;
41 }
42 return full_name.compare(full_name.size() - clean_input.size(),
43 clean_input.size(), clean_input) == 0;
44 }
45 } // namespace
46
47 class ErrorPrinter : public protobuf::compiler::MultiFileErrorCollector {
48 public:
ErrorPrinter(ProtoFileParser * parser)49 explicit ErrorPrinter(ProtoFileParser* parser) : parser_(parser) {}
50
AddError(const std::string & filename,int line,int column,const std::string & message)51 void AddError(const std::string& filename, int line, int column,
52 const std::string& message) override {
53 std::ostringstream oss;
54 oss << "error " << filename << " " << line << " " << column << " "
55 << message << "\n";
56 parser_->LogError(oss.str());
57 }
58
AddWarning(const std::string & filename,int line,int column,const std::string & message)59 void AddWarning(const std::string& filename, int line, int column,
60 const std::string& message) override {
61 std::cerr << "warning " << filename << " " << line << " " << column << " "
62 << message << std::endl;
63 }
64
65 private:
66 ProtoFileParser* parser_; // not owned
67 };
68
ProtoFileParser(const std::shared_ptr<grpc::Channel> & channel,const std::string & proto_path,const std::string & protofiles)69 ProtoFileParser::ProtoFileParser(const std::shared_ptr<grpc::Channel>& channel,
70 const std::string& proto_path,
71 const std::string& protofiles)
72 : has_error_(false),
73 dynamic_factory_(new protobuf::DynamicMessageFactory()) {
74 std::vector<std::string> service_list;
75 if (channel) {
76 reflection_db_ =
77 std::make_unique<grpc::ProtoReflectionDescriptorDatabase>(channel);
78 reflection_db_->GetServices(&service_list);
79 }
80
81 std::unordered_set<std::string> known_services;
82 if (!protofiles.empty()) {
83 for (const absl::string_view single_path : absl::StrSplit(
84 proto_path, GRPC_CLI_PATH_SEPARATOR, absl::AllowEmpty())) {
85 source_tree_.MapPath("", std::string(single_path));
86 }
87 error_printer_ = std::make_unique<ErrorPrinter>(this);
88 importer_ = std::make_unique<protobuf::compiler::Importer>(
89 &source_tree_, error_printer_.get());
90
91 std::string file_name;
92 std::stringstream ss(protofiles);
93 while (std::getline(ss, file_name, ',')) {
94 const auto* file_desc = importer_->Import(file_name);
95 if (file_desc) {
96 for (int i = 0; i < file_desc->service_count(); i++) {
97 service_desc_list_.push_back(file_desc->service(i));
98 known_services.insert(file_desc->service(i)->full_name());
99 }
100 } else {
101 std::cerr << file_name << " not found" << std::endl;
102 }
103 }
104
105 file_db_ =
106 std::make_unique<protobuf::DescriptorPoolDatabase>(*importer_->pool());
107 }
108
109 if (!reflection_db_ && !file_db_) {
110 LogError("No available proto database");
111 return;
112 }
113
114 if (!reflection_db_) {
115 desc_db_ = std::move(file_db_);
116 } else if (!file_db_) {
117 desc_db_ = std::move(reflection_db_);
118 } else {
119 desc_db_ = std::make_unique<protobuf::MergedDescriptorDatabase>(
120 reflection_db_.get(), file_db_.get());
121 }
122
123 desc_pool_ = std::make_unique<protobuf::DescriptorPool>(desc_db_.get());
124
125 for (auto it = service_list.begin(); it != service_list.end(); it++) {
126 if (known_services.find(*it) == known_services.end()) {
127 if (const protobuf::ServiceDescriptor* service_desc =
128 desc_pool_->FindServiceByName(*it)) {
129 service_desc_list_.push_back(service_desc);
130 known_services.insert(*it);
131 }
132 }
133 }
134 }
135
~ProtoFileParser()136 ProtoFileParser::~ProtoFileParser() {}
137
GetFullMethodName(const std::string & method)138 std::string ProtoFileParser::GetFullMethodName(const std::string& method) {
139 has_error_ = false;
140
141 if (known_methods_.find(method) != known_methods_.end()) {
142 return known_methods_[method];
143 }
144
145 const protobuf::MethodDescriptor* method_descriptor = nullptr;
146 for (auto it = service_desc_list_.begin(); it != service_desc_list_.end();
147 it++) {
148 const auto* service_desc = *it;
149 for (int j = 0; j < service_desc->method_count(); j++) {
150 const auto* method_desc = service_desc->method(j);
151 if (MethodNameMatch(method_desc->full_name(), method)) {
152 if (method_descriptor) {
153 std::ostringstream error_stream;
154 error_stream << "Ambiguous method names: ";
155 error_stream << method_descriptor->full_name() << " ";
156 error_stream << method_desc->full_name();
157 LogError(error_stream.str());
158 }
159 method_descriptor = method_desc;
160 }
161 }
162 }
163 if (!method_descriptor) {
164 LogError("Method name not found");
165 }
166 if (has_error_) {
167 return "";
168 }
169
170 known_methods_[method] = method_descriptor->full_name();
171
172 return method_descriptor->full_name();
173 }
174
GetFormattedMethodName(const std::string & method)175 std::string ProtoFileParser::GetFormattedMethodName(const std::string& method) {
176 has_error_ = false;
177 std::string formatted_method_name = GetFullMethodName(method);
178 if (has_error_) {
179 return "";
180 }
181 size_t last_dot = formatted_method_name.find_last_of('.');
182 if (last_dot != std::string::npos) {
183 formatted_method_name[last_dot] = '/';
184 }
185 formatted_method_name.insert(formatted_method_name.begin(), '/');
186 return formatted_method_name;
187 }
188
GetMessageTypeFromMethod(const std::string & method,bool is_request)189 std::string ProtoFileParser::GetMessageTypeFromMethod(const std::string& method,
190 bool is_request) {
191 has_error_ = false;
192 std::string full_method_name = GetFullMethodName(method);
193 if (has_error_) {
194 return "";
195 }
196 const protobuf::MethodDescriptor* method_desc =
197 desc_pool_->FindMethodByName(full_method_name);
198 if (!method_desc) {
199 LogError("Method not found");
200 return "";
201 }
202
203 return is_request ? method_desc->input_type()->full_name()
204 : method_desc->output_type()->full_name();
205 }
206
IsStreaming(const std::string & method,bool is_request)207 bool ProtoFileParser::IsStreaming(const std::string& method, bool is_request) {
208 has_error_ = false;
209
210 std::string full_method_name = GetFullMethodName(method);
211 if (has_error_) {
212 return false;
213 }
214
215 const protobuf::MethodDescriptor* method_desc =
216 desc_pool_->FindMethodByName(full_method_name);
217 if (!method_desc) {
218 LogError("Method not found");
219 return false;
220 }
221
222 return is_request ? method_desc->client_streaming()
223 : method_desc->server_streaming();
224 }
225
GetSerializedProtoFromMethod(const std::string & method,const std::string & formatted_proto,bool is_request,bool is_json_format)226 std::string ProtoFileParser::GetSerializedProtoFromMethod(
227 const std::string& method, const std::string& formatted_proto,
228 bool is_request, bool is_json_format) {
229 has_error_ = false;
230 std::string message_type_name = GetMessageTypeFromMethod(method, is_request);
231 if (has_error_) {
232 return "";
233 }
234 return GetSerializedProtoFromMessageType(message_type_name, formatted_proto,
235 is_json_format);
236 }
237
GetFormattedStringFromMethod(const std::string & method,const std::string & serialized_proto,bool is_request,bool is_json_format)238 std::string ProtoFileParser::GetFormattedStringFromMethod(
239 const std::string& method, const std::string& serialized_proto,
240 bool is_request, bool is_json_format) {
241 has_error_ = false;
242 std::string message_type_name = GetMessageTypeFromMethod(method, is_request);
243 if (has_error_) {
244 return "";
245 }
246 return GetFormattedStringFromMessageType(message_type_name, serialized_proto,
247 is_json_format);
248 }
249
GetSerializedProtoFromMessageType(const std::string & message_type_name,const std::string & formatted_proto,bool is_json_format)250 std::string ProtoFileParser::GetSerializedProtoFromMessageType(
251 const std::string& message_type_name, const std::string& formatted_proto,
252 bool is_json_format) {
253 has_error_ = false;
254 std::string serialized;
255 const protobuf::Descriptor* desc =
256 desc_pool_->FindMessageTypeByName(message_type_name);
257 if (!desc) {
258 LogError("Message type not found");
259 return "";
260 }
261 std::unique_ptr<grpc::protobuf::Message> msg(
262 dynamic_factory_->GetPrototype(desc)->New());
263
264 bool ok;
265 if (is_json_format) {
266 ok = grpc::protobuf::json::JsonStringToMessage(formatted_proto, msg.get())
267 .ok();
268 if (!ok) {
269 LogError("Failed to convert json format to proto.");
270 return "";
271 }
272 } else {
273 ok = protobuf::TextFormat::ParseFromString(formatted_proto, msg.get());
274 if (!ok) {
275 LogError("Failed to convert text format to proto.");
276 return "";
277 }
278 }
279
280 ok = msg->SerializeToString(&serialized);
281 if (!ok) {
282 LogError("Failed to serialize proto.");
283 return "";
284 }
285 return serialized;
286 }
287
GetFormattedStringFromMessageType(const std::string & message_type_name,const std::string & serialized_proto,bool is_json_format)288 std::string ProtoFileParser::GetFormattedStringFromMessageType(
289 const std::string& message_type_name, const std::string& serialized_proto,
290 bool is_json_format) {
291 has_error_ = false;
292 const protobuf::Descriptor* desc =
293 desc_pool_->FindMessageTypeByName(message_type_name);
294 if (!desc) {
295 LogError("Message type not found");
296 return "";
297 }
298 std::unique_ptr<grpc::protobuf::Message> msg(
299 dynamic_factory_->GetPrototype(desc)->New());
300 if (!msg->ParseFromString(serialized_proto)) {
301 LogError("Failed to deserialize proto.");
302 return "";
303 }
304 std::string formatted_string;
305
306 if (is_json_format) {
307 grpc::protobuf::json::JsonPrintOptions jsonPrintOptions;
308 jsonPrintOptions.add_whitespace = true;
309 if (!grpc::protobuf::json::MessageToJsonString(*msg, &formatted_string,
310 jsonPrintOptions)
311 .ok()) {
312 LogError("Failed to print proto message to json format");
313 return "";
314 }
315 } else {
316 if (!protobuf::TextFormat::PrintToString(*msg, &formatted_string)) {
317 LogError("Failed to print proto message to text format");
318 return "";
319 }
320 }
321 return formatted_string;
322 }
323
LogError(const std::string & error_msg)324 void ProtoFileParser::LogError(const std::string& error_msg) {
325 if (!error_msg.empty()) {
326 std::cerr << error_msg << std::endl;
327 }
328 has_error_ = true;
329 }
330
331 } // namespace testing
332 } // namespace grpc
333