1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/parsing/matcher.h"
18
19 #include <string>
20 #include <vector>
21
22 #include "utils/base/arena.h"
23 #include "utils/grammar/rules_generated.h"
24 #include "utils/grammar/types.h"
25 #include "utils/grammar/utils/rules.h"
26 #include "utils/strings/append.h"
27 #include "utils/utf8/unilib.h"
28 #include "gmock/gmock.h"
29 #include "gtest/gtest.h"
30
31 namespace libtextclassifier3::grammar {
32 namespace {
33
34 using ::testing::DescribeMatcher;
35 using ::testing::ElementsAre;
36 using ::testing::ExplainMatchResult;
37 using ::testing::IsEmpty;
38
39 struct TestMatchResult {
40 CodepointSpan codepoint_span;
41 std::string terminal;
42 std::string nonterminal;
43 int rule_id;
44
operator <<(std::ostream & os,const TestMatchResult & match)45 friend std::ostream& operator<<(std::ostream& os,
46 const TestMatchResult& match) {
47 return os << "Result(rule_id=" << match.rule_id
48 << ", begin=" << match.codepoint_span.first
49 << ", end=" << match.codepoint_span.second
50 << ", terminal=" << match.terminal
51 << ", nonterminal=" << match.nonterminal << ")";
52 }
53 };
54
55 MATCHER_P3(IsTerminal, begin, end, terminal,
56 "is terminal with begin that " +
57 DescribeMatcher<int>(begin, negation) + ", end that " +
58 DescribeMatcher<int>(end, negation) + ", value that " +
59 DescribeMatcher<std::string>(terminal, negation)) {
60 return ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
61 result_listener) &&
62 ExplainMatchResult(terminal, arg.terminal, result_listener);
63 }
64
65 MATCHER_P3(IsNonterminal, begin, end, name,
66 "is nonterminal with begin that " +
67 DescribeMatcher<int>(begin, negation) + ", end that " +
68 DescribeMatcher<int>(end, negation) + ", name that " +
69 DescribeMatcher<std::string>(name, negation)) {
70 return ExplainMatchResult(CodepointSpan(begin, end), arg.codepoint_span,
71 result_listener) &&
72 ExplainMatchResult(name, arg.nonterminal, result_listener);
73 }
74
75 MATCHER_P4(IsDerivation, begin, end, name, rule_id,
76 "is derivation of rule that " +
77 DescribeMatcher<int>(rule_id, negation) + ", begin that " +
78 DescribeMatcher<int>(begin, negation) + ", end that " +
79 DescribeMatcher<int>(end, negation) + ", name that " +
80 DescribeMatcher<std::string>(name, negation)) {
81 return ExplainMatchResult(IsNonterminal(begin, end, name), arg,
82 result_listener) &&
83 ExplainMatchResult(rule_id, arg.rule_id, result_listener);
84 }
85
86 // Superclass of all tests.
87 class MatcherTest : public testing::Test {
88 protected:
MatcherTest()89 MatcherTest()
90 : INIT_UNILIB_FOR_TESTING(unilib_), arena_(/*block_size=*/16 << 10) {}
91
GetNonterminalName(const RulesSet_::DebugInformation * debug_information,const Nonterm nonterminal) const92 std::string GetNonterminalName(
93 const RulesSet_::DebugInformation* debug_information,
94 const Nonterm nonterminal) const {
95 if (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry =
96 debug_information->nonterminal_names()->LookupByKey(nonterminal)) {
97 return entry->value()->str();
98 }
99 // Unnamed Nonterm.
100 return "()";
101 }
102
GetMatchResults(const Chart<> & chart,const RulesSet_::DebugInformation * debug_information)103 std::vector<TestMatchResult> GetMatchResults(
104 const Chart<>& chart,
105 const RulesSet_::DebugInformation* debug_information) {
106 std::vector<TestMatchResult> result;
107 for (const Derivation& derivation : chart.derivations()) {
108 result.emplace_back();
109 result.back().rule_id = derivation.rule_id;
110 result.back().codepoint_span = derivation.parse_tree->codepoint_span;
111 result.back().nonterminal =
112 GetNonterminalName(debug_information, derivation.parse_tree->lhs);
113 if (derivation.parse_tree->IsTerminalRule()) {
114 result.back().terminal = derivation.parse_tree->terminal;
115 }
116 }
117 return result;
118 }
119
120 UniLib unilib_;
121 UnsafeArena arena_;
122 };
123
TEST_F(MatcherTest,HandlesBasicOperations)124 TEST_F(MatcherTest, HandlesBasicOperations) {
125 // Create an example grammar.
126 grammar::LocaleShardMap locale_shard_map =
127 grammar::LocaleShardMap::CreateLocaleShardMap({""});
128 Rules rules(locale_shard_map);
129 rules.Add("<test>", {"the", "quick", "brown", "fox"},
130 static_cast<CallbackId>(DefaultCallback::kRootRule));
131 rules.Add("<action>", {"<test>"},
132 static_cast<CallbackId>(DefaultCallback::kRootRule));
133 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
134 /*include_debug_information=*/true);
135 const RulesSet* rules_set =
136 flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
137 Matcher matcher(&unilib_, rules_set, &arena_);
138
139 matcher.AddTerminal(0, 1, "the");
140 matcher.AddTerminal(1, 2, "quick");
141 matcher.AddTerminal(2, 3, "brown");
142 matcher.AddTerminal(3, 4, "fox");
143
144 EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
145 ElementsAre(IsNonterminal(0, 4, "<test>"),
146 IsNonterminal(0, 4, "<action>")));
147 }
148
CreateTestGrammar()149 std::string CreateTestGrammar() {
150 // Create an example grammar.
151 grammar::LocaleShardMap locale_shard_map =
152 grammar::LocaleShardMap::CreateLocaleShardMap({""});
153 Rules rules(locale_shard_map);
154
155 // Callbacks on terminal rules.
156 rules.Add("<output_5>", {"quick"},
157 static_cast<CallbackId>(DefaultCallback::kRootRule), 6);
158 rules.Add("<output_0>", {"the"},
159 static_cast<CallbackId>(DefaultCallback::kRootRule), 1);
160
161 // Callbacks on non-terminal rules.
162 rules.Add("<output_1>", {"the", "quick", "brown", "fox"},
163 static_cast<CallbackId>(DefaultCallback::kRootRule), 2);
164 rules.Add("<output_2>", {"the", "quick"},
165 static_cast<CallbackId>(DefaultCallback::kRootRule), 3);
166 rules.Add("<output_3>", {"brown", "fox"},
167 static_cast<CallbackId>(DefaultCallback::kRootRule), 4);
168
169 // Now a complex thing: "the* brown fox".
170 rules.Add("<thestarbrownfox>", {"brown", "fox"},
171 static_cast<CallbackId>(DefaultCallback::kRootRule), 5);
172 rules.Add("<thestarbrownfox>", {"the", "<thestarbrownfox>"},
173 static_cast<CallbackId>(DefaultCallback::kRootRule), 5);
174
175 return rules.Finalize().SerializeAsFlatbuffer(
176 /*include_debug_information=*/true);
177 }
178
FindNontermForName(const RulesSet * rules,const std::string & nonterminal_name)179 Nonterm FindNontermForName(const RulesSet* rules,
180 const std::string& nonterminal_name) {
181 for (const RulesSet_::DebugInformation_::NonterminalNamesEntry* entry :
182 *rules->debug_information()->nonterminal_names()) {
183 if (entry->value()->str() == nonterminal_name) {
184 return entry->key();
185 }
186 }
187 return kUnassignedNonterm;
188 }
189
TEST_F(MatcherTest,HandlesDerivationsOfRules)190 TEST_F(MatcherTest, HandlesDerivationsOfRules) {
191 const std::string rules_buffer = CreateTestGrammar();
192 const RulesSet* rules_set =
193 flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
194 Matcher matcher(&unilib_, rules_set, &arena_);
195
196 matcher.AddTerminal(0, 1, "the");
197 matcher.AddTerminal(1, 2, "quick");
198 matcher.AddTerminal(2, 3, "brown");
199 matcher.AddTerminal(3, 4, "fox");
200 matcher.AddTerminal(3, 5, "fox");
201 matcher.AddTerminal(4, 6, "fox"); // Not adjacent to "brown".
202
203 EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
204 ElementsAre(
205 // the
206 IsDerivation(0, 1, "<output_0>", 1),
207
208 // quick
209 IsDerivation(1, 2, "<output_5>", 6),
210 IsDerivation(0, 2, "<output_2>", 3),
211
212 // brown
213
214 // fox
215 IsDerivation(0, 4, "<output_1>", 2),
216 IsDerivation(2, 4, "<output_3>", 4),
217 IsDerivation(2, 4, "<thestarbrownfox>", 5),
218
219 // fox
220 IsDerivation(0, 5, "<output_1>", 2),
221 IsDerivation(2, 5, "<output_3>", 4),
222 IsDerivation(2, 5, "<thestarbrownfox>", 5)));
223 }
224
TEST_F(MatcherTest,HandlesRecursiveRules)225 TEST_F(MatcherTest, HandlesRecursiveRules) {
226 const std::string rules_buffer = CreateTestGrammar();
227 const RulesSet* rules_set =
228 flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
229 Matcher matcher(&unilib_, rules_set, &arena_);
230
231 matcher.AddTerminal(0, 1, "the");
232 matcher.AddTerminal(1, 2, "the");
233 matcher.AddTerminal(2, 4, "the");
234 matcher.AddTerminal(3, 4, "the");
235 matcher.AddTerminal(4, 5, "brown");
236 matcher.AddTerminal(5, 6, "fox"); // Generates 5 of <thestarbrownfox>
237
238 EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
239 ElementsAre(IsTerminal(0, 1, "the"), IsTerminal(1, 2, "the"),
240 IsTerminal(2, 4, "the"), IsTerminal(3, 4, "the"),
241 IsNonterminal(4, 6, "<output_3>"),
242 IsNonterminal(4, 6, "<thestarbrownfox>"),
243 IsNonterminal(3, 6, "<thestarbrownfox>"),
244 IsNonterminal(2, 6, "<thestarbrownfox>"),
245 IsNonterminal(1, 6, "<thestarbrownfox>"),
246 IsNonterminal(0, 6, "<thestarbrownfox>")));
247 }
248
TEST_F(MatcherTest,HandlesManualAddParseTreeCalls)249 TEST_F(MatcherTest, HandlesManualAddParseTreeCalls) {
250 const std::string rules_buffer = CreateTestGrammar();
251 const RulesSet* rules_set =
252 flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
253 Matcher matcher(&unilib_, rules_set, &arena_);
254
255 // Test having the lexer call AddParseTree() instead of AddTerminal()
256 matcher.AddTerminal(-4, 37, "the");
257 matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
258 FindNontermForName(rules_set, "<thestarbrownfox>"), CodepointSpan{37, 42},
259 /*match_offset=*/37, ParseTree::Type::kDefault));
260
261 EXPECT_THAT(GetMatchResults(matcher.chart(), rules_set->debug_information()),
262 ElementsAre(IsTerminal(-4, 37, "the"),
263 IsNonterminal(-4, 42, "<thestarbrownfox>")));
264 }
265
TEST_F(MatcherTest,HandlesOptionalRuleElements)266 TEST_F(MatcherTest, HandlesOptionalRuleElements) {
267 grammar::LocaleShardMap locale_shard_map =
268 grammar::LocaleShardMap::CreateLocaleShardMap({""});
269 Rules rules(locale_shard_map);
270 rules.Add("<output_0>", {"a?", "b?", "c?", "d?", "e"},
271 static_cast<CallbackId>(DefaultCallback::kRootRule));
272 rules.Add("<output_1>", {"a", "b?", "c", "d?", "e"},
273 static_cast<CallbackId>(DefaultCallback::kRootRule));
274 rules.Add("<output_2>", {"a", "b?", "c", "d", "e?"},
275 static_cast<CallbackId>(DefaultCallback::kRootRule));
276
277 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
278 /*include_debug_information=*/true);
279
280 const RulesSet* rules_set =
281 flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
282 Matcher matcher(&unilib_, rules_set, &arena_);
283
284 // Run the matcher on "a b c d e".
285 matcher.AddTerminal(0, 1, "a");
286 matcher.AddTerminal(1, 2, "b");
287 matcher.AddTerminal(2, 3, "c");
288 matcher.AddTerminal(3, 4, "d");
289 matcher.AddTerminal(4, 5, "e");
290
291 EXPECT_THAT(
292 GetMatchResults(matcher.chart(), rules_set->debug_information()),
293 ElementsAre(
294 IsNonterminal(0, 4, "<output_2>"), IsTerminal(4, 5, "e"),
295 IsNonterminal(0, 5, "<output_0>"), IsNonterminal(0, 5, "<output_1>"),
296 IsNonterminal(0, 5, "<output_2>"), IsNonterminal(1, 5, "<output_0>"),
297 IsNonterminal(2, 5, "<output_0>"),
298 IsNonterminal(3, 5, "<output_0>")));
299 }
300
TEST_F(MatcherTest,HandlesWhitespaceGapLimits)301 TEST_F(MatcherTest, HandlesWhitespaceGapLimits) {
302 grammar::LocaleShardMap locale_shard_map =
303 grammar::LocaleShardMap::CreateLocaleShardMap({""});
304 Rules rules(locale_shard_map);
305 rules.Add("<iata>", {"lx"});
306 rules.Add("<iata>", {"aa"});
307 // Require no whitespace between code and flight number.
308 rules.Add("<flight_number>", {"<iata>", "<4_digits>"},
309 /*callback=*/static_cast<CallbackId>(DefaultCallback::kRootRule), 0,
310 /*max_whitespace_gap=*/0);
311 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
312 /*include_debug_information=*/true);
313 const RulesSet* rules_set =
314 flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
315
316 // Check that the grammar triggers on LX1138.
317 {
318 Matcher matcher(&unilib_, rules_set, &arena_);
319 matcher.AddTerminal(0, 2, "LX");
320 matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
321 rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
322 CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
323 EXPECT_THAT(
324 GetMatchResults(matcher.chart(), rules_set->debug_information()),
325 ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
326 }
327
328 // Check that the grammar doesn't trigger on LX 1138.
329 {
330 Matcher matcher(&unilib_, rules_set, &arena_);
331 matcher.AddTerminal(6, 8, "LX");
332 matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
333 rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
334 CodepointSpan{9, 13}, /*match_offset=*/8, ParseTree::Type::kDefault));
335 EXPECT_THAT(
336 GetMatchResults(matcher.chart(), rules_set->debug_information()),
337 IsEmpty());
338 }
339 }
340
TEST_F(MatcherTest,HandlesCaseSensitiveTerminals)341 TEST_F(MatcherTest, HandlesCaseSensitiveTerminals) {
342 grammar::LocaleShardMap locale_shard_map =
343 grammar::LocaleShardMap::CreateLocaleShardMap({""});
344 Rules rules(locale_shard_map);
345 rules.Add("<iata>", {"LX"}, /*callback=*/kNoCallback, 0,
346 /*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
347 rules.Add("<iata>", {"AA"}, /*callback=*/kNoCallback, 0,
348 /*max_whitespace_gap*/ -1, /*case_sensitive=*/true);
349 rules.Add("<iata>", {"dl"}, /*callback=*/kNoCallback, 0,
350 /*max_whitespace_gap*/ -1, /*case_sensitive=*/false);
351 // Require no whitespace between code and flight number.
352 rules.Add("<flight_number>", {"<iata>", "<4_digits>"},
353 /*callback=*/static_cast<CallbackId>(DefaultCallback::kRootRule), 0,
354 /*max_whitespace_gap=*/0);
355 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
356 /*include_debug_information=*/true);
357 const RulesSet* rules_set =
358 flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
359
360 // Check that the grammar triggers on LX1138.
361 {
362 Matcher matcher(&unilib_, rules_set, &arena_);
363 matcher.AddTerminal(0, 2, "LX");
364 matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
365 rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
366 CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
367 EXPECT_THAT(
368 GetMatchResults(matcher.chart(), rules_set->debug_information()),
369 ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
370 }
371
372 // Check that the grammar doesn't trigger on lx1138.
373 {
374 Matcher matcher(&unilib_, rules_set, &arena_);
375 matcher.AddTerminal(6, 8, "lx");
376 matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
377 rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
378 CodepointSpan{8, 12}, /*match_offset=*/8, ParseTree::Type::kDefault));
379 EXPECT_THAT(matcher.chart().derivations(), IsEmpty());
380 }
381
382 // Check that the grammar does trigger on dl1138.
383 {
384 Matcher matcher(&unilib_, rules_set, &arena_);
385 matcher.AddTerminal(12, 14, "dl");
386 matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
387 rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
388 CodepointSpan{14, 18}, /*match_offset=*/14, ParseTree::Type::kDefault));
389 EXPECT_THAT(
390 GetMatchResults(matcher.chart(), rules_set->debug_information()),
391 ElementsAre(IsNonterminal(12, 18, "<flight_number>")));
392 }
393 }
394
TEST_F(MatcherTest,HandlesExclusions)395 TEST_F(MatcherTest, HandlesExclusions) {
396 grammar::LocaleShardMap locale_shard_map =
397 grammar::LocaleShardMap::CreateLocaleShardMap({""});
398 Rules rules(locale_shard_map);
399
400 rules.Add("<all_zeros>", {"0000"});
401 rules.AddWithExclusion("<flight_code>", {"<4_digits>"},
402 /*excluded_nonterminal=*/"<all_zeros>");
403 rules.Add("<iata>", {"lx"});
404 rules.Add("<iata>", {"aa"});
405 rules.Add("<iata>", {"dl"});
406 // Require no whitespace between code and flight number.
407 rules.Add("<flight_number>", {"<iata>", "<flight_code>"},
408 static_cast<CallbackId>(DefaultCallback::kRootRule));
409 const std::string rules_buffer = rules.Finalize().SerializeAsFlatbuffer(
410 /*include_debug_information=*/true);
411 const RulesSet* rules_set =
412 flatbuffers::GetRoot<RulesSet>(rules_buffer.data());
413
414 // Check that the grammar triggers on LX1138.
415 {
416 Matcher matcher(&unilib_, rules_set, &arena_);
417 matcher.AddTerminal(0, 2, "LX");
418 matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
419 rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
420 CodepointSpan{2, 6}, /*match_offset=*/2, ParseTree::Type::kDefault));
421 matcher.Finish();
422 EXPECT_THAT(
423 GetMatchResults(matcher.chart(), rules_set->debug_information()),
424 ElementsAre(IsNonterminal(0, 6, "<flight_number>")));
425 }
426
427 // Check that the grammar doesn't trigger on LX0000.
428 {
429 Matcher matcher(&unilib_, rules_set, &arena_);
430 matcher.AddTerminal(6, 8, "LX");
431 matcher.AddTerminal(8, 12, "0000");
432 matcher.AddParseTree(arena_.AllocAndInit<ParseTree>(
433 rules_set->nonterminals()->n_digits_nt()->Get(4 - 1),
434 CodepointSpan{8, 12}, /*match_offset=*/8, ParseTree::Type::kDefault));
435 matcher.Finish();
436 EXPECT_THAT(matcher.chart().derivations(), IsEmpty());
437 }
438 }
439
440 } // namespace
441 } // namespace libtextclassifier3::grammar
442