xref: /aosp_15_r20/external/libtextclassifier/native/actions/lua-actions.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "actions/lua-actions.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/lua-utils.h"
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
23*993b0882SAndroid Build Coastguard Worker extern "C" {
24*993b0882SAndroid Build Coastguard Worker #endif
25*993b0882SAndroid Build Coastguard Worker #include "lauxlib.h"
26*993b0882SAndroid Build Coastguard Worker #include "lualib.h"
27*993b0882SAndroid Build Coastguard Worker #ifdef __cplusplus
28*993b0882SAndroid Build Coastguard Worker }
29*993b0882SAndroid Build Coastguard Worker #endif
30*993b0882SAndroid Build Coastguard Worker 
31*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
32*993b0882SAndroid Build Coastguard Worker namespace {
33*993b0882SAndroid Build Coastguard Worker 
GetTensorViewForOutput(const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,int output)34*993b0882SAndroid Build Coastguard Worker TensorView<float> GetTensorViewForOutput(
35*993b0882SAndroid Build Coastguard Worker     const TfLiteModelExecutor* model_executor,
36*993b0882SAndroid Build Coastguard Worker     const tflite::Interpreter* interpreter, int output) {
37*993b0882SAndroid Build Coastguard Worker   if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
38*993b0882SAndroid Build Coastguard Worker     return TensorView<float>::Invalid();
39*993b0882SAndroid Build Coastguard Worker   }
40*993b0882SAndroid Build Coastguard Worker   return model_executor->OutputView<float>(output, interpreter);
41*993b0882SAndroid Build Coastguard Worker }
42*993b0882SAndroid Build Coastguard Worker 
GetStringTensorForOutput(const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,int output)43*993b0882SAndroid Build Coastguard Worker std::vector<std::string> GetStringTensorForOutput(
44*993b0882SAndroid Build Coastguard Worker     const TfLiteModelExecutor* model_executor,
45*993b0882SAndroid Build Coastguard Worker     const tflite::Interpreter* interpreter, int output) {
46*993b0882SAndroid Build Coastguard Worker   if (output < 0 || model_executor == nullptr || interpreter == nullptr) {
47*993b0882SAndroid Build Coastguard Worker     return {};
48*993b0882SAndroid Build Coastguard Worker   }
49*993b0882SAndroid Build Coastguard Worker   return model_executor->Output<std::string>(output, interpreter);
50*993b0882SAndroid Build Coastguard Worker }
51*993b0882SAndroid Build Coastguard Worker 
52*993b0882SAndroid Build Coastguard Worker }  // namespace
53*993b0882SAndroid Build Coastguard Worker 
54*993b0882SAndroid Build Coastguard Worker std::unique_ptr<LuaActionsSuggestions>
CreateLuaActionsSuggestions(const std::string & snippet,const Conversation & conversation,const TfLiteModelExecutor * model_executor,const TensorflowLiteModelSpec * model_spec,const tflite::Interpreter * interpreter,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)55*993b0882SAndroid Build Coastguard Worker LuaActionsSuggestions::CreateLuaActionsSuggestions(
56*993b0882SAndroid Build Coastguard Worker     const std::string& snippet, const Conversation& conversation,
57*993b0882SAndroid Build Coastguard Worker     const TfLiteModelExecutor* model_executor,
58*993b0882SAndroid Build Coastguard Worker     const TensorflowLiteModelSpec* model_spec,
59*993b0882SAndroid Build Coastguard Worker     const tflite::Interpreter* interpreter,
60*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* actions_entity_data_schema,
61*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* annotations_entity_data_schema) {
62*993b0882SAndroid Build Coastguard Worker   auto lua_actions =
63*993b0882SAndroid Build Coastguard Worker       std::unique_ptr<LuaActionsSuggestions>(new LuaActionsSuggestions(
64*993b0882SAndroid Build Coastguard Worker           snippet, conversation, model_executor, model_spec, interpreter,
65*993b0882SAndroid Build Coastguard Worker           actions_entity_data_schema, annotations_entity_data_schema));
66*993b0882SAndroid Build Coastguard Worker   if (!lua_actions->Initialize()) {
67*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR)
68*993b0882SAndroid Build Coastguard Worker         << "Could not initialize lua environment for actions suggestions.";
69*993b0882SAndroid Build Coastguard Worker     return nullptr;
70*993b0882SAndroid Build Coastguard Worker   }
71*993b0882SAndroid Build Coastguard Worker   return lua_actions;
72*993b0882SAndroid Build Coastguard Worker }
73*993b0882SAndroid Build Coastguard Worker 
LuaActionsSuggestions(const std::string & snippet,const Conversation & conversation,const TfLiteModelExecutor * model_executor,const TensorflowLiteModelSpec * model_spec,const tflite::Interpreter * interpreter,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)74*993b0882SAndroid Build Coastguard Worker LuaActionsSuggestions::LuaActionsSuggestions(
75*993b0882SAndroid Build Coastguard Worker     const std::string& snippet, const Conversation& conversation,
76*993b0882SAndroid Build Coastguard Worker     const TfLiteModelExecutor* model_executor,
77*993b0882SAndroid Build Coastguard Worker     const TensorflowLiteModelSpec* model_spec,
78*993b0882SAndroid Build Coastguard Worker     const tflite::Interpreter* interpreter,
79*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* actions_entity_data_schema,
80*993b0882SAndroid Build Coastguard Worker     const reflection::Schema* annotations_entity_data_schema)
81*993b0882SAndroid Build Coastguard Worker     : snippet_(snippet),
82*993b0882SAndroid Build Coastguard Worker       conversation_(conversation),
83*993b0882SAndroid Build Coastguard Worker       actions_scores_(
84*993b0882SAndroid Build Coastguard Worker           model_spec == nullptr
85*993b0882SAndroid Build Coastguard Worker               ? TensorView<float>::Invalid()
86*993b0882SAndroid Build Coastguard Worker               : GetTensorViewForOutput(model_executor, interpreter,
87*993b0882SAndroid Build Coastguard Worker                                        model_spec->output_actions_scores())),
88*993b0882SAndroid Build Coastguard Worker       smart_reply_scores_(
89*993b0882SAndroid Build Coastguard Worker           model_spec == nullptr
90*993b0882SAndroid Build Coastguard Worker               ? TensorView<float>::Invalid()
91*993b0882SAndroid Build Coastguard Worker               : GetTensorViewForOutput(model_executor, interpreter,
92*993b0882SAndroid Build Coastguard Worker                                        model_spec->output_replies_scores())),
93*993b0882SAndroid Build Coastguard Worker       sensitivity_score_(model_spec == nullptr
94*993b0882SAndroid Build Coastguard Worker                              ? TensorView<float>::Invalid()
95*993b0882SAndroid Build Coastguard Worker                              : GetTensorViewForOutput(
96*993b0882SAndroid Build Coastguard Worker                                    model_executor, interpreter,
97*993b0882SAndroid Build Coastguard Worker                                    model_spec->output_sensitive_topic_score())),
98*993b0882SAndroid Build Coastguard Worker       triggering_score_(
99*993b0882SAndroid Build Coastguard Worker           model_spec == nullptr
100*993b0882SAndroid Build Coastguard Worker               ? TensorView<float>::Invalid()
101*993b0882SAndroid Build Coastguard Worker               : GetTensorViewForOutput(model_executor, interpreter,
102*993b0882SAndroid Build Coastguard Worker                                        model_spec->output_triggering_score())),
103*993b0882SAndroid Build Coastguard Worker       smart_replies_(model_spec == nullptr ? std::vector<std::string>{}
104*993b0882SAndroid Build Coastguard Worker                                            : GetStringTensorForOutput(
105*993b0882SAndroid Build Coastguard Worker                                                  model_executor, interpreter,
106*993b0882SAndroid Build Coastguard Worker                                                  model_spec->output_replies())),
107*993b0882SAndroid Build Coastguard Worker       actions_entity_data_schema_(actions_entity_data_schema),
108*993b0882SAndroid Build Coastguard Worker       annotations_entity_data_schema_(annotations_entity_data_schema) {}
109*993b0882SAndroid Build Coastguard Worker 
Initialize()110*993b0882SAndroid Build Coastguard Worker bool LuaActionsSuggestions::Initialize() {
111*993b0882SAndroid Build Coastguard Worker   return RunProtected([this] {
112*993b0882SAndroid Build Coastguard Worker            LoadDefaultLibraries();
113*993b0882SAndroid Build Coastguard Worker 
114*993b0882SAndroid Build Coastguard Worker            // Expose conversation message stream.
115*993b0882SAndroid Build Coastguard Worker            PushConversation(&conversation_.messages,
116*993b0882SAndroid Build Coastguard Worker                             annotations_entity_data_schema_);
117*993b0882SAndroid Build Coastguard Worker            lua_setglobal(state_, "messages");
118*993b0882SAndroid Build Coastguard Worker 
119*993b0882SAndroid Build Coastguard Worker            // Expose ML model output.
120*993b0882SAndroid Build Coastguard Worker            lua_newtable(state_);
121*993b0882SAndroid Build Coastguard Worker 
122*993b0882SAndroid Build Coastguard Worker            PushTensor(&actions_scores_);
123*993b0882SAndroid Build Coastguard Worker            lua_setfield(state_, /*idx=*/-2, "actions_scores");
124*993b0882SAndroid Build Coastguard Worker 
125*993b0882SAndroid Build Coastguard Worker            PushTensor(&smart_reply_scores_);
126*993b0882SAndroid Build Coastguard Worker            lua_setfield(state_, /*idx=*/-2, "reply_scores");
127*993b0882SAndroid Build Coastguard Worker 
128*993b0882SAndroid Build Coastguard Worker            PushTensor(&sensitivity_score_);
129*993b0882SAndroid Build Coastguard Worker            lua_setfield(state_, /*idx=*/-2, "sensitivity");
130*993b0882SAndroid Build Coastguard Worker 
131*993b0882SAndroid Build Coastguard Worker            PushTensor(&triggering_score_);
132*993b0882SAndroid Build Coastguard Worker            lua_setfield(state_, /*idx=*/-2, "triggering_score");
133*993b0882SAndroid Build Coastguard Worker 
134*993b0882SAndroid Build Coastguard Worker            PushVectorIterator(&smart_replies_);
135*993b0882SAndroid Build Coastguard Worker            lua_setfield(state_, /*idx=*/-2, "reply");
136*993b0882SAndroid Build Coastguard Worker 
137*993b0882SAndroid Build Coastguard Worker            lua_setglobal(state_, "model");
138*993b0882SAndroid Build Coastguard Worker 
139*993b0882SAndroid Build Coastguard Worker            return LUA_OK;
140*993b0882SAndroid Build Coastguard Worker          }) == LUA_OK;
141*993b0882SAndroid Build Coastguard Worker }
142*993b0882SAndroid Build Coastguard Worker 
SuggestActions(std::vector<ActionSuggestion> * actions)143*993b0882SAndroid Build Coastguard Worker bool LuaActionsSuggestions::SuggestActions(
144*993b0882SAndroid Build Coastguard Worker     std::vector<ActionSuggestion>* actions) {
145*993b0882SAndroid Build Coastguard Worker   if (luaL_loadbuffer(state_, snippet_.data(), snippet_.size(),
146*993b0882SAndroid Build Coastguard Worker                       /*name=*/nullptr) != LUA_OK) {
147*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not load actions suggestions snippet.";
148*993b0882SAndroid Build Coastguard Worker     return false;
149*993b0882SAndroid Build Coastguard Worker   }
150*993b0882SAndroid Build Coastguard Worker 
151*993b0882SAndroid Build Coastguard Worker   if (lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0) != LUA_OK) {
152*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not run actions suggestions snippet.";
153*993b0882SAndroid Build Coastguard Worker     return false;
154*993b0882SAndroid Build Coastguard Worker   }
155*993b0882SAndroid Build Coastguard Worker 
156*993b0882SAndroid Build Coastguard Worker   if (RunProtected(
157*993b0882SAndroid Build Coastguard Worker           [this, actions] {
158*993b0882SAndroid Build Coastguard Worker             return ReadActions(actions_entity_data_schema_,
159*993b0882SAndroid Build Coastguard Worker                                annotations_entity_data_schema_, actions);
160*993b0882SAndroid Build Coastguard Worker           },
161*993b0882SAndroid Build Coastguard Worker           /*num_args=*/1) != LUA_OK) {
162*993b0882SAndroid Build Coastguard Worker     TC3_LOG(ERROR) << "Could not read lua result.";
163*993b0882SAndroid Build Coastguard Worker     return false;
164*993b0882SAndroid Build Coastguard Worker   }
165*993b0882SAndroid Build Coastguard Worker   return true;
166*993b0882SAndroid Build Coastguard Worker }
167*993b0882SAndroid Build Coastguard Worker 
168*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
169