xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstddef>
16 #include <cstdint>
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace unidirectional_sequence_rnn {
28 
29 namespace {
30 
31 struct OpData {
32   int scratch_tensor_index;
33   bool compute_row_sums = false;
34 };
35 
36 }  // namespace
37 
38 // Input tensors.
39 constexpr int kInputTensor = 0;
40 constexpr int kWeightsTensor = 1;
41 constexpr int kRecurrentWeightsTensor = 2;
42 constexpr int kBiasTensor = 3;
43 constexpr int kHiddenStateTensor = 4;
44 
45 // Output tensor.
46 constexpr int kOutputTensor = 0;
47 
Init(TfLiteContext * context,const char * buffer,size_t length)48 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
49   auto* op_data = new OpData();
50   context->AddTensors(context, /*tensors_to_add=*/6,
51                       &op_data->scratch_tensor_index);
52   return op_data;
53 }
54 
Free(TfLiteContext * context,void * buffer)55 void Free(TfLiteContext* context, void* buffer) {
56   delete reinterpret_cast<OpData*>(buffer);
57 }
58 
Prepare(TfLiteContext * context,TfLiteNode * node)59 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
60   // Check we have all the inputs and outputs we need.
61   TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
62   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
63 
64   const TfLiteTensor* input;
65   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
66   const TfLiteTensor* input_weights;
67   TF_LITE_ENSURE_OK(
68       context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
69   const TfLiteTensor* recurrent_weights;
70   TF_LITE_ENSURE_OK(
71       context,
72       GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
73   const TfLiteTensor* bias;
74   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
75   const TfLiteTensor* hidden_state;
76   TF_LITE_ENSURE_OK(
77       context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state));
78 
79   // Check all the parameters of tensor match within themselves and match the
80   // input configuration.
81   auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
82   const bool time_major = params->time_major;
83   const int batch_size =
84       (time_major) ? input->dims->data[1] : input->dims->data[0];
85   const int max_time =
86       (time_major) ? input->dims->data[0] : input->dims->data[1];
87   const int num_units = input_weights->dims->data[0];
88   TF_LITE_ENSURE_EQ(context, input->dims->data[2],
89                     input_weights->dims->data[1]);
90   TF_LITE_ENSURE_EQ(context, input_weights->dims->data[0], bias->dims->data[0]);
91   TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[0],
92                     bias->dims->data[0]);
93   TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[1],
94                     bias->dims->data[0]);
95   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
96   TF_LITE_ENSURE_TYPES_EQ(context, input_weights->type,
97                           recurrent_weights->type);
98   TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
99   TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
100   TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
101 
102   TfLiteTensor* output;
103   TF_LITE_ENSURE_OK(context,
104                     GetOutputSafe(context, node, kOutputTensor, &output));
105 
106   // Resize output.
107   TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
108   output_size_array->data[0] = (time_major) ? max_time : batch_size;
109   output_size_array->data[1] = (time_major) ? batch_size : max_time;
110   output_size_array->data[2] = num_units;
111   TF_LITE_ENSURE_OK(context,
112                     context->ResizeTensor(context, output, output_size_array));
113 
114   const bool is_hybrid = IsHybridOp(input, input_weights);
115 
116   // Allocate temporary tensors to store quantized values of input and
117   // hidden_state tensors.
118   if (is_hybrid) {
119     auto* op_data = reinterpret_cast<OpData*>(node->user_data);
120     op_data->compute_row_sums = true;
121     TfLiteIntArrayFree(node->temporaries);
122     node->temporaries = TfLiteIntArrayCreate(6);
123     node->temporaries->data[0] = op_data->scratch_tensor_index;
124     TfLiteTensor* input_quantized;
125     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
126                                                 &input_quantized));
127     input_quantized->type = input_weights->type;
128     input_quantized->allocation_type = kTfLiteArenaRw;
129     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
130       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
131       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
132                                                        input_quantized_size));
133     }
134     node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
135     TfLiteTensor* hidden_state_quantized;
136     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
137                                                 &hidden_state_quantized));
138     hidden_state_quantized->type = input_weights->type;
139     hidden_state_quantized->allocation_type = kTfLiteArenaRw;
140     if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
141                              hidden_state->dims)) {
142       TfLiteIntArray* hidden_state_quantized_size =
143           TfLiteIntArrayCopy(hidden_state->dims);
144       TF_LITE_ENSURE_OK(context,
145                         context->ResizeTensor(context, hidden_state_quantized,
146                                               hidden_state_quantized_size));
147     }
148     node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
149     TfLiteTensor* scaling_factors;
150     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
151                                                 &scaling_factors));
152     scaling_factors->type = kTfLiteFloat32;
153     scaling_factors->allocation_type = kTfLiteArenaRw;
154     int scaling_dims[1] = {batch_size};
155     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
156       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
157       scaling_factors_size->data[0] = batch_size;
158       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
159                                                        scaling_factors_size));
160     }
161     node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
162     TfLiteTensor* accum_scratch;
163     TF_LITE_ENSURE_OK(
164         context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch));
165     accum_scratch->type = kTfLiteInt32;
166     accum_scratch->allocation_type = kTfLiteArenaRw;
167     int accum_scratch_dims[2] = {num_units, batch_size};
168     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
169                                    accum_scratch_dims)) {
170       TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
171       accum_scratch_size->data[0] = accum_scratch_dims[0];
172       accum_scratch_size->data[1] = accum_scratch_dims[1];
173       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
174                                                        accum_scratch_size));
175     }
176     node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
177     TfLiteTensor* zero_points;
178     TF_LITE_ENSURE_OK(
179         context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
180     zero_points->type = kTfLiteInt32;
181     zero_points->allocation_type = kTfLiteArenaRw;
182     int zero_points_dims[1] = {batch_size};
183     if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
184       TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
185       zero_points_size->data[0] = batch_size;
186       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
187                                                        zero_points_size));
188     }
189     node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
190     TfLiteTensor* row_sums;
191     TF_LITE_ENSURE_OK(context,
192                       GetTemporarySafe(context, node, /*index=*/5, &row_sums));
193     row_sums->type = kTfLiteInt32;
194     row_sums->allocation_type = kTfLiteArenaRwPersistent;
195     int row_sums_dims[2] = {2, num_units};
196     if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
197       TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
198       row_sums_size->data[0] = row_sums_dims[0];
199       row_sums_size->data[1] = row_sums_dims[1];
200       TF_LITE_ENSURE_OK(
201           context, context->ResizeTensor(context, row_sums, row_sums_size));
202     }
203   }
204   return kTfLiteOk;
205 }
206 
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * input_weights,const TfLiteTensor * recurrent_weights,const TfLiteTensor * bias,const TfLiteSequenceRNNParams * params,TfLiteTensor * hidden_state,TfLiteTensor * output)207 TfLiteStatus EvalFloat(const TfLiteTensor* input,
208                        const TfLiteTensor* input_weights,
209                        const TfLiteTensor* recurrent_weights,
210                        const TfLiteTensor* bias,
211                        const TfLiteSequenceRNNParams* params,
212                        TfLiteTensor* hidden_state, TfLiteTensor* output) {
213   // Initialize the pointer bias.
214   const float* bias_ptr = GetTensorData<float>(bias);
215 
216   const bool time_major = params->time_major;
217   const int batch_size =
218       (time_major) ? input->dims->data[1] : input->dims->data[0];
219   const int max_time =
220       (time_major) ? input->dims->data[0] : input->dims->data[1];
221   const int num_units = input_weights->dims->data[0];
222   const int input_size = input->dims->data[2];
223 
224   // Initialize input_weights and recurrent_weights.
225   const float* input_weights_ptr = GetTensorData<float>(input_weights);
226   const float* recurrent_weights_ptr = GetTensorData<float>(recurrent_weights);
227 
228   if (time_major) {
229     // Initialize the pointer to hidden state.
230     float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state);
231     // Unroll the sequence and use batch operations for efficiency.
232     for (int s = 0; s < max_time; s++) {
233       // Initialize the pointer to input and output.
234       const float* input_ptr_batch =
235           GetTensorData<float>(input) + s * input_size * batch_size;
236       float* output_ptr_batch =
237           GetTensorData<float>(output) + s * num_units * batch_size;
238 
239       kernel_utils::RnnBatchStep(
240           input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr,
241           input_size, num_units, batch_size, num_units, params->activation,
242           hidden_state_ptr_batch, output_ptr_batch);
243     }
244   } else {
245     // For each batch
246     for (int b = 0; b < batch_size; b++) {
247       // Initialize the pointer to hidden state.
248       float* hidden_state_ptr_batch =
249           GetTensorData<float>(hidden_state) + b * num_units;
250       for (int s = 0; s < max_time; s++) {
251         // Initialize the pointer to input and output.
252         const float* input_ptr_batch = GetTensorData<float>(input) +
253                                        b * input_size * max_time +
254                                        s * input_size;
255         float* output_ptr_batch = GetTensorData<float>(output) +
256                                   b * num_units * max_time + s * num_units;
257 
258         kernel_utils::RnnBatchStep(
259             input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr,
260             input_size, num_units, /*batch_size=*/1, num_units,
261             params->activation, hidden_state_ptr_batch, output_ptr_batch);
262       }
263     }
264   }
265   return kTfLiteOk;
266 }
267 
EvalHybrid(const TfLiteTensor * input,const TfLiteTensor * input_weights,const TfLiteTensor * recurrent_weights,const TfLiteTensor * bias,const TfLiteSequenceRNNParams * params,TfLiteTensor * input_scratch,TfLiteTensor * hidden_state_scratch,TfLiteTensor * scaling_factors,TfLiteTensor * hidden_state,TfLiteTensor * output,TfLiteTensor * zero_points,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,bool * compute_row_sums)268 TfLiteStatus EvalHybrid(
269     const TfLiteTensor* input, const TfLiteTensor* input_weights,
270     const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
271     const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
272     TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
273     TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* zero_points,
274     TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
275     bool* compute_row_sums) {
276   const bool time_major = params->time_major;
277   const int batch_size =
278       (time_major) ? input->dims->data[1] : input->dims->data[0];
279   const int max_time =
280       (time_major) ? input->dims->data[0] : input->dims->data[1];
281   const int num_units = input_weights->dims->data[0];
282   const int input_size = input->dims->data[2];
283 
284   // Initialize the pointer bias.
285   const float* bias_ptr = GetTensorData<float>(bias);
286 
287   // Initialize input_weights, recurrent_weights, and temporary storage for
288   // quantized values.
289   const int8_t* input_weights_ptr = GetTensorData<int8_t>(input_weights);
290   const int8_t* recurrent_weights_ptr =
291       GetTensorData<int8_t>(recurrent_weights);
292   int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_scratch);
293   int8_t* quantized_hidden_state_ptr =
294       GetTensorData<int8_t>(hidden_state_scratch);
295 
296   // Get the scale of the quantized weights.
297   float input_weights_scale = input_weights->params.scale;
298   float recurrent_weights_scale = recurrent_weights->params.scale;
299   float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
300   int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
301   int32_t* zero_points_ptr = nullptr;
302   int32_t* row_sums_ptr = nullptr;
303 
304   if (params->asymmetric_quantize_inputs) {
305     zero_points_ptr = GetTensorData<int32_t>(zero_points);
306     row_sums_ptr = GetTensorData<int32_t>(row_sums);
307   }
308 
309   if (time_major) {
310     // Initialize the pointer to hidden state.
311     float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state);
312     // Unroll the sequence and use batch operations for efficiency.
313     for (int s = 0; s < max_time; s++) {
314       // Initialize the pointer to input and output.
315       const float* input_ptr_batch =
316           GetTensorData<float>(input) + s * input_size * batch_size;
317       float* output_ptr_batch =
318           GetTensorData<float>(output) + s * num_units * batch_size;
319 
320       kernel_utils::RnnBatchStep(
321           input_ptr_batch, input_weights_ptr, input_weights_scale,
322           recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
323           num_units, batch_size, num_units, params->activation,
324           quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
325           hidden_state_ptr_batch, output_ptr_batch,
326           params->asymmetric_quantize_inputs, zero_points_ptr,
327           accum_scratch_ptr, row_sums_ptr, compute_row_sums);
328     }
329   } else {
330     // For each batch
331     for (int b = 0; b < batch_size; b++) {
332       // Initialize the pointer to hidden state.
333       float* hidden_state_ptr_batch =
334           GetTensorData<float>(hidden_state) + b * num_units;
335       for (int s = 0; s < max_time; s++) {
336         // Initialize the pointer to input and output.
337         const float* input_ptr_batch = GetTensorData<float>(input) +
338                                        b * input_size * max_time +
339                                        s * input_size;
340         float* output_ptr_batch = GetTensorData<float>(output) +
341                                   b * num_units * max_time + s * num_units;
342         kernel_utils::RnnBatchStep(
343             input_ptr_batch, input_weights_ptr, input_weights_scale,
344             recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
345             input_size, num_units, /*batch_size=*/1, num_units,
346             params->activation, quantized_input_ptr, quantized_hidden_state_ptr,
347             scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch,
348             params->asymmetric_quantize_inputs, zero_points_ptr,
349             accum_scratch_ptr, row_sums_ptr, compute_row_sums);
350       }
351     }
352   }
353   return kTfLiteOk;
354 }
355 
Eval(TfLiteContext * context,TfLiteNode * node)356 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
357   auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
358   const TfLiteTensor* input;
359   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
360   const TfLiteTensor* input_weights;
361   TF_LITE_ENSURE_OK(
362       context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
363   const TfLiteTensor* recurrent_weights;
364   TF_LITE_ENSURE_OK(
365       context,
366       GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
367   const TfLiteTensor* bias;
368   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
369   // The hidden_state is a variable input tensor that can be modified.
370   TfLiteTensor* hidden_state =
371       GetVariableInput(context, node, kHiddenStateTensor);
372   TF_LITE_ENSURE(context, hidden_state != nullptr);
373   TfLiteTensor* output;
374   TF_LITE_ENSURE_OK(context,
375                     GetOutputSafe(context, node, kOutputTensor, &output));
376 
377   switch (input_weights->type) {
378     case kTfLiteFloat32:
379       return EvalFloat(input, input_weights, recurrent_weights, bias, params,
380                        hidden_state, output);
381     case kTfLiteUInt8:
382     case kTfLiteInt8: {
383       // TODO(mirkov): implement eval with quantized inputs as well.
384       auto* op_data = reinterpret_cast<OpData*>(node->user_data);
385       TfLiteTensor* input_quantized;
386       TF_LITE_ENSURE_OK(context,
387                         GetTemporarySafe(context, node, 0, &input_quantized));
388       TfLiteTensor* hidden_state_quantized;
389       TF_LITE_ENSURE_OK(
390           context, GetTemporarySafe(context, node, 1, &hidden_state_quantized));
391       TfLiteTensor* scaling_factors;
392       TF_LITE_ENSURE_OK(context,
393                         GetTemporarySafe(context, node, 2, &scaling_factors));
394       TfLiteTensor* accum_scratch;
395       TF_LITE_ENSURE_OK(context,
396                         GetTemporarySafe(context, node, 3, &accum_scratch));
397       TfLiteTensor* zero_points;
398       TF_LITE_ENSURE_OK(context,
399                         GetTemporarySafe(context, node, 4, &zero_points));
400       TfLiteTensor* row_sums;
401       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums));
402       return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
403                         input_quantized, hidden_state_quantized,
404                         scaling_factors, hidden_state, output, zero_points,
405                         accum_scratch, row_sums, &op_data->compute_row_sums);
406     }
407     default:
408       TF_LITE_KERNEL_LOG(context, "Type %d not currently supported.",
409                          TfLiteTypeGetName(input_weights->type));
410       return kTfLiteError;
411   }
412 }
413 
414 }  // namespace unidirectional_sequence_rnn
415 
Register_UNIDIRECTIONAL_SEQUENCE_RNN()416 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() {
417   static TfLiteRegistration r = {
418       unidirectional_sequence_rnn::Init, unidirectional_sequence_rnn::Free,
419       unidirectional_sequence_rnn::Prepare, unidirectional_sequence_rnn::Eval};
420   return &r;
421 }
422 
423 }  // namespace builtin
424 }  // namespace ops
425 }  // namespace tflite
426