1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/datetime/regex-parser.h"
18
19 #include <algorithm>
20 #include <iterator>
21 #include <set>
22 #include <unordered_set>
23
24 #include "annotator/datetime/extractor.h"
25 #include "annotator/datetime/utils.h"
26 #include "utils/base/statusor.h"
27 #include "utils/calendar/calendar.h"
28 #include "utils/i18n/locale.h"
29 #include "utils/strings/split.h"
30 #include "utils/zlib/zlib_regex.h"
31
32 namespace libtextclassifier3 {
Instance(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)33 std::unique_ptr<DatetimeParser> RegexDatetimeParser::Instance(
34 const DatetimeModel* model, const UniLib* unilib,
35 const CalendarLib* calendarlib, ZlibDecompressor* decompressor) {
36 std::unique_ptr<RegexDatetimeParser> result(
37 new RegexDatetimeParser(model, unilib, calendarlib, decompressor));
38 if (!result->initialized_) {
39 result.reset();
40 }
41 return result;
42 }
43
RegexDatetimeParser(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)44 RegexDatetimeParser::RegexDatetimeParser(const DatetimeModel* model,
45 const UniLib* unilib,
46 const CalendarLib* calendarlib,
47 ZlibDecompressor* decompressor)
48 : unilib_(*unilib), calendarlib_(*calendarlib) {
49 initialized_ = false;
50
51 if (model == nullptr) {
52 return;
53 }
54
55 if (model->patterns() != nullptr) {
56 for (const DatetimeModelPattern* pattern : *model->patterns()) {
57 if (pattern->regexes()) {
58 for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
59 std::unique_ptr<UniLib::RegexPattern> regex_pattern =
60 UncompressMakeRegexPattern(
61 unilib_, regex->pattern(), regex->compressed_pattern(),
62 model->lazy_regex_compilation(), decompressor);
63 if (!regex_pattern) {
64 TC3_LOG(ERROR) << "Couldn't create rule pattern.";
65 return;
66 }
67 rules_.push_back({std::move(regex_pattern), regex, pattern});
68 if (pattern->locales()) {
69 for (int locale : *pattern->locales()) {
70 locale_to_rules_[locale].push_back(rules_.size() - 1);
71 }
72 }
73 }
74 }
75 }
76 }
77
78 if (model->extractors() != nullptr) {
79 for (const DatetimeModelExtractor* extractor : *model->extractors()) {
80 std::unique_ptr<UniLib::RegexPattern> regex_pattern =
81 UncompressMakeRegexPattern(
82 unilib_, extractor->pattern(), extractor->compressed_pattern(),
83 model->lazy_regex_compilation(), decompressor);
84 if (!regex_pattern) {
85 TC3_LOG(ERROR) << "Couldn't create extractor pattern";
86 return;
87 }
88 extractor_rules_.push_back(std::move(regex_pattern));
89
90 if (extractor->locales()) {
91 for (int locale : *extractor->locales()) {
92 type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
93 extractor_rules_.size() - 1;
94 }
95 }
96 }
97 }
98
99 if (model->locales() != nullptr) {
100 for (int i = 0; i < model->locales()->size(); ++i) {
101 locale_string_to_id_[model->locales()->Get(i)->str()] = i;
102 }
103 }
104
105 if (model->default_locales() != nullptr) {
106 for (const int locale : *model->default_locales()) {
107 default_locale_ids_.push_back(locale);
108 }
109 }
110
111 use_extractors_for_locating_ = model->use_extractors_for_locating();
112 generate_alternative_interpretations_when_ambiguous_ =
113 model->generate_alternative_interpretations_when_ambiguous();
114 prefer_future_for_unspecified_date_ =
115 model->prefer_future_for_unspecified_date();
116
117 initialized_ = true;
118 }
119
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const120 StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
121 const std::string& input, const int64 reference_time_ms_utc,
122 const std::string& reference_timezone, const LocaleList& locale_list,
123 ModeFlag mode, AnnotationUsecase annotation_usecase,
124 bool anchor_start_end) const {
125 return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
126 reference_time_ms_utc, reference_timezone, locale_list, mode,
127 annotation_usecase, anchor_start_end);
128 }
129
130 StatusOr<std::vector<DatetimeParseResultSpan>>
FindSpansUsingLocales(const std::vector<int> & locale_ids,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,const std::string & reference_locale,std::unordered_set<int> * executed_rules) const131 RegexDatetimeParser::FindSpansUsingLocales(
132 const std::vector<int>& locale_ids, const UnicodeText& input,
133 const int64 reference_time_ms_utc, const std::string& reference_timezone,
134 ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
135 const std::string& reference_locale,
136 std::unordered_set<int>* executed_rules) const {
137 std::vector<DatetimeParseResultSpan> found_spans;
138 for (const int locale_id : locale_ids) {
139 auto rules_it = locale_to_rules_.find(locale_id);
140 if (rules_it == locale_to_rules_.end()) {
141 continue;
142 }
143
144 for (const int rule_id : rules_it->second) {
145 // Skip rules that were already executed in previous locales.
146 if (executed_rules->find(rule_id) != executed_rules->end()) {
147 continue;
148 }
149
150 if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
151 (1 << annotation_usecase)) == 0) {
152 continue;
153 }
154
155 if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
156 continue;
157 }
158
159 executed_rules->insert(rule_id);
160 TC3_ASSIGN_OR_RETURN(
161 const std::vector<DatetimeParseResultSpan>& found_spans_per_rule,
162 ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
163 reference_timezone, reference_locale, locale_id,
164 anchor_start_end));
165 found_spans.insert(std::end(found_spans),
166 std::begin(found_spans_per_rule),
167 std::end(found_spans_per_rule));
168 }
169 }
170 return found_spans;
171 }
172
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const173 StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
174 const UnicodeText& input, const int64 reference_time_ms_utc,
175 const std::string& reference_timezone, const LocaleList& locale_list,
176 ModeFlag mode, AnnotationUsecase annotation_usecase,
177 bool anchor_start_end) const {
178 std::unordered_set<int> executed_rules;
179 const std::vector<int> requested_locales =
180 ParseAndExpandLocales(locale_list.GetLocaleTags());
181 TC3_ASSIGN_OR_RETURN(
182 const std::vector<DatetimeParseResultSpan>& found_spans,
183 FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
184 reference_timezone, mode, annotation_usecase,
185 anchor_start_end, locale_list.GetReferenceLocale(),
186 &executed_rules));
187 std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
188 indexed_found_spans.reserve(found_spans.size());
189 for (int i = 0; i < found_spans.size(); i++) {
190 indexed_found_spans.push_back({found_spans[i], i});
191 }
192
193 // Resolve conflicts by always picking the longer span and breaking ties by
194 // selecting the earlier entry in the list for a given locale.
195 std::stable_sort(indexed_found_spans.begin(), indexed_found_spans.end(),
196 [](const std::pair<DatetimeParseResultSpan, int>& a,
197 const std::pair<DatetimeParseResultSpan, int>& b) {
198 if ((a.first.span.second - a.first.span.first) !=
199 (b.first.span.second - b.first.span.first)) {
200 return (a.first.span.second - a.first.span.first) >
201 (b.first.span.second - b.first.span.first);
202 } else {
203 return a.second < b.second;
204 }
205 });
206
207 std::vector<DatetimeParseResultSpan> results;
208 std::vector<DatetimeParseResultSpan> resolved_found_spans;
209 resolved_found_spans.reserve(indexed_found_spans.size());
210 for (auto& span_index_pair : indexed_found_spans) {
211 resolved_found_spans.push_back(span_index_pair.first);
212 }
213
214 std::set<int, std::function<bool(int, int)>> chosen_indices_set(
215 [&resolved_found_spans](int a, int b) {
216 return resolved_found_spans[a].span.first <
217 resolved_found_spans[b].span.first;
218 });
219 for (int i = 0; i < resolved_found_spans.size(); ++i) {
220 if (!DoesCandidateConflict(i, resolved_found_spans, chosen_indices_set)) {
221 chosen_indices_set.insert(i);
222 results.push_back(resolved_found_spans[i]);
223 }
224 }
225 return results;
226 }
227
228 StatusOr<std::vector<DatetimeParseResultSpan>>
HandleParseMatch(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id) const229 RegexDatetimeParser::HandleParseMatch(const CompiledRule& rule,
230 const UniLib::RegexMatcher& matcher,
231 int64 reference_time_ms_utc,
232 const std::string& reference_timezone,
233 const std::string& reference_locale,
234 int locale_id) const {
235 std::vector<DatetimeParseResultSpan> results;
236 int status = UniLib::RegexMatcher::kNoError;
237 const int start = matcher.Start(&status);
238 if (status != UniLib::RegexMatcher::kNoError) {
239 return Status(StatusCode::INTERNAL,
240 "Failed to gets the start offset of the last match.");
241 }
242
243 const int end = matcher.End(&status);
244 if (status != UniLib::RegexMatcher::kNoError) {
245 return Status(StatusCode::INTERNAL,
246 "Failed to gets the end offset of the last match.");
247 }
248
249 DatetimeParseResultSpan parse_result;
250 std::vector<DatetimeParseResult> alternatives;
251 if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
252 reference_locale, locale_id, &alternatives,
253 &parse_result.span)) {
254 return Status(StatusCode::INTERNAL, "Failed to extract Datetime.");
255 }
256
257 if (!use_extractors_for_locating_) {
258 parse_result.span = {start, end};
259 }
260
261 if (parse_result.span.first != kInvalidIndex &&
262 parse_result.span.second != kInvalidIndex) {
263 parse_result.target_classification_score =
264 rule.pattern->target_classification_score();
265 parse_result.priority_score = rule.pattern->priority_score();
266
267 for (DatetimeParseResult& alternative : alternatives) {
268 parse_result.data.push_back(alternative);
269 }
270 }
271 results.push_back(parse_result);
272 return results;
273 }
274
275 StatusOr<std::vector<DatetimeParseResultSpan>>
ParseWithRule(const CompiledRule & rule,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,const int locale_id,bool anchor_start_end) const276 RegexDatetimeParser::ParseWithRule(const CompiledRule& rule,
277 const UnicodeText& input,
278 const int64 reference_time_ms_utc,
279 const std::string& reference_timezone,
280 const std::string& reference_locale,
281 const int locale_id,
282 bool anchor_start_end) const {
283 std::vector<DatetimeParseResultSpan> results;
284 std::unique_ptr<UniLib::RegexMatcher> matcher =
285 rule.compiled_regex->Matcher(input);
286 int status = UniLib::RegexMatcher::kNoError;
287 if (anchor_start_end) {
288 if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
289 return HandleParseMatch(rule, *matcher, reference_time_ms_utc,
290 reference_timezone, reference_locale, locale_id);
291 }
292 } else {
293 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
294 TC3_ASSIGN_OR_RETURN(
295 const std::vector<DatetimeParseResultSpan>& pattern_occurrence,
296 HandleParseMatch(rule, *matcher, reference_time_ms_utc,
297 reference_timezone, reference_locale, locale_id));
298 results.insert(std::end(results), std::begin(pattern_occurrence),
299 std::end(pattern_occurrence));
300 }
301 }
302 return results;
303 }
304
ParseAndExpandLocales(const std::vector<StringPiece> & locales) const305 std::vector<int> RegexDatetimeParser::ParseAndExpandLocales(
306 const std::vector<StringPiece>& locales) const {
307 std::vector<int> result;
308 for (const StringPiece& locale_str : locales) {
309 auto locale_it = locale_string_to_id_.find(locale_str.ToString());
310 if (locale_it != locale_string_to_id_.end()) {
311 result.push_back(locale_it->second);
312 }
313
314 const Locale locale = Locale::FromBCP47(locale_str.ToString());
315 if (!locale.IsValid()) {
316 continue;
317 }
318
319 const std::string language = locale.Language();
320 const std::string script = locale.Script();
321 const std::string region = locale.Region();
322
323 // First, try adding *-region locale.
324 if (!region.empty()) {
325 locale_it = locale_string_to_id_.find("*-" + region);
326 if (locale_it != locale_string_to_id_.end()) {
327 result.push_back(locale_it->second);
328 }
329 }
330 // Second, try adding language-script-* locale.
331 if (!script.empty()) {
332 locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
333 if (locale_it != locale_string_to_id_.end()) {
334 result.push_back(locale_it->second);
335 }
336 }
337 // Third, try adding language-* locale.
338 if (!language.empty()) {
339 locale_it = locale_string_to_id_.find(language + "-*");
340 if (locale_it != locale_string_to_id_.end()) {
341 result.push_back(locale_it->second);
342 }
343 }
344 }
345
346 // Add the default locales if they haven't been added already.
347 const std::unordered_set<int> result_set(result.begin(), result.end());
348 for (const int default_locale_id : default_locale_ids_) {
349 if (result_set.find(default_locale_id) == result_set.end()) {
350 result.push_back(default_locale_id);
351 }
352 }
353
354 return result;
355 }
356
ExtractDatetime(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResult> * results,CodepointSpan * result_span) const357 bool RegexDatetimeParser::ExtractDatetime(
358 const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
359 const int64 reference_time_ms_utc, const std::string& reference_timezone,
360 const std::string& reference_locale, int locale_id,
361 std::vector<DatetimeParseResult>* results,
362 CodepointSpan* result_span) const {
363 DatetimeParsedData parse;
364 DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_,
365 extractor_rules_,
366 type_and_locale_to_extractor_rule_);
367 if (!extractor.Extract(&parse, result_span)) {
368 return false;
369 }
370 std::vector<DatetimeParsedData> interpretations;
371 if (generate_alternative_interpretations_when_ambiguous_) {
372 FillInterpretations(parse, calendarlib_.GetGranularity(parse),
373 &interpretations);
374 } else {
375 interpretations.push_back(parse);
376 }
377
378 results->reserve(results->size() + interpretations.size());
379 for (const DatetimeParsedData& interpretation : interpretations) {
380 std::vector<DatetimeComponent> date_components;
381 interpretation.GetDatetimeComponents(&date_components);
382 DatetimeParseResult result;
383 // TODO(hassan): Text classifier only provides ambiguity limited to “AM/PM”
384 // which is encoded in the pair of DatetimeParseResult; both
385 // corresponding to the same date, but one corresponding to
386 // “AM” and the other one corresponding to “PM”.
387 // Remove multiple DatetimeParseResult per datetime span,
388 // once the ambiguities/DatetimeComponents are added in the
389 // response. For Details see b/130355975
390 if (!calendarlib_.InterpretParseData(
391 interpretation, reference_time_ms_utc, reference_timezone,
392 reference_locale, prefer_future_for_unspecified_date_,
393 &(result.time_ms_utc), &(result.granularity))) {
394 return false;
395 }
396
397 // Sort the date time units by component type.
398 std::stable_sort(date_components.begin(), date_components.end(),
399 [](DatetimeComponent a, DatetimeComponent b) {
400 return a.component_type > b.component_type;
401 });
402 result.datetime_components.swap(date_components);
403 results->push_back(result);
404 }
405 return true;
406 }
407
408 } // namespace libtextclassifier3
409