1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "actions/utils.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include "annotator/collections.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
21*993b0882SAndroid Build Coastguard Worker #include "utils/normalization.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/strings/stringpiece.h"
23*993b0882SAndroid Build Coastguard Worker
24*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
25*993b0882SAndroid Build Coastguard Worker
26*993b0882SAndroid Build Coastguard Worker // Name for a datetime annotation that only includes time but no date.
27*993b0882SAndroid Build Coastguard Worker const std::string& kTimeAnnotation =
__anon014b53b20102() 28*993b0882SAndroid Build Coastguard Worker *[]() { return new std::string("time"); }();
29*993b0882SAndroid Build Coastguard Worker
FillSuggestionFromSpec(const ActionSuggestionSpec * action,MutableFlatbuffer * entity_data,ActionSuggestion * suggestion)30*993b0882SAndroid Build Coastguard Worker void FillSuggestionFromSpec(const ActionSuggestionSpec* action,
31*993b0882SAndroid Build Coastguard Worker MutableFlatbuffer* entity_data,
32*993b0882SAndroid Build Coastguard Worker ActionSuggestion* suggestion) {
33*993b0882SAndroid Build Coastguard Worker if (action != nullptr) {
34*993b0882SAndroid Build Coastguard Worker suggestion->score = action->score();
35*993b0882SAndroid Build Coastguard Worker suggestion->priority_score = action->priority_score();
36*993b0882SAndroid Build Coastguard Worker if (action->type() != nullptr) {
37*993b0882SAndroid Build Coastguard Worker suggestion->type = action->type()->str();
38*993b0882SAndroid Build Coastguard Worker }
39*993b0882SAndroid Build Coastguard Worker if (action->response_text() != nullptr) {
40*993b0882SAndroid Build Coastguard Worker suggestion->response_text = action->response_text()->str();
41*993b0882SAndroid Build Coastguard Worker }
42*993b0882SAndroid Build Coastguard Worker if (action->serialized_entity_data() != nullptr) {
43*993b0882SAndroid Build Coastguard Worker TC3_CHECK_NE(entity_data, nullptr);
44*993b0882SAndroid Build Coastguard Worker entity_data->MergeFromSerializedFlatbuffer(
45*993b0882SAndroid Build Coastguard Worker StringPiece(action->serialized_entity_data()->data(),
46*993b0882SAndroid Build Coastguard Worker action->serialized_entity_data()->size()));
47*993b0882SAndroid Build Coastguard Worker }
48*993b0882SAndroid Build Coastguard Worker if (action->entity_data() != nullptr) {
49*993b0882SAndroid Build Coastguard Worker TC3_CHECK_NE(entity_data, nullptr);
50*993b0882SAndroid Build Coastguard Worker entity_data->MergeFrom(
51*993b0882SAndroid Build Coastguard Worker reinterpret_cast<const flatbuffers::Table*>(action->entity_data()));
52*993b0882SAndroid Build Coastguard Worker }
53*993b0882SAndroid Build Coastguard Worker }
54*993b0882SAndroid Build Coastguard Worker if (entity_data != nullptr && entity_data->HasExplicitlySetFields()) {
55*993b0882SAndroid Build Coastguard Worker suggestion->serialized_entity_data = entity_data->Serialize();
56*993b0882SAndroid Build Coastguard Worker }
57*993b0882SAndroid Build Coastguard Worker }
58*993b0882SAndroid Build Coastguard Worker
SuggestTextRepliesFromCapturingMatch(const MutableFlatbufferBuilder * entity_data_builder,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,const UnicodeText & match_text,const std::string & smart_reply_action_type,std::vector<ActionSuggestion> * actions)59*993b0882SAndroid Build Coastguard Worker void SuggestTextRepliesFromCapturingMatch(
60*993b0882SAndroid Build Coastguard Worker const MutableFlatbufferBuilder* entity_data_builder,
61*993b0882SAndroid Build Coastguard Worker const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
62*993b0882SAndroid Build Coastguard Worker const UnicodeText& match_text, const std::string& smart_reply_action_type,
63*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion>* actions) {
64*993b0882SAndroid Build Coastguard Worker if (group->text_reply() != nullptr) {
65*993b0882SAndroid Build Coastguard Worker ActionSuggestion suggestion;
66*993b0882SAndroid Build Coastguard Worker suggestion.response_text = match_text.ToUTF8String();
67*993b0882SAndroid Build Coastguard Worker suggestion.type = smart_reply_action_type;
68*993b0882SAndroid Build Coastguard Worker std::unique_ptr<MutableFlatbuffer> entity_data =
69*993b0882SAndroid Build Coastguard Worker entity_data_builder != nullptr ? entity_data_builder->NewRoot()
70*993b0882SAndroid Build Coastguard Worker : nullptr;
71*993b0882SAndroid Build Coastguard Worker FillSuggestionFromSpec(group->text_reply(), entity_data.get(), &suggestion);
72*993b0882SAndroid Build Coastguard Worker actions->push_back(suggestion);
73*993b0882SAndroid Build Coastguard Worker }
74*993b0882SAndroid Build Coastguard Worker }
75*993b0882SAndroid Build Coastguard Worker
NormalizeMatchText(const UniLib & unilib,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,StringPiece match_text)76*993b0882SAndroid Build Coastguard Worker UnicodeText NormalizeMatchText(
77*993b0882SAndroid Build Coastguard Worker const UniLib& unilib,
78*993b0882SAndroid Build Coastguard Worker const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
79*993b0882SAndroid Build Coastguard Worker StringPiece match_text) {
80*993b0882SAndroid Build Coastguard Worker return NormalizeMatchText(unilib, group,
81*993b0882SAndroid Build Coastguard Worker UTF8ToUnicodeText(match_text, /*do_copy=*/false));
82*993b0882SAndroid Build Coastguard Worker }
83*993b0882SAndroid Build Coastguard Worker
NormalizeMatchText(const UniLib & unilib,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,const UnicodeText match_text)84*993b0882SAndroid Build Coastguard Worker UnicodeText NormalizeMatchText(
85*993b0882SAndroid Build Coastguard Worker const UniLib& unilib,
86*993b0882SAndroid Build Coastguard Worker const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
87*993b0882SAndroid Build Coastguard Worker const UnicodeText match_text) {
88*993b0882SAndroid Build Coastguard Worker if (group->normalization_options() == nullptr) {
89*993b0882SAndroid Build Coastguard Worker return match_text;
90*993b0882SAndroid Build Coastguard Worker }
91*993b0882SAndroid Build Coastguard Worker return NormalizeText(unilib, group->normalization_options(), match_text);
92*993b0882SAndroid Build Coastguard Worker }
93*993b0882SAndroid Build Coastguard Worker
FillAnnotationFromCapturingMatch(const CodepointSpan & span,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,const int message_index,StringPiece match_text,ActionSuggestionAnnotation * annotation)94*993b0882SAndroid Build Coastguard Worker bool FillAnnotationFromCapturingMatch(
95*993b0882SAndroid Build Coastguard Worker const CodepointSpan& span,
96*993b0882SAndroid Build Coastguard Worker const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
97*993b0882SAndroid Build Coastguard Worker const int message_index, StringPiece match_text,
98*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation* annotation) {
99*993b0882SAndroid Build Coastguard Worker if (group->annotation_name() == nullptr &&
100*993b0882SAndroid Build Coastguard Worker group->annotation_type() == nullptr) {
101*993b0882SAndroid Build Coastguard Worker return false;
102*993b0882SAndroid Build Coastguard Worker }
103*993b0882SAndroid Build Coastguard Worker annotation->span.span = span;
104*993b0882SAndroid Build Coastguard Worker annotation->span.message_index = message_index;
105*993b0882SAndroid Build Coastguard Worker annotation->span.text = match_text.ToString();
106*993b0882SAndroid Build Coastguard Worker if (group->annotation_name() != nullptr) {
107*993b0882SAndroid Build Coastguard Worker annotation->name = group->annotation_name()->str();
108*993b0882SAndroid Build Coastguard Worker }
109*993b0882SAndroid Build Coastguard Worker if (group->annotation_type() != nullptr) {
110*993b0882SAndroid Build Coastguard Worker annotation->entity.collection = group->annotation_type()->str();
111*993b0882SAndroid Build Coastguard Worker }
112*993b0882SAndroid Build Coastguard Worker return true;
113*993b0882SAndroid Build Coastguard Worker }
114*993b0882SAndroid Build Coastguard Worker
MergeEntityDataFromCapturingMatch(const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,StringPiece match_text,MutableFlatbuffer * buffer)115*993b0882SAndroid Build Coastguard Worker bool MergeEntityDataFromCapturingMatch(
116*993b0882SAndroid Build Coastguard Worker const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
117*993b0882SAndroid Build Coastguard Worker StringPiece match_text, MutableFlatbuffer* buffer) {
118*993b0882SAndroid Build Coastguard Worker if (group->entity_field() != nullptr) {
119*993b0882SAndroid Build Coastguard Worker if (!buffer->ParseAndSet(group->entity_field(), match_text.ToString())) {
120*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not set entity data from rule capturing group.";
121*993b0882SAndroid Build Coastguard Worker return false;
122*993b0882SAndroid Build Coastguard Worker }
123*993b0882SAndroid Build Coastguard Worker }
124*993b0882SAndroid Build Coastguard Worker if (group->entity_data() != nullptr) {
125*993b0882SAndroid Build Coastguard Worker if (!buffer->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
126*993b0882SAndroid Build Coastguard Worker group->entity_data()))) {
127*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not set entity data for capturing match.";
128*993b0882SAndroid Build Coastguard Worker return false;
129*993b0882SAndroid Build Coastguard Worker }
130*993b0882SAndroid Build Coastguard Worker }
131*993b0882SAndroid Build Coastguard Worker return true;
132*993b0882SAndroid Build Coastguard Worker }
133*993b0882SAndroid Build Coastguard Worker
ConvertDatetimeToTime(std::vector<AnnotatedSpan> * annotations)134*993b0882SAndroid Build Coastguard Worker void ConvertDatetimeToTime(std::vector<AnnotatedSpan>* annotations) {
135*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < annotations->size(); i++) {
136*993b0882SAndroid Build Coastguard Worker ClassificationResult* classification =
137*993b0882SAndroid Build Coastguard Worker &(*annotations)[i].classification.front();
138*993b0882SAndroid Build Coastguard Worker // Specialize datetime annotation to time annotation if no date
139*993b0882SAndroid Build Coastguard Worker // component is present.
140*993b0882SAndroid Build Coastguard Worker if (classification->collection == Collections::DateTime() &&
141*993b0882SAndroid Build Coastguard Worker classification->datetime_parse_result.IsSet()) {
142*993b0882SAndroid Build Coastguard Worker bool has_only_time = true;
143*993b0882SAndroid Build Coastguard Worker for (const DatetimeComponent& component :
144*993b0882SAndroid Build Coastguard Worker classification->datetime_parse_result.datetime_components) {
145*993b0882SAndroid Build Coastguard Worker if (component.component_type !=
146*993b0882SAndroid Build Coastguard Worker DatetimeComponent::ComponentType::UNSPECIFIED &&
147*993b0882SAndroid Build Coastguard Worker component.component_type < DatetimeComponent::ComponentType::HOUR) {
148*993b0882SAndroid Build Coastguard Worker has_only_time = false;
149*993b0882SAndroid Build Coastguard Worker break;
150*993b0882SAndroid Build Coastguard Worker }
151*993b0882SAndroid Build Coastguard Worker }
152*993b0882SAndroid Build Coastguard Worker if (has_only_time) {
153*993b0882SAndroid Build Coastguard Worker classification->collection = kTimeAnnotation;
154*993b0882SAndroid Build Coastguard Worker }
155*993b0882SAndroid Build Coastguard Worker }
156*993b0882SAndroid Build Coastguard Worker }
157*993b0882SAndroid Build Coastguard Worker }
158*993b0882SAndroid Build Coastguard Worker
159*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
160