xref: /aosp_15_r20/external/libtextclassifier/native/actions/ranker_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/ranker.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <string>
20*993b0882SAndroid Build Coastguard Worker 
21*993b0882SAndroid Build Coastguard Worker #include "actions/actions_model_generated.h"
22*993b0882SAndroid Build Coastguard Worker #include "actions/types.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/zlib/zlib.h"
24*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
25*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
26*993b0882SAndroid Build Coastguard Worker 
27*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
28*993b0882SAndroid Build Coastguard Worker namespace {
29*993b0882SAndroid Build Coastguard Worker 
30*993b0882SAndroid Build Coastguard Worker MATCHER_P3(IsAction, type, response_text, score, "") {
31*993b0882SAndroid Build Coastguard Worker   return testing::Value(arg.type, type) &&
32*993b0882SAndroid Build Coastguard Worker          testing::Value(arg.response_text, response_text) &&
33*993b0882SAndroid Build Coastguard Worker          testing::Value(arg.score, score);
34*993b0882SAndroid Build Coastguard Worker }
35*993b0882SAndroid Build Coastguard Worker 
36*993b0882SAndroid Build Coastguard Worker MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
37*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,DeduplicationSmartReply)38*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, DeduplicationSmartReply) {
39*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
40*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
41*993b0882SAndroid Build Coastguard Worker   response.actions = {
42*993b0882SAndroid Build Coastguard Worker       {/*response_text=*/"hello there", /*type=*/"text_reply",
43*993b0882SAndroid Build Coastguard Worker        /*score=*/1.0},
44*993b0882SAndroid Build Coastguard Worker       {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5}};
45*993b0882SAndroid Build Coastguard Worker 
46*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
47*993b0882SAndroid Build Coastguard Worker   options.deduplicate_suggestions = true;
48*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
49*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
50*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
51*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
52*993b0882SAndroid Build Coastguard Worker       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
53*993b0882SAndroid Build Coastguard Worker 
54*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
55*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
56*993b0882SAndroid Build Coastguard Worker       response.actions,
57*993b0882SAndroid Build Coastguard Worker       testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0)}));
58*993b0882SAndroid Build Coastguard Worker }
59*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,DeduplicationExtraData)60*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, DeduplicationExtraData) {
61*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
62*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
63*993b0882SAndroid Build Coastguard Worker   response.actions = {
64*993b0882SAndroid Build Coastguard Worker       {/*response_text=*/"hello there", /*type=*/"text_reply",
65*993b0882SAndroid Build Coastguard Worker        /*score=*/1.0, /*priority_score=*/0.0},
66*993b0882SAndroid Build Coastguard Worker       {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5,
67*993b0882SAndroid Build Coastguard Worker        /*priority_score=*/0.0},
68*993b0882SAndroid Build Coastguard Worker       {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.6,
69*993b0882SAndroid Build Coastguard Worker        /*priority_score=*/0.0,
70*993b0882SAndroid Build Coastguard Worker        /*annotations=*/{}, /*serialized_entity_data=*/"test"},
71*993b0882SAndroid Build Coastguard Worker   };
72*993b0882SAndroid Build Coastguard Worker 
73*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
74*993b0882SAndroid Build Coastguard Worker   options.deduplicate_suggestions = true;
75*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
76*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
77*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
78*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
79*993b0882SAndroid Build Coastguard Worker       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
80*993b0882SAndroid Build Coastguard Worker 
81*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
82*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
83*993b0882SAndroid Build Coastguard Worker       response.actions,
84*993b0882SAndroid Build Coastguard Worker       testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0),
85*993b0882SAndroid Build Coastguard Worker                                  // Is kept as it has different entity data.
86*993b0882SAndroid Build Coastguard Worker                                  IsAction("text_reply", "hello there", 0.6)}));
87*993b0882SAndroid Build Coastguard Worker }
88*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,DeduplicationAnnotations)89*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, DeduplicationAnnotations) {
90*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {
91*993b0882SAndroid Build Coastguard Worker       {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
92*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
93*993b0882SAndroid Build Coastguard Worker   {
94*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
95*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
96*993b0882SAndroid Build Coastguard Worker                        /*text=*/"742 Evergreen Terrace"};
97*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("address", 0.5);
98*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
99*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"view_map",
100*993b0882SAndroid Build Coastguard Worker                                 /*score=*/0.5,
101*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/1.0,
102*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
103*993b0882SAndroid Build Coastguard Worker   }
104*993b0882SAndroid Build Coastguard Worker   {
105*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
106*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
107*993b0882SAndroid Build Coastguard Worker                        /*text=*/"742 Evergreen Terrace"};
108*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("address", 1.0);
109*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
110*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"view_map",
111*993b0882SAndroid Build Coastguard Worker                                 /*score=*/1.0,
112*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/2.0,
113*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
114*993b0882SAndroid Build Coastguard Worker   }
115*993b0882SAndroid Build Coastguard Worker   {
116*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
117*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
118*993b0882SAndroid Build Coastguard Worker                        /*text=*/"1-800-TESTING"};
119*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("phone", 0.5);
120*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
121*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"call_phone",
122*993b0882SAndroid Build Coastguard Worker                                 /*score=*/0.5,
123*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/1.0,
124*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
125*993b0882SAndroid Build Coastguard Worker   }
126*993b0882SAndroid Build Coastguard Worker 
127*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
128*993b0882SAndroid Build Coastguard Worker   options.deduplicate_suggestions = true;
129*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
130*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
131*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
132*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
133*993b0882SAndroid Build Coastguard Worker       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
134*993b0882SAndroid Build Coastguard Worker 
135*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
136*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions,
137*993b0882SAndroid Build Coastguard Worker               testing::ElementsAreArray({IsAction("view_map", "", 1.0),
138*993b0882SAndroid Build Coastguard Worker                                          IsAction("call_phone", "", 0.5)}));
139*993b0882SAndroid Build Coastguard Worker }
140*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,DeduplicationAnnotationsByPriorityScore)141*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, DeduplicationAnnotationsByPriorityScore) {
142*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {
143*993b0882SAndroid Build Coastguard Worker       {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
144*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
145*993b0882SAndroid Build Coastguard Worker   {
146*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
147*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
148*993b0882SAndroid Build Coastguard Worker                        /*text=*/"742 Evergreen Terrace"};
149*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("address", 0.5);
150*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
151*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"view_map",
152*993b0882SAndroid Build Coastguard Worker                                 /*score=*/0.6,
153*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/2.0,
154*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
155*993b0882SAndroid Build Coastguard Worker   }
156*993b0882SAndroid Build Coastguard Worker   {
157*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
158*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
159*993b0882SAndroid Build Coastguard Worker                        /*text=*/"742 Evergreen Terrace"};
160*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("address", 1.0);
161*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
162*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"view_map",
163*993b0882SAndroid Build Coastguard Worker                                 /*score=*/1.0,
164*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/1.0,
165*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
166*993b0882SAndroid Build Coastguard Worker   }
167*993b0882SAndroid Build Coastguard Worker   {
168*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
169*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
170*993b0882SAndroid Build Coastguard Worker                        /*text=*/"1-800-TESTING"};
171*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("phone", 0.5);
172*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
173*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"call_phone",
174*993b0882SAndroid Build Coastguard Worker                                 /*score=*/0.5,
175*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/1.0,
176*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
177*993b0882SAndroid Build Coastguard Worker   }
178*993b0882SAndroid Build Coastguard Worker 
179*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
180*993b0882SAndroid Build Coastguard Worker   options.deduplicate_suggestions = true;
181*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
182*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
183*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
184*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
185*993b0882SAndroid Build Coastguard Worker       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
186*993b0882SAndroid Build Coastguard Worker 
187*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
188*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
189*993b0882SAndroid Build Coastguard Worker       response.actions,
190*993b0882SAndroid Build Coastguard Worker       testing::ElementsAreArray(
191*993b0882SAndroid Build Coastguard Worker           {IsAction("view_map", "",
192*993b0882SAndroid Build Coastguard Worker                     0.6),  // lower score wins, as priority score is higher
193*993b0882SAndroid Build Coastguard Worker            IsAction("call_phone", "", 0.5)}));
194*993b0882SAndroid Build Coastguard Worker }
195*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,DeduplicatesConflictingActions)196*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, DeduplicatesConflictingActions) {
197*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {{{/*user_id=*/1, "code A-911"}}};
198*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
199*993b0882SAndroid Build Coastguard Worker   {
200*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
201*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{7, 10},
202*993b0882SAndroid Build Coastguard Worker                        /*text=*/"911"};
203*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("phone", 1.0);
204*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
205*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"call_phone",
206*993b0882SAndroid Build Coastguard Worker                                 /*score=*/1.0,
207*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/1.0,
208*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
209*993b0882SAndroid Build Coastguard Worker   }
210*993b0882SAndroid Build Coastguard Worker   {
211*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
212*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{5, 10},
213*993b0882SAndroid Build Coastguard Worker                        /*text=*/"A-911"};
214*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("code", 1.0);
215*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
216*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"copy_code",
217*993b0882SAndroid Build Coastguard Worker                                 /*score=*/1.0,
218*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/2.0,
219*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
220*993b0882SAndroid Build Coastguard Worker   }
221*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
222*993b0882SAndroid Build Coastguard Worker   options.deduplicate_suggestions = true;
223*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
224*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
225*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
226*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
227*993b0882SAndroid Build Coastguard Worker       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
228*993b0882SAndroid Build Coastguard Worker 
229*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
230*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions,
231*993b0882SAndroid Build Coastguard Worker               testing::ElementsAreArray({IsAction("copy_code", "", 1.0)}));
232*993b0882SAndroid Build Coastguard Worker }
233*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,HandlesCompressedLuaScript)234*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, HandlesCompressedLuaScript) {
235*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
236*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
237*993b0882SAndroid Build Coastguard Worker   response.actions = {
238*993b0882SAndroid Build Coastguard Worker       {/*response_text=*/"hello there", /*type=*/"text_reply",
239*993b0882SAndroid Build Coastguard Worker        /*score=*/1.0},
240*993b0882SAndroid Build Coastguard Worker       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
241*993b0882SAndroid Build Coastguard Worker       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
242*993b0882SAndroid Build Coastguard Worker   const std::string test_snippet = R"(
243*993b0882SAndroid Build Coastguard Worker     local result = {}
244*993b0882SAndroid Build Coastguard Worker     for id, action in pairs(actions) do
245*993b0882SAndroid Build Coastguard Worker       if action.type ~= "text_reply" then
246*993b0882SAndroid Build Coastguard Worker         table.insert(result, id)
247*993b0882SAndroid Build Coastguard Worker       end
248*993b0882SAndroid Build Coastguard Worker     end
249*993b0882SAndroid Build Coastguard Worker     return result
250*993b0882SAndroid Build Coastguard Worker   )";
251*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
252*993b0882SAndroid Build Coastguard Worker   options.compressed_lua_ranking_script.reset(new CompressedBufferT);
253*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
254*993b0882SAndroid Build Coastguard Worker   compressor->Compress(test_snippet,
255*993b0882SAndroid Build Coastguard Worker                        options.compressed_lua_ranking_script.get());
256*993b0882SAndroid Build Coastguard Worker   options.deduplicate_suggestions = true;
257*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
258*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
259*993b0882SAndroid Build Coastguard Worker 
260*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
261*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
262*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
263*993b0882SAndroid Build Coastguard Worker       decompressor.get(), /*smart_reply_action_type=*/"text_reply");
264*993b0882SAndroid Build Coastguard Worker 
265*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
266*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions,
267*993b0882SAndroid Build Coastguard Worker               testing::ElementsAreArray({IsActionType("share_location"),
268*993b0882SAndroid Build Coastguard Worker                                          IsActionType("add_to_collection")}));
269*993b0882SAndroid Build Coastguard Worker }
270*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,SuppressSmartRepliesWithAction)271*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, SuppressSmartRepliesWithAction) {
272*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
273*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
274*993b0882SAndroid Build Coastguard Worker   {
275*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
276*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
277*993b0882SAndroid Build Coastguard Worker                        /*text=*/"911"};
278*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("phone", 1.0);
279*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
280*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"call_phone",
281*993b0882SAndroid Build Coastguard Worker                                 /*score=*/1.0,
282*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/1.0,
283*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
284*993b0882SAndroid Build Coastguard Worker   }
285*993b0882SAndroid Build Coastguard Worker   response.actions.push_back({/*response_text=*/"How are you?",
286*993b0882SAndroid Build Coastguard Worker                               /*type=*/"text_reply"});
287*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
288*993b0882SAndroid Build Coastguard Worker   options.suppress_smart_replies_with_actions = true;
289*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
290*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
291*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
292*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
293*993b0882SAndroid Build Coastguard Worker       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
294*993b0882SAndroid Build Coastguard Worker 
295*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
296*993b0882SAndroid Build Coastguard Worker 
297*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions,
298*993b0882SAndroid Build Coastguard Worker               testing::ElementsAreArray({IsAction("call_phone", "", 1.0)}));
299*993b0882SAndroid Build Coastguard Worker }
300*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,GroupsActionsByAnnotations)301*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, GroupsActionsByAnnotations) {
302*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
303*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
304*993b0882SAndroid Build Coastguard Worker   {
305*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
306*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
307*993b0882SAndroid Build Coastguard Worker                        /*text=*/"911"};
308*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("phone", 1.0);
309*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
310*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"call_phone",
311*993b0882SAndroid Build Coastguard Worker                                 /*score=*/1.0,
312*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/0.0,
313*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
314*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
315*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"add_contact",
316*993b0882SAndroid Build Coastguard Worker                                 /*score=*/0.0,
317*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/1.0,
318*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
319*993b0882SAndroid Build Coastguard Worker   }
320*993b0882SAndroid Build Coastguard Worker   response.actions.push_back({/*response_text=*/"How are you?",
321*993b0882SAndroid Build Coastguard Worker                               /*type=*/"text_reply",
322*993b0882SAndroid Build Coastguard Worker                               /*score=*/0.5});
323*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
324*993b0882SAndroid Build Coastguard Worker   options.group_by_annotations = true;
325*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
326*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
327*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
328*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
329*993b0882SAndroid Build Coastguard Worker       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
330*993b0882SAndroid Build Coastguard Worker 
331*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
332*993b0882SAndroid Build Coastguard Worker 
333*993b0882SAndroid Build Coastguard Worker   // The text reply should be last, even though it has a higher score than the
334*993b0882SAndroid Build Coastguard Worker   // `add_contact` action.
335*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
336*993b0882SAndroid Build Coastguard Worker       response.actions,
337*993b0882SAndroid Build Coastguard Worker       testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
338*993b0882SAndroid Build Coastguard Worker                                  IsAction("add_contact", "", 0.0),
339*993b0882SAndroid Build Coastguard Worker                                  IsAction("text_reply", "How are you?", 0.5)}));
340*993b0882SAndroid Build Coastguard Worker }
341*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,GroupsByAnnotationsSortedByPriority)342*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, GroupsByAnnotationsSortedByPriority) {
343*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
344*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
345*993b0882SAndroid Build Coastguard Worker   response.actions.push_back({/*response_text=*/"How are you?",
346*993b0882SAndroid Build Coastguard Worker                               /*type=*/"text_reply",
347*993b0882SAndroid Build Coastguard Worker                               /*score=*/2.0,
348*993b0882SAndroid Build Coastguard Worker                               /*priority_score=*/0.0});
349*993b0882SAndroid Build Coastguard Worker   {
350*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
351*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
352*993b0882SAndroid Build Coastguard Worker                        /*text=*/"911"};
353*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("phone", 1.0);
354*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
355*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"add_contact",
356*993b0882SAndroid Build Coastguard Worker                                 /*score=*/0.0,
357*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/1.0,
358*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
359*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
360*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"call_phone",
361*993b0882SAndroid Build Coastguard Worker                                 /*score=*/1.0,
362*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/0.0,
363*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
364*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
365*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"add_contact2",
366*993b0882SAndroid Build Coastguard Worker                                 /*score=*/0.5,
367*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/1.0,
368*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
369*993b0882SAndroid Build Coastguard Worker   }
370*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
371*993b0882SAndroid Build Coastguard Worker   options.group_by_annotations = true;
372*993b0882SAndroid Build Coastguard Worker   options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
373*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
374*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
375*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
376*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
377*993b0882SAndroid Build Coastguard Worker       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
378*993b0882SAndroid Build Coastguard Worker 
379*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
380*993b0882SAndroid Build Coastguard Worker 
381*993b0882SAndroid Build Coastguard Worker   // The text reply should be last, even though it's score is higher than
382*993b0882SAndroid Build Coastguard Worker   // any other scores -- because it's priority_score is lower than the max
383*993b0882SAndroid Build Coastguard Worker   // of those with the 'phone' annotation
384*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions,
385*993b0882SAndroid Build Coastguard Worker               testing::ElementsAreArray({
386*993b0882SAndroid Build Coastguard Worker                   // Group 1 (Phone annotation)
387*993b0882SAndroid Build Coastguard Worker                   IsAction("add_contact2", "", 0.5),  // priority_score=1.0
388*993b0882SAndroid Build Coastguard Worker                   IsAction("add_contact", "", 0.0),   // priority_score=1.0
389*993b0882SAndroid Build Coastguard Worker                   IsAction("call_phone", "", 1.0),    // priority_score=0.0
390*993b0882SAndroid Build Coastguard Worker                   IsAction("text_reply", "How are you?", 2.0),  // Group 2
391*993b0882SAndroid Build Coastguard Worker               }));
392*993b0882SAndroid Build Coastguard Worker }
393*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,SortsActionsByScore)394*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, SortsActionsByScore) {
395*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
396*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
397*993b0882SAndroid Build Coastguard Worker   {
398*993b0882SAndroid Build Coastguard Worker     ActionSuggestionAnnotation annotation;
399*993b0882SAndroid Build Coastguard Worker     annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
400*993b0882SAndroid Build Coastguard Worker                        /*text=*/"911"};
401*993b0882SAndroid Build Coastguard Worker     annotation.entity = ClassificationResult("phone", 1.0);
402*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
403*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"call_phone",
404*993b0882SAndroid Build Coastguard Worker                                 /*score=*/1.0,
405*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/0.0,
406*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
407*993b0882SAndroid Build Coastguard Worker     response.actions.push_back({/*response_text=*/"",
408*993b0882SAndroid Build Coastguard Worker                                 /*type=*/"add_contact",
409*993b0882SAndroid Build Coastguard Worker                                 /*score=*/0.0,
410*993b0882SAndroid Build Coastguard Worker                                 /*priority_score=*/1.0,
411*993b0882SAndroid Build Coastguard Worker                                 /*annotations=*/{annotation}});
412*993b0882SAndroid Build Coastguard Worker   }
413*993b0882SAndroid Build Coastguard Worker   response.actions.push_back({/*response_text=*/"How are you?",
414*993b0882SAndroid Build Coastguard Worker                               /*type=*/"text_reply",
415*993b0882SAndroid Build Coastguard Worker                               /*score=*/0.5});
416*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
417*993b0882SAndroid Build Coastguard Worker   // Don't group by annotation.
418*993b0882SAndroid Build Coastguard Worker   options.group_by_annotations = false;
419*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
420*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
421*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
422*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
423*993b0882SAndroid Build Coastguard Worker       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
424*993b0882SAndroid Build Coastguard Worker 
425*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
426*993b0882SAndroid Build Coastguard Worker 
427*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
428*993b0882SAndroid Build Coastguard Worker       response.actions,
429*993b0882SAndroid Build Coastguard Worker       testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
430*993b0882SAndroid Build Coastguard Worker                                  IsAction("text_reply", "How are you?", 0.5),
431*993b0882SAndroid Build Coastguard Worker                                  IsAction("add_contact", "", 0.0)}));
432*993b0882SAndroid Build Coastguard Worker }
433*993b0882SAndroid Build Coastguard Worker 
TEST(RankingTest,SortsActionsByPriority)434*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, SortsActionsByPriority) {
435*993b0882SAndroid Build Coastguard Worker   const Conversation conversation = {{{/*user_id=*/1, "hello?"}}};
436*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
437*993b0882SAndroid Build Coastguard Worker   // emoji replies given higher priority_score
438*993b0882SAndroid Build Coastguard Worker   response.actions.push_back({/*response_text=*/"��",
439*993b0882SAndroid Build Coastguard Worker                               /*type=*/"text_reply",
440*993b0882SAndroid Build Coastguard Worker                               /*score=*/0.5,
441*993b0882SAndroid Build Coastguard Worker                               /*priority_score=*/1.0});
442*993b0882SAndroid Build Coastguard Worker   response.actions.push_back({/*response_text=*/"��",
443*993b0882SAndroid Build Coastguard Worker                               /*type=*/"text_reply",
444*993b0882SAndroid Build Coastguard Worker                               /*score=*/0.4,
445*993b0882SAndroid Build Coastguard Worker                               /*priority_score=*/1.0});
446*993b0882SAndroid Build Coastguard Worker   response.actions.push_back({/*response_text=*/"Yes",
447*993b0882SAndroid Build Coastguard Worker                               /*type=*/"text_reply",
448*993b0882SAndroid Build Coastguard Worker                               /*score=*/1.0,
449*993b0882SAndroid Build Coastguard Worker                               /*priority_score=*/0.0});
450*993b0882SAndroid Build Coastguard Worker   RankingOptionsT options;
451*993b0882SAndroid Build Coastguard Worker   // Don't group by annotation.
452*993b0882SAndroid Build Coastguard Worker   options.group_by_annotations = false;
453*993b0882SAndroid Build Coastguard Worker   options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
454*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
455*993b0882SAndroid Build Coastguard Worker   builder.Finish(RankingOptions::Pack(builder, &options));
456*993b0882SAndroid Build Coastguard Worker   auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
457*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
458*993b0882SAndroid Build Coastguard Worker       /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
459*993b0882SAndroid Build Coastguard Worker 
460*993b0882SAndroid Build Coastguard Worker   ranker->RankActions(conversation, &response);
461*993b0882SAndroid Build Coastguard Worker 
462*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, testing::ElementsAreArray(
463*993b0882SAndroid Build Coastguard Worker                                     {IsAction("text_reply", "��", 0.5),
464*993b0882SAndroid Build Coastguard Worker                                      IsAction("text_reply", "��", 0.4),
465*993b0882SAndroid Build Coastguard Worker                                      // Ranked last because of priority score
466*993b0882SAndroid Build Coastguard Worker                                      IsAction("text_reply", "Yes", 1.0)}));
467*993b0882SAndroid Build Coastguard Worker }
468*993b0882SAndroid Build Coastguard Worker 
469*993b0882SAndroid Build Coastguard Worker }  // namespace
470*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
471