xref: /aosp_15_r20/external/libtextclassifier/native/annotator/datetime/regex-parser.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/datetime/regex-parser.h"
18 
19 #include <algorithm>
20 #include <iterator>
21 #include <set>
22 #include <unordered_set>
23 
24 #include "annotator/datetime/extractor.h"
25 #include "annotator/datetime/utils.h"
26 #include "utils/base/statusor.h"
27 #include "utils/calendar/calendar.h"
28 #include "utils/i18n/locale.h"
29 #include "utils/strings/split.h"
30 #include "utils/zlib/zlib_regex.h"
31 
32 namespace libtextclassifier3 {
Instance(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)33 std::unique_ptr<DatetimeParser> RegexDatetimeParser::Instance(
34     const DatetimeModel* model, const UniLib* unilib,
35     const CalendarLib* calendarlib, ZlibDecompressor* decompressor) {
36   std::unique_ptr<RegexDatetimeParser> result(
37       new RegexDatetimeParser(model, unilib, calendarlib, decompressor));
38   if (!result->initialized_) {
39     result.reset();
40   }
41   return result;
42 }
43 
RegexDatetimeParser(const DatetimeModel * model,const UniLib * unilib,const CalendarLib * calendarlib,ZlibDecompressor * decompressor)44 RegexDatetimeParser::RegexDatetimeParser(const DatetimeModel* model,
45                                          const UniLib* unilib,
46                                          const CalendarLib* calendarlib,
47                                          ZlibDecompressor* decompressor)
48     : unilib_(*unilib), calendarlib_(*calendarlib) {
49   initialized_ = false;
50 
51   if (model == nullptr) {
52     return;
53   }
54 
55   if (model->patterns() != nullptr) {
56     for (const DatetimeModelPattern* pattern : *model->patterns()) {
57       if (pattern->regexes()) {
58         for (const DatetimeModelPattern_::Regex* regex : *pattern->regexes()) {
59           std::unique_ptr<UniLib::RegexPattern> regex_pattern =
60               UncompressMakeRegexPattern(
61                   unilib_, regex->pattern(), regex->compressed_pattern(),
62                   model->lazy_regex_compilation(), decompressor);
63           if (!regex_pattern) {
64             TC3_LOG(ERROR) << "Couldn't create rule pattern.";
65             return;
66           }
67           rules_.push_back({std::move(regex_pattern), regex, pattern});
68           if (pattern->locales()) {
69             for (int locale : *pattern->locales()) {
70               locale_to_rules_[locale].push_back(rules_.size() - 1);
71             }
72           }
73         }
74       }
75     }
76   }
77 
78   if (model->extractors() != nullptr) {
79     for (const DatetimeModelExtractor* extractor : *model->extractors()) {
80       std::unique_ptr<UniLib::RegexPattern> regex_pattern =
81           UncompressMakeRegexPattern(
82               unilib_, extractor->pattern(), extractor->compressed_pattern(),
83               model->lazy_regex_compilation(), decompressor);
84       if (!regex_pattern) {
85         TC3_LOG(ERROR) << "Couldn't create extractor pattern";
86         return;
87       }
88       extractor_rules_.push_back(std::move(regex_pattern));
89 
90       if (extractor->locales()) {
91         for (int locale : *extractor->locales()) {
92           type_and_locale_to_extractor_rule_[extractor->extractor()][locale] =
93               extractor_rules_.size() - 1;
94         }
95       }
96     }
97   }
98 
99   if (model->locales() != nullptr) {
100     for (int i = 0; i < model->locales()->size(); ++i) {
101       locale_string_to_id_[model->locales()->Get(i)->str()] = i;
102     }
103   }
104 
105   if (model->default_locales() != nullptr) {
106     for (const int locale : *model->default_locales()) {
107       default_locale_ids_.push_back(locale);
108     }
109   }
110 
111   use_extractors_for_locating_ = model->use_extractors_for_locating();
112   generate_alternative_interpretations_when_ambiguous_ =
113       model->generate_alternative_interpretations_when_ambiguous();
114   prefer_future_for_unspecified_date_ =
115       model->prefer_future_for_unspecified_date();
116 
117   initialized_ = true;
118 }
119 
Parse(const std::string & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const120 StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
121     const std::string& input, const int64 reference_time_ms_utc,
122     const std::string& reference_timezone, const LocaleList& locale_list,
123     ModeFlag mode, AnnotationUsecase annotation_usecase,
124     bool anchor_start_end) const {
125   return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
126                reference_time_ms_utc, reference_timezone, locale_list, mode,
127                annotation_usecase, anchor_start_end);
128 }
129 
130 StatusOr<std::vector<DatetimeParseResultSpan>>
FindSpansUsingLocales(const std::vector<int> & locale_ids,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end,const std::string & reference_locale,std::unordered_set<int> * executed_rules) const131 RegexDatetimeParser::FindSpansUsingLocales(
132     const std::vector<int>& locale_ids, const UnicodeText& input,
133     const int64 reference_time_ms_utc, const std::string& reference_timezone,
134     ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
135     const std::string& reference_locale,
136     std::unordered_set<int>* executed_rules) const {
137   std::vector<DatetimeParseResultSpan> found_spans;
138   for (const int locale_id : locale_ids) {
139     auto rules_it = locale_to_rules_.find(locale_id);
140     if (rules_it == locale_to_rules_.end()) {
141       continue;
142     }
143 
144     for (const int rule_id : rules_it->second) {
145       // Skip rules that were already executed in previous locales.
146       if (executed_rules->find(rule_id) != executed_rules->end()) {
147         continue;
148       }
149 
150       if ((rules_[rule_id].pattern->enabled_annotation_usecases() &
151            (1 << annotation_usecase)) == 0) {
152         continue;
153       }
154 
155       if (!(rules_[rule_id].pattern->enabled_modes() & mode)) {
156         continue;
157       }
158 
159       executed_rules->insert(rule_id);
160       TC3_ASSIGN_OR_RETURN(
161           const std::vector<DatetimeParseResultSpan>& found_spans_per_rule,
162           ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
163                         reference_timezone, reference_locale, locale_id,
164                         anchor_start_end));
165       found_spans.insert(std::end(found_spans),
166                          std::begin(found_spans_per_rule),
167                          std::end(found_spans_per_rule));
168     }
169   }
170   return found_spans;
171 }
172 
Parse(const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const LocaleList & locale_list,ModeFlag mode,AnnotationUsecase annotation_usecase,bool anchor_start_end) const173 StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
174     const UnicodeText& input, const int64 reference_time_ms_utc,
175     const std::string& reference_timezone, const LocaleList& locale_list,
176     ModeFlag mode, AnnotationUsecase annotation_usecase,
177     bool anchor_start_end) const {
178   std::unordered_set<int> executed_rules;
179   const std::vector<int> requested_locales =
180       ParseAndExpandLocales(locale_list.GetLocaleTags());
181   TC3_ASSIGN_OR_RETURN(
182       const std::vector<DatetimeParseResultSpan>& found_spans,
183       FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
184                             reference_timezone, mode, annotation_usecase,
185                             anchor_start_end, locale_list.GetReferenceLocale(),
186                             &executed_rules));
187   std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
188   indexed_found_spans.reserve(found_spans.size());
189   for (int i = 0; i < found_spans.size(); i++) {
190     indexed_found_spans.push_back({found_spans[i], i});
191   }
192 
193   // Resolve conflicts by always picking the longer span and breaking ties by
194   // selecting the earlier entry in the list for a given locale.
195   std::stable_sort(indexed_found_spans.begin(), indexed_found_spans.end(),
196                    [](const std::pair<DatetimeParseResultSpan, int>& a,
197                       const std::pair<DatetimeParseResultSpan, int>& b) {
198                      if ((a.first.span.second - a.first.span.first) !=
199                          (b.first.span.second - b.first.span.first)) {
200                        return (a.first.span.second - a.first.span.first) >
201                               (b.first.span.second - b.first.span.first);
202                      } else {
203                        return a.second < b.second;
204                      }
205                    });
206 
207   std::vector<DatetimeParseResultSpan> results;
208   std::vector<DatetimeParseResultSpan> resolved_found_spans;
209   resolved_found_spans.reserve(indexed_found_spans.size());
210   for (auto& span_index_pair : indexed_found_spans) {
211     resolved_found_spans.push_back(span_index_pair.first);
212   }
213 
214   std::set<int, std::function<bool(int, int)>> chosen_indices_set(
215       [&resolved_found_spans](int a, int b) {
216         return resolved_found_spans[a].span.first <
217                resolved_found_spans[b].span.first;
218       });
219   for (int i = 0; i < resolved_found_spans.size(); ++i) {
220     if (!DoesCandidateConflict(i, resolved_found_spans, chosen_indices_set)) {
221       chosen_indices_set.insert(i);
222       results.push_back(resolved_found_spans[i]);
223     }
224   }
225   return results;
226 }
227 
228 StatusOr<std::vector<DatetimeParseResultSpan>>
HandleParseMatch(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id) const229 RegexDatetimeParser::HandleParseMatch(const CompiledRule& rule,
230                                       const UniLib::RegexMatcher& matcher,
231                                       int64 reference_time_ms_utc,
232                                       const std::string& reference_timezone,
233                                       const std::string& reference_locale,
234                                       int locale_id) const {
235   std::vector<DatetimeParseResultSpan> results;
236   int status = UniLib::RegexMatcher::kNoError;
237   const int start = matcher.Start(&status);
238   if (status != UniLib::RegexMatcher::kNoError) {
239     return Status(StatusCode::INTERNAL,
240                   "Failed to gets the start offset of the last match.");
241   }
242 
243   const int end = matcher.End(&status);
244   if (status != UniLib::RegexMatcher::kNoError) {
245     return Status(StatusCode::INTERNAL,
246                   "Failed to gets the end offset of the last match.");
247   }
248 
249   DatetimeParseResultSpan parse_result;
250   std::vector<DatetimeParseResult> alternatives;
251   if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
252                        reference_locale, locale_id, &alternatives,
253                        &parse_result.span)) {
254     return Status(StatusCode::INTERNAL, "Failed to extract Datetime.");
255   }
256 
257   if (!use_extractors_for_locating_) {
258     parse_result.span = {start, end};
259   }
260 
261   if (parse_result.span.first != kInvalidIndex &&
262       parse_result.span.second != kInvalidIndex) {
263     parse_result.target_classification_score =
264         rule.pattern->target_classification_score();
265     parse_result.priority_score = rule.pattern->priority_score();
266 
267     for (DatetimeParseResult& alternative : alternatives) {
268       parse_result.data.push_back(alternative);
269     }
270   }
271   results.push_back(parse_result);
272   return results;
273 }
274 
275 StatusOr<std::vector<DatetimeParseResultSpan>>
ParseWithRule(const CompiledRule & rule,const UnicodeText & input,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,const int locale_id,bool anchor_start_end) const276 RegexDatetimeParser::ParseWithRule(const CompiledRule& rule,
277                                    const UnicodeText& input,
278                                    const int64 reference_time_ms_utc,
279                                    const std::string& reference_timezone,
280                                    const std::string& reference_locale,
281                                    const int locale_id,
282                                    bool anchor_start_end) const {
283   std::vector<DatetimeParseResultSpan> results;
284   std::unique_ptr<UniLib::RegexMatcher> matcher =
285       rule.compiled_regex->Matcher(input);
286   int status = UniLib::RegexMatcher::kNoError;
287   if (anchor_start_end) {
288     if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
289       return HandleParseMatch(rule, *matcher, reference_time_ms_utc,
290                               reference_timezone, reference_locale, locale_id);
291     }
292   } else {
293     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
294       TC3_ASSIGN_OR_RETURN(
295           const std::vector<DatetimeParseResultSpan>& pattern_occurrence,
296           HandleParseMatch(rule, *matcher, reference_time_ms_utc,
297                            reference_timezone, reference_locale, locale_id));
298       results.insert(std::end(results), std::begin(pattern_occurrence),
299                      std::end(pattern_occurrence));
300     }
301   }
302   return results;
303 }
304 
ParseAndExpandLocales(const std::vector<StringPiece> & locales) const305 std::vector<int> RegexDatetimeParser::ParseAndExpandLocales(
306     const std::vector<StringPiece>& locales) const {
307   std::vector<int> result;
308   for (const StringPiece& locale_str : locales) {
309     auto locale_it = locale_string_to_id_.find(locale_str.ToString());
310     if (locale_it != locale_string_to_id_.end()) {
311       result.push_back(locale_it->second);
312     }
313 
314     const Locale locale = Locale::FromBCP47(locale_str.ToString());
315     if (!locale.IsValid()) {
316       continue;
317     }
318 
319     const std::string language = locale.Language();
320     const std::string script = locale.Script();
321     const std::string region = locale.Region();
322 
323     // First, try adding *-region locale.
324     if (!region.empty()) {
325       locale_it = locale_string_to_id_.find("*-" + region);
326       if (locale_it != locale_string_to_id_.end()) {
327         result.push_back(locale_it->second);
328       }
329     }
330     // Second, try adding language-script-* locale.
331     if (!script.empty()) {
332       locale_it = locale_string_to_id_.find(language + "-" + script + "-*");
333       if (locale_it != locale_string_to_id_.end()) {
334         result.push_back(locale_it->second);
335       }
336     }
337     // Third, try adding language-* locale.
338     if (!language.empty()) {
339       locale_it = locale_string_to_id_.find(language + "-*");
340       if (locale_it != locale_string_to_id_.end()) {
341         result.push_back(locale_it->second);
342       }
343     }
344   }
345 
346   // Add the default locales if they haven't been added already.
347   const std::unordered_set<int> result_set(result.begin(), result.end());
348   for (const int default_locale_id : default_locale_ids_) {
349     if (result_set.find(default_locale_id) == result_set.end()) {
350       result.push_back(default_locale_id);
351     }
352   }
353 
354   return result;
355 }
356 
ExtractDatetime(const CompiledRule & rule,const UniLib::RegexMatcher & matcher,const int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & reference_locale,int locale_id,std::vector<DatetimeParseResult> * results,CodepointSpan * result_span) const357 bool RegexDatetimeParser::ExtractDatetime(
358     const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
359     const int64 reference_time_ms_utc, const std::string& reference_timezone,
360     const std::string& reference_locale, int locale_id,
361     std::vector<DatetimeParseResult>* results,
362     CodepointSpan* result_span) const {
363   DatetimeParsedData parse;
364   DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_,
365                               extractor_rules_,
366                               type_and_locale_to_extractor_rule_);
367   if (!extractor.Extract(&parse, result_span)) {
368     return false;
369   }
370   std::vector<DatetimeParsedData> interpretations;
371   if (generate_alternative_interpretations_when_ambiguous_) {
372     FillInterpretations(parse, calendarlib_.GetGranularity(parse),
373                         &interpretations);
374   } else {
375     interpretations.push_back(parse);
376   }
377 
378   results->reserve(results->size() + interpretations.size());
379   for (const DatetimeParsedData& interpretation : interpretations) {
380     std::vector<DatetimeComponent> date_components;
381     interpretation.GetDatetimeComponents(&date_components);
382     DatetimeParseResult result;
383     // TODO(hassan): Text classifier only provides ambiguity limited to “AM/PM
384     //               which is encoded in the pair of DatetimeParseResult; both
385     //               corresponding to the same date, but one corresponding to
386     //               “AM” and the other one corresponding to “PM”.
387     //               Remove multiple DatetimeParseResult per datetime span,
388     //               once the ambiguities/DatetimeComponents are added in the
389     //               response. For Details see b/130355975
390     if (!calendarlib_.InterpretParseData(
391             interpretation, reference_time_ms_utc, reference_timezone,
392             reference_locale, prefer_future_for_unspecified_date_,
393             &(result.time_ms_utc), &(result.granularity))) {
394       return false;
395     }
396 
397     // Sort the date time units by component type.
398     std::stable_sort(date_components.begin(), date_components.end(),
399                      [](DatetimeComponent a, DatetimeComponent b) {
400                        return a.component_type > b.component_type;
401                      });
402     result.datetime_components.swap(date_components);
403     results->push_back(result);
404   }
405   return true;
406 }
407 
408 }  // namespace libtextclassifier3
409