xref: /aosp_15_r20/external/libtextclassifier/native/utils/grammar/utils/ir.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/utils/ir.h"
18 
19 #include <algorithm>
20 
21 #include "utils/i18n/locale.h"
22 #include "utils/strings/append.h"
23 #include "utils/strings/stringpiece.h"
24 #include "utils/zlib/zlib.h"
25 
26 namespace libtextclassifier3::grammar {
27 namespace {
28 
29 constexpr size_t kMaxHashTableSize = 100;
30 
31 template <typename T>
SortForBinarySearchLookup(T * entries)32 void SortForBinarySearchLookup(T* entries) {
33   std::stable_sort(
34       entries->begin(), entries->end(),
35       [](const auto& a, const auto& b) { return a->key < b->key; });
36 }
37 
38 template <typename T>
SortStructsForBinarySearchLookup(T * entries)39 void SortStructsForBinarySearchLookup(T* entries) {
40   std::stable_sort(
41       entries->begin(), entries->end(),
42       [](const auto& a, const auto& b) { return a.key() < b.key(); });
43 }
44 
IsSameLhs(const Ir::Lhs & lhs,const RulesSet_::Lhs & other)45 bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) {
46   return (lhs.nonterminal == other.nonterminal() &&
47           lhs.callback.id == other.callback_id() &&
48           lhs.callback.param == other.callback_param() &&
49           lhs.preconditions.max_whitespace_gap == other.max_whitespace_gap());
50 }
51 
IsSameLhsEntry(const Ir::Lhs & lhs,const int32 lhs_entry,const std::vector<RulesSet_::Lhs> & candidates)52 bool IsSameLhsEntry(const Ir::Lhs& lhs, const int32 lhs_entry,
53                     const std::vector<RulesSet_::Lhs>& candidates) {
54   // Simple case: direct encoding of the nonterminal.
55   if (lhs_entry > 0) {
56     return (lhs.nonterminal == lhs_entry && lhs.callback.id == kNoCallback &&
57             lhs.preconditions.max_whitespace_gap == -1);
58   }
59 
60   // Entry is index into callback lookup.
61   return IsSameLhs(lhs, candidates[-lhs_entry]);
62 }
63 
IsSameLhsSet(const Ir::LhsSet & lhs_set,const RulesSet_::LhsSetT & candidate,const std::vector<RulesSet_::Lhs> & candidates)64 bool IsSameLhsSet(const Ir::LhsSet& lhs_set,
65                   const RulesSet_::LhsSetT& candidate,
66                   const std::vector<RulesSet_::Lhs>& candidates) {
67   if (lhs_set.size() != candidate.lhs.size()) {
68     return false;
69   }
70 
71   for (int i = 0; i < lhs_set.size(); i++) {
72     // Check that entries are the same.
73     if (!IsSameLhsEntry(lhs_set[i], candidate.lhs[i], candidates)) {
74       return false;
75     }
76   }
77 
78   return true;
79 }
80 
SortedLhsSet(const Ir::LhsSet & lhs_set)81 Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
82   Ir::LhsSet sorted_lhs = lhs_set;
83   std::stable_sort(
84       sorted_lhs.begin(), sorted_lhs.end(),
85       [](const Ir::Lhs& a, const Ir::Lhs& b) {
86         return std::tie(a.nonterminal, a.callback.id, a.callback.param,
87                         a.preconditions.max_whitespace_gap) <
88                std::tie(b.nonterminal, b.callback.id, b.callback.param,
89                         b.preconditions.max_whitespace_gap);
90       });
91   return lhs_set;
92 }
93 
94 // Adds a new lhs match set to the output.
95 // Reuses the same set, if it was previously observed.
AddLhsSet(const Ir::LhsSet & lhs_set,RulesSetT * rules_set)96 int AddLhsSet(const Ir::LhsSet& lhs_set, RulesSetT* rules_set) {
97   Ir::LhsSet sorted_lhs = SortedLhsSet(lhs_set);
98   // Check whether we can reuse an entry.
99   const int output_size = rules_set->lhs_set.size();
100   for (int i = 0; i < output_size; i++) {
101     if (IsSameLhsSet(lhs_set, *rules_set->lhs_set[i], rules_set->lhs)) {
102       return i;
103     }
104   }
105 
106   // Add new entry.
107   rules_set->lhs_set.emplace_back(std::make_unique<RulesSet_::LhsSetT>());
108   RulesSet_::LhsSetT* serialized_lhs_set = rules_set->lhs_set.back().get();
109   for (const Ir::Lhs& lhs : lhs_set) {
110     // Simple case: No callback and no special requirements, we directly encode
111     // the nonterminal.
112     if (lhs.callback.id == kNoCallback &&
113         lhs.preconditions.max_whitespace_gap < 0) {
114       serialized_lhs_set->lhs.push_back(lhs.nonterminal);
115     } else {
116       // Check whether we can reuse a callback entry.
117       const int lhs_size = rules_set->lhs.size();
118       bool found_entry = false;
119       for (int i = 0; i < lhs_size; i++) {
120         if (IsSameLhs(lhs, rules_set->lhs[i])) {
121           found_entry = true;
122           serialized_lhs_set->lhs.push_back(-i);
123           break;
124         }
125       }
126 
127       // We could reuse an existing entry.
128       if (found_entry) {
129         continue;
130       }
131 
132       // Add a new one.
133       rules_set->lhs.push_back(
134           RulesSet_::Lhs(lhs.nonterminal, lhs.callback.id, lhs.callback.param,
135                          lhs.preconditions.max_whitespace_gap));
136       serialized_lhs_set->lhs.push_back(-lhs_size);
137     }
138   }
139   return output_size;
140 }
141 
142 // Serializes a unary rules table.
SerializeUnaryRulesShard(const std::unordered_map<Nonterm,Ir::LhsSet> & unary_rules,RulesSetT * rules_set,RulesSet_::RulesT * rules)143 void SerializeUnaryRulesShard(
144     const std::unordered_map<Nonterm, Ir::LhsSet>& unary_rules,
145     RulesSetT* rules_set, RulesSet_::RulesT* rules) {
146   for (const auto& it : unary_rules) {
147     rules->unary_rules.push_back(RulesSet_::Rules_::UnaryRulesEntry(
148         it.first, AddLhsSet(it.second, rules_set)));
149   }
150   SortStructsForBinarySearchLookup(&rules->unary_rules);
151 }
152 
153 // // Serializes a binary rules table.
SerializeBinaryRulesShard(const std::unordered_map<TwoNonterms,Ir::LhsSet,BinaryRuleHasher> & binary_rules,RulesSetT * rules_set,RulesSet_::RulesT * rules)154 void SerializeBinaryRulesShard(
155     const std::unordered_map<TwoNonterms, Ir::LhsSet, BinaryRuleHasher>&
156         binary_rules,
157     RulesSetT* rules_set, RulesSet_::RulesT* rules) {
158   const size_t num_buckets = std::min(binary_rules.size(), kMaxHashTableSize);
159   for (int i = 0; i < num_buckets; i++) {
160     rules->binary_rules.emplace_back(
161         new RulesSet_::Rules_::BinaryRuleTableBucketT());
162   }
163 
164   // Serialize the table.
165   BinaryRuleHasher hash;
166   for (const auto& it : binary_rules) {
167     const TwoNonterms key = it.first;
168     uint32 bucket_index = hash(key) % num_buckets;
169 
170     // Add entry to bucket chain list.
171     rules->binary_rules[bucket_index]->rules.push_back(
172         RulesSet_::Rules_::BinaryRule(key.first, key.second,
173                                       AddLhsSet(it.second, rules_set)));
174   }
175 }
176 
177 }  // namespace
178 
AddToSet(const Lhs & lhs,LhsSet * lhs_set)179 Nonterm Ir::AddToSet(const Lhs& lhs, LhsSet* lhs_set) {
180   const int lhs_set_size = lhs_set->size();
181   Nonterm shareable_nonterm = lhs.nonterminal;
182   for (int i = 0; i < lhs_set_size; i++) {
183     Lhs* candidate = &lhs_set->at(i);
184 
185     // Exact match, just reuse rule.
186     if (lhs == *candidate) {
187       return candidate->nonterminal;
188     }
189 
190     // Cannot reuse unshareable ids.
191     if (nonshareable_.find(candidate->nonterminal) != nonshareable_.end() ||
192         nonshareable_.find(lhs.nonterminal) != nonshareable_.end()) {
193       continue;
194     }
195 
196     // Cannot reuse id if the preconditions are different.
197     if (!(lhs.preconditions == candidate->preconditions)) {
198       continue;
199     }
200 
201     // If the nonterminal is already defined, it must match for sharing.
202     if (lhs.nonterminal != kUnassignedNonterm &&
203         lhs.nonterminal != candidate->nonterminal) {
204       continue;
205     }
206 
207     // Check whether the callbacks match.
208     if (lhs.callback == candidate->callback) {
209       return candidate->nonterminal;
210     }
211 
212     // We can reuse if one of the output callbacks is not used.
213     if (lhs.callback.id == kNoCallback) {
214       return candidate->nonterminal;
215     } else if (candidate->callback.id == kNoCallback) {
216       // Old entry has no output callback, which is redundant now.
217       candidate->callback = lhs.callback;
218       return candidate->nonterminal;
219     }
220 
221     // We can share the nonterminal, but we need to
222     // add a new output callback. Defer this as we might find a shareable
223     // nonterminal first.
224     shareable_nonterm = candidate->nonterminal;
225   }
226 
227   // We didn't find a redundant entry, so create a new one.
228   shareable_nonterm = DefineNonterminal(shareable_nonterm);
229   lhs_set->push_back(Lhs{shareable_nonterm, lhs.callback, lhs.preconditions});
230   return shareable_nonterm;
231 }
232 
Add(const Lhs & lhs,const std::string & terminal,const bool case_sensitive,const int shard)233 Nonterm Ir::Add(const Lhs& lhs, const std::string& terminal,
234                 const bool case_sensitive, const int shard) {
235   TC3_CHECK_LT(shard, shards_.size());
236   if (case_sensitive) {
237     return AddRule(lhs, terminal, &shards_[shard].terminal_rules);
238   } else {
239     return AddRule(lhs, terminal, &shards_[shard].lowercase_terminal_rules);
240   }
241 }
242 
243 // For latency we put sub-rules on the first shard which must be any match
244 // i.e. '*' rules are always included while parsing the tree as it is only
245 // on shard one hence will be deduped correctly.
Add(const Lhs & lhs,const std::vector<Nonterm> & rhs,const int shard)246 Nonterm Ir::Add(const Lhs& lhs, const std::vector<Nonterm>& rhs,
247                 const int shard) {
248   // Add a new unary rule.
249   if (rhs.size() == 1) {
250     return Add(lhs, rhs.front(), shard);
251   }
252 
253   // Add a chain of (rhs.size() - 1) binary rules.
254   Nonterm prev = rhs.front();
255   for (int i = 1; i < rhs.size() - 1; i++) {
256     prev = Add(kUnassignedNonterm, prev, rhs[i]);
257   }
258   return Add(lhs, prev, rhs.back());
259 }
260 
AddRegex(Nonterm lhs,const std::string & regex_pattern)261 Nonterm Ir::AddRegex(Nonterm lhs, const std::string& regex_pattern) {
262   lhs = DefineNonterminal(lhs);
263   regex_rules_.emplace_back(regex_pattern, lhs);
264   return lhs;
265 }
266 
AddAnnotation(const Nonterm lhs,const std::string & annotation)267 void Ir::AddAnnotation(const Nonterm lhs, const std::string& annotation) {
268   annotations_.emplace_back(annotation, lhs);
269 }
270 
271 // Serializes the terminal rules table.
SerializeTerminalRules(RulesSetT * rules_set,std::vector<std::unique_ptr<RulesSet_::RulesT>> * rules_shards) const272 void Ir::SerializeTerminalRules(
273     RulesSetT* rules_set,
274     std::vector<std::unique_ptr<RulesSet_::RulesT>>* rules_shards) const {
275   // Use common pool for all terminals.
276   struct TerminalEntry {
277     std::string terminal;
278     int set_index;
279     int index;
280     Ir::LhsSet lhs_set;
281   };
282   std::vector<TerminalEntry> terminal_rules;
283 
284   // Merge all terminals into a common pool.
285   // We want to use one common pool, but still need to track which set they
286   // belong to.
287   std::vector<const std::unordered_map<std::string, Ir::LhsSet>*>
288       terminal_rules_sets;
289   std::vector<RulesSet_::Rules_::TerminalRulesMapT*> rules_maps;
290   terminal_rules_sets.reserve(2 * shards_.size());
291   rules_maps.reserve(terminal_rules_sets.size());
292   for (int i = 0; i < shards_.size(); i++) {
293     terminal_rules_sets.push_back(&shards_[i].terminal_rules);
294     terminal_rules_sets.push_back(&shards_[i].lowercase_terminal_rules);
295     rules_shards->at(i)->terminal_rules.reset(
296         new RulesSet_::Rules_::TerminalRulesMapT());
297     rules_shards->at(i)->lowercase_terminal_rules.reset(
298         new RulesSet_::Rules_::TerminalRulesMapT());
299     rules_maps.push_back(rules_shards->at(i)->terminal_rules.get());
300     rules_maps.push_back(rules_shards->at(i)->lowercase_terminal_rules.get());
301   }
302   for (int i = 0; i < terminal_rules_sets.size(); i++) {
303     for (const auto& it : *terminal_rules_sets[i]) {
304       terminal_rules.push_back(
305           TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second});
306     }
307   }
308   std::stable_sort(terminal_rules.begin(), terminal_rules.end(),
309                    [](const TerminalEntry& a, const TerminalEntry& b) {
310                      return a.terminal < b.terminal;
311                    });
312 
313   // Index the entries in sorted order.
314   std::vector<int> index(terminal_rules_sets.size(), 0);
315   for (int i = 0; i < terminal_rules.size(); i++) {
316     terminal_rules[i].index = index[terminal_rules[i].set_index]++;
317   }
318 
319   // We store the terminal strings sorted into a buffer and keep offsets into
320   // that buffer. In this way, we don't need extra space for terminals that are
321   // suffixes of others.
322 
323   // Find terminals that are a suffix of others, O(n^2) algorithm.
324   constexpr int kInvalidIndex = -1;
325   std::vector<int> suffix(terminal_rules.size(), kInvalidIndex);
326   for (int i = 0; i < terminal_rules.size(); i++) {
327     const StringPiece terminal(terminal_rules[i].terminal);
328 
329     // Check whether the ith terminal is a suffix of another.
330     for (int j = 0; j < terminal_rules.size(); j++) {
331       if (i == j) {
332         continue;
333       }
334       if (StringPiece(terminal_rules[j].terminal).EndsWith(terminal)) {
335         // If both terminals are the same keep the first.
336         // This avoids cyclic dependencies.
337         // This can happen if multiple shards use same terminals, such as
338         // punctuation.
339         if (terminal_rules[j].terminal.size() == terminal.size() && j < i) {
340           continue;
341         }
342         suffix[i] = j;
343         break;
344       }
345     }
346   }
347 
348   rules_set->terminals = "";
349 
350   for (int i = 0; i < terminal_rules_sets.size(); i++) {
351     rules_maps[i]->terminal_offsets.resize(terminal_rules_sets[i]->size());
352     rules_maps[i]->max_terminal_length = 0;
353     rules_maps[i]->min_terminal_length = std::numeric_limits<int>::max();
354   }
355 
356   for (int i = 0; i < terminal_rules.size(); i++) {
357     const TerminalEntry& entry = terminal_rules[i];
358 
359     // Update bounds.
360     rules_maps[entry.set_index]->min_terminal_length =
361         std::min(rules_maps[entry.set_index]->min_terminal_length,
362                  static_cast<int>(entry.terminal.size()));
363     rules_maps[entry.set_index]->max_terminal_length =
364         std::max(rules_maps[entry.set_index]->max_terminal_length,
365                  static_cast<int>(entry.terminal.size()));
366 
367     // Only include terminals that are not suffixes of others.
368     if (suffix[i] != kInvalidIndex) {
369       continue;
370     }
371 
372     rules_maps[entry.set_index]->terminal_offsets[entry.index] =
373         rules_set->terminals.length();
374     rules_set->terminals += entry.terminal + '\0';
375   }
376 
377   // Store just an offset into the existing terminal data for the terminals
378   // that are suffixes of others.
379   for (int i = 0; i < terminal_rules.size(); i++) {
380     int canonical_index = i;
381     if (suffix[canonical_index] == kInvalidIndex) {
382       continue;
383     }
384 
385     // Find the overlapping string that was included in the data.
386     while (suffix[canonical_index] != kInvalidIndex) {
387       canonical_index = suffix[canonical_index];
388     }
389 
390     const TerminalEntry& entry = terminal_rules[i];
391     const TerminalEntry& canonical_entry = terminal_rules[canonical_index];
392 
393     // The offset is the offset of the overlapping string and the offset within
394     // that string.
395     rules_maps[entry.set_index]->terminal_offsets[entry.index] =
396         rules_maps[canonical_entry.set_index]
397             ->terminal_offsets[canonical_entry.index] +
398         (canonical_entry.terminal.length() - entry.terminal.length());
399   }
400 
401   for (const TerminalEntry& entry : terminal_rules) {
402     rules_maps[entry.set_index]->lhs_set_index.push_back(
403         AddLhsSet(entry.lhs_set, rules_set));
404   }
405 }
406 
Serialize(const bool include_debug_information,RulesSetT * output) const407 void Ir::Serialize(const bool include_debug_information,
408                    RulesSetT* output) const {
409   // Add information about predefined nonterminal classes.
410   output->nonterminals.reset(new RulesSet_::NonterminalsT);
411   output->nonterminals->start_nt = GetNonterminalForName(kStartNonterm);
412   output->nonterminals->end_nt = GetNonterminalForName(kEndNonterm);
413   output->nonterminals->wordbreak_nt = GetNonterminalForName(kWordBreakNonterm);
414   output->nonterminals->token_nt = GetNonterminalForName(kTokenNonterm);
415   output->nonterminals->uppercase_token_nt =
416       GetNonterminalForName(kUppercaseTokenNonterm);
417   output->nonterminals->digits_nt = GetNonterminalForName(kDigitsNonterm);
418   for (int i = 1; i <= kMaxNDigitsNontermLength; i++) {
419     if (const Nonterm n_digits_nt =
420             GetNonterminalForName(strings::StringPrintf(kNDigitsNonterm, i))) {
421       output->nonterminals->n_digits_nt.resize(i, kUnassignedNonterm);
422       output->nonterminals->n_digits_nt[i - 1] = n_digits_nt;
423     }
424   }
425   for (const auto& [annotation, annotation_nt] : annotations_) {
426     output->nonterminals->annotation_nt.emplace_back(
427         new RulesSet_::Nonterminals_::AnnotationNtEntryT);
428     output->nonterminals->annotation_nt.back()->key = annotation;
429     output->nonterminals->annotation_nt.back()->value = annotation_nt;
430   }
431   SortForBinarySearchLookup(&output->nonterminals->annotation_nt);
432 
433   if (include_debug_information) {
434     output->debug_information.reset(new RulesSet_::DebugInformationT);
435     // Keep original non-terminal names.
436     for (const auto& it : nonterminal_names_) {
437       output->debug_information->nonterminal_names.emplace_back(
438           new RulesSet_::DebugInformation_::NonterminalNamesEntryT);
439       output->debug_information->nonterminal_names.back()->key = it.first;
440       output->debug_information->nonterminal_names.back()->value = it.second;
441     }
442     SortForBinarySearchLookup(&output->debug_information->nonterminal_names);
443   }
444 
445   // Add regex rules.
446   std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
447   for (auto [pattern, lhs] : regex_rules_) {
448     output->regex_annotator.emplace_back(new RulesSet_::RegexAnnotatorT);
449     output->regex_annotator.back()->compressed_pattern.reset(
450         new CompressedBufferT);
451     compressor->Compress(
452         pattern, output->regex_annotator.back()->compressed_pattern.get());
453     output->regex_annotator.back()->nonterminal = lhs;
454   }
455 
456   // Serialize the unary and binary rules.
457   for (int i = 0; i < shards_.size(); i++) {
458     output->rules.emplace_back(std::make_unique<RulesSet_::RulesT>());
459     RulesSet_::RulesT* rules = output->rules.back().get();
460     for (const Locale& shard_locale : locale_shard_map_.GetLocales(i)) {
461       if (shard_locale.IsValid()) {
462         // Check if the language is set to all i.e. '*' which is a special, to
463         // make it consistent with device side parser here instead of filling
464         // the all locale leave the language tag list empty
465         rules->locale.emplace_back(
466             std::make_unique<libtextclassifier3::LanguageTagT>());
467         libtextclassifier3::LanguageTagT* language_tag =
468             rules->locale.back().get();
469         language_tag->language = shard_locale.Language();
470         language_tag->region = shard_locale.Region();
471         language_tag->script = shard_locale.Script();
472       }
473     }
474 
475     // Serialize the unary rules.
476     SerializeUnaryRulesShard(shards_[i].unary_rules, output, rules);
477     // Serialize the binary rules.
478     SerializeBinaryRulesShard(shards_[i].binary_rules, output, rules);
479   }
480   // Serialize the terminal rules.
481   // We keep the rules separate by shard but merge the actual terminals into
482   // one shared string pool to most effectively exploit reuse.
483   SerializeTerminalRules(output, &output->rules);
484 }
485 
SerializeAsFlatbuffer(const bool include_debug_information) const486 std::string Ir::SerializeAsFlatbuffer(
487     const bool include_debug_information) const {
488   RulesSetT output;
489   Serialize(include_debug_information, &output);
490   flatbuffers::FlatBufferBuilder builder;
491   builder.Finish(RulesSet::Pack(builder, &output));
492   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
493                      builder.GetSize());
494 }
495 
496 }  // namespace libtextclassifier3::grammar
497