xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h"
16 
17 #include <algorithm>
18 #include <cstdio>
19 #include <vector>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/core/api/error_reporter.h"
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/lstm_shared.h"
30 #include "tensorflow/lite/kernels/op_macros.h"
31 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
32 
33 namespace tflite {
34 namespace optimize {
35 namespace calibration {
36 namespace builtin {
37 
38 namespace {
39 
CalculateLstmGateFloat(const float * input,const float * input_to_gate_weights,const float * aux_input,const float * aux_input_to_gate_weights,const float * output_state,const float * recurrent_to_gate_weights,const float * cell_state,const float * cell_to_gate_weights,const float * layer_norm_coefficients,const float * gate_bias,const int n_batch,const int n_input,const int n_aux_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,float * gate,const bool is_input_all_zeros,const bool is_aux_input_all_zeros,Logger * logger,int intermediate_tensor_index,const int subgraph_index,ErrorReporter * error_reporter)40 inline void CalculateLstmGateFloat(
41     const float* input, const float* input_to_gate_weights,
42     const float* aux_input, const float* aux_input_to_gate_weights,
43     const float* output_state, const float* recurrent_to_gate_weights,
44     const float* cell_state, const float* cell_to_gate_weights,
45     const float* layer_norm_coefficients, const float* gate_bias,
46     const int n_batch, const int n_input, const int n_aux_input,
47     const int n_output, const int n_cell,
48     const TfLiteFusedActivation activation, float* gate,
49     const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
50     Logger* logger, int intermediate_tensor_index, const int subgraph_index,
51     ErrorReporter* error_reporter) {
52   const bool use_peephole = (cell_to_gate_weights != nullptr);
53   const bool use_layer_norm = (layer_norm_coefficients != nullptr);
54 
55   // Initialize scratch buffers with bias for regular lstm or initialize with
56   // zero for layer norm lstm.
57   if (use_layer_norm) {
58     std::fill_n(gate, n_cell * n_batch, 0.0f);
59   } else {
60     tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
61   }
62   // For each batch and cell: compute input_weight * input.
63   // Skip if input is all zeros.
64   if (!is_input_all_zeros) {
65     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
66         input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
67   }
68   // For each batch and cell: compute aux_input_weight * aux_input.
69   // Skip if auxiliary input is not available or all zeros.
70   if (!is_aux_input_all_zeros) {
71     tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights,
72                                                       n_cell, n_aux_input,
73                                                       aux_input, n_batch, gate);
74   }
75   // For each batch and cell: compute recurrent_weight * output_state.
76   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
77       recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
78   // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
79   if (use_peephole) {
80     tensor_utils::VectorBatchVectorCwiseProductAccumulate(
81         cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
82   }
83   // Do layer normalization (if layer norm LSTM)
84   if (use_layer_norm) {
85     logger->LogTensorValue(subgraph_index, intermediate_tensor_index, gate,
86                            n_cell * n_batch, error_reporter);
87 
88     tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
89     tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
90                                                 gate, n_batch, gate);
91     tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
92   }
93   // Apply activation
94   tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation,
95                                         gate);
96 }
97 
98 // TODO(b/159066113): This is the exact same function as UpdateLstmCellFloat in
99 // kernels/lstm_eval.cc, make that public and remove this.
UpdateLstmCellFloat(int n_batch,int n_cell,float * cell_state,const float * input_gate,float * forget_gate,const float * cell_gate,bool use_cifg,float clip)100 void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
101                          const float* input_gate, float* forget_gate,
102                          const float* cell_gate, bool use_cifg, float clip) {
103   tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
104                                          n_batch * n_cell, cell_state);
105 
106   if (use_cifg) {
107     // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
108     // scratch, as input_gate array is not allocated in this case. (Be careful
109     // not to write to the scratch before reading the forget gate data.)
110     float* scratch = forget_gate;
111     tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
112     tensor_utils::VectorVectorCwiseProductAccumulate(
113         cell_gate, scratch, n_batch * n_cell, cell_state);
114   } else {
115     tensor_utils::VectorVectorCwiseProductAccumulate(
116         cell_gate, input_gate, n_batch * n_cell, cell_state);
117   }
118   if (clip > 0.0f) {
119     tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
120   }
121 }
122 
CalculateLstmOutputCalibration(int n_batch,int n_cell,int n_output,const float * cell_state,const float * output_gate,TfLiteFusedActivation activation,const float * projection_weights,const float * projection_bias,const float proj_clip,float * output_state,float * scratch,Logger * logger,int intermediate_tensor_index,const int subgraph_index,ErrorReporter * error_reporter)123 void CalculateLstmOutputCalibration(
124     int n_batch, int n_cell, int n_output, const float* cell_state,
125     const float* output_gate, TfLiteFusedActivation activation,
126     const float* projection_weights, const float* projection_bias,
127     const float proj_clip, float* output_state, float* scratch, Logger* logger,
128     int intermediate_tensor_index, const int subgraph_index,
129     ErrorReporter* error_reporter) {
130   tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
131                                         activation, scratch);
132   tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell,
133                                          scratch);
134 
135   logger->LogTensorValue(subgraph_index, intermediate_tensor_index, scratch,
136                          n_cell * n_batch, error_reporter);
137 
138   const bool use_projection = (projection_weights != nullptr);
139   const bool use_projection_bias = (projection_bias != nullptr);
140 
141   if (use_projection) {
142     if (use_projection_bias) {
143       tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
144                                             output_state);
145     } else {
146       std::fill_n(output_state, n_batch * n_output, 0.0f);
147     }
148     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
149         projection_weights, n_output, n_cell, scratch, n_batch, output_state);
150     if (proj_clip > 0.0f) {
151       tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
152     }
153   } else {
154     std::copy_n(scratch, n_batch * n_output, output_state);
155   }
156 }
157 
LstmStepCalibration(const float * input_ptr,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_gate_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * scratch0,float * scratch1,float * scratch2,float * scratch3,float * output_ptr,Logger * logger,const std::vector<int> & intermediate_tensor_indexes,const int subgraph_index,ErrorReporter * error_reporter)158 inline void LstmStepCalibration(
159     const float* input_ptr, const float* input_to_input_weights_ptr,
160     const float* input_to_forget_weights_ptr,
161     const float* input_to_cell_weights_ptr,
162     const float* input_to_output_weights_ptr, const float* aux_input_ptr,
163     const float* aux_input_to_input_weights_ptr,
164     const float* aux_input_to_forget_weights_ptr,
165     const float* aux_input_to_cell_weights_ptr,
166     const float* aux_input_to_output_weights_ptr,
167     const float* recurrent_to_input_weights_ptr,
168     const float* recurrent_to_forget_weights_ptr,
169     const float* recurrent_to_cell_weights_ptr,
170     const float* recurrent_to_output_weights_ptr,
171     const float* cell_to_input_weights_ptr,
172     const float* cell_to_forget_weights_ptr,
173     const float* cell_to_output_weights_ptr,
174     const float* input_layer_norm_coefficients_ptr,
175     const float* forget_layer_norm_coefficients_ptr,
176     const float* cell_layer_norm_coefficients_ptr,
177     const float* output_layer_norm_coefficients_ptr,
178     const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
179     const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
180     const float* projection_weights_ptr, const float* projection_bias_ptr,
181     const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
182     int n_aux_input, int n_output, int output_batch_leading_dim,
183     float* output_state_ptr, float* cell_state_ptr, float* scratch0,
184     float* scratch1, float* scratch2, float* scratch3, float* output_ptr,
185     Logger* logger, const std::vector<int>& intermediate_tensor_indexes,
186     const int subgraph_index, ErrorReporter* error_reporter) {
187   ruy::profiler::ScopeLabel label("LstmStepCalibration");
188   // Since we have already checked that weights are all there or none, we can
189   // check the existence of only one to the get the condition.
190   const bool use_cifg = (input_to_input_weights_ptr == nullptr);
191 
192   // Make named scratch buffers.
193   float* input_gate_scratch = scratch0;
194   float* forget_gate_scratch = scratch1;
195   float* cell_gate_scratch = scratch2;
196   float* output_gate_scratch = scratch3;
197 
198   // Check if inputs are all zeros so we can skip some computations.
199   const bool is_input_all_zeros =
200       tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
201   const bool is_aux_input_all_zeros =
202       (aux_input_ptr == nullptr ||
203        tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
204   if (!use_cifg) {
205     // Calculate the input gate. (If not CIFG.)
206     CalculateLstmGateFloat(
207         input_ptr, input_to_input_weights_ptr, aux_input_ptr,
208         aux_input_to_input_weights_ptr, output_state_ptr,
209         recurrent_to_input_weights_ptr, cell_state_ptr,
210         cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
211         input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
212         /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
213         is_input_all_zeros, is_aux_input_all_zeros, logger,
214         intermediate_tensor_indexes[0], subgraph_index, error_reporter);
215   }
216   // Calculate the forget gate.
217   CalculateLstmGateFloat(
218       input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
219       aux_input_to_forget_weights_ptr, output_state_ptr,
220       recurrent_to_forget_weights_ptr, cell_state_ptr,
221       cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
222       forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
223       /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
224       is_aux_input_all_zeros, logger, intermediate_tensor_indexes[1],
225       subgraph_index, error_reporter);
226   // Calculate the cell update gate.
227   CalculateLstmGateFloat(
228       input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
229       aux_input_to_cell_weights_ptr, output_state_ptr,
230       recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
231       /*cell_to_gate_weights=*/nullptr, cell_layer_norm_coefficients_ptr,
232       cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
233       params->activation, cell_gate_scratch, is_input_all_zeros,
234       is_aux_input_all_zeros, logger, intermediate_tensor_indexes[2],
235       subgraph_index, error_reporter);
236   // Update the cell state.
237   UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
238                       forget_gate_scratch, cell_gate_scratch, use_cifg,
239                       params->cell_clip);
240   // Calculate output gate.
241   CalculateLstmGateFloat(
242       input_ptr, input_to_output_weights_ptr, aux_input_ptr,
243       aux_input_to_output_weights_ptr, output_state_ptr,
244       recurrent_to_output_weights_ptr, cell_state_ptr,
245       cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
246       output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
247       /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
248       is_aux_input_all_zeros, logger, intermediate_tensor_indexes[3],
249       subgraph_index, error_reporter);
250   // Update the output state.
251   CalculateLstmOutputCalibration(
252       n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
253       params->activation, projection_weights_ptr, projection_bias_ptr,
254       params->proj_clip, output_state_ptr, scratch2, logger,
255       intermediate_tensor_indexes[4], subgraph_index, error_reporter);
256   // Copy output state to the output. Note that the output's rows may not be
257   // contiguous (output_batch_leading_dim != n_output).
258   for (int b = 0; b < n_batch; b++) {
259     std::copy_n(output_state_ptr + b * n_output, n_output,
260                 output_ptr + b * output_batch_leading_dim);
261   }
262 }
263 
EvalCalibration(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,Logger * logger,const std::vector<int> & intermediate_tensor_indexes,const int subgraph_index,ErrorReporter * error_reporter)264 TfLiteStatus EvalCalibration(
265     const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
266     const TfLiteTensor* input_to_forget_weights,
267     const TfLiteTensor* input_to_cell_weights,
268     const TfLiteTensor* input_to_output_weights,
269     const TfLiteTensor* recurrent_to_input_weights,
270     const TfLiteTensor* recurrent_to_forget_weights,
271     const TfLiteTensor* recurrent_to_cell_weights,
272     const TfLiteTensor* recurrent_to_output_weights,
273     const TfLiteTensor* cell_to_input_weights,
274     const TfLiteTensor* cell_to_forget_weights,
275     const TfLiteTensor* cell_to_output_weights,
276     const TfLiteTensor* input_layer_norm_coefficients,
277     const TfLiteTensor* forget_layer_norm_coefficients,
278     const TfLiteTensor* cell_layer_norm_coefficients,
279     const TfLiteTensor* output_layer_norm_coefficients,
280     const TfLiteTensor* aux_input,
281     const TfLiteTensor* aux_input_to_input_weights,
282     const TfLiteTensor* aux_input_to_forget_weights,
283     const TfLiteTensor* aux_input_to_cell_weights,
284     const TfLiteTensor* aux_input_to_output_weights,
285     const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
286     const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
287     const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
288     const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
289     int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
290     TfLiteTensor* cell_state, TfLiteTensor* output, Logger* logger,
291     const std::vector<int>& intermediate_tensor_indexes,
292     const int subgraph_index, ErrorReporter* error_reporter) {
293   TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
294   int max_time, n_batch;
295   if (input->dims->size == 3) {
296     max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
297     n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
298   } else {
299     max_time = 1;
300     n_batch = input->dims->data[0];
301   }
302   const int n_input = input->dims->data[input->dims->size - 1];
303   const int aux_input_size =
304       (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
305 
306   // n_cell and n_output will be the same size when there is no projection.
307   const int n_cell = input_to_output_weights->dims->data[0];
308   const int n_output = recurrent_to_output_weights->dims->data[1];
309 
310   // Since we have already checked that weights are all there or none, we can
311   // check the existence of only one to the get the condition.
312   const bool use_cifg = (input_to_input_weights == nullptr);
313 
314   // Index the scratch buffers pointers to the global scratch buffer.
315   float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
316   float* input_gate_scratch = nullptr;
317   float* cell_gate_scratch = nullptr;
318   float* forget_gate_scratch = nullptr;
319   float* output_gate_scratch = nullptr;
320   if (use_cifg) {
321     cell_gate_scratch = scratch_buffer_ptr;
322     forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
323     output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
324   } else {
325     input_gate_scratch = scratch_buffer_ptr;
326     cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
327     forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
328     output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
329   }
330 
331   const int output_batch_leading_dim =
332       output->dims->data[output->dims->size - 1];
333   if (time_major) {
334     // Loop through the sequence.
335     const int input_step = n_batch * n_input;
336     const int output_step = n_batch * output_batch_leading_dim;
337     for (int t = 0; t < max_time; t++) {
338       // If this is the forward_sequence, step forward, otherwise step
339       // backwards.
340       const int t_rel = forward_sequence ? t : max_time - t - 1;
341       const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
342       const float* aux_input_ptr = nullptr;
343       if (aux_input) {
344         aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
345       }
346       float* output_ptr_time =
347           GetTensorData<float>(output) + t_rel * output_step + output_offset;
348 
349       LstmStepCalibration(
350           input_ptr, GetTensorData<float>(input_to_input_weights),
351           GetTensorData<float>(input_to_forget_weights),
352           GetTensorData<float>(input_to_cell_weights),
353           GetTensorData<float>(input_to_output_weights), aux_input_ptr,
354           GetTensorData<float>(aux_input_to_input_weights),
355           GetTensorData<float>(aux_input_to_forget_weights),
356           GetTensorData<float>(aux_input_to_cell_weights),
357           GetTensorData<float>(aux_input_to_output_weights),
358           GetTensorData<float>(recurrent_to_input_weights),
359           GetTensorData<float>(recurrent_to_forget_weights),
360           GetTensorData<float>(recurrent_to_cell_weights),
361           GetTensorData<float>(recurrent_to_output_weights),
362           GetTensorData<float>(cell_to_input_weights),
363           GetTensorData<float>(cell_to_forget_weights),
364           GetTensorData<float>(cell_to_output_weights),
365           GetTensorData<float>(input_layer_norm_coefficients),
366           GetTensorData<float>(forget_layer_norm_coefficients),
367           GetTensorData<float>(cell_layer_norm_coefficients),
368           GetTensorData<float>(output_layer_norm_coefficients),
369           GetTensorData<float>(input_gate_bias),
370           GetTensorData<float>(forget_gate_bias),
371           GetTensorData<float>(cell_gate_bias),
372           GetTensorData<float>(output_gate_bias),
373           GetTensorData<float>(projection_weights),
374           GetTensorData<float>(projection_bias), params, n_batch, n_cell,
375           n_input, aux_input_size, n_output, output_batch_leading_dim,
376           GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
377           input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
378           output_gate_scratch, output_ptr_time, logger,
379           intermediate_tensor_indexes, subgraph_index, error_reporter);
380     }
381   } else {
382     for (int b = 0; b < n_batch; b++) {
383       const int input_step = n_input;
384       const int output_step = output_batch_leading_dim;
385       for (int t = 0; t < max_time; t++) {
386         // If this is the forward_sequence, step forward, otherwise step
387         // backwards.
388         const int t_rel = forward_sequence ? t : max_time - t - 1;
389         const int time_offset = b * max_time + t_rel;
390         const float* input_ptr =
391             GetTensorData<float>(input) + time_offset * input_step;
392         const float* aux_input_ptr = nullptr;
393         if (aux_input) {
394           aux_input_ptr =
395               GetTensorData<float>(aux_input) + time_offset * input_step;
396         }
397         float* output_ptr = GetTensorData<float>(output) +
398                             time_offset * output_step + output_offset;
399 
400         // Offset the {output,cell}_state pointers to the right batch.
401         float* output_state_ptr =
402             GetTensorData<float>(output_state) + b * output_batch_leading_dim;
403         float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
404         // Offset the scratch pointers to the right batch.
405         float* input_gate_scratch_ptr =
406             input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
407         float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
408         float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
409         float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
410 
411         LstmStepCalibration(
412             input_ptr, GetTensorData<float>(input_to_input_weights),
413             GetTensorData<float>(input_to_forget_weights),
414             GetTensorData<float>(input_to_cell_weights),
415             GetTensorData<float>(input_to_output_weights), aux_input_ptr,
416             GetTensorData<float>(aux_input_to_input_weights),
417             GetTensorData<float>(aux_input_to_forget_weights),
418             GetTensorData<float>(aux_input_to_cell_weights),
419             GetTensorData<float>(aux_input_to_output_weights),
420             GetTensorData<float>(recurrent_to_input_weights),
421             GetTensorData<float>(recurrent_to_forget_weights),
422             GetTensorData<float>(recurrent_to_cell_weights),
423             GetTensorData<float>(recurrent_to_output_weights),
424             GetTensorData<float>(cell_to_input_weights),
425             GetTensorData<float>(cell_to_forget_weights),
426             GetTensorData<float>(cell_to_output_weights),
427             GetTensorData<float>(input_layer_norm_coefficients),
428             GetTensorData<float>(forget_layer_norm_coefficients),
429             GetTensorData<float>(cell_layer_norm_coefficients),
430             GetTensorData<float>(output_layer_norm_coefficients),
431             GetTensorData<float>(input_gate_bias),
432             GetTensorData<float>(forget_gate_bias),
433             GetTensorData<float>(cell_gate_bias),
434             GetTensorData<float>(output_gate_bias),
435             GetTensorData<float>(projection_weights),
436             GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
437             n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
438             output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
439             forget_gate_scratch_ptr, cell_gate_scratch_ptr,
440             output_gate_scratch_ptr, output_ptr, logger,
441             intermediate_tensor_indexes, subgraph_index, error_reporter);
442       }
443     }
444   }
445   return kTfLiteOk;
446 }
447 
448 struct OpData {
449   // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
450   // inputs).
451   // Please note the 20-input full kernel is deprecated and only kept
452   // here for backward compatibility.
453   TfLiteLSTMKernelType kernel_type;
454 
455   // If the lstm is layer norm.
456   bool use_layer_norm;
457 
458   // These fields are only used by full kernel.
459   int scratch_tensor_index;
460 };
461 
462 // Resize the output, state tensors based on the sizes of the input tensors.
463 // Allocate a temporary scratch tensor. Also check that the sizes of the input
464 // tensors match each other.
lstm_eval(TfLiteContext * context,int subgraph_index,TfLiteNode * node,LSTMType lstm_type,Logger * logger,ErrorReporter * error_reporter)465 TfLiteStatus lstm_eval(TfLiteContext* context, int subgraph_index,
466                        TfLiteNode* node, LSTMType lstm_type, Logger* logger,
467                        ErrorReporter* error_reporter) {
468   const TfLiteTensor* input;
469   TF_LITE_ENSURE_OK(
470       context, GetInputSafe(context, node,
471                             ops::builtin::lstm::full::kInputTensor, &input));
472 
473   const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
474       context, node, ops::builtin::lstm::full::kInputToInputWeightsTensor);
475   const TfLiteTensor* input_to_forget_weights;
476   TF_LITE_ENSURE_OK(
477       context,
478       GetInputSafe(context, node,
479                    ops::builtin::lstm::full::kInputToForgetWeightsTensor,
480                    &input_to_forget_weights));
481   const TfLiteTensor* input_to_cell_weights;
482   TF_LITE_ENSURE_OK(
483       context, GetInputSafe(context, node,
484                             ops::builtin::lstm::full::kInputToCellWeightsTensor,
485                             &input_to_cell_weights));
486   const TfLiteTensor* input_to_output_weights;
487   TF_LITE_ENSURE_OK(
488       context,
489       GetInputSafe(context, node,
490                    ops::builtin::lstm::full::kInputToOutputWeightsTensor,
491                    &input_to_output_weights));
492 
493   const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
494       context, node, ops::builtin::lstm::full::kRecurrentToInputWeightsTensor);
495   const TfLiteTensor* recurrent_to_forget_weights;
496   TF_LITE_ENSURE_OK(
497       context,
498       GetInputSafe(context, node,
499                    ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
500                    &recurrent_to_forget_weights));
501   const TfLiteTensor* recurrent_to_cell_weights;
502   TF_LITE_ENSURE_OK(
503       context,
504       GetInputSafe(context, node,
505                    ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
506                    &recurrent_to_cell_weights));
507   const TfLiteTensor* recurrent_to_output_weights;
508   TF_LITE_ENSURE_OK(
509       context,
510       GetInputSafe(context, node,
511                    ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
512                    &recurrent_to_output_weights));
513 
514   const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
515       context, node, ops::builtin::lstm::full::kCellToInputWeightsTensor);
516   const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
517       context, node, ops::builtin::lstm::full::kCellToForgetWeightsTensor);
518   const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
519       context, node, ops::builtin::lstm::full::kCellToOutputWeightsTensor);
520 
521   const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
522       context, node,
523       ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor);
524   const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
525       context, node,
526       ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
527   const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
528       context, node,
529       ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor);
530   const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
531       context, node,
532       ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor);
533 
534   const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(
535       context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
536   const TfLiteTensor* forget_gate_bias;
537   TF_LITE_ENSURE_OK(
538       context, GetInputSafe(context, node,
539                             ops::builtin::lstm::full::kForgetGateBiasTensor,
540                             &forget_gate_bias));
541   const TfLiteTensor* cell_gate_bias;
542   TF_LITE_ENSURE_OK(
543       context,
544       GetInputSafe(context, node, ops::builtin::lstm::full::kCellGateBiasTensor,
545                    &cell_gate_bias));
546   const TfLiteTensor* output_gate_bias;
547   TF_LITE_ENSURE_OK(
548       context, GetInputSafe(context, node,
549                             ops::builtin::lstm::full::kOutputGateBiasTensor,
550                             &output_gate_bias));
551 
552   const TfLiteTensor* projection_weights = GetOptionalInputTensor(
553       context, node, ops::builtin::lstm::full::kProjectionWeightsTensor);
554   const TfLiteTensor* projection_bias = GetOptionalInputTensor(
555       context, node, ops::builtin::lstm::full::kProjectionBiasTensor);
556 
557   // Index the scratch buffers pointers to the global scratch buffer.
558   TfLiteTensor* scratch_buffer;
559   TF_LITE_ENSURE_OK(
560       context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
561 
562   TfLiteTensor* output_state = GetVariableInput(
563       context, node, ops::builtin::lstm::full::kOutputStateTensor);
564   TF_LITE_ENSURE(context, output_state != nullptr);
565   TfLiteTensor* cell_state = GetVariableInput(
566       context, node, ops::builtin::lstm::full::kCellStateTensor);
567   TF_LITE_ENSURE(context, cell_state != nullptr);
568 
569   TfLiteTensor* output;
570   TF_LITE_ENSURE_OK(
571       context, GetOutputSafe(context, node,
572                              ops::builtin::lstm::full::kOutputTensor, &output));
573 
574   std::vector<int> intermediate_tensor_indexes(node->intermediates->size);
575   // LSTM expect 5 intermediate tensors.
576   TF_LITE_ENSURE_EQ(context, node->intermediates->size, 5);
577   for (int i = 0; i < node->intermediates->size; ++i) {
578     intermediate_tensor_indexes[i] = node->intermediates->data[i];
579   }
580 
581   TfLiteLSTMParams lstm_params;
582   bool time_major = true;
583   switch (lstm_type) {
584     case LSTMType::kLSTM: {
585       lstm_params = *(static_cast<TfLiteLSTMParams*>(node->builtin_data));
586       time_major = true;
587       break;
588     }
589     case LSTMType::kUnidirectionalSequenceLSTM: {
590       const auto* params = static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
591           node->builtin_data);
592       // Copy out the LSTM specific params so they can be passed in the
593       // function.
594       lstm_params.activation = params->activation;
595       lstm_params.cell_clip = params->cell_clip;
596       lstm_params.proj_clip = params->proj_clip;
597       lstm_params.asymmetric_quantize_inputs =
598           params->asymmetric_quantize_inputs;
599       time_major = params->time_major;
600       break;
601     }
602     default:
603       return kTfLiteError;
604   }
605 
606   switch (input_to_output_weights->type) {
607     case kTfLiteFloat32: {
608       return EvalCalibration(
609           input, input_to_input_weights, input_to_forget_weights,
610           input_to_cell_weights, input_to_output_weights,
611           recurrent_to_input_weights, recurrent_to_forget_weights,
612           recurrent_to_cell_weights, recurrent_to_output_weights,
613           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
614           input_layer_norm_coefficients, forget_layer_norm_coefficients,
615           cell_layer_norm_coefficients, output_layer_norm_coefficients,
616           /*aux_input=*/nullptr,
617           /*aux_input_to_input_weights=*/nullptr,
618           /*aux_input_to_forget_weights=*/nullptr,
619           /*aux_input_to_cell_weights=*/nullptr,
620           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
621           forget_gate_bias, cell_gate_bias, output_gate_bias,
622           projection_weights, projection_bias, &lstm_params,
623           /*forward_sequence=*/true,
624           /*time_major=*/time_major,
625           /*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
626           logger, intermediate_tensor_indexes, subgraph_index, error_reporter);
627     }
628     case kTfLiteUInt8:
629     case kTfLiteInt8:
630     default:
631       printf("Error. Only float model can be calibrated\n");
632       return kTfLiteError;
633   }
634   return kTfLiteOk;
635 }
636 }  // namespace
637 
lstm_logging_kernel(TfLiteContext * context,const int subgraph_index,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)638 TfLiteStatus lstm_logging_kernel(TfLiteContext* context,
639                                  const int subgraph_index, TfLiteNode* node,
640                                  Logger* logger,
641                                  ErrorReporter* error_reporter) {
642   return lstm_eval(context, subgraph_index, node, LSTMType::kLSTM, logger,
643                    error_reporter);
644 }
645 
unidirectional_sequence_lstm_logging_kernel(TfLiteContext * context,const int subgraph_index,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)646 TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
647     TfLiteContext* context, const int subgraph_index, TfLiteNode* node,
648     Logger* logger, ErrorReporter* error_reporter) {
649   return lstm_eval(context, subgraph_index, node,
650                    LSTMType::kUnidirectionalSequenceLSTM, logger,
651                    error_reporter);
652 }
653 
654 }  // namespace builtin
655 }  // namespace calibration
656 }  // namespace optimize
657 }  // namespace tflite
658