1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "actions/ranker.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <string>
20*993b0882SAndroid Build Coastguard Worker
21*993b0882SAndroid Build Coastguard Worker #include "actions/actions_model_generated.h"
22*993b0882SAndroid Build Coastguard Worker #include "actions/types.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/zlib/zlib.h"
24*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
25*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
26*993b0882SAndroid Build Coastguard Worker
27*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
28*993b0882SAndroid Build Coastguard Worker namespace {
29*993b0882SAndroid Build Coastguard Worker
30*993b0882SAndroid Build Coastguard Worker MATCHER_P3(IsAction, type, response_text, score, "") {
31*993b0882SAndroid Build Coastguard Worker return testing::Value(arg.type, type) &&
32*993b0882SAndroid Build Coastguard Worker testing::Value(arg.response_text, response_text) &&
33*993b0882SAndroid Build Coastguard Worker testing::Value(arg.score, score);
34*993b0882SAndroid Build Coastguard Worker }
35*993b0882SAndroid Build Coastguard Worker
36*993b0882SAndroid Build Coastguard Worker MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
37*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,DeduplicationSmartReply)38*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, DeduplicationSmartReply) {
39*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
40*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
41*993b0882SAndroid Build Coastguard Worker response.actions = {
42*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply",
43*993b0882SAndroid Build Coastguard Worker /*score=*/1.0},
44*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5}};
45*993b0882SAndroid Build Coastguard Worker
46*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
47*993b0882SAndroid Build Coastguard Worker options.deduplicate_suggestions = true;
48*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
49*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
50*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
51*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
52*993b0882SAndroid Build Coastguard Worker /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
53*993b0882SAndroid Build Coastguard Worker
54*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
55*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
56*993b0882SAndroid Build Coastguard Worker response.actions,
57*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0)}));
58*993b0882SAndroid Build Coastguard Worker }
59*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,DeduplicationExtraData)60*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, DeduplicationExtraData) {
61*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
62*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
63*993b0882SAndroid Build Coastguard Worker response.actions = {
64*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply",
65*993b0882SAndroid Build Coastguard Worker /*score=*/1.0, /*priority_score=*/0.0},
66*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5,
67*993b0882SAndroid Build Coastguard Worker /*priority_score=*/0.0},
68*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.6,
69*993b0882SAndroid Build Coastguard Worker /*priority_score=*/0.0,
70*993b0882SAndroid Build Coastguard Worker /*annotations=*/{}, /*serialized_entity_data=*/"test"},
71*993b0882SAndroid Build Coastguard Worker };
72*993b0882SAndroid Build Coastguard Worker
73*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
74*993b0882SAndroid Build Coastguard Worker options.deduplicate_suggestions = true;
75*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
76*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
77*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
78*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
79*993b0882SAndroid Build Coastguard Worker /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
80*993b0882SAndroid Build Coastguard Worker
81*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
82*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
83*993b0882SAndroid Build Coastguard Worker response.actions,
84*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0),
85*993b0882SAndroid Build Coastguard Worker // Is kept as it has different entity data.
86*993b0882SAndroid Build Coastguard Worker IsAction("text_reply", "hello there", 0.6)}));
87*993b0882SAndroid Build Coastguard Worker }
88*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,DeduplicationAnnotations)89*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, DeduplicationAnnotations) {
90*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {
91*993b0882SAndroid Build Coastguard Worker {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
92*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
93*993b0882SAndroid Build Coastguard Worker {
94*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
95*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
96*993b0882SAndroid Build Coastguard Worker /*text=*/"742 Evergreen Terrace"};
97*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("address", 0.5);
98*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
99*993b0882SAndroid Build Coastguard Worker /*type=*/"view_map",
100*993b0882SAndroid Build Coastguard Worker /*score=*/0.5,
101*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0,
102*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
103*993b0882SAndroid Build Coastguard Worker }
104*993b0882SAndroid Build Coastguard Worker {
105*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
106*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
107*993b0882SAndroid Build Coastguard Worker /*text=*/"742 Evergreen Terrace"};
108*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("address", 1.0);
109*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
110*993b0882SAndroid Build Coastguard Worker /*type=*/"view_map",
111*993b0882SAndroid Build Coastguard Worker /*score=*/1.0,
112*993b0882SAndroid Build Coastguard Worker /*priority_score=*/2.0,
113*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
114*993b0882SAndroid Build Coastguard Worker }
115*993b0882SAndroid Build Coastguard Worker {
116*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
117*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
118*993b0882SAndroid Build Coastguard Worker /*text=*/"1-800-TESTING"};
119*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("phone", 0.5);
120*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
121*993b0882SAndroid Build Coastguard Worker /*type=*/"call_phone",
122*993b0882SAndroid Build Coastguard Worker /*score=*/0.5,
123*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0,
124*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
125*993b0882SAndroid Build Coastguard Worker }
126*993b0882SAndroid Build Coastguard Worker
127*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
128*993b0882SAndroid Build Coastguard Worker options.deduplicate_suggestions = true;
129*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
130*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
131*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
132*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
133*993b0882SAndroid Build Coastguard Worker /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
134*993b0882SAndroid Build Coastguard Worker
135*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
136*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
137*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsAction("view_map", "", 1.0),
138*993b0882SAndroid Build Coastguard Worker IsAction("call_phone", "", 0.5)}));
139*993b0882SAndroid Build Coastguard Worker }
140*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,DeduplicationAnnotationsByPriorityScore)141*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, DeduplicationAnnotationsByPriorityScore) {
142*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {
143*993b0882SAndroid Build Coastguard Worker {{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
144*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
145*993b0882SAndroid Build Coastguard Worker {
146*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
147*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
148*993b0882SAndroid Build Coastguard Worker /*text=*/"742 Evergreen Terrace"};
149*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("address", 0.5);
150*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
151*993b0882SAndroid Build Coastguard Worker /*type=*/"view_map",
152*993b0882SAndroid Build Coastguard Worker /*score=*/0.6,
153*993b0882SAndroid Build Coastguard Worker /*priority_score=*/2.0,
154*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
155*993b0882SAndroid Build Coastguard Worker }
156*993b0882SAndroid Build Coastguard Worker {
157*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
158*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
159*993b0882SAndroid Build Coastguard Worker /*text=*/"742 Evergreen Terrace"};
160*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("address", 1.0);
161*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
162*993b0882SAndroid Build Coastguard Worker /*type=*/"view_map",
163*993b0882SAndroid Build Coastguard Worker /*score=*/1.0,
164*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0,
165*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
166*993b0882SAndroid Build Coastguard Worker }
167*993b0882SAndroid Build Coastguard Worker {
168*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
169*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
170*993b0882SAndroid Build Coastguard Worker /*text=*/"1-800-TESTING"};
171*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("phone", 0.5);
172*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
173*993b0882SAndroid Build Coastguard Worker /*type=*/"call_phone",
174*993b0882SAndroid Build Coastguard Worker /*score=*/0.5,
175*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0,
176*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
177*993b0882SAndroid Build Coastguard Worker }
178*993b0882SAndroid Build Coastguard Worker
179*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
180*993b0882SAndroid Build Coastguard Worker options.deduplicate_suggestions = true;
181*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
182*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
183*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
184*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
185*993b0882SAndroid Build Coastguard Worker /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
186*993b0882SAndroid Build Coastguard Worker
187*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
188*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
189*993b0882SAndroid Build Coastguard Worker response.actions,
190*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray(
191*993b0882SAndroid Build Coastguard Worker {IsAction("view_map", "",
192*993b0882SAndroid Build Coastguard Worker 0.6), // lower score wins, as priority score is higher
193*993b0882SAndroid Build Coastguard Worker IsAction("call_phone", "", 0.5)}));
194*993b0882SAndroid Build Coastguard Worker }
195*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,DeduplicatesConflictingActions)196*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, DeduplicatesConflictingActions) {
197*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "code A-911"}}};
198*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
199*993b0882SAndroid Build Coastguard Worker {
200*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
201*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{7, 10},
202*993b0882SAndroid Build Coastguard Worker /*text=*/"911"};
203*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("phone", 1.0);
204*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
205*993b0882SAndroid Build Coastguard Worker /*type=*/"call_phone",
206*993b0882SAndroid Build Coastguard Worker /*score=*/1.0,
207*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0,
208*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
209*993b0882SAndroid Build Coastguard Worker }
210*993b0882SAndroid Build Coastguard Worker {
211*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
212*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{5, 10},
213*993b0882SAndroid Build Coastguard Worker /*text=*/"A-911"};
214*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("code", 1.0);
215*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
216*993b0882SAndroid Build Coastguard Worker /*type=*/"copy_code",
217*993b0882SAndroid Build Coastguard Worker /*score=*/1.0,
218*993b0882SAndroid Build Coastguard Worker /*priority_score=*/2.0,
219*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
220*993b0882SAndroid Build Coastguard Worker }
221*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
222*993b0882SAndroid Build Coastguard Worker options.deduplicate_suggestions = true;
223*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
224*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
225*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
226*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
227*993b0882SAndroid Build Coastguard Worker /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
228*993b0882SAndroid Build Coastguard Worker
229*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
230*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
231*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsAction("copy_code", "", 1.0)}));
232*993b0882SAndroid Build Coastguard Worker }
233*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,HandlesCompressedLuaScript)234*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, HandlesCompressedLuaScript) {
235*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
236*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
237*993b0882SAndroid Build Coastguard Worker response.actions = {
238*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"hello there", /*type=*/"text_reply",
239*993b0882SAndroid Build Coastguard Worker /*score=*/1.0},
240*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
241*993b0882SAndroid Build Coastguard Worker {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
242*993b0882SAndroid Build Coastguard Worker const std::string test_snippet = R"(
243*993b0882SAndroid Build Coastguard Worker local result = {}
244*993b0882SAndroid Build Coastguard Worker for id, action in pairs(actions) do
245*993b0882SAndroid Build Coastguard Worker if action.type ~= "text_reply" then
246*993b0882SAndroid Build Coastguard Worker table.insert(result, id)
247*993b0882SAndroid Build Coastguard Worker end
248*993b0882SAndroid Build Coastguard Worker end
249*993b0882SAndroid Build Coastguard Worker return result
250*993b0882SAndroid Build Coastguard Worker )";
251*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
252*993b0882SAndroid Build Coastguard Worker options.compressed_lua_ranking_script.reset(new CompressedBufferT);
253*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
254*993b0882SAndroid Build Coastguard Worker compressor->Compress(test_snippet,
255*993b0882SAndroid Build Coastguard Worker options.compressed_lua_ranking_script.get());
256*993b0882SAndroid Build Coastguard Worker options.deduplicate_suggestions = true;
257*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
258*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
259*993b0882SAndroid Build Coastguard Worker
260*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
261*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
262*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
263*993b0882SAndroid Build Coastguard Worker decompressor.get(), /*smart_reply_action_type=*/"text_reply");
264*993b0882SAndroid Build Coastguard Worker
265*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
266*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
267*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsActionType("share_location"),
268*993b0882SAndroid Build Coastguard Worker IsActionType("add_to_collection")}));
269*993b0882SAndroid Build Coastguard Worker }
270*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,SuppressSmartRepliesWithAction)271*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, SuppressSmartRepliesWithAction) {
272*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
273*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
274*993b0882SAndroid Build Coastguard Worker {
275*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
276*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
277*993b0882SAndroid Build Coastguard Worker /*text=*/"911"};
278*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("phone", 1.0);
279*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
280*993b0882SAndroid Build Coastguard Worker /*type=*/"call_phone",
281*993b0882SAndroid Build Coastguard Worker /*score=*/1.0,
282*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0,
283*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
284*993b0882SAndroid Build Coastguard Worker }
285*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"How are you?",
286*993b0882SAndroid Build Coastguard Worker /*type=*/"text_reply"});
287*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
288*993b0882SAndroid Build Coastguard Worker options.suppress_smart_replies_with_actions = true;
289*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
290*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
291*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
292*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
293*993b0882SAndroid Build Coastguard Worker /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
294*993b0882SAndroid Build Coastguard Worker
295*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
296*993b0882SAndroid Build Coastguard Worker
297*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
298*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsAction("call_phone", "", 1.0)}));
299*993b0882SAndroid Build Coastguard Worker }
300*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,GroupsActionsByAnnotations)301*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, GroupsActionsByAnnotations) {
302*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
303*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
304*993b0882SAndroid Build Coastguard Worker {
305*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
306*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
307*993b0882SAndroid Build Coastguard Worker /*text=*/"911"};
308*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("phone", 1.0);
309*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
310*993b0882SAndroid Build Coastguard Worker /*type=*/"call_phone",
311*993b0882SAndroid Build Coastguard Worker /*score=*/1.0,
312*993b0882SAndroid Build Coastguard Worker /*priority_score=*/0.0,
313*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
314*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
315*993b0882SAndroid Build Coastguard Worker /*type=*/"add_contact",
316*993b0882SAndroid Build Coastguard Worker /*score=*/0.0,
317*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0,
318*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
319*993b0882SAndroid Build Coastguard Worker }
320*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"How are you?",
321*993b0882SAndroid Build Coastguard Worker /*type=*/"text_reply",
322*993b0882SAndroid Build Coastguard Worker /*score=*/0.5});
323*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
324*993b0882SAndroid Build Coastguard Worker options.group_by_annotations = true;
325*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
326*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
327*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
328*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
329*993b0882SAndroid Build Coastguard Worker /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
330*993b0882SAndroid Build Coastguard Worker
331*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
332*993b0882SAndroid Build Coastguard Worker
333*993b0882SAndroid Build Coastguard Worker // The text reply should be last, even though it has a higher score than the
334*993b0882SAndroid Build Coastguard Worker // `add_contact` action.
335*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
336*993b0882SAndroid Build Coastguard Worker response.actions,
337*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
338*993b0882SAndroid Build Coastguard Worker IsAction("add_contact", "", 0.0),
339*993b0882SAndroid Build Coastguard Worker IsAction("text_reply", "How are you?", 0.5)}));
340*993b0882SAndroid Build Coastguard Worker }
341*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,GroupsByAnnotationsSortedByPriority)342*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, GroupsByAnnotationsSortedByPriority) {
343*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
344*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
345*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"How are you?",
346*993b0882SAndroid Build Coastguard Worker /*type=*/"text_reply",
347*993b0882SAndroid Build Coastguard Worker /*score=*/2.0,
348*993b0882SAndroid Build Coastguard Worker /*priority_score=*/0.0});
349*993b0882SAndroid Build Coastguard Worker {
350*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
351*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
352*993b0882SAndroid Build Coastguard Worker /*text=*/"911"};
353*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("phone", 1.0);
354*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
355*993b0882SAndroid Build Coastguard Worker /*type=*/"add_contact",
356*993b0882SAndroid Build Coastguard Worker /*score=*/0.0,
357*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0,
358*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
359*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
360*993b0882SAndroid Build Coastguard Worker /*type=*/"call_phone",
361*993b0882SAndroid Build Coastguard Worker /*score=*/1.0,
362*993b0882SAndroid Build Coastguard Worker /*priority_score=*/0.0,
363*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
364*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
365*993b0882SAndroid Build Coastguard Worker /*type=*/"add_contact2",
366*993b0882SAndroid Build Coastguard Worker /*score=*/0.5,
367*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0,
368*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
369*993b0882SAndroid Build Coastguard Worker }
370*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
371*993b0882SAndroid Build Coastguard Worker options.group_by_annotations = true;
372*993b0882SAndroid Build Coastguard Worker options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
373*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
374*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
375*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
376*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
377*993b0882SAndroid Build Coastguard Worker /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
378*993b0882SAndroid Build Coastguard Worker
379*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
380*993b0882SAndroid Build Coastguard Worker
381*993b0882SAndroid Build Coastguard Worker // The text reply should be last, even though it's score is higher than
382*993b0882SAndroid Build Coastguard Worker // any other scores -- because it's priority_score is lower than the max
383*993b0882SAndroid Build Coastguard Worker // of those with the 'phone' annotation
384*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions,
385*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({
386*993b0882SAndroid Build Coastguard Worker // Group 1 (Phone annotation)
387*993b0882SAndroid Build Coastguard Worker IsAction("add_contact2", "", 0.5), // priority_score=1.0
388*993b0882SAndroid Build Coastguard Worker IsAction("add_contact", "", 0.0), // priority_score=1.0
389*993b0882SAndroid Build Coastguard Worker IsAction("call_phone", "", 1.0), // priority_score=0.0
390*993b0882SAndroid Build Coastguard Worker IsAction("text_reply", "How are you?", 2.0), // Group 2
391*993b0882SAndroid Build Coastguard Worker }));
392*993b0882SAndroid Build Coastguard Worker }
393*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,SortsActionsByScore)394*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, SortsActionsByScore) {
395*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
396*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
397*993b0882SAndroid Build Coastguard Worker {
398*993b0882SAndroid Build Coastguard Worker ActionSuggestionAnnotation annotation;
399*993b0882SAndroid Build Coastguard Worker annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
400*993b0882SAndroid Build Coastguard Worker /*text=*/"911"};
401*993b0882SAndroid Build Coastguard Worker annotation.entity = ClassificationResult("phone", 1.0);
402*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
403*993b0882SAndroid Build Coastguard Worker /*type=*/"call_phone",
404*993b0882SAndroid Build Coastguard Worker /*score=*/1.0,
405*993b0882SAndroid Build Coastguard Worker /*priority_score=*/0.0,
406*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
407*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
408*993b0882SAndroid Build Coastguard Worker /*type=*/"add_contact",
409*993b0882SAndroid Build Coastguard Worker /*score=*/0.0,
410*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0,
411*993b0882SAndroid Build Coastguard Worker /*annotations=*/{annotation}});
412*993b0882SAndroid Build Coastguard Worker }
413*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"How are you?",
414*993b0882SAndroid Build Coastguard Worker /*type=*/"text_reply",
415*993b0882SAndroid Build Coastguard Worker /*score=*/0.5});
416*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
417*993b0882SAndroid Build Coastguard Worker // Don't group by annotation.
418*993b0882SAndroid Build Coastguard Worker options.group_by_annotations = false;
419*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
420*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
421*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
422*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
423*993b0882SAndroid Build Coastguard Worker /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
424*993b0882SAndroid Build Coastguard Worker
425*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
426*993b0882SAndroid Build Coastguard Worker
427*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(
428*993b0882SAndroid Build Coastguard Worker response.actions,
429*993b0882SAndroid Build Coastguard Worker testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
430*993b0882SAndroid Build Coastguard Worker IsAction("text_reply", "How are you?", 0.5),
431*993b0882SAndroid Build Coastguard Worker IsAction("add_contact", "", 0.0)}));
432*993b0882SAndroid Build Coastguard Worker }
433*993b0882SAndroid Build Coastguard Worker
TEST(RankingTest,SortsActionsByPriority)434*993b0882SAndroid Build Coastguard Worker TEST(RankingTest, SortsActionsByPriority) {
435*993b0882SAndroid Build Coastguard Worker const Conversation conversation = {{{/*user_id=*/1, "hello?"}}};
436*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse response;
437*993b0882SAndroid Build Coastguard Worker // emoji replies given higher priority_score
438*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
439*993b0882SAndroid Build Coastguard Worker /*type=*/"text_reply",
440*993b0882SAndroid Build Coastguard Worker /*score=*/0.5,
441*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0});
442*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"",
443*993b0882SAndroid Build Coastguard Worker /*type=*/"text_reply",
444*993b0882SAndroid Build Coastguard Worker /*score=*/0.4,
445*993b0882SAndroid Build Coastguard Worker /*priority_score=*/1.0});
446*993b0882SAndroid Build Coastguard Worker response.actions.push_back({/*response_text=*/"Yes",
447*993b0882SAndroid Build Coastguard Worker /*type=*/"text_reply",
448*993b0882SAndroid Build Coastguard Worker /*score=*/1.0,
449*993b0882SAndroid Build Coastguard Worker /*priority_score=*/0.0});
450*993b0882SAndroid Build Coastguard Worker RankingOptionsT options;
451*993b0882SAndroid Build Coastguard Worker // Don't group by annotation.
452*993b0882SAndroid Build Coastguard Worker options.group_by_annotations = false;
453*993b0882SAndroid Build Coastguard Worker options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
454*993b0882SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
455*993b0882SAndroid Build Coastguard Worker builder.Finish(RankingOptions::Pack(builder, &options));
456*993b0882SAndroid Build Coastguard Worker auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
457*993b0882SAndroid Build Coastguard Worker flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
458*993b0882SAndroid Build Coastguard Worker /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
459*993b0882SAndroid Build Coastguard Worker
460*993b0882SAndroid Build Coastguard Worker ranker->RankActions(conversation, &response);
461*993b0882SAndroid Build Coastguard Worker
462*993b0882SAndroid Build Coastguard Worker EXPECT_THAT(response.actions, testing::ElementsAreArray(
463*993b0882SAndroid Build Coastguard Worker {IsAction("text_reply", "", 0.5),
464*993b0882SAndroid Build Coastguard Worker IsAction("text_reply", "", 0.4),
465*993b0882SAndroid Build Coastguard Worker // Ranked last because of priority score
466*993b0882SAndroid Build Coastguard Worker IsAction("text_reply", "Yes", 1.0)}));
467*993b0882SAndroid Build Coastguard Worker }
468*993b0882SAndroid Build Coastguard Worker
469*993b0882SAndroid Build Coastguard Worker } // namespace
470*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
471