xref: /aosp_15_r20/external/cronet/third_party/protobuf/src/google/protobuf/descriptor_database.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: [email protected] (Kenton Varda)
32 //  Based on original Protocol Buffers design by
33 //  Sanjay Ghemawat, Jeff Dean, and others.
34 
35 #include <google/protobuf/descriptor_database.h>
36 
37 #include <algorithm>
38 #include <set>
39 
40 #include <google/protobuf/descriptor.pb.h>
41 #include <google/protobuf/stubs/map_util.h>
42 #include <google/protobuf/stubs/stl_util.h>
43 
44 
45 namespace google {
46 namespace protobuf {
47 
48 namespace {
RecordMessageNames(const DescriptorProto & desc_proto,const std::string & prefix,std::set<std::string> * output)49 void RecordMessageNames(const DescriptorProto& desc_proto,
50                         const std::string& prefix,
51                         std::set<std::string>* output) {
52   GOOGLE_CHECK(desc_proto.has_name());
53   std::string full_name = prefix.empty()
54                               ? desc_proto.name()
55                               : StrCat(prefix, ".", desc_proto.name());
56   output->insert(full_name);
57 
58   for (const auto& d : desc_proto.nested_type()) {
59     RecordMessageNames(d, full_name, output);
60   }
61 }
62 
RecordMessageNames(const FileDescriptorProto & file_proto,std::set<std::string> * output)63 void RecordMessageNames(const FileDescriptorProto& file_proto,
64                         std::set<std::string>* output) {
65   for (const auto& d : file_proto.message_type()) {
66     RecordMessageNames(d, file_proto.package(), output);
67   }
68 }
69 
70 template <typename Fn>
ForAllFileProtos(DescriptorDatabase * db,Fn callback,std::vector<std::string> * output)71 bool ForAllFileProtos(DescriptorDatabase* db, Fn callback,
72                       std::vector<std::string>* output) {
73   std::vector<std::string> file_names;
74   if (!db->FindAllFileNames(&file_names)) {
75     return false;
76   }
77   std::set<std::string> set;
78   FileDescriptorProto file_proto;
79   for (const auto& f : file_names) {
80     file_proto.Clear();
81     if (!db->FindFileByName(f, &file_proto)) {
82       GOOGLE_LOG(ERROR) << "File not found in database (unexpected): " << f;
83       return false;
84     }
85     callback(file_proto, &set);
86   }
87   output->insert(output->end(), set.begin(), set.end());
88   return true;
89 }
90 }  // namespace
91 
~DescriptorDatabase()92 DescriptorDatabase::~DescriptorDatabase() {}
93 
FindAllPackageNames(std::vector<std::string> * output)94 bool DescriptorDatabase::FindAllPackageNames(std::vector<std::string>* output) {
95   return ForAllFileProtos(
96       this,
97       [](const FileDescriptorProto& file_proto, std::set<std::string>* set) {
98         set->insert(file_proto.package());
99       },
100       output);
101 }
102 
FindAllMessageNames(std::vector<std::string> * output)103 bool DescriptorDatabase::FindAllMessageNames(std::vector<std::string>* output) {
104   return ForAllFileProtos(
105       this,
106       [](const FileDescriptorProto& file_proto, std::set<std::string>* set) {
107         RecordMessageNames(file_proto, set);
108       },
109       output);
110 }
111 
112 // ===================================================================
113 
SimpleDescriptorDatabase()114 SimpleDescriptorDatabase::SimpleDescriptorDatabase() {}
~SimpleDescriptorDatabase()115 SimpleDescriptorDatabase::~SimpleDescriptorDatabase() {}
116 
117 template <typename Value>
AddFile(const FileDescriptorProto & file,Value value)118 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddFile(
119     const FileDescriptorProto& file, Value value) {
120   if (!InsertIfNotPresent(&by_name_, file.name(), value)) {
121     GOOGLE_LOG(ERROR) << "File already exists in database: " << file.name();
122     return false;
123   }
124 
125   // We must be careful here -- calling file.package() if file.has_package() is
126   // false could access an uninitialized static-storage variable if we are being
127   // run at startup time.
128   std::string path = file.has_package() ? file.package() : std::string();
129   if (!path.empty()) path += '.';
130 
131   for (int i = 0; i < file.message_type_size(); i++) {
132     if (!AddSymbol(path + file.message_type(i).name(), value)) return false;
133     if (!AddNestedExtensions(file.name(), file.message_type(i), value))
134       return false;
135   }
136   for (int i = 0; i < file.enum_type_size(); i++) {
137     if (!AddSymbol(path + file.enum_type(i).name(), value)) return false;
138   }
139   for (int i = 0; i < file.extension_size(); i++) {
140     if (!AddSymbol(path + file.extension(i).name(), value)) return false;
141     if (!AddExtension(file.name(), file.extension(i), value)) return false;
142   }
143   for (int i = 0; i < file.service_size(); i++) {
144     if (!AddSymbol(path + file.service(i).name(), value)) return false;
145   }
146 
147   return true;
148 }
149 
150 namespace {
151 
152 // Returns true if and only if all characters in the name are alphanumerics,
153 // underscores, or periods.
ValidateSymbolName(StringPiece name)154 bool ValidateSymbolName(StringPiece name) {
155   for (char c : name) {
156     // I don't trust ctype.h due to locales.  :(
157     if (c != '.' && c != '_' && (c < '0' || c > '9') && (c < 'A' || c > 'Z') &&
158         (c < 'a' || c > 'z')) {
159       return false;
160     }
161   }
162   return true;
163 }
164 
165 // Find the last key in the container which sorts less than or equal to the
166 // symbol name.  Since upper_bound() returns the *first* key that sorts
167 // *greater* than the input, we want the element immediately before that.
168 template <typename Container, typename Key>
FindLastLessOrEqual(const Container * container,const Key & key)169 typename Container::const_iterator FindLastLessOrEqual(
170     const Container* container, const Key& key) {
171   auto iter = container->upper_bound(key);
172   if (iter != container->begin()) --iter;
173   return iter;
174 }
175 
176 // As above, but using std::upper_bound instead.
177 template <typename Container, typename Key, typename Cmp>
FindLastLessOrEqual(const Container * container,const Key & key,const Cmp & cmp)178 typename Container::const_iterator FindLastLessOrEqual(
179     const Container* container, const Key& key, const Cmp& cmp) {
180   auto iter = std::upper_bound(container->begin(), container->end(), key, cmp);
181   if (iter != container->begin()) --iter;
182   return iter;
183 }
184 
185 // True if either the arguments are equal or super_symbol identifies a
186 // parent symbol of sub_symbol (e.g. "foo.bar" is a parent of
187 // "foo.bar.baz", but not a parent of "foo.barbaz").
IsSubSymbol(StringPiece sub_symbol,StringPiece super_symbol)188 bool IsSubSymbol(StringPiece sub_symbol, StringPiece super_symbol) {
189   return sub_symbol == super_symbol ||
190          (HasPrefixString(super_symbol, sub_symbol) &&
191           super_symbol[sub_symbol.size()] == '.');
192 }
193 
194 }  // namespace
195 
196 template <typename Value>
AddSymbol(const std::string & name,Value value)197 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddSymbol(
198     const std::string& name, Value value) {
199   // We need to make sure not to violate our map invariant.
200 
201   // If the symbol name is invalid it could break our lookup algorithm (which
202   // relies on the fact that '.' sorts before all other characters that are
203   // valid in symbol names).
204   if (!ValidateSymbolName(name)) {
205     GOOGLE_LOG(ERROR) << "Invalid symbol name: " << name;
206     return false;
207   }
208 
209   // Try to look up the symbol to make sure a super-symbol doesn't already
210   // exist.
211   auto iter = FindLastLessOrEqual(&by_symbol_, name);
212 
213   if (iter == by_symbol_.end()) {
214     // Apparently the map is currently empty.  Just insert and be done with it.
215     by_symbol_.insert(
216         typename std::map<std::string, Value>::value_type(name, value));
217     return true;
218   }
219 
220   if (IsSubSymbol(iter->first, name)) {
221     GOOGLE_LOG(ERROR) << "Symbol name \"" << name
222                << "\" conflicts with the existing "
223                   "symbol \""
224                << iter->first << "\".";
225     return false;
226   }
227 
228   // OK, that worked.  Now we have to make sure that no symbol in the map is
229   // a sub-symbol of the one we are inserting.  The only symbol which could
230   // be so is the first symbol that is greater than the new symbol.  Since
231   // |iter| points at the last symbol that is less than or equal, we just have
232   // to increment it.
233   ++iter;
234 
235   if (iter != by_symbol_.end() && IsSubSymbol(name, iter->first)) {
236     GOOGLE_LOG(ERROR) << "Symbol name \"" << name
237                << "\" conflicts with the existing "
238                   "symbol \""
239                << iter->first << "\".";
240     return false;
241   }
242 
243   // OK, no conflicts.
244 
245   // Insert the new symbol using the iterator as a hint, the new entry will
246   // appear immediately before the one the iterator is pointing at.
247   by_symbol_.insert(
248       iter, typename std::map<std::string, Value>::value_type(name, value));
249 
250   return true;
251 }
252 
253 template <typename Value>
AddNestedExtensions(const std::string & filename,const DescriptorProto & message_type,Value value)254 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddNestedExtensions(
255     const std::string& filename, const DescriptorProto& message_type,
256     Value value) {
257   for (int i = 0; i < message_type.nested_type_size(); i++) {
258     if (!AddNestedExtensions(filename, message_type.nested_type(i), value))
259       return false;
260   }
261   for (int i = 0; i < message_type.extension_size(); i++) {
262     if (!AddExtension(filename, message_type.extension(i), value)) return false;
263   }
264   return true;
265 }
266 
267 template <typename Value>
AddExtension(const std::string & filename,const FieldDescriptorProto & field,Value value)268 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::AddExtension(
269     const std::string& filename, const FieldDescriptorProto& field,
270     Value value) {
271   if (!field.extendee().empty() && field.extendee()[0] == '.') {
272     // The extension is fully-qualified.  We can use it as a lookup key in
273     // the by_symbol_ table.
274     if (!InsertIfNotPresent(
275             &by_extension_,
276             std::make_pair(field.extendee().substr(1), field.number()),
277             value)) {
278       GOOGLE_LOG(ERROR) << "Extension conflicts with extension already in database: "
279                     "extend "
280                  << field.extendee() << " { " << field.name() << " = "
281                  << field.number() << " } from:" << filename;
282       return false;
283     }
284   } else {
285     // Not fully-qualified.  We can't really do anything here, unfortunately.
286     // We don't consider this an error, though, because the descriptor is
287     // valid.
288   }
289   return true;
290 }
291 
292 template <typename Value>
FindFile(const std::string & filename)293 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindFile(
294     const std::string& filename) {
295   return FindWithDefault(by_name_, filename, Value());
296 }
297 
298 template <typename Value>
FindSymbol(const std::string & name)299 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindSymbol(
300     const std::string& name) {
301   auto iter = FindLastLessOrEqual(&by_symbol_, name);
302 
303   return (iter != by_symbol_.end() && IsSubSymbol(iter->first, name))
304              ? iter->second
305              : Value();
306 }
307 
308 template <typename Value>
FindExtension(const std::string & containing_type,int field_number)309 Value SimpleDescriptorDatabase::DescriptorIndex<Value>::FindExtension(
310     const std::string& containing_type, int field_number) {
311   return FindWithDefault(
312       by_extension_, std::make_pair(containing_type, field_number), Value());
313 }
314 
315 template <typename Value>
FindAllExtensionNumbers(const std::string & containing_type,std::vector<int> * output)316 bool SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllExtensionNumbers(
317     const std::string& containing_type, std::vector<int>* output) {
318   typename std::map<std::pair<std::string, int>, Value>::const_iterator it =
319       by_extension_.lower_bound(std::make_pair(containing_type, 0));
320   bool success = false;
321 
322   for (; it != by_extension_.end() && it->first.first == containing_type;
323        ++it) {
324     output->push_back(it->first.second);
325     success = true;
326   }
327 
328   return success;
329 }
330 
331 template <typename Value>
FindAllFileNames(std::vector<std::string> * output)332 void SimpleDescriptorDatabase::DescriptorIndex<Value>::FindAllFileNames(
333     std::vector<std::string>* output) {
334   output->resize(by_name_.size());
335   int i = 0;
336   for (const auto& kv : by_name_) {
337     (*output)[i] = kv.first;
338     i++;
339   }
340 }
341 
342 // -------------------------------------------------------------------
343 
Add(const FileDescriptorProto & file)344 bool SimpleDescriptorDatabase::Add(const FileDescriptorProto& file) {
345   FileDescriptorProto* new_file = new FileDescriptorProto;
346   new_file->CopyFrom(file);
347   return AddAndOwn(new_file);
348 }
349 
AddAndOwn(const FileDescriptorProto * file)350 bool SimpleDescriptorDatabase::AddAndOwn(const FileDescriptorProto* file) {
351   files_to_delete_.emplace_back(file);
352   return index_.AddFile(*file, file);
353 }
354 
FindFileByName(const std::string & filename,FileDescriptorProto * output)355 bool SimpleDescriptorDatabase::FindFileByName(const std::string& filename,
356                                               FileDescriptorProto* output) {
357   return MaybeCopy(index_.FindFile(filename), output);
358 }
359 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)360 bool SimpleDescriptorDatabase::FindFileContainingSymbol(
361     const std::string& symbol_name, FileDescriptorProto* output) {
362   return MaybeCopy(index_.FindSymbol(symbol_name), output);
363 }
364 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)365 bool SimpleDescriptorDatabase::FindFileContainingExtension(
366     const std::string& containing_type, int field_number,
367     FileDescriptorProto* output) {
368   return MaybeCopy(index_.FindExtension(containing_type, field_number), output);
369 }
370 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)371 bool SimpleDescriptorDatabase::FindAllExtensionNumbers(
372     const std::string& extendee_type, std::vector<int>* output) {
373   return index_.FindAllExtensionNumbers(extendee_type, output);
374 }
375 
376 
FindAllFileNames(std::vector<std::string> * output)377 bool SimpleDescriptorDatabase::FindAllFileNames(
378     std::vector<std::string>* output) {
379   index_.FindAllFileNames(output);
380   return true;
381 }
382 
MaybeCopy(const FileDescriptorProto * file,FileDescriptorProto * output)383 bool SimpleDescriptorDatabase::MaybeCopy(const FileDescriptorProto* file,
384                                          FileDescriptorProto* output) {
385   if (file == nullptr) return false;
386   output->CopyFrom(*file);
387   return true;
388 }
389 
390 // -------------------------------------------------------------------
391 
392 class EncodedDescriptorDatabase::DescriptorIndex {
393  public:
394   using Value = std::pair<const void*, int>;
395   // Helpers to recursively add particular descriptors and all their contents
396   // to the index.
397   template <typename FileProto>
398   bool AddFile(const FileProto& file, Value value);
399 
400   Value FindFile(StringPiece filename);
401   Value FindSymbol(StringPiece name);
402   Value FindSymbolOnlyFlat(StringPiece name) const;
403   Value FindExtension(StringPiece containing_type, int field_number);
404   bool FindAllExtensionNumbers(StringPiece containing_type,
405                                std::vector<int>* output);
406   void FindAllFileNames(std::vector<std::string>* output) const;
407 
408  private:
409   friend class EncodedDescriptorDatabase;
410 
411   bool AddSymbol(StringPiece symbol);
412 
413   template <typename DescProto>
414   bool AddNestedExtensions(StringPiece filename,
415                            const DescProto& message_type);
416   template <typename FieldProto>
417   bool AddExtension(StringPiece filename, const FieldProto& field);
418 
419   // All the maps below have two representations:
420   //  - a std::set<> where we insert initially.
421   //  - a std::vector<> where we flatten the structure on demand.
422   // The initial tree helps avoid O(N) behavior of inserting into a sorted
423   // vector, while the vector reduces the heap requirements of the data
424   // structure.
425 
426   void EnsureFlat();
427 
428   using String = std::string;
429 
EncodeString(StringPiece str) const430   String EncodeString(StringPiece str) const { return String(str); }
DecodeString(const String & str,int) const431   StringPiece DecodeString(const String& str, int) const { return str; }
432 
433   struct EncodedEntry {
434     // Do not use `Value` here to avoid the padding of that object.
435     const void* data;
436     int size;
437     // Keep the package here instead of each SymbolEntry to save space.
438     String encoded_package;
439 
valuegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::EncodedEntry440     Value value() const { return {data, size}; }
441   };
442   std::vector<EncodedEntry> all_values_;
443 
444   struct FileEntry {
445     int data_offset;
446     String encoded_name;
447 
namegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileEntry448     StringPiece name(const DescriptorIndex& index) const {
449       return index.DecodeString(encoded_name, data_offset);
450     }
451   };
452   struct FileCompare {
453     const DescriptorIndex& index;
454 
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare455     bool operator()(const FileEntry& a, const FileEntry& b) const {
456       return a.name(index) < b.name(index);
457     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare458     bool operator()(const FileEntry& a, StringPiece b) const {
459       return a.name(index) < b;
460     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::FileCompare461     bool operator()(StringPiece a, const FileEntry& b) const {
462       return a < b.name(index);
463     }
464   };
465   std::set<FileEntry, FileCompare> by_name_{FileCompare{*this}};
466   std::vector<FileEntry> by_name_flat_;
467 
468   struct SymbolEntry {
469     int data_offset;
470     String encoded_symbol;
471 
packagegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry472     StringPiece package(const DescriptorIndex& index) const {
473       return index.DecodeString(index.all_values_[data_offset].encoded_package,
474                                 data_offset);
475     }
symbolgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry476     StringPiece symbol(const DescriptorIndex& index) const {
477       return index.DecodeString(encoded_symbol, data_offset);
478     }
479 
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolEntry480     std::string AsString(const DescriptorIndex& index) const {
481       auto p = package(index);
482       return StrCat(p, p.empty() ? "" : ".", symbol(index));
483     }
484   };
485 
486   struct SymbolCompare {
487     const DescriptorIndex& index;
488 
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare489     std::string AsString(const SymbolEntry& entry) const {
490       return entry.AsString(index);
491     }
AsStringgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare492     static StringPiece AsString(StringPiece str) { return str; }
493 
GetPartsgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare494     std::pair<StringPiece, StringPiece> GetParts(
495         const SymbolEntry& entry) const {
496       auto package = entry.package(index);
497       if (package.empty()) return {entry.symbol(index), StringPiece{}};
498       return {package, entry.symbol(index)};
499     }
GetPartsgoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare500     std::pair<StringPiece, StringPiece> GetParts(
501         StringPiece str) const {
502       return {str, {}};
503     }
504 
505     template <typename T, typename U>
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::SymbolCompare506     bool operator()(const T& lhs, const U& rhs) const {
507       auto lhs_parts = GetParts(lhs);
508       auto rhs_parts = GetParts(rhs);
509 
510       // Fast path to avoid making the whole string for common cases.
511       if (int res =
512               lhs_parts.first.substr(0, rhs_parts.first.size())
513                   .compare(rhs_parts.first.substr(0, lhs_parts.first.size()))) {
514         // If the packages already differ, exit early.
515         return res < 0;
516       } else if (lhs_parts.first.size() == rhs_parts.first.size()) {
517         return lhs_parts.second < rhs_parts.second;
518       }
519       return AsString(lhs) < AsString(rhs);
520     }
521   };
522   std::set<SymbolEntry, SymbolCompare> by_symbol_{SymbolCompare{*this}};
523   std::vector<SymbolEntry> by_symbol_flat_;
524 
525   struct ExtensionEntry {
526     int data_offset;
527     String encoded_extendee;
extendeegoogle::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionEntry528     StringPiece extendee(const DescriptorIndex& index) const {
529       return index.DecodeString(encoded_extendee, data_offset).substr(1);
530     }
531     int extension_number;
532   };
533   struct ExtensionCompare {
534     const DescriptorIndex& index;
535 
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare536     bool operator()(const ExtensionEntry& a, const ExtensionEntry& b) const {
537       return std::make_tuple(a.extendee(index), a.extension_number) <
538              std::make_tuple(b.extendee(index), b.extension_number);
539     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare540     bool operator()(const ExtensionEntry& a,
541                     std::tuple<StringPiece, int> b) const {
542       return std::make_tuple(a.extendee(index), a.extension_number) < b;
543     }
operator ()google::protobuf::EncodedDescriptorDatabase::DescriptorIndex::ExtensionCompare544     bool operator()(std::tuple<StringPiece, int> a,
545                     const ExtensionEntry& b) const {
546       return a < std::make_tuple(b.extendee(index), b.extension_number);
547     }
548   };
549   std::set<ExtensionEntry, ExtensionCompare> by_extension_{
550       ExtensionCompare{*this}};
551   std::vector<ExtensionEntry> by_extension_flat_;
552 };
553 
Add(const void * encoded_file_descriptor,int size)554 bool EncodedDescriptorDatabase::Add(const void* encoded_file_descriptor,
555                                     int size) {
556   FileDescriptorProto file;
557   if (file.ParseFromArray(encoded_file_descriptor, size)) {
558     return index_->AddFile(file, std::make_pair(encoded_file_descriptor, size));
559   } else {
560     GOOGLE_LOG(ERROR) << "Invalid file descriptor data passed to "
561                   "EncodedDescriptorDatabase::Add().";
562     return false;
563   }
564 }
565 
AddCopy(const void * encoded_file_descriptor,int size)566 bool EncodedDescriptorDatabase::AddCopy(const void* encoded_file_descriptor,
567                                         int size) {
568   void* copy = operator new(size);
569   memcpy(copy, encoded_file_descriptor, size);
570   files_to_delete_.push_back(copy);
571   return Add(copy, size);
572 }
573 
FindFileByName(const std::string & filename,FileDescriptorProto * output)574 bool EncodedDescriptorDatabase::FindFileByName(const std::string& filename,
575                                                FileDescriptorProto* output) {
576   return MaybeParse(index_->FindFile(filename), output);
577 }
578 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)579 bool EncodedDescriptorDatabase::FindFileContainingSymbol(
580     const std::string& symbol_name, FileDescriptorProto* output) {
581   return MaybeParse(index_->FindSymbol(symbol_name), output);
582 }
583 
FindNameOfFileContainingSymbol(const std::string & symbol_name,std::string * output)584 bool EncodedDescriptorDatabase::FindNameOfFileContainingSymbol(
585     const std::string& symbol_name, std::string* output) {
586   auto encoded_file = index_->FindSymbol(symbol_name);
587   if (encoded_file.first == nullptr) return false;
588 
589   // Optimization:  The name should be the first field in the encoded message.
590   //   Try to just read it directly.
591   io::CodedInputStream input(static_cast<const uint8_t*>(encoded_file.first),
592                              encoded_file.second);
593 
594   const uint32_t kNameTag = internal::WireFormatLite::MakeTag(
595       FileDescriptorProto::kNameFieldNumber,
596       internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
597 
598   if (input.ReadTagNoLastTag() == kNameTag) {
599     // Success!
600     return internal::WireFormatLite::ReadString(&input, output);
601   } else {
602     // Slow path.  Parse whole message.
603     FileDescriptorProto file_proto;
604     if (!file_proto.ParseFromArray(encoded_file.first, encoded_file.second)) {
605       return false;
606     }
607     *output = file_proto.name();
608     return true;
609   }
610 }
611 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)612 bool EncodedDescriptorDatabase::FindFileContainingExtension(
613     const std::string& containing_type, int field_number,
614     FileDescriptorProto* output) {
615   return MaybeParse(index_->FindExtension(containing_type, field_number),
616                     output);
617 }
618 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)619 bool EncodedDescriptorDatabase::FindAllExtensionNumbers(
620     const std::string& extendee_type, std::vector<int>* output) {
621   return index_->FindAllExtensionNumbers(extendee_type, output);
622 }
623 
624 template <typename FileProto>
AddFile(const FileProto & file,Value value)625 bool EncodedDescriptorDatabase::DescriptorIndex::AddFile(const FileProto& file,
626                                                          Value value) {
627   // We push `value` into the array first. This is important because the AddXXX
628   // functions below will expect it to be there.
629   all_values_.push_back({value.first, value.second, {}});
630 
631   if (!ValidateSymbolName(file.package())) {
632     GOOGLE_LOG(ERROR) << "Invalid package name: " << file.package();
633     return false;
634   }
635   all_values_.back().encoded_package = EncodeString(file.package());
636 
637   if (!InsertIfNotPresent(
638           &by_name_, FileEntry{static_cast<int>(all_values_.size() - 1),
639                                EncodeString(file.name())}) ||
640       std::binary_search(by_name_flat_.begin(), by_name_flat_.end(),
641                          file.name(), by_name_.key_comp())) {
642     GOOGLE_LOG(ERROR) << "File already exists in database: " << file.name();
643     return false;
644   }
645 
646   for (const auto& message_type : file.message_type()) {
647     if (!AddSymbol(message_type.name())) return false;
648     if (!AddNestedExtensions(file.name(), message_type)) return false;
649   }
650   for (const auto& enum_type : file.enum_type()) {
651     if (!AddSymbol(enum_type.name())) return false;
652   }
653   for (const auto& extension : file.extension()) {
654     if (!AddSymbol(extension.name())) return false;
655     if (!AddExtension(file.name(), extension)) return false;
656   }
657   for (const auto& service : file.service()) {
658     if (!AddSymbol(service.name())) return false;
659   }
660 
661   return true;
662 }
663 
664 template <typename Iter, typename Iter2, typename Index>
CheckForMutualSubsymbols(StringPiece symbol_name,Iter * iter,Iter2 end,const Index & index)665 static bool CheckForMutualSubsymbols(StringPiece symbol_name, Iter* iter,
666                                      Iter2 end, const Index& index) {
667   if (*iter != end) {
668     if (IsSubSymbol((*iter)->AsString(index), symbol_name)) {
669       GOOGLE_LOG(ERROR) << "Symbol name \"" << symbol_name
670                  << "\" conflicts with the existing symbol \""
671                  << (*iter)->AsString(index) << "\".";
672       return false;
673     }
674 
675     // OK, that worked.  Now we have to make sure that no symbol in the map is
676     // a sub-symbol of the one we are inserting.  The only symbol which could
677     // be so is the first symbol that is greater than the new symbol.  Since
678     // |iter| points at the last symbol that is less than or equal, we just have
679     // to increment it.
680     ++*iter;
681 
682     if (*iter != end && IsSubSymbol(symbol_name, (*iter)->AsString(index))) {
683       GOOGLE_LOG(ERROR) << "Symbol name \"" << symbol_name
684                  << "\" conflicts with the existing symbol \""
685                  << (*iter)->AsString(index) << "\".";
686       return false;
687     }
688   }
689   return true;
690 }
691 
AddSymbol(StringPiece symbol)692 bool EncodedDescriptorDatabase::DescriptorIndex::AddSymbol(
693     StringPiece symbol) {
694   SymbolEntry entry = {static_cast<int>(all_values_.size() - 1),
695                        EncodeString(symbol)};
696   std::string entry_as_string = entry.AsString(*this);
697 
698   // We need to make sure not to violate our map invariant.
699 
700   // If the symbol name is invalid it could break our lookup algorithm (which
701   // relies on the fact that '.' sorts before all other characters that are
702   // valid in symbol names).
703   if (!ValidateSymbolName(symbol)) {
704     GOOGLE_LOG(ERROR) << "Invalid symbol name: " << entry_as_string;
705     return false;
706   }
707 
708   auto iter = FindLastLessOrEqual(&by_symbol_, entry);
709   if (!CheckForMutualSubsymbols(entry_as_string, &iter, by_symbol_.end(),
710                                 *this)) {
711     return false;
712   }
713 
714   // Same, but on by_symbol_flat_
715   auto flat_iter =
716       FindLastLessOrEqual(&by_symbol_flat_, entry, by_symbol_.key_comp());
717   if (!CheckForMutualSubsymbols(entry_as_string, &flat_iter,
718                                 by_symbol_flat_.end(), *this)) {
719     return false;
720   }
721 
722   // OK, no conflicts.
723 
724   // Insert the new symbol using the iterator as a hint, the new entry will
725   // appear immediately before the one the iterator is pointing at.
726   by_symbol_.insert(iter, entry);
727 
728   return true;
729 }
730 
731 template <typename DescProto>
AddNestedExtensions(StringPiece filename,const DescProto & message_type)732 bool EncodedDescriptorDatabase::DescriptorIndex::AddNestedExtensions(
733     StringPiece filename, const DescProto& message_type) {
734   for (const auto& nested_type : message_type.nested_type()) {
735     if (!AddNestedExtensions(filename, nested_type)) return false;
736   }
737   for (const auto& extension : message_type.extension()) {
738     if (!AddExtension(filename, extension)) return false;
739   }
740   return true;
741 }
742 
743 template <typename FieldProto>
AddExtension(StringPiece filename,const FieldProto & field)744 bool EncodedDescriptorDatabase::DescriptorIndex::AddExtension(
745     StringPiece filename, const FieldProto& field) {
746   if (!field.extendee().empty() && field.extendee()[0] == '.') {
747     // The extension is fully-qualified.  We can use it as a lookup key in
748     // the by_symbol_ table.
749     if (!InsertIfNotPresent(
750             &by_extension_,
751             ExtensionEntry{static_cast<int>(all_values_.size() - 1),
752                            EncodeString(field.extendee()), field.number()}) ||
753         std::binary_search(
754             by_extension_flat_.begin(), by_extension_flat_.end(),
755             std::make_pair(field.extendee().substr(1), field.number()),
756             by_extension_.key_comp())) {
757       GOOGLE_LOG(ERROR) << "Extension conflicts with extension already in database: "
758                     "extend "
759                  << field.extendee() << " { " << field.name() << " = "
760                  << field.number() << " } from:" << filename;
761       return false;
762     }
763   } else {
764     // Not fully-qualified.  We can't really do anything here, unfortunately.
765     // We don't consider this an error, though, because the descriptor is
766     // valid.
767   }
768   return true;
769 }
770 
771 std::pair<const void*, int>
FindSymbol(StringPiece name)772 EncodedDescriptorDatabase::DescriptorIndex::FindSymbol(StringPiece name) {
773   EnsureFlat();
774   return FindSymbolOnlyFlat(name);
775 }
776 
777 std::pair<const void*, int>
FindSymbolOnlyFlat(StringPiece name) const778 EncodedDescriptorDatabase::DescriptorIndex::FindSymbolOnlyFlat(
779     StringPiece name) const {
780   auto iter =
781       FindLastLessOrEqual(&by_symbol_flat_, name, by_symbol_.key_comp());
782 
783   return iter != by_symbol_flat_.end() &&
784                  IsSubSymbol(iter->AsString(*this), name)
785              ? all_values_[iter->data_offset].value()
786              : Value();
787 }
788 
789 std::pair<const void*, int>
FindExtension(StringPiece containing_type,int field_number)790 EncodedDescriptorDatabase::DescriptorIndex::FindExtension(
791     StringPiece containing_type, int field_number) {
792   EnsureFlat();
793 
794   auto it = std::lower_bound(
795       by_extension_flat_.begin(), by_extension_flat_.end(),
796       std::make_tuple(containing_type, field_number), by_extension_.key_comp());
797   return it == by_extension_flat_.end() ||
798                  it->extendee(*this) != containing_type ||
799                  it->extension_number != field_number
800              ? std::make_pair(nullptr, 0)
801              : all_values_[it->data_offset].value();
802 }
803 
804 template <typename T, typename Less>
MergeIntoFlat(std::set<T,Less> * s,std::vector<T> * flat)805 static void MergeIntoFlat(std::set<T, Less>* s, std::vector<T>* flat) {
806   if (s->empty()) return;
807   std::vector<T> new_flat(s->size() + flat->size());
808   std::merge(s->begin(), s->end(), flat->begin(), flat->end(), &new_flat[0],
809              s->key_comp());
810   *flat = std::move(new_flat);
811   s->clear();
812 }
813 
EnsureFlat()814 void EncodedDescriptorDatabase::DescriptorIndex::EnsureFlat() {
815   all_values_.shrink_to_fit();
816   // Merge each of the sets into their flat counterpart.
817   MergeIntoFlat(&by_name_, &by_name_flat_);
818   MergeIntoFlat(&by_symbol_, &by_symbol_flat_);
819   MergeIntoFlat(&by_extension_, &by_extension_flat_);
820 }
821 
FindAllExtensionNumbers(StringPiece containing_type,std::vector<int> * output)822 bool EncodedDescriptorDatabase::DescriptorIndex::FindAllExtensionNumbers(
823     StringPiece containing_type, std::vector<int>* output) {
824   EnsureFlat();
825 
826   bool success = false;
827   auto it = std::lower_bound(
828       by_extension_flat_.begin(), by_extension_flat_.end(),
829       std::make_tuple(containing_type, 0), by_extension_.key_comp());
830   for (;
831        it != by_extension_flat_.end() && it->extendee(*this) == containing_type;
832        ++it) {
833     output->push_back(it->extension_number);
834     success = true;
835   }
836 
837   return success;
838 }
839 
FindAllFileNames(std::vector<std::string> * output) const840 void EncodedDescriptorDatabase::DescriptorIndex::FindAllFileNames(
841     std::vector<std::string>* output) const {
842   output->resize(by_name_.size() + by_name_flat_.size());
843   int i = 0;
844   for (const auto& entry : by_name_) {
845     (*output)[i] = std::string(entry.name(*this));
846     i++;
847   }
848   for (const auto& entry : by_name_flat_) {
849     (*output)[i] = std::string(entry.name(*this));
850     i++;
851   }
852 }
853 
854 std::pair<const void*, int>
FindFile(StringPiece filename)855 EncodedDescriptorDatabase::DescriptorIndex::FindFile(
856     StringPiece filename) {
857   EnsureFlat();
858 
859   auto it = std::lower_bound(by_name_flat_.begin(), by_name_flat_.end(),
860                              filename, by_name_.key_comp());
861   return it == by_name_flat_.end() || it->name(*this) != filename
862              ? std::make_pair(nullptr, 0)
863              : all_values_[it->data_offset].value();
864 }
865 
866 
FindAllFileNames(std::vector<std::string> * output)867 bool EncodedDescriptorDatabase::FindAllFileNames(
868     std::vector<std::string>* output) {
869   index_->FindAllFileNames(output);
870   return true;
871 }
872 
MaybeParse(std::pair<const void *,int> encoded_file,FileDescriptorProto * output)873 bool EncodedDescriptorDatabase::MaybeParse(
874     std::pair<const void*, int> encoded_file, FileDescriptorProto* output) {
875   if (encoded_file.first == nullptr) return false;
876   return output->ParseFromArray(encoded_file.first, encoded_file.second);
877 }
878 
EncodedDescriptorDatabase()879 EncodedDescriptorDatabase::EncodedDescriptorDatabase()
880     : index_(new DescriptorIndex()) {}
881 
~EncodedDescriptorDatabase()882 EncodedDescriptorDatabase::~EncodedDescriptorDatabase() {
883   for (void* p : files_to_delete_) {
884     operator delete(p);
885   }
886 }
887 
888 // ===================================================================
889 
DescriptorPoolDatabase(const DescriptorPool & pool)890 DescriptorPoolDatabase::DescriptorPoolDatabase(const DescriptorPool& pool)
891     : pool_(pool) {}
~DescriptorPoolDatabase()892 DescriptorPoolDatabase::~DescriptorPoolDatabase() {}
893 
FindFileByName(const std::string & filename,FileDescriptorProto * output)894 bool DescriptorPoolDatabase::FindFileByName(const std::string& filename,
895                                             FileDescriptorProto* output) {
896   const FileDescriptor* file = pool_.FindFileByName(filename);
897   if (file == nullptr) return false;
898   output->Clear();
899   file->CopyTo(output);
900   return true;
901 }
902 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)903 bool DescriptorPoolDatabase::FindFileContainingSymbol(
904     const std::string& symbol_name, FileDescriptorProto* output) {
905   const FileDescriptor* file = pool_.FindFileContainingSymbol(symbol_name);
906   if (file == nullptr) return false;
907   output->Clear();
908   file->CopyTo(output);
909   return true;
910 }
911 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)912 bool DescriptorPoolDatabase::FindFileContainingExtension(
913     const std::string& containing_type, int field_number,
914     FileDescriptorProto* output) {
915   const Descriptor* extendee = pool_.FindMessageTypeByName(containing_type);
916   if (extendee == nullptr) return false;
917 
918   const FieldDescriptor* extension =
919       pool_.FindExtensionByNumber(extendee, field_number);
920   if (extension == nullptr) return false;
921 
922   output->Clear();
923   extension->file()->CopyTo(output);
924   return true;
925 }
926 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)927 bool DescriptorPoolDatabase::FindAllExtensionNumbers(
928     const std::string& extendee_type, std::vector<int>* output) {
929   const Descriptor* extendee = pool_.FindMessageTypeByName(extendee_type);
930   if (extendee == nullptr) return false;
931 
932   std::vector<const FieldDescriptor*> extensions;
933   pool_.FindAllExtensions(extendee, &extensions);
934 
935   for (const FieldDescriptor* extension : extensions) {
936     output->push_back(extension->number());
937   }
938 
939   return true;
940 }
941 
942 // ===================================================================
943 
MergedDescriptorDatabase(DescriptorDatabase * source1,DescriptorDatabase * source2)944 MergedDescriptorDatabase::MergedDescriptorDatabase(
945     DescriptorDatabase* source1, DescriptorDatabase* source2) {
946   sources_.push_back(source1);
947   sources_.push_back(source2);
948 }
MergedDescriptorDatabase(const std::vector<DescriptorDatabase * > & sources)949 MergedDescriptorDatabase::MergedDescriptorDatabase(
950     const std::vector<DescriptorDatabase*>& sources)
951     : sources_(sources) {}
~MergedDescriptorDatabase()952 MergedDescriptorDatabase::~MergedDescriptorDatabase() {}
953 
FindFileByName(const std::string & filename,FileDescriptorProto * output)954 bool MergedDescriptorDatabase::FindFileByName(const std::string& filename,
955                                               FileDescriptorProto* output) {
956   for (DescriptorDatabase* source : sources_) {
957     if (source->FindFileByName(filename, output)) {
958       return true;
959     }
960   }
961   return false;
962 }
963 
FindFileContainingSymbol(const std::string & symbol_name,FileDescriptorProto * output)964 bool MergedDescriptorDatabase::FindFileContainingSymbol(
965     const std::string& symbol_name, FileDescriptorProto* output) {
966   for (size_t i = 0; i < sources_.size(); i++) {
967     if (sources_[i]->FindFileContainingSymbol(symbol_name, output)) {
968       // The symbol was found in source i.  However, if one of the previous
969       // sources defines a file with the same name (which presumably doesn't
970       // contain the symbol, since it wasn't found in that source), then we
971       // must hide it from the caller.
972       FileDescriptorProto temp;
973       for (size_t j = 0; j < i; j++) {
974         if (sources_[j]->FindFileByName(output->name(), &temp)) {
975           // Found conflicting file in a previous source.
976           return false;
977         }
978       }
979       return true;
980     }
981   }
982   return false;
983 }
984 
FindFileContainingExtension(const std::string & containing_type,int field_number,FileDescriptorProto * output)985 bool MergedDescriptorDatabase::FindFileContainingExtension(
986     const std::string& containing_type, int field_number,
987     FileDescriptorProto* output) {
988   for (size_t i = 0; i < sources_.size(); i++) {
989     if (sources_[i]->FindFileContainingExtension(containing_type, field_number,
990                                                  output)) {
991       // The symbol was found in source i.  However, if one of the previous
992       // sources defines a file with the same name (which presumably doesn't
993       // contain the symbol, since it wasn't found in that source), then we
994       // must hide it from the caller.
995       FileDescriptorProto temp;
996       for (size_t j = 0; j < i; j++) {
997         if (sources_[j]->FindFileByName(output->name(), &temp)) {
998           // Found conflicting file in a previous source.
999           return false;
1000         }
1001       }
1002       return true;
1003     }
1004   }
1005   return false;
1006 }
1007 
FindAllExtensionNumbers(const std::string & extendee_type,std::vector<int> * output)1008 bool MergedDescriptorDatabase::FindAllExtensionNumbers(
1009     const std::string& extendee_type, std::vector<int>* output) {
1010   std::set<int> merged_results;
1011   std::vector<int> results;
1012   bool success = false;
1013 
1014   for (DescriptorDatabase* source : sources_) {
1015     if (source->FindAllExtensionNumbers(extendee_type, &results)) {
1016       std::copy(results.begin(), results.end(),
1017                 std::insert_iterator<std::set<int> >(merged_results,
1018                                                      merged_results.begin()));
1019       success = true;
1020     }
1021     results.clear();
1022   }
1023 
1024   std::copy(merged_results.begin(), merged_results.end(),
1025             std::insert_iterator<std::vector<int> >(*output, output->end()));
1026 
1027   return success;
1028 }
1029 
1030 
FindAllFileNames(std::vector<std::string> * output)1031 bool MergedDescriptorDatabase::FindAllFileNames(
1032     std::vector<std::string>* output) {
1033   bool implemented = false;
1034   for (DescriptorDatabase* source : sources_) {
1035     std::vector<std::string> source_output;
1036     if (source->FindAllFileNames(&source_output)) {
1037       output->reserve(output->size() + source_output.size());
1038       for (auto& source : source_output) {
1039         output->push_back(std::move(source));
1040       }
1041       implemented = true;
1042     }
1043   }
1044   return implemented;
1045 }
1046 
1047 }  // namespace protobuf
1048 }  // namespace google
1049