xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/basic_rnn.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstddef>
16 #include <cstdint>
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace rnn {
28 
29 namespace {
30 
31 struct OpData {
32   int scratch_tensor_index;
33   bool compute_row_sums = false;
34 };
35 
36 }  // namespace
37 
38 constexpr int kInputTensor = 0;
39 constexpr int kWeightsTensor = 1;
40 constexpr int kRecurrentWeightsTensor = 2;
41 constexpr int kBiasTensor = 3;
42 constexpr int kHiddenStateTensor = 4;
43 
44 // Output tensor.
45 constexpr int kOutputTensor = 0;
46 
Init(TfLiteContext * context,const char * buffer,size_t length)47 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
48   auto* op_data = new OpData();
49   context->AddTensors(context, /*tensors_to_add=*/6,
50                       &op_data->scratch_tensor_index);
51   return op_data;
52 }
53 
Free(TfLiteContext * context,void * buffer)54 void Free(TfLiteContext* context, void* buffer) {
55   delete reinterpret_cast<OpData*>(buffer);
56 }
57 
Prepare(TfLiteContext * context,TfLiteNode * node)58 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
59   // Check we have all the inputs and outputs we need.
60   TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
61   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
62 
63   const TfLiteTensor* input;
64   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
65   const TfLiteTensor* input_weights;
66   TF_LITE_ENSURE_OK(
67       context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
68   const TfLiteTensor* recurrent_weights;
69   TF_LITE_ENSURE_OK(
70       context,
71       GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
72   const TfLiteTensor* bias;
73   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
74   const TfLiteTensor* hidden_state;
75   TF_LITE_ENSURE_OK(
76       context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state));
77 
78   // Check all the parameters of tensor match within themselves and match the
79   // input configuration.
80   const int batch_size = input->dims->data[0];
81   const int num_units = input_weights->dims->data[0];
82   TF_LITE_ENSURE_EQ(context, input->dims->data[1],
83                     input_weights->dims->data[1]);
84   TF_LITE_ENSURE_EQ(context, input_weights->dims->data[0], bias->dims->data[0]);
85   TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[0],
86                     bias->dims->data[0]);
87   TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[1],
88                     bias->dims->data[0]);
89   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
90   TF_LITE_ENSURE_TYPES_EQ(context, input_weights->type,
91                           recurrent_weights->type);
92   TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
93   TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
94   TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
95 
96   TfLiteTensor* output;
97   TF_LITE_ENSURE_OK(context,
98                     GetOutputSafe(context, node, kOutputTensor, &output));
99 
100   // Resize output.
101   TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
102   output_size_array->data[0] = batch_size;
103   output_size_array->data[1] = num_units;
104   TF_LITE_ENSURE_OK(context,
105                     context->ResizeTensor(context, output, output_size_array));
106 
107   const bool is_hybrid = IsHybridOp(input, input_weights);
108 
109   // Allocate temporary tensors to store quantized values of input and
110   // hidden_state tensors.
111   if (is_hybrid) {
112     auto* op_data = reinterpret_cast<OpData*>(node->user_data);
113     op_data->compute_row_sums = true;
114     TfLiteIntArrayFree(node->temporaries);
115     node->temporaries = TfLiteIntArrayCreate(6);
116     node->temporaries->data[0] = op_data->scratch_tensor_index;
117     TfLiteTensor* input_quantized;
118     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
119                                                 &input_quantized));
120     input_quantized->type = input_weights->type;
121     input_quantized->allocation_type = kTfLiteArenaRw;
122     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
123       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
124       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
125                                                        input_quantized_size));
126     }
127     node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
128     TfLiteTensor* hidden_state_quantized;
129     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
130                                                 &hidden_state_quantized));
131     hidden_state_quantized->type = input_weights->type;
132     hidden_state_quantized->allocation_type = kTfLiteArenaRw;
133     if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
134                              hidden_state->dims)) {
135       TfLiteIntArray* hidden_state_quantized_size =
136           TfLiteIntArrayCopy(hidden_state->dims);
137       TF_LITE_ENSURE_OK(context,
138                         context->ResizeTensor(context, hidden_state_quantized,
139                                               hidden_state_quantized_size));
140     }
141     node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
142     TfLiteTensor* scaling_factors;
143     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
144                                                 &scaling_factors));
145     scaling_factors->type = kTfLiteFloat32;
146     scaling_factors->allocation_type = kTfLiteArenaRw;
147     int scaling_dims[1] = {batch_size};
148     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
149       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
150       scaling_factors_size->data[0] = batch_size;
151       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
152                                                        scaling_factors_size));
153     }
154     node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
155     TfLiteTensor* accum_scratch;
156     TF_LITE_ENSURE_OK(
157         context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch));
158     accum_scratch->type = kTfLiteInt32;
159     accum_scratch->allocation_type = kTfLiteArenaRw;
160     int accum_scratch_dims[2] = {num_units, batch_size};
161     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
162                                    accum_scratch_dims)) {
163       TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
164       accum_scratch_size->data[0] = accum_scratch_dims[0];
165       accum_scratch_size->data[1] = accum_scratch_dims[1];
166       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
167                                                        accum_scratch_size));
168     }
169     node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
170     TfLiteTensor* zero_points;
171     TF_LITE_ENSURE_OK(
172         context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
173     zero_points->type = kTfLiteInt32;
174     zero_points->allocation_type = kTfLiteArenaRw;
175     int zero_points_dims[1] = {batch_size};
176     if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
177       TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
178       zero_points_size->data[0] = batch_size;
179       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
180                                                        zero_points_size));
181     }
182     node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
183     TfLiteTensor* row_sums;
184     TF_LITE_ENSURE_OK(context,
185                       GetTemporarySafe(context, node, /*index=*/5, &row_sums));
186     row_sums->type = kTfLiteInt32;
187     row_sums->name = "Rnn_row_sums";
188     row_sums->allocation_type = kTfLiteArenaRwPersistent;
189     int row_sums_dims[2] = {2, num_units};
190     if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
191       TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
192       row_sums_size->data[0] = row_sums_dims[0];
193       row_sums_size->data[1] = row_sums_dims[1];
194       TF_LITE_ENSURE_OK(
195           context, context->ResizeTensor(context, row_sums, row_sums_size));
196     }
197   }
198   return kTfLiteOk;
199 }
200 
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * input_weights,const TfLiteTensor * recurrent_weights,const TfLiteTensor * bias,const TfLiteRNNParams * params,TfLiteTensor * hidden_state,TfLiteTensor * output)201 TfLiteStatus EvalFloat(const TfLiteTensor* input,
202                        const TfLiteTensor* input_weights,
203                        const TfLiteTensor* recurrent_weights,
204                        const TfLiteTensor* bias, const TfLiteRNNParams* params,
205                        TfLiteTensor* hidden_state, TfLiteTensor* output) {
206   const int batch_size = input->dims->data[0];
207   const int num_units = input_weights->dims->data[0];
208   const int input_size = input->dims->data[1];
209   const int output_batch_leading_dim =
210       output->dims->data[output->dims->size - 1];
211 
212   // Initialize the pointer to hidden state.
213   float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state);
214   // Initialize the pointer to input and output.
215   const float* input_ptr_batch = GetTensorData<float>(input);
216   float* output_ptr_batch = GetTensorData<float>(output);
217   // Initialize input_weights, recurrent_weights and bias.
218   const float* input_weights_ptr = GetTensorData<float>(input_weights);
219   const float* recurrent_weights_ptr = GetTensorData<float>(recurrent_weights);
220   const float* bias_ptr = GetTensorData<float>(bias);
221 
222   kernel_utils::RnnBatchStep(
223       input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr,
224       input_size, num_units, batch_size, output_batch_leading_dim,
225       params->activation, hidden_state_ptr_batch, output_ptr_batch);
226   return kTfLiteOk;
227 }
228 
EvalHybrid(const TfLiteTensor * input,const TfLiteTensor * input_weights,const TfLiteTensor * recurrent_weights,const TfLiteTensor * bias,const TfLiteRNNParams * params,TfLiteTensor * input_scratch,TfLiteTensor * hidden_state_scratch,TfLiteTensor * scaling_factors,TfLiteTensor * hidden_state,TfLiteTensor * output,TfLiteTensor * zero_points,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,bool * compute_row_sums)229 TfLiteStatus EvalHybrid(const TfLiteTensor* input,
230                         const TfLiteTensor* input_weights,
231                         const TfLiteTensor* recurrent_weights,
232                         const TfLiteTensor* bias, const TfLiteRNNParams* params,
233                         TfLiteTensor* input_scratch,
234                         TfLiteTensor* hidden_state_scratch,
235                         TfLiteTensor* scaling_factors,
236                         TfLiteTensor* hidden_state, TfLiteTensor* output,
237                         TfLiteTensor* zero_points, TfLiteTensor* accum_scratch,
238                         TfLiteTensor* row_sums, bool* compute_row_sums) {
239   const int batch_size = input->dims->data[0];
240   const int num_units = input_weights->dims->data[0];
241   const int input_size = input->dims->data[1];
242   const int output_batch_leading_dim =
243       output->dims->data[output->dims->size - 1];
244 
245   // Initialize the pointer to hidden state.
246   float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state);
247   // Initialize the pointer to input and output.
248   const float* input_ptr_batch = GetTensorData<float>(input);
249   float* output_ptr_batch = GetTensorData<float>(output);
250   // Initialize input_weights, recurrent_weights and bias.
251   const int8_t* input_weights_ptr = GetTensorData<int8_t>(input_weights);
252   const int8_t* recurrent_weights_ptr =
253       GetTensorData<int8_t>(recurrent_weights);
254   const float* bias_ptr = GetTensorData<float>(bias);
255   // Get the scale of the quantized weights.
256   float input_weights_scale = input_weights->params.scale;
257   float recurrent_weights_scale = recurrent_weights->params.scale;
258   // Initialize temporary storage for quantized values.
259   int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_scratch);
260   int8_t* quantized_hidden_state_ptr =
261       GetTensorData<int8_t>(hidden_state_scratch);
262   float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
263   int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
264   int32_t* zero_points_ptr = nullptr;
265   int32_t* row_sums_ptr = nullptr;
266   if (params->asymmetric_quantize_inputs) {
267     zero_points_ptr = GetTensorData<int32_t>(zero_points);
268     row_sums_ptr = GetTensorData<int32_t>(row_sums);
269   }
270   kernel_utils::RnnBatchStep(
271       input_ptr_batch, input_weights_ptr, input_weights_scale,
272       recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
273       num_units, batch_size, output_batch_leading_dim, params->activation,
274       quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
275       hidden_state_ptr_batch, output_ptr_batch,
276       params->asymmetric_quantize_inputs, zero_points_ptr, accum_scratch_ptr,
277       row_sums_ptr, compute_row_sums);
278   return kTfLiteOk;
279 }
280 
Eval(TfLiteContext * context,TfLiteNode * node)281 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
282   auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
283   auto* op_data = reinterpret_cast<OpData*>(node->user_data);
284   const TfLiteTensor* input;
285   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
286   const TfLiteTensor* input_weights;
287   TF_LITE_ENSURE_OK(
288       context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
289   const TfLiteTensor* recurrent_weights;
290   TF_LITE_ENSURE_OK(
291       context,
292       GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
293   const TfLiteTensor* bias;
294   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
295   TfLiteTensor* hidden_state =
296       GetVariableInput(context, node, kHiddenStateTensor);
297   TF_LITE_ENSURE(context, hidden_state != nullptr);
298   TfLiteTensor* output;
299   TF_LITE_ENSURE_OK(context,
300                     GetOutputSafe(context, node, kOutputTensor, &output));
301 
302   // We already checked that weight types are consistent, so branch on one.
303   switch (input_weights->type) {
304     case kTfLiteFloat32:
305       return EvalFloat(input, input_weights, recurrent_weights, bias, params,
306                        hidden_state, output);
307     case kTfLiteUInt8:
308     case kTfLiteInt8: {
309       // TODO(mirkov): implement eval with quantized inputs as well.
310       TfLiteTensor* input_quantized;
311       TF_LITE_ENSURE_OK(context,
312                         GetTemporarySafe(context, node, 0, &input_quantized));
313       TfLiteTensor* hidden_state_quantized;
314       TF_LITE_ENSURE_OK(
315           context, GetTemporarySafe(context, node, 1, &hidden_state_quantized));
316       TfLiteTensor* scaling_factors;
317       TF_LITE_ENSURE_OK(context,
318                         GetTemporarySafe(context, node, 2, &scaling_factors));
319       TfLiteTensor* accum_scratch;
320       TF_LITE_ENSURE_OK(context,
321                         GetTemporarySafe(context, node, 3, &accum_scratch));
322       TfLiteTensor* zero_points;
323       TF_LITE_ENSURE_OK(context,
324                         GetTemporarySafe(context, node, 4, &zero_points));
325       TfLiteTensor* row_sums;
326       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums));
327       return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
328                         input_quantized, hidden_state_quantized,
329                         scaling_factors, hidden_state, output, zero_points,
330                         accum_scratch, row_sums, &op_data->compute_row_sums);
331     }
332     default:
333       TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
334                          TfLiteTypeGetName(input_weights->type));
335       return kTfLiteError;
336   }
337 }
338 
339 }  // namespace rnn
340 
Register_RNN()341 TfLiteRegistration* Register_RNN() {
342   static TfLiteRegistration r = {rnn::Init, rnn::Free, rnn::Prepare, rnn::Eval};
343   return &r;
344 }
345 
346 }  // namespace builtin
347 }  // namespace ops
348 }  // namespace tflite
349