1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 // * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 // * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 // * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 // Author: [email protected] (Kenton Varda)
32 // Based on original Protocol Buffers design by
33 // Sanjay Ghemawat, Jeff Dean, and others.
34
35 #include <google/protobuf/descriptor_database.h>
36
37 #include <algorithm>
38 #include <set>
39
40 #include <google/protobuf/descriptor.pb.h>
41 #include <google/protobuf/stubs/map_util.h>
42 #include <google/protobuf/stubs/stl_util.h>
43
44
45 namespace google {
46 namespace protobuf {
47
48 namespace {
RecordMessageNames(const DescriptorProto & desc_proto,const std::string & prefix,std::set<std::string> * output)49 void RecordMessageNames(const DescriptorProto& desc_proto,
50 const std::string& prefix,
51 std::set<std::string>* output) {
52 GOOGLE_CHECK(desc_proto.has_name());
53 std::string full_name = prefix.empty()
54 ? desc_proto.name()
55 : StrCat(prefix, ".", desc_proto.name());
56 output->insert(full_name);
57
58 for (const auto& d : desc_proto.nested_type()) {
59 RecordMessageNames(d, full_name, output);
60 }
61 }
62
RecordMessageNames(const FileDescriptorProto & file_proto,std::set<std::string> * output)63 void RecordMessageNames(const FileDescriptorProto& file_proto,
64 std::set<std::string>* output) {
65 for (const auto& d : file_proto.message_type()) {
66 RecordMessageNames(d, file_proto.package(), output);
67 }
68 }
69
70 template <typename Fn>
ForAllFileProtos(DescriptorDatabase * db,Fn callback,std::vector<std::string> * output)71 bool ForAllFileProtos(DescriptorDatabase* db, Fn callback,
72 std::vector<std::string>* output) {
73 std::vector<std::string> file_names;
74 if (!db->FindAllFileNames(&file_names)) {
75 return false;
76 }
77 std::set<std::string> set;
78 FileDescriptorProto file_proto;
79 for (const auto& f : file_names) {
80 file_proto.Clear();
81 if (!db->FindFileByName(f, &file_proto)) {
82 GOOGLE_LOG(ERROR) << "File not found in database (unexpected): " << f;
83 return false;
84 }
85 callback(file_proto, &set);
86 }
87 output->insert(output->end(), set.begin(), set.end());
88 return true;
89 }
90 } // namespace
91
~DescriptorDatabase()92 DescriptorDatabase::~DescriptorDatabase() {}
93
FindAllPackageNames(std::vector<std::string> * output)94 bool DescriptorDatabase::FindAllPackageNames(std::vector<std::string>* output) {
95 return ForAllFileProtos(
96 this,
97 [](const FileDescriptorProto& file_proto, std::set<std::string>* set) {
98 set->insert(file_proto.package());
99 },
100 output);
101 }
102
FindAllMessageNames(std::vector<std::string> * output)103 bool DescriptorDatabase::FindAllMessageNames(std::vector<std::string>* output) {
104 return ForAllFileProtos(
105 this,
106 [](const FileDescriptorProto& file_proto, std::set<std::string>* set) {
107 RecordMessageNames(file_proto, set);
108 },
109 output);
110 }
111
112 // ===================================================================
113
SimpleDescriptorDatabase()114 SimpleDescriptorDatabase::SimpleDescriptorDatabase() {}
~SimpleDescriptorDatabase()115 SimpleDescriptorDatabase::~SimpleDescriptorDatabase() {}
116
117 template <typename Value>
AddFile(const FileDescriptorProto & file,Value value)118 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddFile(
119 const FileDescriptorProto& file, Value value) {
120 if (!InsertIfNotPresent(&by_name_, file.name(), value)) {
121 GOOGLE_LOG(ERROR) << "File already exists in database: " << file.name();
122 return false;
123 }
124
125 // We must be careful here -- calling file.package() if file.has_package() is
126 // false could access an uninitialized static-storage variable if we are being
127 // run at startup time.
128 std::string path = file.has_package() ? file.package() : std::string();
129 if (!path.empty()) path += '.';
130
131 for (int i = 0; i < file.message_type_size(); i++) {
132 if (!AddSymbol(path + file.message_type(i).name(), value)) return false;
133 if (!AddNestedExtensions(file.name(), file.message_type(i), value))
134 return false;
135 }
136 for (int i = 0; i < file.enum_type_size(); i++) {
137 if (!AddSymbol(path + file.enum_type(i).name(), value)) return false;
138 }
139 for (int i = 0; i < file.extension_size(); i++) {
140 if (!AddSymbol(path + file.extension(i).name(), value)) return false;
141 if (!AddExtension(file.name(), file.extension(i), value)) return false;
142 }
143 for (int i = 0; i < file.service_size(); i++) {
144 if (!AddSymbol(path + file.service(i).name(), value)) return false;
145 }
146
147 return true;
148 }
149
150 namespace {
151
152 // Returns true if and only if all characters in the name are alphanumerics,
153 // underscores, or periods.
ValidateSymbolName(StringPiece name)154 bool ValidateSymbolName(StringPiece name) {
155 for (char c : name) {
156 // I don't trust ctype.h due to locales. :(
157 if (c != '.' && c != '_' && (c < '0' || c > '9') && (c < 'A' || c > 'Z') &&
158 (c < 'a' || c > 'z')) {
159 return false;
160 }
161 }
162 return true;
163 }
164
165 // Find the last key in the container which sorts less than or equal to the
166 // symbol name. Since upper_bound() returns the *first* key that sorts
167 // *greater* than the input, we want the element immediately before that.
168 template <typename Container, typename Key>
FindLastLessOrEqual(const Container * container,const Key & key)169 typename Container::const_iterator FindLastLessOrEqual(
170 const Container* container, const Key& key) {
171 auto iter = container->upper_bound(key);
172 if (iter != container->begin()) --iter;
173 return iter;
174 }
175
176 // As above, but using std::upper_bound instead.
177 template <typename Container, typename Key, typename Cmp>
FindLastLessOrEqual(const Container * container,const Key & key,const Cmp & cmp)178 typename Container::const_iterator FindLastLessOrEqual(
179 const Container* container, const Key& key, const Cmp& cmp) {
180 auto iter = std::upper_bound(container->begin(), container->end(), key, cmp);
181 if (iter != container->begin()) --iter;
182 return iter;
183 }
184
185 // True if either the arguments are equal or super_symbol identifies a
186 // parent symbol of sub_symbol (e.g. "foo.bar" is a parent of
187 // "foo.bar.baz", but not a parent of "foo.barbaz").
IsSubSymbol(StringPiece sub_symbol,StringPiece super_symbol)188 bool IsSubSymbol(StringPiece sub_symbol, StringPiece super_symbol) {
189 return sub_symbol == super_symbol ||
190 (HasPrefixString(super_symbol, sub_symbol) &&
191 super_symbol[sub_symbol.size()] == '.');
192 }
193
194 } // namespace
195
196 template <typename Value>
AddSymbol(const std::string & name,Value value)197 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddSymbol(
198 const std::string& name, Value value) {
199 // We need to make sure not to violate our map invariant.
200
201 // If the symbol name is invalid it could break our lookup algorithm (which
202 // relies on the fact that '.' sorts before all other characters that are
203 // valid in symbol names).
204 if (!ValidateSymbolName(name)) {
205 GOOGLE_LOG(ERROR) << "Invalid symbol name: " << name;
206 return false;
207 }
208
209 // Try to look up the symbol to make sure a super-symbol doesn't already
210 // exist.
211 auto iter = FindLastLessOrEqual(&by_symbol_, name);
212
213 if (iter == by_symbol_.end()) {
214 // Apparently the map is currently empty. Just insert and be done with it.
215 by_symbol_.insert(
216 typename std::map<std::string, Value>::value_type(name, value));
217 return true;
218 }
219
220 if (IsSubSymbol(iter->first, name)) {
221 GOOGLE_LOG(ERROR) << "Symbol name \"" << name
222 << "\" conflicts with the existing "
223 "symbol \""
224 << iter->first << "\".";
225 return false;
226 }
227
228 // OK, that worked. Now we have to make sure that no symbol in the map is
229 // a sub-symbol of the one we are inserting. The only symbol which could
230 // be so is the first symbol that is greater than the new symbol. Since
231 // |iter| points at the last symbol that is less than or equal, we just have
232 // to increment it.
233 ++iter;
234
235 if (iter != by_symbol_.end() && IsSubSymbol(name, iter->first)) {
236 GOOGLE_LOG(ERROR) << "Symbol name \"" << name
237 << "\" conflicts with the existing "
238 "symbol \""
239 << iter->first << "\".";
240 return false;
241 }
242
243 // OK, no conflicts.
244
245 // Insert the new symbol using the iterator as a hint, the new entry will
246 // appear immediately before the one the iterator is pointing at.
247 by_symbol_.insert(
248 iter, typename std::map<std::string, Value>::value_type(name, value));
249
250 return true;
251 }
252
253 template <typename Value>
AddNestedExtensions(const std::string & filename,const DescriptorProto & message_type,Value value)254 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddNestedExtensions(
255 const std::string& filename, const DescriptorProto& message_type,
256 Value value) {
257 for (int i = 0; i < message_type.nested_type_size(); i++) {
258 if (!AddNestedExtensions(filename, message_type.nested_type(i), value))
259 return false;
260 }
261 for (int i = 0; i < message_type.extension_size(); i++) {
262 if (!AddExtension(filename, message_type.extension(i), value)) return false;
263 }
264 return true;
265 }
266
267 template <typename Value>
AddExtension(const std::string & filename,const FieldDescriptorProto & field,Value value)268 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddExtension(
269 const std::string& filename, const FieldDescriptorProto& field,
270 Value value) {
271 if (!field.extendee().empty() && field.extendee()[0] == '.') {
272 // The extension is fully-qualified. We can use it as a lookup key in
273 // the by_symbol_ table.
274 if (!InsertIfNotPresent(
275 &by_extension_,
276 std::make_pair(field.extendee().substr(1), field.number()),
277 value)) {
278 GOOGLE_LOG(ERROR) << "Extension conflicts with extension already in database: "
279 "extend "
280 << field.extendee() << " { " << field.name() << " = "
281 << field.number() << " } from:" << filename;
282 return false;
283 }
284 } else {
285 // Not fully-qualified. We can't really do anything here, unfortunately.
286 // We don't consider this an error, though, because the descriptor is
287 // valid.
288 }
289 return true;
290 }
291
292 template <typename Value>
FindFile(const std::string & filename)293 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindFile(
294 const std::string& filename) {
295 return FindWithDefault(by_name_, filename, Value());
296 }
297
298 template <typename Value>
FindSymbol(const std::string & name)299 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindSymbol(
300 const std::string& name) {
301 auto iter = FindLastLessOrEqual(&by_symbol_, name);
302
303 return (iter != by_symbol_.end() && IsSubSymbol(iter->first, name))
304 ? iter->second
305 : Value();
306 }
307
308 template <typename Value>
FindExtension(const std::string & containing_type,int field_number)309 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindExtension(
310 const std::string& containing_type, int field_number) {
311 return FindWithDefault(
312 by_extension_, std::make_pair(containing_type, field_number), Value());
313 }
314
315 template <typename Value>
FindAllExtensionNumbers(const std::string & containing_type,std::vector<int> * output)316 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllExtensionNumbers(
317 const std::string& containing_type, std::vector<int>* output) {
318 typename std::map<std::pair<std::string, int>, Value>::const_iterator it =
319 by_extension_.lower_bound(std::make_pair(containing_type, 0));
320 bool success = false;
321
322 for (; it != by_extension_.end() && it->first.first == containing_type;
323 ++it) {
324 output->push_back(it->first.second);
325 success = true;
326 }
327
328 return success;
329 }
330
331 template <typename Value>
FindAllFileNames(std::vector<std::string> * output)332 void SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllFileNames(
333 std::vector<std::string>* output) {
334 output->resize(by_name_.size());
335 int i = 0;
336 for (const auto& kv : by_name_) {
337 (*output)[i] = kv.first;
338 i++;
339 }
340 }
341
342 // -------------------------------------------------------------------
343
Add(const FileDescriptorProto & file)344 bool SimpleDescriptorDatabase::Add(const FileDescriptorProto& file) {
345 FileDescriptorProto* new_file = new FileDescriptorProto;
346 new_file->CopyFrom(file);
347 return AddAndOwn(new_file);
348 }
349
AddAndOwn(const FileDescriptorProto * file)350 bool SimpleDescriptorDatabase::AddAndOwn(const FileDescriptorProto* file) {
351 files_to_delete_.emplace_back(file);
352 return index_.AddFile(*file, file);
353 }
354
FindFileByName(const std::string & filename,FileDescriptorProto * output)355 bool SimpleDescriptorDatabase::FindFileByName(const std::string& filename,
356 FileDescriptorProto* output) {
357 return MaybeCopy(index_.FindFile(filename), output);
358 }
359
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)360 bool SimpleDescriptorDatabase::FindFileContainingSymbol(
361 const std::string& symbol_name, FileDescriptorProto* output) {
362 return MaybeCopy(index_.FindSymbol(symbol_name), output);
363 }
364
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)365 bool SimpleDescriptorDatabase::FindFileContainingExtension(
366 const std::string& containing_type, int field_number,
367 FileDescriptorProto* output) {
368 return MaybeCopy(index_.FindExtension(containing_type, field_number), output);
369 }
370
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)371 bool SimpleDescriptorDatabase::FindAllExtensionNumbers(
372 const std::string& extendee_type, std::vector<int>* output) {
373 return index_.FindAllExtensionNumbers(extendee_type, output);
374 }
375
376
FindAllFileNames(std::vector<std::string> * output)377 bool SimpleDescriptorDatabase::FindAllFileNames(
378 std::vector<std::string>* output) {
379 index_.FindAllFileNames(output);
380 return true;
381 }
382
MaybeCopy(const FileDescriptorProto * file,FileDescriptorProto * output)383 bool SimpleDescriptorDatabase::MaybeCopy(const FileDescriptorProto* file,
384 FileDescriptorProto* output) {
385 if (file == nullptr) return false;
386 output->CopyFrom(*file);
387 return true;
388 }
389
390 // -------------------------------------------------------------------
391
392 class EncodedDescriptorDatabase::DescriptorIndex {
393 public:
394 using Value = std::pair<const void*, int>;
395 // Helpers to recursively add particular descriptors and all their contents
396 // to the index.
397 template <typename FileProto>
398 bool AddFile(const FileProto& file, Value value);
399
400 Value FindFile(StringPiece filename);
401 Value FindSymbol(StringPiece name);
402 Value FindSymbolOnlyFlat(StringPiece name) const;
403 Value FindExtension(StringPiece containing_type, int field_number);
404 bool FindAllExtensionNumbers(StringPiece containing_type,
405 std::vector<int>* output);
406 void FindAllFileNames(std::vector<std::string>* output) const;
407
408 private:
409 friend class EncodedDescriptorDatabase;
410
411 bool AddSymbol(StringPiece symbol);
412
413 template <typename DescProto>
414 bool AddNestedExtensions(StringPiece filename,
415 const DescProto& message_type);
416 template <typename FieldProto>
417 bool AddExtension(StringPiece filename, const FieldProto& field);
418
419 // All the maps below have two representations:
420 // - a std::set<> where we insert initially.
421 // - a std::vector<> where we flatten the structure on demand.
422 // The initial tree helps avoid O(N) behavior of inserting into a sorted
423 // vector, while the vector reduces the heap requirements of the data
424 // structure.
425
426 void EnsureFlat();
427
428 using String = std::string;
429
EncodeString(StringPiece str) const430 String EncodeString(StringPiece str) const { return String(str); }
DecodeString(const String & str,int) const431 StringPiece DecodeString(const String& str, int) const { return str; }
432
433 struct EncodedEntry {
434 // Do not use `Value` here to avoid the padding of that object.
435 const void* data;
436 int size;
437 // Keep the package here instead of each SymbolEntry to save space.
438 String encoded_package;
439
valuegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::EncodedEntry440 Value value() const { return {data, size}; }
441 };
442 std::vector<EncodedEntry> all_values_;
443
444 struct FileEntry {
445 int data_offset;
446 String encoded_name;
447
namegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileEntry448 StringPiece name(const DescriptorIndex& index) const {
449 return index.DecodeString(encoded_name, data_offset);
450 }
451 };
452 struct FileCompare {
453 const DescriptorIndex& index;
454
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare455 bool operator()(const FileEntry& a, const FileEntry& b) const {
456 return a.name(index) < b.name(index);
457 }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare458 bool operator()(const FileEntry& a, StringPiece b) const {
459 return a.name(index) < b;
460 }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare461 bool operator()(StringPiece a, const FileEntry& b) const {
462 return a < b.name(index);
463 }
464 };
465 std::set<FileEntry, FileCompare> by_name_{FileCompare{*this}};
466 std::vector<FileEntry> by_name_flat_;
467
468 struct SymbolEntry {
469 int data_offset;
470 String encoded_symbol;
471
packagegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry472 StringPiece package(const DescriptorIndex& index) const {
473 return index.DecodeString(index.all_values_[data_offset].encoded_package,
474 data_offset);
475 }
symbolgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry476 StringPiece symbol(const DescriptorIndex& index) const {
477 return index.DecodeString(encoded_symbol, data_offset);
478 }
479
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry480 std::string AsString(const DescriptorIndex& index) const {
481 auto p = package(index);
482 return StrCat(p, p.empty() ? "" : ".", symbol(index));
483 }
484 };
485
486 struct SymbolCompare {
487 const DescriptorIndex& index;
488
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare489 std::string AsString(const SymbolEntry& entry) const {
490 return entry.AsString(index);
491 }
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare492 static StringPiece AsString(StringPiece str) { return str; }
493
GetPartsgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare494 std::pair<StringPiece, StringPiece> GetParts(
495 const SymbolEntry& entry) const {
496 auto package = entry.package(index);
497 if (package.empty()) return {entry.symbol(index), StringPiece{}};
498 return {package, entry.symbol(index)};
499 }
GetPartsgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare500 std::pair<StringPiece, StringPiece> GetParts(
501 StringPiece str) const {
502 return {str, {}};
503 }
504
505 template <typename T, typename U>
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare506 bool operator()(const T& lhs, const U& rhs) const {
507 auto lhs_parts = GetParts(lhs);
508 auto rhs_parts = GetParts(rhs);
509
510 // Fast path to avoid making the whole string for common cases.
511 if (int res =
512 lhs_parts.first.substr(0, rhs_parts.first.size())
513 .compare(rhs_parts.first.substr(0, lhs_parts.first.size()))) {
514 // If the packages already differ, exit early.
515 return res < 0;
516 } else if (lhs_parts.first.size() == rhs_parts.first.size()) {
517 return lhs_parts.second < rhs_parts.second;
518 }
519 return AsString(lhs) < AsString(rhs);
520 }
521 };
522 std::set<SymbolEntry, SymbolCompare> by_symbol_{SymbolCompare{*this}};
523 std::vector<SymbolEntry> by_symbol_flat_;
524
525 struct ExtensionEntry {
526 int data_offset;
527 String encoded_extendee;
extendeegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionEntry528 StringPiece extendee(const DescriptorIndex& index) const {
529 return index.DecodeString(encoded_extendee, data_offset).substr(1);
530 }
531 int extension_number;
532 };
533 struct ExtensionCompare {
534 const DescriptorIndex& index;
535
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare536 bool operator()(const ExtensionEntry& a, const ExtensionEntry& b) const {
537 return std::make_tuple(a.extendee(index), a.extension_number) <
538 std::make_tuple(b.extendee(index), b.extension_number);
539 }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare540 bool operator()(const ExtensionEntry& a,
541 std::tuple<StringPiece, int> b) const {
542 return std::make_tuple(a.extendee(index), a.extension_number) < b;
543 }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare544 bool operator()(std::tuple<StringPiece, int> a,
545 const ExtensionEntry& b) const {
546 return a < std::make_tuple(b.extendee(index), b.extension_number);
547 }
548 };
549 std::set<ExtensionEntry, ExtensionCompare> by_extension_{
550 ExtensionCompare{*this}};
551 std::vector<ExtensionEntry> by_extension_flat_;
552 };
553
Add(const void * encoded_file_descriptor,int size)554 bool EncodedDescriptorDatabase::Add(const void* encoded_file_descriptor,
555 int size) {
556 FileDescriptorProto file;
557 if (file.ParseFromArray(encoded_file_descriptor, size)) {
558 return index_->AddFile(file, std::make_pair(encoded_file_descriptor, size));
559 } else {
560 GOOGLE_LOG(ERROR) << "Invalid file descriptor data passed to "
561 "EncodedDescriptorDatabase::Add().";
562 return false;
563 }
564 }
565
AddCopy(const void * encoded_file_descriptor,int size)566 bool EncodedDescriptorDatabase::AddCopy(const void* encoded_file_descriptor,
567 int size) {
568 void* copy = operator new(size);
569 memcpy(copy, encoded_file_descriptor, size);
570 files_to_delete_.push_back(copy);
571 return Add(copy, size);
572 }
573
FindFileByName(const std::string & filename,FileDescriptorProto * output)574 bool EncodedDescriptorDatabase::FindFileByName(const std::string& filename,
575 FileDescriptorProto* output) {
576 return MaybeParse(index_->FindFile(filename), output);
577 }
578
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)579 bool EncodedDescriptorDatabase::FindFileContainingSymbol(
580 const std::string& symbol_name, FileDescriptorProto* output) {
581 return MaybeParse(index_->FindSymbol(symbol_name), output);
582 }
583
FindNameOfFileContainingSymbol(const std::string & symbol_name,std::string * output)584 bool EncodedDescriptorDatabase::FindNameOfFileContainingSymbol(
585 const std::string& symbol_name, std::string* output) {
586 auto encoded_file = index_->FindSymbol(symbol_name);
587 if (encoded_file.first == nullptr) return false;
588
589 // Optimization: The name should be the first field in the encoded message.
590 // Try to just read it directly.
591 io::CodedInputStream input(static_cast<const uint8_t*>(encoded_file.first),
592 encoded_file.second);
593
594 const uint32_t kNameTag = internal::WireFormatLite::MakeTag(
595 FileDescriptorProto::kNameFieldNumber,
596 internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
597
598 if (input.ReadTagNoLastTag() == kNameTag) {
599 // Success!
600 return internal::WireFormatLite::ReadString(&input, output);
601 } else {
602 // Slow path. Parse whole message.
603 FileDescriptorProto file_proto;
604 if (!file_proto.ParseFromArray(encoded_file.first, encoded_file.second)) {
605 return false;
606 }
607 *output = file_proto.name();
608 return true;
609 }
610 }
611
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)612 bool EncodedDescriptorDatabase::FindFileContainingExtension(
613 const std::string& containing_type, int field_number,
614 FileDescriptorProto* output) {
615 return MaybeParse(index_->FindExtension(containing_type, field_number),
616 output);
617 }
618
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)619 bool EncodedDescriptorDatabase::FindAllExtensionNumbers(
620 const std::string& extendee_type, std::vector<int>* output) {
621 return index_->FindAllExtensionNumbers(extendee_type, output);
622 }
623
624 template <typename FileProto>
AddFile(const FileProto & file,Value value)625 bool EncodedDescriptorDatabase::DescriptorIndex::AddFile(const FileProto& file,
626 Value value) {
627 // We push `value` into the array first. This is important because the AddXXX
628 // functions below will expect it to be there.
629 all_values_.push_back({value.first, value.second, {}});
630
631 if (!ValidateSymbolName(file.package())) {
632 GOOGLE_LOG(ERROR) << "Invalid package name: " << file.package();
633 return false;
634 }
635 all_values_.back().encoded_package = EncodeString(file.package());
636
637 if (!InsertIfNotPresent(
638 &by_name_, FileEntry{static_cast<int>(all_values_.size() - 1),
639 EncodeString(file.name())}) ||
640 std::binary_search(by_name_flat_.begin(), by_name_flat_.end(),
641 file.name(), by_name_.key_comp())) {
642 GOOGLE_LOG(ERROR) << "File already exists in database: " << file.name();
643 return false;
644 }
645
646 for (const auto& message_type : file.message_type()) {
647 if (!AddSymbol(message_type.name())) return false;
648 if (!AddNestedExtensions(file.name(), message_type)) return false;
649 }
650 for (const auto& enum_type : file.enum_type()) {
651 if (!AddSymbol(enum_type.name())) return false;
652 }
653 for (const auto& extension : file.extension()) {
654 if (!AddSymbol(extension.name())) return false;
655 if (!AddExtension(file.name(), extension)) return false;
656 }
657 for (const auto& service : file.service()) {
658 if (!AddSymbol(service.name())) return false;
659 }
660
661 return true;
662 }
663
664 template <typename Iter, typename Iter2, typename Index>
CheckForMutualSubsymbols(StringPiece symbol_name,Iter * iter,Iter2 end,const Index & index)665 static bool CheckForMutualSubsymbols(StringPiece symbol_name, Iter* iter,
666 Iter2 end, const Index& index) {
667 if (*iter != end) {
668 if (IsSubSymbol((*iter)->AsString(index), symbol_name)) {
669 GOOGLE_LOG(ERROR) << "Symbol name \"" << symbol_name
670 << "\" conflicts with the existing symbol \""
671 << (*iter)->AsString(index) << "\".";
672 return false;
673 }
674
675 // OK, that worked. Now we have to make sure that no symbol in the map is
676 // a sub-symbol of the one we are inserting. The only symbol which could
677 // be so is the first symbol that is greater than the new symbol. Since
678 // |iter| points at the last symbol that is less than or equal, we just have
679 // to increment it.
680 ++*iter;
681
682 if (*iter != end && IsSubSymbol(symbol_name, (*iter)->AsString(index))) {
683 GOOGLE_LOG(ERROR) << "Symbol name \"" << symbol_name
684 << "\" conflicts with the existing symbol \""
685 << (*iter)->AsString(index) << "\".";
686 return false;
687 }
688 }
689 return true;
690 }
691
AddSymbol(StringPiece symbol)692 bool EncodedDescriptorDatabase::DescriptorIndex::AddSymbol(
693 StringPiece symbol) {
694 SymbolEntry entry = {static_cast<int>(all_values_.size() - 1),
695 EncodeString(symbol)};
696 std::string entry_as_string = entry.AsString(*this);
697
698 // We need to make sure not to violate our map invariant.
699
700 // If the symbol name is invalid it could break our lookup algorithm (which
701 // relies on the fact that '.' sorts before all other characters that are
702 // valid in symbol names).
703 if (!ValidateSymbolName(symbol)) {
704 GOOGLE_LOG(ERROR) << "Invalid symbol name: " << entry_as_string;
705 return false;
706 }
707
708 auto iter = FindLastLessOrEqual(&by_symbol_, entry);
709 if (!CheckForMutualSubsymbols(entry_as_string, &iter, by_symbol_.end(),
710 *this)) {
711 return false;
712 }
713
714 // Same, but on by_symbol_flat_
715 auto flat_iter =
716 FindLastLessOrEqual(&by_symbol_flat_, entry, by_symbol_.key_comp());
717 if (!CheckForMutualSubsymbols(entry_as_string, &flat_iter,
718 by_symbol_flat_.end(), *this)) {
719 return false;
720 }
721
722 // OK, no conflicts.
723
724 // Insert the new symbol using the iterator as a hint, the new entry will
725 // appear immediately before the one the iterator is pointing at.
726 by_symbol_.insert(iter, entry);
727
728 return true;
729 }
730
731 template <typename DescProto>
AddNestedExtensions(StringPiece filename,const DescProto & message_type)732 bool EncodedDescriptorDatabase::DescriptorIndex::AddNestedExtensions(
733 StringPiece filename, const DescProto& message_type) {
734 for (const auto& nested_type : message_type.nested_type()) {
735 if (!AddNestedExtensions(filename, nested_type)) return false;
736 }
737 for (const auto& extension : message_type.extension()) {
738 if (!AddExtension(filename, extension)) return false;
739 }
740 return true;
741 }
742
743 template <typename FieldProto>
AddExtension(StringPiece filename,const FieldProto & field)744 bool EncodedDescriptorDatabase::DescriptorIndex::AddExtension(
745 StringPiece filename, const FieldProto& field) {
746 if (!field.extendee().empty() && field.extendee()[0] == '.') {
747 // The extension is fully-qualified. We can use it as a lookup key in
748 // the by_symbol_ table.
749 if (!InsertIfNotPresent(
750 &by_extension_,
751 ExtensionEntry{static_cast<int>(all_values_.size() - 1),
752 EncodeString(field.extendee()), field.number()}) ||
753 std::binary_search(
754 by_extension_flat_.begin(), by_extension_flat_.end(),
755 std::make_pair(field.extendee().substr(1), field.number()),
756 by_extension_.key_comp())) {
757 GOOGLE_LOG(ERROR) << "Extension conflicts with extension already in database: "
758 "extend "
759 << field.extendee() << " { " << field.name() << " = "
760 << field.number() << " } from:" << filename;
761 return false;
762 }
763 } else {
764 // Not fully-qualified. We can't really do anything here, unfortunately.
765 // We don't consider this an error, though, because the descriptor is
766 // valid.
767 }
768 return true;
769 }
770
771 std::pair<const void*, int>
FindSymbol(StringPiece name)772 EncodedDescriptorDatabase::DescriptorIndex::FindSymbol(StringPiece name) {
773 EnsureFlat();
774 return FindSymbolOnlyFlat(name);
775 }
776
777 std::pair<const void*, int>
FindSymbolOnlyFlat(StringPiece name) const778 EncodedDescriptorDatabase::DescriptorIndex::FindSymbolOnlyFlat(
779 StringPiece name) const {
780 auto iter =
781 FindLastLessOrEqual(&by_symbol_flat_, name, by_symbol_.key_comp());
782
783 return iter != by_symbol_flat_.end() &&
784 IsSubSymbol(iter->AsString(*this), name)
785 ? all_values_[iter->data_offset].value()
786 : Value();
787 }
788
789 std::pair<const void*, int>
FindExtension(StringPiece containing_type,int field_number)790 EncodedDescriptorDatabase::DescriptorIndex::FindExtension(
791 StringPiece containing_type, int field_number) {
792 EnsureFlat();
793
794 auto it = std::lower_bound(
795 by_extension_flat_.begin(), by_extension_flat_.end(),
796 std::make_tuple(containing_type, field_number), by_extension_.key_comp());
797 return it == by_extension_flat_.end() ||
798 it->extendee(*this) != containing_type ||
799 it->extension_number != field_number
800 ? std::make_pair(nullptr, 0)
801 : all_values_[it->data_offset].value();
802 }
803
804 template <typename T, typename Less>
MergeIntoFlat(std::set<T,Less> * s,std::vector<T> * flat)805 static void MergeIntoFlat(std::set<T, Less>* s, std::vector<T>* flat) {
806 if (s->empty()) return;
807 std::vector<T> new_flat(s->size() + flat->size());
808 std::merge(s->begin(), s->end(), flat->begin(), flat->end(), &new_flat[0],
809 s->key_comp());
810 *flat = std::move(new_flat);
811 s->clear();
812 }
813
EnsureFlat()814 void EncodedDescriptorDatabase::DescriptorIndex::EnsureFlat() {
815 all_values_.shrink_to_fit();
816 // Merge each of the sets into their flat counterpart.
817 MergeIntoFlat(&by_name_, &by_name_flat_);
818 MergeIntoFlat(&by_symbol_, &by_symbol_flat_);
819 MergeIntoFlat(&by_extension_, &by_extension_flat_);
820 }
821
FindAllExtensionNumbers(StringPiece containing_type,std::vector<int> * output)822 bool EncodedDescriptorDatabase::DescriptorIndex::FindAllExtensionNumbers(
823 StringPiece containing_type, std::vector<int>* output) {
824 EnsureFlat();
825
826 bool success = false;
827 auto it = std::lower_bound(
828 by_extension_flat_.begin(), by_extension_flat_.end(),
829 std::make_tuple(containing_type, 0), by_extension_.key_comp());
830 for (;
831 it != by_extension_flat_.end() && it->extendee(*this) == containing_type;
832 ++it) {
833 output->push_back(it->extension_number);
834 success = true;
835 }
836
837 return success;
838 }
839
FindAllFileNames(std::vector<std::string> * output) const840 void EncodedDescriptorDatabase::DescriptorIndex::FindAllFileNames(
841 std::vector<std::string>* output) const {
842 output->resize(by_name_.size() + by_name_flat_.size());
843 int i = 0;
844 for (const auto& entry : by_name_) {
845 (*output)[i] = std::string(entry.name(*this));
846 i++;
847 }
848 for (const auto& entry : by_name_flat_) {
849 (*output)[i] = std::string(entry.name(*this));
850 i++;
851 }
852 }
853
854 std::pair<const void*, int>
FindFile(StringPiece filename)855 EncodedDescriptorDatabase::DescriptorIndex::FindFile(
856 StringPiece filename) {
857 EnsureFlat();
858
859 auto it = std::lower_bound(by_name_flat_.begin(), by_name_flat_.end(),
860 filename, by_name_.key_comp());
861 return it == by_name_flat_.end() || it->name(*this) != filename
862 ? std::make_pair(nullptr, 0)
863 : all_values_[it->data_offset].value();
864 }
865
866
FindAllFileNames(std::vector<std::string> * output)867 bool EncodedDescriptorDatabase::FindAllFileNames(
868 std::vector<std::string>* output) {
869 index_->FindAllFileNames(output);
870 return true;
871 }
872
MaybeParse(std::pair<const void *,int> encoded_file,FileDescriptorProto * output)873 bool EncodedDescriptorDatabase::MaybeParse(
874 std::pair<const void*, int> encoded_file, FileDescriptorProto* output) {
875 if (encoded_file.first == nullptr) return false;
876 return output->ParseFromArray(encoded_file.first, encoded_file.second);
877 }
878
EncodedDescriptorDatabase()879 EncodedDescriptorDatabase::EncodedDescriptorDatabase()
880 : index_(new DescriptorIndex()) {}
881
~EncodedDescriptorDatabase()882 EncodedDescriptorDatabase::~EncodedDescriptorDatabase() {
883 for (void* p : files_to_delete_) {
884 operator delete(p);
885 }
886 }
887
888 // ===================================================================
889
DescriptorPoolDatabase(const DescriptorPool & pool)890 DescriptorPoolDatabase::DescriptorPoolDatabase(const DescriptorPool& pool)
891 : pool_(pool) {}
~DescriptorPoolDatabase()892 DescriptorPoolDatabase::~DescriptorPoolDatabase() {}
893
FindFileByName(const std::string & filename,FileDescriptorProto * output)894 bool DescriptorPoolDatabase::FindFileByName(const std::string& filename,
895 FileDescriptorProto* output) {
896 const FileDescriptor* file = pool_.FindFileByName(filename);
897 if (file == nullptr) return false;
898 output->Clear();
899 file->CopyTo(output);
900 return true;
901 }
902
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)903 bool DescriptorPoolDatabase::FindFileContainingSymbol(
904 const std::string& symbol_name, FileDescriptorProto* output) {
905 const FileDescriptor* file = pool_.FindFileContainingSymbol(symbol_name);
906 if (file == nullptr) return false;
907 output->Clear();
908 file->CopyTo(output);
909 return true;
910 }
911
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)912 bool DescriptorPoolDatabase::FindFileContainingExtension(
913 const std::string& containing_type, int field_number,
914 FileDescriptorProto* output) {
915 const Descriptor* extendee = pool_.FindMessageTypeByName(containing_type);
916 if (extendee == nullptr) return false;
917
918 const FieldDescriptor* extension =
919 pool_.FindExtensionByNumber(extendee, field_number);
920 if (extension == nullptr) return false;
921
922 output->Clear();
923 extension->file()->CopyTo(output);
924 return true;
925 }
926
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)927 bool DescriptorPoolDatabase::FindAllExtensionNumbers(
928 const std::string& extendee_type, std::vector<int>* output) {
929 const Descriptor* extendee = pool_.FindMessageTypeByName(extendee_type);
930 if (extendee == nullptr) return false;
931
932 std::vector<const FieldDescriptor*> extensions;
933 pool_.FindAllExtensions(extendee, &extensions);
934
935 for (const FieldDescriptor* extension : extensions) {
936 output->push_back(extension->number());
937 }
938
939 return true;
940 }
941
942 // ===================================================================
943
MergedDescriptorDatabase(DescriptorDatabase * source1,DescriptorDatabase * source2)944 MergedDescriptorDatabase::MergedDescriptorDatabase(
945 DescriptorDatabase* source1, DescriptorDatabase* source2) {
946 sources_.push_back(source1);
947 sources_.push_back(source2);
948 }
MergedDescriptorDatabase(const std::vector<DescriptorDatabase * > & sources)949 MergedDescriptorDatabase::MergedDescriptorDatabase(
950 const std::vector<DescriptorDatabase*>& sources)
951 : sources_(sources) {}
~MergedDescriptorDatabase()952 MergedDescriptorDatabase::~MergedDescriptorDatabase() {}
953
FindFileByName(const std::string & filename,FileDescriptorProto * output)954 bool MergedDescriptorDatabase::FindFileByName(const std::string& filename,
955 FileDescriptorProto* output) {
956 for (DescriptorDatabase* source : sources_) {
957 if (source->FindFileByName(filename, output)) {
958 return true;
959 }
960 }
961 return false;
962 }
963
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)964 bool MergedDescriptorDatabase::FindFileContainingSymbol(
965 const std::string& symbol_name, FileDescriptorProto* output) {
966 for (size_t i = 0; i < sources_.size(); i++) {
967 if (sources_[i]->FindFileContainingSymbol(symbol_name, output)) {
968 // The symbol was found in source i. However, if one of the previous
969 // sources defines a file with the same name (which presumably doesn't
970 // contain the symbol, since it wasn't found in that source), then we
971 // must hide it from the caller.
972 FileDescriptorProto temp;
973 for (size_t j = 0; j < i; j++) {
974 if (sources_[j]->FindFileByName(output->name(), &temp)) {
975 // Found conflicting file in a previous source.
976 return false;
977 }
978 }
979 return true;
980 }
981 }
982 return false;
983 }
984
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)985 bool MergedDescriptorDatabase::FindFileContainingExtension(
986 const std::string& containing_type, int field_number,
987 FileDescriptorProto* output) {
988 for (size_t i = 0; i < sources_.size(); i++) {
989 if (sources_[i]->FindFileContainingExtension(containing_type, field_number,
990 output)) {
991 // The symbol was found in source i. However, if one of the previous
992 // sources defines a file with the same name (which presumably doesn't
993 // contain the symbol, since it wasn't found in that source), then we
994 // must hide it from the caller.
995 FileDescriptorProto temp;
996 for (size_t j = 0; j < i; j++) {
997 if (sources_[j]->FindFileByName(output->name(), &temp)) {
998 // Found conflicting file in a previous source.
999 return false;
1000 }
1001 }
1002 return true;
1003 }
1004 }
1005 return false;
1006 }
1007
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)1008 bool MergedDescriptorDatabase::FindAllExtensionNumbers(
1009 const std::string& extendee_type, std::vector<int>* output) {
1010 std::set<int> merged_results;
1011 std::vector<int> results;
1012 bool success = false;
1013
1014 for (DescriptorDatabase* source : sources_) {
1015 if (source->FindAllExtensionNumbers(extendee_type, &results)) {
1016 std::copy(results.begin(), results.end(),
1017 std::insert_iterator<std::set<int> >(merged_results,
1018 merged_results.begin()));
1019 success = true;
1020 }
1021 results.clear();
1022 }
1023
1024 std::copy(merged_results.begin(), merged_results.end(),
1025 std::insert_iterator<std::vector<int> >(*output, output->end()));
1026
1027 return success;
1028 }
1029
1030
FindAllFileNames(std::vector<std::string> * output)1031 bool MergedDescriptorDatabase::FindAllFileNames(
1032 std::vector<std::string>* output) {
1033 bool implemented = false;
1034 for (DescriptorDatabase* source : sources_) {
1035 std::vector<std::string> source_output;
1036 if (source->FindAllFileNames(&source_output)) {
1037 output->reserve(output->size() + source_output.size());
1038 for (auto& source : source_output) {
1039 output->push_back(std::move(source));
1040 }
1041 implemented = true;
1042 }
1043 }
1044 return implemented;
1045 }
1046
1047 } // namespace protobuf
1048 } // namespace google
1049