xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cstddef>
17 #include <cstdint>
18 
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace bidirectional_sequence_rnn {
30 
31 namespace {
32 
33 struct OpData {
34   int scratch_tensor_index;
35   bool fw_compute_row_sums = false;
36   bool bw_compute_row_sums = false;
37 };
38 
39 }  // namespace
40 
41 // LINT.IfChange
42 
43 constexpr int kInputTensor = 0;
44 // Forward and backward cell tensors.
45 constexpr int kFwWeightsTensor = 1;
46 constexpr int kFwRecurrentWeightsTensor = 2;
47 constexpr int kFwBiasTensor = 3;
48 constexpr int kFwHiddenStateTensor = 4;
49 constexpr int kBwWeightsTensor = 5;
50 constexpr int kBwRecurrentWeightsTensor = 6;
51 constexpr int kBwBiasTensor = 7;
52 constexpr int kBwHiddenStateTensor = 8;
53 // Used as auxiliary input and weights when stacking for
54 // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
55 // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
56 // (without cross links).
57 constexpr int kAuxInputTensor = 9;       // Optional.
58 constexpr int kFwAuxWeightsTensor = 10;  // Optional.
59 constexpr int kBwAuxWeightsTensor = 11;  // Optional.
60 // Output tensors.
61 constexpr int kFwOutputTensor = 0;
62 constexpr int kBwOutputTensor = 1;  // Only if merge_outputs is false.
63 
64 // LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc)
65 
66 // Temporary tensors.
67 enum TemporaryTensor {
68   kInputQuantized = 0,
69   kFwHiddenStateQuantized = 1,
70   kBwHiddenStateQuantized = 2,
71   kScalingFactors = 3,
72   kAccumScratch = 4,
73   kZeroPoints = 5,
74   kFwRowSums = 6,
75   kBwRowSums = 7,
76   kAuxInputQuantized = 8,
77   kNumTemporaryTensors = 9
78 };
79 
Init(TfLiteContext * context,const char * buffer,size_t length)80 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
81   auto* op_data = new OpData();
82   context->AddTensors(context, kNumTemporaryTensors,
83                       &op_data->scratch_tensor_index);
84   return op_data;
85 }
86 
Free(TfLiteContext * context,void * buffer)87 void Free(TfLiteContext* context, void* buffer) {
88   delete reinterpret_cast<OpData*>(buffer);
89 }
90 
Prepare(TfLiteContext * context,TfLiteNode * node)91 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
92   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
93       node->builtin_data);
94 
95   // Check we have all the inputs and outputs we need.
96   TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
97   TF_LITE_ENSURE_EQ(context, node->outputs->size,
98                     params->merge_outputs ? 1 : 2);
99 
100   const TfLiteTensor* input;
101   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
102   const TfLiteTensor* fw_input_weights;
103   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
104                                           &fw_input_weights));
105   const TfLiteTensor* fw_recurrent_weights;
106   TF_LITE_ENSURE_OK(context,
107                     GetInputSafe(context, node, kFwRecurrentWeightsTensor,
108                                  &fw_recurrent_weights));
109   const TfLiteTensor* fw_bias;
110   TF_LITE_ENSURE_OK(context,
111                     GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
112   const TfLiteTensor* fw_hidden_state;
113   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwHiddenStateTensor,
114                                           &fw_hidden_state));
115   const TfLiteTensor* bw_input_weights;
116   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
117                                           &bw_input_weights));
118   const TfLiteTensor* bw_recurrent_weights;
119   TF_LITE_ENSURE_OK(context,
120                     GetInputSafe(context, node, kBwRecurrentWeightsTensor,
121                                  &bw_recurrent_weights));
122   const TfLiteTensor* bw_bias;
123   TF_LITE_ENSURE_OK(context,
124                     GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
125   const TfLiteTensor* bw_hidden_state;
126   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwHiddenStateTensor,
127                                           &bw_hidden_state));
128 
129   const TfLiteTensor* aux_input =
130       GetOptionalInputTensor(context, node, kAuxInputTensor);
131   const TfLiteTensor* fw_aux_input_weights =
132       GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
133   const TfLiteTensor* bw_aux_input_weights =
134       GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
135 
136   const bool aux_inputs_weights_or_none =
137       ((fw_aux_input_weights != nullptr) &&
138        (bw_aux_input_weights != nullptr)) ||
139       ((fw_aux_input_weights == nullptr) && (bw_aux_input_weights == nullptr));
140   TF_LITE_ENSURE(context, aux_inputs_weights_or_none);
141   const bool has_aux_input = (fw_aux_input_weights != nullptr);
142 
143   // Check all the parameters of tensor match within themselves and match the
144   // input configuration.
145   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
146 
147   TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
148   const bool time_major = params->time_major;
149   const int batch_size =
150       (time_major) ? input->dims->data[1] : input->dims->data[0];
151   const int max_time =
152       (time_major) ? input->dims->data[0] : input->dims->data[1];
153   const int fw_num_units = fw_input_weights->dims->data[0];
154   const int bw_num_units = bw_input_weights->dims->data[0];
155   TF_LITE_ENSURE_EQ(context, input->dims->data[2],
156                     fw_input_weights->dims->data[1]);
157   TF_LITE_ENSURE_EQ(context, input->dims->data[2],
158                     bw_input_weights->dims->data[1]);
159   TF_LITE_ENSURE_EQ(context, fw_input_weights->dims->data[0],
160                     fw_bias->dims->data[0]);
161   TF_LITE_ENSURE_EQ(context, bw_input_weights->dims->data[0],
162                     bw_bias->dims->data[0]);
163   TF_LITE_ENSURE_EQ(context, fw_recurrent_weights->dims->data[0],
164                     fw_bias->dims->data[0]);
165   TF_LITE_ENSURE_EQ(context, bw_recurrent_weights->dims->data[1],
166                     bw_bias->dims->data[0]);
167   TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2);
168   TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size);
169   TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units);
170   TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2);
171   TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size);
172   TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units);
173 
174   if (has_aux_input) {
175     // Check that aux_input has the same dimensions (except last) as the input.
176     TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
177     TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
178     // Check that aux_input_weights has the same dimensions (except last) as
179     // the input_weights.
180     TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units);
181     TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units);
182     TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
183                       fw_aux_input_weights->dims->data[1]);
184     TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
185                       bw_aux_input_weights->dims->data[1]);
186   }
187 
188   if (IsHybridOp(input, fw_input_weights)) {
189     OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
190     op_data->fw_compute_row_sums = true;
191     op_data->bw_compute_row_sums = true;
192     TfLiteIntArrayFree(node->temporaries);
193     if (has_aux_input) {
194       node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
195     } else {
196       // No need to create a temporary tensor for the non-existent aux_input.
197       node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1);
198     }
199 
200     node->temporaries->data[kInputQuantized] =
201         op_data->scratch_tensor_index + kInputQuantized;
202     TfLiteTensor* input_quantized;
203     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
204                                                 &input_quantized));
205     input_quantized->type = fw_input_weights->type;
206     input_quantized->allocation_type = kTfLiteArenaRw;
207     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
208       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
209       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
210                                                        input_quantized_size));
211     }
212 
213     node->temporaries->data[kFwHiddenStateQuantized] =
214         op_data->scratch_tensor_index + kFwHiddenStateQuantized;
215     TfLiteTensor* fw_hidden_state_quantized;
216     TF_LITE_ENSURE_OK(context,
217                       GetTemporarySafe(context, node, kFwHiddenStateQuantized,
218                                        &fw_hidden_state_quantized));
219     fw_hidden_state_quantized->type = fw_input_weights->type;
220     fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
221     if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
222                              fw_hidden_state->dims)) {
223       TfLiteIntArray* fw_hidden_state_quantized_size =
224           TfLiteIntArrayCopy(fw_hidden_state->dims);
225       TF_LITE_ENSURE_OK(
226           context, context->ResizeTensor(context, fw_hidden_state_quantized,
227                                          fw_hidden_state_quantized_size));
228     }
229 
230     node->temporaries->data[kBwHiddenStateQuantized] =
231         op_data->scratch_tensor_index + kBwHiddenStateQuantized;
232     TfLiteTensor* bw_hidden_state_quantized;
233     TF_LITE_ENSURE_OK(context,
234                       GetTemporarySafe(context, node, kBwHiddenStateQuantized,
235                                        &bw_hidden_state_quantized));
236     bw_hidden_state_quantized->type = fw_input_weights->type;
237     bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
238     if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
239                              bw_hidden_state->dims)) {
240       TfLiteIntArray* bw_hidden_state_quantized_size =
241           TfLiteIntArrayCopy(bw_hidden_state->dims);
242       TF_LITE_ENSURE_OK(
243           context, context->ResizeTensor(context, bw_hidden_state_quantized,
244                                          bw_hidden_state_quantized_size));
245     }
246 
247     // Allocate temporary tensors to store scaling factors of quantization.
248     node->temporaries->data[kScalingFactors] =
249         op_data->scratch_tensor_index + kScalingFactors;
250     TfLiteTensor* scaling_factors;
251     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScalingFactors,
252                                                 &scaling_factors));
253     scaling_factors->type = kTfLiteFloat32;
254     scaling_factors->allocation_type = kTfLiteArenaRw;
255     int scaling_dims[1] = {batch_size};
256     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
257       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
258       scaling_factors_size->data[0] = batch_size;
259       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
260                                                        scaling_factors_size));
261     }
262     node->temporaries->data[kAccumScratch] =
263         op_data->scratch_tensor_index + kAccumScratch;
264     TfLiteTensor* accum_scratch;
265     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
266                                                 &accum_scratch));
267     accum_scratch->type = kTfLiteInt32;
268     accum_scratch->allocation_type = kTfLiteArenaRw;
269     int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
270                                  batch_size};
271     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
272                                    accum_scratch_dims)) {
273       TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
274       accum_scratch_size->data[0] = accum_scratch_dims[0];
275       accum_scratch_size->data[1] = accum_scratch_dims[1];
276       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
277                                                        accum_scratch_size));
278     }
279     node->temporaries->data[kZeroPoints] =
280         op_data->scratch_tensor_index + kZeroPoints;
281     TfLiteTensor* zero_points;
282     TF_LITE_ENSURE_OK(
283         context,
284         GetTemporarySafe(context, node, /*index=*/kZeroPoints, &zero_points));
285     zero_points->type = kTfLiteInt32;
286     zero_points->allocation_type = kTfLiteArenaRw;
287     int zero_points_dims[1] = {batch_size};
288     if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
289       TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
290       zero_points_size->data[0] = batch_size;
291       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
292                                                        zero_points_size));
293     }
294     const int num_row_sums = has_aux_input ? 3 : 2;
295     node->temporaries->data[kFwRowSums] =
296         op_data->scratch_tensor_index + kFwRowSums;
297     TfLiteTensor* fw_row_sums;
298     TF_LITE_ENSURE_OK(
299         context,
300         GetTemporarySafe(context, node, /*index=*/kFwRowSums, &fw_row_sums));
301     fw_row_sums->type = kTfLiteInt32;
302     fw_row_sums->name = "Lstm_fw_row_sums";
303     fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
304     int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
305     if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
306       TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2);
307       fw_row_sums_size->data[0] = fw_row_sums_dims[0];
308       fw_row_sums_size->data[1] = fw_row_sums_dims[1];
309       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
310                                                        fw_row_sums_size));
311     }
312     node->temporaries->data[kBwRowSums] =
313         op_data->scratch_tensor_index + kBwRowSums;
314     TfLiteTensor* bw_row_sums;
315     TF_LITE_ENSURE_OK(
316         context,
317         GetTemporarySafe(context, node, /*index=*/kBwRowSums, &bw_row_sums));
318     bw_row_sums->type = kTfLiteInt32;
319     bw_row_sums->name = "Lstm_bw_row_sums";
320     bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
321     int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
322     if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
323       TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
324       bw_row_sums_size->data[0] = bw_row_sums_dims[0];
325       bw_row_sums_size->data[1] = bw_row_sums_dims[1];
326       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
327                                                        bw_row_sums_size));
328     }
329     if (has_aux_input) {
330       node->temporaries->data[kAuxInputQuantized] =
331           op_data->scratch_tensor_index + kAuxInputQuantized;
332       TfLiteTensor* aux_input_quantized;
333       TF_LITE_ENSURE_OK(context,
334                         GetTemporarySafe(context, node, kAuxInputQuantized,
335                                          &aux_input_quantized));
336       aux_input_quantized->type = fw_input_weights->type;
337       aux_input_quantized->allocation_type = kTfLiteArenaRw;
338       if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
339         TfLiteIntArray* aux_input_quantized_size =
340             TfLiteIntArrayCopy(aux_input->dims);
341         TF_LITE_ENSURE_OK(context,
342                           context->ResizeTensor(context, aux_input_quantized,
343                                                 aux_input_quantized_size));
344       }
345     }
346   }
347 
348   // Resize outputs.
349   TfLiteTensor* fw_output;
350   TF_LITE_ENSURE_OK(context,
351                     GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
352   TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
353   fw_output_size_array->data[0] = (time_major) ? max_time : batch_size;
354   fw_output_size_array->data[1] = (time_major) ? batch_size : max_time;
355   fw_output_size_array->data[2] =
356       params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
357   TF_LITE_ENSURE_OK(
358       context, context->ResizeTensor(context, fw_output, fw_output_size_array));
359   if (!params->merge_outputs) {
360     TfLiteTensor* bw_output;
361     TF_LITE_ENSURE_OK(
362         context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
363     TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
364     bw_output_size_array->data[0] = batch_size;
365     bw_output_size_array->data[1] = max_time;
366     bw_output_size_array->data[2] = bw_num_units;
367     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output,
368                                                      bw_output_size_array));
369   }
370 
371   return kTfLiteOk;
372 }
373 
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * bw_input,const TfLiteTensor * fw_input_weights,const TfLiteTensor * fw_recurrent_weights,const TfLiteTensor * fw_bias,const TfLiteTensor * bw_input_weights,const TfLiteTensor * bw_recurrent_weights,const TfLiteTensor * bw_bias,const TfLiteTensor * aux_input,const TfLiteTensor * fw_aux_input_weights,const TfLiteTensor * bw_aux_input_weights,const TfLiteBidirectionalSequenceRNNParams * params,TfLiteTensor * fw_hidden_state,TfLiteTensor * fw_output,TfLiteTensor * bw_hidden_state,TfLiteTensor * bw_output)374 TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* bw_input,
375                        const TfLiteTensor* fw_input_weights,
376                        const TfLiteTensor* fw_recurrent_weights,
377                        const TfLiteTensor* fw_bias,
378                        const TfLiteTensor* bw_input_weights,
379                        const TfLiteTensor* bw_recurrent_weights,
380                        const TfLiteTensor* bw_bias,
381                        const TfLiteTensor* aux_input,
382                        const TfLiteTensor* fw_aux_input_weights,
383                        const TfLiteTensor* bw_aux_input_weights,
384                        const TfLiteBidirectionalSequenceRNNParams* params,
385                        TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
386                        TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
387   const bool time_major = params->time_major;
388   const int batch_size =
389       (time_major) ? input->dims->data[1] : input->dims->data[0];
390   const int max_time =
391       (time_major) ? input->dims->data[0] : input->dims->data[1];
392   const int input_size = input->dims->data[2];
393   const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
394 
395   const int fw_num_units = fw_input_weights->dims->data[0];
396   const float* fw_bias_ptr = GetTensorData<float>(fw_bias);
397   const float* fw_input_weights_ptr = GetTensorData<float>(fw_input_weights);
398   const float* fw_recurrent_weights_ptr =
399       GetTensorData<float>(fw_recurrent_weights);
400 
401   const int bw_num_units = bw_input_weights->dims->data[0];
402   const float* bw_bias_ptr = GetTensorData<float>(bw_bias);
403   const float* bw_input_weights_ptr = GetTensorData<float>(bw_input_weights);
404   const float* bw_recurrent_weights_ptr =
405       GetTensorData<float>(bw_recurrent_weights);
406 
407   const float* fw_aux_input_weights_ptr =
408       (fw_aux_input_weights != nullptr)
409           ? GetTensorData<float>(fw_aux_input_weights)
410           : nullptr;
411   const float* bw_aux_input_weights_ptr =
412       (bw_aux_input_weights != nullptr)
413           ? GetTensorData<float>(bw_aux_input_weights)
414           : nullptr;
415 
416   const int fw_output_step =
417       params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
418   const int bw_output_step =
419       params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
420   if (time_major) {
421     // Forward cell.
422     float* fw_hidden_state_ptr_batch = GetTensorData<float>(fw_hidden_state);
423     for (int s = 0; s < max_time; s++) {
424       const float* input_ptr_batch =
425           GetTensorData<float>(input) + s * input_size * batch_size;
426       const float* aux_input_ptr_batch =
427           (aux_input != nullptr)
428               ? GetTensorData<float>(aux_input) + s * input_size * batch_size
429               : nullptr;
430       float* output_ptr_batch =
431           GetTensorData<float>(fw_output) + s * fw_output_step * batch_size;
432 
433       kernel_utils::RnnBatchStep(
434           input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
435           fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
436           input_size, aux_input_size, fw_num_units, batch_size, fw_output_step,
437           params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
438     }
439     // Backward cell.
440     float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
441     for (int s = max_time - 1; s >= 0; s--) {
442       const float* input_ptr_batch =
443           GetTensorData<float>(bw_input) + s * input_size * batch_size;
444       const float* aux_input_ptr_batch =
445           (aux_input != nullptr)
446               ? GetTensorData<float>(aux_input) + s * input_size * batch_size
447               : nullptr;
448       float* output_ptr_batch =
449           (params->merge_outputs
450                ? GetTensorData<float>(fw_output) + fw_num_units
451                : GetTensorData<float>(bw_output)) +
452           s * bw_output_step * batch_size;
453 
454       kernel_utils::RnnBatchStep(
455           input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
456           bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
457           input_size, aux_input_size, bw_num_units, batch_size, bw_output_step,
458           params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
459     }
460   } else {
461     for (int b = 0; b < batch_size; b++) {
462       // Forward cell.
463       float* fw_hidden_state_ptr_batch =
464           GetTensorData<float>(fw_hidden_state) + b * fw_num_units;
465       float* fw_output_offset =
466           GetTensorData<float>(fw_output) + b * fw_output_step * max_time;
467       for (int s = 0; s < max_time; s++) {
468         const float* input_ptr_batch = GetTensorData<float>(input) +
469                                        b * input_size * max_time +
470                                        s * input_size;
471         const float* aux_input_ptr_batch =
472             (aux_input != nullptr)
473                 ? GetTensorData<float>(aux_input) +
474                       b * aux_input_size * max_time + s * aux_input_size
475                 : nullptr;
476         float* output_ptr_batch = fw_output_offset + s * fw_output_step;
477 
478         kernel_utils::RnnBatchStep(
479             input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
480             fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
481             input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
482             fw_output_step, params->activation, fw_hidden_state_ptr_batch,
483             output_ptr_batch);
484       }
485       // Backward cell.
486       float* bw_hidden_state_ptr_batch =
487           GetTensorData<float>(bw_hidden_state) + b * bw_num_units;
488       float* bw_output_offset =
489           params->merge_outputs
490               ? GetTensorData<float>(fw_output) +
491                     b * bw_output_step * max_time + fw_num_units
492               : GetTensorData<float>(bw_output) + b * bw_output_step * max_time;
493       for (int s = max_time - 1; s >= 0; s--) {
494         const float* input_ptr_batch = GetTensorData<float>(input) +
495                                        b * input_size * max_time +
496                                        s * input_size;
497         const float* aux_input_ptr_batch =
498             (aux_input != nullptr)
499                 ? GetTensorData<float>(aux_input) +
500                       b * aux_input_size * max_time + s * aux_input_size
501                 : nullptr;
502         float* output_ptr_batch = bw_output_offset + s * bw_output_step;
503 
504         kernel_utils::RnnBatchStep(
505             input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
506             bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
507             input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
508             bw_output_step, params->activation, bw_hidden_state_ptr_batch,
509             output_ptr_batch);
510       }
511     }
512   }
513   return kTfLiteOk;
514 }
515 
EvalHybrid(const TfLiteTensor * input,const TfLiteTensor * bw_input,const TfLiteTensor * fw_input_weights,const TfLiteTensor * fw_recurrent_weights,const TfLiteTensor * fw_bias,const TfLiteTensor * bw_input_weights,const TfLiteTensor * bw_recurrent_weights,const TfLiteTensor * bw_bias,const TfLiteTensor * aux_input,const TfLiteTensor * aux_fw_input_weights,const TfLiteTensor * aux_bw_input_weights,const TfLiteBidirectionalSequenceRNNParams * params,TfLiteTensor * scaling_factors,TfLiteTensor * input_quantized,TfLiteTensor * aux_input_quantized,TfLiteTensor * fw_hidden_state_quantized,TfLiteTensor * fw_hidden_state,TfLiteTensor * fw_output,TfLiteTensor * bw_hidden_state_quantized,TfLiteTensor * bw_hidden_state,TfLiteTensor * bw_output,TfLiteTensor * zero_points,TfLiteTensor * accum_scratch,TfLiteTensor * fw_row_sums,TfLiteTensor * bw_row_sums,bool * fw_compute_row_sums,bool * bw_compute_row_sums)516 TfLiteStatus EvalHybrid(
517     const TfLiteTensor* input, const TfLiteTensor* bw_input,
518     const TfLiteTensor* fw_input_weights,
519     const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
520     const TfLiteTensor* bw_input_weights,
521     const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
522     const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
523     const TfLiteTensor* aux_bw_input_weights,
524     const TfLiteBidirectionalSequenceRNNParams* params,
525     TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
526     TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
527     TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
528     TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
529     TfLiteTensor* bw_output, TfLiteTensor* zero_points,
530     TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums,
531     TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums,
532     bool* bw_compute_row_sums) {
533   const bool time_major = params->time_major;
534   const int batch_size =
535       (time_major) ? input->dims->data[1] : input->dims->data[0];
536   const int max_time =
537       (time_major) ? input->dims->data[0] : input->dims->data[1];
538   const int input_size = input->dims->data[2];
539   const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
540 
541   const int fw_num_units = fw_input_weights->dims->data[0];
542   const float* fw_bias_ptr = GetTensorData<float>(fw_bias);
543   const int8_t* fw_input_weights_ptr = GetTensorData<int8_t>(fw_input_weights);
544   float fw_input_weights_scale = fw_input_weights->params.scale;
545   const int8_t* fw_recurrent_weights_ptr =
546       GetTensorData<int8_t>(fw_recurrent_weights);
547   float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale;
548 
549   const int bw_num_units = bw_input_weights->dims->data[0];
550   const float* bw_bias_ptr = GetTensorData<float>(bw_bias);
551   const int8_t* bw_input_weights_ptr = GetTensorData<int8_t>(bw_input_weights);
552   float bw_input_weights_scale = bw_input_weights->params.scale;
553   const int8_t* bw_recurrent_weights_ptr =
554       GetTensorData<int8_t>(bw_recurrent_weights);
555   float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale;
556 
557   // Set the auxiliary pointers and scales if needed.
558   const int8_t* aux_fw_input_weights_ptr = nullptr;
559   float aux_fw_input_weights_scale = 0.0f;
560   const int8_t* aux_bw_input_weights_ptr = nullptr;
561   float aux_bw_input_weights_scale = 0.0f;
562   int8_t* aux_quantized_input_ptr = nullptr;
563   if (aux_input_size > 0) {
564     aux_fw_input_weights_ptr = GetTensorData<int8_t>(aux_fw_input_weights);
565     aux_fw_input_weights_scale = aux_fw_input_weights->params.scale;
566     aux_bw_input_weights_ptr = GetTensorData<int8_t>(aux_bw_input_weights);
567     aux_bw_input_weights_scale = aux_bw_input_weights->params.scale;
568     aux_quantized_input_ptr = GetTensorData<int8_t>(aux_input_quantized);
569   }
570 
571   // Initialize temporary storage for quantized values.
572   int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_quantized);
573   int8_t* fw_quantized_hidden_state_ptr =
574       GetTensorData<int8_t>(fw_hidden_state_quantized);
575   int8_t* bw_quantized_hidden_state_ptr =
576       GetTensorData<int8_t>(bw_hidden_state_quantized);
577   float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
578   int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
579   int32_t* zero_points_ptr = nullptr;
580   int32_t* fw_row_sums_ptr = nullptr;
581   int32_t* bw_row_sums_ptr = nullptr;
582   if (params->asymmetric_quantize_inputs) {
583     zero_points_ptr = GetTensorData<int32_t>(zero_points);
584     fw_row_sums_ptr = GetTensorData<int32_t>(fw_row_sums);
585     bw_row_sums_ptr = GetTensorData<int32_t>(bw_row_sums);
586   }
587   const int fw_output_step =
588       params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
589   const int bw_output_step =
590       params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
591 
592   if (time_major) {
593     for (int t = 0; t < max_time; t++) {
594       // Forward cell.
595       float* fw_hidden_state_ptr_batch = GetTensorData<float>(fw_hidden_state);
596       for (int s = 0; s < max_time; s++) {
597         const float* input_ptr_batch =
598             GetTensorData<float>(input) + s * input_size * batch_size;
599         const float* aux_input_ptr_batch =
600             (aux_input != nullptr)
601                 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
602                 : nullptr;
603         float* output_ptr_batch =
604             GetTensorData<float>(fw_output) + s * fw_output_step * batch_size;
605 
606         kernel_utils::RnnBatchStep(
607             input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
608             aux_input_ptr_batch, aux_fw_input_weights_ptr,
609             aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
610             fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
611             fw_num_units, batch_size, fw_output_step, params->activation,
612             quantized_input_ptr, aux_quantized_input_ptr,
613             fw_quantized_hidden_state_ptr, scaling_factors_ptr,
614             fw_hidden_state_ptr_batch, output_ptr_batch,
615             params->asymmetric_quantize_inputs, zero_points_ptr,
616             accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
617       }
618       // Backward cell.
619       float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
620       for (int s = max_time - 1; s >= 0; s--) {
621         const float* input_ptr_batch =
622             GetTensorData<float>(bw_input) + s * input_size * batch_size;
623         const float* aux_input_ptr_batch =
624             (aux_input != nullptr)
625                 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
626                 : nullptr;
627         float* output_ptr_batch =
628             (params->merge_outputs
629                  ? GetTensorData<float>(fw_output) + fw_num_units
630                  : GetTensorData<float>(bw_output)) +
631             s * bw_output_step * batch_size;
632 
633         kernel_utils::RnnBatchStep(
634             input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
635             aux_input_ptr_batch, aux_bw_input_weights_ptr,
636             aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
637             bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
638             bw_num_units, batch_size, bw_output_step, params->activation,
639             quantized_input_ptr, aux_quantized_input_ptr,
640             bw_quantized_hidden_state_ptr, scaling_factors_ptr,
641             bw_hidden_state_ptr_batch, output_ptr_batch,
642             params->asymmetric_quantize_inputs, zero_points_ptr,
643             accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
644       }
645     }
646   } else {
647     for (int b = 0; b < batch_size; b++) {
648       // Forward cell.
649       float* fw_hidden_state_ptr_batch =
650           GetTensorData<float>(fw_hidden_state) + b * fw_num_units;
651       float* fw_output_offset =
652           GetTensorData<float>(fw_output) + b * fw_output_step * max_time;
653       for (int s = 0; s < max_time; s++) {
654         const float* input_ptr_batch = GetTensorData<float>(input) +
655                                        b * input_size * max_time +
656                                        s * input_size;
657         const float* aux_input_ptr_batch =
658             (aux_input != nullptr)
659                 ? GetTensorData<float>(aux_input) + b * input_size * max_time +
660                       s * input_size
661                 : nullptr;
662         float* output_ptr_batch = fw_output_offset + s * fw_output_step;
663 
664         kernel_utils::RnnBatchStep(
665             input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
666             aux_input_ptr_batch, aux_fw_input_weights_ptr,
667             aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
668             fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
669             fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
670             quantized_input_ptr, aux_quantized_input_ptr,
671             fw_quantized_hidden_state_ptr, scaling_factors_ptr,
672             fw_hidden_state_ptr_batch, output_ptr_batch,
673             params->asymmetric_quantize_inputs, zero_points_ptr,
674             accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
675       }
676       // Backward cell.
677       float* bw_hidden_state_ptr_batch =
678           GetTensorData<float>(bw_hidden_state) + b * bw_num_units;
679       float* bw_output_offset =
680           params->merge_outputs
681               ? GetTensorData<float>(fw_output) +
682                     b * bw_output_step * max_time + fw_num_units
683               : GetTensorData<float>(bw_output) + b * bw_output_step * max_time;
684       for (int s = max_time - 1; s >= 0; s--) {
685         const float* input_ptr_batch = GetTensorData<float>(input) +
686                                        b * input_size * max_time +
687                                        s * input_size;
688         const float* aux_input_ptr_batch =
689             (aux_input != nullptr)
690                 ? GetTensorData<float>(aux_input) + b * input_size * max_time +
691                       s * input_size
692                 : nullptr;
693         float* output_ptr_batch = bw_output_offset + s * bw_output_step;
694 
695         kernel_utils::RnnBatchStep(
696             input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
697             aux_input_ptr_batch, aux_bw_input_weights_ptr,
698             aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
699             bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
700             bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
701             quantized_input_ptr, aux_quantized_input_ptr,
702             bw_quantized_hidden_state_ptr, scaling_factors_ptr,
703             bw_hidden_state_ptr_batch, output_ptr_batch,
704             params->asymmetric_quantize_inputs, zero_points_ptr,
705             accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
706       }
707     }
708   }
709   return kTfLiteOk;
710 }
711 
Eval(TfLiteContext * context,TfLiteNode * node)712 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
713   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
714       node->builtin_data);
715 
716   const TfLiteTensor* input;
717   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
718   const TfLiteTensor* fw_input_weights;
719   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
720                                           &fw_input_weights));
721   const TfLiteTensor* fw_recurrent_weights;
722   TF_LITE_ENSURE_OK(context,
723                     GetInputSafe(context, node, kFwRecurrentWeightsTensor,
724                                  &fw_recurrent_weights));
725   const TfLiteTensor* fw_bias;
726   TF_LITE_ENSURE_OK(context,
727                     GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
728   const TfLiteTensor* bw_input_weights;
729   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
730                                           &bw_input_weights));
731   const TfLiteTensor* bw_recurrent_weights;
732   TF_LITE_ENSURE_OK(context,
733                     GetInputSafe(context, node, kBwRecurrentWeightsTensor,
734                                  &bw_recurrent_weights));
735   const TfLiteTensor* bw_bias;
736   TF_LITE_ENSURE_OK(context,
737                     GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
738 
739   // Get auxiliary inputs.
740   const TfLiteTensor* aux_input =
741       GetOptionalInputTensor(context, node, kAuxInputTensor);
742   const TfLiteTensor* fw_aux_input_weights =
743       GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
744   const TfLiteTensor* bw_aux_input_weights =
745       GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
746 
747   TfLiteTensor* fw_hidden_state =
748       GetVariableInput(context, node, kFwHiddenStateTensor);
749   TFLITE_DCHECK(fw_hidden_state != nullptr);
750   TfLiteTensor* bw_hidden_state =
751       GetVariableInput(context, node, kBwHiddenStateTensor);
752   TFLITE_DCHECK(bw_hidden_state != nullptr);
753 
754   TfLiteTensor* fw_output;
755   TF_LITE_ENSURE_OK(context,
756                     GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
757   TfLiteTensor* bw_output = params->merge_outputs
758                                 ? nullptr
759                                 : GetOutput(context, node, kBwOutputTensor);
760 
761   const bool has_previous_bw_output = (aux_input != nullptr);
762   const bool use_aux_input = (fw_aux_input_weights != nullptr);
763 
764   // We want to cover the following cases:
765   //
766   // If not stacking (not connected after other bidi lstms):
767   //   both fw & bw will just use `input`; aux_input will be null.
768   //
769   // If stacking with cross_links, TensorFlow equivalent
770   // (tf.contrib.rnn.stack_bidirectional_rnn):
771   //   both fw & bw will use `input`, but aux_input will be none null.
772   //   Note, this time, whether connected after other bidi lstms both works.
773   //
774   // If stacking without cross_links, but connected after other bidi lstms,
775   // TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
776   //   fw will use `input`, bw will use aux_input, and the `real aux_input`
777   //   will be null.
778 
779   const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
780   const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
781   const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
782 
783   switch (fw_input_weights->type) {
784     case kTfLiteFloat32:
785       return EvalFloat(input, bw_input, fw_input_weights, fw_recurrent_weights,
786                        fw_bias, bw_input_weights, bw_recurrent_weights, bw_bias,
787                        real_aux_input, fw_aux_input_weights,
788                        bw_aux_input_weights, params, fw_hidden_state, fw_output,
789                        bw_hidden_state, bw_output);
790     case kTfLiteUInt8:
791     case kTfLiteInt8: {
792       TfLiteTensor* input_quantized;
793       TF_LITE_ENSURE_OK(
794           context,
795           GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
796       TfLiteTensor* fw_hidden_state_quantized;
797       TF_LITE_ENSURE_OK(context,
798                         GetTemporarySafe(context, node, kFwHiddenStateQuantized,
799                                          &fw_hidden_state_quantized));
800       TfLiteTensor* bw_hidden_state_quantized;
801       TF_LITE_ENSURE_OK(context,
802                         GetTemporarySafe(context, node, kBwHiddenStateQuantized,
803                                          &bw_hidden_state_quantized));
804       TfLiteTensor* scaling_factors;
805       TF_LITE_ENSURE_OK(
806           context,
807           GetTemporarySafe(context, node, kScalingFactors, &scaling_factors));
808       TfLiteTensor* zero_points;
809       TF_LITE_ENSURE_OK(
810           context, GetTemporarySafe(context, node, kZeroPoints, &zero_points));
811       TfLiteTensor* accum_scratch;
812       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
813                                                   &accum_scratch));
814       TfLiteTensor* fw_row_sums;
815       TF_LITE_ENSURE_OK(
816           context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
817       TfLiteTensor* bw_row_sums;
818       TF_LITE_ENSURE_OK(
819           context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
820       TfLiteTensor* aux_input_quantized =
821           use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
822                         : nullptr;
823       auto* op_data = reinterpret_cast<OpData*>(node->user_data);
824       return EvalHybrid(
825           input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias,
826           bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input,
827           fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors,
828           input_quantized, aux_input_quantized, fw_hidden_state_quantized,
829           fw_hidden_state, fw_output, bw_hidden_state_quantized,
830           bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums,
831           bw_row_sums, &op_data->fw_compute_row_sums,
832           &op_data->bw_compute_row_sums);
833     }
834     default:
835       TF_LITE_KERNEL_LOG(context, "Type not currently supported.");
836       return kTfLiteError;
837   }
838 }
839 
840 }  // namespace bidirectional_sequence_rnn
841 
Register_BIDIRECTIONAL_SEQUENCE_RNN()842 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() {
843   static TfLiteRegistration r = {
844       bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free,
845       bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval};
846   return &r;
847 }
848 
849 }  // namespace builtin
850 }  // namespace ops
851 }  // namespace tflite
852