xref: /aosp_15_r20/external/libtextclassifier/native/actions/regex-actions.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/regex-actions.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "actions/utils.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
21*993b0882SAndroid Build Coastguard Worker #include "utils/regex-match.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/zlib/zlib_regex.h"
24*993b0882SAndroid Build Coastguard Worker 
25*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
26*993b0882SAndroid Build Coastguard Worker namespace {
27*993b0882SAndroid Build Coastguard Worker 
28*993b0882SAndroid Build Coastguard Worker // Creates an annotation from a regex capturing group.
FillAnnotationFromMatchGroup(const UniLib::RegexMatcher * matcher,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,const std::string & group_match_text,const int message_index,ActionSuggestionAnnotation * annotation)29*993b0882SAndroid Build Coastguard Worker bool FillAnnotationFromMatchGroup(
30*993b0882SAndroid Build Coastguard Worker     const UniLib::RegexMatcher* matcher,
31*993b0882SAndroid Build Coastguard Worker     const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
32*993b0882SAndroid Build Coastguard Worker     const std::string& group_match_text, const int message_index,
33*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation* annotation) {
34*993b0882SAndroid Build Coastguard Worker   if (group->annotation_name() != nullptr ||
35*993b0882SAndroid Build Coastguard Worker       group->annotation_type() != nullptr) {
36*993b0882SAndroid Build Coastguard Worker     int status = UniLib::RegexMatcher::kNoError;
37*993b0882SAndroid Build Coastguard Worker     const CodepointSpan span = {matcher->Start(group->group_id(), &status),
38*993b0882SAndroid Build Coastguard Worker                                 matcher->End(group->group_id(), &status)};
39*993b0882SAndroid Build Coastguard Worker     if (status != UniLib::RegexMatcher::kNoError) {
40*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
41*993b0882SAndroid Build Coastguard Worker       return false;
42*993b0882SAndroid Build Coastguard Worker     }
43*993b0882SAndroid Build Coastguard Worker     return FillAnnotationFromCapturingMatch(span, group, message_index,
44*993b0882SAndroid Build Coastguard Worker                                             group_match_text, annotation);
45*993b0882SAndroid Build Coastguard Worker   }
46*993b0882SAndroid Build Coastguard Worker   return true;
47*993b0882SAndroid Build Coastguard Worker }
48*993b0882SAndroid Build Coastguard Worker 
49*993b0882SAndroid Build Coastguard Worker }  // namespace
50*993b0882SAndroid Build Coastguard Worker 
InitializeRules(const RulesModel * rules,const RulesModel * low_confidence_rules,const TriggeringPreconditions * triggering_preconditions_overlay,ZlibDecompressor * decompressor)51*993b0882SAndroid Build Coastguard Worker bool RegexActions::InitializeRules(
52*993b0882SAndroid Build Coastguard Worker     const RulesModel* rules, const RulesModel* low_confidence_rules,
53*993b0882SAndroid Build Coastguard Worker     const TriggeringPreconditions* triggering_preconditions_overlay,
54*993b0882SAndroid Build Coastguard Worker     ZlibDecompressor* decompressor) {
55*993b0882SAndroid Build Coastguard Worker   if (rules != nullptr) {
56*993b0882SAndroid Build Coastguard Worker     if (!InitializeRulesModel(rules, decompressor, &rules_)) {
57*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not initialize action rules.";
58*993b0882SAndroid Build Coastguard Worker       return false;
59*993b0882SAndroid Build Coastguard Worker     }
60*993b0882SAndroid Build Coastguard Worker   }
61*993b0882SAndroid Build Coastguard Worker 
62*993b0882SAndroid Build Coastguard Worker   if (low_confidence_rules != nullptr) {
63*993b0882SAndroid Build Coastguard Worker     if (!InitializeRulesModel(low_confidence_rules, decompressor,
64*993b0882SAndroid Build Coastguard Worker                               &low_confidence_rules_)) {
65*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not initialize low confidence rules.";
66*993b0882SAndroid Build Coastguard Worker       return false;
67*993b0882SAndroid Build Coastguard Worker     }
68*993b0882SAndroid Build Coastguard Worker   }
69*993b0882SAndroid Build Coastguard Worker 
70*993b0882SAndroid Build Coastguard Worker   // Extend by rules provided by the overwrite.
71*993b0882SAndroid Build Coastguard Worker   // NOTE: The rules from the original models are *not* cleared.
72*993b0882SAndroid Build Coastguard Worker   if (triggering_preconditions_overlay != nullptr &&
73*993b0882SAndroid Build Coastguard Worker       triggering_preconditions_overlay->low_confidence_rules() != nullptr) {
74*993b0882SAndroid Build Coastguard Worker     // These rules are optionally compressed, but separately.
75*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<ZlibDecompressor> overwrite_decompressor =
76*993b0882SAndroid Build Coastguard Worker         ZlibDecompressor::Instance();
77*993b0882SAndroid Build Coastguard Worker     if (overwrite_decompressor == nullptr) {
78*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules.";
79*993b0882SAndroid Build Coastguard Worker       return false;
80*993b0882SAndroid Build Coastguard Worker     }
81*993b0882SAndroid Build Coastguard Worker     if (!InitializeRulesModel(
82*993b0882SAndroid Build Coastguard Worker             triggering_preconditions_overlay->low_confidence_rules(),
83*993b0882SAndroid Build Coastguard Worker             overwrite_decompressor.get(), &low_confidence_rules_)) {
84*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR)
85*993b0882SAndroid Build Coastguard Worker           << "Could not initialize low confidence rules from overwrite.";
86*993b0882SAndroid Build Coastguard Worker       return false;
87*993b0882SAndroid Build Coastguard Worker     }
88*993b0882SAndroid Build Coastguard Worker   }
89*993b0882SAndroid Build Coastguard Worker 
90*993b0882SAndroid Build Coastguard Worker   return true;
91*993b0882SAndroid Build Coastguard Worker }
92*993b0882SAndroid Build Coastguard Worker 
InitializeRulesModel(const RulesModel * rules,ZlibDecompressor * decompressor,std::vector<CompiledRule> * compiled_rules) const93*993b0882SAndroid Build Coastguard Worker bool RegexActions::InitializeRulesModel(
94*993b0882SAndroid Build Coastguard Worker     const RulesModel* rules, ZlibDecompressor* decompressor,
95*993b0882SAndroid Build Coastguard Worker     std::vector<CompiledRule>* compiled_rules) const {
96*993b0882SAndroid Build Coastguard Worker   if (rules->regex_rule() == nullptr) {
97*993b0882SAndroid Build Coastguard Worker     return true;
98*993b0882SAndroid Build Coastguard Worker   }
99*993b0882SAndroid Build Coastguard Worker   for (const RulesModel_::RegexRule* rule : *rules->regex_rule()) {
100*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
101*993b0882SAndroid Build Coastguard Worker         UncompressMakeRegexPattern(
102*993b0882SAndroid Build Coastguard Worker             unilib_, rule->pattern(), rule->compressed_pattern(),
103*993b0882SAndroid Build Coastguard Worker             rules->lazy_regex_compilation(), decompressor);
104*993b0882SAndroid Build Coastguard Worker     if (compiled_pattern == nullptr) {
105*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Failed to load rule pattern.";
106*993b0882SAndroid Build Coastguard Worker       return false;
107*993b0882SAndroid Build Coastguard Worker     }
108*993b0882SAndroid Build Coastguard Worker 
109*993b0882SAndroid Build Coastguard Worker     // Check whether there is a check on the output.
110*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<UniLib::RegexPattern> compiled_output_pattern;
111*993b0882SAndroid Build Coastguard Worker     if (rule->output_pattern() != nullptr ||
112*993b0882SAndroid Build Coastguard Worker         rule->compressed_output_pattern() != nullptr) {
113*993b0882SAndroid Build Coastguard Worker       compiled_output_pattern = UncompressMakeRegexPattern(
114*993b0882SAndroid Build Coastguard Worker           unilib_, rule->output_pattern(), rule->compressed_output_pattern(),
115*993b0882SAndroid Build Coastguard Worker           rules->lazy_regex_compilation(), decompressor);
116*993b0882SAndroid Build Coastguard Worker       if (compiled_output_pattern == nullptr) {
117*993b0882SAndroid Build Coastguard Worker         TC3_LOG(ERROR) << "Failed to load rule output pattern.";
118*993b0882SAndroid Build Coastguard Worker         return false;
119*993b0882SAndroid Build Coastguard Worker       }
120*993b0882SAndroid Build Coastguard Worker     }
121*993b0882SAndroid Build Coastguard Worker 
122*993b0882SAndroid Build Coastguard Worker     compiled_rules->emplace_back(rule, std::move(compiled_pattern),
123*993b0882SAndroid Build Coastguard Worker                                  std::move(compiled_output_pattern));
124*993b0882SAndroid Build Coastguard Worker   }
125*993b0882SAndroid Build Coastguard Worker 
126*993b0882SAndroid Build Coastguard Worker   return true;
127*993b0882SAndroid Build Coastguard Worker }
128*993b0882SAndroid Build Coastguard Worker 
IsLowConfidenceInput(const Conversation & conversation,const int num_messages,std::vector<const UniLib::RegexPattern * > * post_check_rules) const129*993b0882SAndroid Build Coastguard Worker bool RegexActions::IsLowConfidenceInput(
130*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const int num_messages,
131*993b0882SAndroid Build Coastguard Worker     std::vector<const UniLib::RegexPattern*>* post_check_rules) const {
132*993b0882SAndroid Build Coastguard Worker   for (int i = 1; i <= num_messages; i++) {
133*993b0882SAndroid Build Coastguard Worker     const std::string& message =
134*993b0882SAndroid Build Coastguard Worker         conversation.messages[conversation.messages.size() - i].text;
135*993b0882SAndroid Build Coastguard Worker     const UnicodeText message_unicode(
136*993b0882SAndroid Build Coastguard Worker         UTF8ToUnicodeText(message, /*do_copy=*/false));
137*993b0882SAndroid Build Coastguard Worker     for (int low_confidence_rule = 0;
138*993b0882SAndroid Build Coastguard Worker          low_confidence_rule < low_confidence_rules_.size();
139*993b0882SAndroid Build Coastguard Worker          low_confidence_rule++) {
140*993b0882SAndroid Build Coastguard Worker       const CompiledRule& rule = low_confidence_rules_[low_confidence_rule];
141*993b0882SAndroid Build Coastguard Worker       const std::unique_ptr<UniLib::RegexMatcher> matcher =
142*993b0882SAndroid Build Coastguard Worker           rule.pattern->Matcher(message_unicode);
143*993b0882SAndroid Build Coastguard Worker       int status = UniLib::RegexMatcher::kNoError;
144*993b0882SAndroid Build Coastguard Worker       if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
145*993b0882SAndroid Build Coastguard Worker         // Rule only applies to input-output pairs, so defer the check.
146*993b0882SAndroid Build Coastguard Worker         if (rule.output_pattern != nullptr) {
147*993b0882SAndroid Build Coastguard Worker           post_check_rules->push_back(rule.output_pattern.get());
148*993b0882SAndroid Build Coastguard Worker           continue;
149*993b0882SAndroid Build Coastguard Worker         }
150*993b0882SAndroid Build Coastguard Worker         return true;
151*993b0882SAndroid Build Coastguard Worker       }
152*993b0882SAndroid Build Coastguard Worker     }
153*993b0882SAndroid Build Coastguard Worker   }
154*993b0882SAndroid Build Coastguard Worker   return false;
155*993b0882SAndroid Build Coastguard Worker }
156*993b0882SAndroid Build Coastguard Worker 
FilterConfidenceOutput(const std::vector<const UniLib::RegexPattern * > & post_check_rules,std::vector<ActionSuggestion> * actions) const157*993b0882SAndroid Build Coastguard Worker bool RegexActions::FilterConfidenceOutput(
158*993b0882SAndroid Build Coastguard Worker     const std::vector<const UniLib::RegexPattern*>& post_check_rules,
159*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* actions) const {
160*993b0882SAndroid Build Coastguard Worker   if (post_check_rules.empty() || actions->empty()) {
161*993b0882SAndroid Build Coastguard Worker     return true;
162*993b0882SAndroid Build Coastguard Worker   }
163*993b0882SAndroid Build Coastguard Worker   std::vector<ActionSuggestion> filtered_text_replies;
164*993b0882SAndroid Build Coastguard Worker   for (const ActionSuggestion& action : *actions) {
165*993b0882SAndroid Build Coastguard Worker     if (action.response_text.empty()) {
166*993b0882SAndroid Build Coastguard Worker       filtered_text_replies.push_back(action);
167*993b0882SAndroid Build Coastguard Worker       continue;
168*993b0882SAndroid Build Coastguard Worker     }
169*993b0882SAndroid Build Coastguard Worker     bool passes_post_check = true;
170*993b0882SAndroid Build Coastguard Worker     const UnicodeText text_reply_unicode(
171*993b0882SAndroid Build Coastguard Worker         UTF8ToUnicodeText(action.response_text, /*do_copy=*/false));
172*993b0882SAndroid Build Coastguard Worker     for (const UniLib::RegexPattern* post_check_rule : post_check_rules) {
173*993b0882SAndroid Build Coastguard Worker       const std::unique_ptr<UniLib::RegexMatcher> matcher =
174*993b0882SAndroid Build Coastguard Worker           post_check_rule->Matcher(text_reply_unicode);
175*993b0882SAndroid Build Coastguard Worker       if (matcher == nullptr) {
176*993b0882SAndroid Build Coastguard Worker         TC3_LOG(ERROR) << "Could not create matcher for post check rule.";
177*993b0882SAndroid Build Coastguard Worker         return false;
178*993b0882SAndroid Build Coastguard Worker       }
179*993b0882SAndroid Build Coastguard Worker       int status = UniLib::RegexMatcher::kNoError;
180*993b0882SAndroid Build Coastguard Worker       if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) {
181*993b0882SAndroid Build Coastguard Worker         passes_post_check = false;
182*993b0882SAndroid Build Coastguard Worker         break;
183*993b0882SAndroid Build Coastguard Worker       }
184*993b0882SAndroid Build Coastguard Worker     }
185*993b0882SAndroid Build Coastguard Worker     if (passes_post_check) {
186*993b0882SAndroid Build Coastguard Worker       filtered_text_replies.push_back(action);
187*993b0882SAndroid Build Coastguard Worker     }
188*993b0882SAndroid Build Coastguard Worker   }
189*993b0882SAndroid Build Coastguard Worker   *actions = std::move(filtered_text_replies);
190*993b0882SAndroid Build Coastguard Worker   return true;
191*993b0882SAndroid Build Coastguard Worker }
192*993b0882SAndroid Build Coastguard Worker 
SuggestActions(const Conversation & conversation,const MutableFlatbufferBuilder * entity_data_builder,std::vector<ActionSuggestion> * actions) const193*993b0882SAndroid Build Coastguard Worker bool RegexActions::SuggestActions(
194*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation,
195*993b0882SAndroid Build Coastguard Worker     const MutableFlatbufferBuilder* entity_data_builder,
196*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* actions) const {
197*993b0882SAndroid Build Coastguard Worker   // Create actions based on rules checking the last message.
198*993b0882SAndroid Build Coastguard Worker   const int message_index = conversation.messages.size() - 1;
199*993b0882SAndroid Build Coastguard Worker   const std::string& message = conversation.messages.back().text;
200*993b0882SAndroid Build Coastguard Worker   const UnicodeText message_unicode(
201*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(message, /*do_copy=*/false));
202*993b0882SAndroid Build Coastguard Worker   for (const CompiledRule& rule : rules_) {
203*993b0882SAndroid Build Coastguard Worker     const std::unique_ptr<UniLib::RegexMatcher> matcher =
204*993b0882SAndroid Build Coastguard Worker         rule.pattern->Matcher(message_unicode);
205*993b0882SAndroid Build Coastguard Worker     int status = UniLib::RegexMatcher::kNoError;
206*993b0882SAndroid Build Coastguard Worker     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
207*993b0882SAndroid Build Coastguard Worker       for (const RulesModel_::RuleActionSpec* rule_action :
208*993b0882SAndroid Build Coastguard Worker            *rule.rule->actions()) {
209*993b0882SAndroid Build Coastguard Worker         const ActionSuggestionSpec* action = rule_action->action();
210*993b0882SAndroid Build Coastguard Worker         std::vector<ActionSuggestionAnnotation> annotations;
211*993b0882SAndroid Build Coastguard Worker 
212*993b0882SAndroid Build Coastguard Worker         std::unique_ptr<MutableFlatbuffer> entity_data =
213*993b0882SAndroid Build Coastguard Worker             entity_data_builder != nullptr ? entity_data_builder->NewRoot()
214*993b0882SAndroid Build Coastguard Worker                                            : nullptr;
215*993b0882SAndroid Build Coastguard Worker 
216*993b0882SAndroid Build Coastguard Worker         // Add entity data from rule capturing groups.
217*993b0882SAndroid Build Coastguard Worker         if (rule_action->capturing_group() != nullptr) {
218*993b0882SAndroid Build Coastguard Worker           for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
219*993b0882SAndroid Build Coastguard Worker                *rule_action->capturing_group()) {
220*993b0882SAndroid Build Coastguard Worker             Optional<std::string> group_match_text =
221*993b0882SAndroid Build Coastguard Worker                 GetCapturingGroupText(matcher.get(), group->group_id());
222*993b0882SAndroid Build Coastguard Worker             if (!group_match_text.has_value()) {
223*993b0882SAndroid Build Coastguard Worker               // The group was not part of the match, ignore and continue.
224*993b0882SAndroid Build Coastguard Worker               continue;
225*993b0882SAndroid Build Coastguard Worker             }
226*993b0882SAndroid Build Coastguard Worker 
227*993b0882SAndroid Build Coastguard Worker             UnicodeText normalized_group_match_text =
228*993b0882SAndroid Build Coastguard Worker                 NormalizeMatchText(unilib_, group, group_match_text.value());
229*993b0882SAndroid Build Coastguard Worker 
230*993b0882SAndroid Build Coastguard Worker             if (!MergeEntityDataFromCapturingMatch(
231*993b0882SAndroid Build Coastguard Worker                     group, normalized_group_match_text.ToUTF8String(),
232*993b0882SAndroid Build Coastguard Worker                     entity_data.get())) {
233*993b0882SAndroid Build Coastguard Worker               TC3_LOG(ERROR)
234*993b0882SAndroid Build Coastguard Worker                   << "Could not merge entity data from a capturing match.";
235*993b0882SAndroid Build Coastguard Worker               return false;
236*993b0882SAndroid Build Coastguard Worker             }
237*993b0882SAndroid Build Coastguard Worker 
238*993b0882SAndroid Build Coastguard Worker             // Create a text annotation for the group span.
239*993b0882SAndroid Build Coastguard Worker             ActionSuggestionAnnotation annotation;
240*993b0882SAndroid Build Coastguard Worker             if (FillAnnotationFromMatchGroup(matcher.get(), group,
241*993b0882SAndroid Build Coastguard Worker                                              group_match_text.value(),
242*993b0882SAndroid Build Coastguard Worker                                              message_index, &annotation)) {
243*993b0882SAndroid Build Coastguard Worker               annotations.push_back(annotation);
244*993b0882SAndroid Build Coastguard Worker             }
245*993b0882SAndroid Build Coastguard Worker 
246*993b0882SAndroid Build Coastguard Worker             // Create text reply.
247*993b0882SAndroid Build Coastguard Worker             SuggestTextRepliesFromCapturingMatch(
248*993b0882SAndroid Build Coastguard Worker                 entity_data_builder, group, normalized_group_match_text,
249*993b0882SAndroid Build Coastguard Worker                 smart_reply_action_type_, actions);
250*993b0882SAndroid Build Coastguard Worker           }
251*993b0882SAndroid Build Coastguard Worker         }
252*993b0882SAndroid Build Coastguard Worker 
253*993b0882SAndroid Build Coastguard Worker         if (action != nullptr) {
254*993b0882SAndroid Build Coastguard Worker           ActionSuggestion suggestion;
255*993b0882SAndroid Build Coastguard Worker           suggestion.annotations = annotations;
256*993b0882SAndroid Build Coastguard Worker           FillSuggestionFromSpec(action, entity_data.get(), &suggestion);
257*993b0882SAndroid Build Coastguard Worker           actions->push_back(suggestion);
258*993b0882SAndroid Build Coastguard Worker         }
259*993b0882SAndroid Build Coastguard Worker       }
260*993b0882SAndroid Build Coastguard Worker     }
261*993b0882SAndroid Build Coastguard Worker   }
262*993b0882SAndroid Build Coastguard Worker   return true;
263*993b0882SAndroid Build Coastguard Worker }
264*993b0882SAndroid Build Coastguard Worker 
265*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
266