xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <math.h>
17 
18 #include <algorithm>
19 #include <cstddef>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/cpu_backend_context.h"
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
26 #include "tensorflow/lite/kernels/internal/quantization_util.h"
27 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
28 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 #include "tensorflow/lite/kernels/lstm_eval.h"
31 #include "tensorflow/lite/kernels/lstm_shared.h"
32 
33 namespace tflite {
34 namespace ops {
35 namespace builtin {
36 namespace unidirectional_sequence_lstm {
37 namespace {
38 
39 struct OpData {
40   // If the lstm is layer norm.
41   bool use_layer_norm;
42   // The scratch tensor index.
43   int scratch_tensor_index;
44   bool compute_row_sums = false;
45 
46   lstm_eval::IntegerLstmParameter integer_lstm_param;
47 };
48 
PopulateQuantizedLstmParams8x8_16(TfLiteContext * context,TfLiteNode * node,lstm_eval::IntegerLstmParameter * integer_lstm_param)49 TfLiteStatus PopulateQuantizedLstmParams8x8_16(
50     TfLiteContext* context, TfLiteNode* node,
51     lstm_eval::IntegerLstmParameter* integer_lstm_param) {
52   // Calculate quantized clip for projection and cell.
53   const auto* params =
54       static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(node->builtin_data);
55   const float cell_clip = params->cell_clip;
56   const float proj_clip = params->proj_clip;
57 
58   const TfLiteTensor* cell_state =
59       GetVariableInput(context, node, lstm::full::kCellStateTensor);
60   TF_LITE_ENSURE(context, cell_state != nullptr);
61   TfLiteTensor* output_tensor;
62   TF_LITE_ENSURE_OK(
63       context,
64       GetOutputSafe(context, node, lstm::full::kOutputTensor, &output_tensor));
65 
66   TF_LITE_ENSURE(context,
67                  cell_state->quantization.type != kTfLiteNoQuantization);
68   auto* cell_state_params =
69       static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
70   TF_LITE_ENSURE(context,
71                  output_tensor->quantization.type != kTfLiteNoQuantization);
72   auto* proj_params = static_cast<TfLiteAffineQuantization*>(
73       output_tensor->quantization.params);
74   if (cell_clip > 0.0) {
75     integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
76         std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
77         32767.0f));
78   } else {
79     integer_lstm_param->quantized_cell_clip = 0;
80   }
81   if (proj_clip > 0.0) {
82     integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
83         std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
84   } else {
85     integer_lstm_param->quantized_proj_clip = 0;
86   }
87 
88   // Calculate effective scales.
89   OpData* op_data = static_cast<OpData*>(node->user_data);
90   const bool use_layer_norm = op_data->use_layer_norm;
91 
92   const TfLiteTensor* input;
93   TF_LITE_ENSURE_OK(
94       context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
95 
96   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
97       context, node, lstm::full::kInputToInputWeightsTensor);
98   const TfLiteTensor* input_to_forget_weights;
99   TF_LITE_ENSURE_OK(
100       context,
101       GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
102                    &input_to_forget_weights));
103   const TfLiteTensor* input_to_cell_weights;
104   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
105                                           lstm::full::kInputToCellWeightsTensor,
106                                           &input_to_cell_weights));
107   const TfLiteTensor* input_to_output_weights;
108   TF_LITE_ENSURE_OK(
109       context,
110       GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
111                    &input_to_output_weights));
112 
113   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
114       context, node, lstm::full::kRecurrentToInputWeightsTensor);
115   const TfLiteTensor* recurrent_to_forget_weights;
116   TF_LITE_ENSURE_OK(
117       context,
118       GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
119                    &recurrent_to_forget_weights));
120   const TfLiteTensor* recurrent_to_cell_weights;
121   TF_LITE_ENSURE_OK(
122       context,
123       GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
124                    &recurrent_to_cell_weights));
125   const TfLiteTensor* recurrent_to_output_weights;
126   TF_LITE_ENSURE_OK(
127       context,
128       GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
129                    &recurrent_to_output_weights));
130 
131   const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
132       context, node, lstm::full::kCellToInputWeightsTensor);
133   const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
134       context, node, lstm::full::kCellToForgetWeightsTensor);
135   const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
136       context, node, lstm::full::kCellToOutputWeightsTensor);
137 
138   const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
139       context, node, lstm::full::kInputLayerNormCoefficientsTensor);
140   const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
141       context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
142   const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
143       context, node, lstm::full::kCellLayerNormCoefficientsTensor);
144   const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
145       context, node, lstm::full::kOutputLayerNormCoefficientsTensor);
146 
147   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
148       context, node, lstm::full::kProjectionWeightsTensor);
149 
150   TfLiteTensor* output_state =
151       GetVariableInput(context, node, lstm::full::kOutputStateTensor);
152   TF_LITE_ENSURE(context, output_state != nullptr);
153 
154   // Since we have already checked that weights are all there or none, we can
155   // check the existence of only one to get the condition.
156   const bool use_cifg = (input_to_input_weights == nullptr);
157   const bool use_peephole = (cell_to_output_weights != nullptr);
158   const bool use_projection = (projection_weights != nullptr);
159 
160   // Get intermediate scales and zero points.
161   std::vector<float> intermediate_scale;
162   std::vector<int32> intermediate_zp;
163   for (int i = 0; i < 4; ++i) {
164     if (use_layer_norm) {
165       TfLiteTensor* intermediate;
166       TF_LITE_ENSURE_OK(context,
167                         GetIntermediatesSafe(context, node, i, &intermediate));
168       TF_LITE_ENSURE(context,
169                      intermediate->quantization.type != kTfLiteNoQuantization);
170       auto* params = static_cast<TfLiteAffineQuantization*>(
171           intermediate->quantization.params);
172       intermediate_scale.push_back(params->scale->data[0]);
173       intermediate_zp.push_back(params->zero_point->data[0]);
174     } else {
175       // Q3.12 for activation functions.
176       intermediate_scale.push_back(std::pow(2, -12));
177       intermediate_zp.push_back(0);
178     }
179   }
180   // In the absence of projection, hidden becomes otuput and this intermediate
181   // is ignored.
182   TfLiteTensor* hidden;
183   TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
184   TF_LITE_ENSURE(context, hidden->quantization.type != kTfLiteNoQuantization);
185   auto* hidden_params =
186       static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
187   intermediate_scale.push_back(hidden_params->scale->data[0]);
188   intermediate_zp.push_back(hidden_params->zero_point->data[0]);
189 
190   // Scales.
191   const float default_scale = 1.0;
192   float input_scale = default_scale;
193   float input_to_input_weight_scale = default_scale;
194   float recurrent_to_input_weight_scale = default_scale;
195   float cell_to_input_weight_scale = default_scale;
196   float input_to_forget_weight_scale = default_scale;
197   float recurrent_to_forget_weight_scale = default_scale;
198   float cell_to_forget_weight_scale = default_scale;
199   float input_to_cell_weight_scale = default_scale;
200   float recurrent_to_cell_weight_scale = default_scale;
201   float input_to_output_weight_scale = default_scale;
202   float recurrent_to_output_weight_scale = default_scale;
203   float cell_to_output_weight_scale = default_scale;
204   float projection_weight_scale = default_scale;
205   float layer_norm_input_scale = default_scale;
206   float layer_norm_forget_scale = default_scale;
207   float layer_norm_cell_scale = default_scale;
208   float layer_norm_output_scale = default_scale;
209   float output_state_scale = default_scale;
210   int cell_scale = 1;
211 
212   // Effective scales.
213   float effective_input_to_input_scale = default_scale;
214   float effective_recurrent_to_input_scale = default_scale;
215   float effective_cell_to_input_scale = default_scale;
216   float effective_input_to_forget_scale = default_scale;
217   float effective_recurrent_to_forget_scale = default_scale;
218   float effective_cell_to_forget_scale = default_scale;
219   float effective_input_to_cell_scale = default_scale;
220   float effective_recurrent_to_cell_scale = default_scale;
221   float effective_input_to_output_scale = default_scale;
222   float effective_recurrent_to_output_scale = default_scale;
223   float effective_cell_to_output_scale = default_scale;
224   float effective_proj_scale = default_scale;
225   float effective_hidden_scale = default_scale;
226 
227   // Populate scales.
228   if (!use_cifg) {
229     input_to_input_weight_scale = input_to_input_weights->params.scale;
230     recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
231   }
232 
233   if (use_peephole) {
234     if (!use_cifg) {
235       cell_to_input_weight_scale = cell_to_input_weights->params.scale;
236     }
237     cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
238     cell_to_output_weight_scale = cell_to_output_weights->params.scale;
239   }
240 
241   if (use_layer_norm) {
242     if (!use_cifg) {
243       layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
244     }
245     layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
246     layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
247     layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
248   }
249 
250   if (use_projection) {
251     projection_weight_scale = projection_weights->params.scale;
252   }
253   output_state_scale = output_state->params.scale;
254 
255   input_to_forget_weight_scale = input_to_forget_weights->params.scale;
256   input_to_cell_weight_scale = input_to_cell_weights->params.scale;
257   input_to_output_weight_scale = input_to_output_weights->params.scale;
258   recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
259   recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
260   recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
261 
262   // Check cell state (already used above)
263   TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
264   // TF_LITE_ENSURE(context, cell_scale <= -9);
265   integer_lstm_param->cell_scale = cell_scale;
266   input_scale = input->params.scale;
267 
268   // Calculate effective scales.
269   if (!use_cifg) {
270     effective_input_to_input_scale =
271         input_to_input_weight_scale * input_scale / intermediate_scale[0];
272     effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
273                                          output_state_scale /
274                                          intermediate_scale[0];
275   }
276   effective_input_to_forget_scale =
277       input_to_forget_weight_scale * input_scale / intermediate_scale[1];
278   effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
279                                         output_state_scale /
280                                         intermediate_scale[1];
281 
282   effective_input_to_cell_scale =
283       input_to_cell_weight_scale * input_scale / intermediate_scale[2];
284   effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
285                                       output_state_scale /
286                                       intermediate_scale[2];
287 
288   effective_input_to_output_scale =
289       input_to_output_weight_scale * input_scale / intermediate_scale[3];
290   effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
291                                         output_state_scale /
292                                         intermediate_scale[3];
293 
294   effective_hidden_scale =
295       std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15);
296 
297   effective_proj_scale =
298       projection_weight_scale * intermediate_scale[4] / output_state_scale;
299 
300   if (use_peephole) {
301     if (!use_cifg) {
302       effective_cell_to_input_scale = std::pow(2, cell_scale) *  // NOLINT
303                                       cell_to_input_weight_scale /
304                                       intermediate_scale[0];
305     }
306     effective_cell_to_forget_scale = std::pow(2, cell_scale) *  // NOLINT
307                                      cell_to_forget_weight_scale /
308                                      intermediate_scale[1];
309     effective_cell_to_output_scale = std::pow(2, cell_scale) *  // NOLINT
310                                      cell_to_output_weight_scale /
311                                      intermediate_scale[3];
312   }
313 
314   // Decompose scales.
315   QuantizeMultiplier(effective_input_to_input_scale,
316                      &integer_lstm_param->effective_input_to_input_scale_a,
317                      &integer_lstm_param->effective_input_to_input_scale_b);
318   QuantizeMultiplier(effective_recurrent_to_input_scale,
319                      &integer_lstm_param->effective_recurrent_to_input_scale_a,
320                      &integer_lstm_param->effective_recurrent_to_input_scale_b);
321   QuantizeMultiplier(effective_cell_to_input_scale,
322                      &integer_lstm_param->effective_cell_to_input_scale_a,
323                      &integer_lstm_param->effective_cell_to_input_scale_b);
324   QuantizeMultiplier(effective_input_to_forget_scale,
325                      &integer_lstm_param->effective_input_to_forget_scale_a,
326                      &integer_lstm_param->effective_input_to_forget_scale_b);
327   QuantizeMultiplier(
328       effective_recurrent_to_forget_scale,
329       &integer_lstm_param->effective_recurrent_to_forget_scale_a,
330       &integer_lstm_param->effective_recurrent_to_forget_scale_b);
331   QuantizeMultiplier(effective_cell_to_forget_scale,
332                      &integer_lstm_param->effective_cell_to_forget_scale_a,
333                      &integer_lstm_param->effective_cell_to_forget_scale_b);
334   QuantizeMultiplier(effective_input_to_cell_scale,
335                      &integer_lstm_param->effective_input_to_cell_scale_a,
336                      &integer_lstm_param->effective_input_to_cell_scale_b);
337   QuantizeMultiplier(effective_recurrent_to_cell_scale,
338                      &integer_lstm_param->effective_recurrent_to_cell_scale_a,
339                      &integer_lstm_param->effective_recurrent_to_cell_scale_b);
340   QuantizeMultiplier(effective_input_to_output_scale,
341                      &integer_lstm_param->effective_input_to_output_scale_a,
342                      &integer_lstm_param->effective_input_to_output_scale_b);
343   QuantizeMultiplier(
344       effective_recurrent_to_output_scale,
345       &integer_lstm_param->effective_recurrent_to_output_scale_a,
346       &integer_lstm_param->effective_recurrent_to_output_scale_b);
347   QuantizeMultiplier(effective_cell_to_output_scale,
348                      &integer_lstm_param->effective_cell_to_output_scale_a,
349                      &integer_lstm_param->effective_cell_to_output_scale_b);
350   QuantizeMultiplier(effective_proj_scale,
351                      &integer_lstm_param->effective_proj_scale_a,
352                      &integer_lstm_param->effective_proj_scale_b);
353   QuantizeMultiplier(effective_hidden_scale,
354                      &integer_lstm_param->effective_hidden_scale_a,
355                      &integer_lstm_param->effective_hidden_scale_b);
356   QuantizeMultiplier(layer_norm_input_scale,
357                      &integer_lstm_param->layer_norm_input_scale_a,
358                      &integer_lstm_param->layer_norm_input_scale_b);
359   QuantizeMultiplier(layer_norm_forget_scale,
360                      &integer_lstm_param->layer_norm_forget_scale_a,
361                      &integer_lstm_param->layer_norm_forget_scale_b);
362   QuantizeMultiplier(layer_norm_cell_scale,
363                      &integer_lstm_param->layer_norm_cell_scale_a,
364                      &integer_lstm_param->layer_norm_cell_scale_b);
365   QuantizeMultiplier(layer_norm_output_scale,
366                      &integer_lstm_param->layer_norm_output_scale_a,
367                      &integer_lstm_param->layer_norm_output_scale_b);
368 
369   integer_lstm_param->hidden_zp = intermediate_zp[4];
370 
371   // 10000 is used to make sure the kernel logic does not overflow.
372   if (!use_cifg) {
373     integer_lstm_param->input_variance_guard =
374         std::max(1, static_cast<int32_t>(10000 * layer_norm_input_scale));
375   }
376   integer_lstm_param->forget_variance_guard =
377       std::max(1, static_cast<int32_t>(10000 * layer_norm_forget_scale));
378   integer_lstm_param->cell_variance_guard =
379       std::max(1, static_cast<int32_t>(10000 * layer_norm_cell_scale));
380   integer_lstm_param->output_variance_guard =
381       std::max(1, static_cast<int32_t>(10000 * layer_norm_output_scale));
382 
383   return kTfLiteOk;
384 }
385 
386 }  // namespace
387 
388 // Temporary tensors
389 enum TemporaryTensor {
390   kScratchBuffer = 0,
391   kInputQuantized = 1,
392   kOutputStateQuantized = 2,
393   kCellStateQuantized = 3,
394   kInputScalingFactors = 4,
395   kOutputStateScalingFactors = 5,
396   kProductScalingFactors = 6,
397   kRecoveredCellWeights = 7,
398   kAccumScratch = 8,
399   kInputZeroPoints = 9,
400   kOutputStateZeroPoints = 10,
401   kRowSums = 11,
402   kNumTemporaryTensors = 12,
403 };
404 
Init(TfLiteContext * context,const char * buffer,size_t length)405 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
406   auto* op_data = new OpData();
407   context->AddTensors(context, kNumTemporaryTensors,
408                       &op_data->scratch_tensor_index);
409   return op_data;
410 }
411 
Free(TfLiteContext * context,void * buffer)412 void Free(TfLiteContext* context, void* buffer) {
413   delete reinterpret_cast<OpData*>(buffer);
414 }
415 
416 // Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell,bool use_layer_norm,bool is_integer)417 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
418                                         TfLiteNode* node, int n_input,
419                                         int n_output, int n_cell,
420                                         bool use_layer_norm, bool is_integer) {
421   const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
422 
423   // Making sure clipping parameters have valid values.
424   // == 0 means no clipping
425   //  > 0 means clipping
426   TF_LITE_ENSURE(context, params->cell_clip >= 0);
427   TF_LITE_ENSURE(context, params->proj_clip >= 0);
428 
429   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
430       context, node, lstm::full::kInputToInputWeightsTensor);
431   if (input_to_input_weights != nullptr) {
432     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
433     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
434     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
435   }
436 
437   const TfLiteTensor* input_to_forget_weights;
438   TF_LITE_ENSURE_OK(
439       context,
440       GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
441                    &input_to_forget_weights));
442   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
443   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
444   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
445 
446   const TfLiteTensor* input_to_cell_weights;
447   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
448                                           lstm::full::kInputToCellWeightsTensor,
449                                           &input_to_cell_weights));
450   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
451   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
452   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
453 
454   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
455       context, node, lstm::full::kRecurrentToInputWeightsTensor);
456   if (recurrent_to_input_weights != nullptr) {
457     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
458     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
459                       n_cell);
460     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
461                       n_output);
462   }
463 
464   const TfLiteTensor* recurrent_to_forget_weights;
465   TF_LITE_ENSURE_OK(
466       context,
467       GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
468                    &recurrent_to_forget_weights));
469   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
470   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
471                     n_cell);
472   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
473                     n_output);
474 
475   const TfLiteTensor* recurrent_to_cell_weights;
476   TF_LITE_ENSURE_OK(
477       context,
478       GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
479                    &recurrent_to_cell_weights));
480   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
481   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
482   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
483                     n_output);
484 
485   // We make sure the input-gate's parameters are either both present (regular
486   // LSTM) or not at all (CIFG-LSTM).
487   const bool cifg_weights_all_or_none =
488       ((input_to_input_weights != nullptr) &&
489        (recurrent_to_input_weights != nullptr)) ||
490       ((input_to_input_weights == nullptr) &&
491        (recurrent_to_input_weights == nullptr));
492   TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
493 
494   const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
495       context, node, lstm::full::kCellToInputWeightsTensor);
496   if (cell_to_input_weights != nullptr) {
497     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
498     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
499     TF_LITE_ENSURE_TYPES_EQ(
500         context, cell_to_input_weights->type,
501         is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
502   }
503 
504   const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
505       context, node, lstm::full::kCellToForgetWeightsTensor);
506   if (cell_to_forget_weights != nullptr) {
507     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
508     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
509     TF_LITE_ENSURE_TYPES_EQ(
510         context, cell_to_forget_weights->type,
511         is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
512   }
513 
514   const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
515       context, node, lstm::full::kCellToOutputWeightsTensor);
516   if (cell_to_output_weights != nullptr) {
517     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
518     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
519     TF_LITE_ENSURE_TYPES_EQ(
520         context, cell_to_output_weights->type,
521         is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
522   }
523 
524   // Making sure the peephole weights are there all or none.
525   const bool use_cifg = (input_to_input_weights == nullptr);
526   const bool peephole_weights_all_or_none =
527       ((cell_to_input_weights != nullptr || use_cifg) &&
528        (cell_to_forget_weights != nullptr) &&
529        (cell_to_output_weights != nullptr)) ||
530       ((cell_to_input_weights == nullptr) &&
531        (cell_to_forget_weights == nullptr) &&
532        (cell_to_output_weights == nullptr));
533   TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
534 
535   // Make sure the input gate bias is present only when not a CIFG-LSTM.
536   const TfLiteTensor* input_gate_bias =
537       GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
538   if (use_cifg) {
539     TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
540   } else {
541     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
542     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
543     if (is_integer) {
544       TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
545     } else {
546       TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
547     }
548   }
549 
550   const TfLiteTensor* forget_gate_bias;
551   TF_LITE_ENSURE_OK(
552       context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
553                             &forget_gate_bias));
554   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
555   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
556   if (is_integer) {
557     TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
558   } else {
559     TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
560   }
561 
562   const TfLiteTensor* cell_gate_bias;
563   TF_LITE_ENSURE_OK(context,
564                     GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
565                                  &cell_gate_bias));
566   TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
567   TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
568   if (is_integer) {
569     TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
570   } else {
571     TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
572   }
573 
574   const TfLiteTensor* output_gate_bias;
575   TF_LITE_ENSURE_OK(
576       context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
577                             &output_gate_bias));
578   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
579   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
580   if (is_integer) {
581     TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
582   } else {
583     TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
584   }
585 
586   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
587       context, node, lstm::full::kProjectionWeightsTensor);
588   if (projection_weights != nullptr) {
589     TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
590     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
591     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
592   }
593 
594   const TfLiteTensor* projection_bias =
595       GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
596   if (projection_bias != nullptr) {
597     TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
598     TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
599     if (is_integer) {
600       TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
601     } else {
602       TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
603     }
604   }
605 
606   // Making sure the projection tensors are consistent:
607   // 1) If projection weight is not present, then projection bias should not be
608   // present.
609   // 2) If projection weight is present, then projection bias is optional.
610   // TODO(ghodrat): make sure this is correct.
611   const bool projecton_tensors_consistent =
612       ((projection_weights != nullptr) || (projection_bias == nullptr));
613   TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
614 
615   if (use_layer_norm) {
616     const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
617         context, node, lstm::full::kInputLayerNormCoefficientsTensor);
618     if (use_cifg) {
619       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
620     } else {
621       TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
622       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
623       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
624                         n_cell);
625       if (is_integer) {
626         TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
627                                 kTfLiteInt16);
628       } else {
629         TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
630                                 kTfLiteFloat32);
631       }
632     }
633 
634     const TfLiteTensor* forget_layer_norm_coefficients;
635     TF_LITE_ENSURE_OK(
636         context, GetInputSafe(context, node,
637                               lstm::full::kForgetLayerNormCoefficientsTensor,
638                               &forget_layer_norm_coefficients));
639     TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
640     TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
641                       n_cell);
642     if (is_integer) {
643       TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
644                               kTfLiteInt16);
645     } else {
646       TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
647                               kTfLiteFloat32);
648     }
649 
650     const TfLiteTensor* cell_layer_norm_coefficients;
651     TF_LITE_ENSURE_OK(context,
652                       GetInputSafe(context, node,
653                                    lstm::full::kCellLayerNormCoefficientsTensor,
654                                    &cell_layer_norm_coefficients));
655     TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
656     TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
657                       n_cell);
658     if (is_integer) {
659       TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
660                               kTfLiteInt16);
661     } else {
662       TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
663                               kTfLiteFloat32);
664     }
665 
666     const TfLiteTensor* output_layer_norm_coefficients;
667     TF_LITE_ENSURE_OK(
668         context, GetInputSafe(context, node,
669                               lstm::full::kOutputLayerNormCoefficientsTensor,
670                               &output_layer_norm_coefficients));
671     TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
672     TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
673                       n_cell);
674     if (is_integer) {
675       TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
676                               kTfLiteInt16);
677     } else {
678       TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
679                               kTfLiteFloat32);
680     }
681   }
682 
683   return kTfLiteOk;
684 }
685 
PrecomputeZeroPointTimesWeightWithBias(TfLiteContext * context,int32_t zero_point,const TfLiteTensor * weight_tensor,const TfLiteTensor * bias_tensor,std::unique_ptr<int32_t[]> * output)686 TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
687     TfLiteContext* context, int32_t zero_point,
688     const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor,
689     std::unique_ptr<int32_t[]>* output) {
690   if (weight_tensor == nullptr) {
691     return kTfLiteOk;
692   }
693 
694   const RuntimeShape& weight_shape = GetTensorShape(weight_tensor);
695   TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2);
696   const int row = weight_shape.Dims(0);
697   const int col = weight_shape.Dims(1);
698   output->reset(new int32_t[row]);
699   if (bias_tensor == nullptr) {
700     memset(output->get(), 0, row * sizeof(int32_t));
701   } else {
702     const int32_t* bias = GetTensorData<int32_t>(bias_tensor);
703     memcpy(output->get(), bias, row * sizeof(int32_t));
704   }
705   if (zero_point != 0) {
706     const int8_t* weight = GetTensorData<int8_t>(weight_tensor);
707     tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, col,
708                                                  output->get());
709   }
710   return kTfLiteOk;
711 }
712 
PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext * context,OpData * op_data,TfLiteNode * node)713 TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
714                                                        OpData* op_data,
715                                                        TfLiteNode* node) {
716   const TfLiteTensor* input;
717   TF_LITE_ENSURE_OK(
718       context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
719   const TfLiteTensor* output_state =
720       GetVariableInput(context, node, lstm::full::kOutputStateTensor);
721   TF_LITE_ENSURE(context, output_state != nullptr);
722 
723   const int32_t input_zero_point = -input->params.zero_point;
724   const int32_t output_state_zero_point = -output_state->params.zero_point;
725 
726   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
727       context, node, lstm::full::kInputToInputWeightsTensor);
728   const TfLiteTensor* input_to_forget_weights;
729   TF_LITE_ENSURE_OK(
730       context,
731       GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
732                    &input_to_forget_weights));
733   const TfLiteTensor* input_to_cell_weights;
734   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
735                                           lstm::full::kInputToCellWeightsTensor,
736                                           &input_to_cell_weights));
737   const TfLiteTensor* input_to_output_weights;
738   TF_LITE_ENSURE_OK(
739       context,
740       GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
741                    &input_to_output_weights));
742 
743   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
744       context, node, lstm::full::kRecurrentToInputWeightsTensor);
745   const TfLiteTensor* recurrent_to_forget_weights;
746   TF_LITE_ENSURE_OK(
747       context,
748       GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
749                    &recurrent_to_forget_weights));
750   const TfLiteTensor* recurrent_to_cell_weights;
751   TF_LITE_ENSURE_OK(
752       context,
753       GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
754                    &recurrent_to_cell_weights));
755   const TfLiteTensor* recurrent_to_output_weights;
756   TF_LITE_ENSURE_OK(
757       context,
758       GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
759                    &recurrent_to_output_weights));
760 
761   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
762       context, node, lstm::full::kProjectionWeightsTensor);
763   const TfLiteTensor* projection_bias =
764       GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
765 
766   lstm_eval::IntegerLstmParameter* integer_lstm_params =
767       &op_data->integer_lstm_param;
768 
769   const TfLiteTensor* intermediate =
770       &context->tensors[node->intermediates->data[4]];
771   TF_LITE_ENSURE(context,
772                  intermediate->quantization.type != kTfLiteNoQuantization);
773   const auto* params =
774       static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params);
775   const int32_t hidden_zp = params->zero_point->data[0];
776 
777   // Get bias and perform zero point calculation.
778   // When there is layer normalization, the gate bias does not apply to matmul
779   // directly:
780   //      y = ln(w * x + w * r + w * c) + b.
781   const bool is_layer_norm = op_data->use_layer_norm;
782 
783   // Forget gate.
784   const TfLiteTensor* forget_gate_bias =
785       is_layer_norm
786           ? nullptr
787           : GetInput(context, node, lstm::full::kForgetGateBiasTensor);
788   TF_LITE_ENSURE_OK(
789       context,
790       PrecomputeZeroPointTimesWeightWithBias(
791           context, input_zero_point, input_to_forget_weights, forget_gate_bias,
792           &(integer_lstm_params->input_to_forget_effective_bias)));
793 
794   TF_LITE_ENSURE_OK(
795       context,
796       PrecomputeZeroPointTimesWeightWithBias(
797           context, output_state_zero_point, recurrent_to_forget_weights,
798           nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
799 
800   // Modulation gate.
801   const TfLiteTensor* cell_gate_bias =
802       is_layer_norm ? nullptr
803                     : GetInput(context, node, lstm::full::kCellGateBiasTensor);
804   TF_LITE_ENSURE_OK(
805       context,
806       PrecomputeZeroPointTimesWeightWithBias(
807           context, input_zero_point, input_to_cell_weights, cell_gate_bias,
808           &(integer_lstm_params->input_to_cell_effective_bias)));
809   TF_LITE_ENSURE_OK(
810       context,
811       PrecomputeZeroPointTimesWeightWithBias(
812           context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
813           &(integer_lstm_params->recurrent_to_cell_effective_bias)));
814 
815   // Output gate.
816   const TfLiteTensor* output_gate_bias =
817       is_layer_norm
818           ? nullptr
819           : GetInput(context, node, lstm::full::kOutputGateBiasTensor);
820   TF_LITE_ENSURE_OK(
821       context,
822       PrecomputeZeroPointTimesWeightWithBias(
823           context, input_zero_point, input_to_output_weights, output_gate_bias,
824           &(integer_lstm_params->input_to_output_effective_bias)));
825 
826   TF_LITE_ENSURE_OK(
827       context,
828       PrecomputeZeroPointTimesWeightWithBias(
829           context, output_state_zero_point, recurrent_to_output_weights,
830           nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
831 
832   // Input gate. The calculation is only meaningful for non-cifg case.
833   const TfLiteTensor* input_gate_bias =
834       is_layer_norm ? nullptr
835                     : GetInput(context, node, lstm::full::kInputGateBiasTensor);
836   TF_LITE_ENSURE_OK(
837       context,
838       PrecomputeZeroPointTimesWeightWithBias(
839           context, input_zero_point, input_to_input_weights, input_gate_bias,
840           &(integer_lstm_params->input_to_input_effective_bias)));
841   TF_LITE_ENSURE_OK(
842       context,
843       PrecomputeZeroPointTimesWeightWithBias(
844           context, output_state_zero_point, recurrent_to_input_weights, nullptr,
845           &(integer_lstm_params->recurrent_to_input_effective_bias)));
846 
847   // Projection bias. The calculation is only meaningful for with projection.
848   TF_LITE_ENSURE_OK(context,
849                     PrecomputeZeroPointTimesWeightWithBias(
850                         context, hidden_zp, projection_weights, projection_bias,
851                         &(integer_lstm_params->projection_effective_bias)));
852   return kTfLiteOk;
853 }
854 
855 // Resize the output and  state tensors based on the sizes of the input tensors.
856 // Allocate a temporary scratch tensor. Also check that the sizes of the input
857 // tensors match each other.
Prepare(TfLiteContext * context,TfLiteNode * node)858 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
859   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
860   const int scratch_tensor_index = op_data->scratch_tensor_index;
861 
862   // Check we have all the inputs and outputs we need.
863   bool use_layer_norm = false;
864   if (node->inputs->size == 24) {
865     const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
866         context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
867     if (forget_layer_norm_coefficients == nullptr) {
868       use_layer_norm = false;
869     } else {
870       use_layer_norm = true;
871     }
872   } else if (node->inputs->size == 20) {
873     // This is deprecated and is only kept here for backward compatibility.
874     use_layer_norm = false;
875   } else {
876     TF_LITE_KERNEL_LOG(
877         context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
878         node->inputs->size);
879     return kTfLiteError;
880   }
881   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
882   op_data->use_layer_norm = use_layer_norm;
883 
884   // Inferring batch size, number of outputs and sequence length and
885   // number of cells from the input tensors.
886   const TfLiteTensor* input;
887   TF_LITE_ENSURE_OK(
888       context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
889   const bool is_integer = input->type == kTfLiteInt8;
890   TF_LITE_ENSURE(context, input->dims->size > 1);
891   const auto* params =
892       reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
893           node->builtin_data);
894   const bool time_major = params->time_major;
895   const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
896   const int n_input = input->dims->data[2];
897 
898   const TfLiteTensor* input_to_output_weights;
899   TF_LITE_ENSURE_OK(
900       context,
901       GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
902                    &input_to_output_weights));
903   const int n_cell = input_to_output_weights->dims->data[0];
904   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
905   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
906 
907   const TfLiteTensor* recurrent_to_output_weights;
908   TF_LITE_ENSURE_OK(
909       context,
910       GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
911                    &recurrent_to_output_weights));
912   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
913   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
914                     n_cell);
915   const int n_output = recurrent_to_output_weights->dims->data[1];
916 
917   // Check that input tensor dimensions matches with each other.
918   TF_LITE_ENSURE_OK(
919       context, CheckInputTensorDimensions(context, node, n_input, n_output,
920                                           n_cell, use_layer_norm, is_integer));
921 
922   // Get the pointer to output, output_state and cell_state buffer tensors.
923   TfLiteTensor* output;
924   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
925                                            lstm::full::kOutputTensor, &output));
926 
927   TfLiteTensor* output_state =
928       GetVariableInput(context, node, lstm::full::kOutputStateTensor);
929   TF_LITE_ENSURE(context, output_state != nullptr);
930   TfLiteTensor* cell_state =
931       GetVariableInput(context, node, lstm::full::kCellStateTensor);
932   TF_LITE_ENSURE(context, cell_state != nullptr);
933 
934   // Check the shape of input state tensors.
935   // These tensor may be 1D or 2D. It's fine as long as the total size is
936   // correct.
937   TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output);
938   TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
939 
940   // Resize the output tensors.
941   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
942   output_size->data[input->dims->size - 1] = n_output;
943   TF_LITE_ENSURE_OK(context,
944                     context->ResizeTensor(context, output, output_size));
945 
946   if (is_integer) {
947     const int num_intermediate_tensors = node->intermediates->size;
948     TF_LITE_ENSURE(context, num_intermediate_tensors == 5);
949   }
950 
951   TfLiteIntArrayFree(node->temporaries);
952   if (IsHybridOp(input, input_to_output_weights)) {
953     node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
954   } else if (is_integer) {
955     node->temporaries = TfLiteIntArrayCreate(6);
956   } else {
957     node->temporaries = TfLiteIntArrayCreate(1);
958   }
959   node->temporaries->data[kScratchBuffer] =
960       scratch_tensor_index + kScratchBuffer;
961 
962   // Create a scratch buffer tensor.
963   TfLiteTensor* scratch_buffer;
964   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
965                                               &scratch_buffer));
966   scratch_buffer->type = input->type;
967   scratch_buffer->allocation_type = kTfLiteArenaRw;
968 
969   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
970       context, node, lstm::full::kInputToInputWeightsTensor);
971   const bool use_cifg = (input_to_input_weights == nullptr);
972   TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
973   scratch_buffer_size->data[0] = n_batch;
974   if (use_cifg) {
975     // Reserving space for Cell, Forget, Output gates and scratch accumulation
976     // buffer and an extra 16 bytes to avoid internal ruy copies.
977     scratch_buffer_size->data[1] = n_cell * 4 + 16;
978   } else {
979     // Reserving space for Input, Cell, Forget, Output gates and scratch
980     // accumulation buffer and an extra 16 bytes to avoid internal ruy copies.
981     scratch_buffer_size->data[1] = n_cell * 5 + 16;
982   }
983   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
984                                                    scratch_buffer_size));
985 
986   if (IsHybridOp(input, input_to_output_weights)) {
987     op_data->compute_row_sums = true;
988     // Allocate temporary tensors to store quantized values of input,
989     // output_state and cell_state tensors.
990     node->temporaries->data[kInputQuantized] =
991         scratch_tensor_index + kInputQuantized;
992     TfLiteTensor* input_quantized;
993     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
994                                                 &input_quantized));
995     input_quantized->type = input_to_output_weights->type;
996     input_quantized->allocation_type = kTfLiteArenaRw;
997     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
998       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
999       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
1000                                                        input_quantized_size));
1001     }
1002     node->temporaries->data[kOutputStateQuantized] =
1003         scratch_tensor_index + kOutputStateQuantized;
1004     TfLiteTensor* output_state_quantized;
1005     TF_LITE_ENSURE_OK(context,
1006                       GetTemporarySafe(context, node, kOutputStateQuantized,
1007                                        &output_state_quantized));
1008     output_state_quantized->type = input_to_output_weights->type;
1009     output_state_quantized->allocation_type = kTfLiteArenaRw;
1010     if (!TfLiteIntArrayEqual(output_state_quantized->dims,
1011                              output_state->dims)) {
1012       TfLiteIntArray* output_state_quantized_size =
1013           TfLiteIntArrayCopy(output_state->dims);
1014       TF_LITE_ENSURE_OK(context,
1015                         context->ResizeTensor(context, output_state_quantized,
1016                                               output_state_quantized_size));
1017     }
1018     node->temporaries->data[kCellStateQuantized] =
1019         scratch_tensor_index + kCellStateQuantized;
1020     TfLiteTensor* cell_state_quantized;
1021     TF_LITE_ENSURE_OK(context,
1022                       GetTemporarySafe(context, node, kCellStateQuantized,
1023                                        &cell_state_quantized));
1024     cell_state_quantized->type = input_to_output_weights->type;
1025     cell_state_quantized->allocation_type = kTfLiteArenaRw;
1026     if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
1027       TfLiteIntArray* cell_state_quantized_size =
1028           TfLiteIntArrayCopy(cell_state->dims);
1029       TF_LITE_ENSURE_OK(context,
1030                         context->ResizeTensor(context, cell_state_quantized,
1031                                               cell_state_quantized_size));
1032     }
1033 
1034     // Allocate temporary tensors to store scaling factors and product scaling
1035     // factors. The latter is a convenience storage which allows to quantize
1036     // a vector once (which produces the scaling factors) and multiply it with
1037     // different matrices (which requires multiplying the scaling factors with
1038     // the scaling factor of the matrix).
1039     node->temporaries->data[kInputScalingFactors] =
1040         op_data->scratch_tensor_index + kInputScalingFactors;
1041     TfLiteTensor* input_sf;
1042     TF_LITE_ENSURE_OK(
1043         context,
1044         GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
1045     input_sf->type = kTfLiteFloat32;
1046     input_sf->allocation_type = kTfLiteArenaRw;
1047     int scaling_dims[1] = {n_batch};
1048     if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
1049       TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
1050       input_sf_size->data[0] = n_batch;
1051       TF_LITE_ENSURE_OK(
1052           context, context->ResizeTensor(context, input_sf, input_sf_size));
1053     }
1054     node->temporaries->data[kOutputStateScalingFactors] =
1055         op_data->scratch_tensor_index + kOutputStateScalingFactors;
1056     TfLiteTensor* output_state_sf;
1057     TF_LITE_ENSURE_OK(
1058         context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
1059                                   &output_state_sf));
1060     output_state_sf->type = kTfLiteFloat32;
1061     output_state_sf->allocation_type = kTfLiteArenaRw;
1062     if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
1063       TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
1064       output_state_sf_size->data[0] = n_batch;
1065       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
1066                                                        output_state_sf_size));
1067     }
1068     node->temporaries->data[kProductScalingFactors] =
1069         scratch_tensor_index + kProductScalingFactors;
1070     TfLiteTensor* prod_scaling_factors;
1071     TF_LITE_ENSURE_OK(context,
1072                       GetTemporarySafe(context, node, kProductScalingFactors,
1073                                        &prod_scaling_factors));
1074     prod_scaling_factors->type = kTfLiteFloat32;
1075     prod_scaling_factors->allocation_type = kTfLiteArenaRw;
1076     if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
1077                                    scaling_dims)) {
1078       TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
1079       prod_scaling_factors_size->data[0] = n_batch;
1080       TF_LITE_ENSURE_OK(context,
1081                         context->ResizeTensor(context, prod_scaling_factors,
1082                                               prod_scaling_factors_size));
1083     }
1084 
1085     // Allocate a temporary tensor to store the recovered cell weights. Since
1086     // this is used for diagonal matrices, only need to store n_cell values.
1087     node->temporaries->data[kRecoveredCellWeights] =
1088         scratch_tensor_index + kRecoveredCellWeights;
1089     TfLiteTensor* recovered_cell_weights;
1090     TF_LITE_ENSURE_OK(context,
1091                       GetTemporarySafe(context, node, kRecoveredCellWeights,
1092                                        &recovered_cell_weights));
1093     recovered_cell_weights->type = kTfLiteFloat32;
1094     recovered_cell_weights->allocation_type = kTfLiteArenaRw;
1095     int recovered_cell_dims[1] = {n_cell};
1096     if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
1097                                    recovered_cell_dims)) {
1098       TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
1099       recovered_cell_weights_size->data[0] = n_cell;
1100       TF_LITE_ENSURE_OK(context,
1101                         context->ResizeTensor(context, recovered_cell_weights,
1102                                               recovered_cell_weights_size));
1103     }
1104 
1105     // Allocate a temporary tensor to store the accumulated int32 values.
1106     node->temporaries->data[kAccumScratch] =
1107         scratch_tensor_index + kAccumScratch;
1108     TfLiteTensor* accum_scratch;
1109     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
1110                                                 &accum_scratch));
1111     accum_scratch->type = kTfLiteInt32;
1112     accum_scratch->allocation_type = kTfLiteArenaRw;
1113     int accum_scratch_dims[2] = {n_cell, n_batch};
1114     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
1115                                    accum_scratch_dims)) {
1116       TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
1117       accum_size->data[0] = n_cell;
1118       accum_size->data[1] = n_batch;
1119       TF_LITE_ENSURE_OK(
1120           context, context->ResizeTensor(context, accum_scratch, accum_size));
1121     }
1122     node->temporaries->data[kInputZeroPoints] =
1123         op_data->scratch_tensor_index + kInputZeroPoints;
1124     TfLiteTensor* input_zp;
1125     TF_LITE_ENSURE_OK(
1126         context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
1127     input_zp->type = kTfLiteFloat32;
1128     input_zp->allocation_type = kTfLiteArenaRw;
1129     if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
1130       TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
1131       input_zp_size->data[0] = n_batch;
1132       TF_LITE_ENSURE_OK(
1133           context, context->ResizeTensor(context, input_zp, input_zp_size));
1134     }
1135     node->temporaries->data[kOutputStateZeroPoints] =
1136         op_data->scratch_tensor_index + kOutputStateZeroPoints;
1137     TfLiteTensor* output_state_zp;
1138     TF_LITE_ENSURE_OK(context,
1139                       GetTemporarySafe(context, node, kOutputStateZeroPoints,
1140                                        &output_state_zp));
1141     output_state_zp->type = kTfLiteFloat32;
1142     output_state_zp->allocation_type = kTfLiteArenaRw;
1143     if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
1144       TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
1145       output_state_zp_size->data[0] = n_batch;
1146       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
1147                                                        output_state_zp_size));
1148     }
1149     node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
1150     TfLiteTensor* row_sums;
1151     TF_LITE_ENSURE_OK(context,
1152                       GetTemporarySafe(context, node, kRowSums, &row_sums));
1153     row_sums->type = kTfLiteInt32;
1154     row_sums->name = "Lstm_row_sums";
1155     row_sums->allocation_type = kTfLiteArenaRwPersistent;
1156     int row_sums_rows = use_cifg ? 6 : 8;
1157     const TfLiteTensor* projection_weights = GetOptionalInputTensor(
1158         context, node, lstm::full::kProjectionWeightsTensor);
1159     if (projection_weights != nullptr) {
1160       row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
1161     }
1162     int row_sums_dims[2] = {row_sums_rows, n_cell};
1163     if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
1164       TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
1165       row_sums_size->data[0] = row_sums_dims[0];
1166       row_sums_size->data[1] = row_sums_dims[1];
1167       TF_LITE_ENSURE_OK(
1168           context, context->ResizeTensor(context, row_sums, row_sums_size));
1169     }
1170   }
1171 
1172   if (is_integer) {
1173     // Integer UnidirectionalSequenceLSTM prepare function for 8x8->16.
1174     // This code path needs 5 intermediate tensors per Op.
1175     // Populate quantization parameters.
1176     PopulateQuantizedLstmParams8x8_16(context, node,
1177                                       &op_data->integer_lstm_param);
1178     // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
1179     // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
1180     // buffer with size n_batch * n_cell.
1181     //
1182     // Handle cifg case as well, which might save one buffer.
1183     for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
1184       node->temporaries->data[scratch_index] =
1185           op_data->scratch_tensor_index + scratch_index;
1186       TfLiteTensor* scratch_tensor;
1187       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, scratch_index,
1188                                                   &scratch_tensor));
1189 
1190       scratch_tensor->type = kTfLiteInt16;
1191       if (scratch_index == 4) {
1192         scratch_tensor->type = kTfLiteInt8;
1193       } else if (scratch_index == 5) {
1194         scratch_tensor->type = kTfLiteInt32;
1195       }
1196 
1197       scratch_tensor->allocation_type = kTfLiteArenaRw;
1198       const int scratch_dimension[2] = {n_batch, n_cell};
1199       if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
1200                                      scratch_dimension)) {
1201         TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
1202         scratch_buffer_size->data[0] = n_batch;
1203         scratch_buffer_size->data[1] = n_cell;
1204         TF_LITE_ENSURE_OK(context,
1205                           context->ResizeTensor(context, scratch_tensor,
1206                                                 scratch_buffer_size));
1207       }
1208     }
1209 
1210     // Populate precomputed zp * weight.
1211     TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
1212                                    context, op_data, node));
1213   }
1214 
1215   return kTfLiteOk;
1216 }
1217 
Eval(TfLiteContext * context,TfLiteNode * node)1218 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
1219   const auto* params =
1220       reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
1221           node->builtin_data);
1222   const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
1223   const bool use_layer_norm = op_data->use_layer_norm;
1224   const bool time_major = params->time_major;
1225   const TfLiteTensor* input;
1226   TF_LITE_ENSURE_OK(
1227       context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
1228 
1229   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
1230       context, node, lstm::full::kInputToInputWeightsTensor);
1231   const TfLiteTensor* input_to_forget_weights;
1232   TF_LITE_ENSURE_OK(
1233       context,
1234       GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
1235                    &input_to_forget_weights));
1236   const TfLiteTensor* input_to_cell_weights;
1237   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
1238                                           lstm::full::kInputToCellWeightsTensor,
1239                                           &input_to_cell_weights));
1240   const TfLiteTensor* input_to_output_weights;
1241   TF_LITE_ENSURE_OK(
1242       context,
1243       GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
1244                    &input_to_output_weights));
1245 
1246   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
1247       context, node, lstm::full::kRecurrentToInputWeightsTensor);
1248   const TfLiteTensor* recurrent_to_forget_weights;
1249   TF_LITE_ENSURE_OK(
1250       context,
1251       GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
1252                    &recurrent_to_forget_weights));
1253   const TfLiteTensor* recurrent_to_cell_weights;
1254   TF_LITE_ENSURE_OK(
1255       context,
1256       GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
1257                    &recurrent_to_cell_weights));
1258   const TfLiteTensor* recurrent_to_output_weights;
1259   TF_LITE_ENSURE_OK(
1260       context,
1261       GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
1262                    &recurrent_to_output_weights));
1263 
1264   const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
1265       context, node, lstm::full::kCellToInputWeightsTensor);
1266   const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
1267       context, node, lstm::full::kCellToForgetWeightsTensor);
1268   const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
1269       context, node, lstm::full::kCellToOutputWeightsTensor);
1270 
1271   const TfLiteTensor* input_gate_bias =
1272       GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
1273   const TfLiteTensor* forget_gate_bias;
1274   TF_LITE_ENSURE_OK(
1275       context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
1276                             &forget_gate_bias));
1277   const TfLiteTensor* cell_gate_bias;
1278   TF_LITE_ENSURE_OK(context,
1279                     GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
1280                                  &cell_gate_bias));
1281   const TfLiteTensor* output_gate_bias;
1282   TF_LITE_ENSURE_OK(
1283       context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
1284                             &output_gate_bias));
1285 
1286   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
1287       context, node, lstm::full::kProjectionWeightsTensor);
1288   const TfLiteTensor* projection_bias =
1289       GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
1290 
1291   TfLiteTensor* output_state =
1292       GetVariableInput(context, node, lstm::full::kOutputStateTensor);
1293   TFLITE_DCHECK(output_state != nullptr);
1294   TfLiteTensor* cell_state =
1295       GetVariableInput(context, node, lstm::full::kCellStateTensor);
1296   TFLITE_DCHECK(cell_state != nullptr);
1297 
1298   const TfLiteTensor* input_layer_norm_coefficients =
1299       use_layer_norm
1300           ? GetOptionalInputTensor(
1301                 context, node, lstm::full::kInputLayerNormCoefficientsTensor)
1302           : nullptr;
1303   const TfLiteTensor* forget_layer_norm_coefficients =
1304       use_layer_norm ? GetInput(context, node,
1305                                 lstm::full::kForgetLayerNormCoefficientsTensor)
1306                      : nullptr;
1307   const TfLiteTensor* cell_layer_norm_coefficients =
1308       use_layer_norm ? GetInput(context, node,
1309                                 lstm::full::kCellLayerNormCoefficientsTensor)
1310                      : nullptr;
1311   const TfLiteTensor* output_layer_norm_coefficients =
1312       use_layer_norm ? GetInput(context, node,
1313                                 lstm::full::kOutputLayerNormCoefficientsTensor)
1314                      : nullptr;
1315 
1316   TfLiteTensor* output;
1317   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
1318                                            lstm::full::kOutputTensor, &output));
1319 
1320   // Copy out the LSTM specific params so they can be passed in the function.
1321   TfLiteLSTMParams lstm_params;
1322   lstm_params.activation = params->activation;
1323   lstm_params.cell_clip = params->cell_clip;
1324   lstm_params.proj_clip = params->proj_clip;
1325   lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
1326 
1327   switch (input_to_output_weights->type) {
1328     case kTfLiteFloat32: {
1329       // Index the scratch buffers pointers to the global scratch buffer.
1330       TfLiteTensor* scratch_buffer;
1331       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
1332                                                   &scratch_buffer));
1333       return lstm_eval::EvalFloat(
1334           input, input_to_input_weights, input_to_forget_weights,
1335           input_to_cell_weights, input_to_output_weights,
1336           recurrent_to_input_weights, recurrent_to_forget_weights,
1337           recurrent_to_cell_weights, recurrent_to_output_weights,
1338           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
1339           input_layer_norm_coefficients, forget_layer_norm_coefficients,
1340           cell_layer_norm_coefficients, output_layer_norm_coefficients,
1341           /*aux_input=*/nullptr,
1342           /*aux_input_to_input_weights=*/nullptr,
1343           /*aux_input_to_forget_weights=*/nullptr,
1344           /*aux_input_to_cell_weights=*/nullptr,
1345           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
1346           forget_gate_bias, cell_gate_bias, output_gate_bias,
1347           projection_weights, projection_bias, &lstm_params,
1348           /*forward_sequence=*/true, time_major,
1349           /*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
1350           CpuBackendContext::GetFromContext(context));
1351     }
1352     case kTfLiteUInt8:
1353     case kTfLiteInt8: {
1354       const bool is_hybrid = input->type == kTfLiteFloat32;
1355       if (is_hybrid) {
1356         // Index the scratch buffers pointers to the global scratch buffer.
1357         TfLiteTensor* scratch_buffer;
1358         TF_LITE_ENSURE_OK(
1359             context,
1360             GetTemporarySafe(context, node, kScratchBuffer, &scratch_buffer));
1361 
1362         OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
1363         TfLiteTensor* row_sums;
1364         TF_LITE_ENSURE_OK(context,
1365                           GetTemporarySafe(context, node, kRowSums, &row_sums));
1366         const int row_sums_size = row_sums->dims->data[0];
1367         return lstm_eval::EvalHybrid(
1368             input, input_to_input_weights,
1369             /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
1370             /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
1371             /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
1372             /*input_to_output_weights_ledger*/ nullptr,
1373             recurrent_to_input_weights,
1374             /*recurrent_to_input_weights_ledger*/ nullptr,
1375             recurrent_to_forget_weights,
1376             /*recurrent_to_forget_weights_ledger*/ nullptr,
1377             recurrent_to_cell_weights,
1378             /*recurrent_to_cell_weights_ledger*/ nullptr,
1379             recurrent_to_output_weights,
1380             /*recurrent_to_output_weights_ledger*/ nullptr,
1381             cell_to_input_weights, cell_to_forget_weights,
1382             cell_to_output_weights, input_layer_norm_coefficients,
1383             forget_layer_norm_coefficients, cell_layer_norm_coefficients,
1384             output_layer_norm_coefficients,
1385             /*aux_input=*/nullptr,
1386             /*aux_input_to_input_weights=*/nullptr,
1387             /*aux_input_to_forget_weights=*/nullptr,
1388             /*aux_input_to_cell_weights=*/nullptr,
1389             /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
1390             forget_gate_bias, cell_gate_bias, output_gate_bias,
1391             projection_weights, /*projection_weights_ledger*/ nullptr,
1392             projection_bias, &lstm_params,
1393             /*forward_sequence=*/true, time_major,
1394             /*output_offset=*/0, scratch_buffer,
1395             GetTemporary(context, node, kInputScalingFactors),
1396             /*aux_input_sf=*/nullptr,
1397             GetTemporary(context, node, kOutputStateScalingFactors),
1398             GetTemporary(context, node, kProductScalingFactors),
1399             GetTemporary(context, node, kRecoveredCellWeights),
1400             GetTemporary(context, node, kInputQuantized),
1401             /*aux_input_quantized=*/nullptr,
1402             GetTemporary(context, node, kOutputStateQuantized),
1403             GetTemporary(context, node, kCellStateQuantized), output_state,
1404             cell_state, GetTemporary(context, node, kAccumScratch), output,
1405             GetTemporary(context, node, kInputZeroPoints),
1406             /*aux_input_zp=*/nullptr,
1407             GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
1408             row_sums_size, &op_data->compute_row_sums,
1409             CpuBackendContext::GetFromContext(context));
1410       } else {
1411         TfLiteTensor* scratch0;
1412         TF_LITE_ENSURE_OK(context,
1413                           GetTemporarySafe(context, node, 0, &scratch0));
1414         TfLiteTensor* scratch1;
1415         TF_LITE_ENSURE_OK(context,
1416                           GetTemporarySafe(context, node, 1, &scratch1));
1417         TfLiteTensor* scratch2;
1418         TF_LITE_ENSURE_OK(context,
1419                           GetTemporarySafe(context, node, 2, &scratch2));
1420         TfLiteTensor* scratch3;
1421         TF_LITE_ENSURE_OK(context,
1422                           GetTemporarySafe(context, node, 3, &scratch3));
1423         TfLiteTensor* scratch4;
1424         TF_LITE_ENSURE_OK(context,
1425                           GetTemporarySafe(context, node, 4, &scratch4));
1426         TfLiteTensor* scratch5;
1427         TF_LITE_ENSURE_OK(context,
1428                           GetTemporarySafe(context, node, 5, &scratch5));
1429         return lstm_eval::EvalInteger8x8_16(
1430             input, input_to_input_weights, input_to_forget_weights,
1431             input_to_cell_weights, input_to_output_weights,
1432             recurrent_to_input_weights, recurrent_to_forget_weights,
1433             recurrent_to_cell_weights, recurrent_to_output_weights,
1434             cell_to_input_weights, cell_to_forget_weights,
1435             cell_to_output_weights, input_layer_norm_coefficients,
1436             forget_layer_norm_coefficients, cell_layer_norm_coefficients,
1437             output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
1438             cell_gate_bias, output_gate_bias, projection_weights,
1439             projection_bias, &lstm_params, /*forward_sequence=*/true,
1440             time_major, &op_data->integer_lstm_param, output_state, cell_state,
1441             output, scratch0, scratch1, scratch2, scratch3, scratch4, scratch5,
1442             CpuBackendContext::GetFromContext(context));
1443       }
1444     }
1445     default:
1446       TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
1447                          TfLiteTypeGetName(input_to_output_weights->type));
1448       return kTfLiteError;
1449   }
1450 }
1451 }  // namespace unidirectional_sequence_lstm
1452 
Register_UNIDIRECTIONAL_SEQUENCE_LSTM()1453 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
1454   static TfLiteRegistration r = {unidirectional_sequence_lstm::Init,
1455                                  unidirectional_sequence_lstm::Free,
1456                                  unidirectional_sequence_lstm::Prepare,
1457                                  unidirectional_sequence_lstm::Eval};
1458   return &r;
1459 }
1460 
1461 }  // namespace builtin
1462 }  // namespace ops
1463 }  // namespace tflite
1464