xref: /aosp_15_r20/external/libtextclassifier/native/actions/utils.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/utils.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "annotator/collections.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
21*993b0882SAndroid Build Coastguard Worker #include "utils/normalization.h"
22*993b0882SAndroid Build Coastguard Worker #include "utils/strings/stringpiece.h"
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
25*993b0882SAndroid Build Coastguard Worker 
26*993b0882SAndroid Build Coastguard Worker // Name for a datetime annotation that only includes time but no date.
27*993b0882SAndroid Build Coastguard Worker const std::string& kTimeAnnotation =
__anon014b53b20102() 28*993b0882SAndroid Build Coastguard Worker     *[]() { return new std::string("time"); }();
29*993b0882SAndroid Build Coastguard Worker 
FillSuggestionFromSpec(const ActionSuggestionSpec * action,MutableFlatbuffer * entity_data,ActionSuggestion * suggestion)30*993b0882SAndroid Build Coastguard Worker void FillSuggestionFromSpec(const ActionSuggestionSpec* action,
31*993b0882SAndroid Build Coastguard Worker                             MutableFlatbuffer* entity_data,
32*993b0882SAndroid Build Coastguard Worker                             ActionSuggestion* suggestion) {
33*993b0882SAndroid Build Coastguard Worker   if (action != nullptr) {
34*993b0882SAndroid Build Coastguard Worker     suggestion->score = action->score();
35*993b0882SAndroid Build Coastguard Worker     suggestion->priority_score = action->priority_score();
36*993b0882SAndroid Build Coastguard Worker     if (action->type() != nullptr) {
37*993b0882SAndroid Build Coastguard Worker       suggestion->type = action->type()->str();
38*993b0882SAndroid Build Coastguard Worker     }
39*993b0882SAndroid Build Coastguard Worker     if (action->response_text() != nullptr) {
40*993b0882SAndroid Build Coastguard Worker       suggestion->response_text = action->response_text()->str();
41*993b0882SAndroid Build Coastguard Worker     }
42*993b0882SAndroid Build Coastguard Worker     if (action->serialized_entity_data() != nullptr) {
43*993b0882SAndroid Build Coastguard Worker       TC3_CHECK_NE(entity_data, nullptr);
44*993b0882SAndroid Build Coastguard Worker       entity_data->MergeFromSerializedFlatbuffer(
45*993b0882SAndroid Build Coastguard Worker           StringPiece(action->serialized_entity_data()->data(),
46*993b0882SAndroid Build Coastguard Worker                       action->serialized_entity_data()->size()));
47*993b0882SAndroid Build Coastguard Worker     }
48*993b0882SAndroid Build Coastguard Worker     if (action->entity_data() != nullptr) {
49*993b0882SAndroid Build Coastguard Worker       TC3_CHECK_NE(entity_data, nullptr);
50*993b0882SAndroid Build Coastguard Worker       entity_data->MergeFrom(
51*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const flatbuffers::Table*>(action->entity_data()));
52*993b0882SAndroid Build Coastguard Worker     }
53*993b0882SAndroid Build Coastguard Worker   }
54*993b0882SAndroid Build Coastguard Worker   if (entity_data != nullptr && entity_data->HasExplicitlySetFields()) {
55*993b0882SAndroid Build Coastguard Worker     suggestion->serialized_entity_data = entity_data->Serialize();
56*993b0882SAndroid Build Coastguard Worker   }
57*993b0882SAndroid Build Coastguard Worker }
58*993b0882SAndroid Build Coastguard Worker 
SuggestTextRepliesFromCapturingMatch(const MutableFlatbufferBuilder * entity_data_builder,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,const UnicodeText & match_text,const std::string & smart_reply_action_type,std::vector<ActionSuggestion> * actions)59*993b0882SAndroid Build Coastguard Worker void SuggestTextRepliesFromCapturingMatch(
60*993b0882SAndroid Build Coastguard Worker     const MutableFlatbufferBuilder* entity_data_builder,
61*993b0882SAndroid Build Coastguard Worker     const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
62*993b0882SAndroid Build Coastguard Worker     const UnicodeText& match_text, const std::string& smart_reply_action_type,
63*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* actions) {
64*993b0882SAndroid Build Coastguard Worker   if (group->text_reply() != nullptr) {
65*993b0882SAndroid Build Coastguard Worker     ActionSuggestion suggestion;
66*993b0882SAndroid Build Coastguard Worker     suggestion.response_text = match_text.ToUTF8String();
67*993b0882SAndroid Build Coastguard Worker     suggestion.type = smart_reply_action_type;
68*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<MutableFlatbuffer> entity_data =
69*993b0882SAndroid Build Coastguard Worker         entity_data_builder != nullptr ? entity_data_builder->NewRoot()
70*993b0882SAndroid Build Coastguard Worker                                        : nullptr;
71*993b0882SAndroid Build Coastguard Worker     FillSuggestionFromSpec(group->text_reply(), entity_data.get(), &suggestion);
72*993b0882SAndroid Build Coastguard Worker     actions->push_back(suggestion);
73*993b0882SAndroid Build Coastguard Worker   }
74*993b0882SAndroid Build Coastguard Worker }
75*993b0882SAndroid Build Coastguard Worker 
NormalizeMatchText(const UniLib & unilib,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,StringPiece match_text)76*993b0882SAndroid Build Coastguard Worker UnicodeText NormalizeMatchText(
77*993b0882SAndroid Build Coastguard Worker     const UniLib& unilib,
78*993b0882SAndroid Build Coastguard Worker     const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
79*993b0882SAndroid Build Coastguard Worker     StringPiece match_text) {
80*993b0882SAndroid Build Coastguard Worker   return NormalizeMatchText(unilib, group,
81*993b0882SAndroid Build Coastguard Worker                             UTF8ToUnicodeText(match_text, /*do_copy=*/false));
82*993b0882SAndroid Build Coastguard Worker }
83*993b0882SAndroid Build Coastguard Worker 
NormalizeMatchText(const UniLib & unilib,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,const UnicodeText match_text)84*993b0882SAndroid Build Coastguard Worker UnicodeText NormalizeMatchText(
85*993b0882SAndroid Build Coastguard Worker     const UniLib& unilib,
86*993b0882SAndroid Build Coastguard Worker     const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
87*993b0882SAndroid Build Coastguard Worker     const UnicodeText match_text) {
88*993b0882SAndroid Build Coastguard Worker   if (group->normalization_options() == nullptr) {
89*993b0882SAndroid Build Coastguard Worker     return match_text;
90*993b0882SAndroid Build Coastguard Worker   }
91*993b0882SAndroid Build Coastguard Worker   return NormalizeText(unilib, group->normalization_options(), match_text);
92*993b0882SAndroid Build Coastguard Worker }
93*993b0882SAndroid Build Coastguard Worker 
FillAnnotationFromCapturingMatch(const CodepointSpan & span,const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,const int message_index,StringPiece match_text,ActionSuggestionAnnotation * annotation)94*993b0882SAndroid Build Coastguard Worker bool FillAnnotationFromCapturingMatch(
95*993b0882SAndroid Build Coastguard Worker     const CodepointSpan& span,
96*993b0882SAndroid Build Coastguard Worker     const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
97*993b0882SAndroid Build Coastguard Worker     const int message_index, StringPiece match_text,
98*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation* annotation) {
99*993b0882SAndroid Build Coastguard Worker   if (group->annotation_name() == nullptr &&
100*993b0882SAndroid Build Coastguard Worker       group->annotation_type() == nullptr) {
101*993b0882SAndroid Build Coastguard Worker     return false;
102*993b0882SAndroid Build Coastguard Worker   }
103*993b0882SAndroid Build Coastguard Worker   annotation->span.span = span;
104*993b0882SAndroid Build Coastguard Worker   annotation->span.message_index = message_index;
105*993b0882SAndroid Build Coastguard Worker   annotation->span.text = match_text.ToString();
106*993b0882SAndroid Build Coastguard Worker   if (group->annotation_name() != nullptr) {
107*993b0882SAndroid Build Coastguard Worker     annotation->name = group->annotation_name()->str();
108*993b0882SAndroid Build Coastguard Worker   }
109*993b0882SAndroid Build Coastguard Worker   if (group->annotation_type() != nullptr) {
110*993b0882SAndroid Build Coastguard Worker     annotation->entity.collection = group->annotation_type()->str();
111*993b0882SAndroid Build Coastguard Worker   }
112*993b0882SAndroid Build Coastguard Worker   return true;
113*993b0882SAndroid Build Coastguard Worker }
114*993b0882SAndroid Build Coastguard Worker 
MergeEntityDataFromCapturingMatch(const RulesModel_::RuleActionSpec_::RuleCapturingGroup * group,StringPiece match_text,MutableFlatbuffer * buffer)115*993b0882SAndroid Build Coastguard Worker bool MergeEntityDataFromCapturingMatch(
116*993b0882SAndroid Build Coastguard Worker     const RulesModel_::RuleActionSpec_::RuleCapturingGroup* group,
117*993b0882SAndroid Build Coastguard Worker     StringPiece match_text, MutableFlatbuffer* buffer) {
118*993b0882SAndroid Build Coastguard Worker   if (group->entity_field() != nullptr) {
119*993b0882SAndroid Build Coastguard Worker     if (!buffer->ParseAndSet(group->entity_field(), match_text.ToString())) {
120*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not set entity data from rule capturing group.";
121*993b0882SAndroid Build Coastguard Worker       return false;
122*993b0882SAndroid Build Coastguard Worker     }
123*993b0882SAndroid Build Coastguard Worker   }
124*993b0882SAndroid Build Coastguard Worker   if (group->entity_data() != nullptr) {
125*993b0882SAndroid Build Coastguard Worker     if (!buffer->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
126*993b0882SAndroid Build Coastguard Worker             group->entity_data()))) {
127*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not set entity data for capturing match.";
128*993b0882SAndroid Build Coastguard Worker       return false;
129*993b0882SAndroid Build Coastguard Worker     }
130*993b0882SAndroid Build Coastguard Worker   }
131*993b0882SAndroid Build Coastguard Worker   return true;
132*993b0882SAndroid Build Coastguard Worker }
133*993b0882SAndroid Build Coastguard Worker 
ConvertDatetimeToTime(std::vector<AnnotatedSpan> * annotations)134*993b0882SAndroid Build Coastguard Worker void ConvertDatetimeToTime(std::vector<AnnotatedSpan>* annotations) {
135*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < annotations->size(); i++) {
136*993b0882SAndroid Build Coastguard Worker     ClassificationResult* classification =
137*993b0882SAndroid Build Coastguard Worker         &(*annotations)[i].classification.front();
138*993b0882SAndroid Build Coastguard Worker     // Specialize datetime annotation to time annotation if no date
139*993b0882SAndroid Build Coastguard Worker     // component is present.
140*993b0882SAndroid Build Coastguard Worker     if (classification->collection == Collections::DateTime() &&
141*993b0882SAndroid Build Coastguard Worker         classification->datetime_parse_result.IsSet()) {
142*993b0882SAndroid Build Coastguard Worker       bool has_only_time = true;
143*993b0882SAndroid Build Coastguard Worker       for (const DatetimeComponent& component :
144*993b0882SAndroid Build Coastguard Worker            classification->datetime_parse_result.datetime_components) {
145*993b0882SAndroid Build Coastguard Worker         if (component.component_type !=
146*993b0882SAndroid Build Coastguard Worker                 DatetimeComponent::ComponentType::UNSPECIFIED &&
147*993b0882SAndroid Build Coastguard Worker             component.component_type < DatetimeComponent::ComponentType::HOUR) {
148*993b0882SAndroid Build Coastguard Worker           has_only_time = false;
149*993b0882SAndroid Build Coastguard Worker           break;
150*993b0882SAndroid Build Coastguard Worker         }
151*993b0882SAndroid Build Coastguard Worker       }
152*993b0882SAndroid Build Coastguard Worker       if (has_only_time) {
153*993b0882SAndroid Build Coastguard Worker         classification->collection = kTimeAnnotation;
154*993b0882SAndroid Build Coastguard Worker       }
155*993b0882SAndroid Build Coastguard Worker     }
156*993b0882SAndroid Build Coastguard Worker   }
157*993b0882SAndroid Build Coastguard Worker }
158*993b0882SAndroid Build Coastguard Worker 
159*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
160