1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/token_encoder.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <vector>
20*993b0882SAndroid Build Coastguard Worker
21*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/encoder_common.h"
22*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/kernel_util.h"
23*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h"
24*993b0882SAndroid Build Coastguard Worker
25*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
26*993b0882SAndroid Build Coastguard Worker namespace {
27*993b0882SAndroid Build Coastguard Worker
28*993b0882SAndroid Build Coastguard Worker // Input parameters for the op.
29*993b0882SAndroid Build Coastguard Worker // The number of tokens per message as (1, conversation length) int tensor.
30*993b0882SAndroid Build Coastguard Worker constexpr const int kInputNumTokens = 0;
31*993b0882SAndroid Build Coastguard Worker
32*993b0882SAndroid Build Coastguard Worker // The number of messages, the conversation length, int scalar.
33*993b0882SAndroid Build Coastguard Worker constexpr const int kInputNumInputs = 1;
34*993b0882SAndroid Build Coastguard Worker
35*993b0882SAndroid Build Coastguard Worker // Maximum output length of the encoding, int scalar.
36*993b0882SAndroid Build Coastguard Worker constexpr const int kInputMaxLength = 2;
37*993b0882SAndroid Build Coastguard Worker
38*993b0882SAndroid Build Coastguard Worker // Additional attributes to align to the sentence pieces, e.g. user ids per
39*993b0882SAndroid Build Coastguard Worker // message.
40*993b0882SAndroid Build Coastguard Worker constexpr const int kInputAttr = 3;
41*993b0882SAndroid Build Coastguard Worker
42*993b0882SAndroid Build Coastguard Worker // Output parameters for the op.
43*993b0882SAndroid Build Coastguard Worker // Relative position of each token in the input text,
44*993b0882SAndroid Build Coastguard Worker // (1, max output length) int tensor.
45*993b0882SAndroid Build Coastguard Worker constexpr const int kOutputPosition = 0;
46*993b0882SAndroid Build Coastguard Worker
47*993b0882SAndroid Build Coastguard Worker // Output length after trimming to the maximum output length specified.
48*993b0882SAndroid Build Coastguard Worker // int scalar.
49*993b0882SAndroid Build Coastguard Worker constexpr const int kOutputLengths = 1;
50*993b0882SAndroid Build Coastguard Worker
51*993b0882SAndroid Build Coastguard Worker // Padded and sentence piece aligned provided attributes, e.g. user id per
52*993b0882SAndroid Build Coastguard Worker // sentence piece.
53*993b0882SAndroid Build Coastguard Worker constexpr const int kOutputAttr = 2;
54*993b0882SAndroid Build Coastguard Worker
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,int max_output_length)55*993b0882SAndroid Build Coastguard Worker TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
56*993b0882SAndroid Build Coastguard Worker int max_output_length) {
57*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(
58*993b0882SAndroid Build Coastguard Worker context,
59*993b0882SAndroid Build Coastguard Worker ResizeOutputTensor(
60*993b0882SAndroid Build Coastguard Worker max_output_length,
61*993b0882SAndroid Build Coastguard Worker &context->tensors[node->outputs->data[kOutputPosition]], context));
62*993b0882SAndroid Build Coastguard Worker
63*993b0882SAndroid Build Coastguard Worker const int num_output_attrs = node->outputs->size - kOutputAttr;
64*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < num_output_attrs; ++i) {
65*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(
66*993b0882SAndroid Build Coastguard Worker context,
67*993b0882SAndroid Build Coastguard Worker ResizeOutputTensor(
68*993b0882SAndroid Build Coastguard Worker max_output_length,
69*993b0882SAndroid Build Coastguard Worker &context->tensors[node->outputs->data[kOutputAttr + i]], context));
70*993b0882SAndroid Build Coastguard Worker }
71*993b0882SAndroid Build Coastguard Worker return kTfLiteOk;
72*993b0882SAndroid Build Coastguard Worker }
73*993b0882SAndroid Build Coastguard Worker
Prepare(TfLiteContext * context,TfLiteNode * node)74*993b0882SAndroid Build Coastguard Worker TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
75*993b0882SAndroid Build Coastguard Worker // Check that the batch dimension is kBatchSize.
76*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& num_tokens =
77*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kInputNumTokens]];
78*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_EQ(context, num_tokens.dims->size, kEncoderInputRank);
79*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_EQ(context, num_tokens.dims->data[0], kEncoderBatchSize);
80*993b0882SAndroid Build Coastguard Worker
81*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_lengths =
82*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputLengths]];
83*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_positions =
84*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputPosition]];
85*993b0882SAndroid Build Coastguard Worker
86*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(context,
87*993b0882SAndroid Build Coastguard Worker context->ResizeTensor(context, &output_lengths,
88*993b0882SAndroid Build Coastguard Worker CreateIntArray({kEncoderBatchSize})));
89*993b0882SAndroid Build Coastguard Worker
90*993b0882SAndroid Build Coastguard Worker // Check that there are enough outputs for attributes.
91*993b0882SAndroid Build Coastguard Worker const int num_output_attrs = node->outputs->size - kOutputAttr;
92*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
93*993b0882SAndroid Build Coastguard Worker
94*993b0882SAndroid Build Coastguard Worker // Copy attribute types from input to output tensors.
95*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < num_output_attrs; ++i) {
96*993b0882SAndroid Build Coastguard Worker TfLiteTensor& input = context->tensors[node->inputs->data[kInputAttr + i]];
97*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output =
98*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputAttr + i]];
99*993b0882SAndroid Build Coastguard Worker output.type = input.type;
100*993b0882SAndroid Build Coastguard Worker }
101*993b0882SAndroid Build Coastguard Worker
102*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& output_length =
103*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kInputMaxLength]];
104*993b0882SAndroid Build Coastguard Worker
105*993b0882SAndroid Build Coastguard Worker if (tflite::IsConstantTensor(&output_length)) {
106*993b0882SAndroid Build Coastguard Worker return ResizeOutputTensors(context, node, output_length.data.i64[0]);
107*993b0882SAndroid Build Coastguard Worker } else {
108*993b0882SAndroid Build Coastguard Worker tflite::SetTensorToDynamic(&output_positions);
109*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < num_output_attrs; ++i) {
110*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_attr =
111*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputAttr + i]];
112*993b0882SAndroid Build Coastguard Worker tflite::SetTensorToDynamic(&output_attr);
113*993b0882SAndroid Build Coastguard Worker }
114*993b0882SAndroid Build Coastguard Worker }
115*993b0882SAndroid Build Coastguard Worker
116*993b0882SAndroid Build Coastguard Worker return kTfLiteOk;
117*993b0882SAndroid Build Coastguard Worker }
118*993b0882SAndroid Build Coastguard Worker
Eval(TfLiteContext * context,TfLiteNode * node)119*993b0882SAndroid Build Coastguard Worker TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
120*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& num_tokens =
121*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kInputNumTokens]];
122*993b0882SAndroid Build Coastguard Worker const int num_inputs =
123*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kInputNumInputs]].data.i32[0];
124*993b0882SAndroid Build Coastguard Worker
125*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& output_length =
126*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kInputMaxLength]];
127*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_positions =
128*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputPosition]];
129*993b0882SAndroid Build Coastguard Worker if (!tflite::IsConstantTensor(&output_length)) {
130*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(
131*993b0882SAndroid Build Coastguard Worker context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
132*993b0882SAndroid Build Coastguard Worker }
133*993b0882SAndroid Build Coastguard Worker
134*993b0882SAndroid Build Coastguard Worker std::vector<int> encoded_offsets;
135*993b0882SAndroid Build Coastguard Worker std::vector<int> encoded_positions;
136*993b0882SAndroid Build Coastguard Worker encoded_offsets.reserve(num_inputs);
137*993b0882SAndroid Build Coastguard Worker const int max_output_length = output_positions.dims->data[1];
138*993b0882SAndroid Build Coastguard Worker const int max_encoded_position = max_output_length;
139*993b0882SAndroid Build Coastguard Worker int total_tokens = 0;
140*993b0882SAndroid Build Coastguard Worker
141*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < num_inputs; ++i) {
142*993b0882SAndroid Build Coastguard Worker const int num_message_tokens =
143*993b0882SAndroid Build Coastguard Worker num_tokens.data.i32[i] + 2; /* num_tokens + start and end token. */
144*993b0882SAndroid Build Coastguard Worker total_tokens += num_message_tokens;
145*993b0882SAndroid Build Coastguard Worker encoded_offsets.push_back(total_tokens);
146*993b0882SAndroid Build Coastguard Worker for (int k = 0; k < num_message_tokens; k++) {
147*993b0882SAndroid Build Coastguard Worker encoded_positions.push_back(std::min(k, max_encoded_position - 1));
148*993b0882SAndroid Build Coastguard Worker }
149*993b0882SAndroid Build Coastguard Worker }
150*993b0882SAndroid Build Coastguard Worker
151*993b0882SAndroid Build Coastguard Worker const int num_skip = CopyDataToTensorAndPadOrTruncate(
152*993b0882SAndroid Build Coastguard Worker max_output_length, encoded_positions,
153*993b0882SAndroid Build Coastguard Worker /*padding_value=*/max_encoded_position, &output_positions);
154*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_lengths =
155*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputLengths]];
156*993b0882SAndroid Build Coastguard Worker output_lengths.data.i32[0] = encoded_positions.size() - num_skip;
157*993b0882SAndroid Build Coastguard Worker
158*993b0882SAndroid Build Coastguard Worker // Process attributes, all checks of sizes and types are done in Prepare.
159*993b0882SAndroid Build Coastguard Worker const int num_output_attrs = node->outputs->size - kOutputAttr;
160*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttr, num_output_attrs);
161*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < num_output_attrs; ++i) {
162*993b0882SAndroid Build Coastguard Worker TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
163*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kInputAttr + i]], encoded_offsets,
164*993b0882SAndroid Build Coastguard Worker num_skip, context,
165*993b0882SAndroid Build Coastguard Worker &context->tensors[node->outputs->data[kOutputAttr + i]]);
166*993b0882SAndroid Build Coastguard Worker if (attr_status != kTfLiteOk) {
167*993b0882SAndroid Build Coastguard Worker return attr_status;
168*993b0882SAndroid Build Coastguard Worker }
169*993b0882SAndroid Build Coastguard Worker }
170*993b0882SAndroid Build Coastguard Worker
171*993b0882SAndroid Build Coastguard Worker return kTfLiteOk;
172*993b0882SAndroid Build Coastguard Worker }
173*993b0882SAndroid Build Coastguard Worker
174*993b0882SAndroid Build Coastguard Worker } // namespace
175*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
176*993b0882SAndroid Build Coastguard Worker
177*993b0882SAndroid Build Coastguard Worker namespace tflite {
178*993b0882SAndroid Build Coastguard Worker namespace ops {
179*993b0882SAndroid Build Coastguard Worker namespace custom {
180*993b0882SAndroid Build Coastguard Worker
Register_TOKEN_ENCODER()181*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_TOKEN_ENCODER() {
182*993b0882SAndroid Build Coastguard Worker static TfLiteRegistration registration = {/*init=*/nullptr, /*free=*/nullptr,
183*993b0882SAndroid Build Coastguard Worker libtextclassifier3::Prepare,
184*993b0882SAndroid Build Coastguard Worker libtextclassifier3::Eval};
185*993b0882SAndroid Build Coastguard Worker return ®istration;
186*993b0882SAndroid Build Coastguard Worker }
187*993b0882SAndroid Build Coastguard Worker
188*993b0882SAndroid Build Coastguard Worker } // namespace custom
189*993b0882SAndroid Build Coastguard Worker } // namespace ops
190*993b0882SAndroid Build Coastguard Worker } // namespace tflite
191