xref: /aosp_15_r20/external/libtextclassifier/native/utils/grammar/parsing/derivation.h (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
18*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
19*993b0882SAndroid Build Coastguard Worker 
20*993b0882SAndroid Build Coastguard Worker #include <vector>
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/parsing/parse-tree.h"
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3::grammar {
25*993b0882SAndroid Build Coastguard Worker 
26*993b0882SAndroid Build Coastguard Worker // A parse tree for a root rule.
27*993b0882SAndroid Build Coastguard Worker struct Derivation {
28*993b0882SAndroid Build Coastguard Worker   const ParseTree* parse_tree;
29*993b0882SAndroid Build Coastguard Worker   int64 rule_id;
30*993b0882SAndroid Build Coastguard Worker 
31*993b0882SAndroid Build Coastguard Worker   // Checks that all assertions are fulfilled.
32*993b0882SAndroid Build Coastguard Worker   bool IsValid() const;
GetRuleIdDerivation33*993b0882SAndroid Build Coastguard Worker   int64 GetRuleId() const { return rule_id; }
GetParseTreeDerivation34*993b0882SAndroid Build Coastguard Worker   const ParseTree* GetParseTree() const { return parse_tree; }
35*993b0882SAndroid Build Coastguard Worker };
36*993b0882SAndroid Build Coastguard Worker 
37*993b0882SAndroid Build Coastguard Worker // Deduplicates rule derivations by containing overlap.
38*993b0882SAndroid Build Coastguard Worker // The grammar system can output multiple candidates for optional parts.
39*993b0882SAndroid Build Coastguard Worker // For example if a rule has an optional suffix, we
40*993b0882SAndroid Build Coastguard Worker // will get two rule derivations when the suffix is present: one with and one
41*993b0882SAndroid Build Coastguard Worker // without the suffix. We therefore deduplicate by containing overlap, viz. from
42*993b0882SAndroid Build Coastguard Worker // two candidates we keep the longer one if it completely contains the shorter.
43*993b0882SAndroid Build Coastguard Worker // This factory function works with any type T that extends Derivation.
44*993b0882SAndroid Build Coastguard Worker template <typename T, typename std::enable_if<std::is_base_of<
45*993b0882SAndroid Build Coastguard Worker                           Derivation, T>::value>::type* = nullptr>
46*993b0882SAndroid Build Coastguard Worker // std::vector<T> DeduplicateDerivations(const std::vector<T>& derivations);
DeduplicateDerivations(const std::vector<T> & derivations)47*993b0882SAndroid Build Coastguard Worker std::vector<T> DeduplicateDerivations(const std::vector<T>& derivations) {
48*993b0882SAndroid Build Coastguard Worker   std::vector<T> sorted_candidates = derivations;
49*993b0882SAndroid Build Coastguard Worker 
50*993b0882SAndroid Build Coastguard Worker   std::stable_sort(sorted_candidates.begin(), sorted_candidates.end(),
51*993b0882SAndroid Build Coastguard Worker                    [](const T& a, const T& b) {
52*993b0882SAndroid Build Coastguard Worker                      // Sort by id.
53*993b0882SAndroid Build Coastguard Worker                      if (a.GetRuleId() != b.GetRuleId()) {
54*993b0882SAndroid Build Coastguard Worker                        return a.GetRuleId() < b.GetRuleId();
55*993b0882SAndroid Build Coastguard Worker                      }
56*993b0882SAndroid Build Coastguard Worker 
57*993b0882SAndroid Build Coastguard Worker                      // Sort by increasing start.
58*993b0882SAndroid Build Coastguard Worker                      if (a.GetParseTree()->codepoint_span.first !=
59*993b0882SAndroid Build Coastguard Worker                          b.GetParseTree()->codepoint_span.first) {
60*993b0882SAndroid Build Coastguard Worker                        return a.GetParseTree()->codepoint_span.first <
61*993b0882SAndroid Build Coastguard Worker                               b.GetParseTree()->codepoint_span.first;
62*993b0882SAndroid Build Coastguard Worker                      }
63*993b0882SAndroid Build Coastguard Worker 
64*993b0882SAndroid Build Coastguard Worker                      // Sort by decreasing end.
65*993b0882SAndroid Build Coastguard Worker                      return a.GetParseTree()->codepoint_span.second >
66*993b0882SAndroid Build Coastguard Worker                             b.GetParseTree()->codepoint_span.second;
67*993b0882SAndroid Build Coastguard Worker                    });
68*993b0882SAndroid Build Coastguard Worker 
69*993b0882SAndroid Build Coastguard Worker   // Deduplicate by overlap.
70*993b0882SAndroid Build Coastguard Worker   std::vector<T> result;
71*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < sorted_candidates.size(); i++) {
72*993b0882SAndroid Build Coastguard Worker     const T& candidate = sorted_candidates[i];
73*993b0882SAndroid Build Coastguard Worker     bool eliminated = false;
74*993b0882SAndroid Build Coastguard Worker 
75*993b0882SAndroid Build Coastguard Worker     // Due to the sorting above, the candidate can only be completely
76*993b0882SAndroid Build Coastguard Worker     // intersected by a match before it in the sorted order.
77*993b0882SAndroid Build Coastguard Worker     for (int j = i - 1; j >= 0; j--) {
78*993b0882SAndroid Build Coastguard Worker       if (sorted_candidates[j].rule_id != candidate.rule_id) {
79*993b0882SAndroid Build Coastguard Worker         break;
80*993b0882SAndroid Build Coastguard Worker       }
81*993b0882SAndroid Build Coastguard Worker       if (sorted_candidates[j].parse_tree->codepoint_span.first <=
82*993b0882SAndroid Build Coastguard Worker               candidate.parse_tree->codepoint_span.first &&
83*993b0882SAndroid Build Coastguard Worker           sorted_candidates[j].parse_tree->codepoint_span.second >=
84*993b0882SAndroid Build Coastguard Worker               candidate.parse_tree->codepoint_span.second) {
85*993b0882SAndroid Build Coastguard Worker         eliminated = true;
86*993b0882SAndroid Build Coastguard Worker         break;
87*993b0882SAndroid Build Coastguard Worker       }
88*993b0882SAndroid Build Coastguard Worker     }
89*993b0882SAndroid Build Coastguard Worker     if (!eliminated) {
90*993b0882SAndroid Build Coastguard Worker       result.push_back(candidate);
91*993b0882SAndroid Build Coastguard Worker     }
92*993b0882SAndroid Build Coastguard Worker   }
93*993b0882SAndroid Build Coastguard Worker   return result;
94*993b0882SAndroid Build Coastguard Worker }
95*993b0882SAndroid Build Coastguard Worker 
96*993b0882SAndroid Build Coastguard Worker // Deduplicates and validates rule derivations.
97*993b0882SAndroid Build Coastguard Worker std::vector<Derivation> ValidDeduplicatedDerivations(
98*993b0882SAndroid Build Coastguard Worker     const std::vector<Derivation>& derivations);
99*993b0882SAndroid Build Coastguard Worker 
100*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3::grammar
101*993b0882SAndroid Build Coastguard Worker 
102*993b0882SAndroid Build Coastguard Worker #endif  // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_PARSING_DERIVATION_H_
103