1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tflite/text_encoder.h"
18
19 #include <memory>
20 #include <vector>
21
22 #include "utils/base/logging.h"
23 #include "utils/container/double-array-trie.h"
24 #include "utils/container/sorted-strings-table.h"
25 #include "utils/sentencepiece/encoder.h"
26 #include "utils/sentencepiece/normalizer.h"
27 #include "utils/strings/stringpiece.h"
28 #include "utils/tflite/encoder_common.h"
29 #include "utils/tflite/text_encoder_config_generated.h"
30 #include "flatbuffers/flatbuffers.h"
31 #include "flatbuffers/flexbuffers.h"
32 #include "tensorflow/lite/kernels/kernel_util.h"
33 #include "tensorflow/lite/model.h"
34 #include "tensorflow/lite/string_util.h"
35
36 namespace libtextclassifier3 {
37 namespace {
38
39 struct TextEncoderOp {
40 std::unique_ptr<SentencePieceNormalizer> normalizer;
41 std::unique_ptr<Encoder> encoder;
42 std::unique_ptr<StringSet> matcher;
43 };
44
45 // Input parameters for the op.
46 // The conversation message as a (1, conversation length) string tensor.
47 constexpr const int kInputTexts = 0;
48
49 // The number of messages, the conversation length, int scalar.
50 constexpr const int kInputNumInputs = 1;
51
52 // Maximum output length of the encoding, int scalar.
53 constexpr const int kInputMaxLength = 2;
54
55 // Additional attributes to align to the sentence pieces, e.g. user ids per
56 // message.
57 constexpr const int kInputAttr = 3;
58
59 // Output parameters for the op.
60 // The text sentence piece encodings as ids, (1, max output length) int tensor.
61 constexpr const int kOutputEncoded = 0;
62
63 // Relative position of each sentence piece in the input text,
64 // (1, max output length) int tensor.
65 constexpr const int kOutputPosition = 1;
66
67 // Output length after trimming to the maximum output length specified.
68 // int scalar.
69 constexpr const int kOutputLengths = 2;
70
71 // Padded and sentence piece aligned provided attributes, e.g. user id per
72 // sentence piece.
73 constexpr const int kOutputAttr = 3;
74
75 const char kTextEncoderConfigAttr[] = "text_encoder_config";
76
77 // Initializes text encoder object from serialized options:
78 // The options are a flexbuffers attribute map that contain the op config
79 // with the key `text_encoder_config` as `TextEncoderConfig`.
Initialize(TfLiteContext * context,const char * buffer,size_t length)80 void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
81 const flexbuffers::Map& attr_map =
82 flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
83 .AsMap();
84 const flexbuffers::Blob serialized_config =
85 attr_map[kTextEncoderConfigAttr].AsBlob();
86 const TextEncoderConfig* config =
87 flatbuffers::GetRoot<TextEncoderConfig>(serialized_config.data());
88
89 std::unique_ptr<TextEncoderOp> encoder_op(new TextEncoderOp());
90
91 // Create normalizer from options.
92 const TrieNode* charsmap_trie_nodes = reinterpret_cast<const TrieNode*>(
93 config->normalization_charsmap()->Data());
94 const int charsmap_trie_nodes_length =
95 config->normalization_charsmap()->size() / sizeof(TrieNode);
96 encoder_op->normalizer.reset(new SentencePieceNormalizer(
97 DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length),
98 StringPiece(config->normalization_charsmap_values()->data(),
99 config->normalization_charsmap_values()->size()),
100 config->add_dummy_prefix(), config->remove_extra_whitespaces(),
101 config->escape_whitespaces()));
102
103 const int num_pieces = config->pieces_scores()->size();
104
105 switch (config->matcher_type()) {
106 case SentencePieceMatcherType_MAPPED_TRIE: {
107 const TrieNode* pieces_trie_nodes =
108 reinterpret_cast<const TrieNode*>(config->pieces()->Data());
109 const int pieces_trie_nodes_length =
110 config->pieces()->size() / sizeof(TrieNode);
111 encoder_op->matcher.reset(
112 new DoubleArrayTrie(pieces_trie_nodes, pieces_trie_nodes_length));
113 break;
114 }
115 case SentencePieceMatcherType_SORTED_STRING_TABLE: {
116 encoder_op->matcher.reset(new SortedStringsTable(
117 num_pieces, config->pieces_offsets()->data(),
118 StringPiece(config->pieces()->data(), config->pieces()->size())));
119 break;
120 }
121 default: {
122 TC3_LOG(ERROR) << "Unknown sentence piece matcher type.";
123 return nullptr;
124 }
125 }
126 encoder_op->encoder.reset(new Encoder(
127 encoder_op->matcher.get(), num_pieces, config->pieces_scores()->data(),
128 config->start_code(), config->end_code(), config->encoding_offset(),
129 config->unknown_code(), config->unknown_score()));
130 return encoder_op.release();
131 }
132
Free(TfLiteContext * context,void * buffer)133 void Free(TfLiteContext* context, void* buffer) {
134 delete reinterpret_cast<TextEncoderOp*>(buffer);
135 }
136
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,int max_output_length)137 TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
138 int max_output_length) {
139 TF_LITE_ENSURE_OK(
140 context,
141 ResizeOutputTensor(max_output_length,
142 &context->tensors[node->outputs->data[kOutputEncoded]],
143 context));
144
145 TF_LITE_ENSURE_OK(
146 context,
147 ResizeOutputTensor(
148 max_output_length,
149 &context->tensors[node->outputs->data[kOutputPosition]], context));
150
151 const int num_output_attrs = node->outputs->size - kOutputAttr;
152 for (int i = 0; i < num_output_attrs; ++i) {
153 TF_LITE_ENSURE_OK(
154 context,
155 ResizeOutputTensor(
156 max_output_length,
157 &context->tensors[node->outputs->data[kOutputAttr + i]], context));
158 }
159 return kTfLiteOk;
160 }
161
Prepare(TfLiteContext * context,TfLiteNode * node)162 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
163 // Check that the batch dimension is kBatchSize.
164 const TfLiteTensor& input_text =
165 context->tensors[node->inputs->data[kInputTexts]];
166 TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
167 TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
168
169 TfLiteTensor& output_lengths =
170 context->tensors[node->outputs->data[kOutputLengths]];
171 TfLiteTensor& output_encoded =
172 context->tensors[node->outputs->data[kOutputEncoded]];
173 TfLiteTensor& output_positions =
174 context->tensors[node->outputs->data[kOutputPosition]];
175
176 TF_LITE_ENSURE_OK(context,
177 context->ResizeTensor(context, &output_lengths,
178 CreateIntArray({kEncoderBatchSize})));
179
180 // Check that there are enough outputs for attributes.
181 const int num_output_attrs = node->outputs->size - kOutputAttr;
182 TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
183
184 // Copy attribute types from input to output tensors.
185 for (int i = 0; i < num_output_attrs; ++i) {
186 TfLiteTensor& input = context->tensors[node->inputs->data[kInputAttr + i]];
187 TfLiteTensor& output =
188 context->tensors[node->outputs->data[kOutputAttr + i]];
189 output.type = input.type;
190 }
191
192 const TfLiteTensor& output_length =
193 context->tensors[node->inputs->data[kInputMaxLength]];
194
195 if (tflite::IsConstantTensor(&output_length)) {
196 return ResizeOutputTensors(context, node, output_length.data.i64[0]);
197 } else {
198 tflite::SetTensorToDynamic(&output_encoded);
199 tflite::SetTensorToDynamic(&output_positions);
200 for (int i = 0; i < num_output_attrs; ++i) {
201 TfLiteTensor& output_attr =
202 context->tensors[node->outputs->data[kOutputAttr + i]];
203 tflite::SetTensorToDynamic(&output_attr);
204 }
205 }
206
207 return kTfLiteOk;
208 }
209
Eval(TfLiteContext * context,TfLiteNode * node)210 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
211 if (node->user_data == nullptr) {
212 return kTfLiteError;
213 }
214 const TextEncoderOp* encoder_op =
215 reinterpret_cast<TextEncoderOp*>(node->user_data);
216 const TfLiteTensor& input_text =
217 context->tensors[node->inputs->data[kInputTexts]];
218 const int num_strings = tflite::GetStringCount(&input_text);
219 // Check that the number of strings matches the length parameter.
220 const int num_strings_param =
221 context->tensors[node->inputs->data[kInputNumInputs]].data.i32[0];
222 TF_LITE_ENSURE_EQ(context, num_strings, num_strings_param);
223
224 TfLiteTensor& output_encoded =
225 context->tensors[node->outputs->data[kOutputEncoded]];
226 if (tflite::IsDynamicTensor(&output_encoded)) {
227 const TfLiteTensor& output_length =
228 context->tensors[node->inputs->data[kInputMaxLength]];
229 TF_LITE_ENSURE_OK(
230 context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
231 }
232 TfLiteTensor& output_positions =
233 context->tensors[node->outputs->data[kOutputPosition]];
234
235 std::vector<int> encoded_total;
236 std::vector<int> encoded_offsets;
237 std::vector<int> encoded_positions;
238 encoded_offsets.reserve(num_strings);
239 const int max_output_length = output_encoded.dims->data[1];
240 const int max_encoded_position = max_output_length;
241
242 for (int i = 0; i < num_strings; ++i) {
243 const auto& strref = tflite::GetString(&input_text, i);
244 std::string normalized;
245 TF_LITE_ENSURE(context,
246 encoder_op->normalizer->Normalize(
247 StringPiece(strref.str, strref.len), &normalized));
248 std::vector<int> encoded;
249 TF_LITE_ENSURE(context, encoder_op->encoder->Encode(normalized, &encoded));
250 encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
251 encoded_offsets.push_back(encoded_total.size());
252 for (int i = 0; i < encoded.size(); i++) {
253 encoded_positions.push_back(std::min(i, max_encoded_position - 1));
254 }
255 }
256
257 const int num_skip = CopyDataToTensorAndPadOrTruncate(
258 max_output_length, encoded_total,
259 /*padding_value=*/encoded_total.back(), &output_encoded);
260 TfLiteTensor& output_lengths =
261 context->tensors[node->outputs->data[kOutputLengths]];
262 output_lengths.data.i32[0] = encoded_total.size() - num_skip;
263 CopyDataToTensorAndPadOrTruncate(max_output_length, encoded_positions,
264 /*padding_value=*/max_encoded_position,
265 &output_positions);
266
267 // Process attributes, all checks of sizes and types are done in Prepare.
268 const int num_output_attrs = node->outputs->size - kOutputAttr;
269 TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
270 for (int i = 0; i < num_output_attrs; ++i) {
271 TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
272 context->tensors[node->inputs->data[kInputAttr + i]], encoded_offsets,
273 num_skip, context,
274 &context->tensors[node->outputs->data[kOutputAttr + i]]);
275 if (attr_status != kTfLiteOk) {
276 return attr_status;
277 }
278 }
279
280 return kTfLiteOk;
281 }
282
283 } // namespace
284 } // namespace libtextclassifier3
285
286 namespace tflite {
287 namespace ops {
288 namespace custom {
289
Register_TEXT_ENCODER()290 TfLiteRegistration* Register_TEXT_ENCODER() {
291 static TfLiteRegistration registration = {
292 libtextclassifier3::Initialize, libtextclassifier3::Free,
293 libtextclassifier3::Prepare, libtextclassifier3::Eval};
294 return ®istration;
295 }
296
297 } // namespace custom
298 } // namespace ops
299 } // namespace tflite
300