xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite/token_encoder.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/token_encoder.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <vector>
20*993b0882SAndroid Build Coastguard Worker 
21*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/encoder_common.h"
22*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/kernel_util.h"
23*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h"
24*993b0882SAndroid Build Coastguard Worker 
25*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
26*993b0882SAndroid Build Coastguard Worker namespace {
27*993b0882SAndroid Build Coastguard Worker 
28*993b0882SAndroid Build Coastguard Worker // Input parameters for the op.
29*993b0882SAndroid Build Coastguard Worker // The number of tokens per message as (1, conversation length) int tensor.
30*993b0882SAndroid Build Coastguard Worker constexpr const int kInputNumTokens = 0;
31*993b0882SAndroid Build Coastguard Worker 
32*993b0882SAndroid Build Coastguard Worker // The number of messages, the conversation length, int scalar.
33*993b0882SAndroid Build Coastguard Worker constexpr const int kInputNumInputs = 1;
34*993b0882SAndroid Build Coastguard Worker 
35*993b0882SAndroid Build Coastguard Worker // Maximum output length of the encoding, int scalar.
36*993b0882SAndroid Build Coastguard Worker constexpr const int kInputMaxLength = 2;
37*993b0882SAndroid Build Coastguard Worker 
38*993b0882SAndroid Build Coastguard Worker // Additional attributes to align to the sentence pieces, e.g. user ids per
39*993b0882SAndroid Build Coastguard Worker // message.
40*993b0882SAndroid Build Coastguard Worker constexpr const int kInputAttr = 3;
41*993b0882SAndroid Build Coastguard Worker 
42*993b0882SAndroid Build Coastguard Worker // Output parameters for the op.
43*993b0882SAndroid Build Coastguard Worker // Relative position of each token in the input text,
44*993b0882SAndroid Build Coastguard Worker // (1, max output length) int tensor.
45*993b0882SAndroid Build Coastguard Worker constexpr const int kOutputPosition = 0;
46*993b0882SAndroid Build Coastguard Worker 
47*993b0882SAndroid Build Coastguard Worker // Output length after trimming to the maximum output length specified.
48*993b0882SAndroid Build Coastguard Worker // int scalar.
49*993b0882SAndroid Build Coastguard Worker constexpr const int kOutputLengths = 1;
50*993b0882SAndroid Build Coastguard Worker 
51*993b0882SAndroid Build Coastguard Worker // Padded and sentence piece aligned provided attributes, e.g. user id per
52*993b0882SAndroid Build Coastguard Worker // sentence piece.
53*993b0882SAndroid Build Coastguard Worker constexpr const int kOutputAttr = 2;
54*993b0882SAndroid Build Coastguard Worker 
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,int max_output_length)55*993b0882SAndroid Build Coastguard Worker TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
56*993b0882SAndroid Build Coastguard Worker                                  int max_output_length) {
57*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_OK(
58*993b0882SAndroid Build Coastguard Worker       context,
59*993b0882SAndroid Build Coastguard Worker       ResizeOutputTensor(
60*993b0882SAndroid Build Coastguard Worker           max_output_length,
61*993b0882SAndroid Build Coastguard Worker           &context->tensors[node->outputs->data[kOutputPosition]], context));
62*993b0882SAndroid Build Coastguard Worker 
63*993b0882SAndroid Build Coastguard Worker   const int num_output_attrs = node->outputs->size - kOutputAttr;
64*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_output_attrs; ++i) {
65*993b0882SAndroid Build Coastguard Worker     TF_LITE_ENSURE_OK(
66*993b0882SAndroid Build Coastguard Worker         context,
67*993b0882SAndroid Build Coastguard Worker         ResizeOutputTensor(
68*993b0882SAndroid Build Coastguard Worker             max_output_length,
69*993b0882SAndroid Build Coastguard Worker             &context->tensors[node->outputs->data[kOutputAttr + i]], context));
70*993b0882SAndroid Build Coastguard Worker   }
71*993b0882SAndroid Build Coastguard Worker   return kTfLiteOk;
72*993b0882SAndroid Build Coastguard Worker }
73*993b0882SAndroid Build Coastguard Worker 
Prepare(TfLiteContext * context,TfLiteNode * node)74*993b0882SAndroid Build Coastguard Worker TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
75*993b0882SAndroid Build Coastguard Worker   // Check that the batch dimension is kBatchSize.
76*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor& num_tokens =
77*993b0882SAndroid Build Coastguard Worker       context->tensors[node->inputs->data[kInputNumTokens]];
78*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_EQ(context, num_tokens.dims->size, kEncoderInputRank);
79*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_EQ(context, num_tokens.dims->data[0], kEncoderBatchSize);
80*993b0882SAndroid Build Coastguard Worker 
81*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_lengths =
82*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputLengths]];
83*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_positions =
84*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputPosition]];
85*993b0882SAndroid Build Coastguard Worker 
86*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_OK(context,
87*993b0882SAndroid Build Coastguard Worker                     context->ResizeTensor(context, &output_lengths,
88*993b0882SAndroid Build Coastguard Worker                                           CreateIntArray({kEncoderBatchSize})));
89*993b0882SAndroid Build Coastguard Worker 
90*993b0882SAndroid Build Coastguard Worker   // Check that there are enough outputs for attributes.
91*993b0882SAndroid Build Coastguard Worker   const int num_output_attrs = node->outputs->size - kOutputAttr;
92*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
93*993b0882SAndroid Build Coastguard Worker 
94*993b0882SAndroid Build Coastguard Worker   // Copy attribute types from input to output tensors.
95*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_output_attrs; ++i) {
96*993b0882SAndroid Build Coastguard Worker     TfLiteTensor& input = context->tensors[node->inputs->data[kInputAttr + i]];
97*993b0882SAndroid Build Coastguard Worker     TfLiteTensor& output =
98*993b0882SAndroid Build Coastguard Worker         context->tensors[node->outputs->data[kOutputAttr + i]];
99*993b0882SAndroid Build Coastguard Worker     output.type = input.type;
100*993b0882SAndroid Build Coastguard Worker   }
101*993b0882SAndroid Build Coastguard Worker 
102*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor& output_length =
103*993b0882SAndroid Build Coastguard Worker       context->tensors[node->inputs->data[kInputMaxLength]];
104*993b0882SAndroid Build Coastguard Worker 
105*993b0882SAndroid Build Coastguard Worker   if (tflite::IsConstantTensor(&output_length)) {
106*993b0882SAndroid Build Coastguard Worker     return ResizeOutputTensors(context, node, output_length.data.i64[0]);
107*993b0882SAndroid Build Coastguard Worker   } else {
108*993b0882SAndroid Build Coastguard Worker     tflite::SetTensorToDynamic(&output_positions);
109*993b0882SAndroid Build Coastguard Worker     for (int i = 0; i < num_output_attrs; ++i) {
110*993b0882SAndroid Build Coastguard Worker       TfLiteTensor& output_attr =
111*993b0882SAndroid Build Coastguard Worker           context->tensors[node->outputs->data[kOutputAttr + i]];
112*993b0882SAndroid Build Coastguard Worker       tflite::SetTensorToDynamic(&output_attr);
113*993b0882SAndroid Build Coastguard Worker     }
114*993b0882SAndroid Build Coastguard Worker   }
115*993b0882SAndroid Build Coastguard Worker 
116*993b0882SAndroid Build Coastguard Worker   return kTfLiteOk;
117*993b0882SAndroid Build Coastguard Worker }
118*993b0882SAndroid Build Coastguard Worker 
Eval(TfLiteContext * context,TfLiteNode * node)119*993b0882SAndroid Build Coastguard Worker TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
120*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor& num_tokens =
121*993b0882SAndroid Build Coastguard Worker       context->tensors[node->inputs->data[kInputNumTokens]];
122*993b0882SAndroid Build Coastguard Worker   const int num_inputs =
123*993b0882SAndroid Build Coastguard Worker       context->tensors[node->inputs->data[kInputNumInputs]].data.i32[0];
124*993b0882SAndroid Build Coastguard Worker 
125*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor& output_length =
126*993b0882SAndroid Build Coastguard Worker       context->tensors[node->inputs->data[kInputMaxLength]];
127*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_positions =
128*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputPosition]];
129*993b0882SAndroid Build Coastguard Worker   if (!tflite::IsConstantTensor(&output_length)) {
130*993b0882SAndroid Build Coastguard Worker     TF_LITE_ENSURE_OK(
131*993b0882SAndroid Build Coastguard Worker         context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
132*993b0882SAndroid Build Coastguard Worker   }
133*993b0882SAndroid Build Coastguard Worker 
134*993b0882SAndroid Build Coastguard Worker   std::vector<int> encoded_offsets;
135*993b0882SAndroid Build Coastguard Worker   std::vector<int> encoded_positions;
136*993b0882SAndroid Build Coastguard Worker   encoded_offsets.reserve(num_inputs);
137*993b0882SAndroid Build Coastguard Worker   const int max_output_length = output_positions.dims->data[1];
138*993b0882SAndroid Build Coastguard Worker   const int max_encoded_position = max_output_length;
139*993b0882SAndroid Build Coastguard Worker   int total_tokens = 0;
140*993b0882SAndroid Build Coastguard Worker 
141*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_inputs; ++i) {
142*993b0882SAndroid Build Coastguard Worker     const int num_message_tokens =
143*993b0882SAndroid Build Coastguard Worker         num_tokens.data.i32[i] + 2; /* num_tokens + start and end token. */
144*993b0882SAndroid Build Coastguard Worker     total_tokens += num_message_tokens;
145*993b0882SAndroid Build Coastguard Worker     encoded_offsets.push_back(total_tokens);
146*993b0882SAndroid Build Coastguard Worker     for (int k = 0; k < num_message_tokens; k++) {
147*993b0882SAndroid Build Coastguard Worker       encoded_positions.push_back(std::min(k, max_encoded_position - 1));
148*993b0882SAndroid Build Coastguard Worker     }
149*993b0882SAndroid Build Coastguard Worker   }
150*993b0882SAndroid Build Coastguard Worker 
151*993b0882SAndroid Build Coastguard Worker   const int num_skip = CopyDataToTensorAndPadOrTruncate(
152*993b0882SAndroid Build Coastguard Worker       max_output_length, encoded_positions,
153*993b0882SAndroid Build Coastguard Worker       /*padding_value=*/max_encoded_position, &output_positions);
154*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_lengths =
155*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputLengths]];
156*993b0882SAndroid Build Coastguard Worker   output_lengths.data.i32[0] = encoded_positions.size() - num_skip;
157*993b0882SAndroid Build Coastguard Worker 
158*993b0882SAndroid Build Coastguard Worker   // Process attributes, all checks of sizes and types are done in Prepare.
159*993b0882SAndroid Build Coastguard Worker   const int num_output_attrs = node->outputs->size - kOutputAttr;
160*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
161*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_output_attrs; ++i) {
162*993b0882SAndroid Build Coastguard Worker     TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
163*993b0882SAndroid Build Coastguard Worker         context->tensors[node->inputs->data[kInputAttr + i]], encoded_offsets,
164*993b0882SAndroid Build Coastguard Worker         num_skip, context,
165*993b0882SAndroid Build Coastguard Worker         &context->tensors[node->outputs->data[kOutputAttr + i]]);
166*993b0882SAndroid Build Coastguard Worker     if (attr_status != kTfLiteOk) {
167*993b0882SAndroid Build Coastguard Worker       return attr_status;
168*993b0882SAndroid Build Coastguard Worker     }
169*993b0882SAndroid Build Coastguard Worker   }
170*993b0882SAndroid Build Coastguard Worker 
171*993b0882SAndroid Build Coastguard Worker   return kTfLiteOk;
172*993b0882SAndroid Build Coastguard Worker }
173*993b0882SAndroid Build Coastguard Worker 
174*993b0882SAndroid Build Coastguard Worker }  // namespace
175*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
176*993b0882SAndroid Build Coastguard Worker 
177*993b0882SAndroid Build Coastguard Worker namespace tflite {
178*993b0882SAndroid Build Coastguard Worker namespace ops {
179*993b0882SAndroid Build Coastguard Worker namespace custom {
180*993b0882SAndroid Build Coastguard Worker 
Register_TOKEN_ENCODER()181*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_TOKEN_ENCODER() {
182*993b0882SAndroid Build Coastguard Worker   static TfLiteRegistration registration = {/*init=*/nullptr, /*free=*/nullptr,
183*993b0882SAndroid Build Coastguard Worker                                             libtextclassifier3::Prepare,
184*993b0882SAndroid Build Coastguard Worker                                             libtextclassifier3::Eval};
185*993b0882SAndroid Build Coastguard Worker   return &registration;
186*993b0882SAndroid Build Coastguard Worker }
187*993b0882SAndroid Build Coastguard Worker 
188*993b0882SAndroid Build Coastguard Worker }  // namespace custom
189*993b0882SAndroid Build Coastguard Worker }  // namespace ops
190*993b0882SAndroid Build Coastguard Worker }  // namespace tflite
191