xref: /aosp_15_r20/external/libtextclassifier/native/actions/ranker.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/ranker.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <functional>
20*993b0882SAndroid Build Coastguard Worker #include <set>
21*993b0882SAndroid Build Coastguard Worker #include <vector>
22*993b0882SAndroid Build Coastguard Worker 
23*993b0882SAndroid Build Coastguard Worker #include "actions/actions_model_generated.h"
24*993b0882SAndroid Build Coastguard Worker 
25*993b0882SAndroid Build Coastguard Worker #if !defined(TC3_DISABLE_LUA)
26*993b0882SAndroid Build Coastguard Worker #include "actions/lua-ranker.h"
27*993b0882SAndroid Build Coastguard Worker #endif
28*993b0882SAndroid Build Coastguard Worker #include "actions/zlib-utils.h"
29*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
30*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
31*993b0882SAndroid Build Coastguard Worker #if !defined(TC3_DISABLE_LUA)
32*993b0882SAndroid Build Coastguard Worker #include "utils/lua-utils.h"
33*993b0882SAndroid Build Coastguard Worker #endif
34*993b0882SAndroid Build Coastguard Worker 
35*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
36*993b0882SAndroid Build Coastguard Worker namespace {
37*993b0882SAndroid Build Coastguard Worker 
SortByScoreAndType(std::vector<ActionSuggestion> * actions)38*993b0882SAndroid Build Coastguard Worker void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
39*993b0882SAndroid Build Coastguard Worker   std::stable_sort(actions->begin(), actions->end(),
40*993b0882SAndroid Build Coastguard Worker                    [](const ActionSuggestion& a, const ActionSuggestion& b) {
41*993b0882SAndroid Build Coastguard Worker                      return a.score > b.score ||
42*993b0882SAndroid Build Coastguard Worker                             (a.score >= b.score && a.type < b.type);
43*993b0882SAndroid Build Coastguard Worker                    });
44*993b0882SAndroid Build Coastguard Worker }
45*993b0882SAndroid Build Coastguard Worker 
SortByPriorityAndScoreAndType(std::vector<ActionSuggestion> * actions)46*993b0882SAndroid Build Coastguard Worker void SortByPriorityAndScoreAndType(std::vector<ActionSuggestion>* actions) {
47*993b0882SAndroid Build Coastguard Worker   std::stable_sort(
48*993b0882SAndroid Build Coastguard Worker       actions->begin(), actions->end(),
49*993b0882SAndroid Build Coastguard Worker       [](const ActionSuggestion& a, const ActionSuggestion& b) {
50*993b0882SAndroid Build Coastguard Worker         return a.priority_score > b.priority_score ||
51*993b0882SAndroid Build Coastguard Worker                (a.priority_score >= b.priority_score && a.score > b.score) ||
52*993b0882SAndroid Build Coastguard Worker                (a.priority_score >= b.priority_score && a.score >= b.score &&
53*993b0882SAndroid Build Coastguard Worker                 a.type < b.type);
54*993b0882SAndroid Build Coastguard Worker       });
55*993b0882SAndroid Build Coastguard Worker }
56*993b0882SAndroid Build Coastguard Worker 
57*993b0882SAndroid Build Coastguard Worker template <typename T>
Compare(const T & left,const T & right)58*993b0882SAndroid Build Coastguard Worker int Compare(const T& left, const T& right) {
59*993b0882SAndroid Build Coastguard Worker   if (left < right) {
60*993b0882SAndroid Build Coastguard Worker     return -1;
61*993b0882SAndroid Build Coastguard Worker   }
62*993b0882SAndroid Build Coastguard Worker   if (left > right) {
63*993b0882SAndroid Build Coastguard Worker     return 1;
64*993b0882SAndroid Build Coastguard Worker   }
65*993b0882SAndroid Build Coastguard Worker   return 0;
66*993b0882SAndroid Build Coastguard Worker }
67*993b0882SAndroid Build Coastguard Worker 
68*993b0882SAndroid Build Coastguard Worker template <>
Compare(const std::string & left,const std::string & right)69*993b0882SAndroid Build Coastguard Worker int Compare(const std::string& left, const std::string& right) {
70*993b0882SAndroid Build Coastguard Worker   return left.compare(right);
71*993b0882SAndroid Build Coastguard Worker }
72*993b0882SAndroid Build Coastguard Worker 
73*993b0882SAndroid Build Coastguard Worker template <>
Compare(const MessageTextSpan & span,const MessageTextSpan & other)74*993b0882SAndroid Build Coastguard Worker int Compare(const MessageTextSpan& span, const MessageTextSpan& other) {
75*993b0882SAndroid Build Coastguard Worker   if (const int value = Compare(span.message_index, other.message_index)) {
76*993b0882SAndroid Build Coastguard Worker     return value;
77*993b0882SAndroid Build Coastguard Worker   }
78*993b0882SAndroid Build Coastguard Worker   if (const int value = Compare(span.span.first, other.span.first)) {
79*993b0882SAndroid Build Coastguard Worker     return value;
80*993b0882SAndroid Build Coastguard Worker   }
81*993b0882SAndroid Build Coastguard Worker   if (const int value = Compare(span.span.second, other.span.second)) {
82*993b0882SAndroid Build Coastguard Worker     return value;
83*993b0882SAndroid Build Coastguard Worker   }
84*993b0882SAndroid Build Coastguard Worker   return 0;
85*993b0882SAndroid Build Coastguard Worker }
86*993b0882SAndroid Build Coastguard Worker 
IsSameSpan(const MessageTextSpan & span,const MessageTextSpan & other)87*993b0882SAndroid Build Coastguard Worker bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) {
88*993b0882SAndroid Build Coastguard Worker   return Compare(span, other) == 0;
89*993b0882SAndroid Build Coastguard Worker }
90*993b0882SAndroid Build Coastguard Worker 
TextSpansIntersect(const MessageTextSpan & span,const MessageTextSpan & other)91*993b0882SAndroid Build Coastguard Worker bool TextSpansIntersect(const MessageTextSpan& span,
92*993b0882SAndroid Build Coastguard Worker                         const MessageTextSpan& other) {
93*993b0882SAndroid Build Coastguard Worker   return span.message_index == other.message_index &&
94*993b0882SAndroid Build Coastguard Worker          SpansOverlap(span.span, other.span);
95*993b0882SAndroid Build Coastguard Worker }
96*993b0882SAndroid Build Coastguard Worker 
97*993b0882SAndroid Build Coastguard Worker template <>
Compare(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)98*993b0882SAndroid Build Coastguard Worker int Compare(const ActionSuggestionAnnotation& annotation,
99*993b0882SAndroid Build Coastguard Worker             const ActionSuggestionAnnotation& other) {
100*993b0882SAndroid Build Coastguard Worker   if (const int value = Compare(annotation.span, other.span)) {
101*993b0882SAndroid Build Coastguard Worker     return value;
102*993b0882SAndroid Build Coastguard Worker   }
103*993b0882SAndroid Build Coastguard Worker   if (const int value = Compare(annotation.name, other.name)) {
104*993b0882SAndroid Build Coastguard Worker     return value;
105*993b0882SAndroid Build Coastguard Worker   }
106*993b0882SAndroid Build Coastguard Worker   if (const int value =
107*993b0882SAndroid Build Coastguard Worker           Compare(annotation.entity.collection, other.entity.collection)) {
108*993b0882SAndroid Build Coastguard Worker     return value;
109*993b0882SAndroid Build Coastguard Worker   }
110*993b0882SAndroid Build Coastguard Worker   return 0;
111*993b0882SAndroid Build Coastguard Worker }
112*993b0882SAndroid Build Coastguard Worker 
113*993b0882SAndroid Build Coastguard Worker // Checks whether two annotations can be considered equivalent.
IsEquivalentActionAnnotation(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)114*993b0882SAndroid Build Coastguard Worker bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
115*993b0882SAndroid Build Coastguard Worker                                   const ActionSuggestionAnnotation& other) {
116*993b0882SAndroid Build Coastguard Worker   return Compare(annotation, other) == 0;
117*993b0882SAndroid Build Coastguard Worker }
118*993b0882SAndroid Build Coastguard Worker 
119*993b0882SAndroid Build Coastguard Worker // Compares actions based on annotations.
CompareAnnotationsOnly(const ActionSuggestion & action,const ActionSuggestion & other)120*993b0882SAndroid Build Coastguard Worker int CompareAnnotationsOnly(const ActionSuggestion& action,
121*993b0882SAndroid Build Coastguard Worker                            const ActionSuggestion& other) {
122*993b0882SAndroid Build Coastguard Worker   if (const int value =
123*993b0882SAndroid Build Coastguard Worker           Compare(action.annotations.size(), other.annotations.size())) {
124*993b0882SAndroid Build Coastguard Worker     return value;
125*993b0882SAndroid Build Coastguard Worker   }
126*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < action.annotations.size(); i++) {
127*993b0882SAndroid Build Coastguard Worker     if (const int value =
128*993b0882SAndroid Build Coastguard Worker             Compare(action.annotations[i], other.annotations[i])) {
129*993b0882SAndroid Build Coastguard Worker       return value;
130*993b0882SAndroid Build Coastguard Worker     }
131*993b0882SAndroid Build Coastguard Worker   }
132*993b0882SAndroid Build Coastguard Worker   return 0;
133*993b0882SAndroid Build Coastguard Worker }
134*993b0882SAndroid Build Coastguard Worker 
135*993b0882SAndroid Build Coastguard Worker // Checks whether two actions have the same annotations.
HaveEquivalentAnnotations(const ActionSuggestion & action,const ActionSuggestion & other)136*993b0882SAndroid Build Coastguard Worker bool HaveEquivalentAnnotations(const ActionSuggestion& action,
137*993b0882SAndroid Build Coastguard Worker                                const ActionSuggestion& other) {
138*993b0882SAndroid Build Coastguard Worker   return CompareAnnotationsOnly(action, other) == 0;
139*993b0882SAndroid Build Coastguard Worker }
140*993b0882SAndroid Build Coastguard Worker 
141*993b0882SAndroid Build Coastguard Worker template <>
Compare(const ActionSuggestion & action,const ActionSuggestion & other)142*993b0882SAndroid Build Coastguard Worker int Compare(const ActionSuggestion& action, const ActionSuggestion& other) {
143*993b0882SAndroid Build Coastguard Worker   if (const int value = Compare(action.type, other.type)) {
144*993b0882SAndroid Build Coastguard Worker     return value;
145*993b0882SAndroid Build Coastguard Worker   }
146*993b0882SAndroid Build Coastguard Worker   if (const int value = Compare(action.response_text, other.response_text)) {
147*993b0882SAndroid Build Coastguard Worker     return value;
148*993b0882SAndroid Build Coastguard Worker   }
149*993b0882SAndroid Build Coastguard Worker   if (const int value = Compare(action.serialized_entity_data,
150*993b0882SAndroid Build Coastguard Worker                                 other.serialized_entity_data)) {
151*993b0882SAndroid Build Coastguard Worker     return value;
152*993b0882SAndroid Build Coastguard Worker   }
153*993b0882SAndroid Build Coastguard Worker   return CompareAnnotationsOnly(action, other);
154*993b0882SAndroid Build Coastguard Worker }
155*993b0882SAndroid Build Coastguard Worker 
156*993b0882SAndroid Build Coastguard Worker // Checks whether two action suggestions can be considered equivalent.
IsEquivalentActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)157*993b0882SAndroid Build Coastguard Worker bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
158*993b0882SAndroid Build Coastguard Worker                                   const ActionSuggestion& other) {
159*993b0882SAndroid Build Coastguard Worker   return Compare(action, other) == 0;
160*993b0882SAndroid Build Coastguard Worker }
161*993b0882SAndroid Build Coastguard Worker 
162*993b0882SAndroid Build Coastguard Worker // Checks whether any action is equivalent to the given one.
IsAnyActionEquivalent(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)163*993b0882SAndroid Build Coastguard Worker bool IsAnyActionEquivalent(const ActionSuggestion& action,
164*993b0882SAndroid Build Coastguard Worker                            const std::vector<ActionSuggestion>& actions) {
165*993b0882SAndroid Build Coastguard Worker   for (const ActionSuggestion& other : actions) {
166*993b0882SAndroid Build Coastguard Worker     if (IsEquivalentActionSuggestion(action, other)) {
167*993b0882SAndroid Build Coastguard Worker       return true;
168*993b0882SAndroid Build Coastguard Worker     }
169*993b0882SAndroid Build Coastguard Worker   }
170*993b0882SAndroid Build Coastguard Worker   return false;
171*993b0882SAndroid Build Coastguard Worker }
172*993b0882SAndroid Build Coastguard Worker 
IsConflicting(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)173*993b0882SAndroid Build Coastguard Worker bool IsConflicting(const ActionSuggestionAnnotation& annotation,
174*993b0882SAndroid Build Coastguard Worker                    const ActionSuggestionAnnotation& other) {
175*993b0882SAndroid Build Coastguard Worker   // Two annotations are conflicting if they are different but refer to
176*993b0882SAndroid Build Coastguard Worker   // overlapping spans in the conversation.
177*993b0882SAndroid Build Coastguard Worker   return (!IsEquivalentActionAnnotation(annotation, other) &&
178*993b0882SAndroid Build Coastguard Worker           TextSpansIntersect(annotation.span, other.span));
179*993b0882SAndroid Build Coastguard Worker }
180*993b0882SAndroid Build Coastguard Worker 
181*993b0882SAndroid Build Coastguard Worker // Checks whether two action suggestions can be considered conflicting.
IsConflictingActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)182*993b0882SAndroid Build Coastguard Worker bool IsConflictingActionSuggestion(const ActionSuggestion& action,
183*993b0882SAndroid Build Coastguard Worker                                    const ActionSuggestion& other) {
184*993b0882SAndroid Build Coastguard Worker   // Actions are considered conflicting, iff they refer to the same text span,
185*993b0882SAndroid Build Coastguard Worker   // but were not generated from the same annotation.
186*993b0882SAndroid Build Coastguard Worker   if (action.annotations.empty() || other.annotations.empty()) {
187*993b0882SAndroid Build Coastguard Worker     return false;
188*993b0882SAndroid Build Coastguard Worker   }
189*993b0882SAndroid Build Coastguard Worker   for (const ActionSuggestionAnnotation& annotation : action.annotations) {
190*993b0882SAndroid Build Coastguard Worker     for (const ActionSuggestionAnnotation& other_annotation :
191*993b0882SAndroid Build Coastguard Worker          other.annotations) {
192*993b0882SAndroid Build Coastguard Worker       if (IsConflicting(annotation, other_annotation)) {
193*993b0882SAndroid Build Coastguard Worker         return true;
194*993b0882SAndroid Build Coastguard Worker       }
195*993b0882SAndroid Build Coastguard Worker     }
196*993b0882SAndroid Build Coastguard Worker   }
197*993b0882SAndroid Build Coastguard Worker   return false;
198*993b0882SAndroid Build Coastguard Worker }
199*993b0882SAndroid Build Coastguard Worker 
200*993b0882SAndroid Build Coastguard Worker // Checks whether any action is considered conflicting with the given one.
IsAnyActionConflicting(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)201*993b0882SAndroid Build Coastguard Worker bool IsAnyActionConflicting(const ActionSuggestion& action,
202*993b0882SAndroid Build Coastguard Worker                             const std::vector<ActionSuggestion>& actions) {
203*993b0882SAndroid Build Coastguard Worker   for (const ActionSuggestion& other : actions) {
204*993b0882SAndroid Build Coastguard Worker     if (IsConflictingActionSuggestion(action, other)) {
205*993b0882SAndroid Build Coastguard Worker       return true;
206*993b0882SAndroid Build Coastguard Worker     }
207*993b0882SAndroid Build Coastguard Worker   }
208*993b0882SAndroid Build Coastguard Worker   return false;
209*993b0882SAndroid Build Coastguard Worker }
210*993b0882SAndroid Build Coastguard Worker 
211*993b0882SAndroid Build Coastguard Worker }  // namespace
212*993b0882SAndroid Build Coastguard Worker 
213*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestionsRanker>
CreateActionsSuggestionsRanker(const RankingOptions * options,ZlibDecompressor * decompressor,const std::string & smart_reply_action_type)214*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
215*993b0882SAndroid Build Coastguard Worker     const RankingOptions* options, ZlibDecompressor* decompressor,
216*993b0882SAndroid Build Coastguard Worker     const std::string& smart_reply_action_type) {
217*993b0882SAndroid Build Coastguard Worker   auto ranker = std::unique_ptr<ActionsSuggestionsRanker>(
218*993b0882SAndroid Build Coastguard Worker       new ActionsSuggestionsRanker(options, smart_reply_action_type));
219*993b0882SAndroid Build Coastguard Worker 
220*993b0882SAndroid Build Coastguard Worker   if (!ranker->InitializeAndValidate(decompressor)) {
221*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not initialize action ranker.";
222*993b0882SAndroid Build Coastguard Worker     return nullptr;
223*993b0882SAndroid Build Coastguard Worker   }
224*993b0882SAndroid Build Coastguard Worker 
225*993b0882SAndroid Build Coastguard Worker   return ranker;
226*993b0882SAndroid Build Coastguard Worker }
227*993b0882SAndroid Build Coastguard Worker 
InitializeAndValidate(ZlibDecompressor * decompressor)228*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestionsRanker::InitializeAndValidate(
229*993b0882SAndroid Build Coastguard Worker     ZlibDecompressor* decompressor) {
230*993b0882SAndroid Build Coastguard Worker   if (options_ == nullptr) {
231*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "No ranking options specified.";
232*993b0882SAndroid Build Coastguard Worker     return false;
233*993b0882SAndroid Build Coastguard Worker   }
234*993b0882SAndroid Build Coastguard Worker 
235*993b0882SAndroid Build Coastguard Worker #if !defined(TC3_DISABLE_LUA)
236*993b0882SAndroid Build Coastguard Worker   std::string lua_ranking_script;
237*993b0882SAndroid Build Coastguard Worker   if (GetUncompressedString(options_->lua_ranking_script(),
238*993b0882SAndroid Build Coastguard Worker                             options_->compressed_lua_ranking_script(),
239*993b0882SAndroid Build Coastguard Worker                             decompressor, &lua_ranking_script) &&
240*993b0882SAndroid Build Coastguard Worker       !lua_ranking_script.empty()) {
241*993b0882SAndroid Build Coastguard Worker     if (!Compile(lua_ranking_script, &lua_bytecode_)) {
242*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not precompile lua ranking snippet.";
243*993b0882SAndroid Build Coastguard Worker       return false;
244*993b0882SAndroid Build Coastguard Worker     }
245*993b0882SAndroid Build Coastguard Worker   }
246*993b0882SAndroid Build Coastguard Worker #endif
247*993b0882SAndroid Build Coastguard Worker 
248*993b0882SAndroid Build Coastguard Worker   return true;
249*993b0882SAndroid Build Coastguard Worker }
250*993b0882SAndroid Build Coastguard Worker 
RankActions(const Conversation & conversation,ActionsSuggestionsResponse * response,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const251*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestionsRanker::RankActions(
252*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, ActionsSuggestionsResponse* response,
253*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* entity_data_schema,
254*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* annotations_entity_data_schema) const {
255*993b0882SAndroid Build Coastguard Worker   if (options_->deduplicate_suggestions() ||
256*993b0882SAndroid Build Coastguard Worker       options_->deduplicate_suggestions_by_span()) {
257*993b0882SAndroid Build Coastguard Worker     // Order suggestions by [priority score -> score] for deduplication
258*993b0882SAndroid Build Coastguard Worker     SortByPriorityAndScoreAndType(&response->actions);
259*993b0882SAndroid Build Coastguard Worker 
260*993b0882SAndroid Build Coastguard Worker     // Deduplicate, keeping the higher score actions.
261*993b0882SAndroid Build Coastguard Worker     if (options_->deduplicate_suggestions()) {
262*993b0882SAndroid Build Coastguard Worker       std::vector<ActionSuggestion> deduplicated_actions;
263*993b0882SAndroid Build Coastguard Worker       for (const ActionSuggestion& candidate : response->actions) {
264*993b0882SAndroid Build Coastguard Worker         // Check whether we already have an equivalent action.
265*993b0882SAndroid Build Coastguard Worker         if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
266*993b0882SAndroid Build Coastguard Worker           deduplicated_actions.push_back(std::move(candidate));
267*993b0882SAndroid Build Coastguard Worker         }
268*993b0882SAndroid Build Coastguard Worker       }
269*993b0882SAndroid Build Coastguard Worker       response->actions = std::move(deduplicated_actions);
270*993b0882SAndroid Build Coastguard Worker     }
271*993b0882SAndroid Build Coastguard Worker 
272*993b0882SAndroid Build Coastguard Worker     // Resolve conflicts between conflicting actions referring to the same
273*993b0882SAndroid Build Coastguard Worker     // text span.
274*993b0882SAndroid Build Coastguard Worker     if (options_->deduplicate_suggestions_by_span()) {
275*993b0882SAndroid Build Coastguard Worker       std::vector<ActionSuggestion> deduplicated_actions;
276*993b0882SAndroid Build Coastguard Worker       for (const ActionSuggestion& candidate : response->actions) {
277*993b0882SAndroid Build Coastguard Worker         // Check whether we already have a conflicting action.
278*993b0882SAndroid Build Coastguard Worker         if (!IsAnyActionConflicting(candidate, deduplicated_actions)) {
279*993b0882SAndroid Build Coastguard Worker           deduplicated_actions.push_back(std::move(candidate));
280*993b0882SAndroid Build Coastguard Worker         }
281*993b0882SAndroid Build Coastguard Worker       }
282*993b0882SAndroid Build Coastguard Worker       response->actions = std::move(deduplicated_actions);
283*993b0882SAndroid Build Coastguard Worker     }
284*993b0882SAndroid Build Coastguard Worker   }
285*993b0882SAndroid Build Coastguard Worker 
286*993b0882SAndroid Build Coastguard Worker   bool sort_by_priority =
287*993b0882SAndroid Build Coastguard Worker       options_->sort_type() == RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
288*993b0882SAndroid Build Coastguard Worker   // Suppress smart replies if actions are present.
289*993b0882SAndroid Build Coastguard Worker   if (options_->suppress_smart_replies_with_actions()) {
290*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion> non_smart_reply_actions;
291*993b0882SAndroid Build Coastguard Worker     for (const ActionSuggestion& action : response->actions) {
292*993b0882SAndroid Build Coastguard Worker       if (action.type != smart_reply_action_type_) {
293*993b0882SAndroid Build Coastguard Worker         non_smart_reply_actions.push_back(std::move(action));
294*993b0882SAndroid Build Coastguard Worker       }
295*993b0882SAndroid Build Coastguard Worker     }
296*993b0882SAndroid Build Coastguard Worker     response->actions = std::move(non_smart_reply_actions);
297*993b0882SAndroid Build Coastguard Worker   }
298*993b0882SAndroid Build Coastguard Worker 
299*993b0882SAndroid Build Coastguard Worker   // Group by annotation if specified.
300*993b0882SAndroid Build Coastguard Worker   if (options_->group_by_annotations()) {
301*993b0882SAndroid Build Coastguard Worker     auto group_id = std::map<
302*993b0882SAndroid Build Coastguard Worker         ActionSuggestion, int,
303*993b0882SAndroid Build Coastguard Worker         std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{
304*993b0882SAndroid Build Coastguard Worker         [](const ActionSuggestion& action, const ActionSuggestion& other) {
305*993b0882SAndroid Build Coastguard Worker           return (CompareAnnotationsOnly(action, other) < 0);
306*993b0882SAndroid Build Coastguard Worker         }};
307*993b0882SAndroid Build Coastguard Worker     typedef std::vector<ActionSuggestion> ActionSuggestionGroup;
308*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestionGroup> groups;
309*993b0882SAndroid Build Coastguard Worker 
310*993b0882SAndroid Build Coastguard Worker     // Group actions by the annotation set they are based of.
311*993b0882SAndroid Build Coastguard Worker     for (const ActionSuggestion& action : response->actions) {
312*993b0882SAndroid Build Coastguard Worker       // Treat actions with no annotations idependently.
313*993b0882SAndroid Build Coastguard Worker       if (action.annotations.empty()) {
314*993b0882SAndroid Build Coastguard Worker         groups.emplace_back(1, action);
315*993b0882SAndroid Build Coastguard Worker         continue;
316*993b0882SAndroid Build Coastguard Worker       }
317*993b0882SAndroid Build Coastguard Worker 
318*993b0882SAndroid Build Coastguard Worker       auto it = group_id.find(action);
319*993b0882SAndroid Build Coastguard Worker       if (it != group_id.end()) {
320*993b0882SAndroid Build Coastguard Worker         groups[it->second].push_back(action);
321*993b0882SAndroid Build Coastguard Worker       } else {
322*993b0882SAndroid Build Coastguard Worker         group_id[action] = groups.size();
323*993b0882SAndroid Build Coastguard Worker         groups.emplace_back(1, action);
324*993b0882SAndroid Build Coastguard Worker       }
325*993b0882SAndroid Build Coastguard Worker     }
326*993b0882SAndroid Build Coastguard Worker 
327*993b0882SAndroid Build Coastguard Worker     // Sort within each group by score.
328*993b0882SAndroid Build Coastguard Worker     for (std::vector<ActionSuggestion>& group : groups) {
329*993b0882SAndroid Build Coastguard Worker       if (sort_by_priority) {
330*993b0882SAndroid Build Coastguard Worker         SortByPriorityAndScoreAndType(&group);
331*993b0882SAndroid Build Coastguard Worker       } else {
332*993b0882SAndroid Build Coastguard Worker         SortByScoreAndType(&group);
333*993b0882SAndroid Build Coastguard Worker       }
334*993b0882SAndroid Build Coastguard Worker     }
335*993b0882SAndroid Build Coastguard Worker 
336*993b0882SAndroid Build Coastguard Worker     // Sort groups by maximum score or priority score.
337*993b0882SAndroid Build Coastguard Worker     if (sort_by_priority) {
338*993b0882SAndroid Build Coastguard Worker       std::stable_sort(
339*993b0882SAndroid Build Coastguard Worker           groups.begin(), groups.end(),
340*993b0882SAndroid Build Coastguard Worker           [](const std::vector<ActionSuggestion>& a,
341*993b0882SAndroid Build Coastguard Worker              const std::vector<ActionSuggestion>& b) {
342*993b0882SAndroid Build Coastguard Worker             return (a.begin()->priority_score > b.begin()->priority_score) ||
343*993b0882SAndroid Build Coastguard Worker                    (a.begin()->priority_score >= b.begin()->priority_score &&
344*993b0882SAndroid Build Coastguard Worker                     a.begin()->score > b.begin()->score) ||
345*993b0882SAndroid Build Coastguard Worker                    (a.begin()->priority_score >= b.begin()->priority_score &&
346*993b0882SAndroid Build Coastguard Worker                     a.begin()->score >= b.begin()->score &&
347*993b0882SAndroid Build Coastguard Worker                     a.begin()->type < b.begin()->type);
348*993b0882SAndroid Build Coastguard Worker           });
349*993b0882SAndroid Build Coastguard Worker     } else {
350*993b0882SAndroid Build Coastguard Worker       std::stable_sort(groups.begin(), groups.end(),
351*993b0882SAndroid Build Coastguard Worker                        [](const std::vector<ActionSuggestion>& a,
352*993b0882SAndroid Build Coastguard Worker                           const std::vector<ActionSuggestion>& b) {
353*993b0882SAndroid Build Coastguard Worker                          return a.begin()->score > b.begin()->score ||
354*993b0882SAndroid Build Coastguard Worker                                 (a.begin()->score >= b.begin()->score &&
355*993b0882SAndroid Build Coastguard Worker                                  a.begin()->type < b.begin()->type);
356*993b0882SAndroid Build Coastguard Worker                        });
357*993b0882SAndroid Build Coastguard Worker     }
358*993b0882SAndroid Build Coastguard Worker 
359*993b0882SAndroid Build Coastguard Worker     // Flatten result.
360*993b0882SAndroid Build Coastguard Worker     const size_t num_actions = response->actions.size();
361*993b0882SAndroid Build Coastguard Worker     response->actions.clear();
362*993b0882SAndroid Build Coastguard Worker     response->actions.reserve(num_actions);
363*993b0882SAndroid Build Coastguard Worker     for (const std::vector<ActionSuggestion>& actions : groups) {
364*993b0882SAndroid Build Coastguard Worker       response->actions.insert(response->actions.end(), actions.begin(),
365*993b0882SAndroid Build Coastguard Worker                                actions.end());
366*993b0882SAndroid Build Coastguard Worker     }
367*993b0882SAndroid Build Coastguard Worker   } else if (sort_by_priority) {
368*993b0882SAndroid Build Coastguard Worker     SortByPriorityAndScoreAndType(&response->actions);
369*993b0882SAndroid Build Coastguard Worker   } else {
370*993b0882SAndroid Build Coastguard Worker     SortByScoreAndType(&response->actions);
371*993b0882SAndroid Build Coastguard Worker   }
372*993b0882SAndroid Build Coastguard Worker 
373*993b0882SAndroid Build Coastguard Worker #if !defined(TC3_DISABLE_LUA)
374*993b0882SAndroid Build Coastguard Worker   // Run lua ranking snippet, if provided.
375*993b0882SAndroid Build Coastguard Worker   if (!lua_bytecode_.empty()) {
376*993b0882SAndroid Build Coastguard Worker     auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
377*993b0882SAndroid Build Coastguard Worker         conversation, lua_bytecode_, entity_data_schema,
378*993b0882SAndroid Build Coastguard Worker         annotations_entity_data_schema, response);
379*993b0882SAndroid Build Coastguard Worker     if (lua_ranker == nullptr || !lua_ranker->RankActions()) {
380*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not run lua ranking snippet.";
381*993b0882SAndroid Build Coastguard Worker       return false;
382*993b0882SAndroid Build Coastguard Worker     }
383*993b0882SAndroid Build Coastguard Worker   }
384*993b0882SAndroid Build Coastguard Worker #endif
385*993b0882SAndroid Build Coastguard Worker 
386*993b0882SAndroid Build Coastguard Worker   return true;
387*993b0882SAndroid Build Coastguard Worker }
388*993b0882SAndroid Build Coastguard Worker 
389*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
390