xref: /aosp_15_r20/external/libtextclassifier/native/actions/actions-suggestions.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/actions-suggestions.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <memory>
20*993b0882SAndroid Build Coastguard Worker #include <string>
21*993b0882SAndroid Build Coastguard Worker #include <vector>
22*993b0882SAndroid Build Coastguard Worker 
23*993b0882SAndroid Build Coastguard Worker #include "utils/base/statusor.h"
24*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h"
25*993b0882SAndroid Build Coastguard Worker #include "absl/random/random.h"
26*993b0882SAndroid Build Coastguard Worker 
27*993b0882SAndroid Build Coastguard Worker #if !defined(TC3_DISABLE_LUA)
28*993b0882SAndroid Build Coastguard Worker #include "actions/lua-actions.h"
29*993b0882SAndroid Build Coastguard Worker #endif
30*993b0882SAndroid Build Coastguard Worker #include "actions/ngram-model.h"
31*993b0882SAndroid Build Coastguard Worker #include "actions/tflite-sensitive-model.h"
32*993b0882SAndroid Build Coastguard Worker #include "actions/types.h"
33*993b0882SAndroid Build Coastguard Worker #include "actions/utils.h"
34*993b0882SAndroid Build Coastguard Worker #include "actions/zlib-utils.h"
35*993b0882SAndroid Build Coastguard Worker #include "annotator/collections.h"
36*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
37*993b0882SAndroid Build Coastguard Worker #if !defined(TC3_DISABLE_LUA)
38*993b0882SAndroid Build Coastguard Worker #include "utils/lua-utils.h"
39*993b0882SAndroid Build Coastguard Worker #endif
40*993b0882SAndroid Build Coastguard Worker #include "utils/normalization.h"
41*993b0882SAndroid Build Coastguard Worker #include "utils/optional.h"
42*993b0882SAndroid Build Coastguard Worker #include "utils/strings/split.h"
43*993b0882SAndroid Build Coastguard Worker #include "utils/strings/stringpiece.h"
44*993b0882SAndroid Build Coastguard Worker #include "utils/strings/utf8.h"
45*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
46*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_set.h"
47*993b0882SAndroid Build Coastguard Worker #include "absl/random/distributions.h"
48*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/string_util.h"
49*993b0882SAndroid Build Coastguard Worker 
50*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
51*993b0882SAndroid Build Coastguard Worker 
52*993b0882SAndroid Build Coastguard Worker constexpr float kDefaultFloat = 0.0;
53*993b0882SAndroid Build Coastguard Worker constexpr bool kDefaultBool = false;
54*993b0882SAndroid Build Coastguard Worker constexpr int kDefaultInt = 1;
55*993b0882SAndroid Build Coastguard Worker 
56*993b0882SAndroid Build Coastguard Worker namespace {
57*993b0882SAndroid Build Coastguard Worker 
LoadAndVerifyModel(const uint8_t * addr,int size)58*993b0882SAndroid Build Coastguard Worker const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
59*993b0882SAndroid Build Coastguard Worker   flatbuffers::Verifier verifier(addr, size);
60*993b0882SAndroid Build Coastguard Worker   if (VerifyActionsModelBuffer(verifier)) {
61*993b0882SAndroid Build Coastguard Worker     return GetActionsModel(addr);
62*993b0882SAndroid Build Coastguard Worker   } else {
63*993b0882SAndroid Build Coastguard Worker     return nullptr;
64*993b0882SAndroid Build Coastguard Worker   }
65*993b0882SAndroid Build Coastguard Worker }
66*993b0882SAndroid Build Coastguard Worker 
67*993b0882SAndroid Build Coastguard Worker template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)68*993b0882SAndroid Build Coastguard Worker T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
69*993b0882SAndroid Build Coastguard Worker                  const T default_value) {
70*993b0882SAndroid Build Coastguard Worker   if (values == nullptr) {
71*993b0882SAndroid Build Coastguard Worker     return default_value;
72*993b0882SAndroid Build Coastguard Worker   }
73*993b0882SAndroid Build Coastguard Worker   return values->GetField<T>(field_offset, default_value);
74*993b0882SAndroid Build Coastguard Worker }
75*993b0882SAndroid Build Coastguard Worker 
76*993b0882SAndroid Build Coastguard Worker // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)77*993b0882SAndroid Build Coastguard Worker int NumMessagesToConsider(const Conversation& conversation,
78*993b0882SAndroid Build Coastguard Worker                           const int max_conversation_history_length) {
79*993b0882SAndroid Build Coastguard Worker   return ((max_conversation_history_length < 0 ||
80*993b0882SAndroid Build Coastguard Worker            conversation.messages.size() < max_conversation_history_length)
81*993b0882SAndroid Build Coastguard Worker               ? conversation.messages.size()
82*993b0882SAndroid Build Coastguard Worker               : max_conversation_history_length);
83*993b0882SAndroid Build Coastguard Worker }
84*993b0882SAndroid Build Coastguard Worker 
85*993b0882SAndroid Build Coastguard Worker template <typename T>
PadOrTruncateToTargetLength(const std::vector<T> & inputs,const int max_length,const T pad_value)86*993b0882SAndroid Build Coastguard Worker std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs,
87*993b0882SAndroid Build Coastguard Worker                                            const int max_length,
88*993b0882SAndroid Build Coastguard Worker                                            const T pad_value) {
89*993b0882SAndroid Build Coastguard Worker   if (inputs.size() >= max_length) {
90*993b0882SAndroid Build Coastguard Worker     return std::vector<T>(inputs.begin(), inputs.begin() + max_length);
91*993b0882SAndroid Build Coastguard Worker   } else {
92*993b0882SAndroid Build Coastguard Worker     std::vector<T> result;
93*993b0882SAndroid Build Coastguard Worker     result.reserve(max_length);
94*993b0882SAndroid Build Coastguard Worker     result.insert(result.begin(), inputs.begin(), inputs.end());
95*993b0882SAndroid Build Coastguard Worker     result.insert(result.end(), max_length - inputs.size(), pad_value);
96*993b0882SAndroid Build Coastguard Worker     return result;
97*993b0882SAndroid Build Coastguard Worker   }
98*993b0882SAndroid Build Coastguard Worker }
99*993b0882SAndroid Build Coastguard Worker 
100*993b0882SAndroid Build Coastguard Worker template <typename T>
SetVectorOrScalarAsModelInput(const int param_index,const Variant & param_value,tflite::Interpreter * interpreter,const std::unique_ptr<const TfLiteModelExecutor> & model_executor)101*993b0882SAndroid Build Coastguard Worker void SetVectorOrScalarAsModelInput(
102*993b0882SAndroid Build Coastguard Worker     const int param_index, const Variant& param_value,
103*993b0882SAndroid Build Coastguard Worker     tflite::Interpreter* interpreter,
104*993b0882SAndroid Build Coastguard Worker     const std::unique_ptr<const TfLiteModelExecutor>& model_executor) {
105*993b0882SAndroid Build Coastguard Worker   if (param_value.Has<std::vector<T>>()) {
106*993b0882SAndroid Build Coastguard Worker     model_executor->SetInput<T>(
107*993b0882SAndroid Build Coastguard Worker         param_index, param_value.ConstRefValue<std::vector<T>>(), interpreter);
108*993b0882SAndroid Build Coastguard Worker   } else if (param_value.Has<T>()) {
109*993b0882SAndroid Build Coastguard Worker     model_executor->SetInput<float>(param_index, param_value.Value<T>(),
110*993b0882SAndroid Build Coastguard Worker                                     interpreter);
111*993b0882SAndroid Build Coastguard Worker   } else {
112*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Variant type error!";
113*993b0882SAndroid Build Coastguard Worker   }
114*993b0882SAndroid Build Coastguard Worker }
115*993b0882SAndroid Build Coastguard Worker }  // namespace
116*993b0882SAndroid Build Coastguard Worker 
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)117*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
118*993b0882SAndroid Build Coastguard Worker     const uint8_t* buffer, const int size, const UniLib* unilib,
119*993b0882SAndroid Build Coastguard Worker     const std::string& triggering_preconditions_overlay) {
120*993b0882SAndroid Build Coastguard Worker   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
121*993b0882SAndroid Build Coastguard Worker   const ActionsModel* model = LoadAndVerifyModel(buffer, size);
122*993b0882SAndroid Build Coastguard Worker   if (model == nullptr) {
123*993b0882SAndroid Build Coastguard Worker     return nullptr;
124*993b0882SAndroid Build Coastguard Worker   }
125*993b0882SAndroid Build Coastguard Worker   actions->model_ = model;
126*993b0882SAndroid Build Coastguard Worker   actions->SetOrCreateUnilib(unilib);
127*993b0882SAndroid Build Coastguard Worker   actions->triggering_preconditions_overlay_buffer_ =
128*993b0882SAndroid Build Coastguard Worker       triggering_preconditions_overlay;
129*993b0882SAndroid Build Coastguard Worker   if (!actions->ValidateAndInitialize()) {
130*993b0882SAndroid Build Coastguard Worker     return nullptr;
131*993b0882SAndroid Build Coastguard Worker   }
132*993b0882SAndroid Build Coastguard Worker   return actions;
133*993b0882SAndroid Build Coastguard Worker }
134*993b0882SAndroid Build Coastguard Worker 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)135*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
136*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
137*993b0882SAndroid Build Coastguard Worker     const std::string& triggering_preconditions_overlay) {
138*993b0882SAndroid Build Coastguard Worker   if (!mmap->handle().ok()) {
139*993b0882SAndroid Build Coastguard Worker     TC3_VLOG(1) << "Mmap failed.";
140*993b0882SAndroid Build Coastguard Worker     return nullptr;
141*993b0882SAndroid Build Coastguard Worker   }
142*993b0882SAndroid Build Coastguard Worker   const ActionsModel* model = LoadAndVerifyModel(
143*993b0882SAndroid Build Coastguard Worker       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
144*993b0882SAndroid Build Coastguard Worker       mmap->handle().num_bytes());
145*993b0882SAndroid Build Coastguard Worker   if (!model) {
146*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Model verification failed.";
147*993b0882SAndroid Build Coastguard Worker     return nullptr;
148*993b0882SAndroid Build Coastguard Worker   }
149*993b0882SAndroid Build Coastguard Worker   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
150*993b0882SAndroid Build Coastguard Worker   actions->model_ = model;
151*993b0882SAndroid Build Coastguard Worker   actions->mmap_ = std::move(mmap);
152*993b0882SAndroid Build Coastguard Worker   actions->SetOrCreateUnilib(unilib);
153*993b0882SAndroid Build Coastguard Worker   actions->triggering_preconditions_overlay_buffer_ =
154*993b0882SAndroid Build Coastguard Worker       triggering_preconditions_overlay;
155*993b0882SAndroid Build Coastguard Worker   if (!actions->ValidateAndInitialize()) {
156*993b0882SAndroid Build Coastguard Worker     return nullptr;
157*993b0882SAndroid Build Coastguard Worker   }
158*993b0882SAndroid Build Coastguard Worker   return actions;
159*993b0882SAndroid Build Coastguard Worker }
160*993b0882SAndroid Build Coastguard Worker 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)161*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
162*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
163*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<UniLib> unilib,
164*993b0882SAndroid Build Coastguard Worker     const std::string& triggering_preconditions_overlay) {
165*993b0882SAndroid Build Coastguard Worker   if (!mmap->handle().ok()) {
166*993b0882SAndroid Build Coastguard Worker     TC3_VLOG(1) << "Mmap failed.";
167*993b0882SAndroid Build Coastguard Worker     return nullptr;
168*993b0882SAndroid Build Coastguard Worker   }
169*993b0882SAndroid Build Coastguard Worker   const ActionsModel* model = LoadAndVerifyModel(
170*993b0882SAndroid Build Coastguard Worker       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
171*993b0882SAndroid Build Coastguard Worker       mmap->handle().num_bytes());
172*993b0882SAndroid Build Coastguard Worker   if (!model) {
173*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Model verification failed.";
174*993b0882SAndroid Build Coastguard Worker     return nullptr;
175*993b0882SAndroid Build Coastguard Worker   }
176*993b0882SAndroid Build Coastguard Worker   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
177*993b0882SAndroid Build Coastguard Worker   actions->model_ = model;
178*993b0882SAndroid Build Coastguard Worker   actions->mmap_ = std::move(mmap);
179*993b0882SAndroid Build Coastguard Worker   actions->owned_unilib_ = std::move(unilib);
180*993b0882SAndroid Build Coastguard Worker   actions->unilib_ = actions->owned_unilib_.get();
181*993b0882SAndroid Build Coastguard Worker   actions->triggering_preconditions_overlay_buffer_ =
182*993b0882SAndroid Build Coastguard Worker       triggering_preconditions_overlay;
183*993b0882SAndroid Build Coastguard Worker   if (!actions->ValidateAndInitialize()) {
184*993b0882SAndroid Build Coastguard Worker     return nullptr;
185*993b0882SAndroid Build Coastguard Worker   }
186*993b0882SAndroid Build Coastguard Worker   return actions;
187*993b0882SAndroid Build Coastguard Worker }
188*993b0882SAndroid Build Coastguard Worker 
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)189*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
190*993b0882SAndroid Build Coastguard Worker     const int fd, const int offset, const int size, const UniLib* unilib,
191*993b0882SAndroid Build Coastguard Worker     const std::string& triggering_preconditions_overlay) {
192*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
193*993b0882SAndroid Build Coastguard Worker   if (offset >= 0 && size >= 0) {
194*993b0882SAndroid Build Coastguard Worker     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
195*993b0882SAndroid Build Coastguard Worker   } else {
196*993b0882SAndroid Build Coastguard Worker     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
197*993b0882SAndroid Build Coastguard Worker   }
198*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(std::move(mmap), unilib,
199*993b0882SAndroid Build Coastguard Worker                         triggering_preconditions_overlay);
200*993b0882SAndroid Build Coastguard Worker }
201*993b0882SAndroid Build Coastguard Worker 
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)202*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
203*993b0882SAndroid Build Coastguard Worker     const int fd, const int offset, const int size,
204*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<UniLib> unilib,
205*993b0882SAndroid Build Coastguard Worker     const std::string& triggering_preconditions_overlay) {
206*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
207*993b0882SAndroid Build Coastguard Worker   if (offset >= 0 && size >= 0) {
208*993b0882SAndroid Build Coastguard Worker     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
209*993b0882SAndroid Build Coastguard Worker   } else {
210*993b0882SAndroid Build Coastguard Worker     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
211*993b0882SAndroid Build Coastguard Worker   }
212*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(std::move(mmap), std::move(unilib),
213*993b0882SAndroid Build Coastguard Worker                         triggering_preconditions_overlay);
214*993b0882SAndroid Build Coastguard Worker }
215*993b0882SAndroid Build Coastguard Worker 
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)216*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
217*993b0882SAndroid Build Coastguard Worker     const int fd, const UniLib* unilib,
218*993b0882SAndroid Build Coastguard Worker     const std::string& triggering_preconditions_overlay) {
219*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
220*993b0882SAndroid Build Coastguard Worker       new libtextclassifier3::ScopedMmap(fd));
221*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(std::move(mmap), unilib,
222*993b0882SAndroid Build Coastguard Worker                         triggering_preconditions_overlay);
223*993b0882SAndroid Build Coastguard Worker }
224*993b0882SAndroid Build Coastguard Worker 
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)225*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
226*993b0882SAndroid Build Coastguard Worker     const int fd, std::unique_ptr<UniLib> unilib,
227*993b0882SAndroid Build Coastguard Worker     const std::string& triggering_preconditions_overlay) {
228*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
229*993b0882SAndroid Build Coastguard Worker       new libtextclassifier3::ScopedMmap(fd));
230*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(std::move(mmap), std::move(unilib),
231*993b0882SAndroid Build Coastguard Worker                         triggering_preconditions_overlay);
232*993b0882SAndroid Build Coastguard Worker }
233*993b0882SAndroid Build Coastguard Worker 
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)234*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
235*993b0882SAndroid Build Coastguard Worker     const std::string& path, const UniLib* unilib,
236*993b0882SAndroid Build Coastguard Worker     const std::string& triggering_preconditions_overlay) {
237*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
238*993b0882SAndroid Build Coastguard Worker       new libtextclassifier3::ScopedMmap(path));
239*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(std::move(mmap), unilib,
240*993b0882SAndroid Build Coastguard Worker                         triggering_preconditions_overlay);
241*993b0882SAndroid Build Coastguard Worker }
242*993b0882SAndroid Build Coastguard Worker 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)243*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
244*993b0882SAndroid Build Coastguard Worker     const std::string& path, std::unique_ptr<UniLib> unilib,
245*993b0882SAndroid Build Coastguard Worker     const std::string& triggering_preconditions_overlay) {
246*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
247*993b0882SAndroid Build Coastguard Worker       new libtextclassifier3::ScopedMmap(path));
248*993b0882SAndroid Build Coastguard Worker   return FromScopedMmap(std::move(mmap), std::move(unilib),
249*993b0882SAndroid Build Coastguard Worker                         triggering_preconditions_overlay);
250*993b0882SAndroid Build Coastguard Worker }
251*993b0882SAndroid Build Coastguard Worker 
SetOrCreateUnilib(const UniLib * unilib)252*993b0882SAndroid Build Coastguard Worker void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
253*993b0882SAndroid Build Coastguard Worker   if (unilib != nullptr) {
254*993b0882SAndroid Build Coastguard Worker     unilib_ = unilib;
255*993b0882SAndroid Build Coastguard Worker   } else {
256*993b0882SAndroid Build Coastguard Worker     owned_unilib_.reset(new UniLib);
257*993b0882SAndroid Build Coastguard Worker     unilib_ = owned_unilib_.get();
258*993b0882SAndroid Build Coastguard Worker   }
259*993b0882SAndroid Build Coastguard Worker }
260*993b0882SAndroid Build Coastguard Worker 
ValidateAndInitialize()261*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::ValidateAndInitialize() {
262*993b0882SAndroid Build Coastguard Worker   if (model_ == nullptr) {
263*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "No model specified.";
264*993b0882SAndroid Build Coastguard Worker     return false;
265*993b0882SAndroid Build Coastguard Worker   }
266*993b0882SAndroid Build Coastguard Worker 
267*993b0882SAndroid Build Coastguard Worker   if (model_->smart_reply_action_type() == nullptr) {
268*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "No smart reply action type specified.";
269*993b0882SAndroid Build Coastguard Worker     return false;
270*993b0882SAndroid Build Coastguard Worker   }
271*993b0882SAndroid Build Coastguard Worker 
272*993b0882SAndroid Build Coastguard Worker   if (!InitializeTriggeringPreconditions()) {
273*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not initialize preconditions.";
274*993b0882SAndroid Build Coastguard Worker     return false;
275*993b0882SAndroid Build Coastguard Worker   }
276*993b0882SAndroid Build Coastguard Worker 
277*993b0882SAndroid Build Coastguard Worker   if (model_->locales() &&
278*993b0882SAndroid Build Coastguard Worker       !ParseLocales(model_->locales()->c_str(), &locales_)) {
279*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not parse model supported locales.";
280*993b0882SAndroid Build Coastguard Worker     return false;
281*993b0882SAndroid Build Coastguard Worker   }
282*993b0882SAndroid Build Coastguard Worker 
283*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec() != nullptr) {
284*993b0882SAndroid Build Coastguard Worker     model_executor_ = TfLiteModelExecutor::FromBuffer(
285*993b0882SAndroid Build Coastguard Worker         model_->tflite_model_spec()->tflite_model());
286*993b0882SAndroid Build Coastguard Worker     if (!model_executor_) {
287*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not initialize model executor.";
288*993b0882SAndroid Build Coastguard Worker       return false;
289*993b0882SAndroid Build Coastguard Worker     }
290*993b0882SAndroid Build Coastguard Worker   }
291*993b0882SAndroid Build Coastguard Worker 
292*993b0882SAndroid Build Coastguard Worker   // Gather annotation entities for the rules.
293*993b0882SAndroid Build Coastguard Worker   if (model_->annotation_actions_spec() != nullptr &&
294*993b0882SAndroid Build Coastguard Worker       model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
295*993b0882SAndroid Build Coastguard Worker     for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
296*993b0882SAndroid Build Coastguard Worker          *model_->annotation_actions_spec()->annotation_mapping()) {
297*993b0882SAndroid Build Coastguard Worker       annotation_entity_types_.insert(mapping->annotation_collection()->str());
298*993b0882SAndroid Build Coastguard Worker     }
299*993b0882SAndroid Build Coastguard Worker   }
300*993b0882SAndroid Build Coastguard Worker 
301*993b0882SAndroid Build Coastguard Worker   if (model_->actions_entity_data_schema() != nullptr) {
302*993b0882SAndroid Build Coastguard Worker     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
303*993b0882SAndroid Build Coastguard Worker         model_->actions_entity_data_schema()->Data(),
304*993b0882SAndroid Build Coastguard Worker         model_->actions_entity_data_schema()->size());
305*993b0882SAndroid Build Coastguard Worker     if (entity_data_schema_ == nullptr) {
306*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not load entity data schema data.";
307*993b0882SAndroid Build Coastguard Worker       return false;
308*993b0882SAndroid Build Coastguard Worker     }
309*993b0882SAndroid Build Coastguard Worker 
310*993b0882SAndroid Build Coastguard Worker     entity_data_builder_.reset(
311*993b0882SAndroid Build Coastguard Worker         new MutableFlatbufferBuilder(entity_data_schema_));
312*993b0882SAndroid Build Coastguard Worker   } else {
313*993b0882SAndroid Build Coastguard Worker     entity_data_schema_ = nullptr;
314*993b0882SAndroid Build Coastguard Worker   }
315*993b0882SAndroid Build Coastguard Worker 
316*993b0882SAndroid Build Coastguard Worker   // Initialize regular expressions model.
317*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
318*993b0882SAndroid Build Coastguard Worker   regex_actions_.reset(
319*993b0882SAndroid Build Coastguard Worker       new RegexActions(unilib_, model_->smart_reply_action_type()->str()));
320*993b0882SAndroid Build Coastguard Worker   if (!regex_actions_->InitializeRules(
321*993b0882SAndroid Build Coastguard Worker           model_->rules(), model_->low_confidence_rules(),
322*993b0882SAndroid Build Coastguard Worker           triggering_preconditions_overlay_, decompressor.get())) {
323*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not initialize regex rules.";
324*993b0882SAndroid Build Coastguard Worker     return false;
325*993b0882SAndroid Build Coastguard Worker   }
326*993b0882SAndroid Build Coastguard Worker 
327*993b0882SAndroid Build Coastguard Worker   // Setup grammar model.
328*993b0882SAndroid Build Coastguard Worker   if (model_->rules() != nullptr &&
329*993b0882SAndroid Build Coastguard Worker       model_->rules()->grammar_rules() != nullptr) {
330*993b0882SAndroid Build Coastguard Worker     grammar_actions_.reset(new GrammarActions(
331*993b0882SAndroid Build Coastguard Worker         unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
332*993b0882SAndroid Build Coastguard Worker         model_->smart_reply_action_type()->str()));
333*993b0882SAndroid Build Coastguard Worker 
334*993b0882SAndroid Build Coastguard Worker     // Gather annotation entities for the grammars.
335*993b0882SAndroid Build Coastguard Worker     if (auto annotation_nt = model_->rules()
336*993b0882SAndroid Build Coastguard Worker                                  ->grammar_rules()
337*993b0882SAndroid Build Coastguard Worker                                  ->rules()
338*993b0882SAndroid Build Coastguard Worker                                  ->nonterminals()
339*993b0882SAndroid Build Coastguard Worker                                  ->annotation_nt()) {
340*993b0882SAndroid Build Coastguard Worker       for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
341*993b0882SAndroid Build Coastguard Worker            *annotation_nt) {
342*993b0882SAndroid Build Coastguard Worker         annotation_entity_types_.insert(entry->key()->str());
343*993b0882SAndroid Build Coastguard Worker       }
344*993b0882SAndroid Build Coastguard Worker     }
345*993b0882SAndroid Build Coastguard Worker   }
346*993b0882SAndroid Build Coastguard Worker 
347*993b0882SAndroid Build Coastguard Worker #if !defined(TC3_DISABLE_LUA)
348*993b0882SAndroid Build Coastguard Worker   std::string actions_script;
349*993b0882SAndroid Build Coastguard Worker   if (GetUncompressedString(model_->lua_actions_script(),
350*993b0882SAndroid Build Coastguard Worker                             model_->compressed_lua_actions_script(),
351*993b0882SAndroid Build Coastguard Worker                             decompressor.get(), &actions_script) &&
352*993b0882SAndroid Build Coastguard Worker       !actions_script.empty()) {
353*993b0882SAndroid Build Coastguard Worker     if (!Compile(actions_script, &lua_bytecode_)) {
354*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
355*993b0882SAndroid Build Coastguard Worker       return false;
356*993b0882SAndroid Build Coastguard Worker     }
357*993b0882SAndroid Build Coastguard Worker   }
358*993b0882SAndroid Build Coastguard Worker #endif  // TC3_DISABLE_LUA
359*993b0882SAndroid Build Coastguard Worker 
360*993b0882SAndroid Build Coastguard Worker   if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
361*993b0882SAndroid Build Coastguard Worker             model_->ranking_options(), decompressor.get(),
362*993b0882SAndroid Build Coastguard Worker             model_->smart_reply_action_type()->str()))) {
363*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
364*993b0882SAndroid Build Coastguard Worker     return false;
365*993b0882SAndroid Build Coastguard Worker   }
366*993b0882SAndroid Build Coastguard Worker 
367*993b0882SAndroid Build Coastguard Worker   // Create feature processor if specified.
368*993b0882SAndroid Build Coastguard Worker   const ActionsTokenFeatureProcessorOptions* options =
369*993b0882SAndroid Build Coastguard Worker       model_->feature_processor_options();
370*993b0882SAndroid Build Coastguard Worker   if (options != nullptr) {
371*993b0882SAndroid Build Coastguard Worker     if (options->tokenizer_options() == nullptr) {
372*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No tokenizer options specified.";
373*993b0882SAndroid Build Coastguard Worker       return false;
374*993b0882SAndroid Build Coastguard Worker     }
375*993b0882SAndroid Build Coastguard Worker 
376*993b0882SAndroid Build Coastguard Worker     feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
377*993b0882SAndroid Build Coastguard Worker     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
378*993b0882SAndroid Build Coastguard Worker         options->embedding_model(), options->embedding_size(),
379*993b0882SAndroid Build Coastguard Worker         options->embedding_quantization_bits());
380*993b0882SAndroid Build Coastguard Worker 
381*993b0882SAndroid Build Coastguard Worker     if (embedding_executor_ == nullptr) {
382*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
383*993b0882SAndroid Build Coastguard Worker       return false;
384*993b0882SAndroid Build Coastguard Worker     }
385*993b0882SAndroid Build Coastguard Worker 
386*993b0882SAndroid Build Coastguard Worker     // Cache embedding of padding, start and end token.
387*993b0882SAndroid Build Coastguard Worker     if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
388*993b0882SAndroid Build Coastguard Worker         !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
389*993b0882SAndroid Build Coastguard Worker         !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
390*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not precompute token embeddings.";
391*993b0882SAndroid Build Coastguard Worker       return false;
392*993b0882SAndroid Build Coastguard Worker     }
393*993b0882SAndroid Build Coastguard Worker     token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
394*993b0882SAndroid Build Coastguard Worker   }
395*993b0882SAndroid Build Coastguard Worker 
396*993b0882SAndroid Build Coastguard Worker   // Create low confidence model if specified.
397*993b0882SAndroid Build Coastguard Worker   if (model_->low_confidence_ngram_model() != nullptr) {
398*993b0882SAndroid Build Coastguard Worker     sensitive_model_ = NGramSensitiveModel::Create(
399*993b0882SAndroid Build Coastguard Worker         unilib_, model_->low_confidence_ngram_model(),
400*993b0882SAndroid Build Coastguard Worker         feature_processor_ == nullptr ? nullptr
401*993b0882SAndroid Build Coastguard Worker                                       : feature_processor_->tokenizer());
402*993b0882SAndroid Build Coastguard Worker     if (sensitive_model_ == nullptr) {
403*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
404*993b0882SAndroid Build Coastguard Worker       return false;
405*993b0882SAndroid Build Coastguard Worker     }
406*993b0882SAndroid Build Coastguard Worker   }
407*993b0882SAndroid Build Coastguard Worker   if (model_->low_confidence_tflite_model() != nullptr) {
408*993b0882SAndroid Build Coastguard Worker     sensitive_model_ =
409*993b0882SAndroid Build Coastguard Worker         TFLiteSensitiveModel::Create(model_->low_confidence_tflite_model());
410*993b0882SAndroid Build Coastguard Worker     if (sensitive_model_ == nullptr) {
411*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not create TFLite sensitive model.";
412*993b0882SAndroid Build Coastguard Worker       return false;
413*993b0882SAndroid Build Coastguard Worker     }
414*993b0882SAndroid Build Coastguard Worker   }
415*993b0882SAndroid Build Coastguard Worker 
416*993b0882SAndroid Build Coastguard Worker   return true;
417*993b0882SAndroid Build Coastguard Worker }
418*993b0882SAndroid Build Coastguard Worker 
InitializeTriggeringPreconditions()419*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::InitializeTriggeringPreconditions() {
420*993b0882SAndroid Build Coastguard Worker   triggering_preconditions_overlay_ =
421*993b0882SAndroid Build Coastguard Worker       LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
422*993b0882SAndroid Build Coastguard Worker           triggering_preconditions_overlay_buffer_);
423*993b0882SAndroid Build Coastguard Worker 
424*993b0882SAndroid Build Coastguard Worker   if (triggering_preconditions_overlay_ == nullptr &&
425*993b0882SAndroid Build Coastguard Worker       !triggering_preconditions_overlay_buffer_.empty()) {
426*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
427*993b0882SAndroid Build Coastguard Worker     return false;
428*993b0882SAndroid Build Coastguard Worker   }
429*993b0882SAndroid Build Coastguard Worker   const flatbuffers::Table* overlay =
430*993b0882SAndroid Build Coastguard Worker       reinterpret_cast<const flatbuffers::Table*>(
431*993b0882SAndroid Build Coastguard Worker           triggering_preconditions_overlay_);
432*993b0882SAndroid Build Coastguard Worker   const TriggeringPreconditions* defaults = model_->preconditions();
433*993b0882SAndroid Build Coastguard Worker   if (defaults == nullptr) {
434*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "No triggering conditions specified.";
435*993b0882SAndroid Build Coastguard Worker     return false;
436*993b0882SAndroid Build Coastguard Worker   }
437*993b0882SAndroid Build Coastguard Worker 
438*993b0882SAndroid Build Coastguard Worker   preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
439*993b0882SAndroid Build Coastguard Worker       overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
440*993b0882SAndroid Build Coastguard Worker       defaults->min_smart_reply_triggering_score());
441*993b0882SAndroid Build Coastguard Worker   preconditions_.max_sensitive_topic_score = ValueOrDefault(
442*993b0882SAndroid Build Coastguard Worker       overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
443*993b0882SAndroid Build Coastguard Worker       defaults->max_sensitive_topic_score());
444*993b0882SAndroid Build Coastguard Worker   preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
445*993b0882SAndroid Build Coastguard Worker       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
446*993b0882SAndroid Build Coastguard Worker       defaults->suppress_on_sensitive_topic());
447*993b0882SAndroid Build Coastguard Worker   preconditions_.min_input_length =
448*993b0882SAndroid Build Coastguard Worker       ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
449*993b0882SAndroid Build Coastguard Worker                      defaults->min_input_length());
450*993b0882SAndroid Build Coastguard Worker   preconditions_.max_input_length =
451*993b0882SAndroid Build Coastguard Worker       ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
452*993b0882SAndroid Build Coastguard Worker                      defaults->max_input_length());
453*993b0882SAndroid Build Coastguard Worker   preconditions_.min_locale_match_fraction = ValueOrDefault(
454*993b0882SAndroid Build Coastguard Worker       overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
455*993b0882SAndroid Build Coastguard Worker       defaults->min_locale_match_fraction());
456*993b0882SAndroid Build Coastguard Worker   preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
457*993b0882SAndroid Build Coastguard Worker       overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
458*993b0882SAndroid Build Coastguard Worker       defaults->handle_missing_locale_as_supported());
459*993b0882SAndroid Build Coastguard Worker   preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
460*993b0882SAndroid Build Coastguard Worker       overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
461*993b0882SAndroid Build Coastguard Worker       defaults->handle_unknown_locale_as_supported());
462*993b0882SAndroid Build Coastguard Worker   preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
463*993b0882SAndroid Build Coastguard Worker       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
464*993b0882SAndroid Build Coastguard Worker       defaults->suppress_on_low_confidence_input());
465*993b0882SAndroid Build Coastguard Worker   preconditions_.min_reply_score_threshold = ValueOrDefault(
466*993b0882SAndroid Build Coastguard Worker       overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
467*993b0882SAndroid Build Coastguard Worker       defaults->min_reply_score_threshold());
468*993b0882SAndroid Build Coastguard Worker 
469*993b0882SAndroid Build Coastguard Worker   return true;
470*993b0882SAndroid Build Coastguard Worker }
471*993b0882SAndroid Build Coastguard Worker 
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const472*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
473*993b0882SAndroid Build Coastguard Worker                                       std::vector<float>* embedding) const {
474*993b0882SAndroid Build Coastguard Worker   return feature_processor_->AppendFeatures(
475*993b0882SAndroid Build Coastguard Worker       {token_id},
476*993b0882SAndroid Build Coastguard Worker       /*dense_features=*/{}, embedding_executor_.get(), embedding);
477*993b0882SAndroid Build Coastguard Worker }
478*993b0882SAndroid Build Coastguard Worker 
Tokenize(const std::vector<std::string> & context) const479*993b0882SAndroid Build Coastguard Worker std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
480*993b0882SAndroid Build Coastguard Worker     const std::vector<std::string>& context) const {
481*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens;
482*993b0882SAndroid Build Coastguard Worker   tokens.reserve(context.size());
483*993b0882SAndroid Build Coastguard Worker   for (const std::string& message : context) {
484*993b0882SAndroid Build Coastguard Worker     tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
485*993b0882SAndroid Build Coastguard Worker   }
486*993b0882SAndroid Build Coastguard Worker   return tokens;
487*993b0882SAndroid Build Coastguard Worker }
488*993b0882SAndroid Build Coastguard Worker 
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const489*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::EmbedTokensPerMessage(
490*993b0882SAndroid Build Coastguard Worker     const std::vector<std::vector<Token>>& tokens,
491*993b0882SAndroid Build Coastguard Worker     std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
492*993b0882SAndroid Build Coastguard Worker   const int num_messages = tokens.size();
493*993b0882SAndroid Build Coastguard Worker   *max_num_tokens_per_message = 0;
494*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_messages; i++) {
495*993b0882SAndroid Build Coastguard Worker     const int num_message_tokens = tokens[i].size();
496*993b0882SAndroid Build Coastguard Worker     if (num_message_tokens > *max_num_tokens_per_message) {
497*993b0882SAndroid Build Coastguard Worker       *max_num_tokens_per_message = num_message_tokens;
498*993b0882SAndroid Build Coastguard Worker     }
499*993b0882SAndroid Build Coastguard Worker   }
500*993b0882SAndroid Build Coastguard Worker 
501*993b0882SAndroid Build Coastguard Worker   if (model_->feature_processor_options()->min_num_tokens_per_message() >
502*993b0882SAndroid Build Coastguard Worker       *max_num_tokens_per_message) {
503*993b0882SAndroid Build Coastguard Worker     *max_num_tokens_per_message =
504*993b0882SAndroid Build Coastguard Worker         model_->feature_processor_options()->min_num_tokens_per_message();
505*993b0882SAndroid Build Coastguard Worker   }
506*993b0882SAndroid Build Coastguard Worker   if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
507*993b0882SAndroid Build Coastguard Worker       *max_num_tokens_per_message >
508*993b0882SAndroid Build Coastguard Worker           model_->feature_processor_options()->max_num_tokens_per_message()) {
509*993b0882SAndroid Build Coastguard Worker     *max_num_tokens_per_message =
510*993b0882SAndroid Build Coastguard Worker         model_->feature_processor_options()->max_num_tokens_per_message();
511*993b0882SAndroid Build Coastguard Worker   }
512*993b0882SAndroid Build Coastguard Worker 
513*993b0882SAndroid Build Coastguard Worker   // Embed all tokens and add paddings to pad tokens of each message to the
514*993b0882SAndroid Build Coastguard Worker   // maximum number of tokens in a message of the conversation.
515*993b0882SAndroid Build Coastguard Worker   // If a number of tokens is specified in the model config, tokens at the
516*993b0882SAndroid Build Coastguard Worker   // beginning of a message are dropped if they don't fit in the limit.
517*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_messages; i++) {
518*993b0882SAndroid Build Coastguard Worker     const int start =
519*993b0882SAndroid Build Coastguard Worker         std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
520*993b0882SAndroid Build Coastguard Worker     for (int pos = start; pos < tokens[i].size(); pos++) {
521*993b0882SAndroid Build Coastguard Worker       if (!feature_processor_->AppendTokenFeatures(
522*993b0882SAndroid Build Coastguard Worker               tokens[i][pos], embedding_executor_.get(), embeddings)) {
523*993b0882SAndroid Build Coastguard Worker         TC3_LOG(ERROR) << "Could not run token feature extractor.";
524*993b0882SAndroid Build Coastguard Worker         return false;
525*993b0882SAndroid Build Coastguard Worker       }
526*993b0882SAndroid Build Coastguard Worker     }
527*993b0882SAndroid Build Coastguard Worker     // Add padding.
528*993b0882SAndroid Build Coastguard Worker     for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
529*993b0882SAndroid Build Coastguard Worker       embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
530*993b0882SAndroid Build Coastguard Worker                          embedded_padding_token_.end());
531*993b0882SAndroid Build Coastguard Worker     }
532*993b0882SAndroid Build Coastguard Worker   }
533*993b0882SAndroid Build Coastguard Worker 
534*993b0882SAndroid Build Coastguard Worker   return true;
535*993b0882SAndroid Build Coastguard Worker }
536*993b0882SAndroid Build Coastguard Worker 
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * total_token_count) const537*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::EmbedAndFlattenTokens(
538*993b0882SAndroid Build Coastguard Worker     const std::vector<std::vector<Token>>& tokens,
539*993b0882SAndroid Build Coastguard Worker     std::vector<float>* embeddings, int* total_token_count) const {
540*993b0882SAndroid Build Coastguard Worker   const int num_messages = tokens.size();
541*993b0882SAndroid Build Coastguard Worker   int start_message = 0;
542*993b0882SAndroid Build Coastguard Worker   int message_token_offset = 0;
543*993b0882SAndroid Build Coastguard Worker 
544*993b0882SAndroid Build Coastguard Worker   // If a maximum model input length is specified, we need to check how
545*993b0882SAndroid Build Coastguard Worker   // much we need to trim at the start.
546*993b0882SAndroid Build Coastguard Worker   const int max_num_total_tokens =
547*993b0882SAndroid Build Coastguard Worker       model_->feature_processor_options()->max_num_total_tokens();
548*993b0882SAndroid Build Coastguard Worker   if (max_num_total_tokens > 0) {
549*993b0882SAndroid Build Coastguard Worker     int total_tokens = 0;
550*993b0882SAndroid Build Coastguard Worker     start_message = num_messages - 1;
551*993b0882SAndroid Build Coastguard Worker     for (; start_message >= 0; start_message--) {
552*993b0882SAndroid Build Coastguard Worker       // Tokens of the message + start and end token.
553*993b0882SAndroid Build Coastguard Worker       const int num_message_tokens = tokens[start_message].size() + 2;
554*993b0882SAndroid Build Coastguard Worker       total_tokens += num_message_tokens;
555*993b0882SAndroid Build Coastguard Worker 
556*993b0882SAndroid Build Coastguard Worker       // Check whether we exhausted the budget.
557*993b0882SAndroid Build Coastguard Worker       if (total_tokens >= max_num_total_tokens) {
558*993b0882SAndroid Build Coastguard Worker         message_token_offset = total_tokens - max_num_total_tokens;
559*993b0882SAndroid Build Coastguard Worker         break;
560*993b0882SAndroid Build Coastguard Worker       }
561*993b0882SAndroid Build Coastguard Worker     }
562*993b0882SAndroid Build Coastguard Worker   }
563*993b0882SAndroid Build Coastguard Worker 
564*993b0882SAndroid Build Coastguard Worker   // Add embeddings.
565*993b0882SAndroid Build Coastguard Worker   *total_token_count = 0;
566*993b0882SAndroid Build Coastguard Worker   for (int i = start_message; i < num_messages; i++) {
567*993b0882SAndroid Build Coastguard Worker     if (message_token_offset == 0) {
568*993b0882SAndroid Build Coastguard Worker       ++(*total_token_count);
569*993b0882SAndroid Build Coastguard Worker       // Add `start message` token.
570*993b0882SAndroid Build Coastguard Worker       embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
571*993b0882SAndroid Build Coastguard Worker                          embedded_start_token_.end());
572*993b0882SAndroid Build Coastguard Worker     }
573*993b0882SAndroid Build Coastguard Worker 
574*993b0882SAndroid Build Coastguard Worker     for (int pos = std::max(0, message_token_offset - 1);
575*993b0882SAndroid Build Coastguard Worker          pos < tokens[i].size(); pos++) {
576*993b0882SAndroid Build Coastguard Worker       ++(*total_token_count);
577*993b0882SAndroid Build Coastguard Worker       if (!feature_processor_->AppendTokenFeatures(
578*993b0882SAndroid Build Coastguard Worker               tokens[i][pos], embedding_executor_.get(), embeddings)) {
579*993b0882SAndroid Build Coastguard Worker         TC3_LOG(ERROR) << "Could not run token feature extractor.";
580*993b0882SAndroid Build Coastguard Worker         return false;
581*993b0882SAndroid Build Coastguard Worker       }
582*993b0882SAndroid Build Coastguard Worker     }
583*993b0882SAndroid Build Coastguard Worker 
584*993b0882SAndroid Build Coastguard Worker     // Add `end message` token.
585*993b0882SAndroid Build Coastguard Worker     ++(*total_token_count);
586*993b0882SAndroid Build Coastguard Worker     embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
587*993b0882SAndroid Build Coastguard Worker                        embedded_end_token_.end());
588*993b0882SAndroid Build Coastguard Worker 
589*993b0882SAndroid Build Coastguard Worker     // Reset for the subsequent messages.
590*993b0882SAndroid Build Coastguard Worker     message_token_offset = 0;
591*993b0882SAndroid Build Coastguard Worker   }
592*993b0882SAndroid Build Coastguard Worker 
593*993b0882SAndroid Build Coastguard Worker   // Add optional padding.
594*993b0882SAndroid Build Coastguard Worker   const int min_num_total_tokens =
595*993b0882SAndroid Build Coastguard Worker       model_->feature_processor_options()->min_num_total_tokens();
596*993b0882SAndroid Build Coastguard Worker   for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
597*993b0882SAndroid Build Coastguard Worker     embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
598*993b0882SAndroid Build Coastguard Worker                        embedded_padding_token_.end());
599*993b0882SAndroid Build Coastguard Worker   }
600*993b0882SAndroid Build Coastguard Worker 
601*993b0882SAndroid Build Coastguard Worker   return true;
602*993b0882SAndroid Build Coastguard Worker }
603*993b0882SAndroid Build Coastguard Worker 
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const604*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::AllocateInput(const int conversation_length,
605*993b0882SAndroid Build Coastguard Worker                                        const int max_tokens,
606*993b0882SAndroid Build Coastguard Worker                                        const int total_token_count,
607*993b0882SAndroid Build Coastguard Worker                                        tflite::Interpreter* interpreter) const {
608*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->resize_inputs()) {
609*993b0882SAndroid Build Coastguard Worker     if (model_->tflite_model_spec()->input_context() >= 0) {
610*993b0882SAndroid Build Coastguard Worker       interpreter->ResizeInputTensor(
611*993b0882SAndroid Build Coastguard Worker           interpreter->inputs()[model_->tflite_model_spec()->input_context()],
612*993b0882SAndroid Build Coastguard Worker           {1, conversation_length});
613*993b0882SAndroid Build Coastguard Worker     }
614*993b0882SAndroid Build Coastguard Worker     if (model_->tflite_model_spec()->input_user_id() >= 0) {
615*993b0882SAndroid Build Coastguard Worker       interpreter->ResizeInputTensor(
616*993b0882SAndroid Build Coastguard Worker           interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
617*993b0882SAndroid Build Coastguard Worker           {1, conversation_length});
618*993b0882SAndroid Build Coastguard Worker     }
619*993b0882SAndroid Build Coastguard Worker     if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
620*993b0882SAndroid Build Coastguard Worker       interpreter->ResizeInputTensor(
621*993b0882SAndroid Build Coastguard Worker           interpreter
622*993b0882SAndroid Build Coastguard Worker               ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
623*993b0882SAndroid Build Coastguard Worker           {1, conversation_length});
624*993b0882SAndroid Build Coastguard Worker     }
625*993b0882SAndroid Build Coastguard Worker     if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
626*993b0882SAndroid Build Coastguard Worker       interpreter->ResizeInputTensor(
627*993b0882SAndroid Build Coastguard Worker           interpreter
628*993b0882SAndroid Build Coastguard Worker               ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
629*993b0882SAndroid Build Coastguard Worker           {conversation_length, 1});
630*993b0882SAndroid Build Coastguard Worker     }
631*993b0882SAndroid Build Coastguard Worker     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
632*993b0882SAndroid Build Coastguard Worker       interpreter->ResizeInputTensor(
633*993b0882SAndroid Build Coastguard Worker           interpreter
634*993b0882SAndroid Build Coastguard Worker               ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
635*993b0882SAndroid Build Coastguard Worker           {conversation_length, max_tokens, token_embedding_size_});
636*993b0882SAndroid Build Coastguard Worker     }
637*993b0882SAndroid Build Coastguard Worker     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
638*993b0882SAndroid Build Coastguard Worker       interpreter->ResizeInputTensor(
639*993b0882SAndroid Build Coastguard Worker           interpreter->inputs()[model_->tflite_model_spec()
640*993b0882SAndroid Build Coastguard Worker                                     ->input_flattened_token_embeddings()],
641*993b0882SAndroid Build Coastguard Worker           {1, total_token_count});
642*993b0882SAndroid Build Coastguard Worker     }
643*993b0882SAndroid Build Coastguard Worker   }
644*993b0882SAndroid Build Coastguard Worker 
645*993b0882SAndroid Build Coastguard Worker   return interpreter->AllocateTensors() == kTfLiteOk;
646*993b0882SAndroid Build Coastguard Worker }
647*993b0882SAndroid Build Coastguard Worker 
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const ActionSuggestionOptions & options,tflite::Interpreter * interpreter) const648*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::SetupModelInput(
649*993b0882SAndroid Build Coastguard Worker     const std::vector<std::string>& context, const std::vector<int>& user_ids,
650*993b0882SAndroid Build Coastguard Worker     const std::vector<float>& time_diffs, const int num_suggestions,
651*993b0882SAndroid Build Coastguard Worker     const ActionSuggestionOptions& options,
652*993b0882SAndroid Build Coastguard Worker     tflite::Interpreter* interpreter) const {
653*993b0882SAndroid Build Coastguard Worker   // Compute token embeddings.
654*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> tokens;
655*993b0882SAndroid Build Coastguard Worker   std::vector<float> token_embeddings;
656*993b0882SAndroid Build Coastguard Worker   std::vector<float> flattened_token_embeddings;
657*993b0882SAndroid Build Coastguard Worker   int max_tokens = 0;
658*993b0882SAndroid Build Coastguard Worker   int total_token_count = 0;
659*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
660*993b0882SAndroid Build Coastguard Worker       model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
661*993b0882SAndroid Build Coastguard Worker       model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
662*993b0882SAndroid Build Coastguard Worker     if (feature_processor_ == nullptr) {
663*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "No feature processor specified.";
664*993b0882SAndroid Build Coastguard Worker       return false;
665*993b0882SAndroid Build Coastguard Worker     }
666*993b0882SAndroid Build Coastguard Worker 
667*993b0882SAndroid Build Coastguard Worker     // Tokenize the messages in the conversation.
668*993b0882SAndroid Build Coastguard Worker     tokens = Tokenize(context);
669*993b0882SAndroid Build Coastguard Worker     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
670*993b0882SAndroid Build Coastguard Worker       if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
671*993b0882SAndroid Build Coastguard Worker         TC3_LOG(ERROR) << "Could not extract token features.";
672*993b0882SAndroid Build Coastguard Worker         return false;
673*993b0882SAndroid Build Coastguard Worker       }
674*993b0882SAndroid Build Coastguard Worker     }
675*993b0882SAndroid Build Coastguard Worker     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
676*993b0882SAndroid Build Coastguard Worker       if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
677*993b0882SAndroid Build Coastguard Worker                                  &total_token_count)) {
678*993b0882SAndroid Build Coastguard Worker         TC3_LOG(ERROR) << "Could not extract token features.";
679*993b0882SAndroid Build Coastguard Worker         return false;
680*993b0882SAndroid Build Coastguard Worker       }
681*993b0882SAndroid Build Coastguard Worker     }
682*993b0882SAndroid Build Coastguard Worker   }
683*993b0882SAndroid Build Coastguard Worker 
684*993b0882SAndroid Build Coastguard Worker   if (!AllocateInput(context.size(), max_tokens, total_token_count,
685*993b0882SAndroid Build Coastguard Worker                      interpreter)) {
686*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
687*993b0882SAndroid Build Coastguard Worker     return false;
688*993b0882SAndroid Build Coastguard Worker   }
689*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->input_context() >= 0) {
690*993b0882SAndroid Build Coastguard Worker     if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
691*993b0882SAndroid Build Coastguard Worker       model_executor_->SetInput<std::string>(
692*993b0882SAndroid Build Coastguard Worker           model_->tflite_model_spec()->input_context(),
693*993b0882SAndroid Build Coastguard Worker           PadOrTruncateToTargetLength(
694*993b0882SAndroid Build Coastguard Worker               context, model_->tflite_model_spec()->input_length_to_pad(),
695*993b0882SAndroid Build Coastguard Worker               std::string("")),
696*993b0882SAndroid Build Coastguard Worker           interpreter);
697*993b0882SAndroid Build Coastguard Worker     } else {
698*993b0882SAndroid Build Coastguard Worker       model_executor_->SetInput<std::string>(
699*993b0882SAndroid Build Coastguard Worker           model_->tflite_model_spec()->input_context(), context, interpreter);
700*993b0882SAndroid Build Coastguard Worker     }
701*993b0882SAndroid Build Coastguard Worker   }
702*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->input_context_length() >= 0) {
703*993b0882SAndroid Build Coastguard Worker     model_executor_->SetInput<int>(
704*993b0882SAndroid Build Coastguard Worker         model_->tflite_model_spec()->input_context_length(), context.size(),
705*993b0882SAndroid Build Coastguard Worker         interpreter);
706*993b0882SAndroid Build Coastguard Worker   }
707*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->input_user_id() >= 0) {
708*993b0882SAndroid Build Coastguard Worker     if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
709*993b0882SAndroid Build Coastguard Worker       model_executor_->SetInput<int>(
710*993b0882SAndroid Build Coastguard Worker           model_->tflite_model_spec()->input_user_id(),
711*993b0882SAndroid Build Coastguard Worker           PadOrTruncateToTargetLength(
712*993b0882SAndroid Build Coastguard Worker               user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0),
713*993b0882SAndroid Build Coastguard Worker           interpreter);
714*993b0882SAndroid Build Coastguard Worker     } else {
715*993b0882SAndroid Build Coastguard Worker       model_executor_->SetInput<int>(
716*993b0882SAndroid Build Coastguard Worker           model_->tflite_model_spec()->input_user_id(), user_ids, interpreter);
717*993b0882SAndroid Build Coastguard Worker     }
718*993b0882SAndroid Build Coastguard Worker   }
719*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
720*993b0882SAndroid Build Coastguard Worker     model_executor_->SetInput<int>(
721*993b0882SAndroid Build Coastguard Worker         model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
722*993b0882SAndroid Build Coastguard Worker         interpreter);
723*993b0882SAndroid Build Coastguard Worker   }
724*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
725*993b0882SAndroid Build Coastguard Worker     model_executor_->SetInput<float>(
726*993b0882SAndroid Build Coastguard Worker         model_->tflite_model_spec()->input_time_diffs(), time_diffs,
727*993b0882SAndroid Build Coastguard Worker         interpreter);
728*993b0882SAndroid Build Coastguard Worker   }
729*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
730*993b0882SAndroid Build Coastguard Worker     std::vector<int> num_tokens_per_message(tokens.size());
731*993b0882SAndroid Build Coastguard Worker     for (int i = 0; i < tokens.size(); i++) {
732*993b0882SAndroid Build Coastguard Worker       num_tokens_per_message[i] = tokens[i].size();
733*993b0882SAndroid Build Coastguard Worker     }
734*993b0882SAndroid Build Coastguard Worker     model_executor_->SetInput<int>(
735*993b0882SAndroid Build Coastguard Worker         model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
736*993b0882SAndroid Build Coastguard Worker         interpreter);
737*993b0882SAndroid Build Coastguard Worker   }
738*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
739*993b0882SAndroid Build Coastguard Worker     model_executor_->SetInput<float>(
740*993b0882SAndroid Build Coastguard Worker         model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
741*993b0882SAndroid Build Coastguard Worker         interpreter);
742*993b0882SAndroid Build Coastguard Worker   }
743*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
744*993b0882SAndroid Build Coastguard Worker     model_executor_->SetInput<float>(
745*993b0882SAndroid Build Coastguard Worker         model_->tflite_model_spec()->input_flattened_token_embeddings(),
746*993b0882SAndroid Build Coastguard Worker         flattened_token_embeddings, interpreter);
747*993b0882SAndroid Build Coastguard Worker   }
748*993b0882SAndroid Build Coastguard Worker   // Set up additional input parameters.
749*993b0882SAndroid Build Coastguard Worker   if (const auto* input_name_index =
750*993b0882SAndroid Build Coastguard Worker           model_->tflite_model_spec()->input_name_index()) {
751*993b0882SAndroid Build Coastguard Worker     const std::unordered_map<std::string, Variant>& model_parameters =
752*993b0882SAndroid Build Coastguard Worker         options.model_parameters;
753*993b0882SAndroid Build Coastguard Worker     for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry :
754*993b0882SAndroid Build Coastguard Worker          *input_name_index) {
755*993b0882SAndroid Build Coastguard Worker       const std::string param_name = entry->key()->str();
756*993b0882SAndroid Build Coastguard Worker       const int param_index = entry->value();
757*993b0882SAndroid Build Coastguard Worker       const TfLiteType param_type =
758*993b0882SAndroid Build Coastguard Worker           interpreter->tensor(interpreter->inputs()[param_index])->type;
759*993b0882SAndroid Build Coastguard Worker       const auto param_value_it = model_parameters.find(param_name);
760*993b0882SAndroid Build Coastguard Worker       const bool has_value = param_value_it != model_parameters.end();
761*993b0882SAndroid Build Coastguard Worker       switch (param_type) {
762*993b0882SAndroid Build Coastguard Worker         case kTfLiteFloat32:
763*993b0882SAndroid Build Coastguard Worker           if (has_value) {
764*993b0882SAndroid Build Coastguard Worker             SetVectorOrScalarAsModelInput<float>(param_index,
765*993b0882SAndroid Build Coastguard Worker                                                  param_value_it->second,
766*993b0882SAndroid Build Coastguard Worker                                                  interpreter, model_executor_);
767*993b0882SAndroid Build Coastguard Worker           } else {
768*993b0882SAndroid Build Coastguard Worker             model_executor_->SetInput<float>(param_index, kDefaultFloat,
769*993b0882SAndroid Build Coastguard Worker                                              interpreter);
770*993b0882SAndroid Build Coastguard Worker           }
771*993b0882SAndroid Build Coastguard Worker           break;
772*993b0882SAndroid Build Coastguard Worker         case kTfLiteInt32:
773*993b0882SAndroid Build Coastguard Worker           if (has_value) {
774*993b0882SAndroid Build Coastguard Worker             SetVectorOrScalarAsModelInput<int32_t>(
775*993b0882SAndroid Build Coastguard Worker                 param_index, param_value_it->second, interpreter,
776*993b0882SAndroid Build Coastguard Worker                 model_executor_);
777*993b0882SAndroid Build Coastguard Worker           } else {
778*993b0882SAndroid Build Coastguard Worker             model_executor_->SetInput<int32_t>(param_index, kDefaultInt,
779*993b0882SAndroid Build Coastguard Worker                                                interpreter);
780*993b0882SAndroid Build Coastguard Worker           }
781*993b0882SAndroid Build Coastguard Worker           break;
782*993b0882SAndroid Build Coastguard Worker         case kTfLiteInt64:
783*993b0882SAndroid Build Coastguard Worker           model_executor_->SetInput<int64_t>(
784*993b0882SAndroid Build Coastguard Worker               param_index,
785*993b0882SAndroid Build Coastguard Worker               has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
786*993b0882SAndroid Build Coastguard Worker               interpreter);
787*993b0882SAndroid Build Coastguard Worker           break;
788*993b0882SAndroid Build Coastguard Worker         case kTfLiteUInt8:
789*993b0882SAndroid Build Coastguard Worker           model_executor_->SetInput<uint8_t>(
790*993b0882SAndroid Build Coastguard Worker               param_index,
791*993b0882SAndroid Build Coastguard Worker               has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
792*993b0882SAndroid Build Coastguard Worker               interpreter);
793*993b0882SAndroid Build Coastguard Worker           break;
794*993b0882SAndroid Build Coastguard Worker         case kTfLiteInt8:
795*993b0882SAndroid Build Coastguard Worker           model_executor_->SetInput<int8_t>(
796*993b0882SAndroid Build Coastguard Worker               param_index,
797*993b0882SAndroid Build Coastguard Worker               has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
798*993b0882SAndroid Build Coastguard Worker               interpreter);
799*993b0882SAndroid Build Coastguard Worker           break;
800*993b0882SAndroid Build Coastguard Worker         case kTfLiteBool:
801*993b0882SAndroid Build Coastguard Worker           model_executor_->SetInput<bool>(
802*993b0882SAndroid Build Coastguard Worker               param_index,
803*993b0882SAndroid Build Coastguard Worker               has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
804*993b0882SAndroid Build Coastguard Worker               interpreter);
805*993b0882SAndroid Build Coastguard Worker           break;
806*993b0882SAndroid Build Coastguard Worker         default:
807*993b0882SAndroid Build Coastguard Worker           TC3_LOG(ERROR) << "Unsupported type of additional input parameter: "
808*993b0882SAndroid Build Coastguard Worker                          << param_name;
809*993b0882SAndroid Build Coastguard Worker       }
810*993b0882SAndroid Build Coastguard Worker     }
811*993b0882SAndroid Build Coastguard Worker   }
812*993b0882SAndroid Build Coastguard Worker   return true;
813*993b0882SAndroid Build Coastguard Worker }
814*993b0882SAndroid Build Coastguard Worker 
PopulateTextReplies(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const std::string & type,float priority_score,const absl::flat_hash_set<std::string> & blocklist,const absl::flat_hash_map<std::string,std::vector<std::string>> & concept_mappings,ActionsSuggestionsResponse * response) const815*993b0882SAndroid Build Coastguard Worker void ActionsSuggestions::PopulateTextReplies(
816*993b0882SAndroid Build Coastguard Worker     const tflite::Interpreter* interpreter, int suggestion_index,
817*993b0882SAndroid Build Coastguard Worker     int score_index, const std::string& type, float priority_score,
818*993b0882SAndroid Build Coastguard Worker     const absl::flat_hash_set<std::string>& blocklist,
819*993b0882SAndroid Build Coastguard Worker     const absl::flat_hash_map<std::string, std::vector<std::string>>&
820*993b0882SAndroid Build Coastguard Worker         concept_mappings,
821*993b0882SAndroid Build Coastguard Worker     ActionsSuggestionsResponse* response) const {
822*993b0882SAndroid Build Coastguard Worker   const std::vector<tflite::StringRef> replies =
823*993b0882SAndroid Build Coastguard Worker       model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
824*993b0882SAndroid Build Coastguard Worker   const TensorView<float> scores =
825*993b0882SAndroid Build Coastguard Worker       model_executor_->OutputView<float>(score_index, interpreter);
826*993b0882SAndroid Build Coastguard Worker 
827*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < replies.size(); i++) {
828*993b0882SAndroid Build Coastguard Worker     if (replies[i].len == 0) {
829*993b0882SAndroid Build Coastguard Worker       continue;
830*993b0882SAndroid Build Coastguard Worker     }
831*993b0882SAndroid Build Coastguard Worker     const float score = scores.data()[i];
832*993b0882SAndroid Build Coastguard Worker     if (score < preconditions_.min_reply_score_threshold) {
833*993b0882SAndroid Build Coastguard Worker       continue;
834*993b0882SAndroid Build Coastguard Worker     }
835*993b0882SAndroid Build Coastguard Worker     std::string response_text(replies[i].str, replies[i].len);
836*993b0882SAndroid Build Coastguard Worker     if (blocklist.contains(response_text)) {
837*993b0882SAndroid Build Coastguard Worker       continue;
838*993b0882SAndroid Build Coastguard Worker     }
839*993b0882SAndroid Build Coastguard Worker     if (concept_mappings.contains(response_text)) {
840*993b0882SAndroid Build Coastguard Worker       const int candidates_size = concept_mappings.at(response_text).size();
841*993b0882SAndroid Build Coastguard Worker       const int candidate_index = absl::Uniform<int>(
842*993b0882SAndroid Build Coastguard Worker           absl::IntervalOpenOpen, bit_gen_, 0, candidates_size);
843*993b0882SAndroid Build Coastguard Worker       response_text = concept_mappings.at(response_text)[candidate_index];
844*993b0882SAndroid Build Coastguard Worker     }
845*993b0882SAndroid Build Coastguard Worker 
846*993b0882SAndroid Build Coastguard Worker     response->actions.push_back({response_text, type, score, priority_score});
847*993b0882SAndroid Build Coastguard Worker   }
848*993b0882SAndroid Build Coastguard Worker }
849*993b0882SAndroid Build Coastguard Worker 
FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec * spec,ActionSuggestion * suggestion) const850*993b0882SAndroid Build Coastguard Worker void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
851*993b0882SAndroid Build Coastguard Worker     const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
852*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<MutableFlatbuffer> entity_data =
853*993b0882SAndroid Build Coastguard Worker       entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
854*993b0882SAndroid Build Coastguard Worker                                       : nullptr;
855*993b0882SAndroid Build Coastguard Worker   FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
856*993b0882SAndroid Build Coastguard Worker }
857*993b0882SAndroid Build Coastguard Worker 
PopulateIntentTriggering(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const ActionSuggestionSpec * task_spec,ActionsSuggestionsResponse * response) const858*993b0882SAndroid Build Coastguard Worker void ActionsSuggestions::PopulateIntentTriggering(
859*993b0882SAndroid Build Coastguard Worker     const tflite::Interpreter* interpreter, int suggestion_index,
860*993b0882SAndroid Build Coastguard Worker     int score_index, const ActionSuggestionSpec* task_spec,
861*993b0882SAndroid Build Coastguard Worker     ActionsSuggestionsResponse* response) const {
862*993b0882SAndroid Build Coastguard Worker   if (!task_spec || task_spec->type()->size() == 0) {
863*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR)
864*993b0882SAndroid Build Coastguard Worker         << "Task type for intent (action) triggering cannot be empty!";
865*993b0882SAndroid Build Coastguard Worker     return;
866*993b0882SAndroid Build Coastguard Worker   }
867*993b0882SAndroid Build Coastguard Worker   const TensorView<bool> intent_prediction =
868*993b0882SAndroid Build Coastguard Worker       model_executor_->OutputView<bool>(suggestion_index, interpreter);
869*993b0882SAndroid Build Coastguard Worker   const TensorView<float> intent_scores =
870*993b0882SAndroid Build Coastguard Worker       model_executor_->OutputView<float>(score_index, interpreter);
871*993b0882SAndroid Build Coastguard Worker   // Two result corresponding to binary triggering case.
872*993b0882SAndroid Build Coastguard Worker   TC3_CHECK_EQ(intent_prediction.size(), 2);
873*993b0882SAndroid Build Coastguard Worker   TC3_CHECK_EQ(intent_scores.size(), 2);
874*993b0882SAndroid Build Coastguard Worker   // We rely on in-graph thresholding logic so at this point the results
875*993b0882SAndroid Build Coastguard Worker   // have been ranked properly according to threshold.
876*993b0882SAndroid Build Coastguard Worker   const bool triggering = intent_prediction.data()[0];
877*993b0882SAndroid Build Coastguard Worker   const float trigger_score = intent_scores.data()[0];
878*993b0882SAndroid Build Coastguard Worker 
879*993b0882SAndroid Build Coastguard Worker   if (triggering) {
880*993b0882SAndroid Build Coastguard Worker     ActionSuggestion suggestion;
881*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<MutableFlatbuffer> entity_data =
882*993b0882SAndroid Build Coastguard Worker         entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
883*993b0882SAndroid Build Coastguard Worker                                         : nullptr;
884*993b0882SAndroid Build Coastguard Worker     FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
885*993b0882SAndroid Build Coastguard Worker     suggestion.score = trigger_score;
886*993b0882SAndroid Build Coastguard Worker     response->actions.push_back(std::move(suggestion));
887*993b0882SAndroid Build Coastguard Worker   }
888*993b0882SAndroid Build Coastguard Worker }
889*993b0882SAndroid Build Coastguard Worker 
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const890*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::ReadModelOutput(
891*993b0882SAndroid Build Coastguard Worker     tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
892*993b0882SAndroid Build Coastguard Worker     ActionsSuggestionsResponse* response) const {
893*993b0882SAndroid Build Coastguard Worker   // Read sensitivity and triggering score predictions.
894*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
895*993b0882SAndroid Build Coastguard Worker     const TensorView<float> triggering_score =
896*993b0882SAndroid Build Coastguard Worker         model_executor_->OutputView<float>(
897*993b0882SAndroid Build Coastguard Worker             model_->tflite_model_spec()->output_triggering_score(),
898*993b0882SAndroid Build Coastguard Worker             interpreter);
899*993b0882SAndroid Build Coastguard Worker     if (!triggering_score.is_valid() || triggering_score.size() == 0) {
900*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not compute triggering score.";
901*993b0882SAndroid Build Coastguard Worker       return false;
902*993b0882SAndroid Build Coastguard Worker     }
903*993b0882SAndroid Build Coastguard Worker     response->triggering_score = triggering_score.data()[0];
904*993b0882SAndroid Build Coastguard Worker     response->output_filtered_min_triggering_score =
905*993b0882SAndroid Build Coastguard Worker         (response->triggering_score <
906*993b0882SAndroid Build Coastguard Worker          preconditions_.min_smart_reply_triggering_score);
907*993b0882SAndroid Build Coastguard Worker   }
908*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
909*993b0882SAndroid Build Coastguard Worker     const TensorView<float> sensitive_topic_score =
910*993b0882SAndroid Build Coastguard Worker         model_executor_->OutputView<float>(
911*993b0882SAndroid Build Coastguard Worker             model_->tflite_model_spec()->output_sensitive_topic_score(),
912*993b0882SAndroid Build Coastguard Worker             interpreter);
913*993b0882SAndroid Build Coastguard Worker     if (!sensitive_topic_score.is_valid() ||
914*993b0882SAndroid Build Coastguard Worker         sensitive_topic_score.dim(0) != 1) {
915*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
916*993b0882SAndroid Build Coastguard Worker       return false;
917*993b0882SAndroid Build Coastguard Worker     }
918*993b0882SAndroid Build Coastguard Worker     response->sensitivity_score = sensitive_topic_score.data()[0];
919*993b0882SAndroid Build Coastguard Worker     response->is_sensitive = (response->sensitivity_score >
920*993b0882SAndroid Build Coastguard Worker                               preconditions_.max_sensitive_topic_score);
921*993b0882SAndroid Build Coastguard Worker   }
922*993b0882SAndroid Build Coastguard Worker 
923*993b0882SAndroid Build Coastguard Worker   // Suppress model outputs.
924*993b0882SAndroid Build Coastguard Worker   if (response->is_sensitive) {
925*993b0882SAndroid Build Coastguard Worker     return true;
926*993b0882SAndroid Build Coastguard Worker   }
927*993b0882SAndroid Build Coastguard Worker 
928*993b0882SAndroid Build Coastguard Worker   // Read smart reply predictions.
929*993b0882SAndroid Build Coastguard Worker   if (!response->output_filtered_min_triggering_score &&
930*993b0882SAndroid Build Coastguard Worker       model_->tflite_model_spec()->output_replies() >= 0) {
931*993b0882SAndroid Build Coastguard Worker     absl::flat_hash_set<std::string> empty_blocklist;
932*993b0882SAndroid Build Coastguard Worker     PopulateTextReplies(
933*993b0882SAndroid Build Coastguard Worker         interpreter, model_->tflite_model_spec()->output_replies(),
934*993b0882SAndroid Build Coastguard Worker         model_->tflite_model_spec()->output_replies_scores(),
935*993b0882SAndroid Build Coastguard Worker         model_->smart_reply_action_type()->str(),
936*993b0882SAndroid Build Coastguard Worker         /* priority_score */ 0.0, empty_blocklist, {}, response);
937*993b0882SAndroid Build Coastguard Worker   }
938*993b0882SAndroid Build Coastguard Worker 
939*993b0882SAndroid Build Coastguard Worker   // Read actions suggestions.
940*993b0882SAndroid Build Coastguard Worker   if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
941*993b0882SAndroid Build Coastguard Worker     const TensorView<float> actions_scores = model_executor_->OutputView<float>(
942*993b0882SAndroid Build Coastguard Worker         model_->tflite_model_spec()->output_actions_scores(), interpreter);
943*993b0882SAndroid Build Coastguard Worker     for (int i = 0; i < model_->action_type()->size(); i++) {
944*993b0882SAndroid Build Coastguard Worker       const ActionTypeOptions* action_type = model_->action_type()->Get(i);
945*993b0882SAndroid Build Coastguard Worker       // Skip disabled action classes, such as the default other category.
946*993b0882SAndroid Build Coastguard Worker       if (!action_type->enabled()) {
947*993b0882SAndroid Build Coastguard Worker         continue;
948*993b0882SAndroid Build Coastguard Worker       }
949*993b0882SAndroid Build Coastguard Worker       const float score = actions_scores.data()[i];
950*993b0882SAndroid Build Coastguard Worker       if (score < action_type->min_triggering_score()) {
951*993b0882SAndroid Build Coastguard Worker         continue;
952*993b0882SAndroid Build Coastguard Worker       }
953*993b0882SAndroid Build Coastguard Worker 
954*993b0882SAndroid Build Coastguard Worker       // Create action from model output.
955*993b0882SAndroid Build Coastguard Worker       ActionSuggestion suggestion;
956*993b0882SAndroid Build Coastguard Worker       suggestion.type = action_type->name()->str();
957*993b0882SAndroid Build Coastguard Worker       std::unique_ptr<MutableFlatbuffer> entity_data =
958*993b0882SAndroid Build Coastguard Worker           entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
959*993b0882SAndroid Build Coastguard Worker                                           : nullptr;
960*993b0882SAndroid Build Coastguard Worker       FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
961*993b0882SAndroid Build Coastguard Worker       suggestion.score = score;
962*993b0882SAndroid Build Coastguard Worker       response->actions.push_back(std::move(suggestion));
963*993b0882SAndroid Build Coastguard Worker     }
964*993b0882SAndroid Build Coastguard Worker   }
965*993b0882SAndroid Build Coastguard Worker 
966*993b0882SAndroid Build Coastguard Worker   // Read multi-task predictions and construct the result properly.
967*993b0882SAndroid Build Coastguard Worker   if (const auto* prediction_metadata =
968*993b0882SAndroid Build Coastguard Worker           model_->tflite_model_spec()->prediction_metadata()) {
969*993b0882SAndroid Build Coastguard Worker     for (const PredictionMetadata* metadata : *prediction_metadata) {
970*993b0882SAndroid Build Coastguard Worker       const ActionSuggestionSpec* task_spec = metadata->task_spec();
971*993b0882SAndroid Build Coastguard Worker       const int suggestions_index = metadata->output_suggestions();
972*993b0882SAndroid Build Coastguard Worker       const int suggestions_scores_index =
973*993b0882SAndroid Build Coastguard Worker           metadata->output_suggestions_scores();
974*993b0882SAndroid Build Coastguard Worker       absl::flat_hash_set<std::string> response_text_blocklist;
975*993b0882SAndroid Build Coastguard Worker       absl::flat_hash_map<std::string, std::vector<std::string>>
976*993b0882SAndroid Build Coastguard Worker           concept_mappings;
977*993b0882SAndroid Build Coastguard Worker       switch (metadata->prediction_type()) {
978*993b0882SAndroid Build Coastguard Worker         case PredictionType_NEXT_MESSAGE_PREDICTION:
979*993b0882SAndroid Build Coastguard Worker           if (!task_spec || task_spec->type()->size() == 0) {
980*993b0882SAndroid Build Coastguard Worker             TC3_LOG(WARNING) << "Task type not provided, use default "
981*993b0882SAndroid Build Coastguard Worker                                 "smart_reply_action_type!";
982*993b0882SAndroid Build Coastguard Worker           }
983*993b0882SAndroid Build Coastguard Worker           if (task_spec) {
984*993b0882SAndroid Build Coastguard Worker             if (task_spec->response_text_blocklist()) {
985*993b0882SAndroid Build Coastguard Worker               for (const auto& val : *task_spec->response_text_blocklist()) {
986*993b0882SAndroid Build Coastguard Worker                 response_text_blocklist.insert(val->str());
987*993b0882SAndroid Build Coastguard Worker               }
988*993b0882SAndroid Build Coastguard Worker             }
989*993b0882SAndroid Build Coastguard Worker             if (task_spec->concept_mappings()) {
990*993b0882SAndroid Build Coastguard Worker               for (const auto& concept : *task_spec->concept_mappings()) {
991*993b0882SAndroid Build Coastguard Worker                 std::vector<std::string> candidates;
992*993b0882SAndroid Build Coastguard Worker                 for (const auto& candidate : *concept->candidates()) {
993*993b0882SAndroid Build Coastguard Worker                   candidates.push_back(candidate->str());
994*993b0882SAndroid Build Coastguard Worker                 }
995*993b0882SAndroid Build Coastguard Worker                 concept_mappings[concept->concept_name()->str()] = candidates;
996*993b0882SAndroid Build Coastguard Worker               }
997*993b0882SAndroid Build Coastguard Worker             }
998*993b0882SAndroid Build Coastguard Worker           }
999*993b0882SAndroid Build Coastguard Worker           PopulateTextReplies(
1000*993b0882SAndroid Build Coastguard Worker               interpreter, suggestions_index, suggestions_scores_index,
1001*993b0882SAndroid Build Coastguard Worker               task_spec ? task_spec->type()->str()
1002*993b0882SAndroid Build Coastguard Worker                         : model_->smart_reply_action_type()->str(),
1003*993b0882SAndroid Build Coastguard Worker               task_spec ? task_spec->priority_score() : 0.0,
1004*993b0882SAndroid Build Coastguard Worker               response_text_blocklist, concept_mappings, response);
1005*993b0882SAndroid Build Coastguard Worker           break;
1006*993b0882SAndroid Build Coastguard Worker         case PredictionType_INTENT_TRIGGERING:
1007*993b0882SAndroid Build Coastguard Worker           PopulateIntentTriggering(interpreter, suggestions_index,
1008*993b0882SAndroid Build Coastguard Worker                                    suggestions_scores_index, task_spec,
1009*993b0882SAndroid Build Coastguard Worker                                    response);
1010*993b0882SAndroid Build Coastguard Worker           break;
1011*993b0882SAndroid Build Coastguard Worker         default:
1012*993b0882SAndroid Build Coastguard Worker           TC3_LOG(ERROR) << "Unsupported prediction type!";
1013*993b0882SAndroid Build Coastguard Worker           return false;
1014*993b0882SAndroid Build Coastguard Worker       }
1015*993b0882SAndroid Build Coastguard Worker     }
1016*993b0882SAndroid Build Coastguard Worker   }
1017*993b0882SAndroid Build Coastguard Worker 
1018*993b0882SAndroid Build Coastguard Worker   return true;
1019*993b0882SAndroid Build Coastguard Worker }
1020*993b0882SAndroid Build Coastguard Worker 
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const1021*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::SuggestActionsFromModel(
1022*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const int num_messages,
1023*993b0882SAndroid Build Coastguard Worker     const ActionSuggestionOptions& options,
1024*993b0882SAndroid Build Coastguard Worker     ActionsSuggestionsResponse* response,
1025*993b0882SAndroid Build Coastguard Worker     std::unique_ptr<tflite::Interpreter>* interpreter) const {
1026*993b0882SAndroid Build Coastguard Worker   TC3_CHECK_LE(num_messages, conversation.messages.size());
1027*993b0882SAndroid Build Coastguard Worker 
1028*993b0882SAndroid Build Coastguard Worker   if (sensitive_model_ != nullptr &&
1029*993b0882SAndroid Build Coastguard Worker       sensitive_model_->EvalConversation(conversation, num_messages).first) {
1030*993b0882SAndroid Build Coastguard Worker     response->is_sensitive = true;
1031*993b0882SAndroid Build Coastguard Worker     return true;
1032*993b0882SAndroid Build Coastguard Worker   }
1033*993b0882SAndroid Build Coastguard Worker 
1034*993b0882SAndroid Build Coastguard Worker   if (!model_executor_) {
1035*993b0882SAndroid Build Coastguard Worker     return true;
1036*993b0882SAndroid Build Coastguard Worker   }
1037*993b0882SAndroid Build Coastguard Worker   *interpreter = model_executor_->CreateInterpreter();
1038*993b0882SAndroid Build Coastguard Worker 
1039*993b0882SAndroid Build Coastguard Worker   if (!*interpreter) {
1040*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
1041*993b0882SAndroid Build Coastguard Worker                       "actions suggestions model.";
1042*993b0882SAndroid Build Coastguard Worker     return false;
1043*993b0882SAndroid Build Coastguard Worker   }
1044*993b0882SAndroid Build Coastguard Worker 
1045*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> context;
1046*993b0882SAndroid Build Coastguard Worker   std::vector<int> user_ids;
1047*993b0882SAndroid Build Coastguard Worker   std::vector<float> time_diffs;
1048*993b0882SAndroid Build Coastguard Worker   context.reserve(num_messages);
1049*993b0882SAndroid Build Coastguard Worker   user_ids.reserve(num_messages);
1050*993b0882SAndroid Build Coastguard Worker   time_diffs.reserve(num_messages);
1051*993b0882SAndroid Build Coastguard Worker 
1052*993b0882SAndroid Build Coastguard Worker   // Gather last `num_messages` messages from the conversation.
1053*993b0882SAndroid Build Coastguard Worker   int64 last_message_reference_time_ms_utc = 0;
1054*993b0882SAndroid Build Coastguard Worker   const float second_in_ms = 1000;
1055*993b0882SAndroid Build Coastguard Worker   for (int i = conversation.messages.size() - num_messages;
1056*993b0882SAndroid Build Coastguard Worker        i < conversation.messages.size(); i++) {
1057*993b0882SAndroid Build Coastguard Worker     const ConversationMessage& message = conversation.messages[i];
1058*993b0882SAndroid Build Coastguard Worker     context.push_back(message.text);
1059*993b0882SAndroid Build Coastguard Worker     user_ids.push_back(message.user_id);
1060*993b0882SAndroid Build Coastguard Worker 
1061*993b0882SAndroid Build Coastguard Worker     float time_diff_secs = 0;
1062*993b0882SAndroid Build Coastguard Worker     if (message.reference_time_ms_utc != 0 &&
1063*993b0882SAndroid Build Coastguard Worker         last_message_reference_time_ms_utc != 0) {
1064*993b0882SAndroid Build Coastguard Worker       time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
1065*993b0882SAndroid Build Coastguard Worker                                        last_message_reference_time_ms_utc) /
1066*993b0882SAndroid Build Coastguard Worker                                           second_in_ms);
1067*993b0882SAndroid Build Coastguard Worker     }
1068*993b0882SAndroid Build Coastguard Worker     if (message.reference_time_ms_utc != 0) {
1069*993b0882SAndroid Build Coastguard Worker       last_message_reference_time_ms_utc = message.reference_time_ms_utc;
1070*993b0882SAndroid Build Coastguard Worker     }
1071*993b0882SAndroid Build Coastguard Worker     time_diffs.push_back(time_diff_secs);
1072*993b0882SAndroid Build Coastguard Worker   }
1073*993b0882SAndroid Build Coastguard Worker 
1074*993b0882SAndroid Build Coastguard Worker   if (!SetupModelInput(context, user_ids, time_diffs,
1075*993b0882SAndroid Build Coastguard Worker                        /*num_suggestions=*/model_->num_smart_replies(), options,
1076*993b0882SAndroid Build Coastguard Worker                        interpreter->get())) {
1077*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
1078*993b0882SAndroid Build Coastguard Worker     return false;
1079*993b0882SAndroid Build Coastguard Worker   }
1080*993b0882SAndroid Build Coastguard Worker 
1081*993b0882SAndroid Build Coastguard Worker   if ((*interpreter)->Invoke() != kTfLiteOk) {
1082*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
1083*993b0882SAndroid Build Coastguard Worker     return false;
1084*993b0882SAndroid Build Coastguard Worker   }
1085*993b0882SAndroid Build Coastguard Worker 
1086*993b0882SAndroid Build Coastguard Worker   return ReadModelOutput(interpreter->get(), options, response);
1087*993b0882SAndroid Build Coastguard Worker }
1088*993b0882SAndroid Build Coastguard Worker 
SuggestActionsFromConversationIntentDetection(const Conversation & conversation,const ActionSuggestionOptions & options,std::vector<ActionSuggestion> * actions) const1089*993b0882SAndroid Build Coastguard Worker Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection(
1090*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const ActionSuggestionOptions& options,
1091*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* actions) const {
1092*993b0882SAndroid Build Coastguard Worker   TC3_ASSIGN_OR_RETURN(
1093*993b0882SAndroid Build Coastguard Worker       std::vector<ActionSuggestion> new_actions,
1094*993b0882SAndroid Build Coastguard Worker       conversation_intent_detection_->SuggestActions(conversation, options));
1095*993b0882SAndroid Build Coastguard Worker   for (auto& action : new_actions) {
1096*993b0882SAndroid Build Coastguard Worker     actions->push_back(std::move(action));
1097*993b0882SAndroid Build Coastguard Worker   }
1098*993b0882SAndroid Build Coastguard Worker   return Status::OK;
1099*993b0882SAndroid Build Coastguard Worker }
1100*993b0882SAndroid Build Coastguard Worker 
AnnotationOptionsForMessage(const ConversationMessage & message) const1101*993b0882SAndroid Build Coastguard Worker AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
1102*993b0882SAndroid Build Coastguard Worker     const ConversationMessage& message) const {
1103*993b0882SAndroid Build Coastguard Worker   AnnotationOptions options;
1104*993b0882SAndroid Build Coastguard Worker   options.detected_text_language_tags = message.detected_text_language_tags;
1105*993b0882SAndroid Build Coastguard Worker   options.reference_time_ms_utc = message.reference_time_ms_utc;
1106*993b0882SAndroid Build Coastguard Worker   options.reference_timezone = message.reference_timezone;
1107*993b0882SAndroid Build Coastguard Worker   options.annotation_usecase =
1108*993b0882SAndroid Build Coastguard Worker       model_->annotation_actions_spec()->annotation_usecase();
1109*993b0882SAndroid Build Coastguard Worker   options.is_serialized_entity_data_enabled =
1110*993b0882SAndroid Build Coastguard Worker       model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1111*993b0882SAndroid Build Coastguard Worker   options.entity_types = annotation_entity_types_;
1112*993b0882SAndroid Build Coastguard Worker   return options;
1113*993b0882SAndroid Build Coastguard Worker }
1114*993b0882SAndroid Build Coastguard Worker 
1115*993b0882SAndroid Build Coastguard Worker // Run annotator on the messages of a conversation.
AnnotateConversation(const Conversation & conversation,const Annotator * annotator) const1116*993b0882SAndroid Build Coastguard Worker Conversation ActionsSuggestions::AnnotateConversation(
1117*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const Annotator* annotator) const {
1118*993b0882SAndroid Build Coastguard Worker   if (annotator == nullptr) {
1119*993b0882SAndroid Build Coastguard Worker     return conversation;
1120*993b0882SAndroid Build Coastguard Worker   }
1121*993b0882SAndroid Build Coastguard Worker   const int num_messages_grammar =
1122*993b0882SAndroid Build Coastguard Worker       ((model_->rules() && model_->rules()->grammar_rules() &&
1123*993b0882SAndroid Build Coastguard Worker         model_->rules()
1124*993b0882SAndroid Build Coastguard Worker             ->grammar_rules()
1125*993b0882SAndroid Build Coastguard Worker             ->rules()
1126*993b0882SAndroid Build Coastguard Worker             ->nonterminals()
1127*993b0882SAndroid Build Coastguard Worker             ->annotation_nt())
1128*993b0882SAndroid Build Coastguard Worker            ? 1
1129*993b0882SAndroid Build Coastguard Worker            : 0);
1130*993b0882SAndroid Build Coastguard Worker   const int num_messages_mapping =
1131*993b0882SAndroid Build Coastguard Worker       (model_->annotation_actions_spec()
1132*993b0882SAndroid Build Coastguard Worker            ? std::max(model_->annotation_actions_spec()
1133*993b0882SAndroid Build Coastguard Worker                           ->max_history_from_any_person(),
1134*993b0882SAndroid Build Coastguard Worker                       model_->annotation_actions_spec()
1135*993b0882SAndroid Build Coastguard Worker                           ->max_history_from_last_person())
1136*993b0882SAndroid Build Coastguard Worker            : 0);
1137*993b0882SAndroid Build Coastguard Worker   const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
1138*993b0882SAndroid Build Coastguard Worker   if (num_messages == 0) {
1139*993b0882SAndroid Build Coastguard Worker     // No annotations are used.
1140*993b0882SAndroid Build Coastguard Worker     return conversation;
1141*993b0882SAndroid Build Coastguard Worker   }
1142*993b0882SAndroid Build Coastguard Worker   Conversation annotated_conversation = conversation;
1143*993b0882SAndroid Build Coastguard Worker   for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
1144*993b0882SAndroid Build Coastguard Worker        i < num_messages && message_index >= 0; i++, message_index--) {
1145*993b0882SAndroid Build Coastguard Worker     ConversationMessage* message =
1146*993b0882SAndroid Build Coastguard Worker         &annotated_conversation.messages[message_index];
1147*993b0882SAndroid Build Coastguard Worker     if (message->annotations.empty()) {
1148*993b0882SAndroid Build Coastguard Worker       message->annotations = annotator->Annotate(
1149*993b0882SAndroid Build Coastguard Worker           message->text, AnnotationOptionsForMessage(*message));
1150*993b0882SAndroid Build Coastguard Worker       ConvertDatetimeToTime(&message->annotations);
1151*993b0882SAndroid Build Coastguard Worker     }
1152*993b0882SAndroid Build Coastguard Worker   }
1153*993b0882SAndroid Build Coastguard Worker   return annotated_conversation;
1154*993b0882SAndroid Build Coastguard Worker }
1155*993b0882SAndroid Build Coastguard Worker 
SuggestActionsFromAnnotations(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1156*993b0882SAndroid Build Coastguard Worker void ActionsSuggestions::SuggestActionsFromAnnotations(
1157*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation,
1158*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* actions) const {
1159*993b0882SAndroid Build Coastguard Worker   if (model_->annotation_actions_spec() == nullptr ||
1160*993b0882SAndroid Build Coastguard Worker       model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1161*993b0882SAndroid Build Coastguard Worker       model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1162*993b0882SAndroid Build Coastguard Worker     return;
1163*993b0882SAndroid Build Coastguard Worker   }
1164*993b0882SAndroid Build Coastguard Worker 
1165*993b0882SAndroid Build Coastguard Worker   // Create actions based on the annotations.
1166*993b0882SAndroid Build Coastguard Worker   const int max_from_any_person =
1167*993b0882SAndroid Build Coastguard Worker       model_->annotation_actions_spec()->max_history_from_any_person();
1168*993b0882SAndroid Build Coastguard Worker   const int max_from_last_person =
1169*993b0882SAndroid Build Coastguard Worker       model_->annotation_actions_spec()->max_history_from_last_person();
1170*993b0882SAndroid Build Coastguard Worker   const int last_person = conversation.messages.back().user_id;
1171*993b0882SAndroid Build Coastguard Worker 
1172*993b0882SAndroid Build Coastguard Worker   int num_messages_last_person = 0;
1173*993b0882SAndroid Build Coastguard Worker   int num_messages_any_person = 0;
1174*993b0882SAndroid Build Coastguard Worker   bool all_from_last_person = true;
1175*993b0882SAndroid Build Coastguard Worker   for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1176*993b0882SAndroid Build Coastguard Worker        message_index--) {
1177*993b0882SAndroid Build Coastguard Worker     const ConversationMessage& message = conversation.messages[message_index];
1178*993b0882SAndroid Build Coastguard Worker     std::vector<AnnotatedSpan> annotations = message.annotations;
1179*993b0882SAndroid Build Coastguard Worker 
1180*993b0882SAndroid Build Coastguard Worker     // Update how many messages we have processed from the last person in the
1181*993b0882SAndroid Build Coastguard Worker     // conversation and from any person in the conversation.
1182*993b0882SAndroid Build Coastguard Worker     num_messages_any_person++;
1183*993b0882SAndroid Build Coastguard Worker     if (all_from_last_person && message.user_id == last_person) {
1184*993b0882SAndroid Build Coastguard Worker       num_messages_last_person++;
1185*993b0882SAndroid Build Coastguard Worker     } else {
1186*993b0882SAndroid Build Coastguard Worker       all_from_last_person = false;
1187*993b0882SAndroid Build Coastguard Worker     }
1188*993b0882SAndroid Build Coastguard Worker 
1189*993b0882SAndroid Build Coastguard Worker     if (num_messages_any_person > max_from_any_person &&
1190*993b0882SAndroid Build Coastguard Worker         (!all_from_last_person ||
1191*993b0882SAndroid Build Coastguard Worker          num_messages_last_person > max_from_last_person)) {
1192*993b0882SAndroid Build Coastguard Worker       break;
1193*993b0882SAndroid Build Coastguard Worker     }
1194*993b0882SAndroid Build Coastguard Worker 
1195*993b0882SAndroid Build Coastguard Worker     if (message.user_id == kLocalUserId) {
1196*993b0882SAndroid Build Coastguard Worker       if (model_->annotation_actions_spec()->only_until_last_sent()) {
1197*993b0882SAndroid Build Coastguard Worker         break;
1198*993b0882SAndroid Build Coastguard Worker       }
1199*993b0882SAndroid Build Coastguard Worker       if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1200*993b0882SAndroid Build Coastguard Worker         continue;
1201*993b0882SAndroid Build Coastguard Worker       }
1202*993b0882SAndroid Build Coastguard Worker     }
1203*993b0882SAndroid Build Coastguard Worker 
1204*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestionAnnotation> action_annotations;
1205*993b0882SAndroid Build Coastguard Worker     action_annotations.reserve(annotations.size());
1206*993b0882SAndroid Build Coastguard Worker     for (const AnnotatedSpan& annotation : annotations) {
1207*993b0882SAndroid Build Coastguard Worker       if (annotation.classification.empty()) {
1208*993b0882SAndroid Build Coastguard Worker         continue;
1209*993b0882SAndroid Build Coastguard Worker       }
1210*993b0882SAndroid Build Coastguard Worker 
1211*993b0882SAndroid Build Coastguard Worker       const ClassificationResult& classification_result =
1212*993b0882SAndroid Build Coastguard Worker           annotation.classification[0];
1213*993b0882SAndroid Build Coastguard Worker 
1214*993b0882SAndroid Build Coastguard Worker       ActionSuggestionAnnotation action_annotation;
1215*993b0882SAndroid Build Coastguard Worker       action_annotation.span = {
1216*993b0882SAndroid Build Coastguard Worker           message_index, annotation.span,
1217*993b0882SAndroid Build Coastguard Worker           UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1218*993b0882SAndroid Build Coastguard Worker               .UTF8Substring(annotation.span.first, annotation.span.second)};
1219*993b0882SAndroid Build Coastguard Worker       action_annotation.entity = classification_result;
1220*993b0882SAndroid Build Coastguard Worker       action_annotation.name = classification_result.collection;
1221*993b0882SAndroid Build Coastguard Worker       action_annotations.push_back(std::move(action_annotation));
1222*993b0882SAndroid Build Coastguard Worker     }
1223*993b0882SAndroid Build Coastguard Worker 
1224*993b0882SAndroid Build Coastguard Worker     if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1225*993b0882SAndroid Build Coastguard Worker       // Create actions only for deduplicated annotations.
1226*993b0882SAndroid Build Coastguard Worker       for (const int annotation_id :
1227*993b0882SAndroid Build Coastguard Worker            DeduplicateAnnotations(action_annotations)) {
1228*993b0882SAndroid Build Coastguard Worker         SuggestActionsFromAnnotation(
1229*993b0882SAndroid Build Coastguard Worker             message_index, action_annotations[annotation_id], actions);
1230*993b0882SAndroid Build Coastguard Worker       }
1231*993b0882SAndroid Build Coastguard Worker     } else {
1232*993b0882SAndroid Build Coastguard Worker       // Create actions for all annotations.
1233*993b0882SAndroid Build Coastguard Worker       for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1234*993b0882SAndroid Build Coastguard Worker         SuggestActionsFromAnnotation(message_index, annotation, actions);
1235*993b0882SAndroid Build Coastguard Worker       }
1236*993b0882SAndroid Build Coastguard Worker     }
1237*993b0882SAndroid Build Coastguard Worker   }
1238*993b0882SAndroid Build Coastguard Worker }
1239*993b0882SAndroid Build Coastguard Worker 
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1240*993b0882SAndroid Build Coastguard Worker void ActionsSuggestions::SuggestActionsFromAnnotation(
1241*993b0882SAndroid Build Coastguard Worker     const int message_index, const ActionSuggestionAnnotation& annotation,
1242*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* actions) const {
1243*993b0882SAndroid Build Coastguard Worker   for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1244*993b0882SAndroid Build Coastguard Worker        *model_->annotation_actions_spec()->annotation_mapping()) {
1245*993b0882SAndroid Build Coastguard Worker     if (annotation.entity.collection ==
1246*993b0882SAndroid Build Coastguard Worker         mapping->annotation_collection()->str()) {
1247*993b0882SAndroid Build Coastguard Worker       if (annotation.entity.score < mapping->min_annotation_score()) {
1248*993b0882SAndroid Build Coastguard Worker         continue;
1249*993b0882SAndroid Build Coastguard Worker       }
1250*993b0882SAndroid Build Coastguard Worker 
1251*993b0882SAndroid Build Coastguard Worker       std::unique_ptr<MutableFlatbuffer> entity_data =
1252*993b0882SAndroid Build Coastguard Worker           entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1253*993b0882SAndroid Build Coastguard Worker                                           : nullptr;
1254*993b0882SAndroid Build Coastguard Worker 
1255*993b0882SAndroid Build Coastguard Worker       // Set annotation text as (additional) entity data field.
1256*993b0882SAndroid Build Coastguard Worker       if (mapping->entity_field() != nullptr) {
1257*993b0882SAndroid Build Coastguard Worker         TC3_CHECK_NE(entity_data, nullptr);
1258*993b0882SAndroid Build Coastguard Worker 
1259*993b0882SAndroid Build Coastguard Worker         UnicodeText normalized_annotation_text =
1260*993b0882SAndroid Build Coastguard Worker             UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
1261*993b0882SAndroid Build Coastguard Worker 
1262*993b0882SAndroid Build Coastguard Worker         // Apply normalization if specified.
1263*993b0882SAndroid Build Coastguard Worker         if (mapping->normalization_options() != nullptr) {
1264*993b0882SAndroid Build Coastguard Worker           normalized_annotation_text =
1265*993b0882SAndroid Build Coastguard Worker               NormalizeText(*unilib_, mapping->normalization_options(),
1266*993b0882SAndroid Build Coastguard Worker                             normalized_annotation_text);
1267*993b0882SAndroid Build Coastguard Worker         }
1268*993b0882SAndroid Build Coastguard Worker 
1269*993b0882SAndroid Build Coastguard Worker         entity_data->ParseAndSet(mapping->entity_field(),
1270*993b0882SAndroid Build Coastguard Worker                                  normalized_annotation_text.ToUTF8String());
1271*993b0882SAndroid Build Coastguard Worker       }
1272*993b0882SAndroid Build Coastguard Worker 
1273*993b0882SAndroid Build Coastguard Worker       ActionSuggestion suggestion;
1274*993b0882SAndroid Build Coastguard Worker       FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion);
1275*993b0882SAndroid Build Coastguard Worker       if (mapping->use_annotation_score()) {
1276*993b0882SAndroid Build Coastguard Worker         suggestion.score = annotation.entity.score;
1277*993b0882SAndroid Build Coastguard Worker       }
1278*993b0882SAndroid Build Coastguard Worker       suggestion.annotations = {annotation};
1279*993b0882SAndroid Build Coastguard Worker       actions->push_back(std::move(suggestion));
1280*993b0882SAndroid Build Coastguard Worker     }
1281*993b0882SAndroid Build Coastguard Worker   }
1282*993b0882SAndroid Build Coastguard Worker }
1283*993b0882SAndroid Build Coastguard Worker 
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1284*993b0882SAndroid Build Coastguard Worker std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1285*993b0882SAndroid Build Coastguard Worker     const std::vector<ActionSuggestionAnnotation>& annotations) const {
1286*993b0882SAndroid Build Coastguard Worker   std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1287*993b0882SAndroid Build Coastguard Worker 
1288*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < annotations.size(); i++) {
1289*993b0882SAndroid Build Coastguard Worker     const std::pair<std::string, std::string> key = {annotations[i].name,
1290*993b0882SAndroid Build Coastguard Worker                                                      annotations[i].span.text};
1291*993b0882SAndroid Build Coastguard Worker     auto entry = deduplicated_annotations.find(key);
1292*993b0882SAndroid Build Coastguard Worker     if (entry != deduplicated_annotations.end()) {
1293*993b0882SAndroid Build Coastguard Worker       // Kepp the annotation with the higher score.
1294*993b0882SAndroid Build Coastguard Worker       if (annotations[entry->second].entity.score <
1295*993b0882SAndroid Build Coastguard Worker           annotations[i].entity.score) {
1296*993b0882SAndroid Build Coastguard Worker         entry->second = i;
1297*993b0882SAndroid Build Coastguard Worker       }
1298*993b0882SAndroid Build Coastguard Worker       continue;
1299*993b0882SAndroid Build Coastguard Worker     }
1300*993b0882SAndroid Build Coastguard Worker     deduplicated_annotations.insert(entry, {key, i});
1301*993b0882SAndroid Build Coastguard Worker   }
1302*993b0882SAndroid Build Coastguard Worker 
1303*993b0882SAndroid Build Coastguard Worker   std::vector<int> result;
1304*993b0882SAndroid Build Coastguard Worker   result.reserve(deduplicated_annotations.size());
1305*993b0882SAndroid Build Coastguard Worker   for (const auto& key_and_annotation : deduplicated_annotations) {
1306*993b0882SAndroid Build Coastguard Worker     result.push_back(key_and_annotation.second);
1307*993b0882SAndroid Build Coastguard Worker   }
1308*993b0882SAndroid Build Coastguard Worker   return result;
1309*993b0882SAndroid Build Coastguard Worker }
1310*993b0882SAndroid Build Coastguard Worker 
1311*993b0882SAndroid Build Coastguard Worker #if !defined(TC3_DISABLE_LUA)
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1312*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::SuggestActionsFromLua(
1313*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1314*993b0882SAndroid Build Coastguard Worker     const tflite::Interpreter* interpreter,
1315*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* annotation_entity_data_schema,
1316*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* actions) const {
1317*993b0882SAndroid Build Coastguard Worker   if (lua_bytecode_.empty()) {
1318*993b0882SAndroid Build Coastguard Worker     return true;
1319*993b0882SAndroid Build Coastguard Worker   }
1320*993b0882SAndroid Build Coastguard Worker 
1321*993b0882SAndroid Build Coastguard Worker   auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1322*993b0882SAndroid Build Coastguard Worker       lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1323*993b0882SAndroid Build Coastguard Worker       interpreter, entity_data_schema_, annotation_entity_data_schema);
1324*993b0882SAndroid Build Coastguard Worker   if (lua_actions == nullptr) {
1325*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not create lua actions.";
1326*993b0882SAndroid Build Coastguard Worker     return false;
1327*993b0882SAndroid Build Coastguard Worker   }
1328*993b0882SAndroid Build Coastguard Worker   return lua_actions->SuggestActions(actions);
1329*993b0882SAndroid Build Coastguard Worker }
1330*993b0882SAndroid Build Coastguard Worker #else
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1331*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::SuggestActionsFromLua(
1332*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1333*993b0882SAndroid Build Coastguard Worker     const tflite::Interpreter* interpreter,
1334*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* annotation_entity_data_schema,
1335*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* actions) const {
1336*993b0882SAndroid Build Coastguard Worker   return true;
1337*993b0882SAndroid Build Coastguard Worker }
1338*993b0882SAndroid Build Coastguard Worker #endif
1339*993b0882SAndroid Build Coastguard Worker 
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1340*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::GatherActionsSuggestions(
1341*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const Annotator* annotator,
1342*993b0882SAndroid Build Coastguard Worker     const ActionSuggestionOptions& options,
1343*993b0882SAndroid Build Coastguard Worker     ActionsSuggestionsResponse* response) const {
1344*993b0882SAndroid Build Coastguard Worker   if (conversation.messages.empty()) {
1345*993b0882SAndroid Build Coastguard Worker     return true;
1346*993b0882SAndroid Build Coastguard Worker   }
1347*993b0882SAndroid Build Coastguard Worker 
1348*993b0882SAndroid Build Coastguard Worker   // Run annotator against messages.
1349*993b0882SAndroid Build Coastguard Worker   const Conversation annotated_conversation =
1350*993b0882SAndroid Build Coastguard Worker       AnnotateConversation(conversation, annotator);
1351*993b0882SAndroid Build Coastguard Worker 
1352*993b0882SAndroid Build Coastguard Worker   const int num_messages = NumMessagesToConsider(
1353*993b0882SAndroid Build Coastguard Worker       annotated_conversation, model_->max_conversation_history_length());
1354*993b0882SAndroid Build Coastguard Worker 
1355*993b0882SAndroid Build Coastguard Worker   if (num_messages <= 0) {
1356*993b0882SAndroid Build Coastguard Worker     TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1357*993b0882SAndroid Build Coastguard Worker     return false;
1358*993b0882SAndroid Build Coastguard Worker   }
1359*993b0882SAndroid Build Coastguard Worker 
1360*993b0882SAndroid Build Coastguard Worker   SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
1361*993b0882SAndroid Build Coastguard Worker 
1362*993b0882SAndroid Build Coastguard Worker   if (grammar_actions_ != nullptr &&
1363*993b0882SAndroid Build Coastguard Worker       !grammar_actions_->SuggestActions(annotated_conversation,
1364*993b0882SAndroid Build Coastguard Worker                                         &response->actions)) {
1365*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
1366*993b0882SAndroid Build Coastguard Worker     return false;
1367*993b0882SAndroid Build Coastguard Worker   }
1368*993b0882SAndroid Build Coastguard Worker 
1369*993b0882SAndroid Build Coastguard Worker   int input_text_length = 0;
1370*993b0882SAndroid Build Coastguard Worker   int num_matching_locales = 0;
1371*993b0882SAndroid Build Coastguard Worker   for (int i = annotated_conversation.messages.size() - num_messages;
1372*993b0882SAndroid Build Coastguard Worker        i < annotated_conversation.messages.size(); i++) {
1373*993b0882SAndroid Build Coastguard Worker     input_text_length += annotated_conversation.messages[i].text.length();
1374*993b0882SAndroid Build Coastguard Worker     std::vector<Locale> message_languages;
1375*993b0882SAndroid Build Coastguard Worker     if (!ParseLocales(
1376*993b0882SAndroid Build Coastguard Worker             annotated_conversation.messages[i].detected_text_language_tags,
1377*993b0882SAndroid Build Coastguard Worker             &message_languages)) {
1378*993b0882SAndroid Build Coastguard Worker       continue;
1379*993b0882SAndroid Build Coastguard Worker     }
1380*993b0882SAndroid Build Coastguard Worker     if (Locale::IsAnyLocaleSupported(
1381*993b0882SAndroid Build Coastguard Worker             message_languages, locales_,
1382*993b0882SAndroid Build Coastguard Worker             preconditions_.handle_unknown_locale_as_supported)) {
1383*993b0882SAndroid Build Coastguard Worker       ++num_matching_locales;
1384*993b0882SAndroid Build Coastguard Worker     }
1385*993b0882SAndroid Build Coastguard Worker   }
1386*993b0882SAndroid Build Coastguard Worker 
1387*993b0882SAndroid Build Coastguard Worker   // Bail out if we are provided with too few or too much input.
1388*993b0882SAndroid Build Coastguard Worker   if (input_text_length < preconditions_.min_input_length ||
1389*993b0882SAndroid Build Coastguard Worker       (preconditions_.max_input_length >= 0 &&
1390*993b0882SAndroid Build Coastguard Worker        input_text_length > preconditions_.max_input_length)) {
1391*993b0882SAndroid Build Coastguard Worker     TC3_LOG(INFO) << "Too much or not enough input for inference.";
1392*993b0882SAndroid Build Coastguard Worker     return response;
1393*993b0882SAndroid Build Coastguard Worker   }
1394*993b0882SAndroid Build Coastguard Worker 
1395*993b0882SAndroid Build Coastguard Worker   // Bail out if the text does not look like it can be handled by the model.
1396*993b0882SAndroid Build Coastguard Worker   const float matching_fraction =
1397*993b0882SAndroid Build Coastguard Worker       static_cast<float>(num_matching_locales) / num_messages;
1398*993b0882SAndroid Build Coastguard Worker   if (matching_fraction < preconditions_.min_locale_match_fraction) {
1399*993b0882SAndroid Build Coastguard Worker     TC3_LOG(INFO) << "Not enough locale matches.";
1400*993b0882SAndroid Build Coastguard Worker     response->output_filtered_locale_mismatch = true;
1401*993b0882SAndroid Build Coastguard Worker     return true;
1402*993b0882SAndroid Build Coastguard Worker   }
1403*993b0882SAndroid Build Coastguard Worker 
1404*993b0882SAndroid Build Coastguard Worker   std::vector<const UniLib::RegexPattern*> post_check_rules;
1405*993b0882SAndroid Build Coastguard Worker   if (preconditions_.suppress_on_low_confidence_input) {
1406*993b0882SAndroid Build Coastguard Worker     if (regex_actions_->IsLowConfidenceInput(annotated_conversation,
1407*993b0882SAndroid Build Coastguard Worker                                              num_messages, &post_check_rules)) {
1408*993b0882SAndroid Build Coastguard Worker       response->output_filtered_low_confidence = true;
1409*993b0882SAndroid Build Coastguard Worker       return true;
1410*993b0882SAndroid Build Coastguard Worker     }
1411*993b0882SAndroid Build Coastguard Worker   }
1412*993b0882SAndroid Build Coastguard Worker 
1413*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<tflite::Interpreter> interpreter;
1414*993b0882SAndroid Build Coastguard Worker   if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
1415*993b0882SAndroid Build Coastguard Worker                                response, &interpreter)) {
1416*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not run model.";
1417*993b0882SAndroid Build Coastguard Worker     return false;
1418*993b0882SAndroid Build Coastguard Worker   }
1419*993b0882SAndroid Build Coastguard Worker 
1420*993b0882SAndroid Build Coastguard Worker   // SuggestActionsFromModel also detects if the conversation is sensitive,
1421*993b0882SAndroid Build Coastguard Worker   // either by using the old ngram model or the new model.
1422*993b0882SAndroid Build Coastguard Worker   // Suppress all predictions if the conversation was deemed sensitive.
1423*993b0882SAndroid Build Coastguard Worker   if (preconditions_.suppress_on_sensitive_topic && response->is_sensitive) {
1424*993b0882SAndroid Build Coastguard Worker     return true;
1425*993b0882SAndroid Build Coastguard Worker   }
1426*993b0882SAndroid Build Coastguard Worker 
1427*993b0882SAndroid Build Coastguard Worker   if (conversation_intent_detection_) {
1428*993b0882SAndroid Build Coastguard Worker     // TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works.
1429*993b0882SAndroid Build Coastguard Worker     auto actions = SuggestActionsFromConversationIntentDetection(
1430*993b0882SAndroid Build Coastguard Worker         annotated_conversation, options, &response->actions);
1431*993b0882SAndroid Build Coastguard Worker     if (!actions.ok()) {
1432*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Could not run conversation intent detection: "
1433*993b0882SAndroid Build Coastguard Worker                      << actions.error_message();
1434*993b0882SAndroid Build Coastguard Worker       return false;
1435*993b0882SAndroid Build Coastguard Worker     }
1436*993b0882SAndroid Build Coastguard Worker   }
1437*993b0882SAndroid Build Coastguard Worker 
1438*993b0882SAndroid Build Coastguard Worker   if (!SuggestActionsFromLua(
1439*993b0882SAndroid Build Coastguard Worker           annotated_conversation, model_executor_.get(), interpreter.get(),
1440*993b0882SAndroid Build Coastguard Worker           annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1441*993b0882SAndroid Build Coastguard Worker           &response->actions)) {
1442*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not suggest actions from script.";
1443*993b0882SAndroid Build Coastguard Worker     return false;
1444*993b0882SAndroid Build Coastguard Worker   }
1445*993b0882SAndroid Build Coastguard Worker 
1446*993b0882SAndroid Build Coastguard Worker   if (!regex_actions_->SuggestActions(annotated_conversation,
1447*993b0882SAndroid Build Coastguard Worker                                       entity_data_builder_.get(),
1448*993b0882SAndroid Build Coastguard Worker                                       &response->actions)) {
1449*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
1450*993b0882SAndroid Build Coastguard Worker     return false;
1451*993b0882SAndroid Build Coastguard Worker   }
1452*993b0882SAndroid Build Coastguard Worker 
1453*993b0882SAndroid Build Coastguard Worker   if (preconditions_.suppress_on_low_confidence_input &&
1454*993b0882SAndroid Build Coastguard Worker       !regex_actions_->FilterConfidenceOutput(post_check_rules,
1455*993b0882SAndroid Build Coastguard Worker                                               &response->actions)) {
1456*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not post-check actions.";
1457*993b0882SAndroid Build Coastguard Worker     return false;
1458*993b0882SAndroid Build Coastguard Worker   }
1459*993b0882SAndroid Build Coastguard Worker 
1460*993b0882SAndroid Build Coastguard Worker   return true;
1461*993b0882SAndroid Build Coastguard Worker }
1462*993b0882SAndroid Build Coastguard Worker 
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1463*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1464*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation, const Annotator* annotator,
1465*993b0882SAndroid Build Coastguard Worker     const ActionSuggestionOptions& options) const {
1466*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse response;
1467*993b0882SAndroid Build Coastguard Worker 
1468*993b0882SAndroid Build Coastguard Worker   // Assert that messages are sorted correctly.
1469*993b0882SAndroid Build Coastguard Worker   for (int i = 1; i < conversation.messages.size(); i++) {
1470*993b0882SAndroid Build Coastguard Worker     if (conversation.messages[i].reference_time_ms_utc <
1471*993b0882SAndroid Build Coastguard Worker         conversation.messages[i - 1].reference_time_ms_utc) {
1472*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
1473*993b0882SAndroid Build Coastguard Worker       return response;
1474*993b0882SAndroid Build Coastguard Worker     }
1475*993b0882SAndroid Build Coastguard Worker   }
1476*993b0882SAndroid Build Coastguard Worker 
1477*993b0882SAndroid Build Coastguard Worker   // Check that messages are valid utf8.
1478*993b0882SAndroid Build Coastguard Worker   for (const ConversationMessage& message : conversation.messages) {
1479*993b0882SAndroid Build Coastguard Worker     if (message.text.size() > std::numeric_limits<int>::max()) {
1480*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size();
1481*993b0882SAndroid Build Coastguard Worker       return {};
1482*993b0882SAndroid Build Coastguard Worker     }
1483*993b0882SAndroid Build Coastguard Worker 
1484*993b0882SAndroid Build Coastguard Worker     if (!unilib_->IsValidUtf8(UTF8ToUnicodeText(
1485*993b0882SAndroid Build Coastguard Worker             message.text.data(), message.text.size(), /*do_copy=*/false))) {
1486*993b0882SAndroid Build Coastguard Worker       TC3_LOG(ERROR) << "Not valid utf8 provided.";
1487*993b0882SAndroid Build Coastguard Worker       return response;
1488*993b0882SAndroid Build Coastguard Worker     }
1489*993b0882SAndroid Build Coastguard Worker   }
1490*993b0882SAndroid Build Coastguard Worker 
1491*993b0882SAndroid Build Coastguard Worker   if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1492*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1493*993b0882SAndroid Build Coastguard Worker     response.actions.clear();
1494*993b0882SAndroid Build Coastguard Worker   } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1495*993b0882SAndroid Build Coastguard Worker                                    annotator != nullptr
1496*993b0882SAndroid Build Coastguard Worker                                        ? annotator->entity_data_schema()
1497*993b0882SAndroid Build Coastguard Worker                                        : nullptr)) {
1498*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not rank actions.";
1499*993b0882SAndroid Build Coastguard Worker     response.actions.clear();
1500*993b0882SAndroid Build Coastguard Worker   }
1501*993b0882SAndroid Build Coastguard Worker   return response;
1502*993b0882SAndroid Build Coastguard Worker }
1503*993b0882SAndroid Build Coastguard Worker 
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1504*993b0882SAndroid Build Coastguard Worker ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1505*993b0882SAndroid Build Coastguard Worker     const Conversation& conversation,
1506*993b0882SAndroid Build Coastguard Worker     const ActionSuggestionOptions& options) const {
1507*993b0882SAndroid Build Coastguard Worker   return SuggestActions(conversation, /*annotator=*/nullptr, options);
1508*993b0882SAndroid Build Coastguard Worker }
1509*993b0882SAndroid Build Coastguard Worker 
model() const1510*993b0882SAndroid Build Coastguard Worker const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1511*993b0882SAndroid Build Coastguard Worker const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1512*993b0882SAndroid Build Coastguard Worker   return entity_data_schema_;
1513*993b0882SAndroid Build Coastguard Worker }
1514*993b0882SAndroid Build Coastguard Worker 
ViewActionsModel(const void * buffer,int size)1515*993b0882SAndroid Build Coastguard Worker const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1516*993b0882SAndroid Build Coastguard Worker   if (buffer == nullptr) {
1517*993b0882SAndroid Build Coastguard Worker     return nullptr;
1518*993b0882SAndroid Build Coastguard Worker   }
1519*993b0882SAndroid Build Coastguard Worker   return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1520*993b0882SAndroid Build Coastguard Worker }
1521*993b0882SAndroid Build Coastguard Worker 
InitializeConversationIntentDetection(const std::string & serialized_config)1522*993b0882SAndroid Build Coastguard Worker bool ActionsSuggestions::InitializeConversationIntentDetection(
1523*993b0882SAndroid Build Coastguard Worker     const std::string& serialized_config) {
1524*993b0882SAndroid Build Coastguard Worker   auto conversation_intent_detection =
1525*993b0882SAndroid Build Coastguard Worker       std::make_unique<ConversationIntentDetection>();
1526*993b0882SAndroid Build Coastguard Worker   if (!conversation_intent_detection->Initialize(serialized_config).ok()) {
1527*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Failed to initialize conversation intent detection.";
1528*993b0882SAndroid Build Coastguard Worker     return false;
1529*993b0882SAndroid Build Coastguard Worker   }
1530*993b0882SAndroid Build Coastguard Worker   conversation_intent_detection_ = std::move(conversation_intent_detection);
1531*993b0882SAndroid Build Coastguard Worker   return true;
1532*993b0882SAndroid Build Coastguard Worker }
1533*993b0882SAndroid Build Coastguard Worker 
1534*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
1535