xref: /aosp_15_r20/external/libtextclassifier/native/utils/grammar/parsing/matcher_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/parsing/matcher.h"
18 
19 #include <string>
20 #include <vector>
21 
22 #include "utils/base/arena.h"
23 #include "utils/grammar/rules_generated.h"
24 #include "utils/grammar/types.h"
25 #include "utils/grammar/utils/rules.h"
26 #include "utils/strings/append.h"
27 #include "utils/utf8/unilib.h"
28 #include "gmock/gmock.h"
29 #include "gtest/gtest.h"
30 
31 namespace libtextclassifier3::grammar {
32 namespace {
33 
34 using ::testing::DescribeMatcher;
35 using ::testing::ElementsAre;
36 using ::testing::ExplainMatchResult;
37 using ::testing::IsEmpty;
38 
39 struct TestMatchResult {
40   CodepointSpan codepoint_span;
41   std::string terminal;
42   std::string nonterminal;
43   int rule_id;
44 
operator <<(std::ostream & os,const TestMatchResult & match)45   friend std::ostream& operator<<(std::ostream& os,
46                                   const TestMatchResult& match) {
47     return os << "Result(rule_id=" << match.rule_id
48               << ", begin=" << match.codepoint_span.first
49               << ", end=" << match.codepoint_span.second
50               << ", terminal=" << match.terminal
51               << ", nonterminal=" << match.nonterminal << ")";
52   }
53 };
54 
55 MATCHER_P3(IsTerminal, begin, end, terminal,
56            "is terminal with begin that " +
57                DescribeMatcher<int>(begin, negation) + ", end that " +
58                DescribeMatcher<int>(end, negation) + ", value that " +
59                DescribeMatcher<std::string>(terminal, negation)) {
60   return ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
61                             result_listener) &&
62          ExplainMatchResult(terminal, arg.terminal, result_listener);
63 }
64 
65 MATCHER_P3(IsNonterminal, begin, end, name,
66            "is nonterminal with begin that " +
67                DescribeMatcher<int>(begin, negation) + ", end that " +
68                DescribeMatcher<int>(end, negation) + ", name that " +
69                DescribeMatcher<std::string>(name, negation)) {
70   return ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
71                             result_listener) &&
72          ExplainMatchResult(name, arg.nonterminal, result_listener);
73 }
74 
75 MATCHER_P4(IsDerivation, begin, end, name, rule_id,
76            "is derivation of rule that " +
77                DescribeMatcher<int>(rule_id, negation) + ", begin that " +
78                DescribeMatcher<int>(begin, negation) + ", end that " +
79                DescribeMatcher<int>(end, negation) + ", name that " +
80                DescribeMatcher<std::string>(name, negation)) {
81   return ExplainMatchResult(IsNonterminal(begin, end, name), arg,
82                             result_listener) &&
83          ExplainMatchResult(rule_id, arg.rule_id, result_listener);
84 }
85 
86 // Superclass of all tests.
87 class MatcherTest : public testing::Test {
88  protected:
MatcherTest()89   MatcherTest()
90       : INIT_UNILIB_FOR_TESTING(unilib_), arena_(/*block_size=*/16 << 10) {}
91 
GetNonterminalName(const RulesSet_::DebugInformation * debug_information,const Nonterm nonterminal) const92   std::string GetNonterminalName(
93       const RulesSet_::DebugInformation* debug_information,
94       const Nonterm nonterminal) const {
95     if (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry =
96             debug_information->nonterminal_names()->LookupByKey(nonterminal)) {
97       return entry->value()->str();
98     }
99     // Unnamed Nonterm.
100     return "()";
101   }
102 
GetMatchResults(const Chart<> & chart,const RulesSet_::DebugInformation * debug_information)103   std::vector<TestMatchResult> GetMatchResults(
104       const Chart<>& chart,
105       const RulesSet_::DebugInformation* debug_information) {
106     std::vector<TestMatchResult> result;
107     for (const Derivation& derivation : chart.derivations()) {
108       result.emplace_back();
109       result.back().rule_id = derivation.rule_id;
110       result.back().codepoint_span = derivation.parse_tree->codepoint_span;
111       result.back().nonterminal =
112           GetNonterminalName(debug_information, derivation.parse_tree->lhs);
113       if (derivation.parse_tree->IsTerminalRule()) {
114         result.back().terminal = derivation.parse_tree->terminal;
115       }
116     }
117     return result;
118   }
119 
120   UniLib unilib_;
121   UnsafeArena arena_;
122 };
123 
TEST_F(MatcherTest,HandlesBasicOperations)124 TEST_F(MatcherTest, HandlesBasicOperations) {
125   // Create an example grammar.
126   grammar::LocaleShardMap locale_shard_map =
127       grammar::LocaleShardMap::CreateLocaleShardMap({""});
128   Rules rules(locale_shard_map);
129   rules.Add("<test>", {"the", "quick", "brown", "fox"},
130             static_cast<CallbackId>(DefaultCallback::kRootRule));
131   rules.Add("<action>", {"<test>"},
132             static_cast<CallbackId>(DefaultCallback::kRootRule));
133   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
134       /*include_debug_information=*/true);
135   const RulesSet* rules_set =
136       flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
137   Matcher matcher(&unilib_, rules_set, &arena_);
138 
139   matcher.AddTerminal(0, 1, "the");
140   matcher.AddTerminal(1, 2, "quick");
141   matcher.AddTerminal(2, 3, "brown");
142   matcher.AddTerminal(3, 4, "fox");
143 
144   EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
145               ElementsAre(IsNonterminal(0, 4, "<test>"),
146                           IsNonterminal(0, 4, "<action>")));
147 }
148 
CreateTestGrammar()149 std::string CreateTestGrammar() {
150   // Create an example grammar.
151   grammar::LocaleShardMap locale_shard_map =
152       grammar::LocaleShardMap::CreateLocaleShardMap({""});
153   Rules rules(locale_shard_map);
154 
155   // Callbacks on terminal rules.
156   rules.Add("<output_5>", {"quick"},
157             static_cast<CallbackId>(DefaultCallback::kRootRule), 6);
158   rules.Add("<output_0>", {"the"},
159             static_cast<CallbackId>(DefaultCallback::kRootRule), 1);
160 
161   // Callbacks on non-terminal rules.
162   rules.Add("<output_1>", {"the", "quick", "brown", "fox"},
163             static_cast<CallbackId>(DefaultCallback::kRootRule), 2);
164   rules.Add("<output_2>", {"the", "quick"},
165             static_cast<CallbackId>(DefaultCallback::kRootRule), 3);
166   rules.Add("<output_3>", {"brown", "fox"},
167             static_cast<CallbackId>(DefaultCallback::kRootRule), 4);
168 
169   // Now a complex thing: "the* brown fox".
170   rules.Add("<thestarbrownfox>", {"brown", "fox"},
171             static_cast<CallbackId>(DefaultCallback::kRootRule), 5);
172   rules.Add("<thestarbrownfox>", {"the", "<thestarbrownfox>"},
173             static_cast<CallbackId>(DefaultCallback::kRootRule), 5);
174 
175   return rules.Finalize().SerializeAsFlatbuffer(
176       /*include_debug_information=*/true);
177 }
178 
FindNontermForName(const RulesSet * rules,const std::string & nonterminal_name)179 Nonterm FindNontermForName(const RulesSet* rules,
180                            const std::string& nonterminal_name) {
181   for (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry :
182        *rules->debug_information()->nonterminal_names()) {
183     if (entry->value()->str() == nonterminal_name) {
184       return entry->key();
185     }
186   }
187   return kUnassignedNonterm;
188 }
189 
TEST_F(MatcherTest,HandlesDerivationsOfRules)190 TEST_F(MatcherTest, HandlesDerivationsOfRules) {
191   const std::string rules_buffer = CreateTestGrammar();
192   const RulesSet* rules_set =
193       flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
194   Matcher matcher(&unilib_, rules_set, &arena_);
195 
196   matcher.AddTerminal(0, 1, "the");
197   matcher.AddTerminal(1, 2, "quick");
198   matcher.AddTerminal(2, 3, "brown");
199   matcher.AddTerminal(3, 4, "fox");
200   matcher.AddTerminal(3, 5, "fox");
201   matcher.AddTerminal(4, 6, "fox");  // Not adjacent to "brown".
202 
203   EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
204               ElementsAre(
205                   // the
206                   IsDerivation(0, 1, "<output_0>", 1),
207 
208                   // quick
209                   IsDerivation(1, 2, "<output_5>", 6),
210                   IsDerivation(0, 2, "<output_2>", 3),
211 
212                   // brown
213 
214                   // fox
215                   IsDerivation(0, 4, "<output_1>", 2),
216                   IsDerivation(2, 4, "<output_3>", 4),
217                   IsDerivation(2, 4, "<thestarbrownfox>", 5),
218 
219                   // fox
220                   IsDerivation(0, 5, "<output_1>", 2),
221                   IsDerivation(2, 5, "<output_3>", 4),
222                   IsDerivation(2, 5, "<thestarbrownfox>", 5)));
223 }
224 
TEST_F(MatcherTest,HandlesRecursiveRules)225 TEST_F(MatcherTest, HandlesRecursiveRules) {
226   const std::string rules_buffer = CreateTestGrammar();
227   const RulesSet* rules_set =
228       flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
229   Matcher matcher(&unilib_, rules_set, &arena_);
230 
231   matcher.AddTerminal(0, 1, "the");
232   matcher.AddTerminal(1, 2, "the");
233   matcher.AddTerminal(2, 4, "the");
234   matcher.AddTerminal(3, 4, "the");
235   matcher.AddTerminal(4, 5, "brown");
236   matcher.AddTerminal(5, 6, "fox");  // Generates 5 of <thestarbrownfox>
237 
238   EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
239               ElementsAre(IsTerminal(0, 1, "the"), IsTerminal(1, 2, "the"),
240                           IsTerminal(2, 4, "the"), IsTerminal(3, 4, "the"),
241                           IsNonterminal(4, 6, "<output_3>"),
242                           IsNonterminal(4, 6, "<thestarbrownfox>"),
243                           IsNonterminal(3, 6, "<thestarbrownfox>"),
244                           IsNonterminal(2, 6, "<thestarbrownfox>"),
245                           IsNonterminal(1, 6, "<thestarbrownfox>"),
246                           IsNonterminal(0, 6, "<thestarbrownfox>")));
247 }
248 
TEST_F(MatcherTest,HandlesManualAddParseTreeCalls)249 TEST_F(MatcherTest, HandlesManualAddParseTreeCalls) {
250   const std::string rules_buffer = CreateTestGrammar();
251   const RulesSet* rules_set =
252       flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
253   Matcher matcher(&unilib_, rules_set, &arena_);
254 
255   // Test having the lexer call AddParseTree() instead of AddTerminal()
256   matcher.AddTerminal(-4, 37, "the");
257   matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
258       FindNontermForName(rules_set, "<thestarbrownfox>"), CodepointSpan{37, 42},
259       /*match_offset=*/37, ParseTree::Type::kDefault));
260 
261   EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
262               ElementsAre(IsTerminal(-4, 37, "the"),
263                           IsNonterminal(-4, 42, "<thestarbrownfox>")));
264 }
265 
TEST_F(MatcherTest,HandlesOptionalRuleElements)266 TEST_F(MatcherTest, HandlesOptionalRuleElements) {
267   grammar::LocaleShardMap locale_shard_map =
268       grammar::LocaleShardMap::CreateLocaleShardMap({""});
269   Rules rules(locale_shard_map);
270   rules.Add("<output_0>", {"a?", "b?", "c?", "d?", "e"},
271             static_cast<CallbackId>(DefaultCallback::kRootRule));
272   rules.Add("<output_1>", {"a", "b?", "c", "d?", "e"},
273             static_cast<CallbackId>(DefaultCallback::kRootRule));
274   rules.Add("<output_2>", {"a", "b?", "c", "d", "e?"},
275             static_cast<CallbackId>(DefaultCallback::kRootRule));
276 
277   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
278       /*include_debug_information=*/true);
279 
280   const RulesSet* rules_set =
281       flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
282   Matcher matcher(&unilib_, rules_set, &arena_);
283 
284   // Run the matcher on "a b c d e".
285   matcher.AddTerminal(0, 1, "a");
286   matcher.AddTerminal(1, 2, "b");
287   matcher.AddTerminal(2, 3, "c");
288   matcher.AddTerminal(3, 4, "d");
289   matcher.AddTerminal(4, 5, "e");
290 
291   EXPECT_THAT(
292       GetMatchResults(matcher.chart(), rules_set->debug_information()),
293       ElementsAre(
294           IsNonterminal(0, 4, "<output_2>"), IsTerminal(4, 5, "e"),
295           IsNonterminal(0, 5, "<output_0>"), IsNonterminal(0, 5, "<output_1>"),
296           IsNonterminal(0, 5, "<output_2>"), IsNonterminal(1, 5, "<output_0>"),
297           IsNonterminal(2, 5, "<output_0>"),
298           IsNonterminal(3, 5, "<output_0>")));
299 }
300 
TEST_F(MatcherTest,HandlesWhitespaceGapLimits)301 TEST_F(MatcherTest, HandlesWhitespaceGapLimits) {
302   grammar::LocaleShardMap locale_shard_map =
303       grammar::LocaleShardMap::CreateLocaleShardMap({""});
304   Rules rules(locale_shard_map);
305   rules.Add("<iata>", {"lx"});
306   rules.Add("<iata>", {"aa"});
307   // Require no whitespace between code and flight number.
308   rules.Add("<flight_number>", {"<iata>", "<4_digits>"},
309             /*callback=*/static_cast<CallbackId>(DefaultCallback::kRootRule), 0,
310             /*max_whitespace_gap=*/0);
311   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
312       /*include_debug_information=*/true);
313   const RulesSet* rules_set =
314       flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
315 
316   // Check that the grammar triggers on LX1138.
317   {
318     Matcher matcher(&unilib_, rules_set, &arena_);
319     matcher.AddTerminal(0, 2, "LX");
320     matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
321         rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
322         CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
323     EXPECT_THAT(
324         GetMatchResults(matcher.chart(), rules_set->debug_information()),
325         ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
326   }
327 
328   // Check that the grammar doesn't trigger on LX 1138.
329   {
330     Matcher matcher(&unilib_, rules_set, &arena_);
331     matcher.AddTerminal(6, 8, "LX");
332     matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
333         rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
334         CodepointSpan{9, 13}, /*match_offset=*/8, ParseTree::Type::kDefault));
335     EXPECT_THAT(
336         GetMatchResults(matcher.chart(), rules_set->debug_information()),
337         IsEmpty());
338   }
339 }
340 
TEST_F(MatcherTest,HandlesCaseSensitiveTerminals)341 TEST_F(MatcherTest, HandlesCaseSensitiveTerminals) {
342   grammar::LocaleShardMap locale_shard_map =
343       grammar::LocaleShardMap::CreateLocaleShardMap({""});
344   Rules rules(locale_shard_map);
345   rules.Add("<iata>", {"LX"}, /*callback=*/kNoCallback, 0,
346             /*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
347   rules.Add("<iata>", {"AA"}, /*callback=*/kNoCallback, 0,
348             /*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
349   rules.Add("<iata>", {"dl"}, /*callback=*/kNoCallback, 0,
350             /*max_whitespace_gap*/ -1, /*case_sensitive=*/false);
351   // Require no whitespace between code and flight number.
352   rules.Add("<flight_number>", {"<iata>", "<4_digits>"},
353             /*callback=*/static_cast<CallbackId>(DefaultCallback::kRootRule), 0,
354             /*max_whitespace_gap=*/0);
355   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
356       /*include_debug_information=*/true);
357   const RulesSet* rules_set =
358       flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
359 
360   // Check that the grammar triggers on LX1138.
361   {
362     Matcher matcher(&unilib_, rules_set, &arena_);
363     matcher.AddTerminal(0, 2, "LX");
364     matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
365         rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
366         CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
367     EXPECT_THAT(
368         GetMatchResults(matcher.chart(), rules_set->debug_information()),
369         ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
370   }
371 
372   // Check that the grammar doesn't trigger on lx1138.
373   {
374     Matcher matcher(&unilib_, rules_set, &arena_);
375     matcher.AddTerminal(6, 8, "lx");
376     matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
377         rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
378         CodepointSpan{8, 12}, /*match_offset=*/8, ParseTree::Type::kDefault));
379     EXPECT_THAT(matcher.chart().derivations(), IsEmpty());
380   }
381 
382   // Check that the grammar does trigger on dl1138.
383   {
384     Matcher matcher(&unilib_, rules_set, &arena_);
385     matcher.AddTerminal(12, 14, "dl");
386     matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
387         rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
388         CodepointSpan{14, 18}, /*match_offset=*/14, ParseTree::Type::kDefault));
389     EXPECT_THAT(
390         GetMatchResults(matcher.chart(), rules_set->debug_information()),
391         ElementsAre(IsNonterminal(12, 18, "<flight_number>")));
392   }
393 }
394 
TEST_F(MatcherTest,HandlesExclusions)395 TEST_F(MatcherTest, HandlesExclusions) {
396   grammar::LocaleShardMap locale_shard_map =
397       grammar::LocaleShardMap::CreateLocaleShardMap({""});
398   Rules rules(locale_shard_map);
399 
400   rules.Add("<all_zeros>", {"0000"});
401   rules.AddWithExclusion("<flight_code>", {"<4_digits>"},
402                          /*excluded_nonterminal=*/"<all_zeros>");
403   rules.Add("<iata>", {"lx"});
404   rules.Add("<iata>", {"aa"});
405   rules.Add("<iata>", {"dl"});
406   // Require no whitespace between code and flight number.
407   rules.Add("<flight_number>", {"<iata>", "<flight_code>"},
408             static_cast<CallbackId>(DefaultCallback::kRootRule));
409   const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
410       /*include_debug_information=*/true);
411   const RulesSet* rules_set =
412       flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
413 
414   // Check that the grammar triggers on LX1138.
415   {
416     Matcher matcher(&unilib_, rules_set, &arena_);
417     matcher.AddTerminal(0, 2, "LX");
418     matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
419         rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
420         CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
421     matcher.Finish();
422     EXPECT_THAT(
423         GetMatchResults(matcher.chart(), rules_set->debug_information()),
424         ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
425   }
426 
427   // Check that the grammar doesn't trigger on LX0000.
428   {
429     Matcher matcher(&unilib_, rules_set, &arena_);
430     matcher.AddTerminal(6, 8, "LX");
431     matcher.AddTerminal(8, 12, "0000");
432     matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
433         rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
434         CodepointSpan{8, 12}, /*match_offset=*/8, ParseTree::Type::kDefault));
435     matcher.Finish();
436     EXPECT_THAT(matcher.chart().derivations(), IsEmpty());
437   }
438 }
439 
440 }  // namespace
441 }  // namespace libtextclassifier3::grammar
442