xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <math.h>
17 
18 #include <algorithm>
19 #include <cstddef>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/cpu_backend_context.h"
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
26 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 #include "tensorflow/lite/kernels/lstm_eval.h"
29 #include "tensorflow/lite/kernels/op_macros.h"
30 
31 namespace tflite {
32 namespace ops {
33 namespace builtin {
34 namespace bidirectional_sequence_lstm {
35 
36 // LINT.IfChange
37 
38 // Input Tensors of size {max_time, n_batch, n_input}
39 constexpr int kInputTensor = 0;
40 
41 // Forward LSTM cell tensors.
42 // Input weight tensors of size: {n_cell, n_input}
43 constexpr int kFwInputToInputWeightsTensor = 1;  // Optional
44 constexpr int kFwInputToForgetWeightsTensor = 2;
45 constexpr int kFwInputToCellWeightsTensor = 3;
46 constexpr int kFwInputToOutputWeightsTensor = 4;
47 
48 // Recurrent weight tensors of size {n_cell, n_output}
49 constexpr int kFwRecurrentToInputWeightsTensor = 5;  // Optional
50 constexpr int kFwRecurrentToForgetWeightsTensor = 6;
51 constexpr int kFwRecurrentToCellWeightsTensor = 7;
52 constexpr int kFwRecurrentToOutputWeightsTensor = 8;
53 
54 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
55 constexpr int kFwCellToInputWeightsTensor = 9;    // Optional
56 constexpr int kFwCellToForgetWeightsTensor = 10;  // Optional
57 constexpr int kFwCellToOutputWeightsTensor = 11;  // Optional
58 
59 // Gates bias tensors of size {n_cell}
60 constexpr int kFwInputGateBiasTensor = 12;  // Optional
61 constexpr int kFwForgetGateBiasTensor = 13;
62 constexpr int kFwCellGateBiasTensor = 14;
63 constexpr int kFwOutputGateBiasTensor = 15;
64 
65 // Projection weight tensor of size {n_output, n_cell}
66 constexpr int kFwProjectionWeightsTensor = 16;  // Optional
67 // Projection bias tensor of size {n_output}
68 constexpr int kFwProjectionBiasTensor = 17;  // Optional
69 
70 // Backward LSTM cell tensors.
71 // Input weight tensors of size: {n_cell, n_input}
72 constexpr int kBwInputToInputWeightsTensor = 18;  // Optional
73 constexpr int kBwInputToForgetWeightsTensor = 19;
74 constexpr int kBwInputToCellWeightsTensor = 20;
75 constexpr int kBwInputToOutputWeightsTensor = 21;
76 
77 // Recurrent weight tensors of size {n_cell, n_output}
78 constexpr int kBwRecurrentToInputWeightsTensor = 22;  // Optional
79 constexpr int kBwRecurrentToForgetWeightsTensor = 23;
80 constexpr int kBwRecurrentToCellWeightsTensor = 24;
81 constexpr int kBwRecurrentToOutputWeightsTensor = 25;
82 
83 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
84 constexpr int kBwCellToInputWeightsTensor = 26;   // Optional
85 constexpr int kBwCellToForgetWeightsTensor = 27;  // Optional
86 constexpr int kBwCellToOutputWeightsTensor = 28;  // Optional
87 
88 // Gates bias tensors of size {n_cell}
89 constexpr int kBwInputGateBiasTensor = 29;  // Optional
90 constexpr int kBwForgetGateBiasTensor = 30;
91 constexpr int kBwCellGateBiasTensor = 31;
92 constexpr int kBwOutputGateBiasTensor = 32;
93 
94 // Projection weight tensor of size {n_output, n_cell}
95 constexpr int kBwProjectionWeightsTensor = 33;  // Optional
96 // Projection bias tensor of size {n_output}
97 constexpr int kBwProjectionBiasTensor = 34;  // Optional
98 
99 // Stateful input tensors that are variables and will be modified by the Op.
100 // Activation state tensors of size {n_batch, n_output}
101 constexpr int kFwInputActivationStateTensor = 35;
102 // Cell state tensors of size {n_batch, n_cell}
103 constexpr int kFwInputCellStateTensor = 36;
104 // Activation state tensors of size {n_batch, n_output}
105 constexpr int kBwInputActivationStateTensor = 37;
106 // Cell state tensors of size {n_batch, n_cell}
107 constexpr int kBwInputCellStateTensor = 38;
108 
109 // Used as auxiliary input and weights when stacking for
110 // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
111 // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
112 // (without cross links).
113 constexpr int kAuxInputTensor = 39;  // Optional
114 // Forward weights.
115 constexpr int kFwAuxInputToInputWeightsTensor = 40;   // Optional
116 constexpr int kFwAuxInputToForgetWeightsTensor = 41;  // Optional
117 constexpr int kFwAuxInputToCellWeightsTensor = 42;    // Optional
118 constexpr int kFwAuxInputToOutputWeightsTensor = 43;  // Optional
119 // Backward weights.
120 constexpr int kBwAuxInputToInputWeightsTensor = 44;   // Optional
121 constexpr int kBwAuxInputToForgetWeightsTensor = 45;  // Optional
122 constexpr int kBwAuxInputToCellWeightsTensor = 46;    // Optional
123 constexpr int kBwAuxInputToOutputWeightsTensor = 47;  // Optional
124 
125 // Output tensors.
126 constexpr int kFwOutputTensor = 0;
127 constexpr int kBwOutputTensor = 1;  // Ignored if merge_outputs is set.
128 
129 // LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc)
130 
131 // Temporary tensors.
132 enum TemporaryTensor {
133   // Scratch buffers for input, forget, etc. gates
134   kFwScratchBuffer = 0,
135   kBwScratchBuffer = 1,
136   // Quantized tensors needed for the hybrid kernel.
137   kInputQuantized = 2,
138   kFwActivationStateQuantized = 3,
139   kBwActivationStateQuantized = 4,
140   kFwCellStateQuantized = 5,
141   kBwCellStateQuantized = 6,
142   kInputScalingFactors = 7,
143   kAuxInputScalingFactors = 8,
144   kOutputStateScalingFactors = 9,
145   kProductScalingFactors = 10,
146   kRecoveredCellWeights = 11,
147   kAccumScratchBuffer = 12,
148   kInputZeroPoints = 13,
149   kAuxInputZeroPoints = 14,
150   kOutputStateZeroPoints = 15,
151   kFwRowSums = 16,
152   kBwRowSums = 17,
153   kAuxInputQuantized = 18,  // Optional, quantized tensor for auxiliary input.
154   kNumTemporaryTensors = 19,
155 };
156 
157 struct OpData {
158   int scratch_tensor_index;
159   bool compute_fw_row_sums = false;
160   bool compute_bw_row_sums = false;
161 };
162 
Init(TfLiteContext * context,const char * buffer,size_t length)163 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
164   auto* op_data = new OpData();
165   context->AddTensors(context, kNumTemporaryTensors,
166                       &op_data->scratch_tensor_index);
167   return op_data;
168 }
169 
Free(TfLiteContext * context,void * buffer)170 void Free(TfLiteContext* context, void* buffer) {
171   delete reinterpret_cast<OpData*>(buffer);
172 }
173 
174 // Check that input tensor dimensions matches with each other.
CheckLstmTensorDimensionsAndTypes(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell,int input_to_input_weights_tensor,int input_to_forget_weights_tensor,int input_to_cell_weights_tensor,int input_to_output_weights_tensor,int recurrent_to_input_weights_tensor,int recurrent_to_forget_weights_tensor,int recurrent_to_cell_weights_tensor,int recurrent_to_output_weights_tensor,int cell_to_input_weights_tensor,int cell_to_forget_weights_tensor,int cell_to_output_weights_tensor,int input_gate_bias_tensor,int forget_gate_bias_tensor,int cell_gate_bias_tensor,int output_gate_bias_tensor,int projection_weights_tensor,int projection_bias_tensor)175 TfLiteStatus CheckLstmTensorDimensionsAndTypes(
176     TfLiteContext* context, TfLiteNode* node, int n_input, int n_output,
177     int n_cell, int input_to_input_weights_tensor,
178     int input_to_forget_weights_tensor, int input_to_cell_weights_tensor,
179     int input_to_output_weights_tensor, int recurrent_to_input_weights_tensor,
180     int recurrent_to_forget_weights_tensor,
181     int recurrent_to_cell_weights_tensor,
182     int recurrent_to_output_weights_tensor, int cell_to_input_weights_tensor,
183     int cell_to_forget_weights_tensor, int cell_to_output_weights_tensor,
184     int input_gate_bias_tensor, int forget_gate_bias_tensor,
185     int cell_gate_bias_tensor, int output_gate_bias_tensor,
186     int projection_weights_tensor, int projection_bias_tensor) {
187   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
188       node->builtin_data);
189 
190   // Making sure clipping parameters have valid values.
191   // == 0 means no clipping
192   //  > 0 means clipping
193   TF_LITE_ENSURE(context, params->cell_clip >= 0);
194   TF_LITE_ENSURE(context, params->proj_clip >= 0);
195 
196   const TfLiteTensor* input_to_forget_weights;
197   TF_LITE_ENSURE_OK(context,
198                     GetInputSafe(context, node, input_to_forget_weights_tensor,
199                                  &input_to_forget_weights));
200   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
201   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
202   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
203   TF_LITE_ENSURE(context, (input_to_forget_weights->type == kTfLiteFloat32) ||
204                               (input_to_forget_weights->type == kTfLiteInt8) ||
205                               (input_to_forget_weights->type == kTfLiteUInt8));
206 
207   const TfLiteTensor* input_to_input_weights =
208       GetOptionalInputTensor(context, node, input_to_input_weights_tensor);
209   if (input_to_input_weights != nullptr) {
210     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
211     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
212     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
213     TF_LITE_ENSURE_TYPES_EQ(context, input_to_input_weights->type,
214                             input_to_forget_weights->type);
215   }
216 
217   const TfLiteTensor* input_to_cell_weights;
218   TF_LITE_ENSURE_OK(context,
219                     GetInputSafe(context, node, input_to_cell_weights_tensor,
220                                  &input_to_cell_weights));
221   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
222   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
223   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
224   TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
225                           input_to_forget_weights->type);
226 
227   const TfLiteTensor* input_to_output_weights;
228   TF_LITE_ENSURE_OK(context,
229                     GetInputSafe(context, node, input_to_output_weights_tensor,
230                                  &input_to_output_weights));
231   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
232   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell);
233   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
234   TF_LITE_ENSURE_TYPES_EQ(context, input_to_output_weights->type,
235                           input_to_forget_weights->type);
236 
237   const TfLiteTensor* recurrent_to_input_weights =
238       GetOptionalInputTensor(context, node, recurrent_to_input_weights_tensor);
239   if (recurrent_to_input_weights != nullptr) {
240     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
241     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
242                       n_cell);
243     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
244                       n_output);
245     TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_input_weights->type,
246                             input_to_forget_weights->type);
247   }
248 
249   const TfLiteTensor* recurrent_to_forget_weights;
250   TF_LITE_ENSURE_OK(
251       context, GetInputSafe(context, node, recurrent_to_forget_weights_tensor,
252                             &recurrent_to_forget_weights));
253   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
254   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
255                     n_cell);
256   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
257                     n_output);
258   TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
259                           input_to_forget_weights->type);
260 
261   const TfLiteTensor* recurrent_to_cell_weights;
262   TF_LITE_ENSURE_OK(
263       context, GetInputSafe(context, node, recurrent_to_cell_weights_tensor,
264                             &recurrent_to_cell_weights));
265   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
266   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
267   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
268                     n_output);
269   TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_cell_weights->type,
270                           input_to_forget_weights->type);
271 
272   // We make sure the input-gate's parameters are either both present (regular
273   // LSTM) or not at all (CIFG-LSTM).
274   const bool cifg_weights_all_or_none =
275       ((input_to_input_weights != nullptr) &&
276        (recurrent_to_input_weights != nullptr)) ||
277       ((input_to_input_weights == nullptr) &&
278        (recurrent_to_input_weights == nullptr));
279   TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
280 
281   const TfLiteTensor* cell_to_input_weights =
282       GetOptionalInputTensor(context, node, cell_to_input_weights_tensor);
283   if (cell_to_input_weights != nullptr) {
284     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
285     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
286     TF_LITE_ENSURE_TYPES_EQ(context, cell_to_input_weights->type,
287                             input_to_forget_weights->type);
288   }
289 
290   const TfLiteTensor* cell_to_forget_weights =
291       GetOptionalInputTensor(context, node, cell_to_forget_weights_tensor);
292   if (cell_to_forget_weights != nullptr) {
293     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
294     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
295     TF_LITE_ENSURE_TYPES_EQ(context, cell_to_forget_weights->type,
296                             input_to_forget_weights->type);
297   }
298 
299   const TfLiteTensor* cell_to_output_weights =
300       GetOptionalInputTensor(context, node, cell_to_output_weights_tensor);
301   if (cell_to_output_weights != nullptr) {
302     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
303     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
304     TF_LITE_ENSURE_TYPES_EQ(context, cell_to_output_weights->type,
305                             input_to_forget_weights->type);
306   }
307 
308   // Making sure the peephole weights are there all or none.
309   const bool use_cifg = (input_to_input_weights == nullptr);
310   const bool peephole_weights_all_or_none =
311       ((cell_to_input_weights != nullptr || use_cifg) &&
312        (cell_to_forget_weights != nullptr) &&
313        (cell_to_output_weights != nullptr)) ||
314       ((cell_to_input_weights == nullptr) &&
315        (cell_to_forget_weights == nullptr) &&
316        (cell_to_output_weights == nullptr));
317   TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
318 
319   // Make sure the input gate bias is present only when not a CIFG-LSTM.
320   const TfLiteTensor* input_gate_bias =
321       GetOptionalInputTensor(context, node, input_gate_bias_tensor);
322   if (use_cifg) {
323     TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
324   } else {
325     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
326     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
327     TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
328   }
329 
330   const TfLiteTensor* forget_gate_bias;
331   TF_LITE_ENSURE_OK(
332       context,
333       GetInputSafe(context, node, forget_gate_bias_tensor, &forget_gate_bias));
334   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
335   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
336   TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
337 
338   const TfLiteTensor* cell_gate_bias;
339   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, cell_gate_bias_tensor,
340                                           &cell_gate_bias));
341   TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
342   TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
343   TF_LITE_ENSURE_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
344 
345   const TfLiteTensor* output_gate_bias;
346   TF_LITE_ENSURE_OK(
347       context,
348       GetInputSafe(context, node, output_gate_bias_tensor, &output_gate_bias));
349   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
350   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
351   TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
352 
353   const TfLiteTensor* projection_weights =
354       GetOptionalInputTensor(context, node, projection_weights_tensor);
355   if (projection_weights != nullptr) {
356     TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
357     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
358     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
359     TF_LITE_ENSURE_TYPES_EQ(context, projection_weights->type,
360                             input_to_forget_weights->type);
361   }
362 
363   const TfLiteTensor* projection_bias =
364       GetOptionalInputTensor(context, node, projection_bias_tensor);
365   if (projection_bias != nullptr) {
366     TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
367     TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
368     TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
369   }
370 
371   // Making sure the projection tensors are consistent:
372   // 1) If projection weight is not present, then projection bias should not be
373   // present.
374   // 2) If projection weight is present, then projection bias is optional.
375   // TODO(ghodrat): make sure this is correct.
376   const bool projecton_tensors_consistent =
377       ((projection_weights != nullptr) || (projection_bias == nullptr));
378   TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
379 
380   return kTfLiteOk;
381 }
382 
CheckInputTensorDimensions(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell)383 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
384                                         TfLiteNode* node, int n_input,
385                                         int n_output, int n_cell) {
386   TF_LITE_ENSURE_OK(
387       context,
388       CheckLstmTensorDimensionsAndTypes(
389           context, node, n_input, n_output, n_cell,
390           kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor,
391           kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor,
392           kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor,
393           kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor,
394           kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor,
395           kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor,
396           kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
397           kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
398           kFwProjectionBiasTensor));
399 
400   TF_LITE_ENSURE_OK(
401       context,
402       CheckLstmTensorDimensionsAndTypes(
403           context, node, n_input, n_output, n_cell,
404           kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor,
405           kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor,
406           kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor,
407           kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor,
408           kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor,
409           kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor,
410           kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
411           kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
412           kBwProjectionBiasTensor));
413 
414   // Check if Forward and Backward tensors match along required dimensions.
415   return kTfLiteOk;
416 }
417 
418 // Resize the output and scratch tensors based on the sizes of the input
419 // tensors. Also check that the size of the input tensors match each other.
Prepare(TfLiteContext * context,TfLiteNode * node)420 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
421   auto* op_data = reinterpret_cast<OpData*>(node->user_data);
422   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
423       node->builtin_data);
424 
425   // Check we have all the inputs and outputs we need.
426   TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
427   TF_LITE_ENSURE_EQ(context, node->outputs->size,
428                     params->merge_outputs ? 1 : 2);
429 
430   // Inferring batch size, number of outputs and sequence length and
431   // number of cells from the input tensors.
432   const TfLiteTensor* input;
433   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
434   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
435   TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
436   const bool time_major = params->time_major;
437   const int max_time = time_major ? input->dims->data[0] : input->dims->data[1];
438   const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
439   const int n_input = input->dims->data[2];
440 
441   const TfLiteTensor* fw_input_to_output_weights;
442   TF_LITE_ENSURE_OK(context,
443                     GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
444                                  &fw_input_to_output_weights));
445   const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
446   TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2);
447   TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
448                     n_input);
449 
450   const TfLiteTensor* bw_input_to_output_weights;
451   TF_LITE_ENSURE_OK(context,
452                     GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
453                                  &bw_input_to_output_weights));
454   const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
455   TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
456   TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
457                     n_input);
458   TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type,
459                     fw_input_to_output_weights->type);
460 
461   const TfLiteTensor* fw_recurrent_to_output_weights;
462   TF_LITE_ENSURE_OK(
463       context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
464                             &fw_recurrent_to_output_weights));
465   TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
466   TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0],
467                     n_fw_cell);
468   TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->type,
469                     fw_input_to_output_weights->type);
470   const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
471 
472   const TfLiteTensor* bw_recurrent_to_output_weights;
473   TF_LITE_ENSURE_OK(
474       context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
475                             &bw_recurrent_to_output_weights));
476   TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
477   TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
478                     n_bw_cell);
479   TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->type,
480                     fw_input_to_output_weights->type);
481   const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
482 
483   // Check that input tensor dimensions matches with each other.
484   TF_LITE_ENSURE_OK(
485       context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
486                                           n_fw_cell));
487 
488   // Get (optional) auxiliary inputs and weights.
489   const TfLiteTensor* aux_input =
490       GetOptionalInputTensor(context, node, kAuxInputTensor);
491   const TfLiteTensor* fw_aux_input_to_input_weights =
492       GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
493   const TfLiteTensor* fw_aux_input_to_forget_weights =
494       GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
495   const TfLiteTensor* fw_aux_input_to_cell_weights =
496       GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
497   const TfLiteTensor* fw_aux_input_to_output_weights =
498       GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
499   const TfLiteTensor* bw_aux_input_to_input_weights =
500       GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
501   const TfLiteTensor* bw_aux_input_to_forget_weights =
502       GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
503   const TfLiteTensor* bw_aux_input_to_cell_weights =
504       GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
505   const TfLiteTensor* bw_aux_input_to_output_weights =
506       GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
507 
508   const bool aux_inputs_weights_all_or_none =
509       ((fw_aux_input_to_cell_weights != nullptr) &&
510        (fw_aux_input_to_forget_weights != nullptr) &&
511        (fw_aux_input_to_output_weights != nullptr) &&
512        (bw_aux_input_to_cell_weights != nullptr) &&
513        (bw_aux_input_to_forget_weights != nullptr) &&
514        (bw_aux_input_to_output_weights != nullptr)) ||
515       ((fw_aux_input_to_cell_weights == nullptr) &&
516        (fw_aux_input_to_forget_weights == nullptr) &&
517        (fw_aux_input_to_output_weights == nullptr) &&
518        (bw_aux_input_to_cell_weights == nullptr) &&
519        (bw_aux_input_to_forget_weights == nullptr) &&
520        (bw_aux_input_to_output_weights == nullptr));
521   TF_LITE_ENSURE(context, aux_inputs_weights_all_or_none);
522 
523   const bool has_aux_input = (fw_aux_input_to_forget_weights != nullptr);
524 
525   if (has_aux_input) {
526     // Check that aux_input has the same dimensions (except last) as the input.
527     TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
528     TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
529   }
530 
531   // Get the pointer to output, activation_state and cell_state buffer tensors.
532   TfLiteTensor* fw_output;
533   TF_LITE_ENSURE_OK(context,
534                     GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
535   TfLiteTensor* fw_activation_state =
536       GetVariableInput(context, node, kFwInputActivationStateTensor);
537   TF_LITE_ENSURE(context, fw_activation_state != nullptr);
538   TfLiteTensor* fw_cell_state =
539       GetVariableInput(context, node, kFwInputCellStateTensor);
540   TF_LITE_ENSURE(context, fw_cell_state != nullptr);
541 
542   // Check the shape of input state tensors.
543   // These tensor may be 1D or 2D. It's fine as long as the total size is
544   // correct.
545   TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state),
546                     n_batch * n_fw_output);
547   TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell);
548 
549   // Resize the output tensors.
550   TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
551   fw_output_size->data[0] = time_major ? max_time : n_batch;
552   fw_output_size->data[1] = time_major ? n_batch : max_time;
553   fw_output_size->data[2] =
554       params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output;
555   TF_LITE_ENSURE_OK(context,
556                     context->ResizeTensor(context, fw_output, fw_output_size));
557 
558   // The weights are of consistent type, so it suffices to check one.
559   const bool is_hybrid_op = IsHybridOp(input, fw_input_to_output_weights);
560 
561   TfLiteIntArrayFree(node->temporaries);
562   if (is_hybrid_op) {
563     node->temporaries = TfLiteIntArrayCreate(
564         has_aux_input ? kNumTemporaryTensors : kNumTemporaryTensors - 1);
565   } else {
566     node->temporaries = TfLiteIntArrayCreate(2);  // the two scratch buffers.
567   }
568   // Create a scratch buffer tensor.
569   node->temporaries->data[kFwScratchBuffer] =
570       op_data->scratch_tensor_index + kFwScratchBuffer;
571   TfLiteTensor* fw_scratch_buffer;
572   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
573                                               &fw_scratch_buffer));
574   fw_scratch_buffer->type = input->type;
575   fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
576 
577   const TfLiteTensor* fw_input_to_input_weights =
578       GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
579   const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
580   if (has_aux_input && !fw_use_cifg) {
581     TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0],
582                       fw_input_to_input_weights->dims->data[0]);
583   }
584   TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2);
585   fw_scratch_buffer_size->data[0] = n_batch;
586   if (fw_use_cifg) {
587     // Reserving space for Cell, Forget, Output gates and scratch accumulation
588     // buffer and an extra 16 bytes to avoid internal ruy copies.
589     fw_scratch_buffer_size->data[1] = n_fw_cell * 4 + 16;
590   } else {
591     // Reserving space for Input, Cell, Forget, Output gates and scratch
592     // accumulation buffer and an extra 16 bytes to avoid internal ruy copies.
593     fw_scratch_buffer_size->data[1] = n_fw_cell * 5 + 16;
594   }
595   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer,
596                                                    fw_scratch_buffer_size));
597   // Same for the backward cell.
598 
599   // Check that input tensor dimensions matches with each other.
600   TF_LITE_ENSURE_OK(
601       context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
602                                           n_bw_cell));
603 
604   // Get the pointer to activation_state and cell_state buffer tensors.
605   TfLiteTensor* bw_activation_state =
606       GetVariableInput(context, node, kBwInputActivationStateTensor);
607   TF_LITE_ENSURE(context, bw_activation_state != nullptr);
608   TfLiteTensor* bw_cell_state =
609       GetVariableInput(context, node, kBwInputCellStateTensor);
610   TF_LITE_ENSURE(context, bw_cell_state != nullptr);
611 
612   // Resize the output tensors.
613   if (!params->merge_outputs) {
614     TfLiteTensor* bw_output;
615     TF_LITE_ENSURE_OK(
616         context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
617     TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
618     bw_output_size->data[0] = time_major ? max_time : n_batch;
619     bw_output_size->data[1] = time_major ? n_batch : max_time;
620     bw_output_size->data[2] = n_bw_output;
621     TF_LITE_ENSURE_OK(
622         context, context->ResizeTensor(context, bw_output, bw_output_size));
623   }
624 
625   // Check the shape of input state tensors.
626   // These tensor may be 1D or 2D. It's fine as long as the total size is
627   // correct.
628   TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state),
629                     n_batch * n_bw_output);
630   TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
631 
632   // Create a scratch buffer tensor.
633   node->temporaries->data[kBwScratchBuffer] =
634       op_data->scratch_tensor_index + kBwScratchBuffer;
635   TfLiteTensor* bw_scratch_buffer;
636   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
637                                               &bw_scratch_buffer));
638   bw_scratch_buffer->type = input->type;
639   bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
640 
641   const TfLiteTensor* bw_input_to_input_weights =
642       GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
643   const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
644   if (has_aux_input && !bw_use_cifg) {
645     TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0],
646                       bw_input_to_input_weights->dims->data[0]);
647   }
648   TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2);
649   bw_scratch_buffer_size->data[0] = n_batch;
650   if (bw_use_cifg) {
651     // Reserving space for Cell, Forget, Output gates and scratch accumulation
652     // buffer and an extra 16 bytes to avoid internal ruy copies.
653     bw_scratch_buffer_size->data[1] = n_bw_cell * 4;
654   } else {
655     // Reserving space for Input, Cell, Forget, Output gates and scratch
656     // accumulation buffer and an extra 16 bytes to avoid internal ruy copies.
657     bw_scratch_buffer_size->data[1] = n_bw_cell * 5;
658   }
659   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
660                                                    bw_scratch_buffer_size));
661   if (is_hybrid_op) {
662     // Compute the row sums for cached zero_point offset calculation.
663     op_data->compute_fw_row_sums = true;
664     op_data->compute_bw_row_sums = true;
665     // Allocate temporary tensors to store quantized values of input, aux_input
666     // (if present), activation_state and cell_state tensors.
667     node->temporaries->data[kInputQuantized] =
668         op_data->scratch_tensor_index + kInputQuantized;
669     TfLiteTensor* input_quantized;
670     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
671                                                 &input_quantized));
672     input_quantized->type = fw_input_to_output_weights->type;
673     input_quantized->allocation_type = kTfLiteArenaRw;
674     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
675       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
676       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
677                                                        input_quantized_size));
678     }
679 
680     node->temporaries->data[kFwActivationStateQuantized] =
681         op_data->scratch_tensor_index + kFwActivationStateQuantized;
682     TfLiteTensor* fw_activation_state_quantized;
683     TF_LITE_ENSURE_OK(
684         context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
685                                   &fw_activation_state_quantized));
686     fw_activation_state_quantized->type = fw_input_to_output_weights->type;
687     fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
688     if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
689                              fw_activation_state->dims)) {
690       TfLiteIntArray* fw_activation_state_quantized_size =
691           TfLiteIntArrayCopy(fw_activation_state->dims);
692       TF_LITE_ENSURE_OK(
693           context, context->ResizeTensor(context, fw_activation_state_quantized,
694                                          fw_activation_state_quantized_size));
695     }
696     node->temporaries->data[kBwActivationStateQuantized] =
697         op_data->scratch_tensor_index + kBwActivationStateQuantized;
698     TfLiteTensor* bw_activation_state_quantized;
699     TF_LITE_ENSURE_OK(
700         context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
701                                   &bw_activation_state_quantized));
702     bw_activation_state_quantized->type = fw_input_to_output_weights->type;
703     bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
704     if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
705                              bw_activation_state->dims)) {
706       TfLiteIntArray* bw_activation_state_quantized_size =
707           TfLiteIntArrayCopy(bw_activation_state->dims);
708       TF_LITE_ENSURE_OK(
709           context, context->ResizeTensor(context, bw_activation_state_quantized,
710                                          bw_activation_state_quantized_size));
711     }
712     node->temporaries->data[kFwCellStateQuantized] =
713         op_data->scratch_tensor_index + kFwCellStateQuantized;
714     TfLiteTensor* fw_cell_state_quantized;
715     TF_LITE_ENSURE_OK(context,
716                       GetTemporarySafe(context, node, kFwCellStateQuantized,
717                                        &fw_cell_state_quantized));
718     fw_cell_state_quantized->type = fw_input_to_output_weights->type;
719     fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
720     if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
721                              fw_cell_state->dims)) {
722       TfLiteIntArray* fw_cell_state_quantized_size =
723           TfLiteIntArrayCopy(fw_cell_state->dims);
724       TF_LITE_ENSURE_OK(context,
725                         context->ResizeTensor(context, fw_cell_state_quantized,
726                                               fw_cell_state_quantized_size));
727     }
728     node->temporaries->data[kBwCellStateQuantized] =
729         op_data->scratch_tensor_index + kBwCellStateQuantized;
730     TfLiteTensor* bw_cell_state_quantized;
731     TF_LITE_ENSURE_OK(context,
732                       GetTemporarySafe(context, node, kBwCellStateQuantized,
733                                        &bw_cell_state_quantized));
734     bw_cell_state_quantized->type = fw_input_to_output_weights->type;
735     bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
736     if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
737                              bw_cell_state->dims)) {
738       TfLiteIntArray* bw_cell_state_quantized_size =
739           TfLiteIntArrayCopy(bw_cell_state->dims);
740       TF_LITE_ENSURE_OK(context,
741                         context->ResizeTensor(context, bw_cell_state_quantized,
742                                               bw_cell_state_quantized_size));
743     }
744 
745     // Allocate temporary tensors to store scaling factors and product scaling
746     // factors. The latter is a convenience storage which allows to quantize
747     // a vector once (which produces the scaling factors) and multiply it with
748     // different matrices (which requires multiplying the scaling factors with
749     // the scaling factor of the matrix).
750     node->temporaries->data[kInputScalingFactors] =
751         op_data->scratch_tensor_index + kInputScalingFactors;
752     TfLiteTensor* input_sf;
753     TF_LITE_ENSURE_OK(
754         context,
755         GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
756     input_sf->type = kTfLiteFloat32;
757     input_sf->allocation_type = kTfLiteArenaRw;
758     int scaling_dims[1] = {n_batch};
759     if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
760       TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
761       input_sf_size->data[0] = n_batch;
762       TF_LITE_ENSURE_OK(
763           context, context->ResizeTensor(context, input_sf, input_sf_size));
764     }
765     node->temporaries->data[kAuxInputScalingFactors] =
766         op_data->scratch_tensor_index + kAuxInputScalingFactors;
767     TfLiteTensor* aux_input_sf;
768     TF_LITE_ENSURE_OK(context,
769                       GetTemporarySafe(context, node, kAuxInputScalingFactors,
770                                        &aux_input_sf));
771     aux_input_sf->type = kTfLiteFloat32;
772     aux_input_sf->allocation_type = kTfLiteArenaRw;
773     if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) {
774       TfLiteIntArray* aux_input_sf_size = TfLiteIntArrayCreate(1);
775       aux_input_sf_size->data[0] = n_batch;
776       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_sf,
777                                                        aux_input_sf_size));
778     }
779     node->temporaries->data[kOutputStateScalingFactors] =
780         op_data->scratch_tensor_index + kOutputStateScalingFactors;
781     TfLiteTensor* output_state_sf;
782     TF_LITE_ENSURE_OK(
783         context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
784                                   &output_state_sf));
785     output_state_sf->type = kTfLiteFloat32;
786     output_state_sf->allocation_type = kTfLiteArenaRw;
787     if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
788       TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
789       output_state_sf_size->data[0] = n_batch;
790       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
791                                                        output_state_sf_size));
792     }
793     node->temporaries->data[kProductScalingFactors] =
794         op_data->scratch_tensor_index + kProductScalingFactors;
795     TfLiteTensor* prod_scaling_factors;
796     TF_LITE_ENSURE_OK(context,
797                       GetTemporarySafe(context, node, kProductScalingFactors,
798                                        &prod_scaling_factors));
799     prod_scaling_factors->type = kTfLiteFloat32;
800     prod_scaling_factors->allocation_type = kTfLiteArenaRw;
801     if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
802                                    scaling_dims)) {
803       TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
804       prod_scaling_factors_size->data[0] = n_batch;
805       TF_LITE_ENSURE_OK(context,
806                         context->ResizeTensor(context, prod_scaling_factors,
807                                               prod_scaling_factors_size));
808     }
809 
810     // Allocate a temporary tensor to store the recovered cell weights. Since
811     // this is used for diagonal matrices, only need to store n_cell values.
812     node->temporaries->data[kRecoveredCellWeights] =
813         op_data->scratch_tensor_index + kRecoveredCellWeights;
814     TfLiteTensor* recovered_cell_weights;
815     TF_LITE_ENSURE_OK(context,
816                       GetTemporarySafe(context, node, kRecoveredCellWeights,
817                                        &recovered_cell_weights));
818     recovered_cell_weights->type = kTfLiteFloat32;
819     recovered_cell_weights->allocation_type = kTfLiteArenaRw;
820     int recovered_cell_dims[1] = {n_fw_cell};
821     if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
822                                    recovered_cell_dims)) {
823       TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
824       recovered_cell_weights_size->data[0] = n_fw_cell;
825       TF_LITE_ENSURE_OK(context,
826                         context->ResizeTensor(context, recovered_cell_weights,
827                                               recovered_cell_weights_size));
828     }
829 
830     // Allocate a temporary tensor to store the accumulated int32 values.
831     node->temporaries->data[kAccumScratchBuffer] =
832         op_data->scratch_tensor_index + kAccumScratchBuffer;
833     TfLiteTensor* accum_scratch;
834     TF_LITE_ENSURE_OK(
835         context,
836         GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
837     accum_scratch->type = kTfLiteInt32;
838     accum_scratch->allocation_type = kTfLiteArenaRw;
839     int n_cell = std::max(n_fw_cell, n_bw_cell);
840     if (has_aux_input) {
841       n_cell = std::max(n_cell, fw_aux_input_to_output_weights->dims->data[0]);
842       n_cell = std::max(n_cell, bw_aux_input_to_output_weights->dims->data[0]);
843     }
844     int accum_scratch_dims[2] = {n_cell, n_batch};
845     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
846                                    accum_scratch_dims)) {
847       TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
848       accum_size->data[0] = n_cell;
849       accum_size->data[1] = n_batch;
850       TF_LITE_ENSURE_OK(
851           context, context->ResizeTensor(context, accum_scratch, accum_size));
852     }
853 
854     // Allocate temporary tensors for storing zero-points.
855     node->temporaries->data[kInputZeroPoints] =
856         op_data->scratch_tensor_index + kInputZeroPoints;
857     TfLiteTensor* input_zp;
858     TF_LITE_ENSURE_OK(
859         context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
860     input_zp->type = kTfLiteFloat32;
861     input_zp->allocation_type = kTfLiteArenaRw;
862     if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
863       TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
864       input_zp_size->data[0] = n_batch;
865       TF_LITE_ENSURE_OK(
866           context, context->ResizeTensor(context, input_zp, input_zp_size));
867     }
868     node->temporaries->data[kAuxInputZeroPoints] =
869         op_data->scratch_tensor_index + kAuxInputZeroPoints;
870     TfLiteTensor* aux_input_zp;
871     TF_LITE_ENSURE_OK(
872         context,
873         GetTemporarySafe(context, node, kAuxInputZeroPoints, &aux_input_zp));
874     aux_input_zp->type = kTfLiteFloat32;
875     aux_input_zp->allocation_type = kTfLiteArenaRw;
876     if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) {
877       TfLiteIntArray* aux_input_zp_size = TfLiteIntArrayCreate(1);
878       aux_input_zp_size->data[0] = n_batch;
879       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_zp,
880                                                        aux_input_zp_size));
881     }
882     node->temporaries->data[kOutputStateZeroPoints] =
883         op_data->scratch_tensor_index + kOutputStateZeroPoints;
884     TfLiteTensor* output_state_zp;
885     TF_LITE_ENSURE_OK(context,
886                       GetTemporarySafe(context, node, kOutputStateZeroPoints,
887                                        &output_state_zp));
888     output_state_zp->type = kTfLiteFloat32;
889     output_state_zp->allocation_type = kTfLiteArenaRw;
890     if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
891       TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
892       output_state_zp_size->data[0] = n_batch;
893       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
894                                                        output_state_zp_size));
895     }
896 
897     // Allocate temporary tensors for caching row sums for hybrid zero-point
898     // calculations.
899     int fw_row_sums_rows = fw_use_cifg ? 6 : 8;
900     if (has_aux_input) {
901       fw_row_sums_rows += fw_use_cifg ? 3 : 4;
902     }
903     const TfLiteTensor* fw_projection_weights =
904         GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
905     if (fw_projection_weights != nullptr) {
906       fw_row_sums_rows += ceil(static_cast<float>(n_fw_output) / n_fw_cell);
907     }
908     node->temporaries->data[kFwRowSums] =
909         op_data->scratch_tensor_index + kFwRowSums;
910     TfLiteTensor* fw_row_sums;
911     TF_LITE_ENSURE_OK(
912         context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
913     fw_row_sums->type = kTfLiteInt32;
914     fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
915     int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
916     if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
917       TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2);
918       fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0];
919       fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1];
920       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
921                                                        fw_hybrid_scratch_size));
922     }
923 
924     int bw_row_sums_rows = bw_use_cifg ? 6 : 8;
925     if (has_aux_input) {
926       bw_row_sums_rows += bw_use_cifg ? 3 : 4;
927     }
928     const TfLiteTensor* bw_projection_weights =
929         GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
930     if (bw_projection_weights != nullptr) {
931       bw_row_sums_rows += ceil(static_cast<float>(n_bw_output) / n_bw_cell);
932     }
933     node->temporaries->data[kBwRowSums] =
934         op_data->scratch_tensor_index + kBwRowSums;
935     TfLiteTensor* bw_row_sums;
936     TF_LITE_ENSURE_OK(
937         context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
938     bw_row_sums->type = kTfLiteInt32;
939     bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
940     int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
941     if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
942       TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
943       bw_row_sums_size->data[0] = bw_row_sums_dims[0];
944       bw_row_sums_size->data[1] = bw_row_sums_dims[1];
945       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
946                                                        bw_row_sums_size));
947     }
948 
949     // Only allocate a temporary tensor for quantized auxiliary input if we are
950     // actually going to use it.
951     if (has_aux_input) {
952       node->temporaries->data[kAuxInputQuantized] =
953           op_data->scratch_tensor_index + kAuxInputQuantized;
954       TfLiteTensor* aux_input_quantized;
955       TF_LITE_ENSURE_OK(context,
956                         GetTemporarySafe(context, node, kAuxInputQuantized,
957                                          &aux_input_quantized));
958       aux_input_quantized->type = fw_input_to_output_weights->type;
959       aux_input_quantized->allocation_type = kTfLiteArenaRw;
960       if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
961         TfLiteIntArray* aux_input_quantized_size =
962             TfLiteIntArrayCopy(aux_input->dims);
963         TF_LITE_ENSURE_OK(context,
964                           context->ResizeTensor(context, aux_input_quantized,
965                                                 aux_input_quantized_size));
966       }
967     }
968   }
969   return kTfLiteOk;
970 }
971 
972 // The LSTM Op engine.
Eval(TfLiteContext * context,TfLiteNode * node)973 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
974   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
975       node->builtin_data);
976   auto* op_data = reinterpret_cast<OpData*>(node->user_data);
977   // Input tensor.
978   const TfLiteTensor* input;
979   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
980 
981   // Tensors for the forward cell.
982   const TfLiteTensor* fw_input_to_input_weights =
983       GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
984   const TfLiteTensor* fw_input_to_forget_weights;
985   TF_LITE_ENSURE_OK(context,
986                     GetInputSafe(context, node, kFwInputToForgetWeightsTensor,
987                                  &fw_input_to_forget_weights));
988   const TfLiteTensor* fw_input_to_cell_weights;
989   TF_LITE_ENSURE_OK(context,
990                     GetInputSafe(context, node, kFwInputToCellWeightsTensor,
991                                  &fw_input_to_cell_weights));
992   const TfLiteTensor* fw_input_to_output_weights;
993   TF_LITE_ENSURE_OK(context,
994                     GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
995                                  &fw_input_to_output_weights));
996 
997   const TfLiteTensor* fw_recurrent_to_input_weights =
998       GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor);
999   const TfLiteTensor* fw_recurrent_to_forget_weights;
1000   TF_LITE_ENSURE_OK(
1001       context, GetInputSafe(context, node, kFwRecurrentToForgetWeightsTensor,
1002                             &fw_recurrent_to_forget_weights));
1003   const TfLiteTensor* fw_recurrent_to_cell_weights;
1004   TF_LITE_ENSURE_OK(context,
1005                     GetInputSafe(context, node, kFwRecurrentToCellWeightsTensor,
1006                                  &fw_recurrent_to_cell_weights));
1007   const TfLiteTensor* fw_recurrent_to_output_weights;
1008   TF_LITE_ENSURE_OK(
1009       context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
1010                             &fw_recurrent_to_output_weights));
1011 
1012   const TfLiteTensor* fw_cell_to_input_weights =
1013       GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor);
1014   const TfLiteTensor* fw_cell_to_forget_weights =
1015       GetOptionalInputTensor(context, node, kFwCellToForgetWeightsTensor);
1016   const TfLiteTensor* fw_cell_to_output_weights =
1017       GetOptionalInputTensor(context, node, kFwCellToOutputWeightsTensor);
1018 
1019   const TfLiteTensor* fw_input_gate_bias =
1020       GetOptionalInputTensor(context, node, kFwInputGateBiasTensor);
1021   const TfLiteTensor* fw_forget_gate_bias;
1022   TF_LITE_ENSURE_OK(context,
1023                     GetInputSafe(context, node, kFwForgetGateBiasTensor,
1024                                  &fw_forget_gate_bias));
1025   const TfLiteTensor* fw_cell_gate_bias;
1026   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwCellGateBiasTensor,
1027                                           &fw_cell_gate_bias));
1028   const TfLiteTensor* fw_output_gate_bias;
1029   TF_LITE_ENSURE_OK(context,
1030                     GetInputSafe(context, node, kFwOutputGateBiasTensor,
1031                                  &fw_output_gate_bias));
1032 
1033   const TfLiteTensor* fw_projection_weights =
1034       GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
1035   const TfLiteTensor* fw_projection_bias =
1036       GetOptionalInputTensor(context, node, kFwProjectionBiasTensor);
1037 
1038   TfLiteTensor* fw_activation_state =
1039       GetVariableInput(context, node, kFwInputActivationStateTensor);
1040   TFLITE_DCHECK(fw_activation_state != nullptr);
1041   TfLiteTensor* fw_cell_state =
1042       GetVariableInput(context, node, kFwInputCellStateTensor);
1043   TFLITE_DCHECK(fw_cell_state != nullptr);
1044   TfLiteTensor* fw_output;
1045   TF_LITE_ENSURE_OK(context,
1046                     GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
1047 
1048   // Tensors for the backward cell.
1049   const TfLiteTensor* bw_input_to_input_weights =
1050       GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
1051   const TfLiteTensor* bw_input_to_forget_weights;
1052   TF_LITE_ENSURE_OK(context,
1053                     GetInputSafe(context, node, kBwInputToForgetWeightsTensor,
1054                                  &bw_input_to_forget_weights));
1055   const TfLiteTensor* bw_input_to_cell_weights;
1056   TF_LITE_ENSURE_OK(context,
1057                     GetInputSafe(context, node, kBwInputToCellWeightsTensor,
1058                                  &bw_input_to_cell_weights));
1059   const TfLiteTensor* bw_input_to_output_weights;
1060   TF_LITE_ENSURE_OK(context,
1061                     GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
1062                                  &bw_input_to_output_weights));
1063 
1064   const TfLiteTensor* bw_recurrent_to_input_weights =
1065       GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor);
1066   const TfLiteTensor* bw_recurrent_to_forget_weights;
1067   TF_LITE_ENSURE_OK(
1068       context, GetInputSafe(context, node, kBwRecurrentToForgetWeightsTensor,
1069                             &bw_recurrent_to_forget_weights));
1070   const TfLiteTensor* bw_recurrent_to_cell_weights;
1071   TF_LITE_ENSURE_OK(context,
1072                     GetInputSafe(context, node, kBwRecurrentToCellWeightsTensor,
1073                                  &bw_recurrent_to_cell_weights));
1074   const TfLiteTensor* bw_recurrent_to_output_weights;
1075   TF_LITE_ENSURE_OK(
1076       context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
1077                             &bw_recurrent_to_output_weights));
1078 
1079   const TfLiteTensor* bw_cell_to_input_weights =
1080       GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor);
1081   const TfLiteTensor* bw_cell_to_forget_weights =
1082       GetOptionalInputTensor(context, node, kBwCellToForgetWeightsTensor);
1083   const TfLiteTensor* bw_cell_to_output_weights =
1084       GetOptionalInputTensor(context, node, kBwCellToOutputWeightsTensor);
1085 
1086   const TfLiteTensor* bw_input_gate_bias =
1087       GetOptionalInputTensor(context, node, kBwInputGateBiasTensor);
1088   const TfLiteTensor* bw_forget_gate_bias;
1089   TF_LITE_ENSURE_OK(context,
1090                     GetInputSafe(context, node, kBwForgetGateBiasTensor,
1091                                  &bw_forget_gate_bias));
1092   const TfLiteTensor* bw_cell_gate_bias;
1093   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwCellGateBiasTensor,
1094                                           &bw_cell_gate_bias));
1095   const TfLiteTensor* bw_output_gate_bias;
1096   TF_LITE_ENSURE_OK(context,
1097                     GetInputSafe(context, node, kBwOutputGateBiasTensor,
1098                                  &bw_output_gate_bias));
1099 
1100   const TfLiteTensor* bw_projection_weights =
1101       GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
1102   const TfLiteTensor* bw_projection_bias =
1103       GetOptionalInputTensor(context, node, kBwProjectionBiasTensor);
1104 
1105   // State tensors.
1106   TfLiteTensor* bw_activation_state =
1107       GetVariableInput(context, node, kBwInputActivationStateTensor);
1108   TFLITE_DCHECK(bw_activation_state != nullptr);
1109   TfLiteTensor* bw_cell_state =
1110       GetVariableInput(context, node, kBwInputCellStateTensor);
1111   TFLITE_DCHECK(bw_cell_state != nullptr);
1112   TfLiteTensor* bw_output = params->merge_outputs
1113                                 ? nullptr
1114                                 : GetOutput(context, node, kBwOutputTensor);
1115 
1116   // Temporary tensors.
1117   TfLiteTensor* fw_scratch_buffer;
1118   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
1119                                               &fw_scratch_buffer));
1120   TfLiteTensor* bw_scratch_buffer;
1121   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
1122                                               &bw_scratch_buffer));
1123 
1124   // (Optional) auxiliary inputs.
1125   const TfLiteTensor* aux_input =
1126       GetOptionalInputTensor(context, node, kAuxInputTensor);
1127   const TfLiteTensor* fw_aux_input_to_input_weights =
1128       GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
1129   const TfLiteTensor* fw_aux_input_to_forget_weights =
1130       GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
1131   const TfLiteTensor* fw_aux_input_to_cell_weights =
1132       GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
1133   const TfLiteTensor* fw_aux_input_to_output_weights =
1134       GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
1135   const TfLiteTensor* bw_aux_input_to_input_weights =
1136       GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
1137   const TfLiteTensor* bw_aux_input_to_forget_weights =
1138       GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
1139   const TfLiteTensor* bw_aux_input_to_cell_weights =
1140       GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
1141   const TfLiteTensor* bw_aux_input_to_output_weights =
1142       GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
1143 
1144   const bool has_previous_bw_output = (aux_input != nullptr);
1145   const bool use_aux_input = (fw_aux_input_to_forget_weights != nullptr);
1146 
1147   // Populate a TfLiteLSTMParams struct for the evaluation functions.
1148   TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
1149                                   params->proj_clip, kTfLiteLSTMFullKernel,
1150                                   params->asymmetric_quantize_inputs};
1151 
1152   const int bw_output_offset =
1153       params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
1154   const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
1155 
1156   const bool time_major = params->time_major;
1157 
1158   // We want to cover the following cases:
1159   //
1160   // If not stacking (not connected after other bidi lstms):
1161   //   both fw & bw will just use `input`; aux_input will be null.
1162   //
1163   // If stacking with cross_links, TensorFlow equivalent
1164   // (tf.contrib.rnn.stack_bidirectional_rnn):
1165   //   both fw & bw will use `input`, but aux_input will be none null.
1166   //   Note, this time, whether connected after other bidi lstms both works.
1167   //
1168   // If stacking without cross_links, but connected after other bidi lstms,
1169   // TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
1170   //   fw will use `input`, bw will use aux_input, and the `real aux_input`
1171   //   will be null.
1172 
1173   const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
1174   const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
1175   const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
1176 
1177   switch (fw_input_to_output_weights->type) {
1178     case kTfLiteFloat32: {
1179       TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
1180           input, fw_input_to_input_weights, fw_input_to_forget_weights,
1181           fw_input_to_cell_weights, fw_input_to_output_weights,
1182           fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
1183           fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
1184           fw_cell_to_input_weights, fw_cell_to_forget_weights,
1185           fw_cell_to_output_weights,
1186           /*input_layer_norm_coefficients=*/nullptr,
1187           /*forget_layer_norm_coefficients=*/nullptr,
1188           /*cell_layer_norm_coefficients=*/nullptr,
1189           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1190           fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
1191           fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
1192           fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
1193           fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
1194           &lstm_params,
1195           /*forward_sequence=*/true, time_major, /*output_offset=*/0,
1196           fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output,
1197           CpuBackendContext::GetFromContext(context));
1198       TF_LITE_ENSURE_OK(context, fw_pass_status);
1199 
1200       TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
1201           bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
1202           bw_input_to_cell_weights, bw_input_to_output_weights,
1203           bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
1204           bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
1205           bw_cell_to_input_weights, bw_cell_to_forget_weights,
1206           bw_cell_to_output_weights,
1207           /*input_layer_norm_coefficients=*/nullptr,
1208           /*forget_layer_norm_coefficients=*/nullptr,
1209           /*cell_layer_norm_coefficients=*/nullptr,
1210           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1211           bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
1212           bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
1213           bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
1214           bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
1215           &lstm_params,
1216           /*forward_sequence=*/false, time_major, bw_output_offset,
1217           bw_scratch_buffer, bw_activation_state, bw_cell_state,
1218           actual_bw_output, CpuBackendContext::GetFromContext(context));
1219       TF_LITE_ENSURE_OK(context, bw_pass_status);
1220       return kTfLiteOk;
1221     }
1222     case kTfLiteUInt8:
1223     case kTfLiteInt8: {
1224       TfLiteTensor* input_quantized;
1225       TF_LITE_ENSURE_OK(
1226           context,
1227           GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
1228       TfLiteTensor* fw_activation_state_quantized;
1229       TF_LITE_ENSURE_OK(
1230           context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
1231                                     &fw_activation_state_quantized));
1232       TfLiteTensor* bw_activation_state_quantized;
1233       TF_LITE_ENSURE_OK(
1234           context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
1235                                     &bw_activation_state_quantized));
1236       TfLiteTensor* fw_cell_state_quantized;
1237       TF_LITE_ENSURE_OK(context,
1238                         GetTemporarySafe(context, node, kFwCellStateQuantized,
1239                                          &fw_cell_state_quantized));
1240       TfLiteTensor* bw_cell_state_quantized;
1241       TF_LITE_ENSURE_OK(context,
1242                         GetTemporarySafe(context, node, kBwCellStateQuantized,
1243                                          &bw_cell_state_quantized));
1244       TfLiteTensor* prod_scaling_factors;
1245       TF_LITE_ENSURE_OK(context,
1246                         GetTemporarySafe(context, node, kProductScalingFactors,
1247                                          &prod_scaling_factors));
1248       TfLiteTensor* recovered_cell_weights;
1249       TF_LITE_ENSURE_OK(context,
1250                         GetTemporarySafe(context, node, kRecoveredCellWeights,
1251                                          &recovered_cell_weights));
1252       TfLiteTensor* aux_input_quantized =
1253           use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
1254                         : nullptr;
1255       TfLiteTensor* accum_scratch;
1256       TF_LITE_ENSURE_OK(
1257           context,
1258           GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
1259       TfLiteTensor* fw_row_sums;
1260       TF_LITE_ENSURE_OK(
1261           context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
1262       TfLiteTensor* bw_row_sums;
1263       TF_LITE_ENSURE_OK(
1264           context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
1265       const int fw_row_sums_size = fw_row_sums->dims->data[0];
1266       const int bw_row_sums_size = bw_row_sums->dims->data[0];
1267       TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
1268           input, fw_input_to_input_weights,
1269           /*input_to_input_weights_ledger*/ nullptr, fw_input_to_forget_weights,
1270           /*input_to_forget_weights_ledger*/ nullptr, fw_input_to_cell_weights,
1271           /*input_to_cell_weights_ledger*/ nullptr, fw_input_to_output_weights,
1272           /*input_to_output_weights_ledger*/ nullptr,
1273           fw_recurrent_to_input_weights,
1274           /*recurrent_to_input_weights_ledger*/ nullptr,
1275           fw_recurrent_to_forget_weights,
1276           /*recurrent_to_forget_weights_ledger*/ nullptr,
1277           fw_recurrent_to_cell_weights,
1278           /*recurrent_to_cell_weights_ledger*/ nullptr,
1279           fw_recurrent_to_output_weights,
1280           /*recurrent_to_output_weights_ledger*/ nullptr,
1281           fw_cell_to_input_weights, fw_cell_to_forget_weights,
1282           fw_cell_to_output_weights,
1283           /*input_layer_norm_coefficients=*/nullptr,
1284           /*forget_layer_norm_coefficients=*/nullptr,
1285           /*cell_layer_norm_coefficients=*/nullptr,
1286           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1287           fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
1288           fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
1289           fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
1290           fw_output_gate_bias, fw_projection_weights,
1291           /*projection_weights_ledger*/ nullptr, fw_projection_bias,
1292           &lstm_params,
1293           /*forward_sequence=*/true, time_major, /*output_offset=*/0,
1294           fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
1295           GetTemporary(context, node, kAuxInputScalingFactors),
1296           GetTemporary(context, node, kOutputStateScalingFactors),
1297           prod_scaling_factors, recovered_cell_weights, input_quantized,
1298           aux_input_quantized, fw_activation_state_quantized,
1299           fw_cell_state_quantized, fw_activation_state, fw_cell_state,
1300           accum_scratch, fw_output,
1301           GetTemporary(context, node, kInputZeroPoints),
1302           GetTemporary(context, node, kAuxInputZeroPoints),
1303           GetTemporary(context, node, kOutputStateZeroPoints), fw_row_sums,
1304           fw_row_sums_size, &op_data->compute_fw_row_sums,
1305           CpuBackendContext::GetFromContext(context));
1306       TF_LITE_ENSURE_OK(context, fw_pass_status);
1307 
1308       TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
1309           bw_input, bw_input_to_input_weights,
1310           /*input_to_input_weights_ledger*/ nullptr, bw_input_to_forget_weights,
1311           /*input_to_forget_weights_ledger*/ nullptr, bw_input_to_cell_weights,
1312           /*input_to_cell_weights_ledger*/ nullptr, bw_input_to_output_weights,
1313           /*input_to_output_weights_ledger*/ nullptr,
1314           bw_recurrent_to_input_weights,
1315           /*recurrent_to_input_weights_ledger*/ nullptr,
1316           bw_recurrent_to_forget_weights,
1317           /*recurrent_to_forget_weights_ledger*/ nullptr,
1318           bw_recurrent_to_cell_weights,
1319           /*recurrent_to_cell_weights_ledger*/ nullptr,
1320           bw_recurrent_to_output_weights,
1321           /*recurrent_to_output_weights_ledger*/ nullptr,
1322           bw_cell_to_input_weights, bw_cell_to_forget_weights,
1323           bw_cell_to_output_weights,
1324           /*input_layer_norm_coefficients=*/nullptr,
1325           /*forget_layer_norm_coefficients=*/nullptr,
1326           /*cell_layer_norm_coefficients=*/nullptr,
1327           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1328           bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
1329           bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
1330           bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
1331           bw_output_gate_bias, bw_projection_weights,
1332           /*projection_weights_ledger*/ nullptr, bw_projection_bias,
1333           &lstm_params,
1334           /*forward_sequence=*/false, time_major, bw_output_offset,
1335           bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
1336           GetTemporary(context, node, kAuxInputScalingFactors),
1337           GetTemporary(context, node, kOutputStateScalingFactors),
1338           prod_scaling_factors, recovered_cell_weights, input_quantized,
1339           aux_input_quantized, bw_activation_state_quantized,
1340           bw_cell_state_quantized, bw_activation_state, bw_cell_state,
1341           accum_scratch, actual_bw_output,
1342           GetTemporary(context, node, kInputZeroPoints),
1343           GetTemporary(context, node, kAuxInputZeroPoints),
1344           GetTemporary(context, node, kOutputStateZeroPoints), bw_row_sums,
1345           bw_row_sums_size, &op_data->compute_bw_row_sums,
1346           CpuBackendContext::GetFromContext(context));
1347       TF_LITE_ENSURE_OK(context, bw_pass_status);
1348       return kTfLiteOk;
1349     }
1350     default:
1351       TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
1352                          TfLiteTypeGetName(fw_input_to_output_weights->type));
1353       return kTfLiteError;
1354   }
1355 }
1356 
1357 }  // namespace bidirectional_sequence_lstm
1358 
Register_BIDIRECTIONAL_SEQUENCE_LSTM()1359 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() {
1360   static TfLiteRegistration r = {
1361       bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free,
1362       bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval};
1363   return &r;
1364 }
1365 
1366 }  // namespace builtin
1367 }  // namespace ops
1368 }  // namespace tflite
1369