xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite/text_encoder.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite/text_encoder.h"
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "utils/base/logging.h"
23 #include "utils/container/double-array-trie.h"
24 #include "utils/container/sorted-strings-table.h"
25 #include "utils/sentencepiece/encoder.h"
26 #include "utils/sentencepiece/normalizer.h"
27 #include "utils/strings/stringpiece.h"
28 #include "utils/tflite/encoder_common.h"
29 #include "utils/tflite/text_encoder_config_generated.h"
30 #include "flatbuffers/flatbuffers.h"
31 #include "flatbuffers/flexbuffers.h"
32 #include "tensorflow/lite/kernels/kernel_util.h"
33 #include "tensorflow/lite/model.h"
34 #include "tensorflow/lite/string_util.h"
35 
36 namespace libtextclassifier3 {
37 namespace {
38 
39 struct TextEncoderOp {
40   std::unique_ptr<SentencePieceNormalizer> normalizer;
41   std::unique_ptr<Encoder> encoder;
42   std::unique_ptr<StringSet> matcher;
43 };
44 
45 // Input parameters for the op.
46 // The conversation message as a (1, conversation length) string tensor.
47 constexpr const int kInputTexts = 0;
48 
49 // The number of messages, the conversation length, int scalar.
50 constexpr const int kInputNumInputs = 1;
51 
52 // Maximum output length of the encoding, int scalar.
53 constexpr const int kInputMaxLength = 2;
54 
55 // Additional attributes to align to the sentence pieces, e.g. user ids per
56 // message.
57 constexpr const int kInputAttr = 3;
58 
59 // Output parameters for the op.
60 // The text sentence piece encodings as ids, (1, max output length) int tensor.
61 constexpr const int kOutputEncoded = 0;
62 
63 // Relative position of each sentence piece in the input text,
64 // (1, max output length) int tensor.
65 constexpr const int kOutputPosition = 1;
66 
67 // Output length after trimming to the maximum output length specified.
68 // int scalar.
69 constexpr const int kOutputLengths = 2;
70 
71 // Padded and sentence piece aligned provided attributes, e.g. user id per
72 // sentence piece.
73 constexpr const int kOutputAttr = 3;
74 
75 const char kTextEncoderConfigAttr[] = "text_encoder_config";
76 
77 // Initializes text encoder object from serialized options:
78 //   The options are a flexbuffers attribute map that contain the op config
79 //   with the key `text_encoder_config` as `TextEncoderConfig`.
Initialize(TfLiteContext * context,const char * buffer,size_t length)80 void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
81   const flexbuffers::Map& attr_map =
82       flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
83           .AsMap();
84   const flexbuffers::Blob serialized_config =
85       attr_map[kTextEncoderConfigAttr].AsBlob();
86   const TextEncoderConfig* config =
87       flatbuffers::GetRoot<TextEncoderConfig>(serialized_config.data());
88 
89   std::unique_ptr<TextEncoderOp> encoder_op(new TextEncoderOp());
90 
91   // Create normalizer from options.
92   const TrieNode* charsmap_trie_nodes = reinterpret_cast<const TrieNode*>(
93       config->normalization_charsmap()->Data());
94   const int charsmap_trie_nodes_length =
95       config->normalization_charsmap()->size() / sizeof(TrieNode);
96   encoder_op->normalizer.reset(new SentencePieceNormalizer(
97       DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length),
98       StringPiece(config->normalization_charsmap_values()->data(),
99                   config->normalization_charsmap_values()->size()),
100       config->add_dummy_prefix(), config->remove_extra_whitespaces(),
101       config->escape_whitespaces()));
102 
103   const int num_pieces = config->pieces_scores()->size();
104 
105   switch (config->matcher_type()) {
106     case SentencePieceMatcherType_MAPPED_TRIE: {
107       const TrieNode* pieces_trie_nodes =
108           reinterpret_cast<const TrieNode*>(config->pieces()->Data());
109       const int pieces_trie_nodes_length =
110           config->pieces()->size() / sizeof(TrieNode);
111       encoder_op->matcher.reset(
112           new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
113       break;
114     }
115     case SentencePieceMatcherType_SORTED_STRING_TABLE: {
116       encoder_op->matcher.reset(new SortedStringsTable(
117           num_pieces, config->pieces_offsets()->data(),
118           StringPiece(config->pieces()->data(), config->pieces()->size())));
119       break;
120     }
121     default: {
122       TC3_LOG(ERROR) << "Unknown sentence piece matcher type.";
123       return nullptr;
124     }
125   }
126   encoder_op->encoder.reset(new Encoder(
127       encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
128       config->start_code(), config->end_code(), config->encoding_offset(),
129       config->unknown_code(), config->unknown_score()));
130   return encoder_op.release();
131 }
132 
Free(TfLiteContext * context,void * buffer)133 void Free(TfLiteContext* context, void* buffer) {
134   delete reinterpret_cast<TextEncoderOp*>(buffer);
135 }
136 
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,int max_output_length)137 TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
138                                  int max_output_length) {
139   TF_LITE_ENSURE_OK(
140       context,
141       ResizeOutputTensor(max_output_length,
142                          &context->tensors[node->outputs->data[kOutputEncoded]],
143                          context));
144 
145   TF_LITE_ENSURE_OK(
146       context,
147       ResizeOutputTensor(
148           max_output_length,
149           &context->tensors[node->outputs->data[kOutputPosition]], context));
150 
151   const int num_output_attrs = node->outputs->size - kOutputAttr;
152   for (int i = 0; i < num_output_attrs; ++i) {
153     TF_LITE_ENSURE_OK(
154         context,
155         ResizeOutputTensor(
156             max_output_length,
157             &context->tensors[node->outputs->data[kOutputAttr + i]], context));
158   }
159   return kTfLiteOk;
160 }
161 
Prepare(TfLiteContext * context,TfLiteNode * node)162 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
163   // Check that the batch dimension is kBatchSize.
164   const TfLiteTensor& input_text =
165       context->tensors[node->inputs->data[kInputTexts]];
166   TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
167   TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
168 
169   TfLiteTensor& output_lengths =
170       context->tensors[node->outputs->data[kOutputLengths]];
171   TfLiteTensor& output_encoded =
172       context->tensors[node->outputs->data[kOutputEncoded]];
173   TfLiteTensor& output_positions =
174       context->tensors[node->outputs->data[kOutputPosition]];
175 
176   TF_LITE_ENSURE_OK(context,
177                     context->ResizeTensor(context, &output_lengths,
178                                           CreateIntArray({kEncoderBatchSize})));
179 
180   // Check that there are enough outputs for attributes.
181   const int num_output_attrs = node->outputs->size - kOutputAttr;
182   TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
183 
184   // Copy attribute types from input to output tensors.
185   for (int i = 0; i < num_output_attrs; ++i) {
186     TfLiteTensor& input = context->tensors[node->inputs->data[kInputAttr + i]];
187     TfLiteTensor& output =
188         context->tensors[node->outputs->data[kOutputAttr + i]];
189     output.type = input.type;
190   }
191 
192   const TfLiteTensor& output_length =
193       context->tensors[node->inputs->data[kInputMaxLength]];
194 
195   if (tflite::IsConstantTensor(&output_length)) {
196     return ResizeOutputTensors(context, node, output_length.data.i64[0]);
197   } else {
198     tflite::SetTensorToDynamic(&output_encoded);
199     tflite::SetTensorToDynamic(&output_positions);
200     for (int i = 0; i < num_output_attrs; ++i) {
201       TfLiteTensor& output_attr =
202           context->tensors[node->outputs->data[kOutputAttr + i]];
203       tflite::SetTensorToDynamic(&output_attr);
204     }
205   }
206 
207   return kTfLiteOk;
208 }
209 
Eval(TfLiteContext * context,TfLiteNode * node)210 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
211   if (node->user_data == nullptr) {
212     return kTfLiteError;
213   }
214   const TextEncoderOp* encoder_op =
215       reinterpret_cast<TextEncoderOp*>(node->user_data);
216   const TfLiteTensor& input_text =
217       context->tensors[node->inputs->data[kInputTexts]];
218   const int num_strings = tflite::GetStringCount(&input_text);
219   // Check that the number of strings matches the length parameter.
220   const int num_strings_param =
221       context->tensors[node->inputs->data[kInputNumInputs]].data.i32[0];
222   TF_LITE_ENSURE_EQ(context, num_strings, num_strings_param);
223 
224   TfLiteTensor& output_encoded =
225       context->tensors[node->outputs->data[kOutputEncoded]];
226   if (tflite::IsDynamicTensor(&output_encoded)) {
227     const TfLiteTensor& output_length =
228         context->tensors[node->inputs->data[kInputMaxLength]];
229     TF_LITE_ENSURE_OK(
230         context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
231   }
232   TfLiteTensor& output_positions =
233       context->tensors[node->outputs->data[kOutputPosition]];
234 
235   std::vector<int> encoded_total;
236   std::vector<int> encoded_offsets;
237   std::vector<int> encoded_positions;
238   encoded_offsets.reserve(num_strings);
239   const int max_output_length = output_encoded.dims->data[1];
240   const int max_encoded_position = max_output_length;
241 
242   for (int i = 0; i < num_strings; ++i) {
243     const auto& strref = tflite::GetString(&input_text, i);
244     std::string normalized;
245     TF_LITE_ENSURE(context,
246                    encoder_op->normalizer->Normalize(
247                        StringPiece(strref.str, strref.len), &normalized));
248     std::vector<int> encoded;
249     TF_LITE_ENSURE(context, encoder_op->encoder->Encode(normalized, &encoded));
250     encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
251     encoded_offsets.push_back(encoded_total.size());
252     for (int i = 0; i < encoded.size(); i++) {
253       encoded_positions.push_back(std::min(i, max_encoded_position - 1));
254     }
255   }
256 
257   const int num_skip = CopyDataToTensorAndPadOrTruncate(
258       max_output_length, encoded_total,
259       /*padding_value=*/encoded_total.back(), &output_encoded);
260   TfLiteTensor& output_lengths =
261       context->tensors[node->outputs->data[kOutputLengths]];
262   output_lengths.data.i32[0] = encoded_total.size() - num_skip;
263   CopyDataToTensorAndPadOrTruncate(max_output_length, encoded_positions,
264                                    /*padding_value=*/max_encoded_position,
265                                    &output_positions);
266 
267   // Process attributes, all checks of sizes and types are done in Prepare.
268   const int num_output_attrs = node->outputs->size - kOutputAttr;
269   TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
270   for (int i = 0; i < num_output_attrs; ++i) {
271     TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
272         context->tensors[node->inputs->data[kInputAttr + i]], encoded_offsets,
273         num_skip, context,
274         &context->tensors[node->outputs->data[kOutputAttr + i]]);
275     if (attr_status != kTfLiteOk) {
276       return attr_status;
277     }
278   }
279 
280   return kTfLiteOk;
281 }
282 
283 }  // namespace
284 }  // namespace libtextclassifier3
285 
286 namespace tflite {
287 namespace ops {
288 namespace custom {
289 
Register_TEXT_ENCODER()290 TfLiteRegistration* Register_TEXT_ENCODER() {
291   static TfLiteRegistration registration = {
292       libtextclassifier3::Initialize, libtextclassifier3::Free,
293       libtextclassifier3::Prepare, libtextclassifier3::Eval};
294   return &registration;
295 }
296 
297 }  // namespace custom
298 }  // namespace ops
299 }  // namespace tflite
300