xref: /aosp_15_r20/external/libtextclassifier/native/utils/intents/intent-generator.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/intents/intent-generator.h"
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "utils/base/logging.h"
24 #include "utils/intents/jni-lua.h"
25 #include "utils/java/jni-helper.h"
26 #include "utils/utf8/unicodetext.h"
27 #include "utils/zlib/zlib.h"
28 
29 #ifdef __cplusplus
30 extern "C" {
31 #endif
32 #include "lauxlib.h"
33 #include "lua.h"
34 #ifdef __cplusplus
35 }
36 #endif
37 
38 namespace libtextclassifier3 {
39 namespace {
40 
41 static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
42 static constexpr const char* kEnableAddContactIntent =
43     "enable_add_contact_intent";
44 static constexpr const char* kEnableSearchIntent = "enable_search_intent";
45 
46 // Lua environment for classfication result intent generation.
47 class AnnotatorJniEnvironment : public JniLuaEnvironment {
48  public:
AnnotatorJniEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const std::string & entity_text,const ClassificationResult & classification,const int64 reference_time_ms_utc,const reflection::Schema * entity_data_schema,const bool enable_add_contact_intent,const bool enable_search_intent)49   AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
50                           const jobject context,
51                           const std::vector<Locale>& device_locales,
52                           const std::string& entity_text,
53                           const ClassificationResult& classification,
54                           const int64 reference_time_ms_utc,
55                           const reflection::Schema* entity_data_schema,
56                           const bool enable_add_contact_intent,
57                           const bool enable_search_intent)
58       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
59         entity_text_(entity_text),
60         classification_(classification),
61         reference_time_ms_utc_(reference_time_ms_utc),
62         enable_add_contact_intent_(enable_add_contact_intent),
63         enable_search_intent_(enable_search_intent),
64         entity_data_schema_(entity_data_schema) {}
65 
66  protected:
SetupExternalHook()67   void SetupExternalHook() override {
68     JniLuaEnvironment::SetupExternalHook();
69     lua_pushinteger(state_, reference_time_ms_utc_);
70     lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
71 
72     PushAnnotation(classification_, entity_text_, entity_data_schema_);
73     lua_setfield(state_, /*idx=*/-2, "entity");
74 
75     lua_pushboolean(state_, enable_add_contact_intent_);
76     lua_setfield(state_, /*idx=*/-2, kEnableAddContactIntent);
77 
78     lua_pushboolean(state_, enable_search_intent_);
79     lua_setfield(state_, /*idx=*/-2, kEnableSearchIntent);
80   }
81 
82   const std::string& entity_text_;
83   const ClassificationResult& classification_;
84   const int64 reference_time_ms_utc_;
85   const bool enable_add_contact_intent_;
86   const bool enable_search_intent_;
87 
88   // Reflection schema data.
89   const reflection::Schema* const entity_data_schema_;
90 };
91 
92 // Lua environment for actions intent generation.
93 class ActionsJniLuaEnvironment : public JniLuaEnvironment {
94  public:
ActionsJniLuaEnvironment(const Resources & resources,const JniCache * jni_cache,const jobject context,const std::vector<Locale> & device_locales,const Conversation & conversation,const ActionSuggestion & action,const reflection::Schema * actions_entity_data_schema,const reflection::Schema * annotations_entity_data_schema)95   ActionsJniLuaEnvironment(
96       const Resources& resources, const JniCache* jni_cache,
97       const jobject context, const std::vector<Locale>& device_locales,
98       const Conversation& conversation, const ActionSuggestion& action,
99       const reflection::Schema* actions_entity_data_schema,
100       const reflection::Schema* annotations_entity_data_schema)
101       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
102         conversation_(conversation),
103         action_(action),
104         actions_entity_data_schema_(actions_entity_data_schema),
105         annotations_entity_data_schema_(annotations_entity_data_schema) {}
106 
107  protected:
SetupExternalHook()108   void SetupExternalHook() override {
109     JniLuaEnvironment::SetupExternalHook();
110     PushConversation(&conversation_.messages, annotations_entity_data_schema_);
111     lua_setfield(state_, /*idx=*/-2, "conversation");
112 
113     PushAction(action_, actions_entity_data_schema_,
114                annotations_entity_data_schema_);
115     lua_setfield(state_, /*idx=*/-2, "entity");
116   }
117 
118   const Conversation& conversation_;
119   const ActionSuggestion& action_;
120   const reflection::Schema* actions_entity_data_schema_;
121   const reflection::Schema* annotations_entity_data_schema_;
122 };
123 
124 }  // namespace
125 
Create(const IntentFactoryModel * options,const ResourcePool * resources,const std::shared_ptr<JniCache> & jni_cache)126 std::unique_ptr<IntentGenerator> IntentGenerator::Create(
127     const IntentFactoryModel* options, const ResourcePool* resources,
128     const std::shared_ptr<JniCache>& jni_cache) {
129   std::unique_ptr<IntentGenerator> intent_generator(
130       new IntentGenerator(options, resources, jni_cache));
131 
132   if (options == nullptr || options->generator() == nullptr) {
133     TC3_LOG(ERROR) << "No intent generator options.";
134     return nullptr;
135   }
136 
137   std::unique_ptr<ZlibDecompressor> zlib_decompressor =
138       ZlibDecompressor::Instance();
139   if (!zlib_decompressor) {
140     TC3_LOG(ERROR) << "Cannot initialize decompressor.";
141     return nullptr;
142   }
143 
144   for (const IntentFactoryModel_::IntentGenerator* generator :
145        *options->generator()) {
146     std::string lua_template_generator;
147     if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
148             generator->lua_template_generator(),
149             generator->compressed_lua_template_generator(),
150             &lua_template_generator)) {
151       TC3_LOG(ERROR) << "Could not decompress generator template.";
152       return nullptr;
153     }
154 
155     std::string lua_code = lua_template_generator;
156     if (options->precompile_generators()) {
157       if (!Compile(lua_template_generator, &lua_code)) {
158         TC3_LOG(ERROR) << "Could not precompile generator template.";
159         return nullptr;
160       }
161     }
162 
163     intent_generator->generators_[generator->type()->str()] = lua_code;
164   }
165 
166   return intent_generator;
167 }
168 
ParseDeviceLocales(const jstring device_locales) const169 std::vector<Locale> IntentGenerator::ParseDeviceLocales(
170     const jstring device_locales) const {
171   if (device_locales == nullptr) {
172     TC3_LOG(ERROR) << "No locales provided.";
173     return {};
174   }
175   StatusOr<std::string> status_or_locales_str =
176       JStringToUtf8String(jni_cache_->GetEnv(), device_locales);
177   if (!status_or_locales_str.ok()) {
178     TC3_LOG(ERROR)
179         << "JStringToUtf8String failed, cannot retrieve provided locales.";
180     return {};
181   }
182   std::vector<Locale> locales;
183   if (!ParseLocales(status_or_locales_str.ValueOrDie(), &locales)) {
184     TC3_LOG(ERROR) << "Cannot parse locales.";
185     return {};
186   }
187   return locales;
188 }
189 
GenerateIntents(const jstring device_locales,const ClassificationResult & classification,const int64 reference_time_ms_utc,const std::string & text,const CodepointSpan selection_indices,const jobject context,const reflection::Schema * annotations_entity_data_schema,const bool enable_add_contact_intent,const bool enable_search_intent,std::vector<RemoteActionTemplate> * remote_actions) const190 bool IntentGenerator::GenerateIntents(
191     const jstring device_locales, const ClassificationResult& classification,
192     const int64 reference_time_ms_utc, const std::string& text,
193     const CodepointSpan selection_indices, const jobject context,
194     const reflection::Schema* annotations_entity_data_schema,
195     const bool enable_add_contact_intent, const bool enable_search_intent,
196     std::vector<RemoteActionTemplate>* remote_actions) const {
197   if (options_ == nullptr) {
198     return false;
199   }
200 
201   // Retrieve generator for specified entity.
202   auto it = generators_.find(classification.collection);
203   if (it == generators_.end()) {
204     TC3_VLOG(INFO) << "Cannot find a generator for the specified collection.";
205     return true;
206   }
207 
208   const std::string entity_text =
209       UTF8ToUnicodeText(text, /*do_copy=*/false)
210           .UTF8Substring(selection_indices.first, selection_indices.second);
211 
212   std::unique_ptr<AnnotatorJniEnvironment> interpreter(
213       new AnnotatorJniEnvironment(
214           resources_, jni_cache_.get(), context,
215           ParseDeviceLocales(device_locales), entity_text, classification,
216           reference_time_ms_utc, annotations_entity_data_schema,
217           enable_add_contact_intent, enable_search_intent));
218 
219   if (!interpreter->Initialize()) {
220     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
221     return false;
222   }
223 
224   return interpreter->RunIntentGenerator(it->second, remote_actions);
225 }
226 
GenerateIntents(const jstring device_locales,const ActionSuggestion & action,const Conversation & conversation,const jobject context,const reflection::Schema * annotations_entity_data_schema,const reflection::Schema * actions_entity_data_schema,std::vector<RemoteActionTemplate> * remote_actions) const227 bool IntentGenerator::GenerateIntents(
228     const jstring device_locales, const ActionSuggestion& action,
229     const Conversation& conversation, const jobject context,
230     const reflection::Schema* annotations_entity_data_schema,
231     const reflection::Schema* actions_entity_data_schema,
232     std::vector<RemoteActionTemplate>* remote_actions) const {
233   if (options_ == nullptr) {
234     return false;
235   }
236 
237   // Retrieve generator for specified action.
238   auto it = generators_.find(action.type);
239   if (it == generators_.end()) {
240     return true;
241   }
242 
243   std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
244       new ActionsJniLuaEnvironment(
245           resources_, jni_cache_.get(), context,
246           ParseDeviceLocales(device_locales), conversation, action,
247           actions_entity_data_schema, annotations_entity_data_schema));
248 
249   if (!interpreter->Initialize()) {
250     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
251     return false;
252   }
253 
254   return interpreter->RunIntentGenerator(it->second, remote_actions);
255 }
256 
257 }  // namespace libtextclassifier3
258