xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/lstm.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cmath>
18 #include <cstddef>
19 #include <cstdint>
20 #include <cstring>
21 #include <memory>
22 #include <vector>
23 
24 #include "tensorflow/lite/c/builtin_op_data.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/kernels/cpu_backend_context.h"
27 #include "tensorflow/lite/kernels/internal/compatibility.h"
28 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
29 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
30 #include "tensorflow/lite/kernels/internal/quantization_util.h"
31 #include "tensorflow/lite/kernels/internal/tensor.h"
32 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
33 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
34 #include "tensorflow/lite/kernels/internal/types.h"
35 #include "tensorflow/lite/kernels/kernel_util.h"
36 #include "tensorflow/lite/kernels/lstm_eval.h"
37 #include "tensorflow/lite/kernels/lstm_shared.h"
38 
39 namespace tflite {
40 namespace ops {
41 namespace builtin {
42 namespace lstm {
43 
44 struct OpData {
45   // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
46   // inputs).
47   // Please note the 20-input full kernel is deprecated and only kept
48   // here for backward compatibility.
49   TfLiteLSTMKernelType kernel_type;
50 
51   // If the lstm is layer norm.
52   bool use_layer_norm;
53 
54   // These fields are only used by full kernel.
55   int scratch_tensor_index;
56   lstm_eval::IntegerLstmParameter integer_lstm_param;
57   bool compute_row_sums;
58 
59   // Only used for sparse hybrid lstm kernels.
60   int ledger_index;
61   bool ledger_initialized;
62 };
63 
64 namespace full {
65 namespace {
66 
67 // Named temporary tensors.
68 enum HybridTemporaryTensor {
69   kScratchBuffer = 0,
70   kInputQuantized = 1,
71   kOutputStateQuantized = 2,
72   kCellStateQuantized = 3,
73   kInputScalingFactors = 4,
74   kOutputStateScalingFactors = 5,
75   kProductScalingFactors = 6,
76   kRecoveredCellWeights = 7,
77   kAccumScratch = 8,
78   kInputZeroPoints = 9,
79   kOutputStateZeroPoints = 10,
80   kRowSums = 11,
81   kNumHybridTemporaryTensors = 12,
82 };
83 
84 constexpr int kLedgersToAdd = 9;
85 constexpr int kInputToInputWeightsLedgerOffset = 0;
86 constexpr int kInputToForgetWeightsLedgerOffset = 1;
87 constexpr int kInputToCellWeightsLedgerOffset = 2;
88 constexpr int kInputToOutputWeightsLedgerOffset = 3;
89 constexpr int kRecurrentToInputWeightsLedgerOffset = 4;
90 constexpr int kRecurrentToForgetWeightsLedgerOffset = 5;
91 constexpr int kRecurrentToCellWeightsLedgerOffset = 6;
92 constexpr int kRecurrentToOutputWeightsLedgerOffset = 7;
93 constexpr int kProjectionWeightsLedgerOffset = 8;
94 
make_ledger(const TfLiteSparsity * sparsity,TfLiteContext * context,TfLiteTensor * ledger)95 TfLiteStatus make_ledger(const TfLiteSparsity* sparsity, TfLiteContext* context,
96                          TfLiteTensor* ledger) {
97   ledger->type = kTfLiteUInt8;
98   ledger->name = "Lstm_ledger";
99   ledger->allocation_type = kTfLiteArenaRwPersistent;
100   if (sparsity == nullptr) {
101     return kTfLiteOk;
102   }
103   TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
104   ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
105                          sparsity->dim_metadata[1].array_segments->size - 1;
106   return context->ResizeTensor(context, ledger, ledger_size);
107 }
108 
copy_ledger(const TfLiteSparsity * sparsity,TfLiteTensor * ledger)109 TfLiteStatus copy_ledger(const TfLiteSparsity* sparsity, TfLiteTensor* ledger) {
110   if (sparsity == nullptr) {
111     return kTfLiteOk;
112   }
113 
114   const auto* array_segments = sparsity->dim_metadata[1].array_segments;
115   const auto* array_indices = sparsity->dim_metadata[1].array_indices;
116   uint8_t* output_data = GetTensorData<uint8_t>(ledger);
117   int output_data_ptr = 0;
118 
119   for (int i = 0; i < array_segments->size - 1; i++) {
120     int row_start = array_segments->data[i];
121     int row_end = array_segments->data[i + 1];
122     if (row_end - row_start > UINT8_MAX) {
123       return kTfLiteError;
124     }
125     // Copy num of non-zero blocks in row i.
126     output_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
127     output_data_ptr++;
128 
129     for (int j = row_start; j < row_end; j++) {
130       if (array_indices->data[j] > UINT8_MAX) {
131         return kTfLiteError;
132       }
133       // Copy indices of non-zero blocks in row i.
134       output_data[output_data_ptr] =
135           static_cast<uint8_t>(array_indices->data[j]);
136       output_data_ptr++;
137     }
138   }
139   return kTfLiteOk;
140 }
141 
PopulateQuantizedLstmParams8x8_16(TfLiteContext * context,TfLiteNode * node,lstm_eval::IntegerLstmParameter * integer_lstm_param)142 TfLiteStatus PopulateQuantizedLstmParams8x8_16(
143     TfLiteContext* context, TfLiteNode* node,
144     lstm_eval::IntegerLstmParameter* integer_lstm_param) {
145   // Calculate quantized clip for projection and cell.
146   const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
147   const float cell_clip = params->cell_clip;
148   const float proj_clip = params->proj_clip;
149 
150   const TfLiteTensor* cell_state =
151       GetVariableInput(context, node, kCellStateTensor);
152   TF_LITE_ENSURE(context, cell_state != nullptr);
153   TfLiteTensor* output_tensor;
154   TF_LITE_ENSURE_OK(
155       context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
156 
157   auto* cell_state_params =
158       static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
159   auto* proj_params = static_cast<TfLiteAffineQuantization*>(
160       output_tensor->quantization.params);
161   if (cell_clip > 0.0) {
162     integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
163         std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
164         32767.0f));
165   } else {
166     integer_lstm_param->quantized_cell_clip = 0;
167   }
168   if (proj_clip > 0.0) {
169     integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
170         std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
171   } else {
172     integer_lstm_param->quantized_proj_clip = 0;
173   }
174 
175   // Calculate effective scales.
176   OpData* op_data = static_cast<OpData*>(node->user_data);
177   const bool use_layer_norm = op_data->use_layer_norm;
178 
179   const TfLiteTensor* input;
180   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
181 
182   const TfLiteTensor* input_to_input_weights =
183       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
184   const TfLiteTensor* input_to_forget_weights;
185   TF_LITE_ENSURE_OK(context,
186                     GetInputSafe(context, node, kInputToForgetWeightsTensor,
187                                  &input_to_forget_weights));
188   const TfLiteTensor* input_to_cell_weights;
189   TF_LITE_ENSURE_OK(context,
190                     GetInputSafe(context, node, kInputToCellWeightsTensor,
191                                  &input_to_cell_weights));
192   const TfLiteTensor* input_to_output_weights;
193   TF_LITE_ENSURE_OK(context,
194                     GetInputSafe(context, node, kInputToOutputWeightsTensor,
195                                  &input_to_output_weights));
196 
197   const TfLiteTensor* recurrent_to_input_weights =
198       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
199   const TfLiteTensor* recurrent_to_forget_weights;
200   TF_LITE_ENSURE_OK(context,
201                     GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
202                                  &recurrent_to_forget_weights));
203   const TfLiteTensor* recurrent_to_cell_weights;
204   TF_LITE_ENSURE_OK(context,
205                     GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
206                                  &recurrent_to_cell_weights));
207   const TfLiteTensor* recurrent_to_output_weights;
208   TF_LITE_ENSURE_OK(context,
209                     GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
210                                  &recurrent_to_output_weights));
211 
212   const TfLiteTensor* cell_to_input_weights =
213       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
214   const TfLiteTensor* cell_to_forget_weights =
215       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
216   const TfLiteTensor* cell_to_output_weights =
217       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
218 
219   const TfLiteTensor* input_layer_norm_coefficients =
220       GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor);
221   const TfLiteTensor* forget_layer_norm_coefficients =
222       GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor);
223   const TfLiteTensor* cell_layer_norm_coefficients =
224       GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor);
225   const TfLiteTensor* output_layer_norm_coefficients =
226       GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor);
227 
228   const TfLiteTensor* projection_weights =
229       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
230 
231   TfLiteTensor* output_state =
232       GetVariableInput(context, node, kOutputStateTensor);
233   TF_LITE_ENSURE(context, output_state != nullptr);
234 
235   // Since we have already checked that weights are all there or none, we can
236   // check the existence of only one to get the condition.
237   const bool use_cifg = (input_to_input_weights == nullptr);
238   const bool use_peephole = (cell_to_output_weights != nullptr);
239   const bool use_projection = (projection_weights != nullptr);
240 
241   // Get intermediate scales and zero points.
242   std::vector<float> intermediate_scale;
243   std::vector<int32> intermediate_zp;
244   for (int i = 0; i < 4; ++i) {
245     if (use_layer_norm) {
246       TfLiteTensor* intermediate;
247       TF_LITE_ENSURE_OK(context,
248                         GetIntermediatesSafe(context, node, i, &intermediate));
249       auto* params = static_cast<TfLiteAffineQuantization*>(
250           intermediate->quantization.params);
251       intermediate_scale.push_back(params->scale->data[0]);
252       intermediate_zp.push_back(params->zero_point->data[0]);
253     } else {
254       // Q3.12 for activation functions.
255       intermediate_scale.push_back(std::pow(2, -12));
256       intermediate_zp.push_back(0);
257     }
258   }
259   // In the absence of projection, hidden becomes output and this intermediate
260   // is ignored.
261   TfLiteTensor* hidden;
262   TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
263   auto* hidden_params =
264       static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
265   intermediate_scale.push_back(hidden_params->scale->data[0]);
266   intermediate_zp.push_back(hidden_params->zero_point->data[0]);
267 
268   // Scales.
269   const float default_scale = 1.0;
270   float input_scale = default_scale;
271   float input_to_input_weight_scale = default_scale;
272   float recurrent_to_input_weight_scale = default_scale;
273   float cell_to_input_weight_scale = default_scale;
274   float input_to_forget_weight_scale = default_scale;
275   float recurrent_to_forget_weight_scale = default_scale;
276   float cell_to_forget_weight_scale = default_scale;
277   float input_to_cell_weight_scale = default_scale;
278   float recurrent_to_cell_weight_scale = default_scale;
279   float input_to_output_weight_scale = default_scale;
280   float recurrent_to_output_weight_scale = default_scale;
281   float cell_to_output_weight_scale = default_scale;
282   float projection_weight_scale = default_scale;
283   float layer_norm_input_scale = default_scale;
284   float layer_norm_forget_scale = default_scale;
285   float layer_norm_cell_scale = default_scale;
286   float layer_norm_output_scale = default_scale;
287   float output_state_scale = default_scale;
288   int cell_scale = 1;
289 
290   // Effective scales.
291   float effective_input_to_input_scale = default_scale;
292   float effective_recurrent_to_input_scale = default_scale;
293   float effective_cell_to_input_scale = default_scale;
294   float effective_input_to_forget_scale = default_scale;
295   float effective_recurrent_to_forget_scale = default_scale;
296   float effective_cell_to_forget_scale = default_scale;
297   float effective_input_to_cell_scale = default_scale;
298   float effective_recurrent_to_cell_scale = default_scale;
299   float effective_input_to_output_scale = default_scale;
300   float effective_recurrent_to_output_scale = default_scale;
301   float effective_cell_to_output_scale = default_scale;
302   float effective_proj_scale = default_scale;
303   float effective_hidden_scale = default_scale;
304 
305   // Populate scales.
306   if (!use_cifg) {
307     input_to_input_weight_scale = input_to_input_weights->params.scale;
308     recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
309   }
310 
311   if (use_peephole) {
312     if (!use_cifg) {
313       cell_to_input_weight_scale = cell_to_input_weights->params.scale;
314     }
315     cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
316     cell_to_output_weight_scale = cell_to_output_weights->params.scale;
317   }
318 
319   if (use_layer_norm) {
320     if (!use_cifg) {
321       layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
322     }
323     layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
324     layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
325     layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
326   }
327 
328   if (use_projection) {
329     projection_weight_scale = projection_weights->params.scale;
330   }
331   output_state_scale = output_state->params.scale;
332 
333   input_to_forget_weight_scale = input_to_forget_weights->params.scale;
334   input_to_cell_weight_scale = input_to_cell_weights->params.scale;
335   input_to_output_weight_scale = input_to_output_weights->params.scale;
336   recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
337   recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
338   recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
339 
340   // Check cell state (already used above)
341   TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
342   TF_LITE_ENSURE(context, cell_scale <= -9);
343   integer_lstm_param->cell_scale = cell_scale;
344   input_scale = input->params.scale;
345 
346   // Calculate effective scales.
347   if (!use_cifg) {
348     effective_input_to_input_scale =
349         input_to_input_weight_scale * input_scale / intermediate_scale[0];
350     effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
351                                          output_state_scale /
352                                          intermediate_scale[0];
353   }
354   effective_input_to_forget_scale =
355       input_to_forget_weight_scale * input_scale / intermediate_scale[1];
356   effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
357                                         output_state_scale /
358                                         intermediate_scale[1];
359 
360   effective_input_to_cell_scale =
361       input_to_cell_weight_scale * input_scale / intermediate_scale[2];
362   effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
363                                       output_state_scale /
364                                       intermediate_scale[2];
365 
366   effective_input_to_output_scale =
367       input_to_output_weight_scale * input_scale / intermediate_scale[3];
368   effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
369                                         output_state_scale /
370                                         intermediate_scale[3];
371 
372   effective_hidden_scale =
373       std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15);
374 
375   effective_proj_scale =
376       projection_weight_scale * intermediate_scale[4] / output_state_scale;
377 
378   if (use_peephole) {
379     if (!use_cifg) {
380       effective_cell_to_input_scale = std::pow(2, cell_scale) *  // NOLINT
381                                       cell_to_input_weight_scale /
382                                       intermediate_scale[0];
383     }
384     effective_cell_to_forget_scale = std::pow(2, cell_scale) *  // NOLINT
385                                      cell_to_forget_weight_scale /
386                                      intermediate_scale[1];
387     effective_cell_to_output_scale = std::pow(2, cell_scale) *  // NOLINT
388                                      cell_to_output_weight_scale /
389                                      intermediate_scale[3];
390   }
391 
392   // Decompose scales.
393   QuantizeMultiplier(effective_input_to_input_scale,
394                      &integer_lstm_param->effective_input_to_input_scale_a,
395                      &integer_lstm_param->effective_input_to_input_scale_b);
396   QuantizeMultiplier(effective_recurrent_to_input_scale,
397                      &integer_lstm_param->effective_recurrent_to_input_scale_a,
398                      &integer_lstm_param->effective_recurrent_to_input_scale_b);
399   QuantizeMultiplier(effective_cell_to_input_scale,
400                      &integer_lstm_param->effective_cell_to_input_scale_a,
401                      &integer_lstm_param->effective_cell_to_input_scale_b);
402   QuantizeMultiplier(effective_input_to_forget_scale,
403                      &integer_lstm_param->effective_input_to_forget_scale_a,
404                      &integer_lstm_param->effective_input_to_forget_scale_b);
405   QuantizeMultiplier(
406       effective_recurrent_to_forget_scale,
407       &integer_lstm_param->effective_recurrent_to_forget_scale_a,
408       &integer_lstm_param->effective_recurrent_to_forget_scale_b);
409   QuantizeMultiplier(effective_cell_to_forget_scale,
410                      &integer_lstm_param->effective_cell_to_forget_scale_a,
411                      &integer_lstm_param->effective_cell_to_forget_scale_b);
412   QuantizeMultiplier(effective_input_to_cell_scale,
413                      &integer_lstm_param->effective_input_to_cell_scale_a,
414                      &integer_lstm_param->effective_input_to_cell_scale_b);
415   QuantizeMultiplier(effective_recurrent_to_cell_scale,
416                      &integer_lstm_param->effective_recurrent_to_cell_scale_a,
417                      &integer_lstm_param->effective_recurrent_to_cell_scale_b);
418   QuantizeMultiplier(effective_input_to_output_scale,
419                      &integer_lstm_param->effective_input_to_output_scale_a,
420                      &integer_lstm_param->effective_input_to_output_scale_b);
421   QuantizeMultiplier(
422       effective_recurrent_to_output_scale,
423       &integer_lstm_param->effective_recurrent_to_output_scale_a,
424       &integer_lstm_param->effective_recurrent_to_output_scale_b);
425   QuantizeMultiplier(effective_cell_to_output_scale,
426                      &integer_lstm_param->effective_cell_to_output_scale_a,
427                      &integer_lstm_param->effective_cell_to_output_scale_b);
428   QuantizeMultiplier(effective_proj_scale,
429                      &integer_lstm_param->effective_proj_scale_a,
430                      &integer_lstm_param->effective_proj_scale_b);
431   QuantizeMultiplier(effective_hidden_scale,
432                      &integer_lstm_param->effective_hidden_scale_a,
433                      &integer_lstm_param->effective_hidden_scale_b);
434   QuantizeMultiplier(layer_norm_input_scale,
435                      &integer_lstm_param->layer_norm_input_scale_a,
436                      &integer_lstm_param->layer_norm_input_scale_b);
437   QuantizeMultiplier(layer_norm_forget_scale,
438                      &integer_lstm_param->layer_norm_forget_scale_a,
439                      &integer_lstm_param->layer_norm_forget_scale_b);
440   QuantizeMultiplier(layer_norm_cell_scale,
441                      &integer_lstm_param->layer_norm_cell_scale_a,
442                      &integer_lstm_param->layer_norm_cell_scale_b);
443   QuantizeMultiplier(layer_norm_output_scale,
444                      &integer_lstm_param->layer_norm_output_scale_a,
445                      &integer_lstm_param->layer_norm_output_scale_b);
446 
447   integer_lstm_param->hidden_zp = intermediate_zp[4];
448 
449   // 10000 is used to make sure the kernel logic does not overflow.
450   if (!use_cifg) {
451     integer_lstm_param->input_variance_guard =
452         std::max(1, static_cast<int32_t>(10000 * layer_norm_input_scale));
453   }
454   integer_lstm_param->forget_variance_guard =
455       std::max(1, static_cast<int32_t>(10000 * layer_norm_forget_scale));
456   integer_lstm_param->cell_variance_guard =
457       std::max(1, static_cast<int32_t>(10000 * layer_norm_cell_scale));
458   integer_lstm_param->output_variance_guard =
459       std::max(1, static_cast<int32_t>(10000 * layer_norm_output_scale));
460 
461   return kTfLiteOk;
462 }
463 
PopulateQuantizedLstmParams8x8_8(TfLiteContext * context,TfLiteNode * node,lstm_eval::IntegerLstmParameter * integer_lstm_param)464 TfLiteStatus PopulateQuantizedLstmParams8x8_8(
465     TfLiteContext* context, TfLiteNode* node,
466     lstm_eval::IntegerLstmParameter* integer_lstm_param) {
467   // Get all tensors.
468   const TfLiteTensor* input;
469   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
470   const TfLiteTensor* input_to_input_weights =
471       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
472   const TfLiteTensor* input_to_forget_weights;
473   TF_LITE_ENSURE_OK(context,
474                     GetInputSafe(context, node, kInputToForgetWeightsTensor,
475                                  &input_to_forget_weights));
476   const TfLiteTensor* input_to_cell_weights;
477   TF_LITE_ENSURE_OK(context,
478                     GetInputSafe(context, node, kInputToCellWeightsTensor,
479                                  &input_to_cell_weights));
480   const TfLiteTensor* input_to_output_weights;
481   TF_LITE_ENSURE_OK(context,
482                     GetInputSafe(context, node, kInputToOutputWeightsTensor,
483                                  &input_to_output_weights));
484 
485   const TfLiteTensor* recurrent_to_input_weights =
486       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
487   const TfLiteTensor* recurrent_to_forget_weights;
488   TF_LITE_ENSURE_OK(context,
489                     GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
490                                  &recurrent_to_forget_weights));
491   const TfLiteTensor* recurrent_to_cell_weights;
492   TF_LITE_ENSURE_OK(context,
493                     GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
494                                  &recurrent_to_cell_weights));
495   const TfLiteTensor* recurrent_to_output_weights;
496   TF_LITE_ENSURE_OK(context,
497                     GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
498                                  &recurrent_to_output_weights));
499 
500   const TfLiteTensor* cell_to_input_weights =
501       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
502   const TfLiteTensor* cell_to_forget_weights =
503       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
504   const TfLiteTensor* cell_to_output_weights =
505       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
506 
507   const TfLiteTensor* input_layer_norm_coefficients =
508       GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor);
509   const TfLiteTensor* forget_layer_norm_coefficients =
510       GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor);
511   const TfLiteTensor* cell_layer_norm_coefficients =
512       GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor);
513   const TfLiteTensor* output_layer_norm_coefficients =
514       GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor);
515 
516   const TfLiteTensor* input_gate_bias =
517       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
518   const TfLiteTensor* forget_gate_bias;
519   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
520                                           &forget_gate_bias));
521   const TfLiteTensor* cell_gate_bias;
522   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
523                                           &cell_gate_bias));
524   const TfLiteTensor* output_gate_bias;
525   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
526                                           &output_gate_bias));
527 
528   const TfLiteTensor* projection_weights =
529       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
530   const TfLiteTensor* projection_bias =
531       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
532 
533   TfLiteTensor* output_state =
534       GetVariableInput(context, node, kOutputStateTensor);
535   TF_LITE_ENSURE(context, output_state != nullptr);
536   TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
537   TF_LITE_ENSURE(context, cell_state != nullptr);
538 
539   // Since we have already checked that weights are all there or none, we can
540   // check the existence of only one to get the condition.
541   const bool use_cifg = (input_to_input_weights == nullptr);
542   const bool use_peephole = (cell_to_output_weights != nullptr);
543   const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
544   const bool use_projection = (projection_weights != nullptr);
545 
546   // Weights and states.
547   int8_t* input_to_input_weight_ptr = nullptr;
548   int8_t* recurrent_to_input_weight_ptr = nullptr;
549   int8_t* cell_to_input_weight_ptr = nullptr;
550   int8_t* input_to_forget_weight_ptr = nullptr;
551   int8_t* recurrent_to_forget_weight_ptr = nullptr;
552   int8_t* cell_to_forget_weight_ptr = nullptr;
553   int8_t* input_to_cell_weight_ptr = nullptr;
554   int8_t* recurrent_to_cell_weight_ptr = nullptr;
555   int8_t* input_to_output_weight_ptr = nullptr;
556   int8_t* recurrent_to_output_weight_ptr = nullptr;
557   int8_t* cell_to_output_weight_ptr = nullptr;
558   int8_t* projection_weight_ptr = nullptr;
559   int16_t* layer_norm_input_weight_ptr = nullptr;
560   int16_t* layer_norm_forget_weight_ptr = nullptr;
561   int16_t* layer_norm_cell_weight_ptr = nullptr;
562   int16_t* layer_norm_output_weight_ptr = nullptr;
563   int32_t* input_gate_bias_ptr = nullptr;
564   int32_t* forget_gate_bias_ptr = nullptr;
565   int32_t* cell_gate_bias_ptr = nullptr;
566   int32_t* output_gate_bias_ptr = nullptr;
567   int32_t* projection_bias_ptr = nullptr;
568   int16_t* cell_ptr = nullptr;
569   int8_t* output_state_ptr = nullptr;
570 
571   // Scales.
572   const float default_scale = 1.0;
573   float input_scale = default_scale;
574   float input_to_input_weight_scale = default_scale;
575   float recurrent_to_input_weight_scale = default_scale;
576   float cell_to_input_weight_scale = default_scale;
577   float input_to_forget_weight_scale = default_scale;
578   float recurrent_to_forget_weight_scale = default_scale;
579   float cell_to_forget_weight_scale = default_scale;
580   float input_to_cell_weight_scale = default_scale;
581   float recurrent_to_cell_weight_scale = default_scale;
582   float input_to_output_weight_scale = default_scale;
583   float recurrent_to_output_weight_scale = default_scale;
584   float cell_to_output_weight_scale = default_scale;
585   float projection_weight_scale = default_scale;
586   float layer_norm_input_scale = default_scale;
587   float layer_norm_forget_scale = default_scale;
588   float layer_norm_cell_scale = default_scale;
589   float layer_norm_output_scale = default_scale;
590   float output_state_scale = default_scale;
591 
592   // Effective scales.
593   float effective_input_to_input_scale = default_scale;
594   float effective_recurrent_to_input_scale = default_scale;
595   float effective_cell_to_input_scale = default_scale;
596   float effective_input_to_forget_scale = default_scale;
597   float effective_recurrent_to_forget_scale = default_scale;
598   float effective_cell_to_forget_scale = default_scale;
599   float effective_input_to_cell_scale = default_scale;
600   float effective_recurrent_to_cell_scale = default_scale;
601   float effective_input_to_output_scale = default_scale;
602   float effective_recurrent_to_output_scale = default_scale;
603   float effective_cell_to_output_scale = default_scale;
604   float effective_proj_scale = default_scale;
605 
606   // Zero points
607   int input_zp = 0;
608   int output_state_zp = 0;
609 
610   // Populate all the values.
611   if (!use_cifg) {
612     input_to_input_weight_ptr = input_to_input_weights->data.int8;
613     recurrent_to_input_weight_ptr = recurrent_to_input_weights->data.int8;
614     input_gate_bias_ptr = input_gate_bias->data.i32;
615     input_to_input_weight_scale = input_to_input_weights->params.scale;
616     recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
617   }
618 
619   if (use_peephole) {
620     if (!use_cifg) {
621       cell_to_input_weight_ptr = cell_to_input_weights->data.int8;
622       cell_to_input_weight_scale = cell_to_input_weights->params.scale;
623     }
624     cell_to_forget_weight_ptr = cell_to_forget_weights->data.int8;
625     cell_to_output_weight_ptr = cell_to_output_weights->data.int8;
626     cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
627     cell_to_output_weight_scale = cell_to_output_weights->params.scale;
628   }
629 
630   if (is_layer_norm_lstm) {
631     if (!use_cifg) {
632       layer_norm_input_weight_ptr = input_layer_norm_coefficients->data.i16;
633       layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
634     }
635     layer_norm_forget_weight_ptr = forget_layer_norm_coefficients->data.i16;
636     layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
637     layer_norm_cell_weight_ptr = cell_layer_norm_coefficients->data.i16;
638     layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
639     layer_norm_output_weight_ptr = output_layer_norm_coefficients->data.i16;
640     layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
641   }
642 
643   if (use_projection) {
644     projection_weight_ptr = projection_weights->data.int8;
645     projection_weight_scale = projection_weights->params.scale;
646     if (projection_bias) {
647       projection_bias_ptr = projection_bias->data.i32;
648     }
649   }
650   output_state_scale = output_state->params.scale;
651 
652   input_to_forget_weight_ptr = input_to_forget_weights->data.int8;
653   input_to_forget_weight_scale = input_to_forget_weights->params.scale;
654   input_to_cell_weight_ptr = input_to_cell_weights->data.int8;
655   input_to_cell_weight_scale = input_to_cell_weights->params.scale;
656   input_to_output_weight_ptr = input_to_output_weights->data.int8;
657   input_to_output_weight_scale = input_to_output_weights->params.scale;
658   recurrent_to_forget_weight_ptr = recurrent_to_forget_weights->data.int8;
659   recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
660   recurrent_to_cell_weight_ptr = recurrent_to_cell_weights->data.int8;
661   recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
662   recurrent_to_output_weight_ptr = recurrent_to_output_weights->data.int8;
663   recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
664   forget_gate_bias_ptr = forget_gate_bias->data.i32;
665   cell_gate_bias_ptr = cell_gate_bias->data.i32;
666   output_gate_bias_ptr = output_gate_bias->data.i32;
667   output_state_ptr = output_state->data.int8;
668   cell_ptr = cell_state->data.i16;
669   input_scale = input->params.scale;
670   input_zp = input->params.zero_point;
671   output_state_zp = output_state->params.zero_point;
672 
673   std::vector<float> intermediate_scale;
674   for (int i = 0; i < 12; ++i) {
675     TfLiteTensor* intermediate =
676         &context->tensors[node->intermediates->data[i]];
677     auto* params = reinterpret_cast<TfLiteAffineQuantization*>(
678         intermediate->quantization.params);
679     intermediate_scale.push_back(params->scale->data[0]);
680     integer_lstm_param->intermediate_zp[i] = params->zero_point->data[0];
681   }
682 
683   // Calculate effective scales.
684   if (!use_cifg) {
685     effective_input_to_input_scale =
686         input_to_input_weight_scale * input_scale / intermediate_scale[1];
687     effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
688                                          output_state_scale /
689                                          intermediate_scale[2];
690   }
691   effective_input_to_forget_scale =
692       input_to_forget_weight_scale * input_scale / intermediate_scale[4];
693   effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
694                                         output_state_scale /
695                                         intermediate_scale[5];
696 
697   effective_input_to_cell_scale =
698       input_to_cell_weight_scale * input_scale / intermediate_scale[7];
699   effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
700                                       output_state_scale /
701                                       intermediate_scale[8];
702 
703   effective_input_to_output_scale =
704       input_to_output_weight_scale * input_scale / intermediate_scale[10];
705   effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
706                                         output_state_scale /
707                                         intermediate_scale[11];
708   effective_proj_scale =
709       projection_weight_scale * std::pow(2, -15) / output_state_scale;
710 
711   if (use_peephole) {
712     if (!use_cifg) {
713       effective_cell_to_input_scale =
714           std::pow(2, -15) * cell_to_input_weight_scale / intermediate_scale[0];
715     }
716     effective_cell_to_forget_scale =
717         std::pow(2, -15) * cell_to_forget_weight_scale / intermediate_scale[3];
718     effective_cell_to_output_scale =
719         std::pow(2, -15) * cell_to_output_weight_scale / intermediate_scale[9];
720   }
721 
722   // Calculate effecgive scales.
723   QuantizeMultiplier(effective_input_to_input_scale,
724                      &integer_lstm_param->effective_input_to_input_scale_a,
725                      &integer_lstm_param->effective_input_to_input_scale_b);
726   QuantizeMultiplier(effective_recurrent_to_input_scale,
727                      &integer_lstm_param->effective_recurrent_to_input_scale_a,
728                      &integer_lstm_param->effective_recurrent_to_input_scale_b);
729   QuantizeMultiplier(effective_cell_to_input_scale,
730                      &integer_lstm_param->effective_cell_to_input_scale_a,
731                      &integer_lstm_param->effective_cell_to_input_scale_b);
732   QuantizeMultiplier(effective_input_to_forget_scale,
733                      &integer_lstm_param->effective_input_to_forget_scale_a,
734                      &integer_lstm_param->effective_input_to_forget_scale_b);
735   QuantizeMultiplier(
736       effective_recurrent_to_forget_scale,
737       &integer_lstm_param->effective_recurrent_to_forget_scale_a,
738       &integer_lstm_param->effective_recurrent_to_forget_scale_b);
739   QuantizeMultiplier(effective_cell_to_forget_scale,
740                      &integer_lstm_param->effective_cell_to_forget_scale_a,
741                      &integer_lstm_param->effective_cell_to_forget_scale_b);
742   QuantizeMultiplier(effective_input_to_cell_scale,
743                      &integer_lstm_param->effective_input_to_cell_scale_a,
744                      &integer_lstm_param->effective_input_to_cell_scale_b);
745   QuantizeMultiplier(effective_recurrent_to_cell_scale,
746                      &integer_lstm_param->effective_recurrent_to_cell_scale_a,
747                      &integer_lstm_param->effective_recurrent_to_cell_scale_b);
748   QuantizeMultiplier(effective_input_to_output_scale,
749                      &integer_lstm_param->effective_input_to_output_scale_a,
750                      &integer_lstm_param->effective_input_to_output_scale_b);
751   QuantizeMultiplier(
752       effective_recurrent_to_output_scale,
753       &integer_lstm_param->effective_recurrent_to_output_scale_a,
754       &integer_lstm_param->effective_recurrent_to_output_scale_b);
755   QuantizeMultiplier(effective_cell_to_output_scale,
756                      &integer_lstm_param->effective_cell_to_output_scale_a,
757                      &integer_lstm_param->effective_cell_to_output_scale_b);
758   QuantizeMultiplier(effective_proj_scale,
759                      &integer_lstm_param->effective_proj_scale_a,
760                      &integer_lstm_param->effective_proj_scale_b);
761   QuantizeMultiplier(layer_norm_input_scale,
762                      &integer_lstm_param->layer_norm_input_scale_a,
763                      &integer_lstm_param->layer_norm_input_scale_b);
764   QuantizeMultiplier(layer_norm_forget_scale,
765                      &integer_lstm_param->layer_norm_forget_scale_a,
766                      &integer_lstm_param->layer_norm_forget_scale_b);
767   QuantizeMultiplier(layer_norm_cell_scale,
768                      &integer_lstm_param->layer_norm_cell_scale_a,
769                      &integer_lstm_param->layer_norm_cell_scale_b);
770   QuantizeMultiplier(layer_norm_output_scale,
771                      &integer_lstm_param->layer_norm_output_scale_a,
772                      &integer_lstm_param->layer_norm_output_scale_b);
773 
774   {
775     // Intermdiates in flatbuffer holds Wx, Wh and Wx+Wh.
776     // effective Wx, Wh is in effective_input/recurrent_to_<...>_scale
777     // So use intermediate_scale to hold scale from Wx and Wh to Wx+Wh
778     // 0: [1] -> [0]
779     // 1: [2] -> [0]
780     // and use intermdiate_zp as is.
781     const float s_1_0 = intermediate_scale[1] / intermediate_scale[0];
782     const float s_2_0 = intermediate_scale[2] / intermediate_scale[0];
783     const float s_4_3 = intermediate_scale[4] / intermediate_scale[3];
784     const float s_5_3 = intermediate_scale[5] / intermediate_scale[3];
785     const float s_7_6 = intermediate_scale[7] / intermediate_scale[6];
786     const float s_8_6 = intermediate_scale[8] / intermediate_scale[6];
787     const float s_10_9 = intermediate_scale[10] / intermediate_scale[9];
788     const float s_11_9 = intermediate_scale[11] / intermediate_scale[9];
789     QuantizeMultiplier(s_1_0, &integer_lstm_param->intermediate_scale_a[0],
790                        &integer_lstm_param->intermediate_scale_b[0]);
791     QuantizeMultiplier(s_2_0, &integer_lstm_param->intermediate_scale_a[1],
792                        &integer_lstm_param->intermediate_scale_b[1]);
793     QuantizeMultiplier(s_4_3, &integer_lstm_param->intermediate_scale_a[2],
794                        &integer_lstm_param->intermediate_scale_b[2]);
795     QuantizeMultiplier(s_5_3, &integer_lstm_param->intermediate_scale_a[3],
796                        &integer_lstm_param->intermediate_scale_b[3]);
797     QuantizeMultiplier(s_7_6, &integer_lstm_param->intermediate_scale_a[4],
798                        &integer_lstm_param->intermediate_scale_b[4]);
799     QuantizeMultiplier(s_8_6, &integer_lstm_param->intermediate_scale_a[5],
800                        &integer_lstm_param->intermediate_scale_b[5]);
801     QuantizeMultiplier(s_10_9, &integer_lstm_param->intermediate_scale_a[6],
802                        &integer_lstm_param->intermediate_scale_b[6]);
803     QuantizeMultiplier(s_11_9, &integer_lstm_param->intermediate_scale_a[7],
804                        &integer_lstm_param->intermediate_scale_b[7]);
805   }
806 
807   // Calculate quantized clip for projection and cell.
808   const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
809   const float cell_clip = params->cell_clip;
810   const float proj_clip = params->proj_clip;
811 
812   TfLiteTensor* output_tensor;
813   TF_LITE_ENSURE_OK(
814       context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
815 
816   auto* cell_state_params = reinterpret_cast<TfLiteAffineQuantization*>(
817       cell_state->quantization.params);
818   auto* proj_params = reinterpret_cast<TfLiteAffineQuantization*>(
819       output_tensor->quantization.params);
820   TF_LITE_ENSURE_EQ(context, cell_state_params->scale->data[0], 1.0 / 32768);
821   if (cell_clip > 0.0 && cell_clip < 1.0) {
822     integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
823         std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
824         32767.0f));
825   } else {
826     integer_lstm_param->quantized_cell_clip = 0;
827   }
828   if (proj_clip > 0.0) {
829     integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
830         std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
831   } else {
832     integer_lstm_param->quantized_proj_clip = 0;
833   }
834   return kTfLiteOk;
835 }
836 
837 }  // namespace
838 
Init(TfLiteContext * context,const char * buffer,size_t length)839 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
840   auto* op_data = new OpData();
841   op_data->kernel_type = kTfLiteLSTMFullKernel;
842   // TODO(b/159066113): maybe just add the minimum required temp tensors?
843   context->AddTensors(context, kNumHybridTemporaryTensors,
844                       &op_data->scratch_tensor_index);
845   // Tensors used for the sparse hybrid kernel.
846   context->AddTensors(context, /*tensors_to_add=*/kLedgersToAdd,
847                       &op_data->ledger_index);
848   return op_data;
849 }
850 
851 // LINT.IfChange
852 // Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(TfLiteContext * context,TfLiteNode * node,int n_input,int n_output,int n_cell,bool use_layer_norm,bool is_integer)853 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
854                                         TfLiteNode* node, int n_input,
855                                         int n_output, int n_cell,
856                                         bool use_layer_norm, bool is_integer) {
857   const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
858 
859   // Making sure clipping parameters have valid values.
860   // == 0 means no clipping
861   //  > 0 means clipping
862   TF_LITE_ENSURE(context, params->cell_clip >= 0);
863   TF_LITE_ENSURE(context, params->proj_clip >= 0);
864 
865   const TfLiteTensor* input_to_forget_weights;
866   TF_LITE_ENSURE_OK(context,
867                     GetInputSafe(context, node, kInputToForgetWeightsTensor,
868                                  &input_to_forget_weights));
869   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
870   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
871   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
872   TF_LITE_ENSURE(context, (input_to_forget_weights->type == kTfLiteFloat32) ||
873                               (input_to_forget_weights->type == kTfLiteUInt8) ||
874                               (input_to_forget_weights->type == kTfLiteInt8));
875 
876   const TfLiteTensor* input_to_input_weights =
877       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
878   const bool use_cifg = (input_to_input_weights == nullptr);
879   if (!use_cifg) {
880     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
881     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
882     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
883     TF_LITE_ENSURE_TYPES_EQ(context, input_to_input_weights->type,
884                             input_to_forget_weights->type);
885   }
886 
887   const TfLiteTensor* input_to_cell_weights;
888   TF_LITE_ENSURE_OK(context,
889                     GetInputSafe(context, node, kInputToCellWeightsTensor,
890                                  &input_to_cell_weights));
891   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
892   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
893   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
894   TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
895                           input_to_forget_weights->type);
896 
897   const TfLiteTensor* recurrent_to_input_weights =
898       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
899   if (recurrent_to_input_weights != nullptr) {
900     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
901     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
902                       n_cell);
903     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
904                       n_output);
905     TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_input_weights->type,
906                             input_to_forget_weights->type);
907   }
908 
909   const TfLiteTensor* recurrent_to_forget_weights;
910   TF_LITE_ENSURE_OK(context,
911                     GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
912                                  &recurrent_to_forget_weights));
913   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
914   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
915                     n_cell);
916   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
917                     n_output);
918   TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
919                           input_to_forget_weights->type);
920 
921   const TfLiteTensor* recurrent_to_cell_weights;
922   TF_LITE_ENSURE_OK(context,
923                     GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
924                                  &recurrent_to_cell_weights));
925   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
926   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
927   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
928                     n_output);
929   TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_cell_weights->type,
930                           input_to_forget_weights->type);
931 
932   // We make sure the input-gate's parameters are either both present (regular
933   // LSTM) or not at all (CIFG-LSTM).
934   const bool cifg_weights_all_or_none =
935       ((input_to_input_weights != nullptr) &&
936        (recurrent_to_input_weights != nullptr)) ||
937       ((input_to_input_weights == nullptr) &&
938        (recurrent_to_input_weights == nullptr));
939   TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
940 
941   const TfLiteTensor* cell_to_input_weights =
942       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
943   if (cell_to_input_weights) {
944     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
945     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
946     TF_LITE_ENSURE_TYPES_EQ(
947         context, cell_to_input_weights->type,
948         is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
949   }
950 
951   const TfLiteTensor* cell_to_forget_weights =
952       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
953   if (cell_to_forget_weights) {
954     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
955     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
956     TF_LITE_ENSURE_TYPES_EQ(
957         context, cell_to_forget_weights->type,
958         is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
959   }
960 
961   const TfLiteTensor* cell_to_output_weights =
962       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
963   if (cell_to_output_weights) {
964     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
965     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
966     TF_LITE_ENSURE_TYPES_EQ(
967         context, cell_to_output_weights->type,
968         is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
969   }
970 
971   // Making sure the peephole weights are there all or none.
972   const bool peephole_weights_all_or_none =
973       ((cell_to_input_weights != nullptr || use_cifg) &&
974        (cell_to_forget_weights != nullptr) &&
975        (cell_to_output_weights != nullptr)) ||
976       ((cell_to_input_weights == nullptr) &&
977        (cell_to_forget_weights == nullptr) &&
978        (cell_to_output_weights == nullptr));
979   TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
980 
981   // Make sure the input gate bias is present only when not a CIFG-LSTM.
982   const TfLiteTensor* input_gate_bias =
983       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
984   if (use_cifg) {
985     TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
986   } else {
987     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
988     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
989     if (is_integer) {
990       TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
991     } else {
992       TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
993     }
994   }
995 
996   const TfLiteTensor* forget_gate_bias;
997   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
998                                           &forget_gate_bias));
999   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
1000   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
1001   if (is_integer) {
1002     TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
1003   } else {
1004     TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
1005   }
1006 
1007   const TfLiteTensor* cell_gate_bias;
1008   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
1009                                           &cell_gate_bias));
1010   TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
1011   TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
1012   if (is_integer) {
1013     TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
1014   } else {
1015     TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
1016   }
1017 
1018   const TfLiteTensor* output_gate_bias;
1019   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
1020                                           &output_gate_bias));
1021   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
1022   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
1023   if (is_integer) {
1024     TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
1025   } else {
1026     TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
1027   }
1028 
1029   const TfLiteTensor* projection_weights =
1030       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
1031   if (projection_weights != nullptr) {
1032     TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
1033     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
1034     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
1035     TF_LITE_ENSURE_TYPES_EQ(context, projection_weights->type,
1036                             input_to_forget_weights->type);
1037   }
1038 
1039   const TfLiteTensor* projection_bias =
1040       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
1041   if (projection_bias != nullptr) {
1042     TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
1043     TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
1044     if (is_integer) {
1045       TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
1046     } else {
1047       TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
1048     }
1049   }
1050 
1051   // Making sure the projection tensors are consistent:
1052   // 1) If projection weight is not present, then projection bias should not be
1053   // present.
1054   // 2) If projection weight is present, then projection bias is optional.
1055   // TODO(ghodrat): make sure this is correct.
1056   const bool projection_tensors_consistent =
1057       ((projection_weights != nullptr) || (projection_bias == nullptr));
1058   TF_LITE_ENSURE(context, projection_tensors_consistent == true);
1059 
1060   if (use_layer_norm) {
1061     const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
1062         context, node, kInputLayerNormCoefficientsTensor);
1063     if (use_cifg) {
1064       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
1065     } else {
1066       TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
1067       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
1068       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
1069                         n_cell);
1070       if (is_integer) {
1071         TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
1072                                 kTfLiteInt16);
1073       } else {
1074         TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
1075                                 kTfLiteFloat32);
1076       }
1077     }
1078 
1079     const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
1080         context, node, kForgetLayerNormCoefficientsTensor);
1081     TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
1082     TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
1083     TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
1084                       n_cell);
1085     if (is_integer) {
1086       TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
1087                               kTfLiteInt16);
1088     } else {
1089       TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
1090                               kTfLiteFloat32);
1091     }
1092 
1093     const TfLiteTensor* cell_layer_norm_coefficients =
1094         GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor);
1095     TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
1096     TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
1097     TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
1098                       n_cell);
1099     if (is_integer) {
1100       TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
1101                               kTfLiteInt16);
1102     } else {
1103       TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
1104                               kTfLiteFloat32);
1105     }
1106 
1107     const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
1108         context, node, kOutputLayerNormCoefficientsTensor);
1109     TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
1110     TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
1111     TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
1112                       n_cell);
1113     if (is_integer) {
1114       TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
1115                               kTfLiteInt16);
1116     } else {
1117       TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
1118                               kTfLiteFloat32);
1119     }
1120   }
1121 
1122   return kTfLiteOk;
1123 }
1124 // LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
1125 
PrecomputeZeroPointTimesWeightWithBias(TfLiteContext * context,int32_t zero_point,const TfLiteTensor * weight_tensor,const TfLiteTensor * bias_tensor,std::unique_ptr<int32_t[]> * output)1126 TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
1127     TfLiteContext* context, int32_t zero_point,
1128     const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor,
1129     std::unique_ptr<int32_t[]>* output) {
1130   if (weight_tensor == nullptr) {
1131     return kTfLiteOk;
1132   }
1133 
1134   const RuntimeShape& weight_shape = GetTensorShape(weight_tensor);
1135   TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2);
1136   const int row = weight_shape.Dims(0);
1137   const int col = weight_shape.Dims(1);
1138   output->reset(new int32_t[row]);
1139   if (bias_tensor == nullptr) {
1140     memset(output->get(), 0, row * sizeof(int32_t));
1141   } else {
1142     const int32_t* bias = GetTensorData<int32_t>(bias_tensor);
1143     memcpy(output->get(), bias, row * sizeof(int32_t));
1144   }
1145   if (zero_point != 0) {
1146     const int8_t* weight = GetTensorData<int8_t>(weight_tensor);
1147     tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, col,
1148                                                  output->get());
1149   }
1150   return kTfLiteOk;
1151 }
1152 
PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext * context,OpData * op_data,TfLiteNode * node)1153 TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
1154                                                        OpData* op_data,
1155                                                        TfLiteNode* node) {
1156   const TfLiteTensor* input;
1157   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
1158   const TfLiteTensor* output_state =
1159       GetVariableInput(context, node, kOutputStateTensor);
1160   TF_LITE_ENSURE(context, output_state != nullptr);
1161 
1162   const int32_t input_zero_point = -input->params.zero_point;
1163   const int32_t output_state_zero_point = -output_state->params.zero_point;
1164 
1165   const TfLiteTensor* input_to_input_weights =
1166       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
1167   const TfLiteTensor* input_to_forget_weights;
1168   TF_LITE_ENSURE_OK(context,
1169                     GetInputSafe(context, node, kInputToForgetWeightsTensor,
1170                                  &input_to_forget_weights));
1171   const TfLiteTensor* input_to_cell_weights;
1172   TF_LITE_ENSURE_OK(context,
1173                     GetInputSafe(context, node, kInputToCellWeightsTensor,
1174                                  &input_to_cell_weights));
1175   const TfLiteTensor* input_to_output_weights;
1176   TF_LITE_ENSURE_OK(context,
1177                     GetInputSafe(context, node, kInputToOutputWeightsTensor,
1178                                  &input_to_output_weights));
1179 
1180   const TfLiteTensor* recurrent_to_input_weights =
1181       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
1182   const TfLiteTensor* recurrent_to_forget_weights;
1183   TF_LITE_ENSURE_OK(context,
1184                     GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
1185                                  &recurrent_to_forget_weights));
1186   const TfLiteTensor* recurrent_to_cell_weights;
1187   TF_LITE_ENSURE_OK(context,
1188                     GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
1189                                  &recurrent_to_cell_weights));
1190   const TfLiteTensor* recurrent_to_output_weights;
1191   TF_LITE_ENSURE_OK(context,
1192                     GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
1193                                  &recurrent_to_output_weights));
1194 
1195   const TfLiteTensor* projection_weights =
1196       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
1197   const TfLiteTensor* projection_bias =
1198       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
1199 
1200   lstm_eval::IntegerLstmParameter* integer_lstm_params =
1201       &op_data->integer_lstm_param;
1202 
1203   const TfLiteTensor* intermediate =
1204       &context->tensors[node->intermediates->data[4]];
1205   const auto* params =
1206       static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params);
1207   const int32_t hidden_zp = params->zero_point->data[0];
1208 
1209   // Get bias and perform zero point calculation.
1210   // When there is layer normalization, the gate bias does not apply to matmul
1211   // directly:
1212   //      y = ln(w * x + w * r + w * c) + b.
1213   const bool is_layer_norm = op_data->use_layer_norm;
1214 
1215   // Forget gate.
1216   const TfLiteTensor* forget_gate_bias =
1217       is_layer_norm ? nullptr : GetInput(context, node, kForgetGateBiasTensor);
1218   TF_LITE_ENSURE_OK(
1219       context,
1220       PrecomputeZeroPointTimesWeightWithBias(
1221           context, input_zero_point, input_to_forget_weights, forget_gate_bias,
1222           &(integer_lstm_params->input_to_forget_effective_bias)));
1223 
1224   TF_LITE_ENSURE_OK(
1225       context,
1226       PrecomputeZeroPointTimesWeightWithBias(
1227           context, output_state_zero_point, recurrent_to_forget_weights,
1228           nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
1229 
1230   // Modulation gate.
1231   const TfLiteTensor* cell_gate_bias =
1232       is_layer_norm ? nullptr : GetInput(context, node, kCellGateBiasTensor);
1233   TF_LITE_ENSURE_OK(
1234       context,
1235       PrecomputeZeroPointTimesWeightWithBias(
1236           context, input_zero_point, input_to_cell_weights, cell_gate_bias,
1237           &(integer_lstm_params->input_to_cell_effective_bias)));
1238   TF_LITE_ENSURE_OK(
1239       context,
1240       PrecomputeZeroPointTimesWeightWithBias(
1241           context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
1242           &(integer_lstm_params->recurrent_to_cell_effective_bias)));
1243 
1244   // Output gate.
1245   const TfLiteTensor* output_gate_bias =
1246       is_layer_norm ? nullptr : GetInput(context, node, kOutputGateBiasTensor);
1247   TF_LITE_ENSURE_OK(
1248       context,
1249       PrecomputeZeroPointTimesWeightWithBias(
1250           context, input_zero_point, input_to_output_weights, output_gate_bias,
1251           &(integer_lstm_params->input_to_output_effective_bias)));
1252 
1253   TF_LITE_ENSURE_OK(
1254       context,
1255       PrecomputeZeroPointTimesWeightWithBias(
1256           context, output_state_zero_point, recurrent_to_output_weights,
1257           nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
1258 
1259   // Input gate. The calculation is only meaningful for non-cifg case.
1260   const TfLiteTensor* input_gate_bias =
1261       is_layer_norm ? nullptr : GetInput(context, node, kInputGateBiasTensor);
1262   TF_LITE_ENSURE_OK(
1263       context,
1264       PrecomputeZeroPointTimesWeightWithBias(
1265           context, input_zero_point, input_to_input_weights, input_gate_bias,
1266           &(integer_lstm_params->input_to_input_effective_bias)));
1267   TF_LITE_ENSURE_OK(
1268       context,
1269       PrecomputeZeroPointTimesWeightWithBias(
1270           context, output_state_zero_point, recurrent_to_input_weights, nullptr,
1271           &(integer_lstm_params->recurrent_to_input_effective_bias)));
1272 
1273   // Projection bias. The calculation is only meaningful for with projection.
1274   TF_LITE_ENSURE_OK(context,
1275                     PrecomputeZeroPointTimesWeightWithBias(
1276                         context, hidden_zp, projection_weights, projection_bias,
1277                         &(integer_lstm_params->projection_effective_bias)));
1278   return kTfLiteOk;
1279 }
1280 
1281 // Resize the output, state tensors based on the sizes of the input tensors.
1282 // Allocate a temporary scratch tensor. Also check that the sizes of the input
1283 // tensors match each other.
1284 // LINT.IfChange
Prepare(TfLiteContext * context,TfLiteNode * node)1285 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
1286   OpData* op_data = static_cast<OpData*>(node->user_data);
1287 
1288   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
1289   // Logic for determining regular lstm and layer norm lstm:
1290   // input_size, forget_gate_layer_norm_tensor (20) null? is_layer_norm?
1291   // 20,         N/A,                                     No.
1292   // 24,         null,                                    No.
1293   // 24,         not null,                                Yes.
1294   // 20-inputs lstm are deprecated and is only kept here for backward
1295   // compatibility.
1296   if (node->inputs->size == 24) {
1297     const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
1298         context, node, kForgetLayerNormCoefficientsTensor);
1299     if (forget_layer_norm_coefficients == nullptr) {
1300       op_data->use_layer_norm = false;
1301     } else {
1302       op_data->use_layer_norm = true;
1303     }
1304   } else if (node->inputs->size == 20) {
1305     // This is deprecated and is only kept here for backward compatibility.
1306     op_data->use_layer_norm = false;
1307   } else {
1308     TF_LITE_KERNEL_LOG(
1309         context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
1310         node->inputs->size);
1311     return kTfLiteError;
1312   }
1313 
1314   const bool use_layer_norm = op_data->use_layer_norm;
1315 
1316   // Inferring batch size, number of outputs and number of cells from the
1317   // input tensors.
1318   const TfLiteTensor* input;
1319   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
1320   const bool is_integer = input->type == kTfLiteInt8;
1321   TF_LITE_ENSURE(context, input->dims->size > 1);
1322   const int n_batch = input->dims->data[0];
1323   const int n_input = input->dims->data[1];
1324 
1325   const TfLiteTensor* input_to_output_weights;
1326   TF_LITE_ENSURE_OK(context,
1327                     GetInputSafe(context, node, kInputToOutputWeightsTensor,
1328                                  &input_to_output_weights));
1329   const int n_cell = input_to_output_weights->dims->data[0];
1330   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
1331   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
1332 
1333   const TfLiteTensor* recurrent_to_output_weights;
1334   TF_LITE_ENSURE_OK(context,
1335                     GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
1336                                  &recurrent_to_output_weights));
1337   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
1338   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
1339                     n_cell);
1340   const int n_output = recurrent_to_output_weights->dims->data[1];
1341 
1342   // Check that input tensor dimensions matches with each other.
1343   TF_LITE_ENSURE_OK(
1344       context, CheckInputTensorDimensions(context, node, n_input, n_output,
1345                                           n_cell, use_layer_norm, is_integer));
1346 
1347   // Get the pointer to output, output_state and cell_state tensors.
1348   TfLiteTensor* output;
1349   TF_LITE_ENSURE_OK(context,
1350                     GetOutputSafe(context, node, kOutputTensor, &output));
1351 
1352   TfLiteTensor* output_state =
1353       GetVariableInput(context, node, kOutputStateTensor);
1354   TF_LITE_ENSURE(context, output_state != nullptr);
1355   TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
1356   TF_LITE_ENSURE(context, cell_state != nullptr);
1357 
1358   // Check the shape of input state tensors.
1359   // These tensor may be 1D or 2D. It's fine as long as the total size is
1360   // correct.
1361   TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output);
1362   TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
1363 
1364   // Resize the output tensors.
1365   TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
1366   output_size->data[0] = n_batch;
1367   output_size->data[1] = n_output;
1368   TF_LITE_ENSURE_OK(context,
1369                     context->ResizeTensor(context, output, output_size));
1370 
1371   // The weights are of consistent type, so it suffices to check one.
1372   const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights);
1373 
1374   const bool is_sparse_op = (input_to_output_weights->sparsity != nullptr);
1375 
1376   // The type of Integer LSTM.
1377   const int num_intermediate_tensors = node->intermediates->size;
1378   if (is_integer) {
1379     TF_LITE_ENSURE(context, num_intermediate_tensors == 5 ||
1380                                 num_intermediate_tensors == 12);
1381   }
1382   // We use number of intermediate tensors to distinguish the 8 bit matmul
1383   // output and the 16 bit matmul output version.
1384   const bool is_8x8_16 = num_intermediate_tensors == 5;
1385 
1386   TfLiteIntArrayFree(node->temporaries);
1387   if (is_hybrid_op) {
1388     if (is_sparse_op) {
1389       node->temporaries =
1390           TfLiteIntArrayCreate(kNumHybridTemporaryTensors + kLedgersToAdd);
1391     } else {
1392       node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors);
1393     }
1394   } else if (is_integer) {
1395     if (is_8x8_16) {
1396       node->temporaries = TfLiteIntArrayCreate(6);
1397     } else {
1398       node->temporaries = TfLiteIntArrayCreate(8);
1399     }
1400   } else {
1401     node->temporaries = TfLiteIntArrayCreate(1);
1402   }
1403 
1404   // Create a scratch buffer tensor for float case and hybrid case.
1405   // TODO(b/152066492): Create a is_float boolean and reorganize the temporary
1406   // buffer allocation logic.
1407   if (!is_integer) {
1408     node->temporaries->data[kScratchBuffer] =
1409         op_data->scratch_tensor_index + kScratchBuffer;
1410     TfLiteTensor* scratch_buffer;
1411     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
1412                                                 &scratch_buffer));
1413     scratch_buffer->type = input->type;
1414     scratch_buffer->allocation_type = kTfLiteArenaRw;
1415 
1416     const TfLiteTensor* input_to_input_weights =
1417         GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
1418     const bool use_cifg = (input_to_input_weights == nullptr);
1419     TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
1420     scratch_buffer_size->data[0] = n_batch;
1421     if (use_cifg) {
1422       // Reserving space for Cell, Forget, Output gates and scratch accumulation
1423       // buffer and an extra 16 bytes to avoid internal ruy copies.
1424       scratch_buffer_size->data[1] = n_cell * 4;
1425     } else {
1426       // Reserving space for Input, Cell, Forget, Output gates and scratch
1427       // accumulation buffer and an extra 16 bytes to avoid internal ruy copies.
1428       scratch_buffer_size->data[1] = n_cell * 5;
1429     }
1430     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
1431                                                      scratch_buffer_size));
1432   }
1433 
1434   if (is_hybrid_op) {
1435     if (!is_sparse_op) {
1436       op_data->compute_row_sums = true;
1437     }
1438     // Allocate temporary tensors to store quantized values of input,
1439     // output_state and cell_state tensors.
1440     node->temporaries->data[kInputQuantized] =
1441         op_data->scratch_tensor_index + kInputQuantized;
1442     TfLiteTensor* input_quantized;
1443     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
1444                                                 &input_quantized));
1445     input_quantized->type = input_to_output_weights->type;
1446     input_quantized->allocation_type = kTfLiteArenaRw;
1447     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
1448       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
1449       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
1450                                                        input_quantized_size));
1451     }
1452     node->temporaries->data[kOutputStateQuantized] =
1453         op_data->scratch_tensor_index + kOutputStateQuantized;
1454     TfLiteTensor* output_state_quantized;
1455     TF_LITE_ENSURE_OK(context,
1456                       GetTemporarySafe(context, node, kOutputStateQuantized,
1457                                        &output_state_quantized));
1458     output_state_quantized->type = input_to_output_weights->type;
1459     output_state_quantized->allocation_type = kTfLiteArenaRw;
1460     if (!TfLiteIntArrayEqual(output_state_quantized->dims,
1461                              output_state->dims)) {
1462       TfLiteIntArray* output_state_quantized_size =
1463           TfLiteIntArrayCopy(output_state->dims);
1464       TF_LITE_ENSURE_OK(context,
1465                         context->ResizeTensor(context, output_state_quantized,
1466                                               output_state_quantized_size));
1467     }
1468     node->temporaries->data[kCellStateQuantized] =
1469         op_data->scratch_tensor_index + kCellStateQuantized;
1470     TfLiteTensor* cell_state_quantized;
1471     TF_LITE_ENSURE_OK(context,
1472                       GetTemporarySafe(context, node, kCellStateQuantized,
1473                                        &cell_state_quantized));
1474     cell_state_quantized->type = input_to_output_weights->type;
1475     cell_state_quantized->allocation_type = kTfLiteArenaRw;
1476     if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
1477       TfLiteIntArray* cell_state_quantized_size =
1478           TfLiteIntArrayCopy(cell_state->dims);
1479       TF_LITE_ENSURE_OK(context,
1480                         context->ResizeTensor(context, cell_state_quantized,
1481                                               cell_state_quantized_size));
1482     }
1483     // Allocate temporary tensors to store scaling factors and product scaling
1484     // factors. The latter is a convenience storage which allows to quantize
1485     // a vector once (which produces the scaling factors) and multiply it with
1486     // different matrices (which requires multiplying the scaling factors with
1487     // the scaling factor of the matrix).
1488     node->temporaries->data[kInputScalingFactors] =
1489         op_data->scratch_tensor_index + kInputScalingFactors;
1490     TfLiteTensor* input_sf;
1491     TF_LITE_ENSURE_OK(
1492         context,
1493         GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
1494     input_sf->type = kTfLiteFloat32;
1495     input_sf->allocation_type = kTfLiteArenaRw;
1496     int scaling_dims[1] = {n_batch};
1497     if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
1498       TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
1499       input_sf_size->data[0] = n_batch;
1500       TF_LITE_ENSURE_OK(
1501           context, context->ResizeTensor(context, input_sf, input_sf_size));
1502     }
1503     node->temporaries->data[kOutputStateScalingFactors] =
1504         op_data->scratch_tensor_index + kOutputStateScalingFactors;
1505     TfLiteTensor* output_state_sf;
1506     TF_LITE_ENSURE_OK(
1507         context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
1508                                   &output_state_sf));
1509     output_state_sf->type = kTfLiteFloat32;
1510     output_state_sf->allocation_type = kTfLiteArenaRw;
1511     if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
1512       TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
1513       output_state_sf_size->data[0] = n_batch;
1514       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
1515                                                        output_state_sf_size));
1516     }
1517     node->temporaries->data[kProductScalingFactors] =
1518         op_data->scratch_tensor_index + kProductScalingFactors;
1519     TfLiteTensor* prod_scaling_factors;
1520     TF_LITE_ENSURE_OK(context,
1521                       GetTemporarySafe(context, node, kProductScalingFactors,
1522                                        &prod_scaling_factors));
1523     prod_scaling_factors->type = kTfLiteFloat32;
1524     prod_scaling_factors->allocation_type = kTfLiteArenaRw;
1525     if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
1526                                    scaling_dims)) {
1527       TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
1528       prod_scaling_factors_size->data[0] = n_batch;
1529       TF_LITE_ENSURE_OK(context,
1530                         context->ResizeTensor(context, prod_scaling_factors,
1531                                               prod_scaling_factors_size));
1532     }
1533 
1534     // Allocate a temporary tensor to store the recovered cell weights. Since
1535     // this is used for diagonal matrices, only need to store n_cell values.
1536     node->temporaries->data[kRecoveredCellWeights] =
1537         op_data->scratch_tensor_index + kRecoveredCellWeights;
1538     TfLiteTensor* recovered_cell_weights;
1539     TF_LITE_ENSURE_OK(context,
1540                       GetTemporarySafe(context, node, kRecoveredCellWeights,
1541                                        &recovered_cell_weights));
1542     recovered_cell_weights->type = kTfLiteFloat32;
1543     recovered_cell_weights->allocation_type = kTfLiteArenaRw;
1544     int recovered_cell_dims[1] = {n_cell};
1545     if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
1546                                    recovered_cell_dims)) {
1547       TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
1548       recovered_cell_weights_size->data[0] = n_cell;
1549       TF_LITE_ENSURE_OK(context,
1550                         context->ResizeTensor(context, recovered_cell_weights,
1551                                               recovered_cell_weights_size));
1552     }
1553     // Allocate a temporary tensor to store accumulate values for matrix
1554     // multiplication before multiplication by scaling factor
1555     node->temporaries->data[kAccumScratch] =
1556         op_data->scratch_tensor_index + kAccumScratch;
1557     TfLiteTensor* accum_scratch;
1558     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
1559                                                 &accum_scratch));
1560     accum_scratch->type = kTfLiteInt32;
1561     accum_scratch->allocation_type = kTfLiteArenaRw;
1562     int accum_scratch_dims[2] = {n_cell, n_batch};
1563     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
1564                                    accum_scratch_dims)) {
1565       TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
1566       accum_size->data[0] = n_cell;
1567       accum_size->data[1] = n_batch;
1568       TF_LITE_ENSURE_OK(
1569           context, context->ResizeTensor(context, accum_scratch, accum_size));
1570     }
1571     node->temporaries->data[kInputZeroPoints] =
1572         op_data->scratch_tensor_index + kInputZeroPoints;
1573     TfLiteTensor* input_zp;
1574     TF_LITE_ENSURE_OK(
1575         context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
1576     input_zp->type = kTfLiteFloat32;
1577     input_zp->allocation_type = kTfLiteArenaRw;
1578     if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
1579       TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
1580       input_zp_size->data[0] = n_batch;
1581       TF_LITE_ENSURE_OK(
1582           context, context->ResizeTensor(context, input_zp, input_zp_size));
1583     }
1584     node->temporaries->data[kOutputStateZeroPoints] =
1585         op_data->scratch_tensor_index + kOutputStateZeroPoints;
1586     TfLiteTensor* output_state_zp;
1587     TF_LITE_ENSURE_OK(context,
1588                       GetTemporarySafe(context, node, kOutputStateZeroPoints,
1589                                        &output_state_zp));
1590     output_state_zp->type = kTfLiteFloat32;
1591     output_state_zp->allocation_type = kTfLiteArenaRw;
1592     if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
1593       TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
1594       output_state_zp_size->data[0] = n_batch;
1595       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
1596                                                        output_state_zp_size));
1597     }
1598 
1599     node->temporaries->data[kRowSums] =
1600         op_data->scratch_tensor_index + kRowSums;
1601     const TfLiteTensor* input_to_input_weights =
1602         GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
1603     const bool use_cifg = (input_to_input_weights == nullptr);
1604     int row_sums_rows = use_cifg ? 6 : 8;
1605     const TfLiteTensor* projection_weights =
1606         GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
1607     if (projection_weights != nullptr) {
1608       row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
1609     }
1610 
1611     TfLiteTensor* row_sums;
1612     TF_LITE_ENSURE_OK(context,
1613                       GetTemporarySafe(context, node, kRowSums, &row_sums));
1614     row_sums->type = kTfLiteInt32;
1615     row_sums->name = "Lstm_row_sums";
1616     row_sums->allocation_type = kTfLiteArenaRwPersistent;
1617     const int row_sums_dims[2] = {row_sums_rows, n_cell};
1618     if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
1619       TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
1620       row_sums_size->data[0] = row_sums_dims[0];
1621       row_sums_size->data[1] = row_sums_dims[1];
1622       TF_LITE_ENSURE_OK(
1623           context, context->ResizeTensor(context, row_sums, row_sums_size));
1624     }
1625 
1626     if (is_sparse_op) {
1627       op_data->ledger_initialized = false;
1628       int offset = kNumHybridTemporaryTensors;
1629       {
1630         node->temporaries->data[offset + kInputToInputWeightsLedgerOffset] =
1631             op_data->ledger_index + kInputToInputWeightsLedgerOffset;
1632         const TfLiteTensor* input_to_input_weights =
1633             GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
1634         TfLiteTensor* input_to_input_weights_ledger =
1635             &context->tensors[op_data->ledger_index +
1636                               kInputToInputWeightsLedgerOffset];
1637         auto status = make_ledger(input_to_input_weights == nullptr
1638                                       ? nullptr
1639                                       : input_to_input_weights->sparsity,
1640                                   context, input_to_input_weights_ledger);
1641         if (status != kTfLiteOk) return status;
1642       }
1643       {
1644         node->temporaries->data[offset + kInputToForgetWeightsLedgerOffset] =
1645             op_data->ledger_index + kInputToForgetWeightsLedgerOffset;
1646         const TfLiteTensor* input_to_forget_weights =
1647             GetInput(context, node, kInputToForgetWeightsTensor);
1648         TfLiteTensor* input_to_forget_weights_ledger =
1649             &context->tensors[op_data->ledger_index +
1650                               kInputToForgetWeightsLedgerOffset];
1651         auto status = make_ledger(input_to_forget_weights->sparsity, context,
1652                                   input_to_forget_weights_ledger);
1653         if (status != kTfLiteOk) return status;
1654       }
1655       {
1656         node->temporaries->data[offset + kInputToCellWeightsLedgerOffset] =
1657             op_data->ledger_index + kInputToCellWeightsLedgerOffset;
1658         const TfLiteTensor* input_to_cell_weights =
1659             GetInput(context, node, kInputToCellWeightsTensor);
1660         TfLiteTensor* input_to_cell_weights_ledger =
1661             &context->tensors[op_data->ledger_index +
1662                               kInputToCellWeightsLedgerOffset];
1663         auto status = make_ledger(input_to_cell_weights->sparsity, context,
1664                                   input_to_cell_weights_ledger);
1665         if (status != kTfLiteOk) return status;
1666       }
1667       {
1668         node->temporaries->data[offset + kInputToOutputWeightsLedgerOffset] =
1669             op_data->ledger_index + kInputToOutputWeightsLedgerOffset;
1670         const TfLiteTensor* input_to_output_weights =
1671             GetInput(context, node, kInputToOutputWeightsTensor);
1672         TfLiteTensor* input_to_output_weights_ledger =
1673             &context->tensors[op_data->ledger_index +
1674                               kInputToOutputWeightsLedgerOffset];
1675         auto status = make_ledger(input_to_output_weights->sparsity, context,
1676                                   input_to_output_weights_ledger);
1677         if (status != kTfLiteOk) return status;
1678       }
1679       {
1680         node->temporaries->data[offset + kRecurrentToInputWeightsLedgerOffset] =
1681             op_data->ledger_index + kRecurrentToInputWeightsLedgerOffset;
1682         const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
1683             context, node, kRecurrentToInputWeightsTensor);
1684         TfLiteTensor* recurrent_to_input_weights_ledger =
1685             &context->tensors[op_data->ledger_index +
1686                               kRecurrentToInputWeightsLedgerOffset];
1687         auto status = make_ledger(recurrent_to_input_weights == nullptr
1688                                       ? nullptr
1689                                       : recurrent_to_input_weights->sparsity,
1690                                   context, recurrent_to_input_weights_ledger);
1691         if (status != kTfLiteOk) return status;
1692       }
1693       {
1694         node->temporaries
1695             ->data[offset + kRecurrentToForgetWeightsLedgerOffset] =
1696             op_data->ledger_index + kRecurrentToForgetWeightsLedgerOffset;
1697         const TfLiteTensor* recurrent_to_forget_weights =
1698             GetInput(context, node, kRecurrentToForgetWeightsTensor);
1699         TfLiteTensor* recurrent_to_forget_weights_ledger =
1700             &context->tensors[op_data->ledger_index +
1701                               kRecurrentToForgetWeightsLedgerOffset];
1702         auto status = make_ledger(recurrent_to_forget_weights->sparsity,
1703                                   context, recurrent_to_forget_weights_ledger);
1704         if (status != kTfLiteOk) return status;
1705       }
1706       {
1707         node->temporaries->data[offset + kRecurrentToCellWeightsLedgerOffset] =
1708             op_data->ledger_index + kRecurrentToCellWeightsLedgerOffset;
1709         const TfLiteTensor* recurrent_to_cell_weights =
1710             GetInput(context, node, kRecurrentToCellWeightsTensor);
1711         TfLiteTensor* recurrent_to_cell_weights_ledger =
1712             &context->tensors[op_data->ledger_index +
1713                               kRecurrentToCellWeightsLedgerOffset];
1714         auto status = make_ledger(recurrent_to_cell_weights->sparsity, context,
1715                                   recurrent_to_cell_weights_ledger);
1716         if (status != kTfLiteOk) return status;
1717       }
1718       {
1719         node->temporaries
1720             ->data[offset + kRecurrentToOutputWeightsLedgerOffset] =
1721             op_data->ledger_index + kRecurrentToOutputWeightsLedgerOffset;
1722         const TfLiteTensor* recurrent_to_output_weights =
1723             GetInput(context, node, kRecurrentToOutputWeightsTensor);
1724         TfLiteTensor* recurrent_to_output_weights_ledger =
1725             &context->tensors[op_data->ledger_index +
1726                               kRecurrentToOutputWeightsLedgerOffset];
1727         auto status = make_ledger(recurrent_to_output_weights->sparsity,
1728                                   context, recurrent_to_output_weights_ledger);
1729         if (status != kTfLiteOk) return status;
1730       }
1731       {
1732         node->temporaries->data[offset + kProjectionWeightsLedgerOffset] =
1733             op_data->ledger_index + kProjectionWeightsLedgerOffset;
1734         const TfLiteTensor* projection_weights =
1735             GetInput(context, node, kProjectionWeightsTensor);
1736         TfLiteTensor* projection_weights_ledger =
1737             &context->tensors[op_data->ledger_index +
1738                               kProjectionWeightsLedgerOffset];
1739         auto status = make_ledger(projection_weights->sparsity, context,
1740                                   projection_weights_ledger);
1741         if (status != kTfLiteOk) return status;
1742       }
1743     }
1744   }
1745 
1746   if (is_integer) {
1747     if (is_8x8_16) {
1748       // Integer LSTM prepare function for 8x8->16.
1749       // This code path needs 5 intermediate tensors per Op.
1750       // Populate quantization parameters.
1751       PopulateQuantizedLstmParams8x8_16(context, node,
1752                                         &op_data->integer_lstm_param);
1753 
1754       // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
1755       // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
1756       // buffer with size n_batch * n_cell.
1757       //
1758       // Handle cifg case as well, which might save one buffer.
1759       for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
1760         node->temporaries->data[scratch_index] =
1761             op_data->scratch_tensor_index + scratch_index;
1762         TfLiteTensor* scratch_tensor;
1763         TF_LITE_ENSURE_OK(
1764             context,
1765             GetTemporarySafe(context, node, scratch_index, &scratch_tensor));
1766         scratch_tensor->type = kTfLiteInt16;
1767         if (scratch_index == 4) {
1768           scratch_tensor->type = kTfLiteInt8;
1769         } else if (scratch_index == 5) {
1770           scratch_tensor->type = kTfLiteInt32;
1771         }
1772         scratch_tensor->allocation_type = kTfLiteArenaRw;
1773         const int scratch_dimension[2] = {n_batch, n_cell};
1774         if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
1775                                        scratch_dimension)) {
1776           TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
1777           scratch_buffer_size->data[0] = n_batch;
1778           scratch_buffer_size->data[1] = n_cell;
1779           TF_LITE_ENSURE_OK(context,
1780                             context->ResizeTensor(context, scratch_tensor,
1781                                                   scratch_buffer_size));
1782         }
1783       }
1784 
1785       // Populate precomputed zp * weight.
1786       TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
1787                                      context, op_data, node));
1788     } else {
1789       // Integer LSTM prepare function for 8x8->8.
1790       // This code path needs 12 intermediate tensors per Op.
1791       PopulateQuantizedLstmParams8x8_8(context, node,
1792                                        &op_data->integer_lstm_param);
1793 
1794       // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
1795       // and 2 8bit buffer with size n_batch * n_cell.
1796       //
1797       // Handle cifg case as well, which might save one buffer.
1798       for (int scratch_index = 0; scratch_index < 8; ++scratch_index) {
1799         node->temporaries->data[scratch_index] =
1800             op_data->scratch_tensor_index + scratch_index;
1801         TfLiteTensor* scratch_tensor;
1802         TF_LITE_ENSURE_OK(
1803             context,
1804             GetTemporarySafe(context, node, scratch_index, &scratch_tensor));
1805         if (scratch_index == 0 || scratch_index == 1) {
1806           scratch_tensor->type = kTfLiteInt8;
1807         } else {
1808           scratch_tensor->type = kTfLiteInt16;
1809         }
1810         scratch_tensor->allocation_type = kTfLiteArenaRw;
1811         const int scratch_dimension[2] = {n_batch, n_cell};
1812         if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
1813                                        scratch_dimension)) {
1814           TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
1815           scratch_buffer_size->data[0] = n_batch;
1816           scratch_buffer_size->data[1] = n_cell;
1817           TF_LITE_ENSURE_OK(context,
1818                             context->ResizeTensor(context, scratch_tensor,
1819                                                   scratch_buffer_size));
1820         }
1821       }
1822     }
1823   }
1824   return kTfLiteOk;
1825 }
1826 // LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
1827 
1828 // LINT.IfChange
Eval(TfLiteContext * context,TfLiteNode * node)1829 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
1830   const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
1831   OpData* op_data = static_cast<OpData*>(node->user_data);
1832 
1833   const TfLiteTensor* input;
1834   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
1835 
1836   const TfLiteTensor* input_to_input_weights =
1837       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
1838   const TfLiteTensor* input_to_forget_weights;
1839   TF_LITE_ENSURE_OK(context,
1840                     GetInputSafe(context, node, kInputToForgetWeightsTensor,
1841                                  &input_to_forget_weights));
1842   const TfLiteTensor* input_to_cell_weights;
1843   TF_LITE_ENSURE_OK(context,
1844                     GetInputSafe(context, node, kInputToCellWeightsTensor,
1845                                  &input_to_cell_weights));
1846   const TfLiteTensor* input_to_output_weights;
1847   TF_LITE_ENSURE_OK(context,
1848                     GetInputSafe(context, node, kInputToOutputWeightsTensor,
1849                                  &input_to_output_weights));
1850 
1851   const TfLiteTensor* recurrent_to_input_weights =
1852       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
1853   const TfLiteTensor* recurrent_to_forget_weights;
1854   TF_LITE_ENSURE_OK(context,
1855                     GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
1856                                  &recurrent_to_forget_weights));
1857   const TfLiteTensor* recurrent_to_cell_weights;
1858   TF_LITE_ENSURE_OK(context,
1859                     GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
1860                                  &recurrent_to_cell_weights));
1861   const TfLiteTensor* recurrent_to_output_weights;
1862   TF_LITE_ENSURE_OK(context,
1863                     GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
1864                                  &recurrent_to_output_weights));
1865 
1866   const TfLiteTensor* cell_to_input_weights =
1867       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
1868   const TfLiteTensor* cell_to_forget_weights =
1869       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
1870   const TfLiteTensor* cell_to_output_weights =
1871       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
1872 
1873   const TfLiteTensor* input_layer_norm_coefficients =
1874       GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor);
1875   const TfLiteTensor* forget_layer_norm_coefficients =
1876       GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor);
1877   const TfLiteTensor* cell_layer_norm_coefficients =
1878       GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor);
1879   const TfLiteTensor* output_layer_norm_coefficients =
1880       GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor);
1881 
1882   const TfLiteTensor* input_gate_bias =
1883       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
1884   const TfLiteTensor* forget_gate_bias;
1885   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
1886                                           &forget_gate_bias));
1887   const TfLiteTensor* cell_gate_bias;
1888   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
1889                                           &cell_gate_bias));
1890   const TfLiteTensor* output_gate_bias;
1891   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
1892                                           &output_gate_bias));
1893 
1894   const TfLiteTensor* projection_weights =
1895       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
1896   const TfLiteTensor* projection_bias =
1897       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
1898 
1899   TfLiteTensor* output_state =
1900       GetVariableInput(context, node, kOutputStateTensor);
1901   TFLITE_DCHECK(output_state != nullptr);
1902   TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
1903   TFLITE_DCHECK(cell_state != nullptr);
1904 
1905   TfLiteTensor* output;
1906   TF_LITE_ENSURE_OK(context,
1907                     GetOutputSafe(context, node, kOutputTensor, &output));
1908 
1909   switch (input_to_output_weights->type) {
1910     case kTfLiteFloat32: {
1911       // Index the scratch buffers pointers to the global scratch buffer.
1912       TfLiteTensor* scratch_buffer;
1913       TF_LITE_ENSURE_OK(context,
1914                         GetTemporarySafe(context, node, 0, &scratch_buffer));
1915       return lstm_eval::EvalFloat(
1916           input, input_to_input_weights, input_to_forget_weights,
1917           input_to_cell_weights, input_to_output_weights,
1918           recurrent_to_input_weights, recurrent_to_forget_weights,
1919           recurrent_to_cell_weights, recurrent_to_output_weights,
1920           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
1921           input_layer_norm_coefficients, forget_layer_norm_coefficients,
1922           cell_layer_norm_coefficients, output_layer_norm_coefficients,
1923           /*aux_input=*/nullptr,
1924           /*aux_input_to_input_weights=*/nullptr,
1925           /*aux_input_to_forget_weights=*/nullptr,
1926           /*aux_input_to_cell_weights=*/nullptr,
1927           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
1928           forget_gate_bias, cell_gate_bias, output_gate_bias,
1929           projection_weights, projection_bias, params,
1930           /*forward_sequence=*/true,
1931           /*time_major=*/true,
1932           /*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
1933           CpuBackendContext::GetFromContext(context));
1934     }
1935     case kTfLiteUInt8:
1936     case kTfLiteInt8: {
1937       const bool is_hybrid = (input->type == kTfLiteFloat32);
1938       const bool is_sparse = input_to_output_weights->sparsity != nullptr;
1939       if (is_hybrid) {
1940         TfLiteTensor* row_sums;
1941         TF_LITE_ENSURE_OK(context,
1942                           GetTemporarySafe(context, node, kRowSums, &row_sums));
1943         const int row_sums_size = row_sums->dims->data[0];
1944         if (is_sparse) {
1945           TfLiteTensor* input_to_input_weights_ledger =
1946               &context->tensors[op_data->ledger_index +
1947                                 kInputToInputWeightsLedgerOffset];
1948           TfLiteTensor* input_to_forget_weights_ledger =
1949               &context->tensors[op_data->ledger_index +
1950                                 kInputToForgetWeightsLedgerOffset];
1951           TfLiteTensor* input_to_cell_weights_ledger =
1952               &context->tensors[op_data->ledger_index +
1953                                 kInputToCellWeightsLedgerOffset];
1954           TfLiteTensor* input_to_output_weights_ledger =
1955               &context->tensors[op_data->ledger_index +
1956                                 kInputToOutputWeightsLedgerOffset];
1957           TfLiteTensor* recurrent_to_input_weights_ledger =
1958               &context->tensors[op_data->ledger_index +
1959                                 kRecurrentToInputWeightsLedgerOffset];
1960           TfLiteTensor* recurrent_to_forget_weights_ledger =
1961               &context->tensors[op_data->ledger_index +
1962                                 kRecurrentToForgetWeightsLedgerOffset];
1963           TfLiteTensor* recurrent_to_cell_weights_ledger =
1964               &context->tensors[op_data->ledger_index +
1965                                 kRecurrentToCellWeightsLedgerOffset];
1966           TfLiteTensor* recurrent_to_output_weights_ledger =
1967               &context->tensors[op_data->ledger_index +
1968                                 kRecurrentToOutputWeightsLedgerOffset];
1969           TfLiteTensor* projection_weights_ledger =
1970               &context->tensors[op_data->ledger_index +
1971                                 kProjectionWeightsLedgerOffset];
1972           if (!op_data->ledger_initialized) {
1973             copy_ledger(input_to_input_weights == nullptr
1974                             ? nullptr
1975                             : input_to_input_weights->sparsity,
1976                         input_to_input_weights_ledger);
1977             copy_ledger(input_to_forget_weights->sparsity,
1978                         input_to_forget_weights_ledger);
1979             copy_ledger(input_to_cell_weights->sparsity,
1980                         input_to_cell_weights_ledger);
1981             copy_ledger(input_to_output_weights->sparsity,
1982                         input_to_output_weights_ledger);
1983             copy_ledger(recurrent_to_input_weights == nullptr
1984                             ? nullptr
1985                             : recurrent_to_input_weights->sparsity,
1986                         recurrent_to_input_weights_ledger);
1987             copy_ledger(recurrent_to_forget_weights->sparsity,
1988                         recurrent_to_forget_weights_ledger);
1989             copy_ledger(recurrent_to_cell_weights->sparsity,
1990                         recurrent_to_cell_weights_ledger);
1991             copy_ledger(recurrent_to_output_weights->sparsity,
1992                         recurrent_to_output_weights_ledger);
1993             copy_ledger(projection_weights->sparsity,
1994                         projection_weights_ledger);
1995             op_data->ledger_initialized = true;
1996           }
1997           return lstm_eval::EvalHybrid(
1998               input, input_to_input_weights, input_to_input_weights_ledger,
1999               input_to_forget_weights, input_to_forget_weights_ledger,
2000               input_to_cell_weights, input_to_cell_weights_ledger,
2001               input_to_output_weights, input_to_output_weights_ledger,
2002               recurrent_to_input_weights, recurrent_to_input_weights_ledger,
2003               recurrent_to_forget_weights, recurrent_to_forget_weights_ledger,
2004               recurrent_to_cell_weights, recurrent_to_cell_weights_ledger,
2005               recurrent_to_output_weights, recurrent_to_output_weights_ledger,
2006               cell_to_input_weights, cell_to_forget_weights,
2007               cell_to_output_weights, input_layer_norm_coefficients,
2008               forget_layer_norm_coefficients, cell_layer_norm_coefficients,
2009               output_layer_norm_coefficients,
2010               /*aux_input=*/nullptr,
2011               /*aux_input_to_input_weights=*/nullptr,
2012               /*aux_input_to_forget_weights=*/nullptr,
2013               /*aux_input_to_cell_weights=*/nullptr,
2014               /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
2015               forget_gate_bias, cell_gate_bias, output_gate_bias,
2016               projection_weights, projection_weights_ledger, projection_bias,
2017               params,
2018               /*forward_sequence=*/true, /*time_major=*/true,
2019               /*output_offset=*/0, GetTemporary(context, node, kScratchBuffer),
2020               GetTemporary(context, node, kInputScalingFactors),
2021               /*aux_input_sf=*/nullptr,
2022               GetTemporary(context, node, kOutputStateScalingFactors),
2023               GetTemporary(context, node, kProductScalingFactors),
2024               GetTemporary(context, node, kRecoveredCellWeights),
2025               GetTemporary(context, node, kInputQuantized),
2026               /*aux_input_quantized=*/nullptr,
2027               GetTemporary(context, node, kOutputStateQuantized),
2028               GetTemporary(context, node, kCellStateQuantized), output_state,
2029               cell_state, GetTemporary(context, node, kAccumScratch), output,
2030               GetTemporary(context, node, kInputZeroPoints),
2031               /*aux_input_zp=*/nullptr,
2032               GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
2033               row_sums_size, &op_data->compute_row_sums,
2034               CpuBackendContext::GetFromContext(context));
2035         }
2036         return lstm_eval::EvalHybrid(
2037             input, input_to_input_weights,
2038             /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
2039             /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
2040             /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
2041             /*input_to_output_weights_ledger*/ nullptr,
2042             recurrent_to_input_weights,
2043             /*recurrent_to_input_weights_ledger*/ nullptr,
2044             recurrent_to_forget_weights,
2045             /*recurrent_to_forget_weights_ledger*/ nullptr,
2046             recurrent_to_cell_weights,
2047             /*recurrent_to_cell_weights_ledger*/ nullptr,
2048             recurrent_to_output_weights,
2049             /*recurrent_to_output_weights_ledger*/ nullptr,
2050             cell_to_input_weights, cell_to_forget_weights,
2051             cell_to_output_weights, input_layer_norm_coefficients,
2052             forget_layer_norm_coefficients, cell_layer_norm_coefficients,
2053             output_layer_norm_coefficients, /*aux_input=*/nullptr,
2054             /*aux_input_to_input_weights=*/nullptr,
2055             /*aux_input_to_forget_weights=*/nullptr,
2056             /*aux_input_to_cell_weights=*/nullptr,
2057             /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
2058             forget_gate_bias, cell_gate_bias, output_gate_bias,
2059             projection_weights, /*projection_weights_ledger*/ nullptr,
2060             projection_bias, params,
2061             /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
2062             GetTemporary(context, node, kScratchBuffer),
2063             GetTemporary(context, node, kInputScalingFactors),
2064             /*aux_input_sf=*/nullptr,
2065             GetTemporary(context, node, kOutputStateScalingFactors),
2066             GetTemporary(context, node, kProductScalingFactors),
2067             GetTemporary(context, node, kRecoveredCellWeights),
2068             GetTemporary(context, node, kInputQuantized),
2069             /*aux_input_quantized=*/nullptr,
2070             GetTemporary(context, node, kOutputStateQuantized),
2071             GetTemporary(context, node, kCellStateQuantized), output_state,
2072             cell_state, GetTemporary(context, node, kAccumScratch), output,
2073             GetTemporary(context, node, kInputZeroPoints),
2074             /*aux_input_zp=*/nullptr,
2075             GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
2076             row_sums_size, &op_data->compute_row_sums,
2077             CpuBackendContext::GetFromContext(context));
2078       }
2079       const int num_intermediate_tensors = node->intermediates->size;
2080       TfLiteTensor* scratch0;
2081       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 0, &scratch0));
2082       TfLiteTensor* scratch1;
2083       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 1, &scratch1));
2084       TfLiteTensor* scratch2;
2085       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 2, &scratch2));
2086       TfLiteTensor* scratch3;
2087       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 3, &scratch3));
2088       TfLiteTensor* scratch4;
2089       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 4, &scratch4));
2090       TfLiteTensor* scratch5;
2091       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &scratch5));
2092       if (num_intermediate_tensors == 5) {
2093         return lstm_eval::EvalInteger8x8_16(
2094             input, input_to_input_weights, input_to_forget_weights,
2095             input_to_cell_weights, input_to_output_weights,
2096             recurrent_to_input_weights, recurrent_to_forget_weights,
2097             recurrent_to_cell_weights, recurrent_to_output_weights,
2098             cell_to_input_weights, cell_to_forget_weights,
2099             cell_to_output_weights, input_layer_norm_coefficients,
2100             forget_layer_norm_coefficients, cell_layer_norm_coefficients,
2101             output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
2102             cell_gate_bias, output_gate_bias, projection_weights,
2103             projection_bias, params, /*forward_sequence=*/true,
2104             /*time_major=*/true, &op_data->integer_lstm_param, output_state,
2105             cell_state, output, scratch0, scratch1, scratch2, scratch3,
2106             scratch4, scratch5, CpuBackendContext::GetFromContext(context));
2107       }
2108       TfLiteTensor* scratch6;
2109       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 6, &scratch6));
2110       TfLiteTensor* scratch7;
2111       TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 7, &scratch7));
2112       return lstm_eval::EvalInteger8x8_8(
2113           input, input_to_input_weights, input_to_forget_weights,
2114           input_to_cell_weights, input_to_output_weights,
2115           recurrent_to_input_weights, recurrent_to_forget_weights,
2116           recurrent_to_cell_weights, recurrent_to_output_weights,
2117           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
2118           input_layer_norm_coefficients, forget_layer_norm_coefficients,
2119           cell_layer_norm_coefficients, output_layer_norm_coefficients,
2120           input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias,
2121           projection_weights, projection_bias, params, output_state, cell_state,
2122           output, &op_data->integer_lstm_param, scratch0, scratch1, scratch2,
2123           scratch3, scratch4, scratch5, scratch6, scratch7);
2124     }
2125     default:
2126       TF_LITE_KERNEL_LOG(context, "Type %d is not currently supported.",
2127                          input_to_output_weights->type);
2128       return kTfLiteError;
2129   }
2130 }
2131 // LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
2132 
2133 }  // namespace full
2134 
2135 // For basic kernel (5-inputs).
2136 namespace basic {
2137 
2138 enum InputTensor {
2139   kInputData = 0,
2140   kInputPrevActivation = 1,
2141   kInputWeights = 2,
2142   kInputBiases = 3,
2143   kInputPrevState = 4,
2144   kInputNum = 5,
2145 };
2146 
2147 enum OutputTensor {
2148   kOutputActivation = 0,
2149   kOutputState = 1,
2150   kOutputConcatTemp = 2,
2151   kOutputActivationTemp = 3,
2152   kOutputNum = 4,
2153 };
2154 
Init(TfLiteContext * context,const char * buffer,size_t length)2155 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
2156   auto* op_data = new OpData();
2157   op_data->kernel_type = kTfLiteLSTMBasicKernel;
2158   // `scratch_tensor_index` is unused in this kernel.
2159   op_data->scratch_tensor_index = -1;
2160   return op_data;
2161 }
2162 
Prepare(TfLiteContext * context,TfLiteNode * node)2163 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
2164   TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
2165   TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
2166 
2167   const TfLiteTensor* input;
2168   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input));
2169   const TfLiteTensor* prev_activation;
2170   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation,
2171                                           &prev_activation));
2172   const TfLiteTensor* weights;
2173   TF_LITE_ENSURE_OK(context,
2174                     GetInputSafe(context, node, kInputWeights, &weights));
2175   const TfLiteTensor* bias;
2176   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias));
2177   const TfLiteTensor* prev_state;
2178   TF_LITE_ENSURE_OK(context,
2179                     GetInputSafe(context, node, kInputPrevState, &prev_state));
2180 
2181   TF_LITE_ENSURE_EQ(context, input->dims->size, 2);
2182   const int num_batches = input->dims->data[0];
2183   const int input_depth = input->dims->data[1];
2184 
2185   TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2);
2186   TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches);
2187   const int activation_depth = prev_activation->dims->data[1];
2188   const int total_depth = input_depth + activation_depth;
2189 
2190   TF_LITE_ENSURE_EQ(context, weights->dims->size, 2);
2191   TF_LITE_ENSURE_EQ(context, weights->dims->data[0], 4 * activation_depth);
2192   TF_LITE_ENSURE_EQ(context, weights->dims->data[1], total_depth);
2193 
2194   TF_LITE_ENSURE_EQ(context, bias->dims->size, 1);
2195   TF_LITE_ENSURE_EQ(context, bias->dims->data[0], 4 * activation_depth);
2196 
2197   TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2);
2198   TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches);
2199   TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth);
2200 
2201   TfLiteTensor* activation_out;
2202   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation,
2203                                            &activation_out));
2204   TfLiteTensor* state_out;
2205   TF_LITE_ENSURE_OK(context,
2206                     GetOutputSafe(context, node, kOutputState, &state_out));
2207   TfLiteTensor* concat_temp;
2208   TF_LITE_ENSURE_OK(
2209       context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp));
2210   TfLiteTensor* activation_temp;
2211   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp,
2212                                            &activation_temp));
2213 
2214   TF_LITE_ENSURE_OK(context, context->ResizeTensor(
2215                                  context, activation_out,
2216                                  TfLiteIntArrayCopy(prev_activation->dims)));
2217   TF_LITE_ENSURE_OK(
2218       context, context->ResizeTensor(context, state_out,
2219                                      TfLiteIntArrayCopy(prev_state->dims)));
2220 
2221   TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2);
2222   concat_temp_size->data[0] = num_batches;
2223   concat_temp_size->data[1] = total_depth;
2224   TF_LITE_ENSURE_OK(
2225       context, context->ResizeTensor(context, concat_temp, concat_temp_size));
2226   TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2);
2227   activation_temp_size->data[0] = num_batches;
2228   activation_temp_size->data[1] = 4 * activation_depth;
2229   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp,
2230                                                    activation_temp_size));
2231 
2232   // Set the state tensors as persistent.
2233   for (auto index : {kInputPrevActivation, kInputPrevState}) {
2234     TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
2235     tensor->allocation_type = kTfLiteArenaRwPersistent;
2236   }
2237   return kTfLiteOk;
2238 }
2239 
Eval(TfLiteContext * context,TfLiteNode * node)2240 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
2241   const TfLiteTensor* input;
2242   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input));
2243   const TfLiteTensor* prev_activation;
2244   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation,
2245                                           &prev_activation));
2246   const TfLiteTensor* weights;
2247   TF_LITE_ENSURE_OK(context,
2248                     GetInputSafe(context, node, kInputWeights, &weights));
2249   const TfLiteTensor* bias;
2250   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias));
2251   const TfLiteTensor* prev_state;
2252   TF_LITE_ENSURE_OK(context,
2253                     GetInputSafe(context, node, kInputPrevState, &prev_state));
2254 
2255   TfLiteTensor* activation_out;
2256   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation,
2257                                            &activation_out));
2258   TfLiteTensor* state_out;
2259   TF_LITE_ENSURE_OK(context,
2260                     GetOutputSafe(context, node, kOutputState, &state_out));
2261   TfLiteTensor* concat_temp;
2262   TF_LITE_ENSURE_OK(
2263       context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp));
2264   TfLiteTensor* activation_temp;
2265   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp,
2266                                            &activation_temp));
2267 
2268   if (input->type == kTfLiteFloat32 &&
2269       prev_activation->type == kTfLiteFloat32 &&
2270       weights->type == kTfLiteFloat32 && bias->type == kTfLiteFloat32 &&
2271       prev_state->type == kTfLiteFloat32 && state_out->type == kTfLiteFloat32 &&
2272       activation_out->type == kTfLiteFloat32 &&
2273       concat_temp->type == kTfLiteFloat32 &&
2274       activation_temp->type == kTfLiteFloat32) {
2275     tflite::LstmCellParams op_params;
2276     // Float LSTM cell does not need parameters to be set: leave untouched.
2277     optimized_ops::LstmCell(
2278         op_params,
2279         // Inputs.
2280         GetTensorShape(input), GetTensorData<float>(input),
2281         GetTensorShape(prev_activation), GetTensorData<float>(prev_activation),
2282         GetTensorShape(weights), GetTensorData<float>(weights),
2283         GetTensorShape(bias), GetTensorData<float>(bias),
2284         GetTensorShape(prev_state), GetTensorData<float>(prev_state),
2285         // Outputs.
2286         GetTensorShape(state_out), GetTensorData<float>(state_out),
2287         GetTensorShape(activation_out), GetTensorData<float>(activation_out),
2288         GetTensorShape(concat_temp), GetTensorData<float>(concat_temp),
2289         GetTensorShape(activation_temp), GetTensorData<float>(activation_temp),
2290         CpuBackendContext::GetFromContext(context));
2291   } else if (input->type == kTfLiteUInt8 &&
2292              prev_activation->type == kTfLiteUInt8 &&
2293              weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
2294              prev_state->type == kTfLiteInt16 &&
2295              state_out->type == kTfLiteInt16 &&
2296              activation_out->type == kTfLiteUInt8 &&
2297              concat_temp->type == kTfLiteUInt8 &&
2298              activation_temp->type == kTfLiteInt16) {
2299     int state_scale_log2_rounded;
2300     if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) {
2301       TF_LITE_KERNEL_LOG(
2302           context,
2303           "The internal state of a LSTM cell must have a power-of-two scale.");
2304       return kTfLiteError;
2305     }
2306     const int state_integer_bits = 15 + state_scale_log2_rounded;
2307     if (state_integer_bits != 4) {
2308       TF_LITE_KERNEL_LOG(context,
2309                          "The only case of quantized LstmCell currently "
2310                          "supported is with StateIntegerBits==4");
2311       return kTfLiteError;
2312     }
2313 
2314     double real_accum_multiplier = 4096 * bias->params.scale;
2315     int32 accum_multiplier;
2316     int accum_shift;
2317     tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier,
2318                                &accum_shift);
2319     tflite::LstmCellParams op_params;
2320     op_params.weights_zero_point = weights->params.zero_point;
2321     op_params.accum_multiplier = accum_multiplier;
2322     op_params.accum_shift = accum_shift;
2323     optimized_ops::LstmCell<4>(
2324         op_params,
2325         // Inputs.
2326         GetTensorShape(input), GetTensorData<uint8_t>(input),
2327         GetTensorShape(prev_activation),
2328         GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights),
2329         GetTensorData<uint8_t>(weights), GetTensorShape(bias),
2330         GetTensorData<int32_t>(bias), GetTensorShape(prev_state),
2331         GetTensorData<int16_t>(prev_state),
2332         // Outputs.
2333         GetTensorShape(state_out), GetTensorData<int16_t>(state_out),
2334         GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
2335         GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
2336         GetTensorShape(activation_temp),
2337         GetTensorData<int16_t>(activation_temp),
2338         CpuBackendContext::GetFromContext(context));
2339   } else {
2340     TF_LITE_KERNEL_LOG(context,
2341                        "Unsupported combination of data types for LstmCell");
2342     return kTfLiteError;
2343   }
2344 
2345   memcpy(prev_activation->data.raw, activation_out->data.raw,
2346          activation_out->bytes);
2347   memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes);
2348 
2349   return kTfLiteOk;
2350 }
2351 
2352 }  // namespace basic
2353 
Init(TfLiteContext * context,const char * buffer,size_t length)2354 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
2355   const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
2356   switch (params->kernel_type) {
2357     case kTfLiteLSTMFullKernel:
2358       return full::Init(context, buffer, length);
2359     case kTfLiteLSTMBasicKernel:
2360       return basic::Init(context, buffer, length);
2361     default:
2362       return nullptr;
2363   }
2364 }
Free(TfLiteContext * context,void * buffer)2365 void Free(TfLiteContext* context, void* buffer) {
2366   delete static_cast<OpData*>(buffer);
2367 }
2368 
Prepare(TfLiteContext * context,TfLiteNode * node)2369 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
2370   const auto* op_data = static_cast<const OpData*>(node->user_data);
2371   switch (op_data->kernel_type) {
2372     case kTfLiteLSTMFullKernel:
2373       return full::Prepare(context, node);
2374     case kTfLiteLSTMBasicKernel:
2375       return basic::Prepare(context, node);
2376     default:
2377       return kTfLiteError;
2378   }
2379 }
2380 
Eval(TfLiteContext * context,TfLiteNode * node)2381 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
2382   const auto* op_data = static_cast<const OpData*>(node->user_data);
2383   switch (op_data->kernel_type) {
2384     case kTfLiteLSTMFullKernel:
2385       return full::Eval(context, node);
2386     case kTfLiteLSTMBasicKernel:
2387       return basic::Eval(context, node);
2388     default:
2389       return kTfLiteError;
2390   }
2391 }
2392 
2393 }  // namespace lstm
2394 
Register_LSTM()2395 TfLiteRegistration* Register_LSTM() {
2396   static TfLiteRegistration r = {lstm::Init, lstm::Free, lstm::Prepare,
2397                                  lstm::Eval};
2398   return &r;
2399 }
2400 
2401 }  // namespace builtin
2402 }  // namespace ops
2403 }  // namespace tflite
2404