xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite-model-executor.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite-model-executor.h"
18 
19 #include "utils/base/logging.h"
20 #include "tensorflow/lite/kernels/register.h"
21 #include "tensorflow/lite/schema/schema_generated.h"
22 
23 // Forward declaration of custom TensorFlow Lite ops for registration.
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 TfLiteRegistration* Register_GELU();
28 TfLiteRegistration* Register_ADD();
29 TfLiteRegistration* Register_CONCATENATION();
30 TfLiteRegistration* Register_CONV_2D();
31 TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
32 TfLiteRegistration* Register_AVERAGE_POOL_2D();
33 TfLiteRegistration* Register_EQUAL();
34 TfLiteRegistration* Register_FULLY_CONNECTED();
35 TfLiteRegistration* Register_GREATER_EQUAL();
36 TfLiteRegistration* Register_L2_NORMALIZATION();
37 TfLiteRegistration* Register_MUL();
38 TfLiteRegistration* Register_RESHAPE();
39 TfLiteRegistration* Register_REDUCE_MAX();
40 TfLiteRegistration* Register_REDUCE_MIN();
41 TfLiteRegistration* Register_REDUCE_ANY();
42 TfLiteRegistration* Register_SOFTMAX();
43 TfLiteRegistration* Register_GATHER();
44 TfLiteRegistration* Register_GATHER_ND();
45 TfLiteRegistration* Register_IF();
46 TfLiteRegistration* Register_ROUND();
47 TfLiteRegistration* Register_ZEROS_LIKE();
48 TfLiteRegistration* Register_TRANSPOSE();
49 TfLiteRegistration* Register_SUB();
50 TfLiteRegistration* Register_DIV();
51 TfLiteRegistration* Register_STRIDED_SLICE();
52 TfLiteRegistration* Register_EXP();
53 TfLiteRegistration* Register_TOPK_V2();
54 TfLiteRegistration* Register_SLICE();
55 TfLiteRegistration* Register_SPLIT();
56 TfLiteRegistration* Register_CAST();
57 TfLiteRegistration* Register_MAXIMUM();
58 TfLiteRegistration* Register_MINIMUM();
59 TfLiteRegistration* Register_NEG();
60 TfLiteRegistration* Register_SLICE();
61 TfLiteRegistration* Register_LOG();
62 TfLiteRegistration* Register_LOGISTIC();
63 TfLiteRegistration* Register_SUM();
64 TfLiteRegistration* Register_PACK();
65 TfLiteRegistration* Register_DEQUANTIZE();
66 TfLiteRegistration* Register_MEAN();
67 TfLiteRegistration* Register_LESS();
68 TfLiteRegistration* Register_TILE();
69 TfLiteRegistration* Register_SQUARED_DIFFERENCE();
70 TfLiteRegistration* Register_RSQRT();
71 TfLiteRegistration* Register_LOG_SOFTMAX();
72 TfLiteRegistration* Register_WHERE();
73 TfLiteRegistration* Register_ONE_HOT();
74 TfLiteRegistration* Register_POW();
75 TfLiteRegistration* Register_TANH();
76 TfLiteRegistration* Register_UNIQUE();
77 TfLiteRegistration* Register_REDUCE_PROD();
78 TfLiteRegistration* Register_SHAPE();
79 TfLiteRegistration* Register_NOT_EQUAL();
80 TfLiteRegistration* Register_CUMSUM();
81 TfLiteRegistration* Register_EXPAND_DIMS();
82 TfLiteRegistration* Register_FILL();
83 TfLiteRegistration* Register_PADV2();
84 TfLiteRegistration* Register_EMBEDDING_LOOKUP();
85 TfLiteRegistration* Register_GREATER();
86 }  // namespace builtin
87 }  // namespace ops
88 }  // namespace tflite
89 
90 #ifdef TC3_WITH_ACTIONS_OPS
91 #include "utils/tflite/blacklist.h"
92 #include "utils/tflite/dist_diversification.h"
93 #include "utils/tflite/string_projection.h"
94 #include "utils/tflite/text_encoder.h"
95 #include "utils/tflite/text_encoder3s.h"
96 #include "utils/tflite/token_encoder.h"
97 
98 namespace tflite {
99 namespace ops {
100 namespace custom {
101 TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
102 TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
103 TfLiteRegistration* Register_RAGGED_RANGE();
104 TfLiteRegistration* Register_RANDOM_UNIFORM();
105 }  // namespace custom
106 }  // namespace ops
107 }  // namespace tflite
108 
RegisterSelectedOps(tflite::MutableOpResolver * resolver)109 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
110   resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
111                        tflite::ops::builtin::Register_ADD(),
112                        /*min_version=*/1,
113                        /*max_version=*/2);
114   resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION,
115                        tflite::ops::builtin::Register_CONCATENATION(),
116                        /*min_version=*/1,
117                        /*max_version=*/2);
118   resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
119                        tflite::ops::builtin::Register_CONV_2D(),
120                        /*min_version=*/1,
121                        /*max_version=*/5);
122   resolver->AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
123                        tflite::ops::builtin::Register_DEPTHWISE_CONV_2D(),
124                        /*min_version=*/1,
125                        /*max_version=*/6);
126   resolver->AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D,
127                        tflite::ops::builtin::Register_AVERAGE_POOL_2D(),
128                        /*min_version=*/1,
129                        /*max_version=*/1);
130   resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL,
131                        ::tflite::ops::builtin::Register_EQUAL());
132 
133   resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
134                        tflite::ops::builtin::Register_FULLY_CONNECTED(),
135                        /*min_version=*/1,
136                        /*max_version=*/9);
137   resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER_EQUAL,
138                        ::tflite::ops::builtin::Register_GREATER_EQUAL());
139   resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
140                        tflite::ops::builtin::Register_L2_NORMALIZATION(),
141                        /*min_version=*/1,
142                        /*max_version=*/2);
143   resolver->AddBuiltin(tflite::BuiltinOperator_MUL,
144                        tflite::ops::builtin::Register_MUL());
145   resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE,
146                        tflite::ops::builtin::Register_RESHAPE());
147   resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MAX,
148                        ::tflite::ops::builtin::Register_REDUCE_MAX());
149   resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MIN,
150                        ::tflite::ops::builtin::Register_REDUCE_MIN());
151   resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_ANY,
152                        ::tflite::ops::builtin::Register_REDUCE_ANY());
153   resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
154                        tflite::ops::builtin::Register_SOFTMAX(),
155                        /*min_version=*/1,
156                        /*max_version=*/2);
157   resolver->AddBuiltin(tflite::BuiltinOperator_GATHER,
158                        tflite::ops::builtin::Register_GATHER(),
159                        /*min_version=*/1,
160                        /*max_version=*/2);
161   resolver->AddBuiltin(::tflite::BuiltinOperator_GATHER_ND,
162                        ::tflite::ops::builtin::Register_GATHER_ND(),
163                        /*version=*/2);
164   resolver->AddBuiltin(::tflite::BuiltinOperator_IF,
165                        ::tflite::ops::builtin::Register_IF()),
166       resolver->AddBuiltin(::tflite::BuiltinOperator_ROUND,
167                            ::tflite::ops::builtin::Register_ROUND());
168   resolver->AddBuiltin(::tflite::BuiltinOperator_ZEROS_LIKE,
169                        ::tflite::ops::builtin::Register_ZEROS_LIKE());
170   resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
171                        tflite::ops::builtin::Register_TRANSPOSE(),
172                        /*min_version=*/1,
173                        /*max_version=*/2);
174   resolver->AddBuiltin(tflite::BuiltinOperator_SUB,
175                        tflite::ops::builtin::Register_SUB(),
176                        /*min_version=*/1,
177                        /*max_version=*/2);
178   resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
179                        tflite::ops::builtin::Register_DIV());
180   resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE,
181                        tflite::ops::builtin::Register_STRIDED_SLICE(),
182                        /*min_version=*/1,
183                        /*max_version=*/2);
184   resolver->AddBuiltin(tflite::BuiltinOperator_EXP,
185                        tflite::ops::builtin::Register_EXP());
186   resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2,
187                        tflite::ops::builtin::Register_TOPK_V2(),
188                        /*min_version=*/1,
189                        /*max_version=*/2);
190   resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
191                        tflite::ops::builtin::Register_SLICE(),
192                        /*min_version=*/1,
193                        /*max_version=*/3);
194   resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
195                        tflite::ops::builtin::Register_SPLIT(),
196                        /*min_version=*/1,
197                        /*max_version=*/3);
198   resolver->AddBuiltin(tflite::BuiltinOperator_CAST,
199                        tflite::ops::builtin::Register_CAST());
200   resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM,
201                        tflite::ops::builtin::Register_MAXIMUM(),
202                        /*min_version=*/1,
203                        /*max_version=*/2);
204   resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM,
205                        tflite::ops::builtin::Register_MINIMUM(),
206                        /*min_version=*/1,
207                        /*max_version=*/2);
208   resolver->AddBuiltin(tflite::BuiltinOperator_NEG,
209                        tflite::ops::builtin::Register_NEG());
210   resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
211                        tflite::ops::builtin::Register_SLICE(),
212                        /*min_version=*/1,
213                        /*max_version=*/2);
214   resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
215                        tflite::ops::builtin::Register_LOG());
216   resolver->AddBuiltin(tflite::BuiltinOperator_LOGISTIC,
217                        tflite::ops::builtin::Register_LOGISTIC());
218   resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
219                        tflite::ops::builtin::Register_SUM());
220   resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
221                        tflite::ops::builtin::Register_PACK(),
222                        /*min_version=*/1,
223                        /*max_version=*/2);
224   resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE,
225                        tflite::ops::builtin::Register_DEQUANTIZE(),
226                        /*min_version=*/1,
227                        /*max_version=*/2);
228   resolver->AddBuiltin(tflite::BuiltinOperator_MEAN,
229                        tflite::ops::builtin::Register_MEAN());
230   resolver->AddBuiltin(tflite::BuiltinOperator_LESS,
231                        tflite::ops::builtin::Register_LESS());
232   resolver->AddBuiltin(tflite::BuiltinOperator_TILE,
233                        tflite::ops::builtin::Register_TILE());
234   resolver->AddBuiltin(tflite::BuiltinOperator_SQUARED_DIFFERENCE,
235                        tflite::ops::builtin::Register_SQUARED_DIFFERENCE());
236   resolver->AddBuiltin(tflite::BuiltinOperator_RSQRT,
237                        tflite::ops::builtin::Register_RSQRT());
238   resolver->AddBuiltin(tflite::BuiltinOperator_LOG_SOFTMAX,
239                        tflite::ops::builtin::Register_LOG_SOFTMAX());
240   resolver->AddBuiltin(::tflite::BuiltinOperator_WHERE,
241                        ::tflite::ops::builtin::Register_WHERE());
242   resolver->AddBuiltin(tflite::BuiltinOperator_ONE_HOT,
243                        tflite::ops::builtin::Register_ONE_HOT(),
244                        /*min_version=*/1,
245                        /*max_version=*/1);
246   resolver->AddBuiltin(tflite::BuiltinOperator_POW,
247                        tflite::ops::builtin::Register_POW(),
248                        /*min_version=*/1,
249                        /*max_version=*/1);
250   resolver->AddBuiltin(tflite::BuiltinOperator_TANH,
251                        tflite::ops::builtin::Register_TANH(),
252                        /*min_version=*/1,
253                        /*max_version=*/1);
254   resolver->AddBuiltin(::tflite::BuiltinOperator_UNIQUE,
255                        ::tflite::ops::builtin::Register_UNIQUE());
256   resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_PROD,
257                        ::tflite::ops::builtin::Register_REDUCE_PROD());
258   resolver->AddBuiltin(::tflite::BuiltinOperator_SHAPE,
259                        ::tflite::ops::builtin::Register_SHAPE());
260   resolver->AddBuiltin(::tflite::BuiltinOperator_NOT_EQUAL,
261                        ::tflite::ops::builtin::Register_NOT_EQUAL());
262   resolver->AddBuiltin(::tflite::BuiltinOperator_CUMSUM,
263                        ::tflite::ops::builtin::Register_CUMSUM());
264   resolver->AddBuiltin(::tflite::BuiltinOperator_EXPAND_DIMS,
265                        ::tflite::ops::builtin::Register_EXPAND_DIMS());
266   resolver->AddBuiltin(::tflite::BuiltinOperator_FILL,
267                        ::tflite::ops::builtin::Register_FILL());
268   resolver->AddBuiltin(::tflite::BuiltinOperator_PADV2,
269                        ::tflite::ops::builtin::Register_PADV2());
270   resolver->AddBuiltin(::tflite::BuiltinOperator_EMBEDDING_LOOKUP,
271                        ::tflite::ops::builtin::Register_EMBEDDING_LOOKUP(),
272                        /* min_version=*/1,
273                        /*max_version=*/3);
274   resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER,
275                        ::tflite::ops::builtin::Register_GREATER());
276   resolver->AddBuiltin(::tflite::BuiltinOperator_GELU,
277                        ::tflite::ops::builtin::Register_GELU());
278 }
279 #else
RegisterSelectedOps(tflite::MutableOpResolver * resolver)280 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
281   resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
282                        tflite::ops::builtin::Register_FULLY_CONNECTED());
283 }
284 #endif  // TC3_WITH_ACTIONS_OPS
285 
286 namespace libtextclassifier3 {
287 
BuildOpResolver()288 std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
289   return BuildOpResolver([](tflite::MutableOpResolver* mutable_resolver) {});
290 }
291 
BuildOpResolver(const std::function<void (tflite::MutableOpResolver *)> & customize_fn)292 std::unique_ptr<tflite::OpResolver> BuildOpResolver(
293     const std::function<void(tflite::MutableOpResolver*)>& customize_fn) {
294 #ifdef TC3_USE_SELECTIVE_REGISTRATION
295   std::unique_ptr<tflite::MutableOpResolver> resolver(
296       new tflite::MutableOpResolver);
297   RegisterSelectedOps(resolver.get());
298 #else
299   std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
300       new tflite::ops::builtin::BuiltinOpResolver);
301 #endif
302 #ifdef TC3_WITH_ACTIONS_OPS
303   resolver->AddCustom("DistanceDiversification",
304                       tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
305   resolver->AddCustom("TextEncoder",
306                       tflite::ops::custom::Register_TEXT_ENCODER());
307   resolver->AddCustom("TextEncoder3S",
308                       tflite::ops::custom::Register_TEXT_ENCODER3S());
309   resolver->AddCustom("TokenEncoder",
310                       tflite::ops::custom::Register_TOKEN_ENCODER());
311   resolver->AddCustom(
312       "TFSentencepieceTokenizeOp",
313       ::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
314   resolver->AddCustom("RaggedRange",
315                       ::tflite::ops::custom::Register_RAGGED_RANGE());
316   resolver->AddCustom(
317       "RaggedTensorToTensor",
318       ::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
319   resolver->AddCustom(
320       "STRING_PROJECTION",
321       ::tflite::ops::custom::libtextclassifier3::Register_STRING_PROJECTION());
322   resolver->AddCustom(
323       "BLACKLIST",
324       ::tflite::ops::custom::libtextclassifier3::Register_BLACKLIST());
325   resolver->AddCustom("RandomUniform",
326                       ::tflite::ops::custom::Register_RANDOM_UNIFORM());
327 #endif  // TC3_WITH_ACTIONS_OPS
328   customize_fn(resolver.get());
329   return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
330 }
331 
TfLiteModelFromModelSpec(const tflite::Model * model_spec)332 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
333     const tflite::Model* model_spec) {
334   std::unique_ptr<const tflite::FlatBufferModel> model(
335       tflite::FlatBufferModel::BuildFromModel(model_spec));
336   if (!model || !model->initialized()) {
337     TC3_LOG(ERROR) << "Could not build TFLite model from a model spec.";
338     return nullptr;
339   }
340   return model;
341 }
342 
TfLiteModelFromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)343 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
344     const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
345   const tflite::Model* model =
346       flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
347   flatbuffers::Verifier verifier(model_spec_buffer->data(),
348                                  model_spec_buffer->size());
349   if (!model->Verify(verifier)) {
350     return nullptr;
351   }
352   return TfLiteModelFromModelSpec(model);
353 }
354 
TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)355 TfLiteModelExecutor::TfLiteModelExecutor(
356     std::unique_ptr<const tflite::FlatBufferModel> model)
357     : model_(std::move(model)), resolver_(BuildOpResolver()) {}
TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model,std::unique_ptr<tflite::OpResolver> resolver)358 TfLiteModelExecutor::TfLiteModelExecutor(
359     std::unique_ptr<const tflite::FlatBufferModel> model,
360     std::unique_ptr<tflite::OpResolver> resolver)
361     : model_(std::move(model)), resolver_(std::move(resolver)) {}
362 
CreateInterpreter() const363 std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter()
364     const {
365   std::unique_ptr<tflite::Interpreter> interpreter;
366   tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter);
367   return interpreter;
368 }
369 
370 template <>
SetInput(const int input_index,const std::vector<std::string> & input_data,tflite::Interpreter * interpreter) const371 void TfLiteModelExecutor::SetInput(const int input_index,
372                                    const std::vector<std::string>& input_data,
373                                    tflite::Interpreter* interpreter) const {
374   tflite::DynamicBuffer buf;
375   for (const std::string& s : input_data) {
376     buf.AddString(s.data(), s.length());
377   }
378   buf.WriteToTensorAsVector(
379       interpreter->tensor(interpreter->inputs()[input_index]));
380 }
381 
382 template <>
Output(const int output_index,const tflite::Interpreter * interpreter) const383 std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
384     const int output_index, const tflite::Interpreter* interpreter) const {
385   const TfLiteTensor* output_tensor =
386       interpreter->tensor(interpreter->outputs()[output_index]);
387   const int num_strings = tflite::GetStringCount(output_tensor);
388   std::vector<tflite::StringRef> output(num_strings);
389   for (int i = 0; i < num_strings; i++) {
390     output[i] = tflite::GetString(output_tensor, i);
391   }
392   return output;
393 }
394 
395 template <>
Output(const int output_index,const tflite::Interpreter * interpreter) const396 std::vector<std::string> TfLiteModelExecutor::Output(
397     const int output_index, const tflite::Interpreter* interpreter) const {
398   std::vector<std::string> output;
399   for (const tflite::StringRef& s :
400        Output<tflite::StringRef>(output_index, interpreter)) {
401     output.push_back(std::string(s.str, s.len));
402   }
403   return output;
404 }
405 
406 }  // namespace libtextclassifier3
407