xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/lstm_eval.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/lstm_eval.h"
16 
17 #include <math.h>
18 #include <string.h>
19 
20 #include <algorithm>
21 #include <cstdint>
22 #include <memory>
23 #include <vector>
24 
25 #include "ruy/matrix.h"  // from @ruy
26 #include "ruy/mul_params.h"  // from @ruy
27 #include "ruy/profiler/instrumentation.h"  // from @ruy
28 #include "ruy/ruy.h"  // from @ruy
29 #include "tensorflow/lite/c/builtin_op_data.h"
30 #include "tensorflow/lite/c/common.h"
31 #include "tensorflow/lite/kernels/cpu_backend_context.h"
32 #include "tensorflow/lite/kernels/internal/compatibility.h"
33 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
34 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
35 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
36 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
37 #include "tensorflow/lite/kernels/op_macros.h"
38 
39 namespace tflite {
40 namespace ops {
41 namespace builtin {
42 namespace lstm_eval {
43 namespace {
44 
MatrixBatchVectorMultiplyAccumulate(const float * matrix,const float * vector,const float * result,float * output,int m_rows,int m_cols,int n_batch,CpuBackendContext * cpu_backend_context)45 void MatrixBatchVectorMultiplyAccumulate(
46     const float* matrix, const float* vector, const float* result,
47     float* output, int m_rows, int m_cols, int n_batch,
48     CpuBackendContext* cpu_backend_context) {
49   tflite::FullyConnectedParams float_fc_params;
50   float_fc_params.float_activation_min = std::numeric_limits<float>::lowest();
51   float_fc_params.float_activation_max = std::numeric_limits<float>::max();
52   float_fc_params.lhs_cacheable = true;
53   float_fc_params.rhs_cacheable = false;
54 
55   tflite::RuntimeShape weight_shape({m_rows, m_cols});
56   tflite::RuntimeShape input_shape({n_batch, m_cols});
57   tflite::RuntimeShape output_shape({n_batch, m_rows});
58   if (n_batch == 1) {
59     tflite::optimized_ops::FullyConnected(
60         float_fc_params, input_shape, vector, weight_shape, matrix,
61         output_shape, result, output_shape, output, cpu_backend_context);
62   } else {
63     tflite::optimized_ops::FullyConnected(
64         float_fc_params, input_shape, vector, weight_shape, matrix,
65         output_shape, nullptr, output_shape, output, cpu_backend_context);
66     for (int i = 0; i < m_rows * n_batch; ++i) {
67       output[i] += result[i];
68     }
69   }
70 }
71 
ComputeRowSums(int32_t * input_to_input_row_sums,int32_t * input_to_forget_row_sums,int32_t * input_to_cell_row_sums,int32_t * input_to_output_row_sums,int32_t * aux_input_to_input_row_sums,int32_t * aux_input_to_forget_row_sums,int32_t * aux_input_to_cell_row_sums,int32_t * aux_input_to_output_row_sums,int32_t * recurrent_to_input_row_sums,int32_t * recurrent_to_forget_row_sums,int32_t * recurrent_to_cell_row_sums,int32_t * recurrent_to_output_row_sums,int32_t * projection_weights_row_sums,int32_t * row_sums,int n_cell,int n_input,int n_aux_input,int n_output,const int8_t * input_to_input_weights_ptr,const int8_t * input_to_forget_weights_ptr,const int8_t * input_to_cell_weights_ptr,const int8_t * input_to_output_weights_ptr,const int8_t * aux_input_to_input_weights_ptr,const int8_t * aux_input_to_forget_weights_ptr,const int8_t * aux_input_to_cell_weights_ptr,const int8_t * aux_input_to_output_weights_ptr,const int8_t * recurrent_to_input_weights_ptr,const int8_t * recurrent_to_forget_weights_ptr,const int8_t * recurrent_to_cell_weights_ptr,const int8_t * recurrent_to_output_weights_ptr,const int8_t * projection_weights_ptr,bool use_cifg,const float * aux_input_ptr)72 void ComputeRowSums(
73     int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
74     int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
75     int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
76     int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
77     int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
78     int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
79     int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
80     int n_input, int n_aux_input, int n_output,
81     const int8_t* input_to_input_weights_ptr,
82     const int8_t* input_to_forget_weights_ptr,
83     const int8_t* input_to_cell_weights_ptr,
84     const int8_t* input_to_output_weights_ptr,
85     const int8_t* aux_input_to_input_weights_ptr,
86     const int8_t* aux_input_to_forget_weights_ptr,
87     const int8_t* aux_input_to_cell_weights_ptr,
88     const int8_t* aux_input_to_output_weights_ptr,
89     const int8_t* recurrent_to_input_weights_ptr,
90     const int8_t* recurrent_to_forget_weights_ptr,
91     const int8_t* recurrent_to_cell_weights_ptr,
92     const int8_t* recurrent_to_output_weights_ptr,
93     const int8_t* projection_weights_ptr, bool use_cifg,
94     const float* aux_input_ptr) {
95   // Compute the row sums for dequantization
96   if (!use_cifg) {
97     tensor_utils::ReductionSumVector(input_to_input_weights_ptr,
98                                      input_to_input_row_sums, n_cell, n_input);
99   }
100   tensor_utils::ReductionSumVector(input_to_forget_weights_ptr,
101                                    input_to_forget_row_sums, n_cell, n_input);
102   tensor_utils::ReductionSumVector(input_to_cell_weights_ptr,
103                                    input_to_cell_row_sums, n_cell, n_input);
104   tensor_utils::ReductionSumVector(input_to_output_weights_ptr,
105                                    input_to_output_row_sums, n_cell, n_input);
106 
107   if (aux_input_ptr) {
108     if (!use_cifg) {
109       tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
110                                        aux_input_to_input_row_sums, n_cell,
111                                        n_aux_input);
112     }
113     tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
114                                      aux_input_to_forget_row_sums, n_cell,
115                                      n_aux_input);
116     tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
117                                      aux_input_to_cell_row_sums, n_cell,
118                                      n_aux_input);
119     tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
120                                      aux_input_to_output_row_sums, n_cell,
121                                      n_aux_input);
122   }
123   if (!use_cifg) {
124     tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
125                                      recurrent_to_input_row_sums, n_cell,
126                                      n_output);
127   }
128   tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
129                                    recurrent_to_forget_row_sums, n_cell,
130                                    n_output);
131   tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
132                                    recurrent_to_cell_row_sums, n_cell,
133                                    n_output);
134   tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
135                                    recurrent_to_output_row_sums, n_cell,
136                                    n_output);
137 
138   if (projection_weights_ptr != nullptr) {
139     tensor_utils::ReductionSumVector(
140         projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
141   }
142 }
143 
GetTensorScale(const TfLiteTensor * tensor)144 inline float GetTensorScale(const TfLiteTensor* tensor) {
145   return tensor == nullptr ? 1.0f : tensor->params.scale;
146 }
147 
148 // LINT.IfChange
149 // Calculates a single LSTM gate.
150 //
151 // Implements the following formula: (* is matrix multiply)
152 //   gate = activate(W_input    * input + W_aux       * aux_input   +
153 //                   W_peephole * cell  + W_recurrent * prev_output + bias)
154 // with layer norm:
155 //   gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
156 //
157 // Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
158 //
159 // Parameters:
160 // Input vectors (to LSTM):    | Size:                | Optional?
161 //   input                     | n_input              |
162 //   aux_input                 | n_aux_input          | y (bidir LSTM)
163 // Input vectors (persistent states):
164 //   output_state              | n_output             |
165 //   cell_state                | n_cell               |
166 // 'Constant' inputs:
167 //   input_to_gate_weights     | n_cell * n_input     |
168 //   aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
169 //   recurrent_to_gate_weights | n_cell * n_output    |
170 //   cell_to_gate_weights      | n_cell               | y (peephole)
171 //   gate_bias                 | n_cell               |
172 //   layer_norm_coefficients   | n_cell               | y (layer norm)
173 // Output vector:
174 //   gate                      | n_cell               |
175 // Scalar parameters:
176 //   n_batch                                    - batch size / number of vectors
177 //   n_input, n_aux_input, n_output, n_cell     - size of vectors.
178 //   activation                                 - activation to use.
179 //   is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
180 //   use_layer_norm                             - if doing layer norm LSTM.
CalculateLstmGateFloat(const float * input,const float * input_to_gate_weights,const float * aux_input,const float * aux_input_to_gate_weights,const float * output_state,const float * recurrent_to_gate_weights,const float * cell_state,const float * cell_to_gate_weights,const float * layer_norm_coefficients,const float * gate_bias,const int n_batch,const int n_input,const int n_aux_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,float * gate,const bool is_input_all_zeros,const bool is_aux_input_all_zeros,float * output,CpuBackendContext * context)181 inline void CalculateLstmGateFloat(
182     const float* input, const float* input_to_gate_weights,
183     const float* aux_input, const float* aux_input_to_gate_weights,
184     const float* output_state, const float* recurrent_to_gate_weights,
185     const float* cell_state, const float* cell_to_gate_weights,
186     const float* layer_norm_coefficients, const float* gate_bias,
187     const int n_batch, const int n_input, const int n_aux_input,
188     const int n_output, const int n_cell,
189     const TfLiteFusedActivation activation, float* gate,
190     const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
191     float* output, CpuBackendContext* context) {
192   const bool use_peephole = (cell_to_gate_weights != nullptr);
193   const bool use_layer_norm = (layer_norm_coefficients != nullptr);
194 
195   // Initialize scratch buffers with bias for regular lstm or initialize with
196   // zero for layer norm lstm.
197   if (use_layer_norm) {
198     std::fill_n(gate, n_cell * n_batch, 0.0f);
199   } else {
200     tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
201   }
202   // For each batch and cell: compute input_weight * input.
203   // Skip if input is all zeros.
204   float* accumulation_buffer = gate;
205   if (!is_input_all_zeros) {
206     MatrixBatchVectorMultiplyAccumulate(input_to_gate_weights, input,
207                                         accumulation_buffer, output, n_cell,
208                                         n_input, n_batch, context);
209     std::swap(accumulation_buffer, output);
210   }
211   // For each batch and cell: compute aux_input_weight * aux_input.
212   // Skip if auxiliary input is not available or all zeros.
213   if (!is_aux_input_all_zeros) {
214     MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, aux_input,
215                                         accumulation_buffer, output, n_cell,
216                                         n_aux_input, n_batch, context);
217     std::swap(accumulation_buffer, output);
218   }
219   // For each batch and cell: compute recurrent_weight * output_state.
220   MatrixBatchVectorMultiplyAccumulate(recurrent_to_gate_weights, output_state,
221                                       accumulation_buffer, output, n_cell,
222                                       n_output, n_batch, context);
223   // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
224   if (use_peephole) {
225     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
226         cell_to_gate_weights, n_cell, cell_state, n_batch, output);
227   }
228   // Do layer normalization (if layer norm LSTM)
229   if (use_layer_norm) {
230     tensor_utils::MeanStddevNormalization(output, output, n_cell, n_batch);
231     tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
232                                                 output, n_batch, output);
233     tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, output);
234   }
235   // Apply activation
236   tensor_utils::ApplyActivationToVector(output, n_batch * n_cell, activation,
237                                         gate);
238 }
239 
240 // Updates the LSTM cell state, used by both float and hybrid LSTM versions.
241 //
242 // Implements the following formula:
243 //   cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
244 //
245 // With CIFG LSTM, input gate is replaced by (1-forget_gate).
246 //
247 // Parameters:
248 //  - n_batch, n_cell: sizes of vectors
249 //  - cell_state: input/output vector, size n_batch*n_cell
250 //  - input_gate: input vector, size n_batch*n_cell.
251 //  - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
252 //  - cell_gate: input vector, size n_batch*n_cell.
253 //  - use_cifg: use 1-forget_gate instead of input_gate.
254 //  - clip: if > 0, clip the resulting cell state to [-clip, +clip].
UpdateLstmCellFloat(int n_batch,int n_cell,float * cell_state,const float * input_gate,float * forget_gate,const float * cell_gate,bool use_cifg,float clip)255 void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
256                          const float* input_gate, float* forget_gate,
257                          const float* cell_gate, bool use_cifg, float clip) {
258   tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
259                                          n_batch * n_cell, cell_state);
260 
261   if (use_cifg) {
262     // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
263     // scratch, as input_gate array is not allocated in this case. (Be careful
264     // not to write to the scratch before reading the forget gate data.)
265     float* scratch = forget_gate;
266     tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
267     tensor_utils::VectorVectorCwiseProductAccumulate(
268         cell_gate, scratch, n_batch * n_cell, cell_state);
269   } else {
270     tensor_utils::VectorVectorCwiseProductAccumulate(
271         cell_gate, input_gate, n_batch * n_cell, cell_state);
272   }
273   if (clip > 0.0f) {
274     tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
275   }
276 }
277 
278 // Calculates the output state tensor of an LSTM step.
279 //
280 // Implements the following formula:
281 //   output_no_projection = output_gate .* activate(cell_state)
282 //     (elementwise vector product)
283 // If no projection is used:
284 //   output = output_state = output_no_projection
285 // With projection:
286 //   output = output_state = clip(W*output_no_projection + bias)
287 //
288 // Output might not have a different 'stride' than n_batch, so we need to copy.
289 //
290 // Parameters:
291 //  - n_batch: batches: the number of distinct vectors in each array.
292 //  - n_cell, n_output: sizes of vectors.
293 //  - cell_state, output_gate: input vectors, size n_batch*n_cell.
294 //  - projection_weights, projection_weights_scale, projection_bias:
295 //      constant inputs, describing projection matrix and bias.
296 //  - proj_clip: if > 0, clip the output of the projection.
297 //  - output_state: output vector, size n_batch*n_output. Must be contigous.
298 //  - scratch: scratch area to store output_no_projection. Size n_batch*n_cell.
299 //  - projection_bias_scratch: scratch area to store projection_bias. Size
300 //  n_batch*n_cell.
301 //  - context: the CpuBackendContext for use with matrix multiplications.
CalculateLstmOutputFloat(int n_batch,int n_cell,int n_output,const float * cell_state,const float * output_gate,TfLiteFusedActivation activation,const float * projection_weights,const float * projection_bias,const float proj_clip,float * output_state,float * scratch,float * projection_bias_scratch,CpuBackendContext * context)302 void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output,
303                               const float* cell_state, const float* output_gate,
304                               TfLiteFusedActivation activation,
305                               const float* projection_weights,
306                               const float* projection_bias,
307                               const float proj_clip, float* output_state,
308                               float* scratch, float* projection_bias_scratch,
309                               CpuBackendContext* context) {
310   tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
311                                         activation, scratch);
312   tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell,
313                                          scratch);
314 
315   const bool use_projection = (projection_weights != nullptr);
316   const bool use_projection_bias = (projection_bias != nullptr);
317 
318   if (use_projection) {
319     if (use_projection_bias) {
320       tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
321                                             projection_bias_scratch);
322     } else {
323       std::fill_n(projection_bias_scratch, n_batch * n_output, 0.0f);
324     }
325     MatrixBatchVectorMultiplyAccumulate(projection_weights, scratch,
326                                         projection_bias_scratch, output_state,
327                                         n_output, n_cell, n_batch, context);
328     if (proj_clip > 0.0f) {
329       tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
330     }
331   } else {
332     std::copy_n(scratch, n_batch * n_output, output_state);
333   }
334 }
335 // LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\
336 //                 ../experimental/kernels/fp16/lstm_eval.cc)
337 
338 // Calculates a single LSTM gate, hybrid version.
339 // Implements the same functionality as CalculateLstmGateFloat.
CalculateLstmGateHybrid(const int8_t * input,const float * input_sf,const int32_t * input_zp,const int8_t * input_to_gate_weights,const uint8_t * input_to_gate_weights_ledger,const float input_to_gate_weights_scale,int32_t * input_to_gate_row_sums,const int8_t * aux_input,const float * aux_input_sf,const int32_t * aux_input_zp,const int8_t * aux_input_to_gate_weights,const float aux_input_to_gate_weights_scale,int32_t * aux_input_to_gate_row_sums,const int8_t * output_state,const float * output_state_sf,const int32_t * output_state_zp,const int8_t * recurrent_to_gate_weights,const uint8_t * recurrent_to_gate_weights_ledger,const float recurrent_to_gate_weights_scale,int32_t * recurrent_to_gate_row_sums,const float * cell_state,const int8_t * cell_to_gate_weights,const float cell_to_gate_weights_scale,const float * layer_norm_coefficients,const float * gate_bias,const int n_batch,const int n_input,const int n_aux_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,float * gate,const bool is_input_all_zeros,const bool is_aux_input_all_zeros,const bool is_output_state_all_zeros,bool * compute_row_sums,CpuBackendContext * context,float * scratch0,float * scratch1,int32_t * accum_scratch)340 void CalculateLstmGateHybrid(
341     // Input and weights
342     const int8_t* input, const float* input_sf, const int32_t* input_zp,
343     const int8_t* input_to_gate_weights,
344     const uint8_t* input_to_gate_weights_ledger,
345     const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
346     // Aux input and weights
347     const int8_t* aux_input, const float* aux_input_sf,
348     const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights,
349     const float aux_input_to_gate_weights_scale,
350     int32_t* aux_input_to_gate_row_sums,
351     // Output state and weights
352     const int8_t* output_state, const float* output_state_sf,
353     const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
354     const uint8_t* recurrent_to_gate_weights_ledger,
355     const float recurrent_to_gate_weights_scale,
356     int32_t* recurrent_to_gate_row_sums,
357     // Cell state and weights (peephole LSTM)
358     const float* cell_state, const int8_t* cell_to_gate_weights,
359     const float cell_to_gate_weights_scale,
360     // Layer normalization coefficients (layer norm LSTM) + gate bias
361     const float* layer_norm_coefficients, const float* gate_bias,
362     // Array sizes
363     const int n_batch, const int n_input, const int n_aux_input,
364     const int n_output, const int n_cell,
365     const TfLiteFusedActivation activation,
366     // Output
367     float* gate,
368     // Parameters for performance optimizations
369     const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
370     const bool is_output_state_all_zeros, bool* compute_row_sums,
371     CpuBackendContext* context,
372     // Scratch arrays
373     float* scratch0,        // size: n_batch
374     float* scratch1,        // size: n_cell, only used if peephole LSTM
375     int32_t* accum_scratch  // For MatrixBatchVectorMultiplyAccumulate
376 ) {
377   const bool use_peephole = (cell_to_gate_weights != nullptr);
378   const bool use_layer_norm = (layer_norm_coefficients != nullptr);
379 
380   // Initialize scratch buffers with bias for regular lstm or initialize with
381   // zero for layer norm lstm.
382   if (use_layer_norm) {
383     std::fill_n(gate, n_cell * n_batch, 0.0f);
384   } else {
385     tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
386   }
387   // For each batch and cell: compute input_weight * input.
388   // Skip if input is all zeros.
389   if (!is_input_all_zeros) {
390     if (input_to_gate_weights_ledger != nullptr) {
391       std::vector<float> scales(n_batch);
392       for (int i = 0; i < n_batch; i++) {
393         scales[i] = input_to_gate_weights_scale * input_sf[i];
394       }
395       tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
396           input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
397           input, scales.data(), n_batch, gate);
398 
399     } else {
400       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
401           input_to_gate_weights, n_cell, n_input, input,
402           input_to_gate_weights_scale, input_sf, n_batch, gate,
403           /*per_channel_scale=*/nullptr, input_zp, accum_scratch,
404           input_to_gate_row_sums, compute_row_sums, scratch0, context);
405     }
406   }
407   // For each batch and cell: compute aux_input_weight * aux_input.
408   // Skip if auxiliary input is not available or all zeros.
409   if (!is_aux_input_all_zeros) {
410     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
411         aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
412         aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate,
413         /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch,
414         aux_input_to_gate_row_sums, compute_row_sums, scratch0, context);
415   }
416   // For each batch and cell: compute recurrent_weight * output_state.
417   // Skip if output state is all zeros.
418   if (!is_output_state_all_zeros) {
419     if (recurrent_to_gate_weights_ledger != nullptr) {
420       std::vector<float> scales(n_batch);
421       for (int i = 0; i < n_batch; i++) {
422         scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
423       }
424       tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
425           recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
426           n_output, output_state, scales.data(), n_batch, gate);
427     } else {
428       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
429           recurrent_to_gate_weights, n_cell, n_output, output_state,
430           recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
431           /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
432           recurrent_to_gate_row_sums, compute_row_sums, scratch0, context);
433     }
434   }
435   // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
436   if (use_peephole) {
437     float* recovered_cell_weights = scratch1;
438     tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell,
439                                        cell_to_gate_weights_scale,
440                                        recovered_cell_weights);
441     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
442         recovered_cell_weights, n_cell, cell_state, n_batch, gate);
443   }
444   // Do layer normalization (if layer norm LSTM)
445   if (use_layer_norm) {
446     tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
447     tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
448                                                 gate, n_batch, gate);
449     tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
450   }
451   // Apply activation
452   tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch, activation,
453                                         gate);
454 }
455 
456 // Calculates the output state tensor of an LSTM step. See Float version too.
457 //
458 // Parameters:
459 //  - n_batch: batches: the number of distinct vectors in each array.
460 //  - n_cell, n_output: sizes of vectors.
461 //  - cell_state, output_gate: input vectors, size n_batch*n_cell.
462 //  - projection_weights, projection_weights_scale, projection_bias:
463 //      constant inputs, describing projection matrix and bias.
464 //  - proj_clip: if > 0, clip the output of the projection.
465 //  - output_state: output vector, size n_batch*n_output. Must be contigous.
466 //  - asymmetric_quantize_inputs: parameter to control quantization.
467 //  - projection_weights_row_sums, compute_row_sums, context: Data for optimized
468 //      MatrixBatchVectorMultiplyAccumulate.
469 //  - scratch0: scratch area of size n_batch*n_cell
470 //  - scratch1: scratch area of size n_batch*n_cell
471 //  - scratch2: scratch area of size n_batch
472 //  - scratch3: scratch area of size n_batch
473 //  - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate
CalculateLstmOutputHybrid(int n_batch,int n_cell,int n_output,const float * cell_state,const float * output_gate,TfLiteFusedActivation activation,const int8_t * projection_weights,const uint8_t * projection_weights_ledger,float projection_weights_scale,const float * projection_bias,const float proj_clip,float * output_state,bool asymmetric_quantize_inputs,int32_t * projection_weights_row_sums,bool * compute_row_sums,CpuBackendContext * context,float * scratch0,int8_t * scratch1,float * scratch2,int32_t * scratch3,int32_t * scratch4)474 void CalculateLstmOutputHybrid(
475     int n_batch, int n_cell, int n_output, const float* cell_state,
476     const float* output_gate, TfLiteFusedActivation activation,
477     const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
478     float projection_weights_scale, const float* projection_bias,
479     const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
480     int32_t* projection_weights_row_sums, bool* compute_row_sums,
481     CpuBackendContext* context, float* scratch0, int8_t* scratch1,
482     float* scratch2, int32_t* scratch3, int32_t* scratch4) {
483   tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
484                                         activation, scratch0);
485   tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
486                                          n_batch * n_cell, scratch0);
487 
488   const bool use_projection = (projection_weights != nullptr);
489   const bool use_projection_bias = (projection_bias != nullptr);
490 
491   if (use_projection) {
492     if (use_projection_bias) {
493       tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
494                                             output_state);
495     } else {
496       std::fill_n(output_state, n_batch * n_output, 0.0f);
497     }
498     if (!tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) {
499       // Save quantization and matmul computation for all zero output.
500       tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1,
501                                         scratch2, scratch3,
502                                         asymmetric_quantize_inputs);
503       if (projection_weights_ledger != nullptr) {
504         std::vector<float> scales(n_batch);
505         for (int i = 0; i < n_batch; i++) {
506           scales[i] = projection_weights_scale * scratch2[i];
507         }
508         tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
509             projection_weights, projection_weights_ledger, n_output, n_cell,
510             scratch1, scales.data(), n_batch, output_state);
511       } else {
512         tensor_utils::MatrixBatchVectorMultiplyAccumulate(
513             projection_weights, n_output, n_cell, scratch1,
514             projection_weights_scale, scratch2, n_batch, output_state,
515             /*per_channel_scale=*/nullptr, scratch3, scratch4,
516             projection_weights_row_sums, compute_row_sums, scratch2, context);
517       }
518     }
519     if (proj_clip > 0.0f) {
520       tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
521     }
522   } else {
523     std::copy_n(scratch0, n_batch * n_output, output_state);
524   }
525 }
526 
527 // Calculates a single LSTM gate, int8x8_16 version.
528 // Implements the same functionality as CalculateLstmGateFloat.
CalculateLstmGateInteger8x8_16(const int8_t * input,const int8_t * input_to_gate_weights,const int32_t * input_to_gate_bias,const int32_t input_to_gate_scale_a,const int32_t input_to_gate_scale_b,const int8_t * output_state,const int8_t * recurrent_to_gate_weights,const int32_t * recurrent_to_gate_bias,const int32_t recurrent_to_gate_scale_a,const int32_t recurrent_to_gate_scale_b,const int16_t * cell_state,const int16_t * cell_to_gate_weights,const int32_t cell_to_gate_scale_a,const int32_t cell_to_gate_scale_b,const int16_t * layer_norm_coefficients,const int32_t * layer_norm_bias,const int32_t layer_norm_input_scale_a,const int32_t layer_norm_input_scale_b,const int32_t layer_norm_variance_guard,const int n_batch,const int n_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,int16_t * gate,CpuBackendContext * context,int32_t * scratch5)529 void CalculateLstmGateInteger8x8_16(
530     // Input and weights
531     const int8_t* input, const int8_t* input_to_gate_weights,
532     const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
533     const int32_t input_to_gate_scale_b,
534     // Output state and weights
535     const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
536     const int32_t* recurrent_to_gate_bias,
537     const int32_t recurrent_to_gate_scale_a,
538     const int32_t recurrent_to_gate_scale_b,
539     // Cell state and weights
540     const int16_t* cell_state, const int16_t* cell_to_gate_weights,
541     const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
542     // Layer normalization parameters (layer norm LSTM)
543     const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
544     const int32_t layer_norm_input_scale_a,
545     const int32_t layer_norm_input_scale_b,
546     const int32_t layer_norm_variance_guard,
547     // Array sizes
548     const int n_batch, const int n_input, const int n_output, const int n_cell,
549     const TfLiteFusedActivation activation,
550     // Output
551     int16_t* gate,
552     // Parameters for performance optimizations
553     CpuBackendContext* context,
554     // Scratch arrays
555     int32_t* scratch5) {
556   const bool use_peephole = (cell_to_gate_weights != nullptr);
557   const bool use_layer_norm = (layer_norm_coefficients != nullptr);
558 
559   // Initialize scratch buffers with zeros. Note that unlike float and hybrid
560   // versions, bias is only used in layer normalization.
561   std::fill_n(gate, n_batch * n_cell, 0);
562   // For each batch and cell: compute input_weight * input.
563   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
564       input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a,
565       input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate,
566       context);
567   // Note: no aux_input.
568 
569   // For each batch and cell: compute recurrent_weight * output_state.
570   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
571       output_state, recurrent_to_gate_bias, recurrent_to_gate_weights,
572       recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
573       n_cell, 0, scratch5, gate, context);
574   // For each batch and cell: compute cell_weight * cell_state (peephole LSTM)
575   if (use_peephole) {
576     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
577         cell_to_gate_weights, n_output, cell_state, n_batch,
578         cell_to_gate_scale_a, cell_to_gate_scale_b, gate);
579   }
580   // Do layer normalization (if layer norm LSTM)
581   if (use_layer_norm) {
582     tensor_utils::ApplyLayerNorm(
583         gate, layer_norm_coefficients, layer_norm_bias,
584         layer_norm_input_scale_a, layer_norm_input_scale_b,
585         layer_norm_variance_guard, n_batch, n_cell, gate);
586   }
587   // Apply activation
588   switch (activation) {
589     case kTfLiteActSigmoid:
590       tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
591       break;
592     case kTfLiteActTanh:
593       tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
594       break;
595     default:
596       // Only Sigmoid or Tanh is used.
597       TFLITE_ASSERT_FALSE;
598   }
599 }
600 
601 // Updates the LSTM cell state, used by both integer LSTM versions.
602 // Also see UpdateLstmCellFloat.
603 //
604 // Parameters:
605 //  - n_batch, n_cell: sizes of vectors
606 //  - cell_state: input/output vector, size n_batch*n_cell
607 //  - cell_state_scale: scaling factor of cell state.
608 //  - input_gate: input vector, size n_batch*n_cell.
609 //  - forget_gate: input/scratch vector, size n_batch*n_cell, always modified.
610 //  - cell_gate: input vector, size n_batch*n_cell.
611 //  - use_cifg: use 1-forget_gate instead of input_gate.
612 //  - clip: if > 0, clip the resulting cell state to [-clip, +clip].
UpdateLstmCellInteger(int n_batch,int n_cell,int16_t * cell_state,int32_t cell_state_scale,const int16_t * input_gate,int16_t * forget_gate,const int16_t * cell_gate,bool use_cifg,int16_t clip)613 void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
614                            int32_t cell_state_scale, const int16_t* input_gate,
615                            int16_t* forget_gate, const int16_t* cell_gate,
616                            bool use_cifg, int16_t clip) {
617   // Use the forget_gate array as scratch, as input_gate array is not allocated
618   // in CIFG case. (Be careful not to write to the scratch before reading the
619   // forget gate data.)
620   int16_t* scratch = forget_gate;
621 
622   tensor_utils::CwiseMul(forget_gate, cell_state, n_batch, n_cell, 15,
623                          cell_state);
624   if (use_cifg) {
625     tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
626     tensor_utils::CwiseMul(scratch, cell_gate, n_batch, n_cell,
627                            30 + cell_state_scale, scratch);
628   } else {
629     tensor_utils::CwiseMul(input_gate, cell_gate, n_batch, n_cell,
630                            30 + cell_state_scale, scratch);
631   }
632   tensor_utils::CwiseAdd(cell_state, scratch, n_batch, n_cell, cell_state);
633 
634   if (clip > 0) {
635     tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
636   }
637 }
638 
639 // Calculates the output state tensor of an LSTM step. See Float and hybrid
640 // versions as well.
641 //
642 // Parameters:
643 //  - n_batch: batches: the number of distinct vectors in each array.
644 //  - n_cell, n_output: sizes of vectors.
645 //  - cell_state, output_gate: input vectors, size n_batch*n_cell.
646 //  - cell_state_scale: scaling of cell_state.
647 //  - hidden_scale_[a|b]: effective scale of cell_state.*output_gate
648 //  - hidden_zp: zero_point for cell_state.*output_gate
649 //  - projection_weights, proj_scale_[a|b], projection_bias:
650 //      constant inputs, describing projection matrix and bias.
651 //  - output_state_zp: zero point of output_state. (Input, calibrated value.)
652 //  - quantized_proj_clip: if > 0, clip the output of the projection.
653 //  - output_state: output vector, size n_batch*n_output. Must be contigous.
654 //  - context: data for optimized MatrixBatchVectorMultiplyAccumulate.
655 //  - scratch0: scratch area of size n_batch*n_cell
656 //  - scratch1: scratch area of size n_batch*n_cell
657 //  - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
CalculateLstmOutputInteger8x8_16(int n_batch,int n_cell,int n_output,const int16_t * cell_state,int32_t cell_state_scale,const int16_t * output_gate,int32_t hidden_scale_a,int32_t hidden_scale_b,int32_t hidden_zp,const int8_t * projection_weights,int32_t proj_scale_a,int32_t proj_scale_b,const int32_t * projection_bias,int32_t output_state_zp,int8_t quantized_proj_clip,int8_t * output_state,CpuBackendContext * context,int16_t * scratch0,int8_t * scratch1,int32_t * scratch2)658 void CalculateLstmOutputInteger8x8_16(
659     int n_batch, int n_cell, int n_output, const int16_t* cell_state,
660     int32_t cell_state_scale, const int16_t* output_gate,
661     int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
662     const int8_t* projection_weights, int32_t proj_scale_a,
663     int32_t proj_scale_b, const int32_t* projection_bias,
664     int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
665     CpuBackendContext* context, int16_t* scratch0, int8_t* scratch1,
666     int32_t* scratch2) {
667   // Note: unlike float/hybrid, the activation is always Tanh.
668   tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch, n_cell,
669                           scratch0);
670   tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a, hidden_scale_b,
671                          n_batch, n_cell, hidden_zp, scratch1);
672 
673   const bool use_projection = (projection_weights != nullptr);
674 
675   if (use_projection) {
676     // Note: no bias like in float/hybrid
677     std::fill_n(output_state, n_batch * n_output, 0);
678     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
679         scratch1, projection_bias, projection_weights, proj_scale_a,
680         proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2,
681         output_state, context);
682     if (quantized_proj_clip > 0) {
683       tensor_utils::CwiseClipping(output_state, n_batch * n_output,
684                                   quantized_proj_clip);
685     }
686   } else {
687     std::copy_n(scratch1, n_batch * n_output, output_state);
688   }
689 }
690 
691 // Calculates a single LSTM gate, int8x8_8 version.
692 // Implements the same functionality as CalculateLstmGateFloat.
CalculateLstmGateInteger8x8_8(const int8_t * input,int32_t input_zp,const int8_t * input_to_gate_weight,const int32_t input_to_gate_scale_a,const int32_t input_to_gate_scale_b,const int32_t input_times_weights_scale_a,const int32_t input_times_weights_scale_b,const int32_t input_times_weights_zp,const int8_t * output_state,const int32_t output_state_zp,const int8_t * recurrent_to_gate_weight,const int32_t recurrent_to_gate_scale_a,const int32_t recurrent_to_gate_scale_b,const int32_t output_state_times_weights_scale_a,const int32_t output_state_times_weights_scale_b,const int32_t output_state_times_weights_zp,const int16_t * layer_norm_gate_weight,const int32_t layer_norm_gate_scale_a,const int32_t layer_norm_gate_scale_b,const int32_t * gate_bias,const int n_batch,const int n_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,int16_t * gate,int8_t * scratch0,int8_t * scratch1)693 void CalculateLstmGateInteger8x8_8(
694     // Inputs and weights
695     const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
696     const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
697     const int32_t input_times_weights_scale_a,
698     const int32_t input_times_weights_scale_b,
699     const int32_t input_times_weights_zp,
700     // Output state and weights
701     const int8_t* output_state, const int32_t output_state_zp,
702     const int8_t* recurrent_to_gate_weight,
703     const int32_t recurrent_to_gate_scale_a,
704     const int32_t recurrent_to_gate_scale_b,
705     const int32_t output_state_times_weights_scale_a,
706     const int32_t output_state_times_weights_scale_b,
707     const int32_t output_state_times_weights_zp,
708     // Layer normalization parameters (layer norm LSTM)
709     const int16_t* layer_norm_gate_weight,
710     const int32_t layer_norm_gate_scale_a,
711     const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
712     // Array sizes
713     const int n_batch, const int n_input, const int n_output, const int n_cell,
714     const TfLiteFusedActivation activation,
715     // Output
716     int16_t* gate,
717     // Scratch arrays, both sized n_batch*n_cell
718     int8_t* scratch0, int8_t* scratch1) {
719   // Multiply input * input_weights => scratch0
720   tensor_utils::MatrixBatchVectorMultiply(
721       input, input_zp, input_to_gate_weight, input_to_gate_scale_a,
722       input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0,
723       input_times_weights_zp);
724   // Multiply output_state * recurrent_weights => scratch1
725   tensor_utils::MatrixBatchVectorMultiply(
726       output_state, output_state_zp, recurrent_to_gate_weight,
727       recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
728       n_cell, scratch1, output_state_times_weights_zp);
729   // Add scratch0 + scratch1 => gate
730   tensor_utils::TwoGateSaturatingAdd(
731       scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp,
732       input_times_weights_scale_a, input_times_weights_scale_b,
733       output_state_times_weights_scale_a, output_state_times_weights_scale_b,
734       n_batch, n_cell, gate);
735   // Apply layer normalization.
736   tensor_utils::ApplyLayerNormFloat(
737       gate, layer_norm_gate_weight, layer_norm_gate_scale_a,
738       layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate);
739   // Apply activation.
740   switch (activation) {
741     case kTfLiteActSigmoid:
742       tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate);
743       break;
744     case kTfLiteActTanh:
745       tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate);
746       break;
747     default:
748       // Only Sigmoid or Tanh is used.
749       TFLITE_ASSERT_FALSE;
750   }
751 }
752 
753 // Calculates the output state tensor of an LSTM step. See Float and hybrid
754 // versions as well.
755 //
756 // Parameters:
757 //  - n_batch: batches: the number of distinct vectors in each array.
758 //  - n_cell, n_output: sizes of vectors.
759 //  - cell_state, output_gate: input vectors, size n_batch*n_cell.
760 //  - projection_weights, proj_scale_[a|b], projection_bias:
761 //      constant inputs, describing projection matrix and bias.
762 //  - output_state_zp: zero point of the output state.
763 //  - quantized_proj_clip: if > 0, clip the output of the projection.
764 //  - output_state: output vector, size n_batch*n_output. Must be contigous.
765 //  - scratch: scratch area of size n_batch*n_cell
CalculateLstmOutputInteger8x8_8(int n_batch,int n_cell,int n_output,const int16_t * cell_state,const int16_t * output_gate,const int8_t * projection_weights,int32_t proj_scale_a,int32_t proj_scale_b,const int32_t * projection_bias,int32_t output_state_zp,int32_t quantized_proj_clip,int8_t * output_state,int16_t * scratch)766 void CalculateLstmOutputInteger8x8_8(
767     int n_batch, int n_cell, int n_output, const int16_t* cell_state,
768     const int16_t* output_gate, const int8_t* projection_weights,
769     int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
770     int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
771     int16_t* scratch) {
772   // Note: unlike float/hybrid, the activation is always Tanh.
773   tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch);
774   tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell, 15 + 15 - 15,
775                          scratch);
776   // Note: no bias like in float/hybrid
777   tensor_utils::MatrixBatchVectorMultiply(
778       scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias,
779       n_batch, n_cell, n_output, output_state_zp, output_state);
780   if (quantized_proj_clip > 0) {
781     tensor_utils::CwiseClipping(output_state, n_batch * n_output,
782                                 quantized_proj_clip);
783   }
784 }
785 
786 // Performs an LSTM batch inference step for input specified by input_ptr.
787 // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
788 // biases (*_bias_ptr), and buffers (*_scratch), along with additional
789 // parameters:
790 //  - params: various LSTM params including activation, clipping, etc.,
791 //  - n_batch: size of batch,
792 //  - n_cell: number of cells (or units),
793 //  - n_input: the input size,
794 //  - n_aux_input: the auxiliary input size.
795 //  - n_output: the output size.
796 //  - output_batch_leading_dim: the leading dimension of the output buffer.
797 //  - context: the CpuBackendContext for use with matrix multiplications.
798 //
799 // Input of size 'n_batch * n_input':
800 //   input_ptr
801 // Input of size 'n_batch * n_aux_input':
802 //   aux_input_ptr                     - optional (can be nullptr)
803 //
804 // LSTM weights:
805 // Input weights of size 'n_cell * n_input':
806 //   input_to_input_weights            - optional
807 //   input_to_forget_weights
808 //   input_to_cell_weights
809 //   input_to_output_weights
810 // Auxiliary input weights of size 'n_cell * n_aux_input':
811 //   aux_input_to_input_weights        - optional
812 //   aux_input_to_forget_weights       - optional
813 //   aux_input_to_cell_weights         - optional
814 //   aux_input_to_output_weights       - optional
815 // Recurrent weights of size 'n_cell * n_output':
816 //   recurrent_to_input_weights        - optional
817 //   recurrent_to_forget_weights
818 //   recurrent_to_cell_weights
819 //   recurrent_to_input_weights
820 // Peephole weights of size 'n_cell', representing diagonal matrices.
821 //   cell_to_input_weights             - optional
822 //   cell_to_cell_weights              - optional
823 //   cell_to_output_weights            - optional
824 // Projection weights of size 'n_output * n_cell'
825 //   projection_weights_ptr            - optional
826 // Gate biases of size 'n_cell':
827 //   input_gate_bias_ptr               - optional
828 //   forget_gate_bias_ptr
829 //   cell_gate_bias_ptr
830 //   output_gate_bias_ptr
831 //
832 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
833 //   input_layer_norm_coefficients_ptr  - optional
834 //   forget_layer_norm_coefficients_ptr - optional
835 //   cell_layer_norm_coefficients_ptr   - optional
836 //   output_layer_norm_coefficients_ptr - optional
837 //
838 // The pointers to the cell and output state and the output are updated.
839 //
840 // The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
841 // in batch_major order, and each step processes batch_size many inputs from
842 // input_ptr, and updates batch_size many cell and output states.
843 //
844 // The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
845 // output tensor, and in most cases will be equal to n_output. It is usually not
846 // when we want to store the LSTM output into a slice of the output tensor, e.g.
847 // for bidirectional LSTMs with merge_outputs. In this case, the batched
848 // operations cannot be used since they assume that the batched outputs are
849 // contiguous, and we manually loop over the batched outputs.
850 // LINT.IfChange
LstmStepFloat(const float * input_ptr,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_gate_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * scratch0,float * scratch1,float * scratch2,float * scratch3,float * scratch4,float * output_ptr,CpuBackendContext * context)851 inline void LstmStepFloat(
852     const float* input_ptr, const float* input_to_input_weights_ptr,
853     const float* input_to_forget_weights_ptr,
854     const float* input_to_cell_weights_ptr,
855     const float* input_to_output_weights_ptr, const float* aux_input_ptr,
856     const float* aux_input_to_input_weights_ptr,
857     const float* aux_input_to_forget_weights_ptr,
858     const float* aux_input_to_cell_weights_ptr,
859     const float* aux_input_to_output_weights_ptr,
860     const float* recurrent_to_input_weights_ptr,
861     const float* recurrent_to_forget_weights_ptr,
862     const float* recurrent_to_cell_weights_ptr,
863     const float* recurrent_to_output_weights_ptr,
864     const float* cell_to_input_weights_ptr,
865     const float* cell_to_forget_weights_ptr,
866     const float* cell_to_output_weights_ptr,
867     const float* input_layer_norm_coefficients_ptr,
868     const float* forget_layer_norm_coefficients_ptr,
869     const float* cell_layer_norm_coefficients_ptr,
870     const float* output_layer_norm_coefficients_ptr,
871     const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
872     const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
873     const float* projection_weights_ptr, const float* projection_bias_ptr,
874     const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
875     int n_aux_input, int n_output, int output_batch_leading_dim,
876     float* output_state_ptr, float* cell_state_ptr, float* scratch0,
877     float* scratch1, float* scratch2, float* scratch3, float* scratch4,
878     float* output_ptr, CpuBackendContext* context) {
879   ruy::profiler::ScopeLabel label("LstmStepFloat");
880   // Since we have already checked that weights are all there or none, we can
881   // check the existence of only one to the get the condition.
882   const bool use_cifg = (input_to_input_weights_ptr == nullptr);
883 
884   // Make named scratch buffers.
885   float* input_gate_scratch = scratch0;
886   float* forget_gate_scratch = scratch1;
887   float* cell_gate_scratch = scratch2;
888   float* output_gate_scratch = scratch3;
889   float* accumulation_scratch_buffer = scratch4;
890 
891   // Check if inputs are all zeros so we can skip some computations.
892   const bool is_input_all_zeros =
893       tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
894   const bool is_aux_input_all_zeros =
895       (aux_input_ptr == nullptr ||
896        tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
897 
898   if (!use_cifg) {
899     // Calculate the input gate. (If not CIFG.)
900     CalculateLstmGateFloat(input_ptr, input_to_input_weights_ptr, aux_input_ptr,
901                            aux_input_to_input_weights_ptr, output_state_ptr,
902                            recurrent_to_input_weights_ptr,
903 
904                            cell_state_ptr, cell_to_input_weights_ptr,
905                            input_layer_norm_coefficients_ptr,
906                            input_gate_bias_ptr, n_batch, n_input, n_aux_input,
907                            n_output, n_cell,
908                            /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
909                            is_input_all_zeros, is_aux_input_all_zeros,
910                            accumulation_scratch_buffer, context);
911   }
912   // Calculate the forget gate.
913   CalculateLstmGateFloat(
914       input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
915       aux_input_to_forget_weights_ptr, output_state_ptr,
916       recurrent_to_forget_weights_ptr,
917 
918       cell_state_ptr, cell_to_forget_weights_ptr,
919       forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
920       n_input, n_aux_input, n_output, n_cell,
921       /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
922       is_aux_input_all_zeros, accumulation_scratch_buffer, context);
923   // Calculate the cell update gate.
924   CalculateLstmGateFloat(
925       input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
926       aux_input_to_cell_weights_ptr, output_state_ptr,
927       recurrent_to_cell_weights_ptr,
928 
929       /*cell_state=*/nullptr,
930       /*cell_to_gate_weights=*/nullptr, cell_layer_norm_coefficients_ptr,
931       cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
932       params->activation, cell_gate_scratch, is_input_all_zeros,
933       is_aux_input_all_zeros, accumulation_scratch_buffer, context);
934   // Update the cell state.
935   UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
936                       forget_gate_scratch, cell_gate_scratch, use_cifg,
937                       params->cell_clip);
938   // Calculate output gate.
939   CalculateLstmGateFloat(
940       input_ptr, input_to_output_weights_ptr, aux_input_ptr,
941       aux_input_to_output_weights_ptr, output_state_ptr,
942       recurrent_to_output_weights_ptr,
943 
944       cell_state_ptr, cell_to_output_weights_ptr,
945       output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
946       n_input, n_aux_input, n_output, n_cell,
947       /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
948       is_aux_input_all_zeros, accumulation_scratch_buffer, context);
949   // Update the output state.
950   CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr,
951                            output_gate_scratch, params->activation,
952                            projection_weights_ptr, projection_bias_ptr,
953                            params->proj_clip, output_state_ptr, scratch2,
954                            accumulation_scratch_buffer, context);
955   // Copy output state to the output. Note that the output's rows may not be
956   // contiguous (output_batch_leading_dim != n_output).
957   for (int b = 0; b < n_batch; b++) {
958     std::copy_n(output_state_ptr + b * n_output, n_output,
959                 output_ptr + b * output_batch_leading_dim);
960   }
961 }
962 // LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\
963 //                 ../experimental/kernels/fp16/lstm_eval.cc)
964 
965 // Same as above but with quantized weight matrices. In detail:
966 // Input of size 'n_batch * n_input':
967 //   input_ptr
968 // Input of size 'n_batch * n_aux_input':
969 //   aux_input_ptr                     - optional (can be nullptr)
970 //
971 // LSTM weights:
972 // Quantized input weights of size 'n_cell * n_input':
973 //   input_to_input_weights            - optional
974 //   input_to_forget_weights
975 //   input_to_cell_weights
976 //   input_to_input_weights
977 // Quantized auxiliary input weights of size 'n_cell * n_aux_input':
978 //   aux_input_to_input_weights        - optional
979 //   aux_input_to_forget_weights       - optional
980 //   aux_input_to_cell_weights         - optional
981 //   aux_input_to_output_weights       - optional
982 // Quantized recurrent weights of size 'n_cell * n_output':
983 //   recurrent_to_input_weights        - optional
984 //   recurrent_to_forget_weights
985 //   recurrent_to_cell_weights
986 //   recurrent_to_input_weights
987 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
988 //   cell_to_input_weights             - optional
989 //   cell_to_cell_weights              - optional
990 //   cell_to_output_weights            - optional
991 // Quantized projection weights of size 'n_output * n_cell'
992 //   projection_weights_ptr            - optional
993 // Weight scales (scalars) for each of the weights above.
994 //   input_to_input_weights_scale      - optional
995 //   input_to_forget_weights_scale
996 //   input_to_cell_weights_scale
997 //   input_to_output_weights_scale
998 //   aux_input_to_input_weights_scale  - optional
999 //   aux_input_to_forget_weights_scale - optional
1000 //   aux_input_to_cell_weights_scale   - optional
1001 //   aux_input_to_output_weights_scale - optional
1002 //   recurrent_to_input_weights_scale  - optional
1003 //   recurrent_to_forget_weights_scale
1004 //   recurrent_to_cell_weights_scale
1005 //   recurrent_to_output_weights_scale
1006 //   cell_to_input_weights_scale,
1007 //   cell_to_forget_weights_scale,
1008 //   cell_to_output_weights_scale,
1009 //   projection_weights_scale          - optional
1010 // Gate biases of size 'n_cell':
1011 //   input_gate_bias_ptr               - optional
1012 //   forget_gate_bias_ptr
1013 //   cell_gate_bias_ptr
1014 //   output_gate_bias_ptr
1015 //
1016 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
1017 //   input_layer_norm_coefficients_ptr  - optional
1018 //   forget_layer_norm_coefficients_ptr - optional
1019 //   cell_layer_norm_coefficients_ptr   - optional
1020 //   output_layer_norm_coefficients_ptr - optional
1021 //
1022 // Temporary pre-allocated storage for quantized values:
1023 //   quantized_input_ptr (same size as input_ptr)
1024 //   quantized_output_state_ptr (same size as output_state_ptr)
1025 //   quantized_output_scratch (same size as cell_state_ptr)
1026 // Temporary pre-allocated storage for recovered values:
1027 //   recovered_cell_weights (same size as cell_to_*_weights)
1028 //
1029 // Outputs:
1030 //   output_state_ptr - size 'n_batch * n_output'
1031 //   cell_state_ptr   - size 'n_batch * n_cell'
1032 //   output_ptr       - size 'n_batch * output_batch_leading_dim'
LstmStepHybrid(const float * input_ptr,const int8_t * input_to_input_weights_ptr,const uint8_t * input_to_input_weights_ledger_ptr,float input_to_input_weights_scale,const int8_t * input_to_forget_weights_ptr,const uint8_t * input_to_forget_weights_ledger_ptr,float input_to_forget_weights_scale,const int8_t * input_to_cell_weights_ptr,const uint8_t * input_to_cell_weights_ledger_ptr,float input_to_cell_weights_scale,const int8_t * input_to_output_weights_ptr,const uint8_t * input_to_output_weights_ledger_ptr,float input_to_output_weights_scale,const float * aux_input_ptr,const int8_t * aux_input_to_input_weights_ptr,float aux_input_to_input_weights_scale,const int8_t * aux_input_to_forget_weights_ptr,float aux_input_to_forget_weights_scale,const int8_t * aux_input_to_cell_weights_ptr,float aux_input_to_cell_weights_scale,const int8_t * aux_input_to_output_weights_ptr,float aux_input_to_output_weights_scale,const int8_t * recurrent_to_input_weights_ptr,const uint8_t * recurrent_to_input_weights_ledger_ptr,float recurrent_to_input_weights_scale,const int8_t * recurrent_to_forget_weights_ptr,const uint8_t * recurrent_to_forget_weights_ledger_ptr,float recurrent_to_forget_weights_scale,const int8_t * recurrent_to_cell_weights_ptr,const uint8_t * recurrent_to_cell_weights_ledger_ptr,float recurrent_to_cell_weights_scale,const int8_t * recurrent_to_output_weights_ptr,const uint8_t * recurrent_to_output_weights_ledger_ptr,float recurrent_to_output_weights_scale,const int8_t * cell_to_input_weights_ptr,float cell_to_input_weights_scale,const int8_t * cell_to_forget_weights_ptr,float cell_to_forget_weights_scale,const int8_t * cell_to_output_weights_ptr,float cell_to_output_weights_scale,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_gate_bias_ptr,const float * output_gate_bias_ptr,const int8_t * projection_weights_ptr,const uint8_t * projection_weights_ledger_ptr,float projection_weights_scale,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * scratch0,float * scratch1,float * scratch2,float * scratch3,float * input_sf,float * aux_input_sf,float * output_state_sf,float * scaling_factors_scratch,float * recovered_cell_weights,int8_t * quantized_input_ptr,int8_t * quantized_aux_input_ptr,int8_t * quantized_output_state_ptr,int8_t * quantized_output_scratch,float * output_state_ptr,float * cell_state_ptr,int32_t * accum_scratch_ptr,float * output_ptr,int32_t * input_zp,int32_t * aux_input_zp,int32_t * output_state_zp,int32_t * row_sums,int row_sums_size,bool * compute_row_sums,bool asymmetric_quantize_inputs,CpuBackendContext * context)1033 inline void LstmStepHybrid(
1034     const float* input_ptr, const int8_t* input_to_input_weights_ptr,
1035     const uint8_t* input_to_input_weights_ledger_ptr,
1036     float input_to_input_weights_scale,
1037     const int8_t* input_to_forget_weights_ptr,
1038     const uint8_t* input_to_forget_weights_ledger_ptr,
1039     float input_to_forget_weights_scale,
1040     const int8_t* input_to_cell_weights_ptr,
1041     const uint8_t* input_to_cell_weights_ledger_ptr,
1042     float input_to_cell_weights_scale,
1043     const int8_t* input_to_output_weights_ptr,
1044     const uint8_t* input_to_output_weights_ledger_ptr,
1045     float input_to_output_weights_scale, const float* aux_input_ptr,
1046     const int8_t* aux_input_to_input_weights_ptr,
1047     float aux_input_to_input_weights_scale,
1048     const int8_t* aux_input_to_forget_weights_ptr,
1049     float aux_input_to_forget_weights_scale,
1050     const int8_t* aux_input_to_cell_weights_ptr,
1051     float aux_input_to_cell_weights_scale,
1052     const int8_t* aux_input_to_output_weights_ptr,
1053     float aux_input_to_output_weights_scale,
1054     const int8_t* recurrent_to_input_weights_ptr,
1055     const uint8_t* recurrent_to_input_weights_ledger_ptr,
1056     float recurrent_to_input_weights_scale,
1057     const int8_t* recurrent_to_forget_weights_ptr,
1058     const uint8_t* recurrent_to_forget_weights_ledger_ptr,
1059     float recurrent_to_forget_weights_scale,
1060     const int8_t* recurrent_to_cell_weights_ptr,
1061     const uint8_t* recurrent_to_cell_weights_ledger_ptr,
1062     float recurrent_to_cell_weights_scale,
1063     const int8_t* recurrent_to_output_weights_ptr,
1064     const uint8_t* recurrent_to_output_weights_ledger_ptr,
1065     float recurrent_to_output_weights_scale,
1066     const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
1067     const int8_t* cell_to_forget_weights_ptr,
1068     float cell_to_forget_weights_scale,
1069     const int8_t* cell_to_output_weights_ptr,
1070     float cell_to_output_weights_scale,
1071     const float* input_layer_norm_coefficients_ptr,
1072     const float* forget_layer_norm_coefficients_ptr,
1073     const float* cell_layer_norm_coefficients_ptr,
1074     const float* output_layer_norm_coefficients_ptr,
1075     const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
1076     const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
1077     const int8_t* projection_weights_ptr,
1078     const uint8_t* projection_weights_ledger_ptr,
1079     float projection_weights_scale, const float* projection_bias_ptr,
1080     const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
1081     int n_aux_input, int n_output, int output_batch_leading_dim,
1082     float* scratch0, float* scratch1, float* scratch2, float* scratch3,
1083     float* input_sf, float* aux_input_sf, float* output_state_sf,
1084     float* scaling_factors_scratch, float* recovered_cell_weights,
1085     int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
1086     int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
1087     float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
1088     float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
1089     int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
1090     bool* compute_row_sums, bool asymmetric_quantize_inputs,
1091     CpuBackendContext* context) {
1092   ruy::profiler::ScopeLabel label("LstmStepHybrid");
1093   // Since we have already checked that weights are all there or none, we
1094   // can check the existence of only one to the get the condition.
1095   const bool use_cifg = (input_to_input_weights_ptr == nullptr);
1096   // Make named scratch buffers for the different gates.
1097   float* input_gate_scratch = scratch0;
1098   float* forget_gate_scratch = scratch1;
1099   float* cell_gate_scratch = scratch2;
1100   float* output_gate_scratch = scratch3;
1101 
1102   int32_t* input_to_input_row_sums = nullptr;
1103   int32_t* input_to_forget_row_sums = nullptr;
1104   int32_t* input_to_cell_row_sums = nullptr;
1105   int32_t* input_to_output_row_sums = nullptr;
1106   int32_t* aux_input_to_input_row_sums = nullptr;
1107   int32_t* aux_input_to_forget_row_sums = nullptr;
1108   int32_t* aux_input_to_cell_row_sums = nullptr;
1109   int32_t* aux_input_to_output_row_sums = nullptr;
1110   int32_t* recurrent_to_input_row_sums = nullptr;
1111   int32_t* recurrent_to_forget_row_sums = nullptr;
1112   int32_t* recurrent_to_cell_row_sums = nullptr;
1113   int32_t* recurrent_to_output_row_sums = nullptr;
1114   int32_t* projection_weights_row_sums = nullptr;
1115 
1116   if (asymmetric_quantize_inputs) {
1117     int num_row_sums = use_cifg ? 6 : 8;
1118     if (aux_input_ptr != nullptr) {
1119       num_row_sums += use_cifg ? 3 : 4;
1120     }
1121     if (projection_weights_ptr != nullptr) {
1122       num_row_sums += ceil(static_cast<float>(n_output) / n_cell);
1123     }
1124     TF_LITE_ASSERT(row_sums_size == num_row_sums);
1125     input_to_input_row_sums = row_sums;
1126     input_to_forget_row_sums =
1127         use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
1128     input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
1129     input_to_output_row_sums = input_to_cell_row_sums + n_cell;
1130     if (aux_input_ptr != nullptr) {
1131       aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
1132       aux_input_to_forget_row_sums = use_cifg
1133                                          ? aux_input_to_input_row_sums
1134                                          : aux_input_to_input_row_sums + n_cell;
1135       aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
1136       aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
1137     }
1138     recurrent_to_input_row_sums = aux_input_ptr
1139                                       ? aux_input_to_output_row_sums + n_cell
1140                                       : input_to_output_row_sums + n_cell;
1141     recurrent_to_forget_row_sums = use_cifg
1142                                        ? recurrent_to_input_row_sums
1143                                        : recurrent_to_input_row_sums + n_cell;
1144     recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
1145     recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
1146     if (projection_weights_ptr != nullptr) {
1147       projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
1148     }
1149     if (*compute_row_sums) {
1150       ComputeRowSums(
1151           input_to_input_row_sums, input_to_forget_row_sums,
1152           input_to_cell_row_sums, input_to_output_row_sums,
1153           aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
1154           aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
1155           recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
1156           recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
1157           projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
1158           n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
1159           input_to_cell_weights_ptr, input_to_output_weights_ptr,
1160           aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
1161           aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
1162           recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
1163           recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
1164           projection_weights_ptr, use_cifg, aux_input_ptr);
1165       *compute_row_sums = false;
1166     }
1167   }
1168 
1169   // Check if inputs are all zeros so we can skip some computations.
1170   const bool is_input_all_zeros =
1171       tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
1172   const bool is_aux_input_all_zeros =
1173       (aux_input_ptr == nullptr ||
1174        tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
1175   const bool is_output_state_all_zeros =
1176       tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output);
1177   // Quantize inputs.
1178   if (!is_input_all_zeros) {
1179     tensor_utils::BatchQuantizeFloats(input_ptr, n_batch, n_input,
1180                                       quantized_input_ptr, input_sf, input_zp,
1181                                       asymmetric_quantize_inputs);
1182   }
1183   if (!is_aux_input_all_zeros) {
1184     tensor_utils::BatchQuantizeFloats(aux_input_ptr, n_batch, n_aux_input,
1185                                       quantized_aux_input_ptr, aux_input_sf,
1186                                       aux_input_zp, asymmetric_quantize_inputs);
1187   }
1188   if (!is_output_state_all_zeros) {
1189     tensor_utils::BatchQuantizeFloats(
1190         output_state_ptr, n_batch, n_output, quantized_output_state_ptr,
1191         output_state_sf, output_state_zp, asymmetric_quantize_inputs);
1192   }
1193   if (!use_cifg) {
1194     // Calculate the input gate. (If not CIFG.)
1195     CalculateLstmGateHybrid(
1196         quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
1197         input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
1198         input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
1199         aux_input_zp, aux_input_to_input_weights_ptr,
1200         aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
1201         quantized_output_state_ptr, output_state_sf, output_state_zp,
1202         recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
1203         recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
1204         cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
1205         input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
1206         n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1207         input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1208         is_output_state_all_zeros, compute_row_sums, context,
1209         scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1210   }
1211   // Calculate the forget gate.
1212   CalculateLstmGateHybrid(
1213       quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
1214       input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
1215       input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
1216       aux_input_zp, aux_input_to_forget_weights_ptr,
1217       aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
1218       quantized_output_state_ptr, output_state_sf, output_state_zp,
1219       recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
1220       recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
1221       cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
1222       forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
1223       n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1224       forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1225       is_output_state_all_zeros, compute_row_sums, context,
1226       scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1227   // Calculate the cell update gate.
1228   CalculateLstmGateHybrid(
1229       quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
1230       input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
1231       input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
1232       aux_input_zp, aux_input_to_cell_weights_ptr,
1233       aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
1234       quantized_output_state_ptr, output_state_sf, output_state_zp,
1235       recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
1236       recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
1237       /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
1238       /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
1239       cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
1240       params->activation, cell_gate_scratch, is_input_all_zeros,
1241       is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums,
1242       context, scaling_factors_scratch, recovered_cell_weights,
1243       accum_scratch_ptr);
1244   // Update the cell state.
1245   UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
1246                       forget_gate_scratch, cell_gate_scratch, use_cifg,
1247                       params->cell_clip);
1248   // Calculate the output gate.
1249   CalculateLstmGateHybrid(
1250       quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
1251       input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
1252       input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
1253       aux_input_zp, aux_input_to_output_weights_ptr,
1254       aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
1255       quantized_output_state_ptr, output_state_sf, output_state_zp,
1256       recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
1257       recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
1258       cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
1259       output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
1260       n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1261       output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1262       is_output_state_all_zeros, compute_row_sums, context,
1263       scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1264   // Update the output state.
1265   CalculateLstmOutputHybrid(
1266       n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
1267       params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
1268       projection_weights_scale, projection_bias_ptr, params->proj_clip,
1269       output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
1270       compute_row_sums, context, scratch2, quantized_output_scratch, input_sf,
1271       input_zp, accum_scratch_ptr);
1272   // Copy output state to the output. Note that the output's rows may not be
1273   // contiguous (output_batch_leading_dim != n_output).
1274   for (int b = 0; b < n_batch; b++) {
1275     std::copy_n(output_state_ptr + b * n_output, n_output,
1276                 output_ptr + b * output_batch_leading_dim);
1277   }
1278 }
1279 
1280 // Fully quantized lstm kernel for 16 bit gate matmul output.
1281 //
1282 // Input tensor of size n_batch * n_input:
1283 //   input_ptr
1284 //
1285 // LSTM weights:
1286 // Quantized input weights of size 'n_cell * n_input':
1287 //   input_to_input_weight_ptr            - optional
1288 //   input_to_forget_weight_ptr           - optional
1289 //   input_to_cell_weight_ptr             - optional
1290 //   input_to_output_weight_ptr           - optional
1291 //
1292 // Quantized recurrent weights of size 'n_cell * n_output':
1293 //   recurrent_to_input_weight_ptr        - optional
1294 //   recurrent_to_forget_weights_ptr
1295 //   recurrent_to_cell_weights_ptr
1296 //   recurrent_to_input_weights_ptr
1297 //
1298 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
1299 //   cell_to_input_weights               - optional
1300 //   cell_to_cell_weights                - optional
1301 //   cell_to_output_weights              - optional
1302 //
1303 // Quantized projection weights of size 'n_output * n_cell'
1304 //   projection_weight_ptr                     - optional
1305 //
1306 // Weight scales (scalars) for each of the weights above.
1307 //   effective_input_to_input_scale_a    - optional
1308 //   effective_input_to_input_scale_b    - optional
1309 //   effective_input_to_forget_scale_a
1310 //   effective_input_to_forget_scale_b
1311 //   effective_input_to_cell_scale_a
1312 //   effective_input_to_cell_scale_b
1313 //   effective_input_to_output_scale_a
1314 //   effective_input_to_output_scale_b
1315 //   effective_recurrent_to_input_scale_a    - optional
1316 //   effective_recurrent_to_input_scale_b    - optional
1317 //   effective_recurrent_to_forget_scale_a
1318 //   effective_recurrent_to_forget_scale_b
1319 //   effective_recurrent_to_cell_scale_a
1320 //   effective_recurrent_to_cell_scale_b
1321 //   effective_recurrent_to_output_scale_a
1322 //   effective_recurrent_to_output_scale_b
1323 //   effective_proj_scale_a                  - optional
1324 //   effective_proj_scale_b                  - optional
1325 //
1326 // Gate biases of size 'n_cell':
1327 //   input_gate_bias_ptr                 - optional
1328 //   forget_gate_bias_ptr
1329 //   cell_gate_bias_ptr
1330 //   output_gate_bias_ptr
1331 //
1332 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
1333 //   layer_norm_input_weight_ptr    - optional
1334 //   layer_norm_forget_weight_ptr   - optional
1335 //   layer_norm_cell_weight_ptr     - optional
1336 //   layer_norm_output_weight_ptr   - optional
1337 //
1338 // Layer norm scales of size 'n_cell'.
1339 //   layer_norm_input_scale_a     - optional
1340 //   layer_norm_input_scale_b     - optional
1341 //   layer_norm_forget_scale_a    - optional
1342 //   layer_norm_forget_scale_b    - optional
1343 //   layer_norm_cell_scale_a      - optional
1344 //   layer_norm_cell_scale_b      - optional
1345 //   layer_norm_output_scale_a    - optional
1346 //   layer_norm_output_scale_b    - optional
1347 //
1348 // Scalar values:
1349 //   quantized_cell_clip: quantized clip value for cell.
1350 //   quantized_proj_clip: quantized clip value for projection.
1351 //   cell_state_scale: the power of two scale for cell state.
1352 //
1353 // Zero points:
1354 //   output_state_zp: zero point of output state
1355 //   hidden_zp: zero point for hidden state.
1356 //
1357 // Temporary pre-allocated storage for the calculation. Each is of size n_cell *
1358 // n_batch.
1359 //   scratch0
1360 //   scratch1
1361 //   scratch2
1362 //   scratch3
1363 //   scratch4
1364 //   scratch5: this scratch buffer is created purely for optimizing the
1365 //              MatrixBatchVectorMultiplyAccumulate.
1366 //
1367 // Outputs:
1368 //   output_state_ptr - size 'n_batch * n_output'
1369 //   cell_state_ptr   - size 'n_batch * n_cell'
1370 //   output_ptr       - size 'n_batch * n_output'
1371 // TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then.
LstmStepInteger8x8_16(const int8_t * input_ptr,const int8_t * input_to_input_weight_ptr,int32_t effective_input_to_input_scale_a,int32_t effective_input_to_input_scale_b,const int8_t * input_to_forget_weight_ptr,int32_t effective_input_to_forget_scale_a,int32_t effective_input_to_forget_scale_b,const int8_t * input_to_cell_weight_ptr,int32_t effective_input_to_cell_scale_a,int32_t effective_input_to_cell_scale_b,const int8_t * input_to_output_weight_ptr,int32_t effective_input_to_output_scale_a,int32_t effective_input_to_output_scale_b,const int8_t * recurrent_to_input_weight_ptr,int32_t effective_recurrent_to_input_scale_a,int32_t effective_recurrent_to_input_scale_b,const int8_t * recurrent_to_forget_weight_ptr,int32_t effective_recurrent_to_forget_scale_a,int32_t effective_recurrent_to_forget_scale_b,const int8_t * recurrent_to_cell_weight_ptr,int32_t effective_recurrent_to_cell_scale_a,int32_t effective_recurrent_to_cell_scale_b,const int8_t * recurrent_to_output_weight_ptr,int32_t effective_recurrent_to_output_scale_a,int32_t effective_recurrent_to_output_scale_b,const int16_t * cell_to_input_weight_ptr,int32_t effective_cell_to_input_scale_a,int32_t effective_cell_to_input_scale_b,const int16_t * cell_to_forget_weight_ptr,int32_t effective_cell_to_forget_scale_a,int32_t effective_cell_to_forget_scale_b,const int16_t * cell_to_output_weight_ptr,int32_t effective_cell_to_output_scale_a,int32_t effective_cell_to_output_scale_b,const int8_t * projection_weight_ptr,int32_t effective_proj_scale_a,int32_t effective_proj_scale_b,int32_t hidden_zp,int32_t effective_hidden_scale_a,int32_t effective_hidden_scale_b,const int16_t * layer_norm_input_weight_ptr,int32_t layer_norm_input_scale_a,int32_t layer_norm_input_scale_b,const int16_t * layer_norm_forget_weight_ptr,int32_t layer_norm_forget_scale_a,int32_t layer_norm_forget_scale_b,const int16_t * layer_norm_cell_weight_ptr,int32_t layer_norm_cell_scale_a,int32_t layer_norm_cell_scale_b,const int16_t * layer_norm_output_weight_ptr,int32_t layer_norm_output_scale_a,int32_t layer_norm_output_scale_b,const int32_t * input_gate_bias_ptr,const int32_t * forget_gate_bias_ptr,const int32_t * cell_gate_bias_ptr,const int32_t * output_gate_bias_ptr,int16_t quantized_cell_clip,int8_t quantized_proj_clip,int32_t cell_state_scale,int32_t input_variance_guard,int32_t forget_variance_guard,int32_t cell_variance_guard,int32_t output_variance_guard,const int32_t * input_to_forget_effective_bias,const int32_t * recurrent_to_forget_effective_bias,const int32_t * input_to_cell_effective_bias,const int32_t * recurrent_to_cell_effective_bias,const int32_t * input_to_output_effective_bias,const int32_t * recurrent_to_output_effective_bias,const int32_t * input_to_input_effective_bias,const int32_t * recurrent_to_input_effective_bias,const int32_t * projection_effective_bias,int n_batch,int n_cell,int n_input,int n_output,int8_t * output_state_ptr,int32_t output_state_zp,int16_t * cell_state_ptr,int8_t * output_ptr,int16_t * scratch0,int16_t * scratch1,int16_t * scratch2,int16_t * scratch3,int8_t * scratch4,int32_t * scratch5,CpuBackendContext * context)1372 inline void LstmStepInteger8x8_16(
1373     const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
1374     int32_t effective_input_to_input_scale_a,
1375     int32_t effective_input_to_input_scale_b,
1376     const int8_t* input_to_forget_weight_ptr,
1377     int32_t effective_input_to_forget_scale_a,
1378     int32_t effective_input_to_forget_scale_b,
1379     const int8_t* input_to_cell_weight_ptr,
1380     int32_t effective_input_to_cell_scale_a,
1381     int32_t effective_input_to_cell_scale_b,
1382     const int8_t* input_to_output_weight_ptr,
1383     int32_t effective_input_to_output_scale_a,
1384     int32_t effective_input_to_output_scale_b,
1385     const int8_t* recurrent_to_input_weight_ptr,
1386     int32_t effective_recurrent_to_input_scale_a,
1387     int32_t effective_recurrent_to_input_scale_b,
1388     const int8_t* recurrent_to_forget_weight_ptr,
1389     int32_t effective_recurrent_to_forget_scale_a,
1390     int32_t effective_recurrent_to_forget_scale_b,
1391     const int8_t* recurrent_to_cell_weight_ptr,
1392     int32_t effective_recurrent_to_cell_scale_a,
1393     int32_t effective_recurrent_to_cell_scale_b,
1394     const int8_t* recurrent_to_output_weight_ptr,
1395     int32_t effective_recurrent_to_output_scale_a,
1396     int32_t effective_recurrent_to_output_scale_b,
1397     const int16_t* cell_to_input_weight_ptr,
1398     int32_t effective_cell_to_input_scale_a,
1399     int32_t effective_cell_to_input_scale_b,
1400     const int16_t* cell_to_forget_weight_ptr,
1401     int32_t effective_cell_to_forget_scale_a,
1402     int32_t effective_cell_to_forget_scale_b,
1403     const int16_t* cell_to_output_weight_ptr,
1404     int32_t effective_cell_to_output_scale_a,
1405     int32_t effective_cell_to_output_scale_b,
1406     const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
1407     int32_t effective_proj_scale_b, int32_t hidden_zp,
1408     int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
1409     const int16_t* layer_norm_input_weight_ptr,
1410     int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
1411     const int16_t* layer_norm_forget_weight_ptr,
1412     int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
1413     const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
1414     int32_t layer_norm_cell_scale_b,
1415     const int16_t* layer_norm_output_weight_ptr,
1416     int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
1417     const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
1418     const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
1419     int16_t quantized_cell_clip, int8_t quantized_proj_clip,
1420     int32_t cell_state_scale, int32_t input_variance_guard,
1421     int32_t forget_variance_guard, int32_t cell_variance_guard,
1422     int32_t output_variance_guard,
1423     const int32_t* input_to_forget_effective_bias,
1424     const int32_t* recurrent_to_forget_effective_bias,
1425     const int32_t* input_to_cell_effective_bias,
1426     const int32_t* recurrent_to_cell_effective_bias,
1427     const int32_t* input_to_output_effective_bias,
1428     const int32_t* recurrent_to_output_effective_bias,
1429     const int32_t* input_to_input_effective_bias,
1430     const int32_t* recurrent_to_input_effective_bias,
1431     const int32_t* projection_effective_bias, int n_batch, int n_cell,
1432     int n_input, int n_output, int8_t* output_state_ptr,
1433     int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
1434     int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
1435     int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) {
1436   ruy::profiler::ScopeLabel label("LstmStepInteger8x8_16");
1437   // Make named scratch buffers for the different gates.
1438   int16_t* input_gate_scratch = scratch0;
1439   int16_t* forget_gate_scratch = scratch1;
1440   int16_t* cell_gate_scratch = scratch2;
1441   int16_t* output_gate_scratch = scratch3;
1442 
1443   // Since we have already checked that weights are all there or none, we
1444   // can check the existence of only one to the get the condition.
1445   const bool use_cifg = (input_to_input_weight_ptr == nullptr);
1446 
1447   // Check for nullptrs.
1448   TFLITE_DCHECK(input_to_forget_effective_bias);
1449   TFLITE_DCHECK(recurrent_to_forget_effective_bias);
1450   TFLITE_DCHECK(input_to_cell_effective_bias);
1451   TFLITE_DCHECK(recurrent_to_cell_effective_bias);
1452   TFLITE_DCHECK(input_to_output_effective_bias);
1453   TFLITE_DCHECK(recurrent_to_output_effective_bias);
1454   if (!use_cifg) {
1455     TFLITE_DCHECK(input_to_input_effective_bias);
1456     TFLITE_DCHECK(recurrent_to_input_effective_bias);
1457   }
1458   const bool use_projection = (projection_weight_ptr != nullptr);
1459   if (use_projection) {
1460     TFLITE_DCHECK(projection_effective_bias);
1461   }
1462   if (!use_cifg) {
1463     // Calculate the input gate. (If not CIFG.)
1464     CalculateLstmGateInteger8x8_16(
1465         input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
1466         effective_input_to_input_scale_a, effective_input_to_input_scale_b,
1467         output_state_ptr, recurrent_to_input_weight_ptr,
1468         recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
1469         effective_recurrent_to_input_scale_b, cell_state_ptr,
1470         cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
1471         effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
1472         input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
1473         input_variance_guard, n_batch, n_input, n_output, n_cell,
1474         kTfLiteActSigmoid, input_gate_scratch, context, scratch5);
1475   }
1476   // Calculate the forget gate.
1477   CalculateLstmGateInteger8x8_16(
1478       input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
1479       effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
1480       output_state_ptr, recurrent_to_forget_weight_ptr,
1481       recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
1482       effective_recurrent_to_forget_scale_b, cell_state_ptr,
1483       cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
1484       effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
1485       forget_gate_bias_ptr, layer_norm_forget_scale_a,
1486       layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
1487       n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, context,
1488       scratch5);
1489   // Calculate the cell update gate.
1490   CalculateLstmGateInteger8x8_16(
1491       input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
1492       effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
1493       output_state_ptr, recurrent_to_cell_weight_ptr,
1494       recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
1495       effective_recurrent_to_cell_scale_b, cell_state_ptr,
1496       /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
1497       /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
1498       cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
1499       cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
1500       cell_gate_scratch, context, scratch5);
1501   // Update the cell state.
1502   UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
1503                         input_gate_scratch, forget_gate_scratch,
1504                         cell_gate_scratch, use_cifg, quantized_cell_clip);
1505   // Calculate the output gate.
1506   CalculateLstmGateInteger8x8_16(
1507       input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
1508       effective_input_to_output_scale_a, effective_input_to_output_scale_b,
1509       output_state_ptr, recurrent_to_output_weight_ptr,
1510       recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
1511       effective_recurrent_to_output_scale_b, cell_state_ptr,
1512       cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
1513       effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
1514       output_gate_bias_ptr, layer_norm_output_scale_a,
1515       layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
1516       n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, context,
1517       scratch5);
1518   // Update the output state.
1519   CalculateLstmOutputInteger8x8_16(
1520       n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
1521       output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
1522       hidden_zp, projection_weight_ptr, effective_proj_scale_a,
1523       effective_proj_scale_b, projection_effective_bias, output_state_zp,
1524       quantized_proj_clip, output_state_ptr, context, scratch0, scratch4,
1525       scratch5);
1526   // Copy output state to the output. Note that unlike float or hybrid, output
1527   // is always contiguous.
1528   std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
1529 }
1530 
1531 // Fully quantized lstm kernel for 8 bit gate matmul output.
1532 //
1533 // Input tensor of size n_batch * n_input:
1534 //   input_ptr
1535 //
1536 // LSTM weights:
1537 // Quantized input weights of size 'n_cell * n_input':
1538 //   input_to_input_weight_ptr            - optional
1539 //   input_to_forget_weight_ptr           - optional
1540 //   input_to_cell_weight_ptr             - optional
1541 //   input_to_output_weight_ptr           - optional
1542 //
1543 // Quantized recurrent weights of size 'n_cell * n_output':
1544 //   recurrent_to_input_weight_ptr        - optional
1545 //   recurrent_to_forget_weights_ptr
1546 //   recurrent_to_cell_weights_ptr
1547 //   recurrent_to_input_weights_ptr
1548 //
1549 // Quantized peephole weights of size 'n_cell', representing diagonal matrices.
1550 //   cell_to_input_weights               - optional
1551 //   cell_to_cell_weights                - optional
1552 //   cell_to_output_weights              - optional
1553 //
1554 // Quantized projection weights of size 'n_output * n_cell'
1555 //   projection_weight_ptr                     - optional
1556 //
1557 // Weight scales (scalars) for each of the weights above.
1558 //   effective_input_to_input_scale_a    - optional
1559 //   effective_input_to_input_scale_b    - optional
1560 //   effective_input_to_forget_scale_a
1561 //   effective_input_to_forget_scale_b
1562 //   effective_input_to_cell_scale_a
1563 //   effective_input_to_cell_scale_b
1564 //   effective_input_to_output_scale_a
1565 //   effective_input_to_output_scale_b
1566 //   effective_recurrent_to_input_scale_a    - optional
1567 //   effective_recurrent_to_input_scale_b    - optional
1568 //   effective_recurrent_to_forget_scale_a
1569 //   effective_recurrent_to_forget_scale_b
1570 //   effective_recurrent_to_cell_scale_a
1571 //   effective_recurrent_to_cell_scale_b
1572 //   effective_recurrent_to_output_scale_a
1573 //   effective_recurrent_to_output_scale_b
1574 //   effective_proj_scale_a                  - optional
1575 //   effective_proj_scale_b                  - optional
1576 //
1577 // Gate biases of size 'n_cell':
1578 //   input_gate_bias_ptr                 - optional
1579 //   forget_gate_bias_ptr
1580 //   cell_gate_bias_ptr
1581 //   output_gate_bias_ptr
1582 //
1583 // Layer norm coefficients of size 'n_cell', representing diagonal matrices.
1584 //   layer_norm_input_weight_ptr    - optional
1585 //   layer_norm_forget_weight_ptr   - optional
1586 //   layer_norm_cell_weight_ptr     - optional
1587 //   layer_norm_output_weight_ptr   - optional
1588 //
1589 // Layer norm scales of size 'n_cell'.
1590 //   layer_norm_input_scale_a     - optional
1591 //   layer_norm_input_scale_b     - optional
1592 //   layer_norm_forget_scale_a    - optional
1593 //   layer_norm_forget_scale_b    - optional
1594 //   layer_norm_cell_scale_a      - optional
1595 //   layer_norm_cell_scale_b      - optional
1596 //   layer_norm_output_scale_a    - optional
1597 //   layer_norm_output_scale_b    - optional
1598 //
1599 // Scalar values:
1600 //   quantized_cell_clip: quantized clip value for cell.
1601 //   quantized_proj_clip: quantized clip value for projection.
1602 //   cell_state_scale: the power of two scale for cell state.
1603 //
1604 // Zero points:
1605 //   input_zp: zero point for input tensor.
1606 //   output_state_zp: zero point of output state.
1607 //   hidden_zp: zero point for hidden state.
1608 //
1609 // Temporary pre-allocated storage for the calculation. Each is of size n_cell *
1610 // n_batch.
1611 //   scratch0
1612 //   scratch1
1613 //   scratch2
1614 //   scratch3
1615 //   scratch4
1616 //   scratch5
1617 //   scratch6
1618 //   scratch7
1619 //
1620 // Outputs:
1621 //   output_state_ptr - size 'n_batch * n_output'
1622 //   cell_state_ptr   - size 'n_batch * n_cell'
1623 //   output_ptr       - size 'n_batch * n_output'
1624 //
1625 // Can move zero point calculation into Prepare() for better perfomance.
1626 // TODO(b/159947023): scratch5 is unused, remove.
LstmStepInteger8x8_8(const int8_t * input_ptr,int32_t input_zp,const int8_t * input_to_input_weight_ptr,int32_t effective_input_to_input_scale_a,int32_t effective_input_to_input_scale_b,const int8_t * input_to_forget_weight_ptr,int32_t effective_input_to_forget_scale_a,int32_t effective_input_to_forget_scale_b,const int8_t * input_to_cell_weight_ptr,int32_t effective_input_to_cell_scale_a,int32_t effective_input_to_cell_scale_b,const int8_t * input_to_output_weight_ptr,int32_t effective_input_to_output_scale_a,int32_t effective_input_to_output_scale_b,const int8_t * recurrent_to_input_weight_ptr,int32_t effective_recurrent_to_input_scale_a,int32_t effective_recurrent_to_input_scale_b,const int8_t * recurrent_to_forget_weight_ptr,int32_t effective_recurrent_to_forget_scale_a,int32_t effective_recurrent_to_forget_scale_b,const int8_t * recurrent_to_cell_weight_ptr,int32_t effective_recurrent_to_cell_scale_a,int32_t effective_recurrent_to_cell_scale_b,const int8_t * recurrent_to_output_weight_ptr,int32_t effective_recurrent_to_output_scale_a,int32_t effective_recurrent_to_output_scale_b,const int8_t * cell_to_input_weight_ptr,int32_t effective_cell_to_input_scale_a,int32_t effective_cell_to_input_scale_b,const int8_t * cell_to_forget_weight_ptr,int32_t effective_cell_to_forget_scale_a,int32_t effective_cell_to_forget_scale_b,const int8_t * cell_to_output_weight_ptr,int32_t effective_cell_to_output_scale_a,int32_t effective_cell_to_output_scale_b,const int8_t * projection_weight_ptr,int32_t effective_proj_scale_a,int32_t effective_proj_scale_b,const int16_t * layer_norm_input_weight_ptr,int32_t layer_norm_input_scale_a,int32_t layer_norm_input_scale_b,const int16_t * layer_norm_forget_weight_ptr,int32_t layer_norm_forget_scale_a,int32_t layer_norm_forget_scale_b,const int16_t * layer_norm_cell_weight_ptr,int32_t layer_norm_cell_scale_a,int32_t layer_norm_cell_scale_b,const int16_t * layer_norm_output_weight_ptr,int32_t layer_norm_output_scale_a,int32_t layer_norm_output_scale_b,const int32_t * input_gate_bias_ptr,const int32_t * forget_gate_bias_ptr,const int32_t * cell_gate_bias_ptr,const int32_t * output_gate_bias_ptr,const int32_t * projection_bias_ptr,const TfLiteLSTMParams * params,const int32_t * intermediate_scale_a,const int32_t * intermediate_scale_b,const int32_t * intermediate_zp,int16_t quantized_cell_clip,int8_t quantized_proj_clip,int n_batch,int n_cell,int n_input,int n_output,int output_batch_leading_dim,int8_t * output_state_ptr,int32_t output_state_zp,int16_t * cell_state_ptr,int8_t * output_ptr,int8_t * scratch0,int8_t * scratch1,int16_t * scratch2,int16_t * scratch3,int16_t * scratch4,int16_t * scratch5,int16_t * scratch6,int16_t * scratch7)1627 inline void LstmStepInteger8x8_8(
1628     const int8_t* input_ptr, int32_t input_zp,
1629     const int8_t* input_to_input_weight_ptr,
1630     int32_t effective_input_to_input_scale_a,
1631     int32_t effective_input_to_input_scale_b,
1632     const int8_t* input_to_forget_weight_ptr,
1633     int32_t effective_input_to_forget_scale_a,
1634     int32_t effective_input_to_forget_scale_b,
1635     const int8_t* input_to_cell_weight_ptr,
1636     int32_t effective_input_to_cell_scale_a,
1637     int32_t effective_input_to_cell_scale_b,
1638     const int8_t* input_to_output_weight_ptr,
1639     int32_t effective_input_to_output_scale_a,
1640     int32_t effective_input_to_output_scale_b,
1641     const int8_t* recurrent_to_input_weight_ptr,
1642     int32_t effective_recurrent_to_input_scale_a,
1643     int32_t effective_recurrent_to_input_scale_b,
1644     const int8_t* recurrent_to_forget_weight_ptr,
1645     int32_t effective_recurrent_to_forget_scale_a,
1646     int32_t effective_recurrent_to_forget_scale_b,
1647     const int8_t* recurrent_to_cell_weight_ptr,
1648     int32_t effective_recurrent_to_cell_scale_a,
1649     int32_t effective_recurrent_to_cell_scale_b,
1650     const int8_t* recurrent_to_output_weight_ptr,
1651     int32_t effective_recurrent_to_output_scale_a,
1652     int32_t effective_recurrent_to_output_scale_b,
1653     const int8_t* cell_to_input_weight_ptr,
1654     int32_t effective_cell_to_input_scale_a,
1655     int32_t effective_cell_to_input_scale_b,
1656     const int8_t* cell_to_forget_weight_ptr,
1657     int32_t effective_cell_to_forget_scale_a,
1658     int32_t effective_cell_to_forget_scale_b,
1659     const int8_t* cell_to_output_weight_ptr,
1660     int32_t effective_cell_to_output_scale_a,
1661     int32_t effective_cell_to_output_scale_b,
1662     const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
1663     int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
1664     int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
1665     const int16_t* layer_norm_forget_weight_ptr,
1666     int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
1667     const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
1668     int32_t layer_norm_cell_scale_b,
1669     const int16_t* layer_norm_output_weight_ptr,
1670     int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
1671     const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
1672     const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
1673     const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
1674     const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
1675     const int32_t* intermediate_zp, int16_t quantized_cell_clip,
1676     int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
1677     int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
1678     int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
1679     int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
1680     int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
1681     int16_t* scratch7) {
1682   // TODO(b/159066113): scratch5 is unused, remove.
1683 
1684   ruy::profiler::ScopeLabel label("LstmStepInteger8x8_8");
1685   // Make named scratch buffers for the different gates.
1686   int16_t* forget_gate_scratch = scratch2;
1687   int16_t* cell_gate_scratch = scratch3;
1688   int16_t* output_gate_scratch = scratch4;
1689   // no-CIFG is not supported here
1690 
1691   // Calculate the forget gate.
1692   CalculateLstmGateInteger8x8_8(
1693       input_ptr, input_zp, input_to_forget_weight_ptr,
1694       effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
1695       intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4],
1696       output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
1697       effective_recurrent_to_forget_scale_a,
1698       effective_recurrent_to_forget_scale_b, intermediate_scale_a[3],
1699       intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr,
1700       layer_norm_forget_scale_a, layer_norm_forget_scale_b,
1701       forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
1702       kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1);
1703   // Calculate the cell update gate.
1704   CalculateLstmGateInteger8x8_8(
1705       input_ptr, input_zp, input_to_cell_weight_ptr,
1706       effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
1707       intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7],
1708       output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
1709       effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
1710       intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8],
1711       layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
1712       layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output,
1713       n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1);
1714   // Update the cell state.
1715   UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr,
1716                         /*cell_state_scale=*/-15, /*input_gate=*/nullptr,
1717                         forget_gate_scratch, cell_gate_scratch,
1718                         /*use_cifg=*/true, quantized_cell_clip);
1719   // Calculate the output gate.
1720   CalculateLstmGateInteger8x8_8(
1721       input_ptr, input_zp, input_to_output_weight_ptr,
1722       effective_input_to_output_scale_a, effective_input_to_output_scale_b,
1723       intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10],
1724       output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
1725       effective_recurrent_to_output_scale_a,
1726       effective_recurrent_to_output_scale_b, intermediate_scale_a[11],
1727       intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr,
1728       layer_norm_output_scale_a, layer_norm_output_scale_b,
1729       output_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
1730       kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1);
1731   // Update the output state.
1732   CalculateLstmOutputInteger8x8_8(
1733       n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
1734       projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
1735       projection_bias_ptr, output_state_zp, quantized_proj_clip,
1736       output_state_ptr, scratch2);
1737   // Copy output state to the output. Note that unlike float or hybrid, output
1738   // is always contigous.
1739   std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
1740 }
1741 
1742 }  // namespace
1743 
1744 // LINT.IfChange
EvalFloat(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,CpuBackendContext * context)1745 TfLiteStatus EvalFloat(
1746     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
1747     const TfLiteTensor* input_to_forget_weights,
1748     const TfLiteTensor* input_to_cell_weights,
1749     const TfLiteTensor* input_to_output_weights,
1750     const TfLiteTensor* recurrent_to_input_weights,
1751     const TfLiteTensor* recurrent_to_forget_weights,
1752     const TfLiteTensor* recurrent_to_cell_weights,
1753     const TfLiteTensor* recurrent_to_output_weights,
1754     const TfLiteTensor* cell_to_input_weights,
1755     const TfLiteTensor* cell_to_forget_weights,
1756     const TfLiteTensor* cell_to_output_weights,
1757     const TfLiteTensor* input_layer_norm_coefficients,
1758     const TfLiteTensor* forget_layer_norm_coefficients,
1759     const TfLiteTensor* cell_layer_norm_coefficients,
1760     const TfLiteTensor* output_layer_norm_coefficients,
1761     const TfLiteTensor* aux_input,
1762     const TfLiteTensor* aux_input_to_input_weights,
1763     const TfLiteTensor* aux_input_to_forget_weights,
1764     const TfLiteTensor* aux_input_to_cell_weights,
1765     const TfLiteTensor* aux_input_to_output_weights,
1766     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
1767     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
1768     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
1769     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
1770     int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
1771     TfLiteTensor* cell_state, TfLiteTensor* output,
1772     CpuBackendContext* context) {
1773   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
1774   int max_time, n_batch;
1775   if (input->dims->size == 3) {
1776     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
1777     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
1778   } else {
1779     max_time = 1;
1780     n_batch = input->dims->data[0];
1781   }
1782   const int n_input = input->dims->data[input->dims->size - 1];
1783   const int aux_input_size =
1784       (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
1785 
1786   // n_cell and n_output will be the same size when there is no projection.
1787   const int n_cell = input_to_output_weights->dims->data[0];
1788   const int n_output = recurrent_to_output_weights->dims->data[1];
1789 
1790   // Since we have already checked that weights are all there or none, we can
1791   // check the existence of only one to the get the condition.
1792   const bool use_cifg = (input_to_input_weights == nullptr);
1793 
1794   // Index the scratch buffers pointers to the global scratch buffer.
1795   float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
1796   float* input_gate_scratch = nullptr;
1797   float* cell_gate_scratch = nullptr;
1798   float* forget_gate_scratch = nullptr;
1799   float* output_gate_scratch = nullptr;
1800   float* accumulation_scratch_buffer = nullptr;
1801   if (use_cifg) {
1802     cell_gate_scratch = scratch_buffer_ptr;
1803     forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1804     output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1805     accumulation_scratch_buffer = scratch_buffer_ptr + 3 * n_cell * n_batch;
1806   } else {
1807     input_gate_scratch = scratch_buffer_ptr;
1808     cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1809     forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1810     output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
1811     accumulation_scratch_buffer = scratch_buffer_ptr + 4 * n_cell * n_batch;
1812   }
1813 
1814   const int output_batch_leading_dim =
1815       output->dims->data[output->dims->size - 1];
1816   if (time_major) {
1817     // Loop through the sequence.
1818     const int input_step = n_batch * n_input;
1819     const int output_step = n_batch * output_batch_leading_dim;
1820     for (int t = 0; t < max_time; t++) {
1821       // If this is the forward_sequence, step forward, otherwise step
1822       // backwards.
1823       const int t_rel = forward_sequence ? t : max_time - t - 1;
1824       const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
1825       const float* aux_input_ptr = nullptr;
1826       if (aux_input) {
1827         aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
1828       }
1829       float* output_ptr =
1830           GetTensorData<float>(output) + t_rel * output_step + output_offset;
1831 
1832       LstmStepFloat(
1833           input_ptr, GetTensorData<float>(input_to_input_weights),
1834           GetTensorData<float>(input_to_forget_weights),
1835           GetTensorData<float>(input_to_cell_weights),
1836           GetTensorData<float>(input_to_output_weights), aux_input_ptr,
1837           GetTensorData<float>(aux_input_to_input_weights),
1838           GetTensorData<float>(aux_input_to_forget_weights),
1839           GetTensorData<float>(aux_input_to_cell_weights),
1840           GetTensorData<float>(aux_input_to_output_weights),
1841           GetTensorData<float>(recurrent_to_input_weights),
1842           GetTensorData<float>(recurrent_to_forget_weights),
1843           GetTensorData<float>(recurrent_to_cell_weights),
1844           GetTensorData<float>(recurrent_to_output_weights),
1845           GetTensorData<float>(cell_to_input_weights),
1846           GetTensorData<float>(cell_to_forget_weights),
1847           GetTensorData<float>(cell_to_output_weights),
1848           GetTensorData<float>(input_layer_norm_coefficients),
1849           GetTensorData<float>(forget_layer_norm_coefficients),
1850           GetTensorData<float>(cell_layer_norm_coefficients),
1851           GetTensorData<float>(output_layer_norm_coefficients),
1852           GetTensorData<float>(input_gate_bias),
1853           GetTensorData<float>(forget_gate_bias),
1854           GetTensorData<float>(cell_gate_bias),
1855           GetTensorData<float>(output_gate_bias),
1856           GetTensorData<float>(projection_weights),
1857           GetTensorData<float>(projection_bias), params, n_batch, n_cell,
1858           n_input, aux_input_size, n_output, output_batch_leading_dim,
1859           GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
1860           input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
1861           output_gate_scratch, accumulation_scratch_buffer, output_ptr,
1862           context);
1863     }
1864   } else {
1865     for (int b = 0; b < n_batch; b++) {
1866       const int input_step = n_input;
1867       const int output_step = output_batch_leading_dim;
1868       for (int t = 0; t < max_time; t++) {
1869         // If this is the forward_sequence, step forward, otherwise step
1870         // backwards.
1871         const int t_rel = forward_sequence ? t : max_time - t - 1;
1872         const int time_offset = b * max_time + t_rel;
1873         const float* input_ptr =
1874             GetTensorData<float>(input) + time_offset * input_step;
1875         const float* aux_input_ptr = nullptr;
1876         if (aux_input) {
1877           aux_input_ptr =
1878               GetTensorData<float>(aux_input) + time_offset * input_step;
1879         }
1880         float* output_ptr = GetTensorData<float>(output) +
1881                             time_offset * output_step + output_offset;
1882 
1883         // Offset the {output,cell}_state pointers to the right batch.
1884         float* output_state_ptr =
1885             GetTensorData<float>(output_state) + b * output_batch_leading_dim;
1886         float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
1887         // Offset the scratch pointers to the right batch.
1888         float* input_gate_scratch_ptr =
1889             input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
1890         float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
1891         float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
1892         float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
1893 
1894         LstmStepFloat(
1895             input_ptr, GetTensorData<float>(input_to_input_weights),
1896             GetTensorData<float>(input_to_forget_weights),
1897             GetTensorData<float>(input_to_cell_weights),
1898             GetTensorData<float>(input_to_output_weights), aux_input_ptr,
1899             GetTensorData<float>(aux_input_to_input_weights),
1900             GetTensorData<float>(aux_input_to_forget_weights),
1901             GetTensorData<float>(aux_input_to_cell_weights),
1902             GetTensorData<float>(aux_input_to_output_weights),
1903             GetTensorData<float>(recurrent_to_input_weights),
1904             GetTensorData<float>(recurrent_to_forget_weights),
1905             GetTensorData<float>(recurrent_to_cell_weights),
1906             GetTensorData<float>(recurrent_to_output_weights),
1907             GetTensorData<float>(cell_to_input_weights),
1908             GetTensorData<float>(cell_to_forget_weights),
1909             GetTensorData<float>(cell_to_output_weights),
1910             GetTensorData<float>(input_layer_norm_coefficients),
1911             GetTensorData<float>(forget_layer_norm_coefficients),
1912             GetTensorData<float>(cell_layer_norm_coefficients),
1913             GetTensorData<float>(output_layer_norm_coefficients),
1914             GetTensorData<float>(input_gate_bias),
1915             GetTensorData<float>(forget_gate_bias),
1916             GetTensorData<float>(cell_gate_bias),
1917             GetTensorData<float>(output_gate_bias),
1918             GetTensorData<float>(projection_weights),
1919             GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
1920             n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
1921             output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
1922             forget_gate_scratch_ptr, cell_gate_scratch_ptr,
1923             output_gate_scratch_ptr, accumulation_scratch_buffer, output_ptr,
1924             context);
1925       }
1926     }
1927   }
1928   return kTfLiteOk;
1929 }
1930 // LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
1931 
EvalHybrid(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_input_weights_ledger,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_forget_weights_ledger,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_cell_weights_ledger,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * input_to_output_weights_ledger,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_input_weights_ledger,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_forget_weights_ledger,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_cell_weights_ledger,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * recurrent_to_output_weights_ledger,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_weights_ledger,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * input_sf,TfLiteTensor * aux_input_sf,TfLiteTensor * output_state_sf,TfLiteTensor * prod_scaling_factors,TfLiteTensor * recovered_cell_weights,TfLiteTensor * input_quantized,TfLiteTensor * aux_input_quantized,TfLiteTensor * output_state_quantized,TfLiteTensor * cell_state_quantized,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output_scratch_buffer,TfLiteTensor * output,TfLiteTensor * input_zp,TfLiteTensor * aux_input_zp,TfLiteTensor * output_state_zp,TfLiteTensor * row_sums,int row_sums_size,bool * compute_row_sums,CpuBackendContext * context)1932 TfLiteStatus EvalHybrid(
1933     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
1934     const TfLiteTensor* input_to_input_weights_ledger,
1935     const TfLiteTensor* input_to_forget_weights,
1936     const TfLiteTensor* input_to_forget_weights_ledger,
1937     const TfLiteTensor* input_to_cell_weights,
1938     const TfLiteTensor* input_to_cell_weights_ledger,
1939     const TfLiteTensor* input_to_output_weights,
1940     const TfLiteTensor* input_to_output_weights_ledger,
1941     const TfLiteTensor* recurrent_to_input_weights,
1942     const TfLiteTensor* recurrent_to_input_weights_ledger,
1943     const TfLiteTensor* recurrent_to_forget_weights,
1944     const TfLiteTensor* recurrent_to_forget_weights_ledger,
1945     const TfLiteTensor* recurrent_to_cell_weights,
1946     const TfLiteTensor* recurrent_to_cell_weights_ledger,
1947     const TfLiteTensor* recurrent_to_output_weights,
1948     const TfLiteTensor* recurrent_to_output_weights_ledger,
1949     const TfLiteTensor* cell_to_input_weights,
1950     const TfLiteTensor* cell_to_forget_weights,
1951     const TfLiteTensor* cell_to_output_weights,
1952     const TfLiteTensor* input_layer_norm_coefficients,
1953     const TfLiteTensor* forget_layer_norm_coefficients,
1954     const TfLiteTensor* cell_layer_norm_coefficients,
1955     const TfLiteTensor* output_layer_norm_coefficients,
1956     const TfLiteTensor* aux_input,
1957     const TfLiteTensor* aux_input_to_input_weights,
1958     const TfLiteTensor* aux_input_to_forget_weights,
1959     const TfLiteTensor* aux_input_to_cell_weights,
1960     const TfLiteTensor* aux_input_to_output_weights,
1961     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
1962     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
1963     const TfLiteTensor* projection_weights,
1964     const TfLiteTensor* projection_weights_ledger,
1965     const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
1966     bool forward_sequence, bool time_major, int output_offset,
1967     TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
1968     TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
1969     TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
1970     TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
1971     TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
1972     TfLiteTensor* output_state, TfLiteTensor* cell_state,
1973     TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
1974     TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
1975     TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
1976     bool* compute_row_sums, CpuBackendContext* context) {
1977   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
1978   const int n_input = input->dims->data[input->dims->size - 1];
1979   int max_time, n_batch;
1980   if (input->dims->size == 2) {
1981     max_time = 1;
1982     n_batch = input->dims->data[0];
1983   } else {
1984     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
1985     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
1986   }
1987   const int aux_input_size =
1988       (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
1989   // n_cell and n_output will be the same size when there is no projection.
1990   const int n_cell = input_to_output_weights->dims->data[0];
1991   const int n_output = recurrent_to_output_weights->dims->data[1];
1992 
1993   // Since we have already checked that weights are all there or none, we can
1994   // check the existence of only one to get the condition.
1995   const bool use_cifg = (input_to_input_weights == nullptr);
1996 
1997   float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
1998   float* input_gate_scratch = nullptr;
1999   float* cell_gate_scratch = nullptr;
2000   float* forget_gate_scratch = nullptr;
2001   float* output_gate_scratch = nullptr;
2002   if (use_cifg) {
2003     cell_gate_scratch = scratch_buffer_ptr;
2004     forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
2005     output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
2006   } else {
2007     input_gate_scratch = scratch_buffer_ptr;
2008     cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
2009     forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
2010     output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
2011   }
2012 
2013   const int output_batch_leading_dim =
2014       output->dims->data[output->dims->size - 1];
2015 
2016   int32_t* input_zp_ptr = nullptr;
2017   int32_t* aux_input_zp_ptr = nullptr;
2018   int32_t* output_state_zp_ptr = nullptr;
2019   int32_t* row_sums_ptr = nullptr;
2020   if (params->asymmetric_quantize_inputs) {
2021     input_zp_ptr = GetTensorData<int32_t>(input_zp);
2022     aux_input_zp_ptr = GetTensorData<int32_t>(aux_input_zp);
2023     output_state_zp_ptr = GetTensorData<int32_t>(output_state_zp);
2024     row_sums_ptr = GetTensorData<int32_t>(row_sums);
2025   }
2026 
2027   if (time_major) {
2028     // Feed the sequence into the LSTM step-by-step.
2029     const int input_step = n_batch * n_input;
2030     const int output_step = n_batch * output_batch_leading_dim;
2031     for (int t = 0; t < max_time; t++) {
2032       // If this is the forward_sequence, step forward, otherwise step
2033       // backwards.
2034       const int t_rel = forward_sequence ? t : max_time - t - 1;
2035       const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
2036       const float* aux_input_ptr = nullptr;
2037       if (aux_input) {
2038         aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
2039       }
2040       float* output_ptr =
2041           GetTensorData<float>(output) + t_rel * output_step + output_offset;
2042       LstmStepHybrid(
2043           input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2044           GetTensorData<uint8_t>(input_to_input_weights_ledger),
2045           GetTensorScale(input_to_input_weights),
2046           GetTensorData<int8_t>(input_to_forget_weights),
2047           GetTensorData<uint8_t>(input_to_forget_weights_ledger),
2048           GetTensorScale(input_to_forget_weights),
2049           GetTensorData<int8_t>(input_to_cell_weights),
2050           GetTensorData<uint8_t>(input_to_cell_weights_ledger),
2051           GetTensorScale(input_to_cell_weights),
2052           GetTensorData<int8_t>(input_to_output_weights),
2053           GetTensorData<uint8_t>(input_to_output_weights_ledger),
2054           GetTensorScale(input_to_output_weights), aux_input_ptr,
2055           GetTensorData<int8_t>(aux_input_to_input_weights),
2056           GetTensorScale(aux_input_to_input_weights),
2057           GetTensorData<int8_t>(aux_input_to_forget_weights),
2058           GetTensorScale(aux_input_to_forget_weights),
2059           GetTensorData<int8_t>(aux_input_to_cell_weights),
2060           GetTensorScale(aux_input_to_cell_weights),
2061           GetTensorData<int8_t>(aux_input_to_output_weights),
2062           GetTensorScale(aux_input_to_output_weights),
2063           GetTensorData<int8_t>(recurrent_to_input_weights),
2064           GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
2065           GetTensorScale(recurrent_to_input_weights),
2066           GetTensorData<int8_t>(recurrent_to_forget_weights),
2067           GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
2068           GetTensorScale(recurrent_to_forget_weights),
2069           GetTensorData<int8_t>(recurrent_to_cell_weights),
2070           GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
2071           GetTensorScale(recurrent_to_cell_weights),
2072           GetTensorData<int8_t>(recurrent_to_output_weights),
2073           GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
2074           GetTensorScale(recurrent_to_output_weights),
2075           GetTensorData<int8_t>(cell_to_input_weights),
2076           GetTensorScale(cell_to_input_weights),
2077           GetTensorData<int8_t>(cell_to_forget_weights),
2078           GetTensorScale(cell_to_forget_weights),
2079           GetTensorData<int8_t>(cell_to_output_weights),
2080           GetTensorScale(cell_to_output_weights),
2081           GetTensorData<float>(input_layer_norm_coefficients),
2082           GetTensorData<float>(forget_layer_norm_coefficients),
2083           GetTensorData<float>(cell_layer_norm_coefficients),
2084           GetTensorData<float>(output_layer_norm_coefficients),
2085           GetTensorData<float>(input_gate_bias),
2086           GetTensorData<float>(forget_gate_bias),
2087           GetTensorData<float>(cell_gate_bias),
2088           GetTensorData<float>(output_gate_bias),
2089           GetTensorData<int8_t>(projection_weights),
2090           GetTensorData<uint8_t>(projection_weights_ledger),
2091           GetTensorScale(projection_weights),
2092           GetTensorData<float>(projection_bias), params, n_batch, n_cell,
2093           n_input, aux_input_size, n_output, output_batch_leading_dim,
2094           input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
2095           output_gate_scratch, GetTensorData<float>(input_sf),
2096           GetTensorData<float>(aux_input_sf),
2097           GetTensorData<float>(output_state_sf),
2098           GetTensorData<float>(prod_scaling_factors),
2099           GetTensorData<float>(recovered_cell_weights),
2100           GetTensorData<int8_t>(input_quantized),
2101           GetTensorData<int8_t>(aux_input_quantized),
2102           GetTensorData<int8_t>(output_state_quantized),
2103           GetTensorData<int8_t>(cell_state_quantized),
2104           GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
2105           GetTensorData<int32_t>(output_scratch_buffer), output_ptr,
2106           input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, row_sums_ptr,
2107           row_sums_size, compute_row_sums, params->asymmetric_quantize_inputs,
2108           context);
2109     }
2110   } else {
2111     for (int b = 0; b < n_batch; b++) {
2112       const int input_step = n_input;
2113       const int output_step = output_batch_leading_dim;
2114       for (int t = 0; t < max_time; t++) {
2115         // If this is the forward_sequence, step forward, otherwise step
2116         // backwards.
2117         const int t_rel = forward_sequence ? t : max_time - t - 1;
2118         const int time_offset = b * max_time + t_rel;
2119         const float* input_ptr =
2120             GetTensorData<float>(input) + time_offset * input_step;
2121         const float* aux_input_ptr = nullptr;
2122         if (aux_input) {
2123           aux_input_ptr =
2124               GetTensorData<float>(aux_input) + time_offset * input_step;
2125         }
2126         float* output_ptr = GetTensorData<float>(output) +
2127                             time_offset * output_step + output_offset;
2128 
2129         // Offset the {output,cell}_state pointers to the right batch.
2130         float* output_state_ptr =
2131             GetTensorData<float>(output_state) + b * output_batch_leading_dim;
2132         float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
2133         // Offset the scratch pointers to the right batch.
2134         float* input_gate_scratch_ptr =
2135             input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
2136         float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
2137         float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
2138         float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
2139 
2140         LstmStepHybrid(
2141             input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2142             GetTensorData<uint8_t>(input_to_input_weights_ledger),
2143             GetTensorScale(input_to_input_weights),
2144             GetTensorData<int8_t>(input_to_forget_weights),
2145             GetTensorData<uint8_t>(input_to_forget_weights_ledger),
2146             GetTensorScale(input_to_forget_weights),
2147             GetTensorData<int8_t>(input_to_cell_weights),
2148             GetTensorData<uint8_t>(input_to_cell_weights_ledger),
2149             GetTensorScale(input_to_cell_weights),
2150             GetTensorData<int8_t>(input_to_output_weights),
2151             GetTensorData<uint8_t>(input_to_output_weights_ledger),
2152             GetTensorScale(input_to_output_weights), aux_input_ptr,
2153             GetTensorData<int8_t>(aux_input_to_input_weights),
2154             GetTensorScale(aux_input_to_input_weights),
2155             GetTensorData<int8_t>(aux_input_to_forget_weights),
2156             GetTensorScale(aux_input_to_forget_weights),
2157             GetTensorData<int8_t>(aux_input_to_cell_weights),
2158             GetTensorScale(aux_input_to_cell_weights),
2159             GetTensorData<int8_t>(aux_input_to_output_weights),
2160             GetTensorScale(aux_input_to_output_weights),
2161             GetTensorData<int8_t>(recurrent_to_input_weights),
2162             GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
2163             GetTensorScale(recurrent_to_input_weights),
2164             GetTensorData<int8_t>(recurrent_to_forget_weights),
2165             GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
2166             GetTensorScale(recurrent_to_forget_weights),
2167             GetTensorData<int8_t>(recurrent_to_cell_weights),
2168             GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
2169             GetTensorScale(recurrent_to_cell_weights),
2170             GetTensorData<int8_t>(recurrent_to_output_weights),
2171             GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
2172             GetTensorScale(recurrent_to_output_weights),
2173             GetTensorData<int8_t>(cell_to_input_weights),
2174             GetTensorScale(cell_to_input_weights),
2175             GetTensorData<int8_t>(cell_to_forget_weights),
2176             GetTensorScale(cell_to_forget_weights),
2177             GetTensorData<int8_t>(cell_to_output_weights),
2178             GetTensorScale(cell_to_output_weights),
2179             GetTensorData<float>(input_layer_norm_coefficients),
2180             GetTensorData<float>(forget_layer_norm_coefficients),
2181             GetTensorData<float>(cell_layer_norm_coefficients),
2182             GetTensorData<float>(output_layer_norm_coefficients),
2183             GetTensorData<float>(input_gate_bias),
2184             GetTensorData<float>(forget_gate_bias),
2185             GetTensorData<float>(cell_gate_bias),
2186             GetTensorData<float>(output_gate_bias),
2187             GetTensorData<int8_t>(projection_weights),
2188             GetTensorData<uint8_t>(projection_weights_ledger),
2189             GetTensorScale(projection_weights),
2190             GetTensorData<float>(projection_bias), params,
2191             /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
2192             output_batch_leading_dim, input_gate_scratch_ptr,
2193             forget_gate_scratch_ptr, cell_gate_scratch_ptr,
2194             output_gate_scratch_ptr, GetTensorData<float>(input_sf),
2195             GetTensorData<float>(aux_input_sf),
2196             GetTensorData<float>(output_state_sf),
2197             GetTensorData<float>(prod_scaling_factors),
2198             GetTensorData<float>(recovered_cell_weights),
2199             GetTensorData<int8_t>(input_quantized),
2200             GetTensorData<int8_t>(aux_input_quantized),
2201             GetTensorData<int8_t>(output_state_quantized),
2202             GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
2203             cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
2204             output_ptr, input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr,
2205             row_sums_ptr, row_sums_size, compute_row_sums,
2206             params->asymmetric_quantize_inputs, context);
2207       }
2208     }
2209   }
2210 
2211   return kTfLiteOk;
2212 }
2213 
EvalInteger8x8_16(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,const lstm_eval::IntegerLstmParameter * integer_lstm_param,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,TfLiteTensor * scratch0,TfLiteTensor * scratch1,TfLiteTensor * scratch2,TfLiteTensor * scratch3,TfLiteTensor * scratch4,TfLiteTensor * scratch5,CpuBackendContext * context)2214 TfLiteStatus EvalInteger8x8_16(
2215     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
2216     const TfLiteTensor* input_to_forget_weights,
2217     const TfLiteTensor* input_to_cell_weights,
2218     const TfLiteTensor* input_to_output_weights,
2219     const TfLiteTensor* recurrent_to_input_weights,
2220     const TfLiteTensor* recurrent_to_forget_weights,
2221     const TfLiteTensor* recurrent_to_cell_weights,
2222     const TfLiteTensor* recurrent_to_output_weights,
2223     const TfLiteTensor* cell_to_input_weights,
2224     const TfLiteTensor* cell_to_forget_weights,
2225     const TfLiteTensor* cell_to_output_weights,
2226     const TfLiteTensor* input_layer_norm_coefficients,
2227     const TfLiteTensor* forget_layer_norm_coefficients,
2228     const TfLiteTensor* cell_layer_norm_coefficients,
2229     const TfLiteTensor* output_layer_norm_coefficients,
2230     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
2231     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
2232     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
2233     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
2234     const lstm_eval::IntegerLstmParameter* integer_lstm_param,
2235     TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
2236     TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
2237     TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
2238     CpuBackendContext* context) {
2239   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
2240   const int n_input = input->dims->data[input->dims->size - 1];
2241   int max_time, n_batch;
2242   if (input->dims->size == 2) {
2243     max_time = 1;
2244     n_batch = input->dims->data[0];
2245   } else {
2246     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
2247     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
2248   }
2249 
2250   // n_cell and n_output will be the same size when there is no projection.
2251   const int n_cell = input_to_output_weights->dims->data[0];
2252   const int n_output = recurrent_to_output_weights->dims->data[1];
2253 
2254   // Activation zero point
2255   int output_state_zp = output_state->params.zero_point;
2256 
2257   // Get params for time/batch/sequence.
2258   const int output_batch_leading_dim =
2259       output->dims->data[output->dims->size - 1];
2260 
2261   if (time_major) {
2262     const int input_step = n_batch * n_input;
2263     const int output_step = n_batch * output_batch_leading_dim;
2264     for (int t = 0; t < max_time; t++) {
2265       const int t_rel = t;
2266       int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
2267       const int8_t* input_ptr =
2268           GetTensorData<int8_t>(input) + t_rel * input_step;
2269       LstmStepInteger8x8_16(
2270           input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2271           integer_lstm_param->effective_input_to_input_scale_a,
2272           integer_lstm_param->effective_input_to_input_scale_b,
2273           GetTensorData<int8_t>(input_to_forget_weights),
2274           integer_lstm_param->effective_input_to_forget_scale_a,
2275           integer_lstm_param->effective_input_to_forget_scale_b,
2276           GetTensorData<int8_t>(input_to_cell_weights),
2277           integer_lstm_param->effective_input_to_cell_scale_a,
2278           integer_lstm_param->effective_input_to_cell_scale_b,
2279           GetTensorData<int8_t>(input_to_output_weights),
2280           integer_lstm_param->effective_input_to_output_scale_a,
2281           integer_lstm_param->effective_input_to_output_scale_b,
2282           GetTensorData<int8_t>(recurrent_to_input_weights),
2283           integer_lstm_param->effective_recurrent_to_input_scale_a,
2284           integer_lstm_param->effective_recurrent_to_input_scale_b,
2285           GetTensorData<int8_t>(recurrent_to_forget_weights),
2286           integer_lstm_param->effective_recurrent_to_forget_scale_a,
2287           integer_lstm_param->effective_recurrent_to_forget_scale_b,
2288           GetTensorData<int8_t>(recurrent_to_cell_weights),
2289           integer_lstm_param->effective_recurrent_to_cell_scale_a,
2290           integer_lstm_param->effective_recurrent_to_cell_scale_b,
2291           GetTensorData<int8_t>(recurrent_to_output_weights),
2292           integer_lstm_param->effective_recurrent_to_output_scale_a,
2293           integer_lstm_param->effective_recurrent_to_output_scale_b,
2294           GetTensorData<int16_t>(cell_to_input_weights),
2295           integer_lstm_param->effective_cell_to_input_scale_a,
2296           integer_lstm_param->effective_cell_to_input_scale_b,
2297           GetTensorData<int16_t>(cell_to_forget_weights),
2298           integer_lstm_param->effective_cell_to_forget_scale_a,
2299           integer_lstm_param->effective_cell_to_forget_scale_b,
2300           GetTensorData<int16_t>(cell_to_output_weights),
2301           integer_lstm_param->effective_cell_to_output_scale_a,
2302           integer_lstm_param->effective_cell_to_output_scale_b,
2303           GetTensorData<int8_t>(projection_weights),
2304           integer_lstm_param->effective_proj_scale_a,
2305           integer_lstm_param->effective_proj_scale_b,
2306           integer_lstm_param->hidden_zp,
2307           integer_lstm_param->effective_hidden_scale_a,
2308           integer_lstm_param->effective_hidden_scale_b,
2309           GetTensorData<int16_t>(input_layer_norm_coefficients),
2310           integer_lstm_param->layer_norm_input_scale_a,
2311           integer_lstm_param->layer_norm_input_scale_b,
2312           GetTensorData<int16_t>(forget_layer_norm_coefficients),
2313           integer_lstm_param->layer_norm_forget_scale_a,
2314           integer_lstm_param->layer_norm_forget_scale_b,
2315           GetTensorData<int16_t>(cell_layer_norm_coefficients),
2316           integer_lstm_param->layer_norm_cell_scale_a,
2317           integer_lstm_param->layer_norm_cell_scale_b,
2318           GetTensorData<int16_t>(output_layer_norm_coefficients),
2319           integer_lstm_param->layer_norm_output_scale_a,
2320           integer_lstm_param->layer_norm_output_scale_b,
2321           GetTensorData<int32_t>(input_gate_bias),
2322           GetTensorData<int32_t>(forget_gate_bias),
2323           GetTensorData<int32_t>(cell_gate_bias),
2324           GetTensorData<int32_t>(output_gate_bias),
2325           integer_lstm_param->quantized_cell_clip,
2326           integer_lstm_param->quantized_proj_clip,
2327           integer_lstm_param->cell_scale,
2328           integer_lstm_param->input_variance_guard,
2329           integer_lstm_param->forget_variance_guard,
2330           integer_lstm_param->cell_variance_guard,
2331           integer_lstm_param->output_variance_guard,
2332           integer_lstm_param->input_to_forget_effective_bias.get(),
2333           integer_lstm_param->recurrent_to_forget_effective_bias.get(),
2334           integer_lstm_param->input_to_cell_effective_bias.get(),
2335           integer_lstm_param->recurrent_to_cell_effective_bias.get(),
2336           integer_lstm_param->input_to_output_effective_bias.get(),
2337           integer_lstm_param->recurrent_to_output_effective_bias.get(),
2338           integer_lstm_param->input_to_input_effective_bias.get(),
2339           integer_lstm_param->recurrent_to_input_effective_bias.get(),
2340           integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell,
2341           n_input, n_output, GetTensorData<int8_t>(output_state),
2342           output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
2343           GetTensorData<int16_t>(scratch0), GetTensorData<int16_t>(scratch1),
2344           GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
2345           GetTensorData<int8_t>(scratch4), GetTensorData<int32_t>(scratch5),
2346           context);
2347     }
2348   } else {
2349     for (int b = 0; b < n_batch; b++) {
2350       const int input_step = n_input;
2351       const int output_step = output_batch_leading_dim;
2352       for (int t = 0; t < max_time; t++) {
2353         // If this is the forward_sequence, step forward, otherwise step
2354         // backwards.
2355         const int t_rel = forward_sequence ? t : max_time - t - 1;
2356         const int time_offset = b * max_time + t_rel;
2357         const int8_t* input_ptr =
2358             GetTensorData<int8_t>(input) + time_offset * input_step;
2359         int8_t* output_ptr =
2360             GetTensorData<int8_t>(output) + time_offset * output_step;
2361 
2362         // Offset the {output,cell}_state pointers to the right batch.
2363         int8_t* output_state_ptr =
2364             GetTensorData<int8_t>(output_state) + b * output_batch_leading_dim;
2365         int16_t* cell_state_ptr =
2366             GetTensorData<int16_t>(cell_state) + b * n_cell;
2367 
2368         LstmStepInteger8x8_16(
2369             input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2370             integer_lstm_param->effective_input_to_input_scale_a,
2371             integer_lstm_param->effective_input_to_input_scale_b,
2372             GetTensorData<int8_t>(input_to_forget_weights),
2373             integer_lstm_param->effective_input_to_forget_scale_a,
2374             integer_lstm_param->effective_input_to_forget_scale_b,
2375             GetTensorData<int8_t>(input_to_cell_weights),
2376             integer_lstm_param->effective_input_to_cell_scale_a,
2377             integer_lstm_param->effective_input_to_cell_scale_b,
2378             GetTensorData<int8_t>(input_to_output_weights),
2379             integer_lstm_param->effective_input_to_output_scale_a,
2380             integer_lstm_param->effective_input_to_output_scale_b,
2381             GetTensorData<int8_t>(recurrent_to_input_weights),
2382             integer_lstm_param->effective_recurrent_to_input_scale_a,
2383             integer_lstm_param->effective_recurrent_to_input_scale_b,
2384             GetTensorData<int8_t>(recurrent_to_forget_weights),
2385             integer_lstm_param->effective_recurrent_to_forget_scale_a,
2386             integer_lstm_param->effective_recurrent_to_forget_scale_b,
2387             GetTensorData<int8_t>(recurrent_to_cell_weights),
2388             integer_lstm_param->effective_recurrent_to_cell_scale_a,
2389             integer_lstm_param->effective_recurrent_to_cell_scale_b,
2390             GetTensorData<int8_t>(recurrent_to_output_weights),
2391             integer_lstm_param->effective_recurrent_to_output_scale_a,
2392             integer_lstm_param->effective_recurrent_to_output_scale_b,
2393             GetTensorData<int16_t>(cell_to_input_weights),
2394             integer_lstm_param->effective_cell_to_input_scale_a,
2395             integer_lstm_param->effective_cell_to_input_scale_b,
2396             GetTensorData<int16_t>(cell_to_forget_weights),
2397             integer_lstm_param->effective_cell_to_forget_scale_a,
2398             integer_lstm_param->effective_cell_to_forget_scale_b,
2399             GetTensorData<int16_t>(cell_to_output_weights),
2400             integer_lstm_param->effective_cell_to_output_scale_a,
2401             integer_lstm_param->effective_cell_to_output_scale_b,
2402             GetTensorData<int8_t>(projection_weights),
2403             integer_lstm_param->effective_proj_scale_a,
2404             integer_lstm_param->effective_proj_scale_b,
2405             integer_lstm_param->hidden_zp,
2406             integer_lstm_param->effective_hidden_scale_a,
2407             integer_lstm_param->effective_hidden_scale_b,
2408             GetTensorData<int16_t>(input_layer_norm_coefficients),
2409             integer_lstm_param->layer_norm_input_scale_a,
2410             integer_lstm_param->layer_norm_input_scale_b,
2411             GetTensorData<int16_t>(forget_layer_norm_coefficients),
2412             integer_lstm_param->layer_norm_forget_scale_a,
2413             integer_lstm_param->layer_norm_forget_scale_b,
2414             GetTensorData<int16_t>(cell_layer_norm_coefficients),
2415             integer_lstm_param->layer_norm_cell_scale_a,
2416             integer_lstm_param->layer_norm_cell_scale_b,
2417             GetTensorData<int16_t>(output_layer_norm_coefficients),
2418             integer_lstm_param->layer_norm_output_scale_a,
2419             integer_lstm_param->layer_norm_output_scale_b,
2420             GetTensorData<int32_t>(input_gate_bias),
2421             GetTensorData<int32_t>(forget_gate_bias),
2422             GetTensorData<int32_t>(cell_gate_bias),
2423             GetTensorData<int32_t>(output_gate_bias),
2424             integer_lstm_param->quantized_cell_clip,
2425             integer_lstm_param->quantized_proj_clip,
2426             integer_lstm_param->cell_scale,
2427             integer_lstm_param->input_variance_guard,
2428             integer_lstm_param->forget_variance_guard,
2429             integer_lstm_param->cell_variance_guard,
2430             integer_lstm_param->output_variance_guard,
2431             integer_lstm_param->input_to_forget_effective_bias.get(),
2432             integer_lstm_param->recurrent_to_forget_effective_bias.get(),
2433             integer_lstm_param->input_to_cell_effective_bias.get(),
2434             integer_lstm_param->recurrent_to_cell_effective_bias.get(),
2435             integer_lstm_param->input_to_output_effective_bias.get(),
2436             integer_lstm_param->recurrent_to_output_effective_bias.get(),
2437             integer_lstm_param->input_to_input_effective_bias.get(),
2438             integer_lstm_param->recurrent_to_input_effective_bias.get(),
2439             integer_lstm_param->projection_effective_bias.get(), /*n_batch=*/1,
2440             n_cell, n_input, n_output, output_state_ptr, output_state_zp,
2441             cell_state_ptr, output_ptr, GetTensorData<int16_t>(scratch0),
2442             GetTensorData<int16_t>(scratch1), GetTensorData<int16_t>(scratch2),
2443             GetTensorData<int16_t>(scratch3), GetTensorData<int8_t>(scratch4),
2444             GetTensorData<int32_t>(scratch5), context);
2445       }
2446     }
2447   }
2448 
2449   return kTfLiteOk;
2450 }
2451 
EvalInteger8x8_8(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,const lstm_eval::IntegerLstmParameter * integer_lstm_param,TfLiteTensor * scratch0,TfLiteTensor * scratch1,TfLiteTensor * scratch2,TfLiteTensor * scratch3,TfLiteTensor * scratch4,TfLiteTensor * scratch5,TfLiteTensor * scratch6,TfLiteTensor * scratch7)2452 TfLiteStatus EvalInteger8x8_8(
2453     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
2454     const TfLiteTensor* input_to_forget_weights,
2455     const TfLiteTensor* input_to_cell_weights,
2456     const TfLiteTensor* input_to_output_weights,
2457     const TfLiteTensor* recurrent_to_input_weights,
2458     const TfLiteTensor* recurrent_to_forget_weights,
2459     const TfLiteTensor* recurrent_to_cell_weights,
2460     const TfLiteTensor* recurrent_to_output_weights,
2461     const TfLiteTensor* cell_to_input_weights,
2462     const TfLiteTensor* cell_to_forget_weights,
2463     const TfLiteTensor* cell_to_output_weights,
2464     const TfLiteTensor* input_layer_norm_coefficients,
2465     const TfLiteTensor* forget_layer_norm_coefficients,
2466     const TfLiteTensor* cell_layer_norm_coefficients,
2467     const TfLiteTensor* output_layer_norm_coefficients,
2468     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
2469     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
2470     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
2471     const TfLiteLSTMParams* params, TfLiteTensor* output_state,
2472     TfLiteTensor* cell_state, TfLiteTensor* output,
2473     const lstm_eval::IntegerLstmParameter* integer_lstm_param,
2474     TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
2475     TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
2476     TfLiteTensor* scratch6, TfLiteTensor* scratch7) {
2477   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
2478   const int n_input = input->dims->data[input->dims->size - 1];
2479   int max_time, n_batch;
2480   if (input->dims->size == 2) {
2481     max_time = 1;
2482     n_batch = input->dims->data[0];
2483   } else {
2484     max_time = input->dims->data[0];
2485     n_batch = input->dims->data[1];
2486   }
2487 
2488   // n_cell and n_output will be the same size when there is no projection.
2489   const int n_cell = input_to_output_weights->dims->data[0];
2490   const int n_output = recurrent_to_output_weights->dims->data[1];
2491 
2492   const int32_t input_zp = input->params.zero_point;
2493   const int32_t output_state_zp = output_state->params.zero_point;
2494 
2495   // Get params for time/batch/sequence.
2496   const int output_batch_leading_dim =
2497       output->dims->data[output->dims->size - 1];
2498   const int input_step = n_batch * n_input;
2499   const int output_step = n_batch * output_batch_leading_dim;
2500 
2501   for (int t = 0; t < max_time; t++) {
2502     const int t_rel = t;
2503     int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
2504     // Input can be int8 asymmetric or int16 symmetric.
2505     const int8_t* input_ptr = GetTensorData<int8_t>(input) + t_rel * input_step;
2506     lstm_eval::LstmStepInteger8x8_8(
2507         input_ptr, input_zp,
2508 
2509         GetTensorData<int8_t>(input_to_input_weights),
2510         integer_lstm_param->effective_input_to_input_scale_a,
2511         integer_lstm_param->effective_input_to_input_scale_b,
2512 
2513         GetTensorData<int8_t>(input_to_forget_weights),
2514         integer_lstm_param->effective_input_to_forget_scale_a,
2515         integer_lstm_param->effective_input_to_forget_scale_b,
2516 
2517         GetTensorData<int8_t>(input_to_cell_weights),
2518         integer_lstm_param->effective_input_to_cell_scale_a,
2519         integer_lstm_param->effective_input_to_cell_scale_b,
2520 
2521         GetTensorData<int8_t>(input_to_output_weights),
2522         integer_lstm_param->effective_input_to_output_scale_a,
2523         integer_lstm_param->effective_input_to_output_scale_b,
2524 
2525         GetTensorData<int8_t>(recurrent_to_input_weights),
2526         integer_lstm_param->effective_recurrent_to_input_scale_a,
2527         integer_lstm_param->effective_recurrent_to_input_scale_b,
2528 
2529         GetTensorData<int8_t>(recurrent_to_forget_weights),
2530         integer_lstm_param->effective_recurrent_to_forget_scale_a,
2531         integer_lstm_param->effective_recurrent_to_forget_scale_b,
2532 
2533         GetTensorData<int8_t>(recurrent_to_cell_weights),
2534         integer_lstm_param->effective_recurrent_to_cell_scale_a,
2535         integer_lstm_param->effective_recurrent_to_cell_scale_b,
2536 
2537         GetTensorData<int8_t>(recurrent_to_output_weights),
2538         integer_lstm_param->effective_recurrent_to_output_scale_a,
2539         integer_lstm_param->effective_recurrent_to_output_scale_b,
2540 
2541         GetTensorData<int8_t>(cell_to_input_weights),
2542         integer_lstm_param->effective_cell_to_input_scale_a,
2543         integer_lstm_param->effective_cell_to_input_scale_b,
2544 
2545         GetTensorData<int8_t>(cell_to_forget_weights),
2546         integer_lstm_param->effective_cell_to_forget_scale_a,
2547         integer_lstm_param->effective_cell_to_forget_scale_b,
2548 
2549         GetTensorData<int8_t>(cell_to_output_weights),
2550         integer_lstm_param->effective_cell_to_output_scale_a,
2551         integer_lstm_param->effective_cell_to_output_scale_b,
2552 
2553         GetTensorData<int8_t>(projection_weights),
2554         integer_lstm_param->effective_proj_scale_a,
2555         integer_lstm_param->effective_proj_scale_b,
2556 
2557         GetTensorData<int16_t>(input_layer_norm_coefficients),
2558         integer_lstm_param->layer_norm_input_scale_a,
2559         integer_lstm_param->layer_norm_input_scale_b,
2560 
2561         GetTensorData<int16_t>(forget_layer_norm_coefficients),
2562         integer_lstm_param->layer_norm_forget_scale_a,
2563         integer_lstm_param->layer_norm_forget_scale_b,
2564 
2565         GetTensorData<int16_t>(cell_layer_norm_coefficients),
2566         integer_lstm_param->layer_norm_cell_scale_a,
2567         integer_lstm_param->layer_norm_cell_scale_b,
2568 
2569         GetTensorData<int16_t>(output_layer_norm_coefficients),
2570         integer_lstm_param->layer_norm_output_scale_a,
2571         integer_lstm_param->layer_norm_output_scale_b,
2572 
2573         GetTensorData<int32_t>(input_gate_bias),
2574         GetTensorData<int32_t>(forget_gate_bias),
2575         GetTensorData<int32_t>(cell_gate_bias),
2576         GetTensorData<int32_t>(output_gate_bias),
2577         GetTensorData<int32_t>(projection_bias),
2578 
2579         params, integer_lstm_param->intermediate_scale_a,
2580         integer_lstm_param->intermediate_scale_b,
2581         integer_lstm_param->intermediate_zp,
2582         integer_lstm_param->quantized_cell_clip,
2583         integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
2584         n_output, output_batch_leading_dim, GetTensorData<int8_t>(output_state),
2585         output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
2586         GetTensorData<int8_t>(scratch0), GetTensorData<int8_t>(scratch1),
2587         GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
2588         GetTensorData<int16_t>(scratch4), GetTensorData<int16_t>(scratch5),
2589         GetTensorData<int16_t>(scratch6), GetTensorData<int16_t>(scratch7));
2590   }
2591 
2592   return kTfLiteOk;
2593 }
2594 
2595 }  // namespace lstm_eval
2596 }  // namespace builtin
2597 }  // namespace ops
2598 }  // namespace tflite
2599