1*993b0882SAndroid Build Coastguard Worker /* 2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project 3*993b0882SAndroid Build Coastguard Worker * 4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*993b0882SAndroid Build Coastguard Worker * 8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*993b0882SAndroid Build Coastguard Worker * 10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*993b0882SAndroid Build Coastguard Worker * limitations under the License. 15*993b0882SAndroid Build Coastguard Worker */ 16*993b0882SAndroid Build Coastguard Worker 17*993b0882SAndroid Build Coastguard Worker // Shared methods for the text and token encoders. 18*993b0882SAndroid Build Coastguard Worker 19*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_ENCODER_COMMON_H_ 20*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_UTILS_TFLITE_ENCODER_COMMON_H_ 21*993b0882SAndroid Build Coastguard Worker 22*993b0882SAndroid Build Coastguard Worker #include <memory> 23*993b0882SAndroid Build Coastguard Worker #include <vector> 24*993b0882SAndroid Build Coastguard Worker 25*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h" 26*993b0882SAndroid Build Coastguard Worker 27*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 { 28*993b0882SAndroid Build Coastguard Worker 29*993b0882SAndroid Build Coastguard Worker // Input rank for the encoder ops is 2, because the first dimension is 30*993b0882SAndroid Build Coastguard Worker // always considered to be for batching, and during inference is always set to 31*993b0882SAndroid Build Coastguard Worker // 1, and the second dimension indexes the input values (texts or token 32*993b0882SAndroid Build Coastguard Worker // lengths). 33*993b0882SAndroid Build Coastguard Worker constexpr const int kEncoderInputRank = 2; 34*993b0882SAndroid Build Coastguard Worker constexpr const int kEncoderBatchSize = 1; 35*993b0882SAndroid Build Coastguard Worker 36*993b0882SAndroid Build Coastguard Worker // Creates a TensorFlow Lite array from an initializer list. 37*993b0882SAndroid Build Coastguard Worker TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values); 38*993b0882SAndroid Build Coastguard Worker 39*993b0882SAndroid Build Coastguard Worker // Copies values associated with the input to the output. 40*993b0882SAndroid Build Coastguard Worker // Typically we have attribute values associated with each item in the input, 41*993b0882SAndroid Build Coastguard Worker // e.g. user id per message in the conversation. 42*993b0882SAndroid Build Coastguard Worker // This aligns and replicates the attribute values with the encoded input, e.g. 43*993b0882SAndroid Build Coastguard Worker // replicates the same user id per token or sentence piece of the input. 44*993b0882SAndroid Build Coastguard Worker // As the input for the whole conversation is concatenated and (potentially) 45*993b0882SAndroid Build Coastguard Worker // trimmed, `encoding_end_offset` indicates where each item ends and 46*993b0882SAndroid Build Coastguard Worker // `start_offset` indicates how many elements at the beginning were dropped. 47*993b0882SAndroid Build Coastguard Worker TfLiteStatus CopyValuesToTensorAndPadOrTruncate( 48*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets, 49*993b0882SAndroid Build Coastguard Worker int start_offset, TfLiteContext* context, TfLiteTensor* out); 50*993b0882SAndroid Build Coastguard Worker 51*993b0882SAndroid Build Coastguard Worker // Resizes an output tensor to shape {kBatchSize, max_output_length}. 52*993b0882SAndroid Build Coastguard Worker TfLiteStatus ResizeOutputTensor(const int max_output_length, 53*993b0882SAndroid Build Coastguard Worker TfLiteTensor* tensor, TfLiteContext* context); 54*993b0882SAndroid Build Coastguard Worker 55*993b0882SAndroid Build Coastguard Worker // Copy a slice of data to output. 56*993b0882SAndroid Build Coastguard Worker // If the size of the data is smaller than `max_output_length` then the output 57*993b0882SAndroid Build Coastguard Worker // is padded with `padding_value`. 58*993b0882SAndroid Build Coastguard Worker // If the size of the data is larger than `max_output_length` then entries at 59*993b0882SAndroid Build Coastguard Worker // the beginning a dropped to fit into the limit. 60*993b0882SAndroid Build Coastguard Worker int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length, 61*993b0882SAndroid Build Coastguard Worker const std::vector<int32_t>& data, 62*993b0882SAndroid Build Coastguard Worker const int32_t padding_value, 63*993b0882SAndroid Build Coastguard Worker TfLiteTensor* output_tensor); 64*993b0882SAndroid Build Coastguard Worker 65*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3 66*993b0882SAndroid Build Coastguard Worker 67*993b0882SAndroid Build Coastguard Worker #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_ENCODER_COMMON_H_ 68