xref: /aosp_15_r20/external/libtextclassifier/native/actions/grammar-actions.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/grammar-actions.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "actions/feature-processor.h"
20*993b0882SAndroid Build Coastguard Worker #include "actions/utils.h"
21*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/base/arena.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/base/statusor.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
25*993b0882SAndroid Build Coastguard Worker 
26*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
27*993b0882SAndroid Build Coastguard Worker 
GrammarActions(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules,const MutableFlatbufferBuilder * entity_data_builder,const std::string & smart_reply_action_type)28*993b0882SAndroid Build Coastguard Worker GrammarActions::GrammarActions(
29*993b0882SAndroid Build Coastguard Worker     const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
30*993b0882SAndroid Build Coastguard Worker     const MutableFlatbufferBuilder* entity_data_builder,
31*993b0882SAndroid Build Coastguard Worker     const std::string& smart_reply_action_type)
32*993b0882SAndroid Build Coastguard Worker     : unilib_(*unilib),
33*993b0882SAndroid Build Coastguard Worker       grammar_rules_(grammar_rules),
34*993b0882SAndroid Build Coastguard Worker       tokenizer_(CreateTokenizer(grammar_rules->tokenizer_options(), unilib)),
35*993b0882SAndroid Build Coastguard Worker       entity_data_builder_(entity_data_builder),
36*993b0882SAndroid Build Coastguard Worker       analyzer_(unilib, grammar_rules->rules(), tokenizer_.get()),
37*993b0882SAndroid Build Coastguard Worker       smart_reply_action_type_(smart_reply_action_type) {}
38*993b0882SAndroid Build Coastguard Worker 
InstantiateActionsFromMatch(const grammar::TextContext & text_context,const int message_index,const grammar::Derivation & derivation,std::vector<ActionSuggestion> * result) const39*993b0882SAndroid Build Coastguard Worker bool GrammarActions::InstantiateActionsFromMatch(
40*993b0882SAndroid Build Coastguard Worker     const grammar::TextContext& text_context, const int message_index,
41*993b0882SAndroid Build Coastguard Worker     const grammar::Derivation& derivation,
42*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* result) const {
43*993b0882SAndroid Build Coastguard Worker   const RulesModel_::GrammarRules_::RuleMatch* rule_match =
44*993b0882SAndroid Build Coastguard Worker       grammar_rules_->rule_match()->Get(derivation.rule_id);
45*993b0882SAndroid Build Coastguard Worker   if (rule_match == nullptr || rule_match->action_id() == nullptr) {
46*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "No rule action defined.";
47*993b0882SAndroid Build Coastguard Worker     return false;
48*993b0882SAndroid Build Coastguard Worker   }
49*993b0882SAndroid Build Coastguard Worker 
50*993b0882SAndroid Build Coastguard Worker   // Gather active capturing matches.
51*993b0882SAndroid Build Coastguard Worker   std::unordered_map<uint16, const grammar::ParseTree*> capturing_matches;
52*993b0882SAndroid Build Coastguard Worker   for (const grammar::MappingNode* mapping_node :
53*993b0882SAndroid Build Coastguard Worker        grammar::SelectAllOfType<grammar::MappingNode>(
54*993b0882SAndroid Build Coastguard Worker            derivation.parse_tree, grammar::ParseTree::Type::kMapping)) {
55*993b0882SAndroid Build Coastguard Worker     capturing_matches[mapping_node->id] = mapping_node;
56*993b0882SAndroid Build Coastguard Worker   }
57*993b0882SAndroid Build Coastguard Worker 
58*993b0882SAndroid Build Coastguard Worker   // Instantiate actions from the rule match.
59*993b0882SAndroid Build Coastguard Worker   for (const uint16 action_id : *rule_match->action_id()) {
60*993b0882SAndroid Build Coastguard Worker     const RulesModel_::RuleActionSpec* action_spec =
61*993b0882SAndroid Build Coastguard Worker         grammar_rules_->actions()->Get(action_id);
62*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestionAnnotation> annotations;
63*993b0882SAndroid Build Coastguard Worker 
64*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<MutableFlatbuffer> entity_data =
65*993b0882SAndroid Build Coastguard Worker         entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
66*993b0882SAndroid Build Coastguard Worker                                         : nullptr;
67*993b0882SAndroid Build Coastguard Worker 
68*993b0882SAndroid Build Coastguard Worker     // Set information from capturing matches.
69*993b0882SAndroid Build Coastguard Worker     if (action_spec->capturing_group() != nullptr) {
70*993b0882SAndroid Build Coastguard Worker       for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
71*993b0882SAndroid Build Coastguard Worker            *action_spec->capturing_group()) {
72*993b0882SAndroid Build Coastguard Worker         auto it = capturing_matches.find(group->group_id());
73*993b0882SAndroid Build Coastguard Worker         if (it == capturing_matches.end()) {
74*993b0882SAndroid Build Coastguard Worker           // Capturing match is not active, skip.
75*993b0882SAndroid Build Coastguard Worker           continue;
76*993b0882SAndroid Build Coastguard Worker         }
77*993b0882SAndroid Build Coastguard Worker 
78*993b0882SAndroid Build Coastguard Worker         const grammar::ParseTree* capturing_match = it->second;
79*993b0882SAndroid Build Coastguard Worker         const UnicodeText match_text =
80*993b0882SAndroid Build Coastguard Worker             text_context.Span(capturing_match->codepoint_span);
81*993b0882SAndroid Build Coastguard Worker         UnicodeText normalized_match_text =
82*993b0882SAndroid Build Coastguard Worker             NormalizeMatchText(unilib_, group, match_text);
83*993b0882SAndroid Build Coastguard Worker 
84*993b0882SAndroid Build Coastguard Worker         if (!MergeEntityDataFromCapturingMatch(
85*993b0882SAndroid Build Coastguard Worker                 group, normalized_match_text.ToUTF8String(),
86*993b0882SAndroid Build Coastguard Worker                 entity_data.get())) {
87*993b0882SAndroid Build Coastguard Worker           TC3_LOG(ERROR)
88*993b0882SAndroid Build Coastguard Worker               << "Could not merge entity data from a capturing match.";
89*993b0882SAndroid Build Coastguard Worker           return false;
90*993b0882SAndroid Build Coastguard Worker         }
91*993b0882SAndroid Build Coastguard Worker 
92*993b0882SAndroid Build Coastguard Worker         // Add smart reply suggestions.
93*993b0882SAndroid Build Coastguard Worker         SuggestTextRepliesFromCapturingMatch(entity_data_builder_, group,
94*993b0882SAndroid Build Coastguard Worker                                              normalized_match_text,
95*993b0882SAndroid Build Coastguard Worker                                              smart_reply_action_type_, result);
96*993b0882SAndroid Build Coastguard Worker 
97*993b0882SAndroid Build Coastguard Worker         // Add annotation.
98*993b0882SAndroid Build Coastguard Worker         ActionSuggestionAnnotation annotation;
99*993b0882SAndroid Build Coastguard Worker         if (FillAnnotationFromCapturingMatch(
100*993b0882SAndroid Build Coastguard Worker                 /*span=*/capturing_match->codepoint_span, group,
101*993b0882SAndroid Build Coastguard Worker                 /*message_index=*/message_index, match_text.ToUTF8String(),
102*993b0882SAndroid Build Coastguard Worker                 &annotation)) {
103*993b0882SAndroid Build Coastguard Worker           if (group->use_annotation_match()) {
104*993b0882SAndroid Build Coastguard Worker             std::vector<const grammar::AnnotationNode*> annotations =
105*993b0882SAndroid Build Coastguard Worker                 grammar::SelectAllOfType<grammar::AnnotationNode>(
106*993b0882SAndroid Build Coastguard Worker                     capturing_match, grammar::ParseTree::Type::kAnnotation);
107*993b0882SAndroid Build Coastguard Worker             if (annotations.size() != 1) {
108*993b0882SAndroid Build Coastguard Worker               TC3_LOG(ERROR) << "Could not get annotation for match.";
109*993b0882SAndroid Build Coastguard Worker               return false;
110*993b0882SAndroid Build Coastguard Worker             }
111*993b0882SAndroid Build Coastguard Worker             annotation.entity = *annotations.front()->annotation;
112*993b0882SAndroid Build Coastguard Worker           }
113*993b0882SAndroid Build Coastguard Worker           annotations.push_back(std::move(annotation));
114*993b0882SAndroid Build Coastguard Worker         }
115*993b0882SAndroid Build Coastguard Worker       }
116*993b0882SAndroid Build Coastguard Worker     }
117*993b0882SAndroid Build Coastguard Worker 
118*993b0882SAndroid Build Coastguard Worker     if (action_spec->action() != nullptr) {
119*993b0882SAndroid Build Coastguard Worker       ActionSuggestion suggestion;
120*993b0882SAndroid Build Coastguard Worker       suggestion.annotations = annotations;
121*993b0882SAndroid Build Coastguard Worker       FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
122*993b0882SAndroid Build Coastguard Worker                              &suggestion);
123*993b0882SAndroid Build Coastguard Worker       result->push_back(std::move(suggestion));
124*993b0882SAndroid Build Coastguard Worker     }
125*993b0882SAndroid Build Coastguard Worker   }
126*993b0882SAndroid Build Coastguard Worker   return true;
127*993b0882SAndroid Build Coastguard Worker }
SuggestActions(const Conversation & conversation,std::vector<ActionSuggestion> * result) const128*993b0882SAndroid Build Coastguard Worker bool GrammarActions::SuggestActions(
129*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation,
130*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* result) const {
131*993b0882SAndroid Build Coastguard Worker   if (grammar_rules_->rules()->rules() == nullptr ||
132*993b0882SAndroid Build Coastguard Worker       conversation.messages.back().text.empty()) {
133*993b0882SAndroid Build Coastguard Worker     // Nothing to do.
134*993b0882SAndroid Build Coastguard Worker     return true;
135*993b0882SAndroid Build Coastguard Worker   }
136*993b0882SAndroid Build Coastguard Worker 
137*993b0882SAndroid Build Coastguard Worker   std::vector<Locale> locales;
138*993b0882SAndroid Build Coastguard Worker   if (!ParseLocales(conversation.messages.back().detected_text_language_tags,
139*993b0882SAndroid Build Coastguard Worker                     &locales)) {
140*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not parse locales of input text.";
141*993b0882SAndroid Build Coastguard Worker     return false;
142*993b0882SAndroid Build Coastguard Worker   }
143*993b0882SAndroid Build Coastguard Worker 
144*993b0882SAndroid Build Coastguard Worker   const int message_index = conversation.messages.size() - 1;
145*993b0882SAndroid Build Coastguard Worker   grammar::TextContext text = analyzer_.BuildTextContextForInput(
146*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false),
147*993b0882SAndroid Build Coastguard Worker       locales);
148*993b0882SAndroid Build Coastguard Worker   text.annotations = conversation.messages.back().annotations;
149*993b0882SAndroid Build Coastguard Worker 
150*993b0882SAndroid Build Coastguard Worker   UnsafeArena arena(/*block_size=*/16 << 10);
151*993b0882SAndroid Build Coastguard Worker   StatusOr<std::vector<grammar::EvaluatedDerivation>> evaluated_derivations =
152*993b0882SAndroid Build Coastguard Worker       analyzer_.Parse(text, &arena);
153*993b0882SAndroid Build Coastguard Worker   // TODO(b/171294882): Return the status here and below.
154*993b0882SAndroid Build Coastguard Worker   if (!evaluated_derivations.ok()) {
155*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not run grammar analyzer: "
156*993b0882SAndroid Build Coastguard Worker                    << evaluated_derivations.status().error_message();
157*993b0882SAndroid Build Coastguard Worker     return false;
158*993b0882SAndroid Build Coastguard Worker   }
159*993b0882SAndroid Build Coastguard Worker 
160*993b0882SAndroid Build Coastguard Worker   for (const grammar::EvaluatedDerivation& evaluated_derivation :
161*993b0882SAndroid Build Coastguard Worker        evaluated_derivations.ValueOrDie()) {
162*993b0882SAndroid Build Coastguard Worker     if (!InstantiateActionsFromMatch(text, message_index, evaluated_derivation,
163*993b0882SAndroid Build Coastguard Worker                                      result)) {
164*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not instantiate actions from a grammar match.";
165*993b0882SAndroid Build Coastguard Worker       return false;
166*993b0882SAndroid Build Coastguard Worker     }
167*993b0882SAndroid Build Coastguard Worker   }
168*993b0882SAndroid Build Coastguard Worker 
169*993b0882SAndroid Build Coastguard Worker   return true;
170*993b0882SAndroid Build Coastguard Worker }
171*993b0882SAndroid Build Coastguard Worker 
172*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
173