1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "actions/lua-actions.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/lua-utils.h"
21*993b0882SAndroid Build Coastguard Worker
22*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
23*993b0882SAndroid Build Coastguard Worker extern "C" {
24*993b0882SAndroid Build Coastguard Worker #endif
25*993b0882SAndroid Build Coastguard Worker #include "lauxlib.h"
26*993b0882SAndroid Build Coastguard Worker #include "lualib.h"
27*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
28*993b0882SAndroid Build Coastguard Worker }
29*993b0882SAndroid Build Coastguard Worker #endif
30*993b0882SAndroid Build Coastguard Worker
31*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
32*993b0882SAndroid Build Coastguard Worker namespace {
33*993b0882SAndroid Build Coastguard Worker
GetTensorViewForOutput(const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,int output)34*993b0882SAndroid Build Coastguard Worker TensorView<float> GetTensorViewForOutput(
35*993b0882SAndroid Build Coastguard Worker const TfLiteModelExecutor* model_executor,
36*993b0882SAndroid Build Coastguard Worker const tflite::Interpreter* interpreter, int output) {
37*993b0882SAndroid Build Coastguard Worker if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
38*993b0882SAndroid Build Coastguard Worker return TensorView<float>::Invalid();
39*993b0882SAndroid Build Coastguard Worker }
40*993b0882SAndroid Build Coastguard Worker return model_executor->OutputView<float>(output, interpreter);
41*993b0882SAndroid Build Coastguard Worker }
42*993b0882SAndroid Build Coastguard Worker
GetStringTensorForOutput(const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,int output)43*993b0882SAndroid Build Coastguard Worker std::vector<std::string> GetStringTensorForOutput(
44*993b0882SAndroid Build Coastguard Worker const TfLiteModelExecutor* model_executor,
45*993b0882SAndroid Build Coastguard Worker const tflite::Interpreter* interpreter, int output) {
46*993b0882SAndroid Build Coastguard Worker if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
47*993b0882SAndroid Build Coastguard Worker return {};
48*993b0882SAndroid Build Coastguard Worker }
49*993b0882SAndroid Build Coastguard Worker return model_executor->Output<std::string>(output, interpreter);
50*993b0882SAndroid Build Coastguard Worker }
51*993b0882SAndroid Build Coastguard Worker
52*993b0882SAndroid Build Coastguard Worker } // namespace
53*993b0882SAndroid Build Coastguard Worker
54*993b0882SAndroid Build Coastguard Worker std::unique_ptr<LuaActionsSuggestions>
CreateLuaActionsSuggestions(const std::string & snippet,const Conversation & conversation,const TfLiteModelExecutor * model_executor,const TensorflowLiteModelSpec * model_spec,const tflite::Interpreter * interpreter,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)55*993b0882SAndroid Build Coastguard Worker LuaActionsSuggestions::CreateLuaActionsSuggestions(
56*993b0882SAndroid Build Coastguard Worker const std::string& snippet, const Conversation& conversation,
57*993b0882SAndroid Build Coastguard Worker const TfLiteModelExecutor* model_executor,
58*993b0882SAndroid Build Coastguard Worker const TensorflowLiteModelSpec* model_spec,
59*993b0882SAndroid Build Coastguard Worker const tflite::Interpreter* interpreter,
60*993b0882SAndroid Build Coastguard Worker const reflection::Schema* actions_entity_data_schema,
61*993b0882SAndroid Build Coastguard Worker const reflection::Schema* annotations_entity_data_schema) {
62*993b0882SAndroid Build Coastguard Worker auto lua_actions =
63*993b0882SAndroid Build Coastguard Worker std::unique_ptr<LuaActionsSuggestions>(new LuaActionsSuggestions(
64*993b0882SAndroid Build Coastguard Worker snippet, conversation, model_executor, model_spec, interpreter,
65*993b0882SAndroid Build Coastguard Worker actions_entity_data_schema, annotations_entity_data_schema));
66*993b0882SAndroid Build Coastguard Worker if (!lua_actions->Initialize()) {
67*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR)
68*993b0882SAndroid Build Coastguard Worker << "Could not initialize lua environment for actions suggestions.";
69*993b0882SAndroid Build Coastguard Worker return nullptr;
70*993b0882SAndroid Build Coastguard Worker }
71*993b0882SAndroid Build Coastguard Worker return lua_actions;
72*993b0882SAndroid Build Coastguard Worker }
73*993b0882SAndroid Build Coastguard Worker
LuaActionsSuggestions(const std::string & snippet,const Conversation & conversation,const TfLiteModelExecutor * model_executor,const TensorflowLiteModelSpec * model_spec,const tflite::Interpreter * interpreter,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)74*993b0882SAndroid Build Coastguard Worker LuaActionsSuggestions::LuaActionsSuggestions(
75*993b0882SAndroid Build Coastguard Worker const std::string& snippet, const Conversation& conversation,
76*993b0882SAndroid Build Coastguard Worker const TfLiteModelExecutor* model_executor,
77*993b0882SAndroid Build Coastguard Worker const TensorflowLiteModelSpec* model_spec,
78*993b0882SAndroid Build Coastguard Worker const tflite::Interpreter* interpreter,
79*993b0882SAndroid Build Coastguard Worker const reflection::Schema* actions_entity_data_schema,
80*993b0882SAndroid Build Coastguard Worker const reflection::Schema* annotations_entity_data_schema)
81*993b0882SAndroid Build Coastguard Worker : snippet_(snippet),
82*993b0882SAndroid Build Coastguard Worker conversation_(conversation),
83*993b0882SAndroid Build Coastguard Worker actions_scores_(
84*993b0882SAndroid Build Coastguard Worker model_spec == nullptr
85*993b0882SAndroid Build Coastguard Worker ? TensorView<float>::Invalid()
86*993b0882SAndroid Build Coastguard Worker : GetTensorViewForOutput(model_executor, interpreter,
87*993b0882SAndroid Build Coastguard Worker model_spec->output_actions_scores())),
88*993b0882SAndroid Build Coastguard Worker smart_reply_scores_(
89*993b0882SAndroid Build Coastguard Worker model_spec == nullptr
90*993b0882SAndroid Build Coastguard Worker ? TensorView<float>::Invalid()
91*993b0882SAndroid Build Coastguard Worker : GetTensorViewForOutput(model_executor, interpreter,
92*993b0882SAndroid Build Coastguard Worker model_spec->output_replies_scores())),
93*993b0882SAndroid Build Coastguard Worker sensitivity_score_(model_spec == nullptr
94*993b0882SAndroid Build Coastguard Worker ? TensorView<float>::Invalid()
95*993b0882SAndroid Build Coastguard Worker : GetTensorViewForOutput(
96*993b0882SAndroid Build Coastguard Worker model_executor, interpreter,
97*993b0882SAndroid Build Coastguard Worker model_spec->output_sensitive_topic_score())),
98*993b0882SAndroid Build Coastguard Worker triggering_score_(
99*993b0882SAndroid Build Coastguard Worker model_spec == nullptr
100*993b0882SAndroid Build Coastguard Worker ? TensorView<float>::Invalid()
101*993b0882SAndroid Build Coastguard Worker : GetTensorViewForOutput(model_executor, interpreter,
102*993b0882SAndroid Build Coastguard Worker model_spec->output_triggering_score())),
103*993b0882SAndroid Build Coastguard Worker smart_replies_(model_spec == nullptr ? std::vector<std::string>{}
104*993b0882SAndroid Build Coastguard Worker : GetStringTensorForOutput(
105*993b0882SAndroid Build Coastguard Worker model_executor, interpreter,
106*993b0882SAndroid Build Coastguard Worker model_spec->output_replies())),
107*993b0882SAndroid Build Coastguard Worker actions_entity_data_schema_(actions_entity_data_schema),
108*993b0882SAndroid Build Coastguard Worker annotations_entity_data_schema_(annotations_entity_data_schema) {}
109*993b0882SAndroid Build Coastguard Worker
Initialize()110*993b0882SAndroid Build Coastguard Worker bool LuaActionsSuggestions::Initialize() {
111*993b0882SAndroid Build Coastguard Worker return RunProtected([this] {
112*993b0882SAndroid Build Coastguard Worker LoadDefaultLibraries();
113*993b0882SAndroid Build Coastguard Worker
114*993b0882SAndroid Build Coastguard Worker // Expose conversation message stream.
115*993b0882SAndroid Build Coastguard Worker PushConversation(&conversation_.messages,
116*993b0882SAndroid Build Coastguard Worker annotations_entity_data_schema_);
117*993b0882SAndroid Build Coastguard Worker lua_setglobal(state_, "messages");
118*993b0882SAndroid Build Coastguard Worker
119*993b0882SAndroid Build Coastguard Worker // Expose ML model output.
120*993b0882SAndroid Build Coastguard Worker lua_newtable(state_);
121*993b0882SAndroid Build Coastguard Worker
122*993b0882SAndroid Build Coastguard Worker PushTensor(&actions_scores_);
123*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, "actions_scores");
124*993b0882SAndroid Build Coastguard Worker
125*993b0882SAndroid Build Coastguard Worker PushTensor(&smart_reply_scores_);
126*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, "reply_scores");
127*993b0882SAndroid Build Coastguard Worker
128*993b0882SAndroid Build Coastguard Worker PushTensor(&sensitivity_score_);
129*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, "sensitivity");
130*993b0882SAndroid Build Coastguard Worker
131*993b0882SAndroid Build Coastguard Worker PushTensor(&triggering_score_);
132*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, "triggering_score");
133*993b0882SAndroid Build Coastguard Worker
134*993b0882SAndroid Build Coastguard Worker PushVectorIterator(&smart_replies_);
135*993b0882SAndroid Build Coastguard Worker lua_setfield(state_, /*idx=*/-2, "reply");
136*993b0882SAndroid Build Coastguard Worker
137*993b0882SAndroid Build Coastguard Worker lua_setglobal(state_, "model");
138*993b0882SAndroid Build Coastguard Worker
139*993b0882SAndroid Build Coastguard Worker return LUA_OK;
140*993b0882SAndroid Build Coastguard Worker }) == LUA_OK;
141*993b0882SAndroid Build Coastguard Worker }
142*993b0882SAndroid Build Coastguard Worker
SuggestActions(std::vector<ActionSuggestion> * actions)143*993b0882SAndroid Build Coastguard Worker bool LuaActionsSuggestions::SuggestActions(
144*993b0882SAndroid Build Coastguard Worker std::vector<ActionSuggestion>* actions) {
145*993b0882SAndroid Build Coastguard Worker if (luaL_loadbuffer(state_, snippet_.data(), snippet_.size(),
146*993b0882SAndroid Build Coastguard Worker /*name=*/nullptr) != LUA_OK) {
147*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not load actions suggestions snippet.";
148*993b0882SAndroid Build Coastguard Worker return false;
149*993b0882SAndroid Build Coastguard Worker }
150*993b0882SAndroid Build Coastguard Worker
151*993b0882SAndroid Build Coastguard Worker if (lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0) != LUA_OK) {
152*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not run actions suggestions snippet.";
153*993b0882SAndroid Build Coastguard Worker return false;
154*993b0882SAndroid Build Coastguard Worker }
155*993b0882SAndroid Build Coastguard Worker
156*993b0882SAndroid Build Coastguard Worker if (RunProtected(
157*993b0882SAndroid Build Coastguard Worker [this, actions] {
158*993b0882SAndroid Build Coastguard Worker return ReadActions(actions_entity_data_schema_,
159*993b0882SAndroid Build Coastguard Worker annotations_entity_data_schema_, actions);
160*993b0882SAndroid Build Coastguard Worker },
161*993b0882SAndroid Build Coastguard Worker /*num_args=*/1) != LUA_OK) {
162*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not read lua result.";
163*993b0882SAndroid Build Coastguard Worker return false;
164*993b0882SAndroid Build Coastguard Worker }
165*993b0882SAndroid Build Coastguard Worker return true;
166*993b0882SAndroid Build Coastguard Worker }
167*993b0882SAndroid Build Coastguard Worker
168*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
169