1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/intents/intent-generator.h"
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "utils/base/logging.h"
24 #include "utils/intents/jni-lua.h"
25 #include "utils/java/jni-helper.h"
26 #include "utils/utf8/unicodetext.h"
27 #include "utils/zlib/zlib.h"
28
29 #ifdef __cplusplus
30 extern "C" {
31 #endif
32 #include "lauxlib.h"
33 #include "lua.h"
34 #ifdef __cplusplus
35 }
36 #endif
37
38 namespace libtextclassifier3 {
39 namespace {
40
41 static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
42 static constexpr const char* kEnableAddContactIntent =
43 "enable_add_contact_intent";
44 static constexpr const char* kEnableSearchIntent = "enable_search_intent";
45
46 // Lua environment for classfication result intent generation.
47 class AnnotatorJniEnvironment : public JniLuaEnvironment {
48 public:
AnnotatorJniEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const std::string & entity_text,const ClassificationResult & classification,const int64 reference_time_ms_utc,const reflection::Schema * entity_data_schema,const bool enable_add_contact_intent,const bool enable_search_intent)49 AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
50 const jobject context,
51 const std::vector<Locale>& device_locales,
52 const std::string& entity_text,
53 const ClassificationResult& classification,
54 const int64 reference_time_ms_utc,
55 const reflection::Schema* entity_data_schema,
56 const bool enable_add_contact_intent,
57 const bool enable_search_intent)
58 : JniLuaEnvironment(resources, jni_cache, context, device_locales),
59 entity_text_(entity_text),
60 classification_(classification),
61 reference_time_ms_utc_(reference_time_ms_utc),
62 enable_add_contact_intent_(enable_add_contact_intent),
63 enable_search_intent_(enable_search_intent),
64 entity_data_schema_(entity_data_schema) {}
65
66 protected:
SetupExternalHook()67 void SetupExternalHook() override {
68 JniLuaEnvironment::SetupExternalHook();
69 lua_pushinteger(state_, reference_time_ms_utc_);
70 lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
71
72 PushAnnotation(classification_, entity_text_, entity_data_schema_);
73 lua_setfield(state_, /*idx=*/-2, "entity");
74
75 lua_pushboolean(state_, enable_add_contact_intent_);
76 lua_setfield(state_, /*idx=*/-2, kEnableAddContactIntent);
77
78 lua_pushboolean(state_, enable_search_intent_);
79 lua_setfield(state_, /*idx=*/-2, kEnableSearchIntent);
80 }
81
82 const std::string& entity_text_;
83 const ClassificationResult& classification_;
84 const int64 reference_time_ms_utc_;
85 const bool enable_add_contact_intent_;
86 const bool enable_search_intent_;
87
88 // Reflection schema data.
89 const reflection::Schema* const entity_data_schema_;
90 };
91
92 // Lua environment for actions intent generation.
93 class ActionsJniLuaEnvironment : public JniLuaEnvironment {
94 public:
ActionsJniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const Conversation & conversation,const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)95 ActionsJniLuaEnvironment(
96 const Resources& resources, const JniCache* jni_cache,
97 const jobject context, const std::vector<Locale>& device_locales,
98 const Conversation& conversation, const ActionSuggestion& action,
99 const reflection::Schema* actions_entity_data_schema,
100 const reflection::Schema* annotations_entity_data_schema)
101 : JniLuaEnvironment(resources, jni_cache, context, device_locales),
102 conversation_(conversation),
103 action_(action),
104 actions_entity_data_schema_(actions_entity_data_schema),
105 annotations_entity_data_schema_(annotations_entity_data_schema) {}
106
107 protected:
SetupExternalHook()108 void SetupExternalHook() override {
109 JniLuaEnvironment::SetupExternalHook();
110 PushConversation(&conversation_.messages, annotations_entity_data_schema_);
111 lua_setfield(state_, /*idx=*/-2, "conversation");
112
113 PushAction(action_, actions_entity_data_schema_,
114 annotations_entity_data_schema_);
115 lua_setfield(state_, /*idx=*/-2, "entity");
116 }
117
118 const Conversation& conversation_;
119 const ActionSuggestion& action_;
120 const reflection::Schema* actions_entity_data_schema_;
121 const reflection::Schema* annotations_entity_data_schema_;
122 };
123
124 } // namespace
125
Create(const IntentFactoryModel * options,const ResourcePool * resources,const std::shared_ptr<JniCache> & jni_cache)126 std::unique_ptr<IntentGenerator> IntentGenerator::Create(
127 const IntentFactoryModel* options, const ResourcePool* resources,
128 const std::shared_ptr<JniCache>& jni_cache) {
129 std::unique_ptr<IntentGenerator> intent_generator(
130 new IntentGenerator(options, resources, jni_cache));
131
132 if (options == nullptr || options->generator() == nullptr) {
133 TC3_LOG(ERROR) << "No intent generator options.";
134 return nullptr;
135 }
136
137 std::unique_ptr<ZlibDecompressor> zlib_decompressor =
138 ZlibDecompressor::Instance();
139 if (!zlib_decompressor) {
140 TC3_LOG(ERROR) << "Cannot initialize decompressor.";
141 return nullptr;
142 }
143
144 for (const IntentFactoryModel_::IntentGenerator* generator :
145 *options->generator()) {
146 std::string lua_template_generator;
147 if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
148 generator->lua_template_generator(),
149 generator->compressed_lua_template_generator(),
150 &lua_template_generator)) {
151 TC3_LOG(ERROR) << "Could not decompress generator template.";
152 return nullptr;
153 }
154
155 std::string lua_code = lua_template_generator;
156 if (options->precompile_generators()) {
157 if (!Compile(lua_template_generator, &lua_code)) {
158 TC3_LOG(ERROR) << "Could not precompile generator template.";
159 return nullptr;
160 }
161 }
162
163 intent_generator->generators_[generator->type()->str()] = lua_code;
164 }
165
166 return intent_generator;
167 }
168
ParseDeviceLocales(const jstring device_locales) const169 std::vector<Locale> IntentGenerator::ParseDeviceLocales(
170 const jstring device_locales) const {
171 if (device_locales == nullptr) {
172 TC3_LOG(ERROR) << "No locales provided.";
173 return {};
174 }
175 StatusOr<std::string> status_or_locales_str =
176 JStringToUtf8String(jni_cache_->GetEnv(), device_locales);
177 if (!status_or_locales_str.ok()) {
178 TC3_LOG(ERROR)
179 << "JStringToUtf8String failed, cannot retrieve provided locales.";
180 return {};
181 }
182 std::vector<Locale> locales;
183 if (!ParseLocales(status_or_locales_str.ValueOrDie(), &locales)) {
184 TC3_LOG(ERROR) << "Cannot parse locales.";
185 return {};
186 }
187 return locales;
188 }
189
GenerateIntents(const jstring device_locales,const ClassificationResult & classification,const int64 reference_time_ms_utc,const std::string & text,const CodepointSpan selection_indices,const jobject context,const reflection::Schema * annotations_entity_data_schema,const bool enable_add_contact_intent,const bool enable_search_intent,std::vector<RemoteActionTemplate> * remote_actions) const190 bool IntentGenerator::GenerateIntents(
191 const jstring device_locales, const ClassificationResult& classification,
192 const int64 reference_time_ms_utc, const std::string& text,
193 const CodepointSpan selection_indices, const jobject context,
194 const reflection::Schema* annotations_entity_data_schema,
195 const bool enable_add_contact_intent, const bool enable_search_intent,
196 std::vector<RemoteActionTemplate>* remote_actions) const {
197 if (options_ == nullptr) {
198 return false;
199 }
200
201 // Retrieve generator for specified entity.
202 auto it = generators_.find(classification.collection);
203 if (it == generators_.end()) {
204 TC3_VLOG(INFO) << "Cannot find a generator for the specified collection.";
205 return true;
206 }
207
208 const std::string entity_text =
209 UTF8ToUnicodeText(text, /*do_copy=*/false)
210 .UTF8Substring(selection_indices.first, selection_indices.second);
211
212 std::unique_ptr<AnnotatorJniEnvironment> interpreter(
213 new AnnotatorJniEnvironment(
214 resources_, jni_cache_.get(), context,
215 ParseDeviceLocales(device_locales), entity_text, classification,
216 reference_time_ms_utc, annotations_entity_data_schema,
217 enable_add_contact_intent, enable_search_intent));
218
219 if (!interpreter->Initialize()) {
220 TC3_LOG(ERROR) << "Could not create Lua interpreter.";
221 return false;
222 }
223
224 return interpreter->RunIntentGenerator(it->second, remote_actions);
225 }
226
GenerateIntents(const jstring device_locales,const ActionSuggestion & action,const Conversation & conversation,const jobject context,const reflection::Schema * annotations_entity_data_schema,const reflection::Schema * actions_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const227 bool IntentGenerator::GenerateIntents(
228 const jstring device_locales, const ActionSuggestion& action,
229 const Conversation& conversation, const jobject context,
230 const reflection::Schema* annotations_entity_data_schema,
231 const reflection::Schema* actions_entity_data_schema,
232 std::vector<RemoteActionTemplate>* remote_actions) const {
233 if (options_ == nullptr) {
234 return false;
235 }
236
237 // Retrieve generator for specified action.
238 auto it = generators_.find(action.type);
239 if (it == generators_.end()) {
240 return true;
241 }
242
243 std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
244 new ActionsJniLuaEnvironment(
245 resources_, jni_cache_.get(), context,
246 ParseDeviceLocales(device_locales), conversation, action,
247 actions_entity_data_schema, annotations_entity_data_schema));
248
249 if (!interpreter->Initialize()) {
250 TC3_LOG(ERROR) << "Could not create Lua interpreter.";
251 return false;
252 }
253
254 return interpreter->RunIntentGenerator(it->second, remote_actions);
255 }
256
257 } // namespace libtextclassifier3
258