1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h"
16
17 #include <algorithm>
18 #include <cstdio>
19 #include <vector>
20
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/core/api/error_reporter.h"
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/lstm_shared.h"
30 #include "tensorflow/lite/kernels/op_macros.h"
31 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
32
33 namespace tflite {
34 namespace optimize {
35 namespace calibration {
36 namespace builtin {
37
38 namespace {
39
CalculateLstmGateFloat(const float * input,const float * input_to_gate_weights,const float * aux_input,const float * aux_input_to_gate_weights,const float * output_state,const float * recurrent_to_gate_weights,const float * cell_state,const float * cell_to_gate_weights,const float * layer_norm_coefficients,const float * gate_bias,const int n_batch,const int n_input,const int n_aux_input,const int n_output,const int n_cell,const TfLiteFusedActivation activation,float * gate,const bool is_input_all_zeros,const bool is_aux_input_all_zeros,Logger * logger,int intermediate_tensor_index,const int subgraph_index,ErrorReporter * error_reporter)40 inline void CalculateLstmGateFloat(
41 const float* input, const float* input_to_gate_weights,
42 const float* aux_input, const float* aux_input_to_gate_weights,
43 const float* output_state, const float* recurrent_to_gate_weights,
44 const float* cell_state, const float* cell_to_gate_weights,
45 const float* layer_norm_coefficients, const float* gate_bias,
46 const int n_batch, const int n_input, const int n_aux_input,
47 const int n_output, const int n_cell,
48 const TfLiteFusedActivation activation, float* gate,
49 const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
50 Logger* logger, int intermediate_tensor_index, const int subgraph_index,
51 ErrorReporter* error_reporter) {
52 const bool use_peephole = (cell_to_gate_weights != nullptr);
53 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
54
55 // Initialize scratch buffers with bias for regular lstm or initialize with
56 // zero for layer norm lstm.
57 if (use_layer_norm) {
58 std::fill_n(gate, n_cell * n_batch, 0.0f);
59 } else {
60 tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
61 }
62 // For each batch and cell: compute input_weight * input.
63 // Skip if input is all zeros.
64 if (!is_input_all_zeros) {
65 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
66 input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
67 }
68 // For each batch and cell: compute aux_input_weight * aux_input.
69 // Skip if auxiliary input is not available or all zeros.
70 if (!is_aux_input_all_zeros) {
71 tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights,
72 n_cell, n_aux_input,
73 aux_input, n_batch, gate);
74 }
75 // For each batch and cell: compute recurrent_weight * output_state.
76 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
77 recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
78 // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
79 if (use_peephole) {
80 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
81 cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
82 }
83 // Do layer normalization (if layer norm LSTM)
84 if (use_layer_norm) {
85 logger->LogTensorValue(subgraph_index, intermediate_tensor_index, gate,
86 n_cell * n_batch, error_reporter);
87
88 tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
89 tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
90 gate, n_batch, gate);
91 tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
92 }
93 // Apply activation
94 tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation,
95 gate);
96 }
97
98 // TODO(b/159066113): This is the exact same function as UpdateLstmCellFloat in
99 // kernels/lstm_eval.cc, make that public and remove this.
UpdateLstmCellFloat(int n_batch,int n_cell,float * cell_state,const float * input_gate,float * forget_gate,const float * cell_gate,bool use_cifg,float clip)100 void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
101 const float* input_gate, float* forget_gate,
102 const float* cell_gate, bool use_cifg, float clip) {
103 tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
104 n_batch * n_cell, cell_state);
105
106 if (use_cifg) {
107 // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
108 // scratch, as input_gate array is not allocated in this case. (Be careful
109 // not to write to the scratch before reading the forget gate data.)
110 float* scratch = forget_gate;
111 tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
112 tensor_utils::VectorVectorCwiseProductAccumulate(
113 cell_gate, scratch, n_batch * n_cell, cell_state);
114 } else {
115 tensor_utils::VectorVectorCwiseProductAccumulate(
116 cell_gate, input_gate, n_batch * n_cell, cell_state);
117 }
118 if (clip > 0.0f) {
119 tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
120 }
121 }
122
CalculateLstmOutputCalibration(int n_batch,int n_cell,int n_output,const float * cell_state,const float * output_gate,TfLiteFusedActivation activation,const float * projection_weights,const float * projection_bias,const float proj_clip,float * output_state,float * scratch,Logger * logger,int intermediate_tensor_index,const int subgraph_index,ErrorReporter * error_reporter)123 void CalculateLstmOutputCalibration(
124 int n_batch, int n_cell, int n_output, const float* cell_state,
125 const float* output_gate, TfLiteFusedActivation activation,
126 const float* projection_weights, const float* projection_bias,
127 const float proj_clip, float* output_state, float* scratch, Logger* logger,
128 int intermediate_tensor_index, const int subgraph_index,
129 ErrorReporter* error_reporter) {
130 tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
131 activation, scratch);
132 tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell,
133 scratch);
134
135 logger->LogTensorValue(subgraph_index, intermediate_tensor_index, scratch,
136 n_cell * n_batch, error_reporter);
137
138 const bool use_projection = (projection_weights != nullptr);
139 const bool use_projection_bias = (projection_bias != nullptr);
140
141 if (use_projection) {
142 if (use_projection_bias) {
143 tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
144 output_state);
145 } else {
146 std::fill_n(output_state, n_batch * n_output, 0.0f);
147 }
148 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
149 projection_weights, n_output, n_cell, scratch, n_batch, output_state);
150 if (proj_clip > 0.0f) {
151 tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
152 }
153 } else {
154 std::copy_n(scratch, n_batch * n_output, output_state);
155 }
156 }
157
LstmStepCalibration(const float * input_ptr,const float * input_to_input_weights_ptr,const float * input_to_forget_weights_ptr,const float * input_to_cell_weights_ptr,const float * input_to_output_weights_ptr,const float * aux_input_ptr,const float * aux_input_to_input_weights_ptr,const float * aux_input_to_forget_weights_ptr,const float * aux_input_to_cell_weights_ptr,const float * aux_input_to_output_weights_ptr,const float * recurrent_to_input_weights_ptr,const float * recurrent_to_forget_weights_ptr,const float * recurrent_to_cell_weights_ptr,const float * recurrent_to_output_weights_ptr,const float * cell_to_input_weights_ptr,const float * cell_to_forget_weights_ptr,const float * cell_to_output_weights_ptr,const float * input_layer_norm_coefficients_ptr,const float * forget_layer_norm_coefficients_ptr,const float * cell_layer_norm_coefficients_ptr,const float * output_layer_norm_coefficients_ptr,const float * input_gate_bias_ptr,const float * forget_gate_bias_ptr,const float * cell_gate_bias_ptr,const float * output_gate_bias_ptr,const float * projection_weights_ptr,const float * projection_bias_ptr,const TfLiteLSTMParams * params,int n_batch,int n_cell,int n_input,int n_aux_input,int n_output,int output_batch_leading_dim,float * output_state_ptr,float * cell_state_ptr,float * scratch0,float * scratch1,float * scratch2,float * scratch3,float * output_ptr,Logger * logger,const std::vector<int> & intermediate_tensor_indexes,const int subgraph_index,ErrorReporter * error_reporter)158 inline void LstmStepCalibration(
159 const float* input_ptr, const float* input_to_input_weights_ptr,
160 const float* input_to_forget_weights_ptr,
161 const float* input_to_cell_weights_ptr,
162 const float* input_to_output_weights_ptr, const float* aux_input_ptr,
163 const float* aux_input_to_input_weights_ptr,
164 const float* aux_input_to_forget_weights_ptr,
165 const float* aux_input_to_cell_weights_ptr,
166 const float* aux_input_to_output_weights_ptr,
167 const float* recurrent_to_input_weights_ptr,
168 const float* recurrent_to_forget_weights_ptr,
169 const float* recurrent_to_cell_weights_ptr,
170 const float* recurrent_to_output_weights_ptr,
171 const float* cell_to_input_weights_ptr,
172 const float* cell_to_forget_weights_ptr,
173 const float* cell_to_output_weights_ptr,
174 const float* input_layer_norm_coefficients_ptr,
175 const float* forget_layer_norm_coefficients_ptr,
176 const float* cell_layer_norm_coefficients_ptr,
177 const float* output_layer_norm_coefficients_ptr,
178 const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
179 const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
180 const float* projection_weights_ptr, const float* projection_bias_ptr,
181 const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
182 int n_aux_input, int n_output, int output_batch_leading_dim,
183 float* output_state_ptr, float* cell_state_ptr, float* scratch0,
184 float* scratch1, float* scratch2, float* scratch3, float* output_ptr,
185 Logger* logger, const std::vector<int>& intermediate_tensor_indexes,
186 const int subgraph_index, ErrorReporter* error_reporter) {
187 ruy::profiler::ScopeLabel label("LstmStepCalibration");
188 // Since we have already checked that weights are all there or none, we can
189 // check the existence of only one to the get the condition.
190 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
191
192 // Make named scratch buffers.
193 float* input_gate_scratch = scratch0;
194 float* forget_gate_scratch = scratch1;
195 float* cell_gate_scratch = scratch2;
196 float* output_gate_scratch = scratch3;
197
198 // Check if inputs are all zeros so we can skip some computations.
199 const bool is_input_all_zeros =
200 tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
201 const bool is_aux_input_all_zeros =
202 (aux_input_ptr == nullptr ||
203 tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
204 if (!use_cifg) {
205 // Calculate the input gate. (If not CIFG.)
206 CalculateLstmGateFloat(
207 input_ptr, input_to_input_weights_ptr, aux_input_ptr,
208 aux_input_to_input_weights_ptr, output_state_ptr,
209 recurrent_to_input_weights_ptr, cell_state_ptr,
210 cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
211 input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
212 /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
213 is_input_all_zeros, is_aux_input_all_zeros, logger,
214 intermediate_tensor_indexes[0], subgraph_index, error_reporter);
215 }
216 // Calculate the forget gate.
217 CalculateLstmGateFloat(
218 input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
219 aux_input_to_forget_weights_ptr, output_state_ptr,
220 recurrent_to_forget_weights_ptr, cell_state_ptr,
221 cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
222 forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
223 /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
224 is_aux_input_all_zeros, logger, intermediate_tensor_indexes[1],
225 subgraph_index, error_reporter);
226 // Calculate the cell update gate.
227 CalculateLstmGateFloat(
228 input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
229 aux_input_to_cell_weights_ptr, output_state_ptr,
230 recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
231 /*cell_to_gate_weights=*/nullptr, cell_layer_norm_coefficients_ptr,
232 cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
233 params->activation, cell_gate_scratch, is_input_all_zeros,
234 is_aux_input_all_zeros, logger, intermediate_tensor_indexes[2],
235 subgraph_index, error_reporter);
236 // Update the cell state.
237 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
238 forget_gate_scratch, cell_gate_scratch, use_cifg,
239 params->cell_clip);
240 // Calculate output gate.
241 CalculateLstmGateFloat(
242 input_ptr, input_to_output_weights_ptr, aux_input_ptr,
243 aux_input_to_output_weights_ptr, output_state_ptr,
244 recurrent_to_output_weights_ptr, cell_state_ptr,
245 cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
246 output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
247 /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
248 is_aux_input_all_zeros, logger, intermediate_tensor_indexes[3],
249 subgraph_index, error_reporter);
250 // Update the output state.
251 CalculateLstmOutputCalibration(
252 n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
253 params->activation, projection_weights_ptr, projection_bias_ptr,
254 params->proj_clip, output_state_ptr, scratch2, logger,
255 intermediate_tensor_indexes[4], subgraph_index, error_reporter);
256 // Copy output state to the output. Note that the output's rows may not be
257 // contiguous (output_batch_leading_dim != n_output).
258 for (int b = 0; b < n_batch; b++) {
259 std::copy_n(output_state_ptr + b * n_output, n_output,
260 output_ptr + b * output_batch_leading_dim);
261 }
262 }
263
EvalCalibration(const TfLiteTensor * input,const TfLiteTensor * input_to_input_weights,const TfLiteTensor * input_to_forget_weights,const TfLiteTensor * input_to_cell_weights,const TfLiteTensor * input_to_output_weights,const TfLiteTensor * recurrent_to_input_weights,const TfLiteTensor * recurrent_to_forget_weights,const TfLiteTensor * recurrent_to_cell_weights,const TfLiteTensor * recurrent_to_output_weights,const TfLiteTensor * cell_to_input_weights,const TfLiteTensor * cell_to_forget_weights,const TfLiteTensor * cell_to_output_weights,const TfLiteTensor * input_layer_norm_coefficients,const TfLiteTensor * forget_layer_norm_coefficients,const TfLiteTensor * cell_layer_norm_coefficients,const TfLiteTensor * output_layer_norm_coefficients,const TfLiteTensor * aux_input,const TfLiteTensor * aux_input_to_input_weights,const TfLiteTensor * aux_input_to_forget_weights,const TfLiteTensor * aux_input_to_cell_weights,const TfLiteTensor * aux_input_to_output_weights,const TfLiteTensor * input_gate_bias,const TfLiteTensor * forget_gate_bias,const TfLiteTensor * cell_gate_bias,const TfLiteTensor * output_gate_bias,const TfLiteTensor * projection_weights,const TfLiteTensor * projection_bias,const TfLiteLSTMParams * params,bool forward_sequence,bool time_major,int output_offset,TfLiteTensor * scratch_buffer,TfLiteTensor * output_state,TfLiteTensor * cell_state,TfLiteTensor * output,Logger * logger,const std::vector<int> & intermediate_tensor_indexes,const int subgraph_index,ErrorReporter * error_reporter)264 TfLiteStatus EvalCalibration(
265 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
266 const TfLiteTensor* input_to_forget_weights,
267 const TfLiteTensor* input_to_cell_weights,
268 const TfLiteTensor* input_to_output_weights,
269 const TfLiteTensor* recurrent_to_input_weights,
270 const TfLiteTensor* recurrent_to_forget_weights,
271 const TfLiteTensor* recurrent_to_cell_weights,
272 const TfLiteTensor* recurrent_to_output_weights,
273 const TfLiteTensor* cell_to_input_weights,
274 const TfLiteTensor* cell_to_forget_weights,
275 const TfLiteTensor* cell_to_output_weights,
276 const TfLiteTensor* input_layer_norm_coefficients,
277 const TfLiteTensor* forget_layer_norm_coefficients,
278 const TfLiteTensor* cell_layer_norm_coefficients,
279 const TfLiteTensor* output_layer_norm_coefficients,
280 const TfLiteTensor* aux_input,
281 const TfLiteTensor* aux_input_to_input_weights,
282 const TfLiteTensor* aux_input_to_forget_weights,
283 const TfLiteTensor* aux_input_to_cell_weights,
284 const TfLiteTensor* aux_input_to_output_weights,
285 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
286 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
287 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
288 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
289 int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
290 TfLiteTensor* cell_state, TfLiteTensor* output, Logger* logger,
291 const std::vector<int>& intermediate_tensor_indexes,
292 const int subgraph_index, ErrorReporter* error_reporter) {
293 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
294 int max_time, n_batch;
295 if (input->dims->size == 3) {
296 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
297 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
298 } else {
299 max_time = 1;
300 n_batch = input->dims->data[0];
301 }
302 const int n_input = input->dims->data[input->dims->size - 1];
303 const int aux_input_size =
304 (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
305
306 // n_cell and n_output will be the same size when there is no projection.
307 const int n_cell = input_to_output_weights->dims->data[0];
308 const int n_output = recurrent_to_output_weights->dims->data[1];
309
310 // Since we have already checked that weights are all there or none, we can
311 // check the existence of only one to the get the condition.
312 const bool use_cifg = (input_to_input_weights == nullptr);
313
314 // Index the scratch buffers pointers to the global scratch buffer.
315 float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
316 float* input_gate_scratch = nullptr;
317 float* cell_gate_scratch = nullptr;
318 float* forget_gate_scratch = nullptr;
319 float* output_gate_scratch = nullptr;
320 if (use_cifg) {
321 cell_gate_scratch = scratch_buffer_ptr;
322 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
323 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
324 } else {
325 input_gate_scratch = scratch_buffer_ptr;
326 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
327 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
328 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
329 }
330
331 const int output_batch_leading_dim =
332 output->dims->data[output->dims->size - 1];
333 if (time_major) {
334 // Loop through the sequence.
335 const int input_step = n_batch * n_input;
336 const int output_step = n_batch * output_batch_leading_dim;
337 for (int t = 0; t < max_time; t++) {
338 // If this is the forward_sequence, step forward, otherwise step
339 // backwards.
340 const int t_rel = forward_sequence ? t : max_time - t - 1;
341 const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
342 const float* aux_input_ptr = nullptr;
343 if (aux_input) {
344 aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
345 }
346 float* output_ptr_time =
347 GetTensorData<float>(output) + t_rel * output_step + output_offset;
348
349 LstmStepCalibration(
350 input_ptr, GetTensorData<float>(input_to_input_weights),
351 GetTensorData<float>(input_to_forget_weights),
352 GetTensorData<float>(input_to_cell_weights),
353 GetTensorData<float>(input_to_output_weights), aux_input_ptr,
354 GetTensorData<float>(aux_input_to_input_weights),
355 GetTensorData<float>(aux_input_to_forget_weights),
356 GetTensorData<float>(aux_input_to_cell_weights),
357 GetTensorData<float>(aux_input_to_output_weights),
358 GetTensorData<float>(recurrent_to_input_weights),
359 GetTensorData<float>(recurrent_to_forget_weights),
360 GetTensorData<float>(recurrent_to_cell_weights),
361 GetTensorData<float>(recurrent_to_output_weights),
362 GetTensorData<float>(cell_to_input_weights),
363 GetTensorData<float>(cell_to_forget_weights),
364 GetTensorData<float>(cell_to_output_weights),
365 GetTensorData<float>(input_layer_norm_coefficients),
366 GetTensorData<float>(forget_layer_norm_coefficients),
367 GetTensorData<float>(cell_layer_norm_coefficients),
368 GetTensorData<float>(output_layer_norm_coefficients),
369 GetTensorData<float>(input_gate_bias),
370 GetTensorData<float>(forget_gate_bias),
371 GetTensorData<float>(cell_gate_bias),
372 GetTensorData<float>(output_gate_bias),
373 GetTensorData<float>(projection_weights),
374 GetTensorData<float>(projection_bias), params, n_batch, n_cell,
375 n_input, aux_input_size, n_output, output_batch_leading_dim,
376 GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
377 input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
378 output_gate_scratch, output_ptr_time, logger,
379 intermediate_tensor_indexes, subgraph_index, error_reporter);
380 }
381 } else {
382 for (int b = 0; b < n_batch; b++) {
383 const int input_step = n_input;
384 const int output_step = output_batch_leading_dim;
385 for (int t = 0; t < max_time; t++) {
386 // If this is the forward_sequence, step forward, otherwise step
387 // backwards.
388 const int t_rel = forward_sequence ? t : max_time - t - 1;
389 const int time_offset = b * max_time + t_rel;
390 const float* input_ptr =
391 GetTensorData<float>(input) + time_offset * input_step;
392 const float* aux_input_ptr = nullptr;
393 if (aux_input) {
394 aux_input_ptr =
395 GetTensorData<float>(aux_input) + time_offset * input_step;
396 }
397 float* output_ptr = GetTensorData<float>(output) +
398 time_offset * output_step + output_offset;
399
400 // Offset the {output,cell}_state pointers to the right batch.
401 float* output_state_ptr =
402 GetTensorData<float>(output_state) + b * output_batch_leading_dim;
403 float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
404 // Offset the scratch pointers to the right batch.
405 float* input_gate_scratch_ptr =
406 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
407 float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
408 float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
409 float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
410
411 LstmStepCalibration(
412 input_ptr, GetTensorData<float>(input_to_input_weights),
413 GetTensorData<float>(input_to_forget_weights),
414 GetTensorData<float>(input_to_cell_weights),
415 GetTensorData<float>(input_to_output_weights), aux_input_ptr,
416 GetTensorData<float>(aux_input_to_input_weights),
417 GetTensorData<float>(aux_input_to_forget_weights),
418 GetTensorData<float>(aux_input_to_cell_weights),
419 GetTensorData<float>(aux_input_to_output_weights),
420 GetTensorData<float>(recurrent_to_input_weights),
421 GetTensorData<float>(recurrent_to_forget_weights),
422 GetTensorData<float>(recurrent_to_cell_weights),
423 GetTensorData<float>(recurrent_to_output_weights),
424 GetTensorData<float>(cell_to_input_weights),
425 GetTensorData<float>(cell_to_forget_weights),
426 GetTensorData<float>(cell_to_output_weights),
427 GetTensorData<float>(input_layer_norm_coefficients),
428 GetTensorData<float>(forget_layer_norm_coefficients),
429 GetTensorData<float>(cell_layer_norm_coefficients),
430 GetTensorData<float>(output_layer_norm_coefficients),
431 GetTensorData<float>(input_gate_bias),
432 GetTensorData<float>(forget_gate_bias),
433 GetTensorData<float>(cell_gate_bias),
434 GetTensorData<float>(output_gate_bias),
435 GetTensorData<float>(projection_weights),
436 GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
437 n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
438 output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
439 forget_gate_scratch_ptr, cell_gate_scratch_ptr,
440 output_gate_scratch_ptr, output_ptr, logger,
441 intermediate_tensor_indexes, subgraph_index, error_reporter);
442 }
443 }
444 }
445 return kTfLiteOk;
446 }
447
448 struct OpData {
449 // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
450 // inputs).
451 // Please note the 20-input full kernel is deprecated and only kept
452 // here for backward compatibility.
453 TfLiteLSTMKernelType kernel_type;
454
455 // If the lstm is layer norm.
456 bool use_layer_norm;
457
458 // These fields are only used by full kernel.
459 int scratch_tensor_index;
460 };
461
462 // Resize the output, state tensors based on the sizes of the input tensors.
463 // Allocate a temporary scratch tensor. Also check that the sizes of the input
464 // tensors match each other.
lstm_eval(TfLiteContext * context,int subgraph_index,TfLiteNode * node,LSTMType lstm_type,Logger * logger,ErrorReporter * error_reporter)465 TfLiteStatus lstm_eval(TfLiteContext* context, int subgraph_index,
466 TfLiteNode* node, LSTMType lstm_type, Logger* logger,
467 ErrorReporter* error_reporter) {
468 const TfLiteTensor* input;
469 TF_LITE_ENSURE_OK(
470 context, GetInputSafe(context, node,
471 ops::builtin::lstm::full::kInputTensor, &input));
472
473 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
474 context, node, ops::builtin::lstm::full::kInputToInputWeightsTensor);
475 const TfLiteTensor* input_to_forget_weights;
476 TF_LITE_ENSURE_OK(
477 context,
478 GetInputSafe(context, node,
479 ops::builtin::lstm::full::kInputToForgetWeightsTensor,
480 &input_to_forget_weights));
481 const TfLiteTensor* input_to_cell_weights;
482 TF_LITE_ENSURE_OK(
483 context, GetInputSafe(context, node,
484 ops::builtin::lstm::full::kInputToCellWeightsTensor,
485 &input_to_cell_weights));
486 const TfLiteTensor* input_to_output_weights;
487 TF_LITE_ENSURE_OK(
488 context,
489 GetInputSafe(context, node,
490 ops::builtin::lstm::full::kInputToOutputWeightsTensor,
491 &input_to_output_weights));
492
493 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
494 context, node, ops::builtin::lstm::full::kRecurrentToInputWeightsTensor);
495 const TfLiteTensor* recurrent_to_forget_weights;
496 TF_LITE_ENSURE_OK(
497 context,
498 GetInputSafe(context, node,
499 ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
500 &recurrent_to_forget_weights));
501 const TfLiteTensor* recurrent_to_cell_weights;
502 TF_LITE_ENSURE_OK(
503 context,
504 GetInputSafe(context, node,
505 ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
506 &recurrent_to_cell_weights));
507 const TfLiteTensor* recurrent_to_output_weights;
508 TF_LITE_ENSURE_OK(
509 context,
510 GetInputSafe(context, node,
511 ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
512 &recurrent_to_output_weights));
513
514 const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
515 context, node, ops::builtin::lstm::full::kCellToInputWeightsTensor);
516 const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
517 context, node, ops::builtin::lstm::full::kCellToForgetWeightsTensor);
518 const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
519 context, node, ops::builtin::lstm::full::kCellToOutputWeightsTensor);
520
521 const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
522 context, node,
523 ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor);
524 const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
525 context, node,
526 ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
527 const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
528 context, node,
529 ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor);
530 const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
531 context, node,
532 ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor);
533
534 const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(
535 context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
536 const TfLiteTensor* forget_gate_bias;
537 TF_LITE_ENSURE_OK(
538 context, GetInputSafe(context, node,
539 ops::builtin::lstm::full::kForgetGateBiasTensor,
540 &forget_gate_bias));
541 const TfLiteTensor* cell_gate_bias;
542 TF_LITE_ENSURE_OK(
543 context,
544 GetInputSafe(context, node, ops::builtin::lstm::full::kCellGateBiasTensor,
545 &cell_gate_bias));
546 const TfLiteTensor* output_gate_bias;
547 TF_LITE_ENSURE_OK(
548 context, GetInputSafe(context, node,
549 ops::builtin::lstm::full::kOutputGateBiasTensor,
550 &output_gate_bias));
551
552 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
553 context, node, ops::builtin::lstm::full::kProjectionWeightsTensor);
554 const TfLiteTensor* projection_bias = GetOptionalInputTensor(
555 context, node, ops::builtin::lstm::full::kProjectionBiasTensor);
556
557 // Index the scratch buffers pointers to the global scratch buffer.
558 TfLiteTensor* scratch_buffer;
559 TF_LITE_ENSURE_OK(
560 context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
561
562 TfLiteTensor* output_state = GetVariableInput(
563 context, node, ops::builtin::lstm::full::kOutputStateTensor);
564 TF_LITE_ENSURE(context, output_state != nullptr);
565 TfLiteTensor* cell_state = GetVariableInput(
566 context, node, ops::builtin::lstm::full::kCellStateTensor);
567 TF_LITE_ENSURE(context, cell_state != nullptr);
568
569 TfLiteTensor* output;
570 TF_LITE_ENSURE_OK(
571 context, GetOutputSafe(context, node,
572 ops::builtin::lstm::full::kOutputTensor, &output));
573
574 std::vector<int> intermediate_tensor_indexes(node->intermediates->size);
575 // LSTM expect 5 intermediate tensors.
576 TF_LITE_ENSURE_EQ(context, node->intermediates->size, 5);
577 for (int i = 0; i < node->intermediates->size; ++i) {
578 intermediate_tensor_indexes[i] = node->intermediates->data[i];
579 }
580
581 TfLiteLSTMParams lstm_params;
582 bool time_major = true;
583 switch (lstm_type) {
584 case LSTMType::kLSTM: {
585 lstm_params = *(static_cast<TfLiteLSTMParams*>(node->builtin_data));
586 time_major = true;
587 break;
588 }
589 case LSTMType::kUnidirectionalSequenceLSTM: {
590 const auto* params = static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
591 node->builtin_data);
592 // Copy out the LSTM specific params so they can be passed in the
593 // function.
594 lstm_params.activation = params->activation;
595 lstm_params.cell_clip = params->cell_clip;
596 lstm_params.proj_clip = params->proj_clip;
597 lstm_params.asymmetric_quantize_inputs =
598 params->asymmetric_quantize_inputs;
599 time_major = params->time_major;
600 break;
601 }
602 default:
603 return kTfLiteError;
604 }
605
606 switch (input_to_output_weights->type) {
607 case kTfLiteFloat32: {
608 return EvalCalibration(
609 input, input_to_input_weights, input_to_forget_weights,
610 input_to_cell_weights, input_to_output_weights,
611 recurrent_to_input_weights, recurrent_to_forget_weights,
612 recurrent_to_cell_weights, recurrent_to_output_weights,
613 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
614 input_layer_norm_coefficients, forget_layer_norm_coefficients,
615 cell_layer_norm_coefficients, output_layer_norm_coefficients,
616 /*aux_input=*/nullptr,
617 /*aux_input_to_input_weights=*/nullptr,
618 /*aux_input_to_forget_weights=*/nullptr,
619 /*aux_input_to_cell_weights=*/nullptr,
620 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
621 forget_gate_bias, cell_gate_bias, output_gate_bias,
622 projection_weights, projection_bias, &lstm_params,
623 /*forward_sequence=*/true,
624 /*time_major=*/time_major,
625 /*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
626 logger, intermediate_tensor_indexes, subgraph_index, error_reporter);
627 }
628 case kTfLiteUInt8:
629 case kTfLiteInt8:
630 default:
631 printf("Error. Only float model can be calibrated\n");
632 return kTfLiteError;
633 }
634 return kTfLiteOk;
635 }
636 } // namespace
637
lstm_logging_kernel(TfLiteContext * context,const int subgraph_index,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)638 TfLiteStatus lstm_logging_kernel(TfLiteContext* context,
639 const int subgraph_index, TfLiteNode* node,
640 Logger* logger,
641 ErrorReporter* error_reporter) {
642 return lstm_eval(context, subgraph_index, node, LSTMType::kLSTM, logger,
643 error_reporter);
644 }
645
unidirectional_sequence_lstm_logging_kernel(TfLiteContext * context,const int subgraph_index,TfLiteNode * node,Logger * logger,ErrorReporter * error_reporter)646 TfLiteStatus unidirectional_sequence_lstm_logging_kernel(
647 TfLiteContext* context, const int subgraph_index, TfLiteNode* node,
648 Logger* logger, ErrorReporter* error_reporter) {
649 return lstm_eval(context, subgraph_index, node,
650 LSTMType::kUnidirectionalSequenceLSTM, logger,
651 error_reporter);
652 }
653
654 } // namespace builtin
655 } // namespace calibration
656 } // namespace optimize
657 } // namespace tflite
658