xref: /aosp_15_r20/external/libtextclassifier/native/actions/actions-suggestions_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/actions-suggestions.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <fstream>
20*993b0882SAndroid Build Coastguard Worker #include <iterator>
21*993b0882SAndroid Build Coastguard Worker #include <memory>
22*993b0882SAndroid Build Coastguard Worker #include <string>
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker #include "actions/actions_model_generated.h"
25*993b0882SAndroid Build Coastguard Worker #include "actions/test-utils.h"
26*993b0882SAndroid Build Coastguard Worker #include "actions/zlib-utils.h"
27*993b0882SAndroid Build Coastguard Worker #include "annotator/collections.h"
28*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
29*993b0882SAndroid Build Coastguard Worker #include "utils/flatbuffers/flatbuffers.h"
30*993b0882SAndroid Build Coastguard Worker #include "utils/flatbuffers/flatbuffers_generated.h"
31*993b0882SAndroid Build Coastguard Worker #include "utils/flatbuffers/mutable.h"
32*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/utils/locale-shard-map.h"
33*993b0882SAndroid Build Coastguard Worker #include "utils/grammar/utils/rules.h"
34*993b0882SAndroid Build Coastguard Worker #include "utils/hash/farmhash.h"
35*993b0882SAndroid Build Coastguard Worker #include "utils/jvm-test-utils.h"
36*993b0882SAndroid Build Coastguard Worker #include "utils/test-data-test-utils.h"
37*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
38*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
39*993b0882SAndroid Build Coastguard Worker #include "flatbuffers/flatbuffers.h"
40*993b0882SAndroid Build Coastguard Worker #include "flatbuffers/reflection.h"
41*993b0882SAndroid Build Coastguard Worker 
42*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
43*993b0882SAndroid Build Coastguard Worker namespace {
44*993b0882SAndroid Build Coastguard Worker 
45*993b0882SAndroid Build Coastguard Worker using ::testing::ElementsAre;
46*993b0882SAndroid Build Coastguard Worker using ::testing::FloatEq;
47*993b0882SAndroid Build Coastguard Worker using ::testing::IsEmpty;
48*993b0882SAndroid Build Coastguard Worker using ::testing::NotNull;
49*993b0882SAndroid Build Coastguard Worker using ::testing::SizeIs;
50*993b0882SAndroid Build Coastguard Worker 
51*993b0882SAndroid Build Coastguard Worker constexpr char kModelFileName[] = "actions_suggestions_test.model";
52*993b0882SAndroid Build Coastguard Worker constexpr char kModelGrammarFileName[] =
53*993b0882SAndroid Build Coastguard Worker     "actions_suggestions_grammar_test.model";
54*993b0882SAndroid Build Coastguard Worker constexpr char kMultiTaskTF2TestModelFileName[] =
55*993b0882SAndroid Build Coastguard Worker     "actions_suggestions_test.multi_task_tf2_test.model";
56*993b0882SAndroid Build Coastguard Worker constexpr char kMultiTaskModelFileName[] =
57*993b0882SAndroid Build Coastguard Worker     "actions_suggestions_test.multi_task_9heads.model";
58*993b0882SAndroid Build Coastguard Worker constexpr char kHashGramModelFileName[] =
59*993b0882SAndroid Build Coastguard Worker     "actions_suggestions_test.hashgram.model";
60*993b0882SAndroid Build Coastguard Worker constexpr char kMultiTaskSrP13nModelFileName[] =
61*993b0882SAndroid Build Coastguard Worker     "actions_suggestions_test.multi_task_sr_p13n.model";
62*993b0882SAndroid Build Coastguard Worker constexpr char kMultiTaskSrEmojiModelFileName[] =
63*993b0882SAndroid Build Coastguard Worker     "actions_suggestions_test.multi_task_sr_emoji.model";
64*993b0882SAndroid Build Coastguard Worker constexpr char kMultiTaskSrEmojiConceptModelFileName[] =
65*993b0882SAndroid Build Coastguard Worker     "actions_suggestions_test.multi_task_sr_emoji_concept.model";
66*993b0882SAndroid Build Coastguard Worker constexpr char kSensitiveTFliteModelFileName[] =
67*993b0882SAndroid Build Coastguard Worker     "actions_suggestions_test.sensitive_tflite.model";
68*993b0882SAndroid Build Coastguard Worker constexpr char kLiveRelayTFLiteModelFileName[] =
69*993b0882SAndroid Build Coastguard Worker     "actions_suggestions_test.live_relay.model";
70*993b0882SAndroid Build Coastguard Worker 
ReadFile(const std::string & file_name)71*993b0882SAndroid Build Coastguard Worker std::string ReadFile(const std::string& file_name) {
72*993b0882SAndroid Build Coastguard Worker   std::ifstream file_stream(file_name);
73*993b0882SAndroid Build Coastguard Worker   return std::string(std::istreambuf_iterator<char>(file_stream), {});
74*993b0882SAndroid Build Coastguard Worker }
75*993b0882SAndroid Build Coastguard Worker 
GetModelPath()76*993b0882SAndroid Build Coastguard Worker std::string GetModelPath() { return GetTestDataPath("actions/test_data/"); }
77*993b0882SAndroid Build Coastguard Worker 
78*993b0882SAndroid Build Coastguard Worker class ActionsSuggestionsTest : public testing::Test {
79*993b0882SAndroid Build Coastguard Worker  protected:
ActionsSuggestionsTest()80*993b0882SAndroid Build Coastguard Worker   explicit ActionsSuggestionsTest() : unilib_(CreateUniLibForTesting()) {}
LoadTestModel(const std::string model_file_name)81*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> LoadTestModel(
82*993b0882SAndroid Build Coastguard Worker       const std::string model_file_name) {
83*993b0882SAndroid Build Coastguard Worker     return ActionsSuggestions::FromPath(GetModelPath() + model_file_name,
84*993b0882SAndroid Build Coastguard Worker                                         unilib_.get());
85*993b0882SAndroid Build Coastguard Worker   }
LoadHashGramTestModel()86*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
87*993b0882SAndroid Build Coastguard Worker     return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
88*993b0882SAndroid Build Coastguard Worker                                         unilib_.get());
89*993b0882SAndroid Build Coastguard Worker   }
LoadMultiTaskTestModel()90*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> LoadMultiTaskTestModel() {
91*993b0882SAndroid Build Coastguard Worker     return ActionsSuggestions::FromPath(
92*993b0882SAndroid Build Coastguard Worker         GetModelPath() + kMultiTaskModelFileName, unilib_.get());
93*993b0882SAndroid Build Coastguard Worker   }
94*993b0882SAndroid Build Coastguard Worker 
LoadMultiTaskSrP13nTestModel()95*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> LoadMultiTaskSrP13nTestModel() {
96*993b0882SAndroid Build Coastguard Worker     return ActionsSuggestions::FromPath(
97*993b0882SAndroid Build Coastguard Worker         GetModelPath() + kMultiTaskSrP13nModelFileName, unilib_.get());
98*993b0882SAndroid Build Coastguard Worker   }
99*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<UniLib> unilib_;
100*993b0882SAndroid Build Coastguard Worker };
101*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,InstantiateActionSuggestions)102*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
103*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(LoadTestModel(kModelFileName), NotNull());
104*993b0882SAndroid Build Coastguard Worker }
105*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,ProducesEmptyResponseOnInvalidInput)106*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidInput) {
107*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
108*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelFileName);
109*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
110*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
111*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Where are you?\xf0\x9f",
112*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
113*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
114*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
115*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, IsEmpty());
116*993b0882SAndroid Build Coastguard Worker }
117*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,ProducesEmptyResponseOnInvalidUtf8)118*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidUtf8) {
119*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
120*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelFileName);
121*993b0882SAndroid Build Coastguard Worker 
122*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
123*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
124*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1,
125*993b0882SAndroid Build Coastguard Worker              "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80",
126*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
127*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
128*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
129*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, IsEmpty());
130*993b0882SAndroid Build Coastguard Worker }
131*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActions)132*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActions) {
133*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
134*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelFileName);
135*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
136*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
137*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
138*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
139*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
140*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
141*993b0882SAndroid Build Coastguard Worker }
142*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsNoActionsForUnknownLocale)143*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsNoActionsForUnknownLocale) {
144*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
145*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelFileName);
146*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
147*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
148*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
149*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
150*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"zz"}}});
151*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, testing::IsEmpty());
152*993b0882SAndroid Build Coastguard Worker }
153*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotations)154*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotations) {
155*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
156*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelFileName);
157*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan annotation;
158*993b0882SAndroid Build Coastguard Worker   annotation.span = {11, 15};
159*993b0882SAndroid Build Coastguard Worker   annotation.classification = {ClassificationResult("address", 1.0)};
160*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
161*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
162*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "are you at home?",
163*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
164*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
165*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{annotation},
166*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
167*993b0882SAndroid Build Coastguard Worker   ASSERT_GE(response.actions.size(), 1);
168*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().type, "view_map");
169*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().score, 1.0);
170*993b0882SAndroid Build Coastguard Worker }
171*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotationsWithEntityData)172*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotationsWithEntityData) {
173*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
174*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
175*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
176*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
177*993b0882SAndroid Build Coastguard Worker   SetTestEntityDataSchema(actions_model.get());
178*993b0882SAndroid Build Coastguard Worker 
179*993b0882SAndroid Build Coastguard Worker   // Set custom actions from annotations config.
180*993b0882SAndroid Build Coastguard Worker   actions_model->annotation_actions_spec->annotation_mapping.clear();
181*993b0882SAndroid Build Coastguard Worker   actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
182*993b0882SAndroid Build Coastguard Worker       new AnnotationActionsSpec_::AnnotationMappingT);
183*993b0882SAndroid Build Coastguard Worker   AnnotationActionsSpec_::AnnotationMappingT* mapping =
184*993b0882SAndroid Build Coastguard Worker       actions_model->annotation_actions_spec->annotation_mapping.back().get();
185*993b0882SAndroid Build Coastguard Worker   mapping->annotation_collection = "address";
186*993b0882SAndroid Build Coastguard Worker   mapping->action.reset(new ActionSuggestionSpecT);
187*993b0882SAndroid Build Coastguard Worker   mapping->action->type = "save_location";
188*993b0882SAndroid Build Coastguard Worker   mapping->action->score = 1.0;
189*993b0882SAndroid Build Coastguard Worker   mapping->action->priority_score = 2.0;
190*993b0882SAndroid Build Coastguard Worker   mapping->entity_field.reset(new FlatbufferFieldPathT);
191*993b0882SAndroid Build Coastguard Worker   mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
192*993b0882SAndroid Build Coastguard Worker   mapping->entity_field->field.back()->field_name = "location";
193*993b0882SAndroid Build Coastguard Worker 
194*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
195*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
196*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
197*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
198*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
199*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
200*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
201*993b0882SAndroid Build Coastguard Worker 
202*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan annotation;
203*993b0882SAndroid Build Coastguard Worker   annotation.span = {11, 15};
204*993b0882SAndroid Build Coastguard Worker   annotation.classification = {ClassificationResult("address", 1.0)};
205*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
206*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
207*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "are you at home?",
208*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
209*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
210*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{annotation},
211*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
212*993b0882SAndroid Build Coastguard Worker   ASSERT_GE(response.actions.size(), 1);
213*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().type, "save_location");
214*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().score, 1.0);
215*993b0882SAndroid Build Coastguard Worker 
216*993b0882SAndroid Build Coastguard Worker   // Check that the `location` entity field holds the text from the address
217*993b0882SAndroid Build Coastguard Worker   // annotation.
218*993b0882SAndroid Build Coastguard Worker   const flatbuffers::Table* entity =
219*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
220*993b0882SAndroid Build Coastguard Worker           response.actions.front().serialized_entity_data.data()));
221*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
222*993b0882SAndroid Build Coastguard Worker             "home");
223*993b0882SAndroid Build Coastguard Worker }
224*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromAnnotationsWithNormalization)225*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest,
226*993b0882SAndroid Build Coastguard Worker        SuggestsActionsFromAnnotationsWithNormalization) {
227*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
228*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
229*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
230*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
231*993b0882SAndroid Build Coastguard Worker   SetTestEntityDataSchema(actions_model.get());
232*993b0882SAndroid Build Coastguard Worker 
233*993b0882SAndroid Build Coastguard Worker   // Set custom actions from annotations config.
234*993b0882SAndroid Build Coastguard Worker   actions_model->annotation_actions_spec->annotation_mapping.clear();
235*993b0882SAndroid Build Coastguard Worker   actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
236*993b0882SAndroid Build Coastguard Worker       new AnnotationActionsSpec_::AnnotationMappingT);
237*993b0882SAndroid Build Coastguard Worker   AnnotationActionsSpec_::AnnotationMappingT* mapping =
238*993b0882SAndroid Build Coastguard Worker       actions_model->annotation_actions_spec->annotation_mapping.back().get();
239*993b0882SAndroid Build Coastguard Worker   mapping->annotation_collection = "address";
240*993b0882SAndroid Build Coastguard Worker   mapping->action.reset(new ActionSuggestionSpecT);
241*993b0882SAndroid Build Coastguard Worker   mapping->action->type = "save_location";
242*993b0882SAndroid Build Coastguard Worker   mapping->action->score = 1.0;
243*993b0882SAndroid Build Coastguard Worker   mapping->action->priority_score = 2.0;
244*993b0882SAndroid Build Coastguard Worker   mapping->entity_field.reset(new FlatbufferFieldPathT);
245*993b0882SAndroid Build Coastguard Worker   mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
246*993b0882SAndroid Build Coastguard Worker   mapping->entity_field->field.back()->field_name = "location";
247*993b0882SAndroid Build Coastguard Worker   mapping->normalization_options.reset(new NormalizationOptionsT);
248*993b0882SAndroid Build Coastguard Worker   mapping->normalization_options->codepointwise_normalization =
249*993b0882SAndroid Build Coastguard Worker       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
250*993b0882SAndroid Build Coastguard Worker 
251*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
252*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
253*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
254*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
255*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
256*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
257*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
258*993b0882SAndroid Build Coastguard Worker 
259*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan annotation;
260*993b0882SAndroid Build Coastguard Worker   annotation.span = {11, 15};
261*993b0882SAndroid Build Coastguard Worker   annotation.classification = {ClassificationResult("address", 1.0)};
262*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
263*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
264*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "are you at home?",
265*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
266*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
267*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{annotation},
268*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
269*993b0882SAndroid Build Coastguard Worker   ASSERT_GE(response.actions.size(), 1);
270*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().type, "save_location");
271*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().score, 1.0);
272*993b0882SAndroid Build Coastguard Worker 
273*993b0882SAndroid Build Coastguard Worker   // Check that the `location` entity field holds the normalized text of the
274*993b0882SAndroid Build Coastguard Worker   // annotation.
275*993b0882SAndroid Build Coastguard Worker   const flatbuffers::Table* entity =
276*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
277*993b0882SAndroid Build Coastguard Worker           response.actions.front().serialized_entity_data.data()));
278*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
279*993b0882SAndroid Build Coastguard Worker             "HOME");
280*993b0882SAndroid Build Coastguard Worker }
281*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromDuplicatedAnnotations)282*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsFromDuplicatedAnnotations) {
283*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
284*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelFileName);
285*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan flight_annotation;
286*993b0882SAndroid Build Coastguard Worker   flight_annotation.span = {11, 15};
287*993b0882SAndroid Build Coastguard Worker   flight_annotation.classification = {ClassificationResult("flight", 2.5)};
288*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan flight_annotation2;
289*993b0882SAndroid Build Coastguard Worker   flight_annotation2.span = {35, 39};
290*993b0882SAndroid Build Coastguard Worker   flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
291*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan email_annotation;
292*993b0882SAndroid Build Coastguard Worker   email_annotation.span = {43, 56};
293*993b0882SAndroid Build Coastguard Worker   email_annotation.classification = {ClassificationResult("email", 2.0)};
294*993b0882SAndroid Build Coastguard Worker 
295*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
296*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
297*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1,
298*993b0882SAndroid Build Coastguard Worker              "call me at LX38 or send message to LX38 or [email protected].",
299*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
300*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
301*993b0882SAndroid Build Coastguard Worker              /*annotations=*/
302*993b0882SAndroid Build Coastguard Worker              {flight_annotation, flight_annotation2, email_annotation},
303*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
304*993b0882SAndroid Build Coastguard Worker 
305*993b0882SAndroid Build Coastguard Worker   ASSERT_GE(response.actions.size(), 2);
306*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "track_flight");
307*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].score, 3.0);
308*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "send_email");
309*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].score, 2.0);
310*993b0882SAndroid Build Coastguard Worker }
311*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsAnnotationsWithNoDeduplication)312*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsAnnotationsWithNoDeduplication) {
313*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
314*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
315*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
316*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
317*993b0882SAndroid Build Coastguard Worker   // Disable deduplication.
318*993b0882SAndroid Build Coastguard Worker   actions_model->annotation_actions_spec->deduplicate_annotations = false;
319*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
320*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
321*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
322*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
323*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
324*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
325*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
326*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan flight_annotation;
327*993b0882SAndroid Build Coastguard Worker   flight_annotation.span = {11, 15};
328*993b0882SAndroid Build Coastguard Worker   flight_annotation.classification = {ClassificationResult("flight", 2.5)};
329*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan flight_annotation2;
330*993b0882SAndroid Build Coastguard Worker   flight_annotation2.span = {35, 39};
331*993b0882SAndroid Build Coastguard Worker   flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
332*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan email_annotation;
333*993b0882SAndroid Build Coastguard Worker   email_annotation.span = {43, 56};
334*993b0882SAndroid Build Coastguard Worker   email_annotation.classification = {ClassificationResult("email", 2.0)};
335*993b0882SAndroid Build Coastguard Worker 
336*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
337*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
338*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1,
339*993b0882SAndroid Build Coastguard Worker              "call me at LX38 or send message to LX38 or [email protected].",
340*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
341*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
342*993b0882SAndroid Build Coastguard Worker              /*annotations=*/
343*993b0882SAndroid Build Coastguard Worker              {flight_annotation, flight_annotation2, email_annotation},
344*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
345*993b0882SAndroid Build Coastguard Worker 
346*993b0882SAndroid Build Coastguard Worker   ASSERT_GE(response.actions.size(), 3);
347*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "track_flight");
348*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].score, 3.0);
349*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "track_flight");
350*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].score, 2.5);
351*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[2].type, "send_email");
352*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[2].score, 2.0);
353*993b0882SAndroid Build Coastguard Worker }
354*993b0882SAndroid Build Coastguard Worker 
TestSuggestActionsFromAnnotations(const std::function<void (ActionsModelT *)> & set_config_fn,const UniLib * unilib=nullptr)355*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
356*993b0882SAndroid Build Coastguard Worker     const std::function<void(ActionsModelT*)>& set_config_fn,
357*993b0882SAndroid Build Coastguard Worker     const UniLib* unilib = nullptr) {
358*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
359*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
360*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
361*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
362*993b0882SAndroid Build Coastguard Worker 
363*993b0882SAndroid Build Coastguard Worker   // Set custom config.
364*993b0882SAndroid Build Coastguard Worker   set_config_fn(actions_model.get());
365*993b0882SAndroid Build Coastguard Worker 
366*993b0882SAndroid Build Coastguard Worker   // Disable smart reply for easier testing.
367*993b0882SAndroid Build Coastguard Worker   actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
368*993b0882SAndroid Build Coastguard Worker 
369*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
370*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
371*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
372*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
373*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
374*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
375*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib);
376*993b0882SAndroid Build Coastguard Worker 
377*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan flight_annotation;
378*993b0882SAndroid Build Coastguard Worker   flight_annotation.span = {15, 19};
379*993b0882SAndroid Build Coastguard Worker   flight_annotation.classification = {ClassificationResult("flight", 2.0)};
380*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan email_annotation;
381*993b0882SAndroid Build Coastguard Worker   email_annotation.span = {0, 16};
382*993b0882SAndroid Build Coastguard Worker   email_annotation.classification = {ClassificationResult("email", 1.0)};
383*993b0882SAndroid Build Coastguard Worker 
384*993b0882SAndroid Build Coastguard Worker   return actions_suggestions->SuggestActions(
385*993b0882SAndroid Build Coastguard Worker       {{{/*user_id=*/ActionsSuggestions::kLocalUserId,
386*993b0882SAndroid Build Coastguard Worker          "[email protected]",
387*993b0882SAndroid Build Coastguard Worker          /*reference_time_ms_utc=*/0,
388*993b0882SAndroid Build Coastguard Worker          /*reference_timezone=*/"Europe/Zurich",
389*993b0882SAndroid Build Coastguard Worker          /*annotations=*/
390*993b0882SAndroid Build Coastguard Worker          {email_annotation},
391*993b0882SAndroid Build Coastguard Worker          /*locales=*/"en"},
392*993b0882SAndroid Build Coastguard Worker         {/*user_id=*/2,
393*993b0882SAndroid Build Coastguard Worker          "[email protected]",
394*993b0882SAndroid Build Coastguard Worker          /*reference_time_ms_utc=*/0,
395*993b0882SAndroid Build Coastguard Worker          /*reference_timezone=*/"Europe/Zurich",
396*993b0882SAndroid Build Coastguard Worker          /*annotations=*/
397*993b0882SAndroid Build Coastguard Worker          {email_annotation},
398*993b0882SAndroid Build Coastguard Worker          /*locales=*/"en"},
399*993b0882SAndroid Build Coastguard Worker         {/*user_id=*/1,
400*993b0882SAndroid Build Coastguard Worker          "[email protected]",
401*993b0882SAndroid Build Coastguard Worker          /*reference_time_ms_utc=*/0,
402*993b0882SAndroid Build Coastguard Worker          /*reference_timezone=*/"Europe/Zurich",
403*993b0882SAndroid Build Coastguard Worker          /*annotations=*/
404*993b0882SAndroid Build Coastguard Worker          {email_annotation},
405*993b0882SAndroid Build Coastguard Worker          /*locales=*/"en"},
406*993b0882SAndroid Build Coastguard Worker         {/*user_id=*/1,
407*993b0882SAndroid Build Coastguard Worker          "I am on flight LX38.",
408*993b0882SAndroid Build Coastguard Worker          /*reference_time_ms_utc=*/0,
409*993b0882SAndroid Build Coastguard Worker          /*reference_timezone=*/"Europe/Zurich",
410*993b0882SAndroid Build Coastguard Worker          /*annotations=*/
411*993b0882SAndroid Build Coastguard Worker          {flight_annotation},
412*993b0882SAndroid Build Coastguard Worker          /*locales=*/"en"}}});
413*993b0882SAndroid Build Coastguard Worker }
414*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsOnlyLastMessage)415*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastMessage) {
416*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
417*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
418*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->include_local_user_messages =
419*993b0882SAndroid Build Coastguard Worker             false;
420*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->only_until_last_sent = true;
421*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_any_person = 1;
422*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_last_person =
423*993b0882SAndroid Build Coastguard Worker             1;
424*993b0882SAndroid Build Coastguard Worker       },
425*993b0882SAndroid Build Coastguard Worker       unilib_.get());
426*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, SizeIs(1));
427*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "track_flight");
428*993b0882SAndroid Build Coastguard Worker }
429*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsOnlyLastPerson)430*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastPerson) {
431*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
432*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
433*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->include_local_user_messages =
434*993b0882SAndroid Build Coastguard Worker             false;
435*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->only_until_last_sent = true;
436*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_any_person = 1;
437*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_last_person =
438*993b0882SAndroid Build Coastguard Worker             3;
439*993b0882SAndroid Build Coastguard Worker       },
440*993b0882SAndroid Build Coastguard Worker       unilib_.get());
441*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, SizeIs(2));
442*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "track_flight");
443*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "send_email");
444*993b0882SAndroid Build Coastguard Worker }
445*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAny)446*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAny) {
447*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
448*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
449*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->include_local_user_messages =
450*993b0882SAndroid Build Coastguard Worker             false;
451*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->only_until_last_sent = true;
452*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_any_person = 2;
453*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_last_person =
454*993b0882SAndroid Build Coastguard Worker             1;
455*993b0882SAndroid Build Coastguard Worker       },
456*993b0882SAndroid Build Coastguard Worker       unilib_.get());
457*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, SizeIs(2));
458*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "track_flight");
459*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "send_email");
460*993b0882SAndroid Build Coastguard Worker }
461*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessages)462*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest,
463*993b0882SAndroid Build Coastguard Worker        SuggestsActionsWithAnnotationsFromAnyManyMessages) {
464*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
465*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
466*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->include_local_user_messages =
467*993b0882SAndroid Build Coastguard Worker             false;
468*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->only_until_last_sent = true;
469*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_any_person = 3;
470*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_last_person =
471*993b0882SAndroid Build Coastguard Worker             1;
472*993b0882SAndroid Build Coastguard Worker       },
473*993b0882SAndroid Build Coastguard Worker       unilib_.get());
474*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, SizeIs(3));
475*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "track_flight");
476*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "send_email");
477*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[2].type, "send_email");
478*993b0882SAndroid Build Coastguard Worker }
479*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser)480*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest,
481*993b0882SAndroid Build Coastguard Worker        SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
482*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
483*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
484*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->include_local_user_messages =
485*993b0882SAndroid Build Coastguard Worker             false;
486*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->only_until_last_sent = true;
487*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_any_person = 5;
488*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_last_person =
489*993b0882SAndroid Build Coastguard Worker             1;
490*993b0882SAndroid Build Coastguard Worker       },
491*993b0882SAndroid Build Coastguard Worker       unilib_.get());
492*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, SizeIs(3));
493*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "track_flight");
494*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "send_email");
495*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[2].type, "send_email");
496*993b0882SAndroid Build Coastguard Worker }
497*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser)498*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest,
499*993b0882SAndroid Build Coastguard Worker        SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
500*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
501*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
502*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->include_local_user_messages =
503*993b0882SAndroid Build Coastguard Worker             true;
504*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->only_until_last_sent = false;
505*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_any_person = 5;
506*993b0882SAndroid Build Coastguard Worker         actions_model->annotation_actions_spec->max_history_from_last_person =
507*993b0882SAndroid Build Coastguard Worker             1;
508*993b0882SAndroid Build Coastguard Worker       },
509*993b0882SAndroid Build Coastguard Worker       unilib_.get());
510*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, SizeIs(4));
511*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "track_flight");
512*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "send_email");
513*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[2].type, "send_email");
514*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[3].type, "send_email");
515*993b0882SAndroid Build Coastguard Worker }
516*993b0882SAndroid Build Coastguard Worker 
TestSuggestActionsWithThreshold(const std::function<void (ActionsModelT *)> & set_value_fn,const UniLib * unilib=nullptr,const int expected_size=0,const std::string & preconditions_overwrite="")517*993b0882SAndroid Build Coastguard Worker void TestSuggestActionsWithThreshold(
518*993b0882SAndroid Build Coastguard Worker     const std::function<void(ActionsModelT*)>& set_value_fn,
519*993b0882SAndroid Build Coastguard Worker     const UniLib* unilib = nullptr, const int expected_size = 0,
520*993b0882SAndroid Build Coastguard Worker     const std::string& preconditions_overwrite = "") {
521*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
522*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
523*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
524*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
525*993b0882SAndroid Build Coastguard Worker   set_value_fn(actions_model.get());
526*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
527*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
528*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
529*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
530*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
531*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
532*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib, preconditions_overwrite);
533*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(actions_suggestions);
534*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
535*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
536*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "I have the low-ground. Where are you?",
537*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
538*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
539*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
540*993b0882SAndroid Build Coastguard Worker   EXPECT_LE(response.actions.size(), expected_size);
541*993b0882SAndroid Build Coastguard Worker }
542*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithTriggeringScore)543*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithTriggeringScore) {
544*993b0882SAndroid Build Coastguard Worker   TestSuggestActionsWithThreshold(
545*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
546*993b0882SAndroid Build Coastguard Worker         actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
547*993b0882SAndroid Build Coastguard Worker       },
548*993b0882SAndroid Build Coastguard Worker       unilib_.get(),
549*993b0882SAndroid Build Coastguard Worker       /*expected_size=*/1 /*no smart reply, only actions*/
550*993b0882SAndroid Build Coastguard Worker   );
551*993b0882SAndroid Build Coastguard Worker }
552*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMinReplyScore)553*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinReplyScore) {
554*993b0882SAndroid Build Coastguard Worker   TestSuggestActionsWithThreshold(
555*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
556*993b0882SAndroid Build Coastguard Worker         actions_model->preconditions->min_reply_score_threshold = 1.0;
557*993b0882SAndroid Build Coastguard Worker       },
558*993b0882SAndroid Build Coastguard Worker       unilib_.get(),
559*993b0882SAndroid Build Coastguard Worker       /*expected_size=*/1 /*no smart reply, only actions*/
560*993b0882SAndroid Build Coastguard Worker   );
561*993b0882SAndroid Build Coastguard Worker }
562*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithSensitiveTopicScore)563*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithSensitiveTopicScore) {
564*993b0882SAndroid Build Coastguard Worker   TestSuggestActionsWithThreshold(
565*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
566*993b0882SAndroid Build Coastguard Worker         actions_model->preconditions->max_sensitive_topic_score = 0.0;
567*993b0882SAndroid Build Coastguard Worker       },
568*993b0882SAndroid Build Coastguard Worker       unilib_.get(),
569*993b0882SAndroid Build Coastguard Worker       /*expected_size=*/4 /* no sensitive prediction in test model*/);
570*993b0882SAndroid Build Coastguard Worker }
571*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMaxInputLength)572*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMaxInputLength) {
573*993b0882SAndroid Build Coastguard Worker   TestSuggestActionsWithThreshold(
574*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
575*993b0882SAndroid Build Coastguard Worker         actions_model->preconditions->max_input_length = 0;
576*993b0882SAndroid Build Coastguard Worker       },
577*993b0882SAndroid Build Coastguard Worker       unilib_.get());
578*993b0882SAndroid Build Coastguard Worker }
579*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithMinInputLength)580*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinInputLength) {
581*993b0882SAndroid Build Coastguard Worker   TestSuggestActionsWithThreshold(
582*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
583*993b0882SAndroid Build Coastguard Worker         actions_model->preconditions->min_input_length = 100;
584*993b0882SAndroid Build Coastguard Worker       },
585*993b0882SAndroid Build Coastguard Worker       unilib_.get());
586*993b0882SAndroid Build Coastguard Worker }
587*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithPreconditionsOverwrite)588*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithPreconditionsOverwrite) {
589*993b0882SAndroid Build Coastguard Worker   TriggeringPreconditionsT preconditions_overwrite;
590*993b0882SAndroid Build Coastguard Worker   preconditions_overwrite.max_input_length = 0;
591*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
592*993b0882SAndroid Build Coastguard Worker   builder.Finish(
593*993b0882SAndroid Build Coastguard Worker       TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
594*993b0882SAndroid Build Coastguard Worker   TestSuggestActionsWithThreshold(
595*993b0882SAndroid Build Coastguard Worker       // Keep model untouched.
596*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {}, unilib_.get(),
597*993b0882SAndroid Build Coastguard Worker       /*expected_size=*/0,
598*993b0882SAndroid Build Coastguard Worker       std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
599*993b0882SAndroid Build Coastguard Worker                   builder.GetSize()));
600*993b0882SAndroid Build Coastguard Worker }
601*993b0882SAndroid Build Coastguard Worker 
602*993b0882SAndroid Build Coastguard Worker #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidence)603*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidence) {
604*993b0882SAndroid Build Coastguard Worker   TestSuggestActionsWithThreshold(
605*993b0882SAndroid Build Coastguard Worker       [](ActionsModelT* actions_model) {
606*993b0882SAndroid Build Coastguard Worker         actions_model->preconditions->suppress_on_low_confidence_input = true;
607*993b0882SAndroid Build Coastguard Worker         actions_model->low_confidence_rules.reset(new RulesModelT);
608*993b0882SAndroid Build Coastguard Worker         actions_model->low_confidence_rules->regex_rule.emplace_back(
609*993b0882SAndroid Build Coastguard Worker             new RulesModel_::RegexRuleT);
610*993b0882SAndroid Build Coastguard Worker         actions_model->low_confidence_rules->regex_rule.back()->pattern =
611*993b0882SAndroid Build Coastguard Worker             "low-ground";
612*993b0882SAndroid Build Coastguard Worker       },
613*993b0882SAndroid Build Coastguard Worker       unilib_.get());
614*993b0882SAndroid Build Coastguard Worker }
615*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidenceInputOutput)616*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidenceInputOutput) {
617*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
618*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
619*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
620*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
621*993b0882SAndroid Build Coastguard Worker   // Add custom triggering rule.
622*993b0882SAndroid Build Coastguard Worker   actions_model->rules.reset(new RulesModelT());
623*993b0882SAndroid Build Coastguard Worker   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
624*993b0882SAndroid Build Coastguard Worker   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
625*993b0882SAndroid Build Coastguard Worker   rule->pattern = "^(?i:hello\\s(there))$";
626*993b0882SAndroid Build Coastguard Worker   {
627*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
628*993b0882SAndroid Build Coastguard Worker         new RulesModel_::RuleActionSpecT);
629*993b0882SAndroid Build Coastguard Worker     rule_action->action.reset(new ActionSuggestionSpecT);
630*993b0882SAndroid Build Coastguard Worker     rule_action->action->type = "text_reply";
631*993b0882SAndroid Build Coastguard Worker     rule_action->action->response_text = "General Desaster!";
632*993b0882SAndroid Build Coastguard Worker     rule_action->action->score = 1.0f;
633*993b0882SAndroid Build Coastguard Worker     rule_action->action->priority_score = 1.0f;
634*993b0882SAndroid Build Coastguard Worker     rule->actions.push_back(std::move(rule_action));
635*993b0882SAndroid Build Coastguard Worker   }
636*993b0882SAndroid Build Coastguard Worker   {
637*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
638*993b0882SAndroid Build Coastguard Worker         new RulesModel_::RuleActionSpecT);
639*993b0882SAndroid Build Coastguard Worker     rule_action->action.reset(new ActionSuggestionSpecT);
640*993b0882SAndroid Build Coastguard Worker     rule_action->action->type = "text_reply";
641*993b0882SAndroid Build Coastguard Worker     rule_action->action->response_text = "General Kenobi!";
642*993b0882SAndroid Build Coastguard Worker     rule_action->action->score = 1.0f;
643*993b0882SAndroid Build Coastguard Worker     rule_action->action->priority_score = 1.0f;
644*993b0882SAndroid Build Coastguard Worker     rule->actions.push_back(std::move(rule_action));
645*993b0882SAndroid Build Coastguard Worker   }
646*993b0882SAndroid Build Coastguard Worker 
647*993b0882SAndroid Build Coastguard Worker   // Add input-output low confidence rule.
648*993b0882SAndroid Build Coastguard Worker   actions_model->preconditions->suppress_on_low_confidence_input = true;
649*993b0882SAndroid Build Coastguard Worker   actions_model->low_confidence_rules.reset(new RulesModelT);
650*993b0882SAndroid Build Coastguard Worker   actions_model->low_confidence_rules->regex_rule.emplace_back(
651*993b0882SAndroid Build Coastguard Worker       new RulesModel_::RegexRuleT);
652*993b0882SAndroid Build Coastguard Worker   actions_model->low_confidence_rules->regex_rule.back()->pattern = "hello";
653*993b0882SAndroid Build Coastguard Worker   actions_model->low_confidence_rules->regex_rule.back()->output_pattern =
654*993b0882SAndroid Build Coastguard Worker       "(?i:desaster)";
655*993b0882SAndroid Build Coastguard Worker 
656*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
657*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
658*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
659*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
660*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
661*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
662*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
663*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(actions_suggestions);
664*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
665*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
666*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "hello there",
667*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
668*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
669*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
670*993b0882SAndroid Build Coastguard Worker   ASSERT_GE(response.actions.size(), 1);
671*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
672*993b0882SAndroid Build Coastguard Worker }
673*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsLowConfidenceInputOutputOverwrite)674*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest,
675*993b0882SAndroid Build Coastguard Worker        SuggestsActionsLowConfidenceInputOutputOverwrite) {
676*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
677*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
678*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
679*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
680*993b0882SAndroid Build Coastguard Worker   actions_model->low_confidence_rules.reset();
681*993b0882SAndroid Build Coastguard Worker 
682*993b0882SAndroid Build Coastguard Worker   // Add custom triggering rule.
683*993b0882SAndroid Build Coastguard Worker   actions_model->rules.reset(new RulesModelT());
684*993b0882SAndroid Build Coastguard Worker   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
685*993b0882SAndroid Build Coastguard Worker   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
686*993b0882SAndroid Build Coastguard Worker   rule->pattern = "^(?i:hello\\s(there))$";
687*993b0882SAndroid Build Coastguard Worker   {
688*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
689*993b0882SAndroid Build Coastguard Worker         new RulesModel_::RuleActionSpecT);
690*993b0882SAndroid Build Coastguard Worker     rule_action->action.reset(new ActionSuggestionSpecT);
691*993b0882SAndroid Build Coastguard Worker     rule_action->action->type = "text_reply";
692*993b0882SAndroid Build Coastguard Worker     rule_action->action->response_text = "General Desaster!";
693*993b0882SAndroid Build Coastguard Worker     rule_action->action->score = 1.0f;
694*993b0882SAndroid Build Coastguard Worker     rule_action->action->priority_score = 1.0f;
695*993b0882SAndroid Build Coastguard Worker     rule->actions.push_back(std::move(rule_action));
696*993b0882SAndroid Build Coastguard Worker   }
697*993b0882SAndroid Build Coastguard Worker   {
698*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
699*993b0882SAndroid Build Coastguard Worker         new RulesModel_::RuleActionSpecT);
700*993b0882SAndroid Build Coastguard Worker     rule_action->action.reset(new ActionSuggestionSpecT);
701*993b0882SAndroid Build Coastguard Worker     rule_action->action->type = "text_reply";
702*993b0882SAndroid Build Coastguard Worker     rule_action->action->response_text = "General Kenobi!";
703*993b0882SAndroid Build Coastguard Worker     rule_action->action->score = 1.0f;
704*993b0882SAndroid Build Coastguard Worker     rule_action->action->priority_score = 1.0f;
705*993b0882SAndroid Build Coastguard Worker     rule->actions.push_back(std::move(rule_action));
706*993b0882SAndroid Build Coastguard Worker   }
707*993b0882SAndroid Build Coastguard Worker 
708*993b0882SAndroid Build Coastguard Worker   // Add custom triggering rule via overwrite.
709*993b0882SAndroid Build Coastguard Worker   actions_model->preconditions->low_confidence_rules.reset();
710*993b0882SAndroid Build Coastguard Worker   TriggeringPreconditionsT preconditions;
711*993b0882SAndroid Build Coastguard Worker   preconditions.suppress_on_low_confidence_input = true;
712*993b0882SAndroid Build Coastguard Worker   preconditions.low_confidence_rules.reset(new RulesModelT);
713*993b0882SAndroid Build Coastguard Worker   preconditions.low_confidence_rules->regex_rule.emplace_back(
714*993b0882SAndroid Build Coastguard Worker       new RulesModel_::RegexRuleT);
715*993b0882SAndroid Build Coastguard Worker   preconditions.low_confidence_rules->regex_rule.back()->pattern = "hello";
716*993b0882SAndroid Build Coastguard Worker   preconditions.low_confidence_rules->regex_rule.back()->output_pattern =
717*993b0882SAndroid Build Coastguard Worker       "(?i:desaster)";
718*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder preconditions_builder;
719*993b0882SAndroid Build Coastguard Worker   preconditions_builder.Finish(
720*993b0882SAndroid Build Coastguard Worker       TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
721*993b0882SAndroid Build Coastguard Worker   std::string serialize_preconditions = std::string(
722*993b0882SAndroid Build Coastguard Worker       reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
723*993b0882SAndroid Build Coastguard Worker       preconditions_builder.GetSize());
724*993b0882SAndroid Build Coastguard Worker 
725*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
726*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
727*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
728*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
729*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
730*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
731*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get(), serialize_preconditions);
732*993b0882SAndroid Build Coastguard Worker 
733*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(actions_suggestions);
734*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
735*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
736*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "hello there",
737*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
738*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
739*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
740*993b0882SAndroid Build Coastguard Worker   ASSERT_GE(response.actions.size(), 1);
741*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
742*993b0882SAndroid Build Coastguard Worker }
743*993b0882SAndroid Build Coastguard Worker #endif
744*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuppressActionsFromAnnotationsOnSensitiveTopic)745*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
746*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
747*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
748*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
749*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
750*993b0882SAndroid Build Coastguard Worker 
751*993b0882SAndroid Build Coastguard Worker   // Don't test if no sensitivity score is produced
752*993b0882SAndroid Build Coastguard Worker   if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
753*993b0882SAndroid Build Coastguard Worker     return;
754*993b0882SAndroid Build Coastguard Worker   }
755*993b0882SAndroid Build Coastguard Worker 
756*993b0882SAndroid Build Coastguard Worker   actions_model->preconditions->max_sensitive_topic_score = 0.0;
757*993b0882SAndroid Build Coastguard Worker   actions_model->preconditions->suppress_on_sensitive_topic = true;
758*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
759*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
760*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
761*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
762*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
763*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
764*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
765*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan annotation;
766*993b0882SAndroid Build Coastguard Worker   annotation.span = {11, 15};
767*993b0882SAndroid Build Coastguard Worker   annotation.classification = {
768*993b0882SAndroid Build Coastguard Worker       ClassificationResult(Collections::Address(), 1.0)};
769*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
770*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
771*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "are you at home?",
772*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
773*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
774*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{annotation},
775*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
776*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, testing::IsEmpty());
777*993b0882SAndroid Build Coastguard Worker }
778*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithLongerConversation)779*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithLongerConversation) {
780*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
781*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
782*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
783*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
784*993b0882SAndroid Build Coastguard Worker 
785*993b0882SAndroid Build Coastguard Worker   // Allow a larger conversation context.
786*993b0882SAndroid Build Coastguard Worker   actions_model->max_conversation_history_length = 10;
787*993b0882SAndroid Build Coastguard Worker 
788*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
789*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
790*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
791*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
792*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
793*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
794*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
795*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan annotation;
796*993b0882SAndroid Build Coastguard Worker   annotation.span = {11, 15};
797*993b0882SAndroid Build Coastguard Worker   annotation.classification = {
798*993b0882SAndroid Build Coastguard Worker       ClassificationResult(Collections::Address(), 1.0)};
799*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
800*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
801*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
802*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/10000,
803*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
804*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"},
805*993b0882SAndroid Build Coastguard Worker             {/*user_id=*/1, "good! are you at home?",
806*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/15000,
807*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
808*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{annotation},
809*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
810*993b0882SAndroid Build Coastguard Worker   ASSERT_GE(response.actions.size(), 1);
811*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "view_map");
812*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].score, 1.0);
813*993b0882SAndroid Build Coastguard Worker }
814*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromTF2MultiTaskModel)815*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsFromTF2MultiTaskModel) {
816*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
817*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kMultiTaskTF2TestModelFileName);
818*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
819*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
820*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Hello how are you",
821*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
822*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
823*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{},
824*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
825*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(), 4);
826*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].response_text, "Okay");
827*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "REPLY_SUGGESTION");
828*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[3].type, "TEST_CLASSIFIER_INTENT");
829*993b0882SAndroid Build Coastguard Worker }
830*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromPhoneGrammarAnnotations)831*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsFromPhoneGrammarAnnotations) {
832*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
833*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelGrammarFileName);
834*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan annotation;
835*993b0882SAndroid Build Coastguard Worker   annotation.span = {11, 15};
836*993b0882SAndroid Build Coastguard Worker   annotation.classification = {ClassificationResult("phone", 0.0)};
837*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
838*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
839*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Contact us at: *1234",
840*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
841*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
842*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{annotation},
843*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
844*993b0882SAndroid Build Coastguard Worker   ASSERT_GE(response.actions.size(), 1);
845*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().type, "call_phone");
846*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().score, 0.0);
847*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().priority_score, 0.0);
848*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().annotations.size(), 1);
849*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().annotations.front().span.span.first, 15);
850*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.front().annotations.front().span.span.second, 20);
851*993b0882SAndroid Build Coastguard Worker }
852*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,CreateActionsFromClassificationResult)853*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
854*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
855*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelFileName);
856*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan annotation;
857*993b0882SAndroid Build Coastguard Worker   annotation.span = {8, 12};
858*993b0882SAndroid Build Coastguard Worker   annotation.classification = {
859*993b0882SAndroid Build Coastguard Worker       ClassificationResult(Collections::Flight(), 1.0)};
860*993b0882SAndroid Build Coastguard Worker 
861*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
862*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
863*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "I'm on LX38?",
864*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
865*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
866*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{annotation},
867*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
868*993b0882SAndroid Build Coastguard Worker 
869*993b0882SAndroid Build Coastguard Worker   ASSERT_GE(response.actions.size(), 2);
870*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "track_flight");
871*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].score, 1.0);
872*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions[0].annotations, SizeIs(1));
873*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
874*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
875*993b0882SAndroid Build Coastguard Worker }
876*993b0882SAndroid Build Coastguard Worker 
877*993b0882SAndroid Build Coastguard Worker #ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest,CreateActionsFromRules)878*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
879*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
880*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
881*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
882*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
883*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
884*993b0882SAndroid Build Coastguard Worker 
885*993b0882SAndroid Build Coastguard Worker   actions_model->rules.reset(new RulesModelT());
886*993b0882SAndroid Build Coastguard Worker   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
887*993b0882SAndroid Build Coastguard Worker   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
888*993b0882SAndroid Build Coastguard Worker   rule->pattern = "^(?i:hello\\s(there))$";
889*993b0882SAndroid Build Coastguard Worker   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
890*993b0882SAndroid Build Coastguard Worker   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
891*993b0882SAndroid Build Coastguard Worker   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
892*993b0882SAndroid Build Coastguard Worker   action->type = "text_reply";
893*993b0882SAndroid Build Coastguard Worker   action->response_text = "General Kenobi!";
894*993b0882SAndroid Build Coastguard Worker   action->score = 1.0f;
895*993b0882SAndroid Build Coastguard Worker   action->priority_score = 1.0f;
896*993b0882SAndroid Build Coastguard Worker 
897*993b0882SAndroid Build Coastguard Worker   // Set capturing groups for entity data.
898*993b0882SAndroid Build Coastguard Worker   rule->actions.back()->capturing_group.emplace_back(
899*993b0882SAndroid Build Coastguard Worker       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
900*993b0882SAndroid Build Coastguard Worker   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
901*993b0882SAndroid Build Coastguard Worker       rule->actions.back()->capturing_group.back().get();
902*993b0882SAndroid Build Coastguard Worker   greeting_group->group_id = 0;
903*993b0882SAndroid Build Coastguard Worker   greeting_group->entity_field.reset(new FlatbufferFieldPathT);
904*993b0882SAndroid Build Coastguard Worker   greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
905*993b0882SAndroid Build Coastguard Worker   greeting_group->entity_field->field.back()->field_name = "greeting";
906*993b0882SAndroid Build Coastguard Worker   rule->actions.back()->capturing_group.emplace_back(
907*993b0882SAndroid Build Coastguard Worker       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
908*993b0882SAndroid Build Coastguard Worker   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* location_group =
909*993b0882SAndroid Build Coastguard Worker       rule->actions.back()->capturing_group.back().get();
910*993b0882SAndroid Build Coastguard Worker   location_group->group_id = 1;
911*993b0882SAndroid Build Coastguard Worker   location_group->entity_field.reset(new FlatbufferFieldPathT);
912*993b0882SAndroid Build Coastguard Worker   location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
913*993b0882SAndroid Build Coastguard Worker   location_group->entity_field->field.back()->field_name = "location";
914*993b0882SAndroid Build Coastguard Worker 
915*993b0882SAndroid Build Coastguard Worker   // Set test entity data schema.
916*993b0882SAndroid Build Coastguard Worker   SetTestEntityDataSchema(actions_model.get());
917*993b0882SAndroid Build Coastguard Worker 
918*993b0882SAndroid Build Coastguard Worker   // Use meta data to generate custom serialized entity data.
919*993b0882SAndroid Build Coastguard Worker   MutableFlatbufferBuilder entity_data_builder(
920*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetRoot<reflection::Schema>(
921*993b0882SAndroid Build Coastguard Worker           actions_model->actions_entity_data_schema.data()));
922*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<MutableFlatbuffer> entity_data =
923*993b0882SAndroid Build Coastguard Worker       entity_data_builder.NewRoot();
924*993b0882SAndroid Build Coastguard Worker   entity_data->Set("person", "Kenobi");
925*993b0882SAndroid Build Coastguard Worker   action->serialized_entity_data = entity_data->Serialize();
926*993b0882SAndroid Build Coastguard Worker 
927*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
928*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
929*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
930*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
931*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
932*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
933*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
934*993b0882SAndroid Build Coastguard Worker 
935*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
936*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
937*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
938*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
939*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
940*993b0882SAndroid Build Coastguard Worker   EXPECT_GE(response.actions.size(), 1);
941*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
942*993b0882SAndroid Build Coastguard Worker 
943*993b0882SAndroid Build Coastguard Worker   // Check entity data.
944*993b0882SAndroid Build Coastguard Worker   const flatbuffers::Table* entity =
945*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
946*993b0882SAndroid Build Coastguard Worker           response.actions[0].serialized_entity_data.data()));
947*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
948*993b0882SAndroid Build Coastguard Worker             "hello there");
949*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
950*993b0882SAndroid Build Coastguard Worker             "there");
951*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
952*993b0882SAndroid Build Coastguard Worker             "Kenobi");
953*993b0882SAndroid Build Coastguard Worker }
954*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,CreateActionsFromRulesWithNormalization)955*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, CreateActionsFromRulesWithNormalization) {
956*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
957*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
958*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
959*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
960*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
961*993b0882SAndroid Build Coastguard Worker 
962*993b0882SAndroid Build Coastguard Worker   actions_model->rules.reset(new RulesModelT());
963*993b0882SAndroid Build Coastguard Worker   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
964*993b0882SAndroid Build Coastguard Worker   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
965*993b0882SAndroid Build Coastguard Worker   rule->pattern = "^(?i:hello\\sthere)$";
966*993b0882SAndroid Build Coastguard Worker   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
967*993b0882SAndroid Build Coastguard Worker   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
968*993b0882SAndroid Build Coastguard Worker   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
969*993b0882SAndroid Build Coastguard Worker   action->type = "text_reply";
970*993b0882SAndroid Build Coastguard Worker   action->response_text = "General Kenobi!";
971*993b0882SAndroid Build Coastguard Worker   action->score = 1.0f;
972*993b0882SAndroid Build Coastguard Worker   action->priority_score = 1.0f;
973*993b0882SAndroid Build Coastguard Worker 
974*993b0882SAndroid Build Coastguard Worker   // Set capturing groups for entity data.
975*993b0882SAndroid Build Coastguard Worker   rule->actions.back()->capturing_group.emplace_back(
976*993b0882SAndroid Build Coastguard Worker       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
977*993b0882SAndroid Build Coastguard Worker   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
978*993b0882SAndroid Build Coastguard Worker       rule->actions.back()->capturing_group.back().get();
979*993b0882SAndroid Build Coastguard Worker   greeting_group->group_id = 0;
980*993b0882SAndroid Build Coastguard Worker   greeting_group->entity_field.reset(new FlatbufferFieldPathT);
981*993b0882SAndroid Build Coastguard Worker   greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
982*993b0882SAndroid Build Coastguard Worker   greeting_group->entity_field->field.back()->field_name = "greeting";
983*993b0882SAndroid Build Coastguard Worker   greeting_group->normalization_options.reset(new NormalizationOptionsT);
984*993b0882SAndroid Build Coastguard Worker   greeting_group->normalization_options->codepointwise_normalization =
985*993b0882SAndroid Build Coastguard Worker       NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
986*993b0882SAndroid Build Coastguard Worker       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
987*993b0882SAndroid Build Coastguard Worker 
988*993b0882SAndroid Build Coastguard Worker   // Set test entity data schema.
989*993b0882SAndroid Build Coastguard Worker   SetTestEntityDataSchema(actions_model.get());
990*993b0882SAndroid Build Coastguard Worker 
991*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
992*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
993*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
994*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
995*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
996*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
997*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
998*993b0882SAndroid Build Coastguard Worker 
999*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1000*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1001*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
1002*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1003*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
1004*993b0882SAndroid Build Coastguard Worker   EXPECT_GE(response.actions.size(), 1);
1005*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
1006*993b0882SAndroid Build Coastguard Worker 
1007*993b0882SAndroid Build Coastguard Worker   // Check entity data.
1008*993b0882SAndroid Build Coastguard Worker   const flatbuffers::Table* entity =
1009*993b0882SAndroid Build Coastguard Worker       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
1010*993b0882SAndroid Build Coastguard Worker           response.actions[0].serialized_entity_data.data()));
1011*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
1012*993b0882SAndroid Build Coastguard Worker             "HELLOTHERE");
1013*993b0882SAndroid Build Coastguard Worker }
1014*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,CreatesTextRepliesFromRules)1015*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
1016*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
1017*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
1018*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
1019*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
1020*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1021*993b0882SAndroid Build Coastguard Worker 
1022*993b0882SAndroid Build Coastguard Worker   actions_model->rules.reset(new RulesModelT());
1023*993b0882SAndroid Build Coastguard Worker   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1024*993b0882SAndroid Build Coastguard Worker   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
1025*993b0882SAndroid Build Coastguard Worker   rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
1026*993b0882SAndroid Build Coastguard Worker   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1027*993b0882SAndroid Build Coastguard Worker 
1028*993b0882SAndroid Build Coastguard Worker   // Set capturing groups for entity data.
1029*993b0882SAndroid Build Coastguard Worker   rule->actions.back()->capturing_group.emplace_back(
1030*993b0882SAndroid Build Coastguard Worker       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
1031*993b0882SAndroid Build Coastguard Worker   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
1032*993b0882SAndroid Build Coastguard Worker       rule->actions.back()->capturing_group.back().get();
1033*993b0882SAndroid Build Coastguard Worker   code_group->group_id = 1;
1034*993b0882SAndroid Build Coastguard Worker   code_group->text_reply.reset(new ActionSuggestionSpecT);
1035*993b0882SAndroid Build Coastguard Worker   code_group->text_reply->score = 1.0f;
1036*993b0882SAndroid Build Coastguard Worker   code_group->text_reply->priority_score = 1.0f;
1037*993b0882SAndroid Build Coastguard Worker   code_group->normalization_options.reset(new NormalizationOptionsT);
1038*993b0882SAndroid Build Coastguard Worker   code_group->normalization_options->codepointwise_normalization =
1039*993b0882SAndroid Build Coastguard Worker       NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE;
1040*993b0882SAndroid Build Coastguard Worker 
1041*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
1042*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
1043*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
1044*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1045*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
1046*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1047*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
1048*993b0882SAndroid Build Coastguard Worker 
1049*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1050*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1051*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1,
1052*993b0882SAndroid Build Coastguard Worker              "visit test.com or reply STOP to cancel your subscription",
1053*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
1054*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1055*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
1056*993b0882SAndroid Build Coastguard Worker   EXPECT_GE(response.actions.size(), 1);
1057*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].response_text, "stop");
1058*993b0882SAndroid Build Coastguard Worker }
1059*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,CreatesActionsFromGrammarRules)1060*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) {
1061*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
1062*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
1063*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
1064*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
1065*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1066*993b0882SAndroid Build Coastguard Worker 
1067*993b0882SAndroid Build Coastguard Worker   actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
1068*993b0882SAndroid Build Coastguard Worker 
1069*993b0882SAndroid Build Coastguard Worker   // Set tokenizer options.
1070*993b0882SAndroid Build Coastguard Worker   RulesModel_::GrammarRulesT* action_grammar_rules =
1071*993b0882SAndroid Build Coastguard Worker       actions_model->rules->grammar_rules.get();
1072*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1073*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
1074*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
1075*993b0882SAndroid Build Coastguard Worker       false;
1076*993b0882SAndroid Build Coastguard Worker 
1077*993b0882SAndroid Build Coastguard Worker   // Setup test rules.
1078*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->rules.reset(new grammar::RulesSetT);
1079*993b0882SAndroid Build Coastguard Worker   grammar::LocaleShardMap locale_shard_map =
1080*993b0882SAndroid Build Coastguard Worker       grammar::LocaleShardMap::CreateLocaleShardMap({""});
1081*993b0882SAndroid Build Coastguard Worker   grammar::Rules rules(locale_shard_map);
1082*993b0882SAndroid Build Coastguard Worker   rules.Add(
1083*993b0882SAndroid Build Coastguard Worker       "<knock>", {"<^>", "ventura", "!?", "<$>"},
1084*993b0882SAndroid Build Coastguard Worker       /*callback=*/
1085*993b0882SAndroid Build Coastguard Worker       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1086*993b0882SAndroid Build Coastguard Worker       /*callback_param=*/0);
1087*993b0882SAndroid Build Coastguard Worker   rules.Finalize().Serialize(/*include_debug_information=*/false,
1088*993b0882SAndroid Build Coastguard Worker                              action_grammar_rules->rules.get());
1089*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1090*993b0882SAndroid Build Coastguard Worker   RulesModel_::RuleActionSpecT* actions_spec =
1091*993b0882SAndroid Build Coastguard Worker       action_grammar_rules->actions.back().get();
1092*993b0882SAndroid Build Coastguard Worker   actions_spec->action.reset(new ActionSuggestionSpecT);
1093*993b0882SAndroid Build Coastguard Worker   actions_spec->action->response_text = "Yes, Satan?";
1094*993b0882SAndroid Build Coastguard Worker   actions_spec->action->priority_score = 1.0;
1095*993b0882SAndroid Build Coastguard Worker   actions_spec->action->score = 1.0;
1096*993b0882SAndroid Build Coastguard Worker   actions_spec->action->type = "text_reply";
1097*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->rule_match.emplace_back(
1098*993b0882SAndroid Build Coastguard Worker       new RulesModel_::GrammarRules_::RuleMatchT);
1099*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->rule_match.back()->action_id.push_back(0);
1100*993b0882SAndroid Build Coastguard Worker 
1101*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
1102*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
1103*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
1104*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1105*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
1106*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1107*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
1108*993b0882SAndroid Build Coastguard Worker 
1109*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1110*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1111*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Ventura!",
1112*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
1113*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1114*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
1115*993b0882SAndroid Build Coastguard Worker 
1116*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, ElementsAre(IsSmartReply("Yes, Satan?")));
1117*993b0882SAndroid Build Coastguard Worker }
1118*993b0882SAndroid Build Coastguard Worker 
1119*993b0882SAndroid Build Coastguard Worker #if defined(TC3_UNILIB_ICU) && !defined(TEST_NO_DATETIME)
TEST_F(ActionsSuggestionsTest,CreatesActionsWithAnnotationsFromGrammarRules)1120*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) {
1121*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<Annotator> annotator =
1122*993b0882SAndroid Build Coastguard Worker       Annotator::FromPath(GetModelPath() + "en.fb", unilib_.get());
1123*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
1124*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
1125*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
1126*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
1127*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1128*993b0882SAndroid Build Coastguard Worker 
1129*993b0882SAndroid Build Coastguard Worker   actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
1130*993b0882SAndroid Build Coastguard Worker 
1131*993b0882SAndroid Build Coastguard Worker   // Set tokenizer options.
1132*993b0882SAndroid Build Coastguard Worker   RulesModel_::GrammarRulesT* action_grammar_rules =
1133*993b0882SAndroid Build Coastguard Worker       actions_model->rules->grammar_rules.get();
1134*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1135*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
1136*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
1137*993b0882SAndroid Build Coastguard Worker       false;
1138*993b0882SAndroid Build Coastguard Worker 
1139*993b0882SAndroid Build Coastguard Worker   // Setup test rules.
1140*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->rules.reset(new grammar::RulesSetT);
1141*993b0882SAndroid Build Coastguard Worker   grammar::LocaleShardMap locale_shard_map =
1142*993b0882SAndroid Build Coastguard Worker       grammar::LocaleShardMap::CreateLocaleShardMap({""});
1143*993b0882SAndroid Build Coastguard Worker   grammar::Rules rules(locale_shard_map);
1144*993b0882SAndroid Build Coastguard Worker   rules.Add(
1145*993b0882SAndroid Build Coastguard Worker       "<event>", {"it", "is", "at", "<time>"},
1146*993b0882SAndroid Build Coastguard Worker       /*callback=*/
1147*993b0882SAndroid Build Coastguard Worker       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1148*993b0882SAndroid Build Coastguard Worker       /*callback_param=*/0);
1149*993b0882SAndroid Build Coastguard Worker   rules.BindAnnotation("<time>", "time");
1150*993b0882SAndroid Build Coastguard Worker   rules.AddAnnotation("datetime");
1151*993b0882SAndroid Build Coastguard Worker   rules.Finalize().Serialize(/*include_debug_information=*/false,
1152*993b0882SAndroid Build Coastguard Worker                              action_grammar_rules->rules.get());
1153*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1154*993b0882SAndroid Build Coastguard Worker   RulesModel_::RuleActionSpecT* actions_spec =
1155*993b0882SAndroid Build Coastguard Worker       action_grammar_rules->actions.back().get();
1156*993b0882SAndroid Build Coastguard Worker   actions_spec->action.reset(new ActionSuggestionSpecT);
1157*993b0882SAndroid Build Coastguard Worker   actions_spec->action->priority_score = 1.0;
1158*993b0882SAndroid Build Coastguard Worker   actions_spec->action->score = 1.0;
1159*993b0882SAndroid Build Coastguard Worker   actions_spec->action->type = "create_event";
1160*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->rule_match.emplace_back(
1161*993b0882SAndroid Build Coastguard Worker       new RulesModel_::GrammarRules_::RuleMatchT);
1162*993b0882SAndroid Build Coastguard Worker   action_grammar_rules->rule_match.back()->action_id.push_back(0);
1163*993b0882SAndroid Build Coastguard Worker 
1164*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
1165*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
1166*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
1167*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1168*993b0882SAndroid Build Coastguard Worker       ActionsSuggestions::FromUnownedBuffer(
1169*993b0882SAndroid Build Coastguard Worker           reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1170*993b0882SAndroid Build Coastguard Worker           builder.GetSize(), unilib_.get());
1171*993b0882SAndroid Build Coastguard Worker 
1172*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1173*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1174*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "it is at 10:30",
1175*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
1176*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1177*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}},
1178*993b0882SAndroid Build Coastguard Worker           annotator.get());
1179*993b0882SAndroid Build Coastguard Worker 
1180*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, ElementsAre(IsActionOfType("create_event")));
1181*993b0882SAndroid Build Coastguard Worker }
1182*993b0882SAndroid Build Coastguard Worker #endif
1183*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,DeduplicateActions)1184*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
1185*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1186*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelFileName);
1187*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
1188*993b0882SAndroid Build Coastguard Worker       {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1189*993b0882SAndroid Build Coastguard Worker          /*reference_timezone=*/"Europe/Zurich",
1190*993b0882SAndroid Build Coastguard Worker          /*annotations=*/{}, /*locales=*/"en"}}});
1191*993b0882SAndroid Build Coastguard Worker 
1192*993b0882SAndroid Build Coastguard Worker   // Check that the location sharing model triggered.
1193*993b0882SAndroid Build Coastguard Worker   bool has_location_sharing_action = false;
1194*993b0882SAndroid Build Coastguard Worker   for (const ActionSuggestion& action : response.actions) {
1195*993b0882SAndroid Build Coastguard Worker     if (action.type == ActionsSuggestionsTypes::ShareLocation()) {
1196*993b0882SAndroid Build Coastguard Worker       has_location_sharing_action = true;
1197*993b0882SAndroid Build Coastguard Worker       break;
1198*993b0882SAndroid Build Coastguard Worker     }
1199*993b0882SAndroid Build Coastguard Worker   }
1200*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(has_location_sharing_action);
1201*993b0882SAndroid Build Coastguard Worker   const int num_actions = response.actions.size();
1202*993b0882SAndroid Build Coastguard Worker 
1203*993b0882SAndroid Build Coastguard Worker   // Add custom rule for location sharing.
1204*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
1205*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
1206*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
1207*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
1208*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1209*993b0882SAndroid Build Coastguard Worker 
1210*993b0882SAndroid Build Coastguard Worker   actions_model->rules.reset(new RulesModelT());
1211*993b0882SAndroid Build Coastguard Worker   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1212*993b0882SAndroid Build Coastguard Worker   actions_model->rules->regex_rule.back()->pattern =
1213*993b0882SAndroid Build Coastguard Worker       "^(?i:where are you[.?]?)$";
1214*993b0882SAndroid Build Coastguard Worker   actions_model->rules->regex_rule.back()->actions.emplace_back(
1215*993b0882SAndroid Build Coastguard Worker       new RulesModel_::RuleActionSpecT);
1216*993b0882SAndroid Build Coastguard Worker   actions_model->rules->regex_rule.back()->actions.back()->action.reset(
1217*993b0882SAndroid Build Coastguard Worker       new ActionSuggestionSpecT);
1218*993b0882SAndroid Build Coastguard Worker   ActionSuggestionSpecT* action =
1219*993b0882SAndroid Build Coastguard Worker       actions_model->rules->regex_rule.back()->actions.back()->action.get();
1220*993b0882SAndroid Build Coastguard Worker   action->score = 1.0f;
1221*993b0882SAndroid Build Coastguard Worker   action->type = ActionsSuggestionsTypes::ShareLocation();
1222*993b0882SAndroid Build Coastguard Worker 
1223*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
1224*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
1225*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
1226*993b0882SAndroid Build Coastguard Worker   actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
1227*993b0882SAndroid Build Coastguard Worker       reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1228*993b0882SAndroid Build Coastguard Worker       builder.GetSize(), unilib_.get());
1229*993b0882SAndroid Build Coastguard Worker 
1230*993b0882SAndroid Build Coastguard Worker   response = actions_suggestions->SuggestActions(
1231*993b0882SAndroid Build Coastguard Worker       {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1232*993b0882SAndroid Build Coastguard Worker          /*reference_timezone=*/"Europe/Zurich",
1233*993b0882SAndroid Build Coastguard Worker          /*annotations=*/{}, /*locales=*/"en"}}});
1234*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions, SizeIs(num_actions));
1235*993b0882SAndroid Build Coastguard Worker }
1236*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,DeduplicateConflictingActions)1237*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
1238*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1239*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelFileName);
1240*993b0882SAndroid Build Coastguard Worker   AnnotatedSpan annotation;
1241*993b0882SAndroid Build Coastguard Worker   annotation.span = {7, 11};
1242*993b0882SAndroid Build Coastguard Worker   annotation.classification = {
1243*993b0882SAndroid Build Coastguard Worker       ClassificationResult(Collections::Flight(), 1.0)};
1244*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
1245*993b0882SAndroid Build Coastguard Worker       {{{/*user_id=*/1, "I'm on LX38",
1246*993b0882SAndroid Build Coastguard Worker          /*reference_time_ms_utc=*/0,
1247*993b0882SAndroid Build Coastguard Worker          /*reference_timezone=*/"Europe/Zurich",
1248*993b0882SAndroid Build Coastguard Worker          /*annotations=*/{annotation},
1249*993b0882SAndroid Build Coastguard Worker          /*locales=*/"en"}}});
1250*993b0882SAndroid Build Coastguard Worker 
1251*993b0882SAndroid Build Coastguard Worker   // Check that the phone actions are present.
1252*993b0882SAndroid Build Coastguard Worker   EXPECT_GE(response.actions.size(), 1);
1253*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "track_flight");
1254*993b0882SAndroid Build Coastguard Worker 
1255*993b0882SAndroid Build Coastguard Worker   // Add custom rule.
1256*993b0882SAndroid Build Coastguard Worker   const std::string actions_model_string =
1257*993b0882SAndroid Build Coastguard Worker       ReadFile(GetModelPath() + kModelFileName);
1258*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsModelT> actions_model =
1259*993b0882SAndroid Build Coastguard Worker       UnPackActionsModel(actions_model_string.c_str());
1260*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
1261*993b0882SAndroid Build Coastguard Worker 
1262*993b0882SAndroid Build Coastguard Worker   actions_model->rules.reset(new RulesModelT());
1263*993b0882SAndroid Build Coastguard Worker   actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
1264*993b0882SAndroid Build Coastguard Worker   RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
1265*993b0882SAndroid Build Coastguard Worker   rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
1266*993b0882SAndroid Build Coastguard Worker   rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
1267*993b0882SAndroid Build Coastguard Worker   rule->actions.back()->action.reset(new ActionSuggestionSpecT);
1268*993b0882SAndroid Build Coastguard Worker   ActionSuggestionSpecT* action = rule->actions.back()->action.get();
1269*993b0882SAndroid Build Coastguard Worker   action->score = 1.0f;
1270*993b0882SAndroid Build Coastguard Worker   action->priority_score = 2.0f;
1271*993b0882SAndroid Build Coastguard Worker   action->type = "test_code";
1272*993b0882SAndroid Build Coastguard Worker   rule->actions.back()->capturing_group.emplace_back(
1273*993b0882SAndroid Build Coastguard Worker       new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
1274*993b0882SAndroid Build Coastguard Worker   RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
1275*993b0882SAndroid Build Coastguard Worker       rule->actions.back()->capturing_group.back().get();
1276*993b0882SAndroid Build Coastguard Worker   code_group->group_id = 1;
1277*993b0882SAndroid Build Coastguard Worker   code_group->annotation_name = "code";
1278*993b0882SAndroid Build Coastguard Worker   code_group->annotation_type = "code";
1279*993b0882SAndroid Build Coastguard Worker 
1280*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
1281*993b0882SAndroid Build Coastguard Worker   FinishActionsModelBuffer(builder,
1282*993b0882SAndroid Build Coastguard Worker                            ActionsModel::Pack(builder, actions_model.get()));
1283*993b0882SAndroid Build Coastguard Worker   actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
1284*993b0882SAndroid Build Coastguard Worker       reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
1285*993b0882SAndroid Build Coastguard Worker       builder.GetSize(), unilib_.get());
1286*993b0882SAndroid Build Coastguard Worker 
1287*993b0882SAndroid Build Coastguard Worker   response = actions_suggestions->SuggestActions(
1288*993b0882SAndroid Build Coastguard Worker       {{{/*user_id=*/1, "I'm on LX38",
1289*993b0882SAndroid Build Coastguard Worker          /*reference_time_ms_utc=*/0,
1290*993b0882SAndroid Build Coastguard Worker          /*reference_timezone=*/"Europe/Zurich",
1291*993b0882SAndroid Build Coastguard Worker          /*annotations=*/{annotation},
1292*993b0882SAndroid Build Coastguard Worker          /*locales=*/"en"}}});
1293*993b0882SAndroid Build Coastguard Worker   EXPECT_GE(response.actions.size(), 1);
1294*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "test_code");
1295*993b0882SAndroid Build Coastguard Worker }
1296*993b0882SAndroid Build Coastguard Worker #endif
1297*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,RanksActions)1298*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, RanksActions) {
1299*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1300*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kModelFileName);
1301*993b0882SAndroid Build Coastguard Worker   std::vector<AnnotatedSpan> annotations(2);
1302*993b0882SAndroid Build Coastguard Worker   annotations[0].span = {11, 15};
1303*993b0882SAndroid Build Coastguard Worker   annotations[0].classification = {ClassificationResult("address", 1.0)};
1304*993b0882SAndroid Build Coastguard Worker   annotations[1].span = {19, 23};
1305*993b0882SAndroid Build Coastguard Worker   annotations[1].classification = {ClassificationResult("address", 2.0)};
1306*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1307*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1308*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "are you at home or work?",
1309*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
1310*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1311*993b0882SAndroid Build Coastguard Worker              /*annotations=*/annotations,
1312*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
1313*993b0882SAndroid Build Coastguard Worker   EXPECT_GE(response.actions.size(), 2);
1314*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "view_map");
1315*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].score, 2.0);
1316*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "view_map");
1317*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].score, 1.0);
1318*993b0882SAndroid Build Coastguard Worker }
1319*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,VisitActionsModel)1320*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
1321*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
1322*993b0882SAndroid Build Coastguard Worker                                       [](const ActionsModel* model) {
1323*993b0882SAndroid Build Coastguard Worker                                         if (model == nullptr) {
1324*993b0882SAndroid Build Coastguard Worker                                           return false;
1325*993b0882SAndroid Build Coastguard Worker                                         }
1326*993b0882SAndroid Build Coastguard Worker                                         return true;
1327*993b0882SAndroid Build Coastguard Worker                                       }));
1328*993b0882SAndroid Build Coastguard Worker   EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
1329*993b0882SAndroid Build Coastguard Worker                                        [](const ActionsModel* model) {
1330*993b0882SAndroid Build Coastguard Worker                                          if (model == nullptr) {
1331*993b0882SAndroid Build Coastguard Worker                                            return false;
1332*993b0882SAndroid Build Coastguard Worker                                          }
1333*993b0882SAndroid Build Coastguard Worker                                          return true;
1334*993b0882SAndroid Build Coastguard Worker                                        }));
1335*993b0882SAndroid Build Coastguard Worker }
1336*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsWithHashGramModel)1337*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsWithHashGramModel) {
1338*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1339*993b0882SAndroid Build Coastguard Worker       LoadHashGramTestModel();
1340*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(actions_suggestions != nullptr);
1341*993b0882SAndroid Build Coastguard Worker   {
1342*993b0882SAndroid Build Coastguard Worker     const ActionsSuggestionsResponse response =
1343*993b0882SAndroid Build Coastguard Worker         actions_suggestions->SuggestActions(
1344*993b0882SAndroid Build Coastguard Worker             {{{/*user_id=*/1, "hello",
1345*993b0882SAndroid Build Coastguard Worker                /*reference_time_ms_utc=*/0,
1346*993b0882SAndroid Build Coastguard Worker                /*reference_timezone=*/"Europe/Zurich",
1347*993b0882SAndroid Build Coastguard Worker                /*annotations=*/{},
1348*993b0882SAndroid Build Coastguard Worker                /*locales=*/"en"}}});
1349*993b0882SAndroid Build Coastguard Worker     EXPECT_THAT(response.actions, testing::IsEmpty());
1350*993b0882SAndroid Build Coastguard Worker   }
1351*993b0882SAndroid Build Coastguard Worker   {
1352*993b0882SAndroid Build Coastguard Worker     const ActionsSuggestionsResponse response =
1353*993b0882SAndroid Build Coastguard Worker         actions_suggestions->SuggestActions(
1354*993b0882SAndroid Build Coastguard Worker             {{{/*user_id=*/1, "where are you",
1355*993b0882SAndroid Build Coastguard Worker                /*reference_time_ms_utc=*/0,
1356*993b0882SAndroid Build Coastguard Worker                /*reference_timezone=*/"Europe/Zurich",
1357*993b0882SAndroid Build Coastguard Worker                /*annotations=*/{},
1358*993b0882SAndroid Build Coastguard Worker                /*locales=*/"en"}}});
1359*993b0882SAndroid Build Coastguard Worker     EXPECT_THAT(
1360*993b0882SAndroid Build Coastguard Worker         response.actions,
1361*993b0882SAndroid Build Coastguard Worker         ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
1362*993b0882SAndroid Build Coastguard Worker   }
1363*993b0882SAndroid Build Coastguard Worker   {
1364*993b0882SAndroid Build Coastguard Worker     const ActionsSuggestionsResponse response =
1365*993b0882SAndroid Build Coastguard Worker         actions_suggestions->SuggestActions(
1366*993b0882SAndroid Build Coastguard Worker             {{{/*user_id=*/1, "do you know johns number",
1367*993b0882SAndroid Build Coastguard Worker                /*reference_time_ms_utc=*/0,
1368*993b0882SAndroid Build Coastguard Worker                /*reference_timezone=*/"Europe/Zurich",
1369*993b0882SAndroid Build Coastguard Worker                /*annotations=*/{},
1370*993b0882SAndroid Build Coastguard Worker                /*locales=*/"en"}}});
1371*993b0882SAndroid Build Coastguard Worker     EXPECT_THAT(
1372*993b0882SAndroid Build Coastguard Worker         response.actions,
1373*993b0882SAndroid Build Coastguard Worker         ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
1374*993b0882SAndroid Build Coastguard Worker   }
1375*993b0882SAndroid Build Coastguard Worker }
1376*993b0882SAndroid Build Coastguard Worker 
1377*993b0882SAndroid Build Coastguard Worker // Test class to expose token embedding methods for testing.
1378*993b0882SAndroid Build Coastguard Worker class TestingMessageEmbedder : private ActionsSuggestions {
1379*993b0882SAndroid Build Coastguard Worker  public:
1380*993b0882SAndroid Build Coastguard Worker   explicit TestingMessageEmbedder(const ActionsModel* model);
1381*993b0882SAndroid Build Coastguard Worker 
1382*993b0882SAndroid Build Coastguard Worker   using ActionsSuggestions::EmbedAndFlattenTokens;
1383*993b0882SAndroid Build Coastguard Worker   using ActionsSuggestions::EmbedTokensPerMessage;
1384*993b0882SAndroid Build Coastguard Worker 
1385*993b0882SAndroid Build Coastguard Worker  protected:
1386*993b0882SAndroid Build Coastguard Worker   // EmbeddingExecutor that always returns features based on
1387*993b0882SAndroid Build Coastguard Worker   // the id of the sparse features.
1388*993b0882SAndroid Build Coastguard Worker   class FakeEmbeddingExecutor : public EmbeddingExecutor {
1389*993b0882SAndroid Build Coastguard Worker    public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const1390*993b0882SAndroid Build Coastguard Worker     bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
1391*993b0882SAndroid Build Coastguard Worker                       const int dest_size) const override {
1392*993b0882SAndroid Build Coastguard Worker       TC3_CHECK_GE(dest_size, 1);
1393*993b0882SAndroid Build Coastguard Worker       EXPECT_EQ(sparse_features.size(), 1);
1394*993b0882SAndroid Build Coastguard Worker       dest[0] = sparse_features.data()[0];
1395*993b0882SAndroid Build Coastguard Worker       return true;
1396*993b0882SAndroid Build Coastguard Worker     }
1397*993b0882SAndroid Build Coastguard Worker   };
1398*993b0882SAndroid Build Coastguard Worker 
1399*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<UniLib> unilib_;
1400*993b0882SAndroid Build Coastguard Worker };
1401*993b0882SAndroid Build Coastguard Worker 
TestingMessageEmbedder(const ActionsModel * model)1402*993b0882SAndroid Build Coastguard Worker TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model)
1403*993b0882SAndroid Build Coastguard Worker     : unilib_(CreateUniLibForTesting()) {
1404*993b0882SAndroid Build Coastguard Worker   model_ = model;
1405*993b0882SAndroid Build Coastguard Worker   const ActionsTokenFeatureProcessorOptions* options =
1406*993b0882SAndroid Build Coastguard Worker       model->feature_processor_options();
1407*993b0882SAndroid Build Coastguard Worker   feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_.get()));
1408*993b0882SAndroid Build Coastguard Worker   embedding_executor_.reset(new FakeEmbeddingExecutor());
1409*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(
1410*993b0882SAndroid Build Coastguard Worker       EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
1411*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
1412*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
1413*993b0882SAndroid Build Coastguard Worker   token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
1414*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(token_embedding_size_, 1);
1415*993b0882SAndroid Build Coastguard Worker }
1416*993b0882SAndroid Build Coastguard Worker 
1417*993b0882SAndroid Build Coastguard Worker class EmbeddingTest : public testing::Test {
1418*993b0882SAndroid Build Coastguard Worker  protected:
EmbeddingTest()1419*993b0882SAndroid Build Coastguard Worker   explicit EmbeddingTest() {
1420*993b0882SAndroid Build Coastguard Worker     model_.feature_processor_options.reset(
1421*993b0882SAndroid Build Coastguard Worker         new ActionsTokenFeatureProcessorOptionsT);
1422*993b0882SAndroid Build Coastguard Worker     options_ = model_.feature_processor_options.get();
1423*993b0882SAndroid Build Coastguard Worker     options_->chargram_orders = {1};
1424*993b0882SAndroid Build Coastguard Worker     options_->num_buckets = 1000;
1425*993b0882SAndroid Build Coastguard Worker     options_->embedding_size = 1;
1426*993b0882SAndroid Build Coastguard Worker     options_->start_token_id = 0;
1427*993b0882SAndroid Build Coastguard Worker     options_->end_token_id = 1;
1428*993b0882SAndroid Build Coastguard Worker     options_->padding_token_id = 2;
1429*993b0882SAndroid Build Coastguard Worker     options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
1430*993b0882SAndroid Build Coastguard Worker   }
1431*993b0882SAndroid Build Coastguard Worker 
CreateTestingMessageEmbedder()1432*993b0882SAndroid Build Coastguard Worker   TestingMessageEmbedder CreateTestingMessageEmbedder() {
1433*993b0882SAndroid Build Coastguard Worker     flatbuffers::FlatBufferBuilder builder;
1434*993b0882SAndroid Build Coastguard Worker     FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
1435*993b0882SAndroid Build Coastguard Worker     buffer_ = builder.Release();
1436*993b0882SAndroid Build Coastguard Worker     return TestingMessageEmbedder(
1437*993b0882SAndroid Build Coastguard Worker         flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
1438*993b0882SAndroid Build Coastguard Worker   }
1439*993b0882SAndroid Build Coastguard Worker 
1440*993b0882SAndroid Build Coastguard Worker   flatbuffers::DetachedBuffer buffer_;
1441*993b0882SAndroid Build Coastguard Worker   ActionsModelT model_;
1442*993b0882SAndroid Build Coastguard Worker   ActionsTokenFeatureProcessorOptionsT* options_;
1443*993b0882SAndroid Build Coastguard Worker };
1444*993b0882SAndroid Build Coastguard Worker 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithNoBounds)1445*993b0882SAndroid Build Coastguard Worker TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
1446*993b0882SAndroid Build Coastguard Worker   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1447*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens = {
1448*993b0882SAndroid Build Coastguard Worker       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1449*993b0882SAndroid Build Coastguard Worker   std::vector<float> embeddings;
1450*993b0882SAndroid Build Coastguard Worker   int max_num_tokens_per_message = 0;
1451*993b0882SAndroid Build Coastguard Worker 
1452*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1453*993b0882SAndroid Build Coastguard Worker                                              &max_num_tokens_per_message));
1454*993b0882SAndroid Build Coastguard Worker 
1455*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(max_num_tokens_per_message, 3);
1456*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(embeddings.size(), 3);
1457*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1458*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1459*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1460*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1461*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1462*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1463*993b0882SAndroid Build Coastguard Worker }
1464*993b0882SAndroid Build Coastguard Worker 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithPadding)1465*993b0882SAndroid Build Coastguard Worker TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
1466*993b0882SAndroid Build Coastguard Worker   options_->min_num_tokens_per_message = 5;
1467*993b0882SAndroid Build Coastguard Worker   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1468*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens = {
1469*993b0882SAndroid Build Coastguard Worker       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1470*993b0882SAndroid Build Coastguard Worker   std::vector<float> embeddings;
1471*993b0882SAndroid Build Coastguard Worker   int max_num_tokens_per_message = 0;
1472*993b0882SAndroid Build Coastguard Worker 
1473*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1474*993b0882SAndroid Build Coastguard Worker                                              &max_num_tokens_per_message));
1475*993b0882SAndroid Build Coastguard Worker 
1476*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(max_num_tokens_per_message, 5);
1477*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(embeddings.size(), 5);
1478*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1479*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1480*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1481*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1482*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1483*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1484*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[3], FloatEq(options_->padding_token_id));
1485*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[4], FloatEq(options_->padding_token_id));
1486*993b0882SAndroid Build Coastguard Worker }
1487*993b0882SAndroid Build Coastguard Worker 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageDropsAtBeginning)1488*993b0882SAndroid Build Coastguard Worker TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
1489*993b0882SAndroid Build Coastguard Worker   options_->max_num_tokens_per_message = 2;
1490*993b0882SAndroid Build Coastguard Worker   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1491*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens = {
1492*993b0882SAndroid Build Coastguard Worker       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1493*993b0882SAndroid Build Coastguard Worker   std::vector<float> embeddings;
1494*993b0882SAndroid Build Coastguard Worker   int max_num_tokens_per_message = 0;
1495*993b0882SAndroid Build Coastguard Worker 
1496*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1497*993b0882SAndroid Build Coastguard Worker                                              &max_num_tokens_per_message));
1498*993b0882SAndroid Build Coastguard Worker 
1499*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(max_num_tokens_per_message, 2);
1500*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(embeddings.size(), 2);
1501*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1502*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1503*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1504*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1505*993b0882SAndroid Build Coastguard Worker }
1506*993b0882SAndroid Build Coastguard Worker 
TEST_F(EmbeddingTest,EmbedsTokensPerMessageWithMultipleMessagesNoBounds)1507*993b0882SAndroid Build Coastguard Worker TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
1508*993b0882SAndroid Build Coastguard Worker   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1509*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens = {
1510*993b0882SAndroid Build Coastguard Worker       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1511*993b0882SAndroid Build Coastguard Worker       {Token("d", 0, 1), Token("e", 2, 3)}};
1512*993b0882SAndroid Build Coastguard Worker   std::vector<float> embeddings;
1513*993b0882SAndroid Build Coastguard Worker   int max_num_tokens_per_message = 0;
1514*993b0882SAndroid Build Coastguard Worker 
1515*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
1516*993b0882SAndroid Build Coastguard Worker                                              &max_num_tokens_per_message));
1517*993b0882SAndroid Build Coastguard Worker 
1518*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(max_num_tokens_per_message, 3);
1519*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1520*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1521*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1522*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1523*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1524*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1525*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1526*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1527*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1528*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1529*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
1530*993b0882SAndroid Build Coastguard Worker }
1531*993b0882SAndroid Build Coastguard Worker 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithNoBounds)1532*993b0882SAndroid Build Coastguard Worker TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
1533*993b0882SAndroid Build Coastguard Worker   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1534*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens = {
1535*993b0882SAndroid Build Coastguard Worker       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1536*993b0882SAndroid Build Coastguard Worker   std::vector<float> embeddings;
1537*993b0882SAndroid Build Coastguard Worker   int total_token_count = 0;
1538*993b0882SAndroid Build Coastguard Worker 
1539*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(
1540*993b0882SAndroid Build Coastguard Worker       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1541*993b0882SAndroid Build Coastguard Worker 
1542*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(total_token_count, 5);
1543*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(embeddings.size(), 5);
1544*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1545*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1546*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1547*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1548*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1549*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1550*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1551*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1552*993b0882SAndroid Build Coastguard Worker }
1553*993b0882SAndroid Build Coastguard Worker 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithPadding)1554*993b0882SAndroid Build Coastguard Worker TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
1555*993b0882SAndroid Build Coastguard Worker   options_->min_num_total_tokens = 7;
1556*993b0882SAndroid Build Coastguard Worker   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1557*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens = {
1558*993b0882SAndroid Build Coastguard Worker       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1559*993b0882SAndroid Build Coastguard Worker   std::vector<float> embeddings;
1560*993b0882SAndroid Build Coastguard Worker   int total_token_count = 0;
1561*993b0882SAndroid Build Coastguard Worker 
1562*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(
1563*993b0882SAndroid Build Coastguard Worker       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1564*993b0882SAndroid Build Coastguard Worker 
1565*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(total_token_count, 7);
1566*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(embeddings.size(), 7);
1567*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1568*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1569*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1570*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1571*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1572*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1573*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1574*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1575*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
1576*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[6], FloatEq(options_->padding_token_id));
1577*993b0882SAndroid Build Coastguard Worker }
1578*993b0882SAndroid Build Coastguard Worker 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensDropsAtBeginning)1579*993b0882SAndroid Build Coastguard Worker TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
1580*993b0882SAndroid Build Coastguard Worker   options_->max_num_total_tokens = 3;
1581*993b0882SAndroid Build Coastguard Worker   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1582*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens = {
1583*993b0882SAndroid Build Coastguard Worker       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
1584*993b0882SAndroid Build Coastguard Worker   std::vector<float> embeddings;
1585*993b0882SAndroid Build Coastguard Worker   int total_token_count = 0;
1586*993b0882SAndroid Build Coastguard Worker 
1587*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(
1588*993b0882SAndroid Build Coastguard Worker       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1589*993b0882SAndroid Build Coastguard Worker 
1590*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(total_token_count, 3);
1591*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(embeddings.size(), 3);
1592*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1593*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1594*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1595*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1596*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[2], FloatEq(options_->end_token_id));
1597*993b0882SAndroid Build Coastguard Worker }
1598*993b0882SAndroid Build Coastguard Worker 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesNoBounds)1599*993b0882SAndroid Build Coastguard Worker TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
1600*993b0882SAndroid Build Coastguard Worker   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1601*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens = {
1602*993b0882SAndroid Build Coastguard Worker       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1603*993b0882SAndroid Build Coastguard Worker       {Token("d", 0, 1), Token("e", 2, 3)}};
1604*993b0882SAndroid Build Coastguard Worker   std::vector<float> embeddings;
1605*993b0882SAndroid Build Coastguard Worker   int total_token_count = 0;
1606*993b0882SAndroid Build Coastguard Worker 
1607*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(
1608*993b0882SAndroid Build Coastguard Worker       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1609*993b0882SAndroid Build Coastguard Worker 
1610*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(total_token_count, 9);
1611*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(embeddings.size(), 9);
1612*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
1613*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
1614*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1615*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
1616*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1617*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1618*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1619*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
1620*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[5], FloatEq(options_->start_token_id));
1621*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[6], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1622*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1623*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[7], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1624*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1625*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[8], FloatEq(options_->end_token_id));
1626*993b0882SAndroid Build Coastguard Worker }
1627*993b0882SAndroid Build Coastguard Worker 
TEST_F(EmbeddingTest,EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning)1628*993b0882SAndroid Build Coastguard Worker TEST_F(EmbeddingTest,
1629*993b0882SAndroid Build Coastguard Worker        EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
1630*993b0882SAndroid Build Coastguard Worker   options_->max_num_total_tokens = 7;
1631*993b0882SAndroid Build Coastguard Worker   const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
1632*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens = {
1633*993b0882SAndroid Build Coastguard Worker       {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
1634*993b0882SAndroid Build Coastguard Worker       {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
1635*993b0882SAndroid Build Coastguard Worker   std::vector<float> embeddings;
1636*993b0882SAndroid Build Coastguard Worker   int total_token_count = 0;
1637*993b0882SAndroid Build Coastguard Worker 
1638*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(
1639*993b0882SAndroid Build Coastguard Worker       embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
1640*993b0882SAndroid Build Coastguard Worker 
1641*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(total_token_count, 7);
1642*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(embeddings.size(), 7);
1643*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
1644*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1645*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[1], FloatEq(options_->end_token_id));
1646*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[2], FloatEq(options_->start_token_id));
1647*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
1648*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1649*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
1650*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1651*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[5], FloatEq(tc3farmhash::Fingerprint64("f", 1) %
1652*993b0882SAndroid Build Coastguard Worker                                      options_->num_buckets));
1653*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(embeddings[6], FloatEq(options_->end_token_id));
1654*993b0882SAndroid Build Coastguard Worker }
1655*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsDefault)1656*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsDefault) {
1657*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1658*993b0882SAndroid Build Coastguard Worker       LoadMultiTaskTestModel();
1659*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1660*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1661*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1662*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1663*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}});
1664*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(),
1665*993b0882SAndroid Build Coastguard Worker             11 /* 8 binary classification + 3 smart replies*/);
1666*993b0882SAndroid Build Coastguard Worker }
1667*993b0882SAndroid Build Coastguard Worker 
1668*993b0882SAndroid Build Coastguard Worker const float kDisableThresholdVal = 2.0;
1669*993b0882SAndroid Build Coastguard Worker 
1670*993b0882SAndroid Build Coastguard Worker constexpr char kSpamThreshold[] = "spam_confidence_threshold";
1671*993b0882SAndroid Build Coastguard Worker constexpr char kLocationThreshold[] = "location_confidence_threshold";
1672*993b0882SAndroid Build Coastguard Worker constexpr char kPhoneThreshold[] = "phone_confidence_threshold";
1673*993b0882SAndroid Build Coastguard Worker constexpr char kWeatherThreshold[] = "weather_confidence_threshold";
1674*993b0882SAndroid Build Coastguard Worker constexpr char kRestaurantsThreshold[] = "restaurants_confidence_threshold";
1675*993b0882SAndroid Build Coastguard Worker constexpr char kMoviesThreshold[] = "movies_confidence_threshold";
1676*993b0882SAndroid Build Coastguard Worker constexpr char kTtrThreshold[] = "time_to_reply_binary_threshold";
1677*993b0882SAndroid Build Coastguard Worker constexpr char kReminderThreshold[] = "reminder_intent_confidence_threshold";
1678*993b0882SAndroid Build Coastguard Worker constexpr char kDiversificationParm[] = "diversification_distance_threshold";
1679*993b0882SAndroid Build Coastguard Worker constexpr char kEmpiricalProbFactor[] = "empirical_probability_factor";
1680*993b0882SAndroid Build Coastguard Worker 
GetOptionsToDisableAllClassification()1681*993b0882SAndroid Build Coastguard Worker ActionSuggestionOptions GetOptionsToDisableAllClassification() {
1682*993b0882SAndroid Build Coastguard Worker   ActionSuggestionOptions options;
1683*993b0882SAndroid Build Coastguard Worker   // Disable all classification heads.
1684*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1685*993b0882SAndroid Build Coastguard Worker       {kSpamThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1686*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1687*993b0882SAndroid Build Coastguard Worker       {kLocationThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1688*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1689*993b0882SAndroid Build Coastguard Worker       {kPhoneThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1690*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1691*993b0882SAndroid Build Coastguard Worker       {kWeatherThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1692*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1693*993b0882SAndroid Build Coastguard Worker       {kRestaurantsThreshold,
1694*993b0882SAndroid Build Coastguard Worker        libtextclassifier3::Variant(kDisableThresholdVal)});
1695*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1696*993b0882SAndroid Build Coastguard Worker       {kMoviesThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1697*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1698*993b0882SAndroid Build Coastguard Worker       {kTtrThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1699*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1700*993b0882SAndroid Build Coastguard Worker       {kReminderThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
1701*993b0882SAndroid Build Coastguard Worker   return options;
1702*993b0882SAndroid Build Coastguard Worker }
1703*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsSmartReplyOnly)1704*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyOnly) {
1705*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1706*993b0882SAndroid Build Coastguard Worker       LoadMultiTaskTestModel();
1707*993b0882SAndroid Build Coastguard Worker   const ActionSuggestionOptions options =
1708*993b0882SAndroid Build Coastguard Worker       GetOptionsToDisableAllClassification();
1709*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1710*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1711*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1712*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1713*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}},
1714*993b0882SAndroid Build Coastguard Worker           /*annotator=*/nullptr, options);
1715*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(response.actions,
1716*993b0882SAndroid Build Coastguard Worker               ElementsAre(IsSmartReply("Here"), IsSmartReply("I'm here"),
1717*993b0882SAndroid Build Coastguard Worker                           IsSmartReply("I'm home")));
1718*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
1719*993b0882SAndroid Build Coastguard Worker }
1720*993b0882SAndroid Build Coastguard Worker 
1721*993b0882SAndroid Build Coastguard Worker const int kUserProfileSize = 1000;
1722*993b0882SAndroid Build Coastguard Worker constexpr char kUserProfileTokenIndex[] = "user_profile_token_index";
1723*993b0882SAndroid Build Coastguard Worker constexpr char kUserProfileTokenWeight[] = "user_profile_token_weight";
1724*993b0882SAndroid Build Coastguard Worker 
GetOptionsForSmartReplyP13nModel()1725*993b0882SAndroid Build Coastguard Worker ActionSuggestionOptions GetOptionsForSmartReplyP13nModel() {
1726*993b0882SAndroid Build Coastguard Worker   ActionSuggestionOptions options;
1727*993b0882SAndroid Build Coastguard Worker   const std::vector<int> user_profile_token_indexes(kUserProfileSize, 1);
1728*993b0882SAndroid Build Coastguard Worker   const std::vector<float> user_profile_token_weights(kUserProfileSize, 0.1f);
1729*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1730*993b0882SAndroid Build Coastguard Worker       {kUserProfileTokenIndex,
1731*993b0882SAndroid Build Coastguard Worker        libtextclassifier3::Variant(user_profile_token_indexes)});
1732*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1733*993b0882SAndroid Build Coastguard Worker       {kUserProfileTokenWeight,
1734*993b0882SAndroid Build Coastguard Worker        libtextclassifier3::Variant(user_profile_token_weights)});
1735*993b0882SAndroid Build Coastguard Worker   return options;
1736*993b0882SAndroid Build Coastguard Worker }
1737*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsSmartReplyP13n)1738*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyP13n) {
1739*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1740*993b0882SAndroid Build Coastguard Worker       LoadMultiTaskSrP13nTestModel();
1741*993b0882SAndroid Build Coastguard Worker   const ActionSuggestionOptions options = GetOptionsForSmartReplyP13nModel();
1742*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1743*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1744*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "How are you?", /*reference_time_ms_utc=*/0,
1745*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1746*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}},
1747*993b0882SAndroid Build Coastguard Worker           /*annotator=*/nullptr, options);
1748*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
1749*993b0882SAndroid Build Coastguard Worker }
1750*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation)1751*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest,
1752*993b0882SAndroid Build Coastguard Worker        MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation) {
1753*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1754*993b0882SAndroid Build Coastguard Worker       LoadMultiTaskTestModel();
1755*993b0882SAndroid Build Coastguard Worker   ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
1756*993b0882SAndroid Build Coastguard Worker   options.model_parameters[kLocationThreshold] =
1757*993b0882SAndroid Build Coastguard Worker       libtextclassifier3::Variant(0.35f);
1758*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1759*993b0882SAndroid Build Coastguard Worker       {kDiversificationParm, libtextclassifier3::Variant(0.5f)});
1760*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1761*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1762*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1763*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1764*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}},
1765*993b0882SAndroid Build Coastguard Worker           /*annotator=*/nullptr, options);
1766*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
1767*993b0882SAndroid Build Coastguard Worker       response.actions,
1768*993b0882SAndroid Build Coastguard Worker       ElementsAre(IsActionOfType("LOCATION_SHARE"), IsSmartReply("Here"),
1769*993b0882SAndroid Build Coastguard Worker                   IsSmartReply("Yes"), IsSmartReply("��")));
1770*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(), 4 /*1 location share + 3 smart replies*/);
1771*993b0882SAndroid Build Coastguard Worker }
1772*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder)1773*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest,
1774*993b0882SAndroid Build Coastguard Worker        MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder) {
1775*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1776*993b0882SAndroid Build Coastguard Worker       LoadMultiTaskTestModel();
1777*993b0882SAndroid Build Coastguard Worker   ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
1778*993b0882SAndroid Build Coastguard Worker   options.model_parameters[kLocationThreshold] =
1779*993b0882SAndroid Build Coastguard Worker       libtextclassifier3::Variant(0.35f);
1780*993b0882SAndroid Build Coastguard Worker   // reminder head always trigger since the threshold is zero.
1781*993b0882SAndroid Build Coastguard Worker   options.model_parameters[kReminderThreshold] =
1782*993b0882SAndroid Build Coastguard Worker       libtextclassifier3::Variant(0.0f);
1783*993b0882SAndroid Build Coastguard Worker   options.model_parameters.insert(
1784*993b0882SAndroid Build Coastguard Worker       {kEmpiricalProbFactor, libtextclassifier3::Variant(2.0f)});
1785*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1786*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1787*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
1788*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1789*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{}, /*locales=*/"en"}}},
1790*993b0882SAndroid Build Coastguard Worker           /*annotator=*/nullptr, options);
1791*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
1792*993b0882SAndroid Build Coastguard Worker       response.actions,
1793*993b0882SAndroid Build Coastguard Worker       ElementsAre(IsSmartReply("Okay"), IsActionOfType("LOCATION_SHARE"),
1794*993b0882SAndroid Build Coastguard Worker                   IsSmartReply("Yes"),
1795*993b0882SAndroid Build Coastguard Worker                   /*Different emoji than previous test*/ IsSmartReply("��"),
1796*993b0882SAndroid Build Coastguard Worker                   IsActionOfType("REMINDER_INTENT")));
1797*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(), 5 /*1 location share + 3 smart replies*/);
1798*993b0882SAndroid Build Coastguard Worker }
1799*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromMultiTaskSrEmojiModel)1800*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
1801*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1802*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kMultiTaskSrEmojiModelFileName);
1803*993b0882SAndroid Build Coastguard Worker 
1804*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1805*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1806*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "hello?",
1807*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
1808*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1809*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{},
1810*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
1811*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(), 5);
1812*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].response_text, "��");
1813*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "text_reply");
1814*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].response_text, "��");
1815*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "text_reply");
1816*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[2].response_text, "Yes");
1817*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[2].type, "text_reply");
1818*993b0882SAndroid Build Coastguard Worker }
1819*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,MultiTaskSrEmojiModelRemovesTextHeadEmoji)1820*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) {
1821*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1822*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kMultiTaskSrEmojiModelFileName);
1823*993b0882SAndroid Build Coastguard Worker 
1824*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1825*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1826*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "a pleasure chatting",
1827*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
1828*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1829*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{},
1830*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
1831*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(), 3);
1832*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].response_text, "��");
1833*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "text_reply");
1834*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].response_text, "��");
1835*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "text_reply");
1836*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[2].response_text, "Okay");
1837*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[2].type, "text_reply");
1838*993b0882SAndroid Build Coastguard Worker }
1839*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,MultiTaskSrEmojiModelUsesConcepts)1840*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelUsesConcepts) {
1841*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1842*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kMultiTaskSrEmojiConceptModelFileName);
1843*993b0882SAndroid Build Coastguard Worker 
1844*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1845*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1846*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "i am tired",
1847*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
1848*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1849*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{},
1850*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
1851*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> sigh_emojis = {"��", "��"};
1852*993b0882SAndroid Build Coastguard Worker 
1853*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(std::find(sigh_emojis.begin(), sigh_emojis.end(),
1854*993b0882SAndroid Build Coastguard Worker                         response.actions[0].response_text) !=
1855*993b0882SAndroid Build Coastguard Worker               sigh_emojis.end());
1856*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "emoji_reply");
1857*993b0882SAndroid Build Coastguard Worker }
1858*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,LiveRelayModel)1859*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
1860*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1861*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kLiveRelayTFLiteModelFileName);
1862*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1863*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1864*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "Hi",
1865*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
1866*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1867*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{},
1868*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
1869*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(), 3);
1870*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].response_text, "Hi how are you doing");
1871*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[0].type, "text_reply");
1872*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].response_text, "Hi whats up");
1873*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions[1].type, "text_reply");
1874*993b0882SAndroid Build Coastguard Worker }
1875*993b0882SAndroid Build Coastguard Worker 
TEST_F(ActionsSuggestionsTest,SuggestsActionsFromSensitiveTfLiteModel)1876*993b0882SAndroid Build Coastguard Worker TEST_F(ActionsSuggestionsTest, SuggestsActionsFromSensitiveTfLiteModel) {
1877*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestions> actions_suggestions =
1878*993b0882SAndroid Build Coastguard Worker       LoadTestModel(kSensitiveTFliteModelFileName);
1879*993b0882SAndroid Build Coastguard Worker   const ActionsSuggestionsResponse response =
1880*993b0882SAndroid Build Coastguard Worker       actions_suggestions->SuggestActions(
1881*993b0882SAndroid Build Coastguard Worker           {{{/*user_id=*/1, "I want to kill myself",
1882*993b0882SAndroid Build Coastguard Worker              /*reference_time_ms_utc=*/0,
1883*993b0882SAndroid Build Coastguard Worker              /*reference_timezone=*/"Europe/Zurich",
1884*993b0882SAndroid Build Coastguard Worker              /*annotations=*/{},
1885*993b0882SAndroid Build Coastguard Worker              /*locales=*/"en"}}});
1886*993b0882SAndroid Build Coastguard Worker   EXPECT_EQ(response.actions.size(), 0);
1887*993b0882SAndroid Build Coastguard Worker   EXPECT_TRUE(response.is_sensitive);
1888*993b0882SAndroid Build Coastguard Worker   EXPECT_FALSE(response.output_filtered_low_confidence);
1889*993b0882SAndroid Build Coastguard Worker }
1890*993b0882SAndroid Build Coastguard Worker 
1891*993b0882SAndroid Build Coastguard Worker }  // namespace
1892*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
1893