1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "actions/grammar-actions.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include "actions/feature-processor.h"
20*993b0882SAndroid Build Coastguard Worker #include "actions/utils.h"
21*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/base/arena.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/base/statusor.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
25*993b0882SAndroid Build Coastguard Worker
26*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
27*993b0882SAndroid Build Coastguard Worker
GrammarActions(const UniLib * unilib,const RulesModel_::GrammarRules * grammar_rules,const MutableFlatbufferBuilder * entity_data_builder,const std::string & smart_reply_action_type)28*993b0882SAndroid Build Coastguard Worker GrammarActions::GrammarActions(
29*993b0882SAndroid Build Coastguard Worker const UniLib* unilib, const RulesModel_::GrammarRules* grammar_rules,
30*993b0882SAndroid Build Coastguard Worker const MutableFlatbufferBuilder* entity_data_builder,
31*993b0882SAndroid Build Coastguard Worker const std::string& smart_reply_action_type)
32*993b0882SAndroid Build Coastguard Worker : unilib_(*unilib),
33*993b0882SAndroid Build Coastguard Worker grammar_rules_(grammar_rules),
34*993b0882SAndroid Build Coastguard Worker tokenizer_(CreateTokenizer(grammar_rules->tokenizer_options(), unilib)),
35*993b0882SAndroid Build Coastguard Worker entity_data_builder_(entity_data_builder),
36*993b0882SAndroid Build Coastguard Worker analyzer_(unilib, grammar_rules->rules(), tokenizer_.get()),
37*993b0882SAndroid Build Coastguard Worker smart_reply_action_type_(smart_reply_action_type) {}
38*993b0882SAndroid Build Coastguard Worker
InstantiateActionsFromMatch(const grammar::TextContext & text_context,const int message_index,const grammar::Derivation & derivation,std::vector<ActionSuggestion> * result) const39*993b0882SAndroid Build Coastguard Worker bool GrammarActions::InstantiateActionsFromMatch(
40*993b0882SAndroid Build Coastguard Worker const grammar::TextContext& text_context, const int message_index,
41*993b0882SAndroid Build Coastguard Worker const grammar::Derivation& derivation,
42*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion>* result) const {
43*993b0882SAndroid Build Coastguard Worker const RulesModel_::GrammarRules_::RuleMatch* rule_match =
44*993b0882SAndroid Build Coastguard Worker grammar_rules_->rule_match()->Get(derivation.rule_id);
45*993b0882SAndroid Build Coastguard Worker if (rule_match == nullptr || rule_match->action_id() == nullptr) {
46*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "No rule action defined.";
47*993b0882SAndroid Build Coastguard Worker return false;
48*993b0882SAndroid Build Coastguard Worker }
49*993b0882SAndroid Build Coastguard Worker
50*993b0882SAndroid Build Coastguard Worker // Gather active capturing matches.
51*993b0882SAndroid Build Coastguard Worker std::unordered_map<uint16, const grammar::ParseTree*> capturing_matches;
52*993b0882SAndroid Build Coastguard Worker for (const grammar::MappingNode* mapping_node :
53*993b0882SAndroid Build Coastguard Worker grammar::SelectAllOfType<grammar::MappingNode>(
54*993b0882SAndroid Build Coastguard Worker derivation.parse_tree, grammar::ParseTree::Type::kMapping)) {
55*993b0882SAndroid Build Coastguard Worker capturing_matches[mapping_node->id] = mapping_node;
56*993b0882SAndroid Build Coastguard Worker }
57*993b0882SAndroid Build Coastguard Worker
58*993b0882SAndroid Build Coastguard Worker // Instantiate actions from the rule match.
59*993b0882SAndroid Build Coastguard Worker for (const uint16 action_id : *rule_match->action_id()) {
60*993b0882SAndroid Build Coastguard Worker const RulesModel_::RuleActionSpec* action_spec =
61*993b0882SAndroid Build Coastguard Worker grammar_rules_->actions()->Get(action_id);
62*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestionAnnotation> annotations;
63*993b0882SAndroid Build Coastguard Worker
64*993b0882SAndroid Build Coastguard Worker std::unique_ptr<MutableFlatbuffer> entity_data =
65*993b0882SAndroid Build Coastguard Worker entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
66*993b0882SAndroid Build Coastguard Worker : nullptr;
67*993b0882SAndroid Build Coastguard Worker
68*993b0882SAndroid Build Coastguard Worker // Set information from capturing matches.
69*993b0882SAndroid Build Coastguard Worker if (action_spec->capturing_group() != nullptr) {
70*993b0882SAndroid Build Coastguard Worker for (const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group :
71*993b0882SAndroid Build Coastguard Worker *action_spec->capturing_group()) {
72*993b0882SAndroid Build Coastguard Worker auto it = capturing_matches.find(group->group_id());
73*993b0882SAndroid Build Coastguard Worker if (it == capturing_matches.end()) {
74*993b0882SAndroid Build Coastguard Worker // Capturing match is not active, skip.
75*993b0882SAndroid Build Coastguard Worker continue;
76*993b0882SAndroid Build Coastguard Worker }
77*993b0882SAndroid Build Coastguard Worker
78*993b0882SAndroid Build Coastguard Worker const grammar::ParseTree* capturing_match = it->second;
79*993b0882SAndroid Build Coastguard Worker const UnicodeText match_text =
80*993b0882SAndroid Build Coastguard Worker text_context.Span(capturing_match->codepoint_span);
81*993b0882SAndroid Build Coastguard Worker UnicodeText normalized_match_text =
82*993b0882SAndroid Build Coastguard Worker NormalizeMatchText(unilib_, group, match_text);
83*993b0882SAndroid Build Coastguard Worker
84*993b0882SAndroid Build Coastguard Worker if (!MergeEntityDataFromCapturingMatch(
85*993b0882SAndroid Build Coastguard Worker group, normalized_match_text.ToUTF8String(),
86*993b0882SAndroid Build Coastguard Worker entity_data.get())) {
87*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR)
88*993b0882SAndroid Build Coastguard Worker << "Could not merge entity data from a capturing match.";
89*993b0882SAndroid Build Coastguard Worker return false;
90*993b0882SAndroid Build Coastguard Worker }
91*993b0882SAndroid Build Coastguard Worker
92*993b0882SAndroid Build Coastguard Worker // Add smart reply suggestions.
93*993b0882SAndroid Build Coastguard Worker SuggestTextRepliesFromCapturingMatch(entity_data_builder_, group,
94*993b0882SAndroid Build Coastguard Worker normalized_match_text,
95*993b0882SAndroid Build Coastguard Worker smart_reply_action_type_, result);
96*993b0882SAndroid Build Coastguard Worker
97*993b0882SAndroid Build Coastguard Worker // Add annotation.
98*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
99*993b0882SAndroid Build Coastguard Worker if (FillAnnotationFromCapturingMatch(
100*993b0882SAndroid Build Coastguard Worker /*span=*/capturing_match->codepoint_span, group,
101*993b0882SAndroid Build Coastguard Worker /*message_index=*/message_index, match_text.ToUTF8String(),
102*993b0882SAndroid Build Coastguard Worker &annotation)) {
103*993b0882SAndroid Build Coastguard Worker if (group->use_annotation_match()) {
104*993b0882SAndroid Build Coastguard Worker std::vector<const grammar::AnnotationNode*> annotations =
105*993b0882SAndroid Build Coastguard Worker grammar::SelectAllOfType<grammar::AnnotationNode>(
106*993b0882SAndroid Build Coastguard Worker capturing_match, grammar::ParseTree::Type::kAnnotation);
107*993b0882SAndroid Build Coastguard Worker if (annotations.size() != 1) {
108*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not get annotation for match.";
109*993b0882SAndroid Build Coastguard Worker return false;
110*993b0882SAndroid Build Coastguard Worker }
111*993b0882SAndroid Build Coastguard Worker annotation.entity = *annotations.front()->annotation;
112*993b0882SAndroid Build Coastguard Worker }
113*993b0882SAndroid Build Coastguard Worker annotations.push_back(std::move(annotation));
114*993b0882SAndroid Build Coastguard Worker }
115*993b0882SAndroid Build Coastguard Worker }
116*993b0882SAndroid Build Coastguard Worker }
117*993b0882SAndroid Build Coastguard Worker
118*993b0882SAndroid Build Coastguard Worker if (action_spec->action() != nullptr) {
119*993b0882SAndroid Build Coastguard Worker ActionSuggestion suggestion;
120*993b0882SAndroid Build Coastguard Worker suggestion.annotations = annotations;
121*993b0882SAndroid Build Coastguard Worker FillSuggestionFromSpec(action_spec->action(), entity_data.get(),
122*993b0882SAndroid Build Coastguard Worker &suggestion);
123*993b0882SAndroid Build Coastguard Worker result->push_back(std::move(suggestion));
124*993b0882SAndroid Build Coastguard Worker }
125*993b0882SAndroid Build Coastguard Worker }
126*993b0882SAndroid Build Coastguard Worker return true;
127*993b0882SAndroid Build Coastguard Worker }
SuggestActions(const Conversation & conversation,std::vector<ActionSuggestion> * result) const128*993b0882SAndroid Build Coastguard Worker bool GrammarActions::SuggestActions(
129*993b0882SAndroid Build Coastguard Worker const Conversation& conversation,
130*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion>* result) const {
131*993b0882SAndroid Build Coastguard Worker if (grammar_rules_->rules()->rules() == nullptr ||
132*993b0882SAndroid Build Coastguard Worker conversation.messages.back().text.empty()) {
133*993b0882SAndroid Build Coastguard Worker // Nothing to do.
134*993b0882SAndroid Build Coastguard Worker return true;
135*993b0882SAndroid Build Coastguard Worker }
136*993b0882SAndroid Build Coastguard Worker
137*993b0882SAndroid Build Coastguard Worker std::vector<Locale> locales;
138*993b0882SAndroid Build Coastguard Worker if (!ParseLocales(conversation.messages.back().detected_text_language_tags,
139*993b0882SAndroid Build Coastguard Worker &locales)) {
140*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not parse locales of input text.";
141*993b0882SAndroid Build Coastguard Worker return false;
142*993b0882SAndroid Build Coastguard Worker }
143*993b0882SAndroid Build Coastguard Worker
144*993b0882SAndroid Build Coastguard Worker const int message_index = conversation.messages.size() - 1;
145*993b0882SAndroid Build Coastguard Worker grammar::TextContext text = analyzer_.BuildTextContextForInput(
146*993b0882SAndroid Build Coastguard Worker UTF8ToUnicodeText(conversation.messages.back().text, /*do_copy=*/false),
147*993b0882SAndroid Build Coastguard Worker locales);
148*993b0882SAndroid Build Coastguard Worker text.annotations = conversation.messages.back().annotations;
149*993b0882SAndroid Build Coastguard Worker
150*993b0882SAndroid Build Coastguard Worker UnsafeArena arena(/*block_size=*/16 << 10);
151*993b0882SAndroid Build Coastguard Worker StatusOr<std::vector<grammar::EvaluatedDerivation>> evaluated_derivations =
152*993b0882SAndroid Build Coastguard Worker analyzer_.Parse(text, &arena);
153*993b0882SAndroid Build Coastguard Worker // TODO(b/171294882): Return the status here and below.
154*993b0882SAndroid Build Coastguard Worker if (!evaluated_derivations.ok()) {
155*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not run grammar analyzer: "
156*993b0882SAndroid Build Coastguard Worker << evaluated_derivations.status().error_message();
157*993b0882SAndroid Build Coastguard Worker return false;
158*993b0882SAndroid Build Coastguard Worker }
159*993b0882SAndroid Build Coastguard Worker
160*993b0882SAndroid Build Coastguard Worker for (const grammar::EvaluatedDerivation& evaluated_derivation :
161*993b0882SAndroid Build Coastguard Worker evaluated_derivations.ValueOrDie()) {
162*993b0882SAndroid Build Coastguard Worker if (!InstantiateActionsFromMatch(text, message_index, evaluated_derivation,
163*993b0882SAndroid Build Coastguard Worker result)) {
164*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not instantiate actions from a grammar match.";
165*993b0882SAndroid Build Coastguard Worker return false;
166*993b0882SAndroid Build Coastguard Worker }
167*993b0882SAndroid Build Coastguard Worker }
168*993b0882SAndroid Build Coastguard Worker
169*993b0882SAndroid Build Coastguard Worker return true;
170*993b0882SAndroid Build Coastguard Worker }
171*993b0882SAndroid Build Coastguard Worker
172*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
173