xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite Bidirectional LSTM op.
16 
17 #include <tuple>
18 #include <vector>
19 
20 #include <gtest/gtest.h>
21 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
22 #include "tensorflow/lite/kernels/test_util.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24 
25 namespace tflite {
26 namespace {
27 
28 using ::testing::ElementsAreArray;
29 
30 class BidirectionalLSTMOpModel : public SingleOpModel {
31  public:
BidirectionalLSTMOpModel(int n_batch,int n_input,int n_cell,int n_output,int sequence_length,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,bool merge_outputs,bool use_aux_input,float cell_clip,float proj_clip,bool quantize_weights,bool time_major,const std::vector<std::vector<int>> & input_shapes,bool asymmetric_quantize_inputs=false)32   BidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
33                            int sequence_length, bool use_cifg,
34                            bool use_peephole, bool use_projection_weights,
35                            bool use_projection_bias, bool merge_outputs,
36                            bool use_aux_input, float cell_clip, float proj_clip,
37                            bool quantize_weights, bool time_major,
38                            const std::vector<std::vector<int>>& input_shapes,
39                            bool asymmetric_quantize_inputs = false)
40       : n_batch_(n_batch),
41         n_input_(n_input),
42         n_fw_cell_(n_cell),
43         n_bw_cell_(n_cell),
44         n_fw_output_(n_output),
45         n_bw_output_(n_output),
46         sequence_length_(sequence_length),
47         quantize_weights_(quantize_weights) {
48     input_ = AddInput(TensorType_FLOAT32);
49     const auto weight_type =
50         quantize_weights_ ? TensorType_UINT8 : TensorType_FLOAT32;
51 
52     if (use_cifg) {
53       fw_input_to_input_weights_ = AddNullInput();
54     } else {
55       fw_input_to_input_weights_ = AddInput(weight_type);
56     }
57 
58     fw_input_to_forget_weights_ = AddInput(weight_type);
59     fw_input_to_cell_weights_ = AddInput(weight_type);
60     fw_input_to_output_weights_ = AddInput(weight_type);
61 
62     if (use_cifg) {
63       fw_recurrent_to_input_weights_ = AddNullInput();
64     } else {
65       fw_recurrent_to_input_weights_ = AddInput(weight_type);
66     }
67 
68     fw_recurrent_to_forget_weights_ = AddInput(weight_type);
69     fw_recurrent_to_cell_weights_ = AddInput(weight_type);
70     fw_recurrent_to_output_weights_ = AddInput(weight_type);
71 
72     if (use_peephole) {
73       if (use_cifg) {
74         fw_cell_to_input_weights_ = AddNullInput();
75       } else {
76         fw_cell_to_input_weights_ = AddInput(weight_type);
77       }
78       fw_cell_to_forget_weights_ = AddInput(weight_type);
79       fw_cell_to_output_weights_ = AddInput(weight_type);
80     } else {
81       fw_cell_to_input_weights_ = AddNullInput();
82       fw_cell_to_forget_weights_ = AddNullInput();
83       fw_cell_to_output_weights_ = AddNullInput();
84     }
85 
86     if (use_cifg) {
87       fw_input_gate_bias_ = AddNullInput();
88     } else {
89       fw_input_gate_bias_ = AddInput(TensorType_FLOAT32);
90     }
91     fw_forget_gate_bias_ = AddInput(TensorType_FLOAT32);
92     fw_cell_gate_bias_ = AddInput(TensorType_FLOAT32);
93     fw_output_gate_bias_ = AddInput(TensorType_FLOAT32);
94 
95     if (use_projection_weights) {
96       fw_projection_weights_ = AddInput(TensorType_FLOAT32);
97       if (use_projection_bias) {
98         fw_projection_bias_ = AddInput(TensorType_FLOAT32);
99       } else {
100         fw_projection_bias_ = AddNullInput();
101       }
102     } else {
103       fw_projection_weights_ = AddNullInput();
104       fw_projection_bias_ = AddNullInput();
105     }
106 
107     if (use_cifg) {
108       bw_input_to_input_weights_ = AddNullInput();
109     } else {
110       bw_input_to_input_weights_ = AddInput(weight_type);
111     }
112 
113     bw_input_to_forget_weights_ = AddInput(weight_type);
114     bw_input_to_cell_weights_ = AddInput(weight_type);
115     bw_input_to_output_weights_ = AddInput(weight_type);
116 
117     if (use_cifg) {
118       bw_recurrent_to_input_weights_ = AddNullInput();
119     } else {
120       bw_recurrent_to_input_weights_ = AddInput(weight_type);
121     }
122 
123     bw_recurrent_to_forget_weights_ = AddInput(weight_type);
124     bw_recurrent_to_cell_weights_ = AddInput(weight_type);
125     bw_recurrent_to_output_weights_ = AddInput(weight_type);
126 
127     if (use_peephole) {
128       if (use_cifg) {
129         bw_cell_to_input_weights_ = AddNullInput();
130       } else {
131         bw_cell_to_input_weights_ = AddInput(weight_type);
132       }
133       bw_cell_to_forget_weights_ = AddInput(weight_type);
134       bw_cell_to_output_weights_ = AddInput(weight_type);
135     } else {
136       bw_cell_to_input_weights_ = AddNullInput();
137       bw_cell_to_forget_weights_ = AddNullInput();
138       bw_cell_to_output_weights_ = AddNullInput();
139     }
140 
141     if (use_cifg) {
142       bw_input_gate_bias_ = AddNullInput();
143     } else {
144       bw_input_gate_bias_ = AddInput(TensorType_FLOAT32);
145     }
146     bw_forget_gate_bias_ = AddInput(TensorType_FLOAT32);
147     bw_cell_gate_bias_ = AddInput(TensorType_FLOAT32);
148     bw_output_gate_bias_ = AddInput(TensorType_FLOAT32);
149 
150     if (use_projection_weights) {
151       bw_projection_weights_ = AddInput(weight_type);
152       if (use_projection_bias) {
153         bw_projection_bias_ = AddInput(TensorType_FLOAT32);
154       } else {
155         bw_projection_bias_ = AddNullInput();
156       }
157     } else {
158       bw_projection_weights_ = AddNullInput();
159       bw_projection_bias_ = AddNullInput();
160     }
161 
162     // Adding the 2 input state tensors.
163     fw_input_activation_state_ = AddVariableInput(
164         TensorData{TensorType_FLOAT32, {n_fw_output_ * n_batch_}});
165     fw_input_cell_state_ = AddVariableInput(
166         TensorData{TensorType_FLOAT32, {n_fw_cell_ * n_batch_}});
167 
168     // Adding the 2 input state tensors.
169     bw_input_activation_state_ = AddVariableInput(
170         TensorData{TensorType_FLOAT32, {n_bw_output_ * n_batch_}});
171     bw_input_cell_state_ = AddVariableInput(
172         TensorData{TensorType_FLOAT32, {n_bw_cell_ * n_batch_}});
173 
174     fw_output_ = AddOutput(TensorType_FLOAT32);
175 
176     if (!merge_outputs) {
177       bw_output_ = AddOutput(TensorType_FLOAT32);
178     }
179 
180     if (use_aux_input) {
181       aux_input_ = AddInput(TensorType_FLOAT32);
182       fw_aux_input_to_input_weights_ = AddInput(weight_type);
183       fw_aux_input_to_forget_weights_ = AddInput(weight_type);
184       fw_aux_input_to_cell_weights_ = AddInput(weight_type);
185       fw_aux_input_to_output_weights_ = AddInput(weight_type);
186       bw_aux_input_to_input_weights_ = AddInput(weight_type);
187       bw_aux_input_to_forget_weights_ = AddInput(weight_type);
188       bw_aux_input_to_cell_weights_ = AddInput(weight_type);
189       bw_aux_input_to_output_weights_ = AddInput(weight_type);
190     } else {
191       aux_input_ = AddNullInput();
192       fw_aux_input_to_input_weights_ = AddNullInput();
193       fw_aux_input_to_forget_weights_ = AddNullInput();
194       fw_aux_input_to_cell_weights_ = AddNullInput();
195       fw_aux_input_to_output_weights_ = AddNullInput();
196       bw_aux_input_to_input_weights_ = AddNullInput();
197       bw_aux_input_to_forget_weights_ = AddNullInput();
198       bw_aux_input_to_cell_weights_ = AddNullInput();
199       bw_aux_input_to_output_weights_ = AddNullInput();
200     }
201 
202     SetBuiltinOp(
203         BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
204         BuiltinOptions_BidirectionalSequenceLSTMOptions,
205         CreateBidirectionalSequenceLSTMOptions(
206             builder_, ActivationFunctionType_TANH, cell_clip, proj_clip,
207             merge_outputs, time_major, asymmetric_quantize_inputs)
208             .Union());
209     BuildInterpreter(input_shapes);
210   }
211 
PopulateWeightTensor(int tensor_id,const std::vector<float> & f)212   void PopulateWeightTensor(int tensor_id, const std::vector<float>& f) {
213     if (quantize_weights_) {
214       SymmetricQuantizeAndPopulate(tensor_id, f);
215     } else {
216       PopulateTensor(tensor_id, f);
217     }
218   }
219 
220   // Set weights in forward and backward cells to be the same.
SetInputToInputWeights(const std::vector<float> & f)221   void SetInputToInputWeights(const std::vector<float>& f) {
222     PopulateWeightTensor(fw_input_to_input_weights_, f);
223     PopulateWeightTensor(bw_input_to_input_weights_, f);
224   }
225 
SetInputToForgetWeights(const std::vector<float> & f)226   void SetInputToForgetWeights(const std::vector<float>& f) {
227     PopulateWeightTensor(fw_input_to_forget_weights_, f);
228     PopulateWeightTensor(bw_input_to_forget_weights_, f);
229   }
230 
SetInputToCellWeights(const std::vector<float> & f)231   void SetInputToCellWeights(const std::vector<float>& f) {
232     PopulateWeightTensor(fw_input_to_cell_weights_, f);
233     PopulateWeightTensor(bw_input_to_cell_weights_, f);
234   }
235 
SetInputToOutputWeights(const std::vector<float> & f)236   void SetInputToOutputWeights(const std::vector<float>& f) {
237     PopulateWeightTensor(fw_input_to_output_weights_, f);
238     PopulateWeightTensor(bw_input_to_output_weights_, f);
239   }
240 
SetRecurrentToInputWeights(const std::vector<float> & f)241   void SetRecurrentToInputWeights(const std::vector<float>& f) {
242     PopulateWeightTensor(fw_recurrent_to_input_weights_, f);
243     PopulateWeightTensor(bw_recurrent_to_input_weights_, f);
244   }
245 
SetRecurrentToForgetWeights(const std::vector<float> & f)246   void SetRecurrentToForgetWeights(const std::vector<float>& f) {
247     PopulateWeightTensor(fw_recurrent_to_forget_weights_, f);
248     PopulateWeightTensor(bw_recurrent_to_forget_weights_, f);
249   }
250 
SetRecurrentToCellWeights(const std::vector<float> & f)251   void SetRecurrentToCellWeights(const std::vector<float>& f) {
252     PopulateWeightTensor(fw_recurrent_to_cell_weights_, f);
253     PopulateWeightTensor(bw_recurrent_to_cell_weights_, f);
254   }
255 
SetRecurrentToOutputWeights(const std::vector<float> & f)256   void SetRecurrentToOutputWeights(const std::vector<float>& f) {
257     PopulateWeightTensor(fw_recurrent_to_output_weights_, f);
258     PopulateWeightTensor(bw_recurrent_to_output_weights_, f);
259   }
260 
SetCellToInputWeights(const std::vector<float> & f)261   void SetCellToInputWeights(const std::vector<float>& f) {
262     PopulateWeightTensor(fw_cell_to_input_weights_, f);
263     PopulateWeightTensor(bw_cell_to_input_weights_, f);
264   }
265 
SetCellToForgetWeights(const std::vector<float> & f)266   void SetCellToForgetWeights(const std::vector<float>& f) {
267     PopulateWeightTensor(fw_cell_to_forget_weights_, f);
268     PopulateWeightTensor(bw_cell_to_forget_weights_, f);
269   }
270 
SetCellToOutputWeights(const std::vector<float> & f)271   void SetCellToOutputWeights(const std::vector<float>& f) {
272     PopulateWeightTensor(fw_cell_to_output_weights_, f);
273     PopulateWeightTensor(bw_cell_to_output_weights_, f);
274   }
275 
SetInputGateBias(const std::vector<float> & f)276   void SetInputGateBias(const std::vector<float>& f) {
277     PopulateTensor(fw_input_gate_bias_, f);
278     PopulateTensor(bw_input_gate_bias_, f);
279   }
280 
SetForgetGateBias(const std::vector<float> & f)281   void SetForgetGateBias(const std::vector<float>& f) {
282     PopulateTensor(fw_forget_gate_bias_, f);
283     PopulateTensor(bw_forget_gate_bias_, f);
284   }
285 
SetCellBias(const std::vector<float> & f)286   void SetCellBias(const std::vector<float>& f) {
287     PopulateTensor(fw_cell_gate_bias_, f);
288     PopulateTensor(bw_cell_gate_bias_, f);
289   }
290 
SetOutputGateBias(const std::vector<float> & f)291   void SetOutputGateBias(const std::vector<float>& f) {
292     PopulateTensor(fw_output_gate_bias_, f);
293     PopulateTensor(bw_output_gate_bias_, f);
294   }
295 
SetProjectionWeights(const std::vector<float> & f)296   void SetProjectionWeights(const std::vector<float>& f) {
297     PopulateWeightTensor(fw_projection_weights_, f);
298     PopulateWeightTensor(bw_projection_weights_, f);
299   }
300 
SetProjectionBias(const std::vector<float> & f)301   void SetProjectionBias(const std::vector<float>& f) {
302     PopulateTensor(fw_projection_bias_, f);
303     PopulateTensor(bw_projection_bias_, f);
304   }
305 
SetInput(int offset,float * begin,float * end)306   void SetInput(int offset, float* begin, float* end) {
307     PopulateTensor(input_, offset, begin, end);
308   }
309 
SetAuxInput(int offset,float * begin,float * end)310   void SetAuxInput(int offset, float* begin, float* end) {
311     PopulateTensor(aux_input_, offset, begin, end);
312   }
313 
SetAuxInputToInputWeights(const std::vector<float> & f)314   void SetAuxInputToInputWeights(const std::vector<float>& f) {
315     PopulateWeightTensor(fw_aux_input_to_input_weights_, f);
316     PopulateWeightTensor(bw_aux_input_to_input_weights_, f);
317   }
318 
SetAuxInputToForgetWeights(const std::vector<float> & f)319   void SetAuxInputToForgetWeights(const std::vector<float>& f) {
320     PopulateWeightTensor(fw_aux_input_to_forget_weights_, f);
321     PopulateWeightTensor(bw_aux_input_to_forget_weights_, f);
322   }
323 
SetAuxInputToCellWeights(const std::vector<float> & f)324   void SetAuxInputToCellWeights(const std::vector<float>& f) {
325     PopulateWeightTensor(fw_aux_input_to_cell_weights_, f);
326     PopulateWeightTensor(bw_aux_input_to_cell_weights_, f);
327   }
328 
SetAuxInputToOutputWeights(const std::vector<float> & f)329   void SetAuxInputToOutputWeights(const std::vector<float>& f) {
330     PopulateWeightTensor(fw_aux_input_to_output_weights_, f);
331     PopulateWeightTensor(bw_aux_input_to_output_weights_, f);
332   }
333 
GetFwOutput()334   std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
GetBwOutput()335   std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
336 
num_inputs()337   int num_inputs() { return n_input_; }
num_fw_outputs()338   int num_fw_outputs() { return n_fw_output_; }
num_bw_outputs()339   int num_bw_outputs() { return n_bw_output_; }
num_fw_cells()340   int num_fw_cells() { return n_fw_cell_; }
num_bw_cells()341   int num_bw_cells() { return n_bw_cell_; }
num_batches()342   int num_batches() { return n_batch_; }
sequence_length()343   int sequence_length() { return sequence_length_; }
344 
345  private:
346   int input_;
347   int fw_input_to_input_weights_;
348   int fw_input_to_forget_weights_;
349   int fw_input_to_cell_weights_;
350   int fw_input_to_output_weights_;
351 
352   int fw_recurrent_to_input_weights_;
353   int fw_recurrent_to_forget_weights_;
354   int fw_recurrent_to_cell_weights_;
355   int fw_recurrent_to_output_weights_;
356 
357   int fw_cell_to_input_weights_;
358   int fw_cell_to_forget_weights_;
359   int fw_cell_to_output_weights_;
360 
361   int fw_input_gate_bias_;
362   int fw_forget_gate_bias_;
363   int fw_cell_gate_bias_;
364   int fw_output_gate_bias_;
365 
366   int fw_projection_weights_;
367   int fw_projection_bias_;
368 
369   int bw_input_to_input_weights_;
370   int bw_input_to_forget_weights_;
371   int bw_input_to_cell_weights_;
372   int bw_input_to_output_weights_;
373 
374   int bw_recurrent_to_input_weights_;
375   int bw_recurrent_to_forget_weights_;
376   int bw_recurrent_to_cell_weights_;
377   int bw_recurrent_to_output_weights_;
378 
379   int bw_cell_to_input_weights_;
380   int bw_cell_to_forget_weights_;
381   int bw_cell_to_output_weights_;
382 
383   int bw_input_gate_bias_;
384   int bw_forget_gate_bias_;
385   int bw_cell_gate_bias_;
386   int bw_output_gate_bias_;
387 
388   int bw_projection_weights_;
389   int bw_projection_bias_;
390 
391   int fw_input_activation_state_;
392   int fw_input_cell_state_;
393   int bw_input_activation_state_;
394   int bw_input_cell_state_;
395 
396   int fw_output_;
397   int bw_output_;
398 
399   int aux_input_;
400   int fw_aux_input_to_input_weights_;
401   int fw_aux_input_to_forget_weights_;
402   int fw_aux_input_to_cell_weights_;
403   int fw_aux_input_to_output_weights_;
404   int bw_aux_input_to_input_weights_;
405   int bw_aux_input_to_forget_weights_;
406   int bw_aux_input_to_cell_weights_;
407   int bw_aux_input_to_output_weights_;
408 
409   int n_batch_;
410   int n_input_;
411   int n_fw_cell_;
412   int n_bw_cell_;
413   int n_fw_output_;
414   int n_bw_output_;
415   int sequence_length_;
416 
417   bool quantize_weights_;
418 };
419 
420 // Declare LSTMOpTest as a parameterized test.
421 class LSTMOpTest
422     : public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
423 
424 INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest,
425                          ::testing::Combine(
426                              /*quantize_weights*/ ::testing::Bool(),
427                              /*asymmetric_quantize_inputs*/ ::testing::Bool()));
428 
TEST_P(LSTMOpTest,BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping)429 TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
430   const int n_batch = 1;
431   const int n_input = 2;
432   // n_cell and n_output have the same size when there is no projection.
433   const int n_cell = 4;
434   const int n_output = 4;
435   const int sequence_length = 3;
436   auto params = GetParam();
437   const bool quantize_weights = std::get<0>(params);
438   const bool asymmetric_quantize_inputs = std::get<1>(params);
439 
440   BidirectionalLSTMOpModel lstm(
441       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
442       /*use_peephole=*/false, /*use_projection_weights=*/false,
443       /*use_projection_bias=*/false, /*merge_outputs=*/false,
444       /*use_aux_input=*/false, /*cell_clip=*/0.0,
445       /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
446       {
447           {sequence_length, n_batch, n_input},  // input tensor
448 
449           // Forward cell
450           {n_cell, n_input},  // input_to_input_weight tensor
451           {n_cell, n_input},  // input_to_forget_weight tensor
452           {n_cell, n_input},  // input_to_cell_weight tensor
453           {n_cell, n_input},  // input_to_output_weight tensor
454 
455           {n_cell, n_output},  // recurrent_to_input_weight tensor
456           {n_cell, n_output},  // recurrent_to_forget_weight tensor
457           {n_cell, n_output},  // recurrent_to_cell_weight tensor
458           {n_cell, n_output},  // recurrent_to_output_weight tensor
459 
460           {0},  // cell_to_input_weight tensor
461           {0},  // cell_to_forget_weight tensor
462           {0},  // cell_to_output_weight tensor
463 
464           {n_cell},  // input_gate_bias tensor
465           {n_cell},  // forget_gate_bias tensor
466           {n_cell},  // cell_gate_bias tensor
467           {n_cell},  // output_gate_bias tensor
468 
469           {0, 0},  // projection_weight tensor
470           {0},     // projection_bias tensor
471 
472           // Backward cell
473           {n_cell, n_input},  // input_to_input_weight tensor
474           {n_cell, n_input},  // input_to_forget_weight tensor
475           {n_cell, n_input},  // input_to_cell_weight tensor
476           {n_cell, n_input},  // input_to_output_weight tensor
477 
478           {n_cell, n_output},  // recurrent_to_input_weight tensor
479           {n_cell, n_output},  // recurrent_to_forget_weight tensor
480           {n_cell, n_output},  // recurrent_to_cell_weight tensor
481           {n_cell, n_output},  // recurrent_to_output_weight tensor
482 
483           {0},  // cell_to_input_weight tensor
484           {0},  // cell_to_forget_weight tensor
485           {0},  // cell_to_output_weight tensor
486 
487           {n_cell},  // input_gate_bias tensor
488           {n_cell},  // forget_gate_bias tensor
489           {n_cell},  // cell_gate_bias tensor
490           {n_cell},  // output_gate_bias tensor
491 
492           {0, 0},  // projection_weight tensor
493           {0},     // projection_bias tensor
494 
495           {n_batch, n_output},  // activation_state tensor
496           {n_batch, n_cell},    // cell_state tensor
497 
498           {n_batch, n_output},  // activation_state tensor
499           {n_batch, n_cell},    // cell_state tensor
500 
501           {sequence_length, n_batch, 0},  // aux_input tensor
502           {0},                            // aux_fw_input_to_input tensor
503           {0},                            // aux_fw_input_to_forget tensor
504           {0},                            // aux_fw_input_to_cell tensor
505           {0},                            // aux_fw_input_to_output tensor
506           {0},                            // aux_bw_input_to_input tensor
507           {0},                            // aux_bw_input_to_forget tensor
508           {0},                            // aux_bw_input_to_cell tensor
509           {0},                            // aux_bw_input_to_output tensor
510       },
511       asymmetric_quantize_inputs);
512 
513   lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
514                                -0.34550029, 0.04266912, -0.15680569,
515                                -0.34856534, 0.43890524});
516 
517   lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
518                               -0.20583314, 0.44344562, 0.22077113,
519                               -0.29909778});
520 
521   lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
522                                 -0.31343272, -0.40032279, 0.44781327,
523                                 0.01387155, -0.35593212});
524 
525   lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
526                                 0.40525138, 0.44272184, 0.03897077, -0.1556896,
527                                 0.19487578});
528 
529   lstm.SetInputGateBias({0., 0., 0., 0.});
530 
531   lstm.SetCellBias({0., 0., 0., 0.});
532 
533   lstm.SetForgetGateBias({1., 1., 1., 1.});
534 
535   lstm.SetOutputGateBias({0., 0., 0., 0.});
536 
537   lstm.SetRecurrentToInputWeights(
538       {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
539        -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
540        -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
541 
542   lstm.SetRecurrentToCellWeights(
543       {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
544        -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
545        -0.46367589, 0.26016325, -0.03894562, -0.16368064});
546 
547   lstm.SetRecurrentToForgetWeights(
548       {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
549        -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
550        0.28053468, 0.01560611, -0.20127171, -0.01140004});
551 
552   lstm.SetRecurrentToOutputWeights(
553       {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
554        0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
555        -0.51818722, -0.15390486, 0.0468148, 0.39922136});
556 
557   // Input should have n_input * sequence_length many values.
558   static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
559   static float lstm_fw_golden_output[] = {
560       -0.02973187, 0.1229473,  0.20885126, -0.15358765,
561       -0.03716109, 0.12507336, 0.41193449, -0.20860538,
562       -0.15053082, 0.09120187, 0.24278517, -0.12222792};
563   static float lstm_bw_golden_output[] = {
564       -0.0806187, 0.139077, 0.400476,   -0.197842, -0.0332076, 0.123838,
565       0.309777,   -0.17621, -0.0490733, 0.0739237, 0.067706,   -0.0208124};
566 
567   float* batch0_start = lstm_input;
568   float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
569 
570   lstm.SetInput(0, batch0_start, batch0_end);
571 
572   ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
573 
574   float* fw_golden_start = lstm_fw_golden_output;
575   float* fw_golden_end =
576       fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length();
577   std::vector<float> fw_expected;
578   fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end);
579   EXPECT_THAT(lstm.GetFwOutput(),
580               ElementsAreArray(
581                   ArrayFloatNear(fw_expected, quantize_weights ? 1e-2 : 1e-5)));
582 
583   float* bw_golden_start = lstm_bw_golden_output;
584   float* bw_golden_end =
585       bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length();
586   std::vector<float> bw_expected;
587   bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
588   EXPECT_THAT(lstm.GetBwOutput(),
589               ElementsAreArray(
590                   ArrayFloatNear(bw_expected, quantize_weights ? 1e-2 : 1e-5)));
591 }
592 
593 // Same as the previous test, yet with a single merged output tensor and n_batch
594 // of 2.
TEST_P(LSTMOpTest,BlackBoxTestMergedOutput)595 TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
596   const int n_batch = 2;
597   const int n_input = 2;
598   // n_cell and n_output have the same size when there is no projection.
599   const int n_cell = 4;
600   const int n_output = 4;
601   const int sequence_length = 3;
602   auto params = GetParam();
603   const bool quantize_weights = std::get<0>(params);
604   const bool asymmetric_quantize_inputs = std::get<1>(params);
605 
606   BidirectionalLSTMOpModel lstm(
607       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
608       /*use_peephole=*/false, /*use_projection_weights=*/false,
609       /*use_projection_bias=*/false, /*merge_outputs=*/true,
610       /*use_aux_input=*/false, /*cell_clip=*/0.0,
611       /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
612       {
613           {sequence_length, n_batch, n_input},  // input tensor
614 
615           // Forward cell
616           {n_cell, n_input},  // input_to_input_weight tensor
617           {n_cell, n_input},  // input_to_forget_weight tensor
618           {n_cell, n_input},  // input_to_cell_weight tensor
619           {n_cell, n_input},  // input_to_output_weight tensor
620 
621           {n_cell, n_output},  // recurrent_to_input_weight tensor
622           {n_cell, n_output},  // recurrent_to_forget_weight tensor
623           {n_cell, n_output},  // recurrent_to_cell_weight tensor
624           {n_cell, n_output},  // recurrent_to_output_weight tensor
625 
626           {0},  // cell_to_input_weight tensor
627           {0},  // cell_to_forget_weight tensor
628           {0},  // cell_to_output_weight tensor
629 
630           {n_cell},  // input_gate_bias tensor
631           {n_cell},  // forget_gate_bias tensor
632           {n_cell},  // cell_gate_bias tensor
633           {n_cell},  // output_gate_bias tensor
634 
635           {0, 0},  // projection_weight tensor
636           {0},     // projection_bias tensor
637 
638           // Backward cell
639           {n_cell, n_input},  // input_to_input_weight tensor
640           {n_cell, n_input},  // input_to_forget_weight tensor
641           {n_cell, n_input},  // input_to_cell_weight tensor
642           {n_cell, n_input},  // input_to_output_weight tensor
643 
644           {n_cell, n_output},  // recurrent_to_input_weight tensor
645           {n_cell, n_output},  // recurrent_to_forget_weight tensor
646           {n_cell, n_output},  // recurrent_to_cell_weight tensor
647           {n_cell, n_output},  // recurrent_to_output_weight tensor
648 
649           {0},  // cell_to_input_weight tensor
650           {0},  // cell_to_forget_weight tensor
651           {0},  // cell_to_output_weight tensor
652 
653           {n_cell},  // input_gate_bias tensor
654           {n_cell},  // forget_gate_bias tensor
655           {n_cell},  // cell_gate_bias tensor
656           {n_cell},  // output_gate_bias tensor
657 
658           {0, 0},  // projection_weight tensor
659           {0},     // projection_bias tensor
660 
661           {n_batch, n_output},  // activation_state tensor
662           {n_batch, n_cell},    // cell_state tensor
663 
664           {n_batch, n_output},  // activation_state tensor
665           {n_batch, n_cell},    // cell_state tensor
666 
667           {sequence_length, n_batch, 0},  // aux_input tensor
668           {0},                            // aux_fw_input_to_input tensor
669           {0},                            // aux_fw_input_to_forget tensor
670           {0},                            // aux_fw_input_to_cell tensor
671           {0},                            // aux_fw_input_to_output tensor
672           {0},                            // aux_bw_input_to_input tensor
673           {0},                            // aux_bw_input_to_forget tensor
674           {0},                            // aux_bw_input_to_cell tensor
675           {0},                            // aux_bw_input_to_output tensor
676       },
677       asymmetric_quantize_inputs);
678 
679   lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
680                                -0.34550029, 0.04266912, -0.15680569,
681                                -0.34856534, 0.43890524});
682 
683   lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
684                               -0.20583314, 0.44344562, 0.22077113,
685                               -0.29909778});
686 
687   lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
688                                 -0.31343272, -0.40032279, 0.44781327,
689                                 0.01387155, -0.35593212});
690 
691   lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
692                                 0.40525138, 0.44272184, 0.03897077, -0.1556896,
693                                 0.19487578});
694 
695   lstm.SetInputGateBias({0., 0., 0., 0.});
696 
697   lstm.SetCellBias({0., 0., 0., 0.});
698 
699   lstm.SetForgetGateBias({1., 1., 1., 1.});
700 
701   lstm.SetOutputGateBias({0., 0., 0., 0.});
702 
703   lstm.SetRecurrentToInputWeights(
704       {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
705        -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
706        -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
707 
708   lstm.SetRecurrentToCellWeights(
709       {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
710        -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
711        -0.46367589, 0.26016325, -0.03894562, -0.16368064});
712 
713   lstm.SetRecurrentToForgetWeights(
714       {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
715        -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
716        0.28053468, 0.01560611, -0.20127171, -0.01140004});
717 
718   lstm.SetRecurrentToOutputWeights(
719       {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
720        0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
721        -0.51818722, -0.15390486, 0.0468148, 0.39922136});
722 
723   // Input should have n_input * sequence_length many values.
724   static float lstm_input[] = {2., 3., 2., 3., 3., 4., 3., 4., 1., 1., 1., 1.};
725   static float lstm_fw_golden_output[] = {
726       -0.02973187, 0.1229473,   0.20885126,  -0.15358765, -0.02973187,
727       0.1229473,   0.20885126,  -0.15358765, -0.03716109, 0.12507336,
728       0.41193449,  -0.20860538, -0.03716109, 0.12507336,  0.41193449,
729       -0.20860538, -0.15053082, 0.09120187,  0.24278517,  -0.12222792,
730       -0.15053082, 0.09120187,  0.24278517,  -0.12222792};
731   static float lstm_bw_golden_output[] = {
732       -0.0806187, 0.139077,   0.400476,   -0.197842, -0.0806187, 0.139077,
733       0.400476,   -0.197842,  -0.0332076, 0.123838,  0.309777,   -0.17621,
734       -0.0332076, 0.123838,   0.309777,   -0.17621,  -0.0490733, 0.0739237,
735       0.067706,   -0.0208124, -0.0490733, 0.0739237, 0.067706,   -0.0208124};
736 
737   float* batch0_start = lstm_input;
738   float* batch0_end = batch0_start + lstm.num_inputs() * lstm.num_batches() *
739                                          lstm.sequence_length();
740 
741   lstm.SetInput(0, batch0_start, batch0_end);
742 
743   ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
744 
745   std::vector<float> merged_expected;
746   for (int k = 0; k < lstm.sequence_length() * lstm.num_batches(); k++) {
747     merged_expected.insert(
748         merged_expected.end(),
749         lstm_fw_golden_output + k * lstm.num_fw_outputs(),
750         lstm_fw_golden_output + (k + 1) * lstm.num_fw_outputs());
751     merged_expected.insert(
752         merged_expected.end(),
753         lstm_bw_golden_output + k * lstm.num_bw_outputs(),
754         lstm_bw_golden_output + (k + 1) * lstm.num_bw_outputs());
755   }
756   EXPECT_THAT(lstm.GetFwOutput(),
757               ElementsAreArray(ArrayFloatNear(merged_expected,
758                                               quantize_weights ? 1e-2 : 1e-5)));
759 }
760 
TEST(LSTMOpTest,BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse)761 TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
762   const int n_batch = 1;
763   const int n_input = 2;
764   // n_cell and n_output have the same size when there is no projection.
765   const int n_cell = 4;
766   const int n_output = 4;
767   const int sequence_length = 3;
768 
769   BidirectionalLSTMOpModel lstm(
770       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
771       /*use_peephole=*/false, /*use_projection_weights=*/false,
772       /*use_projection_bias=*/false, /*merge_outputs=*/false,
773       /*use_aux_input=*/false, /*cell_clip=*/0.0,
774       /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
775       {
776           {sequence_length, n_batch, n_input},  // input tensor
777 
778           // Forward cell
779           {n_cell, n_input},  // input_to_input_weight tensor
780           {n_cell, n_input},  // input_to_forget_weight tensor
781           {n_cell, n_input},  // input_to_cell_weight tensor
782           {n_cell, n_input},  // input_to_output_weight tensor
783 
784           {n_cell, n_output},  // recurrent_to_input_weight tensor
785           {n_cell, n_output},  // recurrent_to_forget_weight tensor
786           {n_cell, n_output},  // recurrent_to_cell_weight tensor
787           {n_cell, n_output},  // recurrent_to_output_weight tensor
788 
789           {0},  // cell_to_input_weight tensor
790           {0},  // cell_to_forget_weight tensor
791           {0},  // cell_to_output_weight tensor
792 
793           {n_cell},  // input_gate_bias tensor
794           {n_cell},  // forget_gate_bias tensor
795           {n_cell},  // cell_gate_bias tensor
796           {n_cell},  // output_gate_bias tensor
797 
798           {0, 0},  // projection_weight tensor
799           {0},     // projection_bias tensor
800 
801           // Backward cell
802           {n_cell, n_input},  // input_to_input_weight tensor
803           {n_cell, n_input},  // input_to_forget_weight tensor
804           {n_cell, n_input},  // input_to_cell_weight tensor
805           {n_cell, n_input},  // input_to_output_weight tensor
806 
807           {n_cell, n_output},  // recurrent_to_input_weight tensor
808           {n_cell, n_output},  // recurrent_to_forget_weight tensor
809           {n_cell, n_output},  // recurrent_to_cell_weight tensor
810           {n_cell, n_output},  // recurrent_to_output_weight tensor
811 
812           {0},  // cell_to_input_weight tensor
813           {0},  // cell_to_forget_weight tensor
814           {0},  // cell_to_output_weight tensor
815 
816           {n_cell},  // input_gate_bias tensor
817           {n_cell},  // forget_gate_bias tensor
818           {n_cell},  // cell_gate_bias tensor
819           {n_cell},  // output_gate_bias tensor
820 
821           {0, 0},  // projection_weight tensor
822           {0},     // projection_bias tensor
823 
824           {n_batch, n_output},  // activation_state tensor
825           {n_batch, n_cell},    // cell_state tensor
826 
827           {n_batch, n_output},  // activation_state tensor
828           {n_batch, n_cell},    // cell_state tensor
829 
830           {sequence_length, n_batch, 0},  // aux_input tensor
831           {0},                            // aux_fw_input_to_input tensor
832           {0},                            // aux_fw_input_to_forget tensor
833           {0},                            // aux_fw_input_to_cell tensor
834           {0},                            // aux_fw_input_to_output tensor
835           {0},                            // aux_bw_input_to_input tensor
836           {0},                            // aux_bw_input_to_forget tensor
837           {0},                            // aux_bw_input_to_cell tensor
838           {0},                            // aux_bw_input_to_output tensor
839       });
840 
841   lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
842                                -0.34550029, 0.04266912, -0.15680569,
843                                -0.34856534, 0.43890524});
844 
845   lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
846                               -0.20583314, 0.44344562, 0.22077113,
847                               -0.29909778});
848 
849   lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
850                                 -0.31343272, -0.40032279, 0.44781327,
851                                 0.01387155, -0.35593212});
852 
853   lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
854                                 0.40525138, 0.44272184, 0.03897077, -0.1556896,
855                                 0.19487578});
856 
857   lstm.SetInputGateBias({0., 0., 0., 0.});
858 
859   lstm.SetCellBias({0., 0., 0., 0.});
860 
861   lstm.SetForgetGateBias({1., 1., 1., 1.});
862 
863   lstm.SetOutputGateBias({0., 0., 0., 0.});
864 
865   lstm.SetRecurrentToInputWeights(
866       {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
867        -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
868        -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
869 
870   lstm.SetRecurrentToCellWeights(
871       {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
872        -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
873        -0.46367589, 0.26016325, -0.03894562, -0.16368064});
874 
875   lstm.SetRecurrentToForgetWeights(
876       {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
877        -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
878        0.28053468, 0.01560611, -0.20127171, -0.01140004});
879 
880   lstm.SetRecurrentToOutputWeights(
881       {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
882        0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
883        -0.51818722, -0.15390486, 0.0468148, 0.39922136});
884 
885   // Input should have n_input * sequence_length many values.
886   // Check reversed inputs.
887   static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
888   static float lstm_fw_golden_output[] = {
889       -0.02973187, 0.1229473,  0.20885126, -0.15358765,
890       -0.03716109, 0.12507336, 0.41193449, -0.20860538,
891       -0.15053082, 0.09120187, 0.24278517, -0.12222792};
892   static float lstm_bw_golden_output[] = {
893       -0.0806187, 0.139077, 0.400476,   -0.197842, -0.0332076, 0.123838,
894       0.309777,   -0.17621, -0.0490733, 0.0739237, 0.067706,   -0.0208124};
895 
896   float* batch0_start = lstm_input_reversed;
897   float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
898 
899   lstm.SetInput(0, batch0_start, batch0_end);
900 
901   ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
902 
903   std::vector<float> fw_expected;
904   for (int s = 0; s < lstm.sequence_length(); s++) {
905     float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
906     float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
907     fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
908   }
909   EXPECT_THAT(lstm.GetBwOutput(),
910               ElementsAreArray(ArrayFloatNear(fw_expected)));
911 
912   std::vector<float> bw_expected;
913   for (int s = 0; s < lstm.sequence_length(); s++) {
914     float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
915     float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
916     bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
917   }
918   EXPECT_THAT(lstm.GetFwOutput(),
919               ElementsAreArray(ArrayFloatNear(bw_expected)));
920 }
921 
TEST(LSTMOpTest,BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping)922 TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
923   const int n_batch = 1;
924   const int n_input = 2;
925   // n_cell and n_output have the same size when there is no projection.
926   const int n_cell = 4;
927   const int n_output = 4;
928   const int sequence_length = 3;
929 
930   BidirectionalLSTMOpModel lstm(
931       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
932       /*use_peephole=*/true, /*use_projection_weights=*/false,
933       /*use_projection_bias=*/false, /*merge_outputs=*/false,
934       /*use_aux_input=*/false, /*cell_clip=*/0.0,
935       /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
936       {
937           {sequence_length, n_batch, n_input},  // input tensor
938 
939           {0, 0},             // input_to_input_weight tensor
940           {n_cell, n_input},  // input_to_forget_weight tensor
941           {n_cell, n_input},  // input_to_cell_weight tensor
942           {n_cell, n_input},  // input_to_output_weight tensor
943 
944           {0, 0},              // recurrent_to_input_weight tensor
945           {n_cell, n_output},  // recurrent_to_forget_weight tensor
946           {n_cell, n_output},  // recurrent_to_cell_weight tensor
947           {n_cell, n_output},  // recurrent_to_output_weight tensor
948 
949           {0},       // cell_to_input_weight tensor
950           {n_cell},  // cell_to_forget_weight tensor
951           {n_cell},  // cell_to_output_weight tensor
952 
953           {0},       // input_gate_bias tensor
954           {n_cell},  // forget_gate_bias tensor
955           {n_cell},  // cell_gate_bias tensor
956           {n_cell},  // output_gate_bias tensor
957 
958           {0, 0},  // projection_weight tensor
959           {0},     // projection_bias tensor
960 
961           {0, 0},             // input_to_input_weight tensor
962           {n_cell, n_input},  // input_to_forget_weight tensor
963           {n_cell, n_input},  // input_to_cell_weight tensor
964           {n_cell, n_input},  // input_to_output_weight tensor
965 
966           {0, 0},              // recurrent_to_input_weight tensor
967           {n_cell, n_output},  // recurrent_to_forget_weight tensor
968           {n_cell, n_output},  // recurrent_to_cell_weight tensor
969           {n_cell, n_output},  // recurrent_to_output_weight tensor
970 
971           {0},       // cell_to_input_weight tensor
972           {n_cell},  // cell_to_forget_weight tensor
973           {n_cell},  // cell_to_output_weight tensor
974 
975           {0},       // input_gate_bias tensor
976           {n_cell},  // forget_gate_bias tensor
977           {n_cell},  // cell_gate_bias tensor
978           {n_cell},  // output_gate_bias tensor
979 
980           {0, 0},  // projection_weight tensor
981           {0},     // projection_bias tensor
982 
983           {n_batch, n_output},  // activation_state tensor
984           {n_batch, n_cell},    // cell_state tensor
985 
986           {n_batch, n_output},  // activation_state tensor
987           {n_batch, n_cell},    // cell_state tensor
988 
989           {sequence_length, n_batch, 0},  // aux_input tensor
990           {0},                            // aux_fw_input_to_input tensor
991           {0},                            // aux_fw_input_to_forget tensor
992           {0},                            // aux_fw_input_to_cell tensor
993           {0},                            // aux_fw_input_to_output tensor
994           {0},                            // aux_bw_input_to_input tensor
995           {0},                            // aux_bw_input_to_forget tensor
996           {0},                            // aux_bw_input_to_cell tensor
997           {0},                            // aux_bw_input_to_output tensor
998       });
999 
1000   lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
1001                               0.04717243, 0.48944736, -0.38535351,
1002                               -0.17212132});
1003 
1004   lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
1005                                 -0.3633365, -0.22755712, 0.28253698, 0.24407166,
1006                                 0.33826375});
1007 
1008   lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
1009                                 -0.09426838, -0.44257352, 0.54939759,
1010                                 0.01533556, 0.42751634});
1011 
1012   lstm.SetCellBias({0., 0., 0., 0.});
1013 
1014   lstm.SetForgetGateBias({1., 1., 1., 1.});
1015 
1016   lstm.SetOutputGateBias({0., 0., 0., 0.});
1017 
1018   lstm.SetRecurrentToCellWeights(
1019       {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
1020        0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
1021        0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
1022        0.21193194});
1023 
1024   lstm.SetRecurrentToForgetWeights(
1025       {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
1026        0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
1027        -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
1028 
1029   lstm.SetRecurrentToOutputWeights(
1030       {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
1031        -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
1032        0.50248802, 0.26114327, -0.43736315, 0.33149987});
1033 
1034   lstm.SetCellToForgetWeights(
1035       {0.47485286, -0.51955009, -0.24458408, 0.31544167});
1036   lstm.SetCellToOutputWeights(
1037       {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
1038 
1039   static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
1040   static float lstm_fw_golden_output[] = {
1041       -0.36444446, -0.00352185, 0.12886585, -0.05163646,
1042       -0.42312205, -0.01218222, 0.24201041, -0.08124574,
1043       -0.358325,   -0.04621704, 0.21641694, -0.06471302};
1044   static float lstm_bw_golden_output[] = {
1045       -0.401685, -0.0232794, 0.288642,  -0.123074,   -0.42915,  -0.00871577,
1046       0.20912,   -0.103567,  -0.166398, -0.00486649, 0.0697471, -0.0537578};
1047 
1048   float* batch0_start = lstm_input;
1049   float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
1050 
1051   lstm.SetInput(0, batch0_start, batch0_end);
1052 
1053   ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
1054 
1055   float* fw_golden_start = lstm_fw_golden_output;
1056   float* fw_golden_end =
1057       fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length();
1058   std::vector<float> fw_expected;
1059   fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end);
1060   EXPECT_THAT(lstm.GetFwOutput(),
1061               ElementsAreArray(ArrayFloatNear(fw_expected)));
1062 
1063   float* bw_golden_start = lstm_bw_golden_output;
1064   float* bw_golden_end =
1065       bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length();
1066   std::vector<float> bw_expected;
1067   bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
1068   EXPECT_THAT(lstm.GetBwOutput(),
1069               ElementsAreArray(ArrayFloatNear(bw_expected)));
1070 }
1071 
TEST(LSTMOpTest,BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed)1072 TEST(LSTMOpTest,
1073      BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed) {
1074   const int n_batch = 1;
1075   const int n_input = 2;
1076   // n_cell and n_output have the same size when there is no projection.
1077   const int n_cell = 4;
1078   const int n_output = 4;
1079   const int sequence_length = 3;
1080 
1081   BidirectionalLSTMOpModel lstm(
1082       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
1083       /*use_peephole=*/true, /*use_projection_weights=*/false,
1084       /*use_projection_bias=*/false, /*merge_outputs=*/false,
1085       /*use_aux_input=*/false, /*cell_clip=*/0.0,
1086       /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
1087       {
1088           {sequence_length, n_batch, n_input},  // input tensor
1089 
1090           {0, 0},             // input_to_input_weight tensor
1091           {n_cell, n_input},  // input_to_forget_weight tensor
1092           {n_cell, n_input},  // input_to_cell_weight tensor
1093           {n_cell, n_input},  // input_to_output_weight tensor
1094 
1095           {0, 0},              // recurrent_to_input_weight tensor
1096           {n_cell, n_output},  // recurrent_to_forget_weight tensor
1097           {n_cell, n_output},  // recurrent_to_cell_weight tensor
1098           {n_cell, n_output},  // recurrent_to_output_weight tensor
1099 
1100           {0},       // cell_to_input_weight tensor
1101           {n_cell},  // cell_to_forget_weight tensor
1102           {n_cell},  // cell_to_output_weight tensor
1103 
1104           {0},       // input_gate_bias tensor
1105           {n_cell},  // forget_gate_bias tensor
1106           {n_cell},  // cell_gate_bias tensor
1107           {n_cell},  // output_gate_bias tensor
1108 
1109           {0, 0},  // projection_weight tensor
1110           {0},     // projection_bias tensor
1111 
1112           {0, 0},             // input_to_input_weight tensor
1113           {n_cell, n_input},  // input_to_forget_weight tensor
1114           {n_cell, n_input},  // input_to_cell_weight tensor
1115           {n_cell, n_input},  // input_to_output_weight tensor
1116 
1117           {0, 0},              // recurrent_to_input_weight tensor
1118           {n_cell, n_output},  // recurrent_to_forget_weight tensor
1119           {n_cell, n_output},  // recurrent_to_cell_weight tensor
1120           {n_cell, n_output},  // recurrent_to_output_weight tensor
1121 
1122           {0},       // cell_to_input_weight tensor
1123           {n_cell},  // cell_to_forget_weight tensor
1124           {n_cell},  // cell_to_output_weight tensor
1125 
1126           {0},       // input_gate_bias tensor
1127           {n_cell},  // forget_gate_bias tensor
1128           {n_cell},  // cell_gate_bias tensor
1129           {n_cell},  // output_gate_bias tensor
1130 
1131           {0, 0},  // projection_weight tensor
1132           {0},     // projection_bias tensor
1133 
1134           {n_batch, n_output},  // activation_state tensor
1135           {n_batch, n_cell},    // cell_state tensor
1136 
1137           {n_batch, n_output},  // activation_state tensor
1138           {n_batch, n_cell},    // cell_state tensor
1139 
1140           {sequence_length, n_batch, 0},  // aux_input tensor
1141           {0},                            // aux_fw_input_to_input tensor
1142           {0},                            // aux_fw_input_to_forget tensor
1143           {0},                            // aux_fw_input_to_cell tensor
1144           {0},                            // aux_fw_input_to_output tensor
1145           {0},                            // aux_bw_input_to_input tensor
1146           {0},                            // aux_bw_input_to_forget tensor
1147           {0},                            // aux_bw_input_to_cell tensor
1148           {0},                            // aux_bw_input_to_output tensor
1149       });
1150 
1151   lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
1152                               0.04717243, 0.48944736, -0.38535351,
1153                               -0.17212132});
1154 
1155   lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
1156                                 -0.3633365, -0.22755712, 0.28253698, 0.24407166,
1157                                 0.33826375});
1158 
1159   lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
1160                                 -0.09426838, -0.44257352, 0.54939759,
1161                                 0.01533556, 0.42751634});
1162 
1163   lstm.SetCellBias({0., 0., 0., 0.});
1164 
1165   lstm.SetForgetGateBias({1., 1., 1., 1.});
1166 
1167   lstm.SetOutputGateBias({0., 0., 0., 0.});
1168 
1169   lstm.SetRecurrentToCellWeights(
1170       {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
1171        0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
1172        0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
1173        0.21193194});
1174 
1175   lstm.SetRecurrentToForgetWeights(
1176       {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
1177        0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
1178        -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
1179 
1180   lstm.SetRecurrentToOutputWeights(
1181       {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
1182        -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
1183        0.50248802, 0.26114327, -0.43736315, 0.33149987});
1184 
1185   lstm.SetCellToForgetWeights(
1186       {0.47485286, -0.51955009, -0.24458408, 0.31544167});
1187   lstm.SetCellToOutputWeights(
1188       {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
1189 
1190   static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
1191   static float lstm_fw_golden_output[] = {
1192       -0.36444446, -0.00352185, 0.12886585, -0.05163646,
1193       -0.42312205, -0.01218222, 0.24201041, -0.08124574,
1194       -0.358325,   -0.04621704, 0.21641694, -0.06471302};
1195   static float lstm_bw_golden_output[] = {
1196       -0.401685, -0.0232794, 0.288642,  -0.123074,   -0.42915,  -0.00871577,
1197       0.20912,   -0.103567,  -0.166398, -0.00486649, 0.0697471, -0.0537578};
1198 
1199   float* batch0_start = lstm_input_reversed;
1200   float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
1201 
1202   lstm.SetInput(0, batch0_start, batch0_end);
1203 
1204   ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
1205 
1206   std::vector<float> fw_expected;
1207   for (int s = 0; s < lstm.sequence_length(); s++) {
1208     float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
1209     float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
1210     fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
1211   }
1212   EXPECT_THAT(lstm.GetBwOutput(),
1213               ElementsAreArray(ArrayFloatNear(fw_expected)));
1214 
1215   std::vector<float> bw_expected;
1216   for (int s = 0; s < lstm.sequence_length(); s++) {
1217     float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
1218     float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
1219     bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
1220   }
1221   EXPECT_THAT(lstm.GetFwOutput(),
1222               ElementsAreArray(ArrayFloatNear(bw_expected)));
1223 }
1224 
TEST(LSTMOpTest,BlackBoxTestWithPeepholeWithProjectionNoClipping)1225 TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
1226   const int n_batch = 2;
1227   const int n_input = 5;
1228   const int n_cell = 20;
1229   const int n_output = 16;
1230   const int sequence_length = 4;
1231 
1232   BidirectionalLSTMOpModel lstm(
1233       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
1234       /*use_peephole=*/true, /*use_projection_weights=*/true,
1235       /*use_projection_bias=*/false, /*merge_outputs=*/false,
1236       /*use_aux_input=*/false, /*cell_clip=*/0.0,
1237       /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
1238       {
1239           {sequence_length, n_batch, n_input},  // input tensor
1240 
1241           {n_cell, n_input},  // input_to_input_weight tensor
1242           {n_cell, n_input},  // input_to_forget_weight tensor
1243           {n_cell, n_input},  // input_to_cell_weight tensor
1244           {n_cell, n_input},  // input_to_output_weight tensor
1245 
1246           {n_cell, n_output},  // recurrent_to_input_weight tensor
1247           {n_cell, n_output},  // recurrent_to_forget_weight tensor
1248           {n_cell, n_output},  // recurrent_to_cell_weight tensor
1249           {n_cell, n_output},  // recurrent_to_output_weight tensor
1250 
1251           {n_cell},  // cell_to_input_weight tensor
1252           {n_cell},  // cell_to_forget_weight tensor
1253           {n_cell},  // cell_to_output_weight tensor
1254 
1255           {n_cell},  // input_gate_bias tensor
1256           {n_cell},  // forget_gate_bias tensor
1257           {n_cell},  // cell_gate_bias tensor
1258           {n_cell},  // output_gate_bias tensor
1259 
1260           {n_output, n_cell},  // projection_weight tensor
1261           {0},                 // projection_bias tensor
1262 
1263           {n_cell, n_input},  // input_to_input_weight tensor
1264           {n_cell, n_input},  // input_to_forget_weight tensor
1265           {n_cell, n_input},  // input_to_cell_weight tensor
1266           {n_cell, n_input},  // input_to_output_weight tensor
1267 
1268           {n_cell, n_output},  // recurrent_to_input_weight tensor
1269           {n_cell, n_output},  // recurrent_to_forget_weight tensor
1270           {n_cell, n_output},  // recurrent_to_cell_weight tensor
1271           {n_cell, n_output},  // recurrent_to_output_weight tensor
1272 
1273           {n_cell},  // cell_to_input_weight tensor
1274           {n_cell},  // cell_to_forget_weight tensor
1275           {n_cell},  // cell_to_output_weight tensor
1276 
1277           {n_cell},  // input_gate_bias tensor
1278           {n_cell},  // forget_gate_bias tensor
1279           {n_cell},  // cell_gate_bias tensor
1280           {n_cell},  // output_gate_bias tensor
1281 
1282           {n_output, n_cell},  // projection_weight tensor
1283           {0},                 // projection_bias tensor
1284 
1285           {n_batch, n_output},  // activation_state tensor
1286           {n_batch, n_cell},    // cell_state tensor
1287 
1288           {n_batch, n_output},  // activation_state tensor
1289           {n_batch, n_cell},    // cell_state tensor
1290 
1291           {sequence_length, n_batch, 0},  // aux_input tensor
1292           {0},                            // aux_fw_input_to_input tensor
1293           {0},                            // aux_fw_input_to_forget tensor
1294           {0},                            // aux_fw_input_to_cell tensor
1295           {0},                            // aux_fw_input_to_output tensor
1296           {0},                            // aux_bw_input_to_input tensor
1297           {0},                            // aux_bw_input_to_forget tensor
1298           {0},                            // aux_bw_input_to_cell tensor
1299           {0},                            // aux_bw_input_to_output tensor
1300       });
1301 
1302   lstm.SetInputToInputWeights(
1303       {0.021393683,  0.06124551,    0.046905167,  -0.014657677,  -0.03149463,
1304        0.09171803,   0.14647801,    0.10797193,   -0.0057968358, 0.0019193048,
1305        -0.2726754,   0.10154029,    -0.018539885, 0.080349885,   -0.10262385,
1306        -0.022599787, -0.09121155,   -0.008675967, -0.045206103,  -0.0821282,
1307        -0.008045952, 0.015478081,   0.055217247,  0.038719587,   0.044153627,
1308        -0.06453243,  0.05031825,    -0.046935108, -0.008164439,  0.014574226,
1309        -0.1671009,   -0.15519552,   -0.16819797,  -0.13971269,   -0.11953059,
1310        0.25005487,   -0.22790983,   0.009855087,  -0.028140958,  -0.11200698,
1311        0.11295408,   -0.0035217577, 0.054485075,  0.05184695,    0.064711206,
1312        0.10989193,   0.11674786,    0.03490607,   0.07727357,    0.11390585,
1313        -0.1863375,   -0.1034451,    -0.13945189,  -0.049401227,  -0.18767063,
1314        0.042483903,  0.14233552,    0.13832581,   0.18350165,    0.14545603,
1315        -0.028545704, 0.024939531,   0.050929718,  0.0076203286,  -0.0029723682,
1316        -0.042484224, -0.11827596,   -0.09171104,  -0.10808628,   -0.16327988,
1317        -0.2273378,   -0.0993647,    -0.017155107, 0.0023917493,  0.049272764,
1318        0.0038534778, 0.054764505,   0.089753784,  0.06947234,    0.08014476,
1319        -0.04544234,  -0.0497073,    -0.07135631,  -0.048929106,  -0.004042012,
1320        -0.009284026, 0.018042054,   0.0036860977, -0.07427302,   -0.11434604,
1321        -0.018995456, 0.031487543,   0.012834908,  0.019977754,   0.044256654,
1322        -0.39292613,  -0.18519334,   -0.11651281,  -0.06809892,   0.011373677});
1323 
1324   lstm.SetInputToForgetWeights(
1325       {-0.0018401089, -0.004852237,  0.03698424,   0.014181704,   0.028273236,
1326        -0.016726194,  -0.05249759,   -0.10204261,  0.00861066,    -0.040979505,
1327        -0.009899187,  0.01923892,    -0.028177269, -0.08535103,   -0.14585495,
1328        0.10662567,    -0.01909731,   -0.017883534, -0.0047269356, -0.045103323,
1329        0.0030784295,  0.076784775,   0.07463696,   0.094531395,   0.0814421,
1330        -0.12257899,   -0.033945758,  -0.031303465, 0.045630626,   0.06843887,
1331        -0.13492945,   -0.012480007,  -0.0811829,   -0.07224499,   -0.09628791,
1332        0.045100946,   0.0012300825,  0.013964662,  0.099372394,   0.02543059,
1333        0.06958324,    0.034257296,   0.0482646,    0.06267997,    0.052625068,
1334        0.12784666,    0.07077897,    0.025725935,  0.04165009,    0.07241905,
1335        0.018668644,   -0.037377294,  -0.06277783,  -0.08833636,   -0.040120605,
1336        -0.011405586,  -0.007808335,  -0.010301386, -0.005102167,  0.027717464,
1337        0.05483423,    0.11449111,    0.11289652,   0.10939839,    0.13396506,
1338        -0.08402166,   -0.01901462,   -0.044678304, -0.07720565,   0.014350063,
1339        -0.11757958,   -0.0652038,    -0.08185733,  -0.076754324,  -0.092614375,
1340        0.10405491,    0.052960336,   0.035755895,  0.035839386,   -0.012540553,
1341        0.036881298,   0.02913376,    0.03420159,   0.05448447,    -0.054523353,
1342        0.02582715,    0.02327355,    -0.011857179, -0.0011980024, -0.034641717,
1343        -0.026125094,  -0.17582615,   -0.15923657,  -0.27486774,   -0.0006143371,
1344        0.0001771948,  -8.470171e-05, 0.02651807,   0.045790765,   0.06956496});
1345 
1346   lstm.SetInputToCellWeights(
1347       {-0.04580283,   -0.09549462,   -0.032418985,  -0.06454633,
1348        -0.043528453,  0.043018587,   -0.049152344,  -0.12418144,
1349        -0.078985475,  -0.07596889,   0.019484362,   -0.11434962,
1350        -0.0074034138, -0.06314844,   -0.092981495,  0.0062155537,
1351        -0.025034338,  -0.0028890965, 0.048929527,   0.06235075,
1352        0.10665918,    -0.032036792,  -0.08505916,   -0.10843358,
1353        -0.13002433,   -0.036816437,  -0.02130134,   -0.016518239,
1354        0.0047691227,  -0.0025825808, 0.066017866,   0.029991534,
1355        -0.10652836,   -0.1037554,    -0.13056071,   -0.03266643,
1356        -0.033702414,  -0.006473424,  -0.04611692,   0.014419339,
1357        -0.025174323,  0.0396852,     0.081777506,   0.06157468,
1358        0.10210095,    -0.009658194,  0.046511717,   0.03603906,
1359        0.0069369148,  0.015960095,   -0.06507666,   0.09551598,
1360        0.053568836,   0.06408714,    0.12835667,    -0.008714329,
1361        -0.20211966,   -0.12093674,   0.029450472,   0.2849013,
1362        -0.029227901,  0.1164364,     -0.08560263,   0.09941786,
1363        -0.036999565,  -0.028842626,  -0.0033637602, -0.017012902,
1364        -0.09720865,   -0.11193351,   -0.029155117,  -0.017936034,
1365        -0.009768936,  -0.04223324,   -0.036159635,  0.06505112,
1366        -0.021742892,  -0.023377212,  -0.07221364,   -0.06430552,
1367        0.05453865,    0.091149814,   0.06387331,    0.007518393,
1368        0.055960953,   0.069779344,   0.046411168,   0.10509911,
1369        0.07463894,    0.0075130584,  0.012850982,   0.04555431,
1370        0.056955688,   0.06555285,    0.050801456,   -0.009862683,
1371        0.00826772,    -0.026555609,  -0.0073611983, -0.0014897042});
1372 
1373   lstm.SetInputToOutputWeights(
1374       {-0.0998932,   -0.07201956,  -0.052803773,  -0.15629593,  -0.15001918,
1375        -0.07650751,  0.02359855,   -0.075155355,  -0.08037709,  -0.15093534,
1376        0.029517552,  -0.04751393,  0.010350531,   -0.02664851,  -0.016839722,
1377        -0.023121163, 0.0077019283, 0.012851257,   -0.05040649,  -0.0129761,
1378        -0.021737747, -0.038305793, -0.06870586,   -0.01481247,  -0.001285394,
1379        0.10124236,   0.083122835,  0.053313006,   -0.062235646, -0.075637154,
1380        -0.027833903, 0.029774971,  0.1130802,     0.09218906,   0.09506135,
1381        -0.086665764, -0.037162706, -0.038880914,  -0.035832845, -0.014481564,
1382        -0.09825003,  -0.12048569,  -0.097665586,  -0.05287633,  -0.0964047,
1383        -0.11366429,  0.035777505,  0.13568819,    0.052451383,  0.050649304,
1384        0.05798951,   -0.021852335, -0.099848844,  0.014740475,  -0.078897946,
1385        0.04974699,   0.014160473,  0.06973932,    0.04964942,   0.033364646,
1386        0.08190124,   0.025535367,  0.050893165,   0.048514254,  0.06945813,
1387        -0.078907564, -0.06707616,  -0.11844508,   -0.09986688,  -0.07509403,
1388        0.06263226,   0.14925587,   0.20188436,    0.12098451,   0.14639415,
1389        0.0015017595, -0.014267382, -0.03417257,   0.012711468,  0.0028300495,
1390        -0.024758482, -0.05098548,  -0.0821182,    0.014225672,  0.021544158,
1391        0.08949725,   0.07505268,   -0.0020780868, 0.04908258,   0.06476295,
1392        -0.022907063, 0.027562456,  0.040185735,   0.019567577,  -0.015598739,
1393        -0.049097303, -0.017121866, -0.083368234,  -0.02332002,  -0.0840956});
1394 
1395   lstm.SetInputGateBias(
1396       {0.02234832,  0.14757581,   0.18176508,  0.10380666,  0.053110216,
1397        -0.06928846, -0.13942584,  -0.11816189, 0.19483899,  0.03652339,
1398        -0.10250295, 0.036714908,  -0.18426876, 0.036065217, 0.21810818,
1399        0.02383196,  -0.043370757, 0.08690144,  -0.04444982, 0.00030581196});
1400 
1401   lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
1402                           0.11098921,  0.15378423,   0.09263801,  0.09790885,
1403                           0.09508917,  0.061199076,  0.07665568,  -0.015443159,
1404                           -0.03499149, 0.046190713,  0.08895977,  0.10899629,
1405                           0.40694186,  0.06030037,   0.012413437, -0.06108739});
1406 
1407   lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132,   0.033463873,
1408                     -0.1483596,   -0.10639995,  -0.091433935, 0.058573797,
1409                     -0.06809782,  -0.07889636,  -0.043246906, -0.09829136,
1410                     -0.4279842,   0.034901652,  0.18797937,   0.0075234566,
1411                     0.016178843,  0.1749513,    0.13975595,   0.92058027});
1412 
1413   lstm.SetOutputGateBias(
1414       {0.046159424,  -0.0012809046, 0.03563469,   0.12648113, 0.027195795,
1415        0.35373217,   -0.018957434,  0.008907322,  -0.0762701, 0.12018895,
1416        0.04216877,   0.0022856654,  0.040952638,  0.3147856,  0.08225149,
1417        -0.057416286, -0.14995944,   -0.008040261, 0.13208859, 0.029760877});
1418 
1419   lstm.SetRecurrentToInputWeights(
1420       {-0.001374326,   -0.078856036,   0.10672688,    0.029162422,
1421        -0.11585556,    0.02557986,     -0.13446963,   -0.035785314,
1422        -0.01244275,    0.025961924,    -0.02337298,   -0.044228926,
1423        -0.055839065,   -0.046598054,   -0.010546039,  -0.06900766,
1424        0.027239809,    0.022582639,    -0.013296484,  -0.05459212,
1425        0.08981,        -0.045407712,   0.08682226,    -0.06867011,
1426        -0.14390695,    -0.02916037,    0.000996957,   0.091420636,
1427        0.14283475,     -0.07390571,    -0.06402044,   0.062524505,
1428        -0.093129106,   0.04860203,     -0.08364217,   -0.08119002,
1429        0.009352075,    0.22920375,     0.0016303885,  0.11583097,
1430        -0.13732095,    0.012405723,    -0.07551853,   0.06343048,
1431        0.12162708,     -0.031923793,   -0.014335606,  0.01790974,
1432        -0.10650317,    -0.0724401,     0.08554849,    -0.05727212,
1433        0.06556731,     -0.042729504,   -0.043227166,  0.011683251,
1434        -0.013082158,   -0.029302018,   -0.010899579,  -0.062036745,
1435        -0.022509435,   -0.00964907,    -0.01567329,   0.04260106,
1436        -0.07787477,    -0.11576462,    0.017356863,   0.048673786,
1437        -0.017577527,   -0.05527947,    -0.082487635,  -0.040137455,
1438        -0.10820036,    -0.04666372,    0.022746278,   -0.07851417,
1439        0.01068115,     0.032956902,    0.022433773,   0.0026891115,
1440        0.08944216,     -0.0685835,     0.010513544,   0.07228705,
1441        0.02032331,     -0.059686817,   -0.0005566496, -0.086984694,
1442        0.040414046,    -0.1380399,     0.094208956,   -0.05722982,
1443        0.012092817,    -0.04989123,    -0.086576,     -0.003399834,
1444        -0.04696032,    -0.045747425,   0.10091314,    0.048676282,
1445        -0.029037097,   0.031399418,    -0.0040285117, 0.047237843,
1446        0.09504992,     0.041799378,    -0.049185462,  -0.031518843,
1447        -0.10516937,    0.026374253,    0.10058866,    -0.0033195973,
1448        -0.041975245,   0.0073591834,   0.0033782164,  -0.004325073,
1449        -0.10167381,    0.042500053,    -0.01447153,   0.06464186,
1450        -0.017142897,   0.03312627,     0.009205989,   0.024138335,
1451        -0.011337001,   0.035530265,    -0.010912711,  0.0706555,
1452        -0.005894094,   0.051841937,    -0.1401738,    -0.02351249,
1453        0.0365468,      0.07590991,     0.08838724,    0.021681072,
1454        -0.10086113,    0.019608743,    -0.06195883,   0.077335775,
1455        0.023646897,    -0.095322326,   0.02233014,    0.09756986,
1456        -0.048691444,   -0.009579111,   0.07595467,    0.11480546,
1457        -0.09801813,    0.019894179,    0.08502348,    0.004032281,
1458        0.037211012,    0.068537936,    -0.048005626,  -0.091520436,
1459        -0.028379958,   -0.01556313,    0.06554592,    -0.045599163,
1460        -0.01672207,    -0.020169014,   -0.011877351,  -0.20212261,
1461        0.010889619,    0.0047078193,   0.038385306,   0.08540671,
1462        -0.017140968,   -0.0035865551,  0.016678626,   0.005633034,
1463        0.015963363,    0.00871737,     0.060130805,   0.028611384,
1464        0.10109069,     -0.015060172,   -0.07894427,   0.06401885,
1465        0.011584063,    -0.024466386,   0.0047652307,  -0.09041358,
1466        0.030737216,    -0.0046374933,  0.14215417,    -0.11823516,
1467        0.019899689,    0.006106124,    -0.027092824,  0.0786356,
1468        0.05052217,     -0.058925,      -0.011402121,  -0.024987547,
1469        -0.0013661642,  -0.06832946,    -0.015667673,  -0.1083353,
1470        -0.00096863037, -0.06988685,    -0.053350925,  -0.027275559,
1471        -0.033664223,   -0.07978348,    -0.025200296,  -0.017207067,
1472        -0.058403496,   -0.055697463,   0.005798788,   0.12965427,
1473        -0.062582195,   0.0013350133,   -0.10482091,   0.0379771,
1474        0.072521195,    -0.0029455067,  -0.13797039,   -0.03628521,
1475        0.013806405,    -0.017858358,   -0.01008298,   -0.07700066,
1476        -0.017081132,   0.019358726,    0.0027079724,  0.004635139,
1477        0.062634714,    -0.02338735,    -0.039547626,  -0.02050681,
1478        0.03385117,     -0.083611414,   0.002862572,   -0.09421313,
1479        0.058618143,    -0.08598433,    0.00972939,    0.023867095,
1480        -0.053934585,   -0.023203006,   0.07452513,    -0.048767887,
1481        -0.07314807,    -0.056307215,   -0.10433547,   -0.06440842,
1482        0.04328182,     0.04389765,     -0.020006588,  -0.09076438,
1483        -0.11652589,    -0.021705797,   0.03345259,    -0.010329105,
1484        -0.025767034,   0.013057034,    -0.07316461,   -0.10145612,
1485        0.06358255,     0.18531723,     0.07759293,    0.12006465,
1486        0.1305557,      0.058638252,    -0.03393652,   0.09622831,
1487        -0.16253184,    -2.4580743e-06, 0.079869635,   -0.070196845,
1488        -0.005644518,   0.06857898,     -0.12598175,   -0.035084512,
1489        0.03156317,     -0.12794146,    -0.031963028,  0.04692781,
1490        0.030070418,    0.0071660685,   -0.095516115,  -0.004643372,
1491        0.040170413,    -0.062104587,   -0.0037324072, 0.0554317,
1492        0.08184801,     -0.019164372,   0.06791302,    0.034257166,
1493        -0.10307039,    0.021943003,    0.046745934,   0.0790918,
1494        -0.0265588,     -0.007824208,   0.042546265,   -0.00977924,
1495        -0.0002440307,  -0.017384544,   -0.017990116,  0.12252321,
1496        -0.014512694,   -0.08251313,    0.08861942,    0.13589665,
1497        0.026351685,    0.012641483,    0.07466548,    0.044301085,
1498        -0.045414884,   -0.051112458,   0.03444247,    -0.08502782,
1499        -0.04106223,    -0.028126027,   0.028473156,   0.10467447});
1500 
1501   lstm.SetRecurrentToForgetWeights(
1502       {-0.057784554,  -0.026057621,  -0.068447545,   -0.022581743,
1503        0.14811787,    0.10826372,    0.09471067,     0.03987225,
1504        -0.0039523416, 0.00030638507, 0.053185795,    0.10572994,
1505        0.08414449,    -0.022036452,  -0.00066928595, -0.09203576,
1506        0.032950465,   -0.10985798,   -0.023809856,   0.0021431844,
1507        -0.02196096,   -0.00326074,   0.00058621005,  -0.074678116,
1508        -0.06193199,   0.055729095,   0.03736828,     0.020123724,
1509        0.061878487,   -0.04729229,   0.034919553,    -0.07585433,
1510        -0.04421272,   -0.044019096,  0.085488975,    0.04058006,
1511        -0.06890133,   -0.030951202,  -0.024628663,   -0.07672815,
1512        0.034293607,   0.08556707,    -0.05293577,    -0.033561368,
1513        -0.04899627,   0.0241671,     0.015736353,    -0.095442444,
1514        -0.029564252,  0.016493602,   -0.035026584,   0.022337519,
1515        -0.026871363,  0.004780428,   0.0077918363,   -0.03601621,
1516        0.016435321,   -0.03263031,   -0.09543275,    -0.047392778,
1517        0.013454138,   0.028934088,   0.01685226,     -0.086110644,
1518        -0.046250615,  -0.01847454,   0.047608484,    0.07339695,
1519        0.034546845,   -0.04881143,   0.009128804,    -0.08802852,
1520        0.03761666,    0.008096139,   -0.014454086,   0.014361001,
1521        -0.023502491,  -0.0011840804, -0.07607001,    0.001856849,
1522        -0.06509276,   -0.006021153,  -0.08570962,    -0.1451793,
1523        0.060212336,   0.055259194,   0.06974018,     0.049454916,
1524        -0.027794661,  -0.08077226,   -0.016179763,   0.1169753,
1525        0.17213494,    -0.0056326236, -0.053934924,   -0.0124349,
1526        -0.11520337,   0.05409887,    0.088759385,    0.0019655675,
1527        0.0042065294,  0.03881498,    0.019844765,    0.041858196,
1528        -0.05695512,   0.047233116,   0.038937137,    -0.06542224,
1529        0.014429736,   -0.09719407,   0.13908425,     -0.05379757,
1530        0.012321099,   0.082840554,   -0.029899208,   0.044217527,
1531        0.059855383,   0.07711018,    -0.045319796,   0.0948846,
1532        -0.011724666,  -0.0033288454, -0.033542685,   -0.04764985,
1533        -0.13873616,   0.040668588,   0.034832682,    -0.015319203,
1534        -0.018715994,  0.046002675,   0.0599172,      -0.043107376,
1535        0.0294216,     -0.002314414,  -0.022424703,   0.0030315618,
1536        0.0014641669,  0.0029166266,  -0.11878115,    0.013738511,
1537        0.12375372,    -0.0006038222, 0.029104086,    0.087442465,
1538        0.052958444,   0.07558703,    0.04817258,     0.044462286,
1539        -0.015213451,  -0.08783778,   -0.0561384,     -0.003008196,
1540        0.047060397,   -0.002058388,  0.03429439,     -0.018839769,
1541        0.024734668,   0.024614193,   -0.042046934,   0.09597743,
1542        -0.0043254104, 0.04320769,    0.0064070094,   -0.0019131786,
1543        -0.02558259,   -0.022822596,  -0.023273505,   -0.02464396,
1544        -0.10991725,   -0.006240552,  0.0074488563,   0.024044557,
1545        0.04383914,    -0.046476185,  0.028658995,    0.060410924,
1546        0.050786525,   0.009452605,   -0.0073054377,  -0.024810238,
1547        0.0052906186,  0.0066939713,  -0.0020913032,  0.014515517,
1548        0.015898481,   0.021362653,   -0.030262267,   0.016587038,
1549        -0.011442813,  0.041154444,   -0.007631438,   -0.03423484,
1550        -0.010977775,  0.036152758,   0.0066366293,   0.11915515,
1551        0.02318443,    -0.041350313,  0.021485701,    -0.10906167,
1552        -0.028218046,  -0.00954771,   0.020531068,    -0.11995105,
1553        -0.03672871,   0.024019798,   0.014255957,    -0.05221243,
1554        -0.00661567,   -0.04630967,   0.033188973,    0.10107534,
1555        -0.014027541,  0.030796422,   -0.10270911,    -0.035999842,
1556        0.15443139,    0.07684145,    0.036571592,    -0.035900835,
1557        -0.0034699554, 0.06209149,    0.015920248,    -0.031122351,
1558        -0.03858649,   0.01849943,    0.13872518,     0.01503974,
1559        0.069941424,   -0.06948533,   -0.0088794185,  0.061282158,
1560        -0.047401894,  0.03100163,    -0.041533746,   -0.10430945,
1561        0.044574402,   -0.01425562,   -0.024290353,   0.034563623,
1562        0.05866852,    0.023947537,   -0.09445152,    0.035450947,
1563        0.02247216,    -0.0042998926, 0.061146557,    -0.10250651,
1564        0.020881841,   -0.06747029,   0.10062043,     -0.0023941975,
1565        0.03532124,    -0.016341697,  0.09685456,     -0.016764693,
1566        0.051808182,   0.05875331,    -0.04536488,    0.001626336,
1567        -0.028892258,  -0.01048663,   -0.009793449,   -0.017093895,
1568        0.010987891,   0.02357273,    -0.00010856845, 0.0099760275,
1569        -0.001845119,  -0.03551521,   0.0018358806,   0.05763657,
1570        -0.01769146,   0.040995963,   0.02235177,     -0.060430344,
1571        0.11475477,    -0.023854522,  0.10071741,     0.0686208,
1572        -0.014250481,  0.034261297,   0.047418304,    0.08562733,
1573        -0.030519066,  0.0060542435,  0.014653856,    -0.038836084,
1574        0.04096551,    0.032249358,   -0.08355519,    -0.026823482,
1575        0.056386515,   -0.010401743,  -0.028396193,   0.08507674,
1576        0.014410365,   0.020995233,   0.17040324,     0.11511526,
1577        0.02459721,    0.0066619175,  0.025853224,    -0.023133837,
1578        -0.081302024,  0.017264642,   -0.009585969,   0.09491168,
1579        -0.051313367,  0.054532815,   -0.014298593,   0.10657464,
1580        0.007076659,   0.10964551,    0.0409152,      0.008275321,
1581        -0.07283536,   0.07937492,    0.04192024,     -0.1075027});
1582 
1583   lstm.SetRecurrentToCellWeights(
1584       {-0.037322544,   0.018592842,   0.0056175636,  -0.06253426,
1585        0.055647098,    -0.05713207,   -0.05626563,   0.005559383,
1586        0.03375411,     -0.025757805,  -0.088049285,  0.06017052,
1587        -0.06570978,    0.007384076,   0.035123326,   -0.07920549,
1588        0.053676967,    0.044480428,   -0.07663568,   0.0071805613,
1589        0.08089997,     0.05143358,    0.038261272,   0.03339287,
1590        -0.027673481,   0.044746667,   0.028349208,   0.020090483,
1591        -0.019443132,   -0.030755889,  -0.0040000007, 0.04465846,
1592        -0.021585021,   0.0031670958,  0.0053199246,  -0.056117613,
1593        -0.10893326,    0.076739706,   -0.08509834,   -0.027997585,
1594        0.037871376,    0.01449768,    -0.09002357,   -0.06111149,
1595        -0.046195522,   0.0422062,     -0.005683705,  -0.1253618,
1596        -0.012925729,   -0.04890792,   0.06985068,    0.037654128,
1597        0.03398274,     -0.004781977,  0.007032333,   -0.031787455,
1598        0.010868644,    -0.031489216,  0.09525667,    0.013939797,
1599        0.0058680447,   0.0167067,     0.02668468,    -0.04797466,
1600        -0.048885044,   -0.12722108,   0.035304096,   0.06554885,
1601        0.00972396,     -0.039238118,  -0.05159735,   -0.11329045,
1602        0.1613692,      -0.03750952,   0.06529313,    -0.071974665,
1603        -0.11769596,    0.015524369,   -0.0013754242, -0.12446318,
1604        0.02786344,     -0.014179351,  0.005264273,   0.14376344,
1605        0.015983658,    0.03406988,    -0.06939408,   0.040699873,
1606        0.02111075,     0.09669095,    0.041345075,   -0.08316494,
1607        -0.07684199,    -0.045768797,  0.032298047,   -0.041805092,
1608        0.0119405,      0.0061010392,  0.12652606,    0.0064572375,
1609        -0.024950314,   0.11574242,    0.04508852,    -0.04335324,
1610        0.06760663,     -0.027437469,  0.07216407,    0.06977076,
1611        -0.05438599,    0.034033038,   -0.028602652,  0.05346137,
1612        0.043184172,    -0.037189785,  0.10420091,    0.00882477,
1613        -0.054019816,   -0.074273005,  -0.030617684,  -0.0028467078,
1614        0.024302477,    -0.0038869337, 0.005332455,   0.0013399826,
1615        0.04361412,     -0.007001822,  0.09631092,    -0.06702025,
1616        -0.042049985,   -0.035070654,  -0.04103342,   -0.10273396,
1617        0.0544271,      0.037184782,   -0.13150354,   -0.0058036847,
1618        -0.008264958,   0.042035464,   0.05891794,    0.029673764,
1619        0.0063542654,   0.044788733,   0.054816857,   0.062257513,
1620        -0.00093483756, 0.048938446,   -0.004952862,  -0.007730018,
1621        -0.04043371,    -0.017094059,  0.07229206,    -0.023670016,
1622        -0.052195564,   -0.025616996,  -0.01520939,   0.045104615,
1623        -0.007376126,   0.003533447,   0.006570588,   0.056037236,
1624        0.12436656,     0.051817212,   0.028532185,   -0.08686856,
1625        0.11868599,     0.07663395,    -0.07323171,   0.03463402,
1626        -0.050708205,   -0.04458982,   -0.11590894,   0.021273347,
1627        0.1251325,      -0.15313013,   -0.12224372,   0.17228661,
1628        0.023029093,    0.086124025,   0.006445803,   -0.03496501,
1629        0.028332196,    0.04449512,    -0.042436164,  -0.026587414,
1630        -0.006041347,   -0.09292539,   -0.05678812,   0.03897832,
1631        0.09465633,     0.008115513,   -0.02171956,   0.08304309,
1632        0.071401566,    0.019622514,   0.032163795,   -0.004167056,
1633        0.02295182,     0.030739572,   0.056506045,   0.004612461,
1634        0.06524936,     0.059999723,   0.046395954,   -0.0045512207,
1635        -0.1335546,     -0.030136576,  0.11584653,    -0.014678886,
1636        0.0020118146,   -0.09688814,   -0.0790206,    0.039770417,
1637        -0.0329582,     0.07922767,    0.029322514,   0.026405897,
1638        0.04207835,     -0.07073373,   0.063781224,   0.0859677,
1639        -0.10925287,    -0.07011058,   0.048005477,   0.03438226,
1640        -0.09606514,    -0.006669445,  -0.043381985,  0.04240257,
1641        -0.06955775,    -0.06769346,   0.043903265,   -0.026784198,
1642        -0.017840602,   0.024307009,   -0.040079936,  -0.019946516,
1643        0.045318738,    -0.12233574,   0.026170589,   0.0074471775,
1644        0.15978073,     0.10185836,    0.10298046,    -0.015476589,
1645        -0.039390966,   -0.072174534,  0.0739445,     -0.1211869,
1646        -0.0347889,     -0.07943156,   0.014809798,   -0.12412325,
1647        -0.0030663363,  0.039695457,   0.0647603,     -0.08291318,
1648        -0.018529687,   -0.004423833,  0.0037507233,  0.084633216,
1649        -0.01514876,    -0.056505352,  -0.012800942,  -0.06994386,
1650        0.012962922,    -0.031234352,  0.07029052,    0.016418684,
1651        0.03618972,     0.055686004,   -0.08663945,   -0.017404709,
1652        -0.054761406,   0.029065743,   0.052404847,   0.020238016,
1653        0.0048197987,   -0.0214882,    0.07078733,    0.013016777,
1654        0.06262858,     0.009184685,   0.020785125,   -0.043904778,
1655        -0.0270329,     -0.03299152,   -0.060088247,  -0.015162964,
1656        -0.001828936,   0.12642565,    -0.056757294,  0.013586685,
1657        0.09232601,     -0.035886683,  0.06000002,    0.05229691,
1658        -0.052580316,   -0.082029596,  -0.010794592,  0.012947712,
1659        -0.036429964,   -0.085508935,  -0.13127148,   -0.017744139,
1660        0.031502828,    0.036232427,   -0.031581745,  0.023051167,
1661        -0.05325106,    -0.03421577,   0.028793324,   -0.034633752,
1662        -0.009881397,   -0.043551125,  -0.018609839,  0.0019097115,
1663        -0.008799762,   0.056595087,   0.0022273948,  0.055752404});
1664 
1665   lstm.SetRecurrentToOutputWeights({
1666       0.025825322,   -0.05813119,  0.09495884,   -0.045984812,   -0.01255415,
1667       -0.0026479573, -0.08196161,  -0.054914974, -0.0046604523,  -0.029587349,
1668       -0.044576716,  -0.07480124,  -0.082868785, 0.023254942,    0.027502948,
1669       -0.0039728214, -0.08683098,  -0.08116779,  -0.014675607,   -0.037924774,
1670       -0.023314456,  -0.007401714, -0.09255757,  0.029460307,    -0.08829125,
1671       -0.005139627,  -0.08989442,  -0.0555066,   0.13596267,     -0.025062224,
1672       -0.048351806,  -0.03850004,  0.07266485,   -0.022414139,   0.05940088,
1673       0.075114764,   0.09597592,   -0.010211725, -0.0049794707,  -0.011523867,
1674       -0.025980417,  0.072999895,  0.11091378,   -0.081685916,   0.014416728,
1675       0.043229222,   0.034178585,  -0.07530371,  0.035837382,    -0.085607,
1676       -0.007721233,  -0.03287832,  -0.043848954, -0.06404588,    -0.06632928,
1677       -0.073643476,  0.008214239,  -0.045984086, 0.039764922,    0.03474462,
1678       0.060612556,   -0.080590084, 0.049127717,  0.04151091,     -0.030063879,
1679       0.008801774,   -0.023021035, -0.019558564, 0.05158114,     -0.010947698,
1680       -0.011825728,  0.0075720972, 0.0699727,    -0.0039981045,  0.069350146,
1681       0.08799282,    0.016156472,  0.035502106,  0.11695009,     0.006217345,
1682       0.13392477,    -0.037875112, 0.025745004,  0.08940699,     -0.00924166,
1683       0.0046702605,  -0.036598757, -0.08811812,  0.10522024,     -0.032441203,
1684       0.008176899,   -0.04454919,  0.07058152,   0.0067963637,   0.039206743,
1685       0.03259838,    0.03725492,   -0.09515802,  0.013326398,    -0.052055415,
1686       -0.025676316,  0.03198509,   -0.015951829, -0.058556724,   0.036879618,
1687       0.043357447,   0.028362012,  -0.05908629,  0.0059240665,   -0.04995891,
1688       -0.019187413,  0.0276265,    -0.01628143,  0.0025863599,   0.08800015,
1689       0.035250366,   -0.022165963, -0.07328642,  -0.009415526,   -0.07455109,
1690       0.11690406,    0.0363299,    0.07411125,   0.042103454,    -0.009660886,
1691       0.019076364,   0.018299393,  -0.046004917, 0.08891175,     0.0431396,
1692       -0.026327137,  -0.051502608, 0.08979574,   -0.051670972,   0.04940282,
1693       -0.07491107,   -0.021240504, 0.022596184,  -0.034280192,   0.060163025,
1694       -0.058211457,  -0.051837247, -0.01349775,  -0.04639988,    -0.035936575,
1695       -0.011681591,  0.064818054,  0.0073146066, -0.021745546,   -0.043124277,
1696       -0.06471268,   -0.07053354,  -0.029321948, -0.05330136,    0.016933719,
1697       -0.053782392,  0.13747959,   -0.1361751,   -0.11569455,    0.0033329215,
1698       0.05693899,    -0.053219706, 0.063698,     0.07977434,     -0.07924483,
1699       0.06936997,    0.0034815092, -0.007305279, -0.037325785,   -0.07251102,
1700       -0.033633437,  -0.08677009,  0.091591336,  -0.14165086,    0.021752775,
1701       0.019683983,   0.0011612234, -0.058154266, 0.049996935,    0.0288841,
1702       -0.0024567875, -0.14345716,  0.010955264,  -0.10234828,    0.1183656,
1703       -0.0010731248, -0.023590032, -0.072285876, -0.0724771,     -0.026382286,
1704       -0.0014920527, 0.042667855,  0.0018776858, 0.02986552,     0.009814309,
1705       0.0733756,     0.12289186,   0.018043943,  -0.0458958,     0.049412545,
1706       0.033632483,   0.05495232,   0.036686596,  -0.013781798,   -0.010036754,
1707       0.02576849,    -0.08307328,  0.010112348,  0.042521734,    -0.05869831,
1708       -0.071689695,  0.03876447,   -0.13275425,  -0.0352966,     -0.023077697,
1709       0.10285965,    0.084736146,  0.15568255,   -0.00040734606, 0.027835453,
1710       -0.10292561,   -0.032401145, 0.10053256,   -0.026142767,   -0.08271222,
1711       -0.0030240538, -0.016368777, 0.1070414,    0.042672627,    0.013456989,
1712       -0.0437609,    -0.022309763, 0.11576483,   0.04108048,     0.061026827,
1713       -0.0190714,    -0.0869359,   0.037901703,  0.0610107,      0.07202949,
1714       0.01675338,    0.086139716,  -0.08795751,  -0.014898893,   -0.023771819,
1715       -0.01965048,   0.007955471,  -0.043740474, 0.03346837,     -0.10549954,
1716       0.090567775,   0.042013682,  -0.03176985,  0.12569028,     -0.02421228,
1717       -0.029526481,  0.023851605,  0.031539805,  0.05292009,     -0.02344001,
1718       -0.07811758,   -0.08834428,  0.10094801,   0.16594367,     -0.06861939,
1719       -0.021256343,  -0.041093912, -0.06669611,  0.035498552,    0.021757556,
1720       -0.09302526,   -0.015403468, -0.06614931,  -0.051798206,   -0.013874718,
1721       0.03630673,    0.010412845,  -0.08077351,  0.046185967,    0.0035662893,
1722       0.03541868,    -0.094149634, -0.034814864, 0.003128424,    -0.020674974,
1723       -0.03944324,   -0.008110165, -0.11113267,  0.08484226,     0.043586485,
1724       0.040582247,   0.0968012,    -0.065249965, -0.028036479,   0.0050708856,
1725       0.0017462453,  0.0326779,    0.041296225,  0.09164146,     -0.047743853,
1726       -0.015952192,  -0.034451712, 0.084197424,  -0.05347844,    -0.11768019,
1727       0.085926116,   -0.08251791,  -0.045081906, 0.0948852,      0.068401024,
1728       0.024856757,   0.06978981,   -0.057309967, -0.012775832,   -0.0032452994,
1729       0.01977615,    -0.041040014, -0.024264973, 0.063464895,    0.05431621,
1730   });
1731 
1732   lstm.SetCellToInputWeights(
1733       {0.040369894, 0.030746894,  0.24704495,  0.018586371,  -0.037586458,
1734        -0.15312155, -0.11812848,  -0.11465643, 0.20259799,   0.11418174,
1735        -0.10116027, -0.011334949, 0.12411352,  -0.076769054, -0.052169047,
1736        0.21198851,  -0.38871562,  -0.09061183, -0.09683246,  -0.21929175});
1737 
1738   lstm.SetCellToForgetWeights(
1739       {-0.01998659,  -0.15568835,  -0.24248174,   -0.012770197, 0.041331276,
1740        -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
1741        -0.047248036, 0.021479502,  0.033189066,   0.11952997,   -0.020432774,
1742        0.64658105,   -0.06650122,  -0.03467612,   0.095340036,  0.23647355});
1743 
1744   lstm.SetCellToOutputWeights(
1745       {0.08286371,  -0.08261836, -0.51210177, 0.002913762, 0.17764764,
1746        -0.5495371,  -0.08460716, -0.24552552, 0.030037103, 0.04123544,
1747        -0.11940523, 0.007358328, 0.1890978,   0.4833202,   -0.34441817,
1748        0.36312827,  -0.26375428, 0.1457655,   -0.19724406, 0.15548733});
1749 
1750   lstm.SetProjectionWeights(
1751       {-0.009802181,  0.09401916,    0.0717386,     -0.13895074,  0.09641832,
1752        0.060420845,   0.08539281,    0.054285463,   0.061395317,  0.034448683,
1753        -0.042991187,  0.019801661,   -0.16840284,   -0.015726732, -0.23041931,
1754        -0.024478018,  -0.10959692,   -0.013875541,  0.18600968,   -0.061274476,
1755        0.0138165,     -0.08160894,   -0.07661644,   0.032372914,  0.16169067,
1756        0.22465782,    -0.03993472,   -0.004017731,  0.08633481,   -0.28869787,
1757        0.08682067,    0.17240396,    0.014975425,   0.056431185,  0.031037588,
1758        0.16702051,    0.0077946745,  0.15140012,    0.29405436,   0.120285,
1759        -0.188994,     -0.027265169,  0.043389652,   -0.022061434, 0.014777949,
1760        -0.20203483,   0.094781205,   0.19100232,    0.13987629,   -0.036132768,
1761        -0.06426278,   -0.05108664,   0.13221376,    0.009441198,  -0.16715929,
1762        0.15859416,    -0.040437475,  0.050779544,   -0.022187516, 0.012166504,
1763        0.027685808,   -0.07675938,   -0.0055694645, -0.09444123,  0.0046453946,
1764        0.050794356,   0.10770313,    -0.20790008,   -0.07149004,  -0.11425117,
1765        0.008225835,   -0.035802525,  0.14374903,    0.15262283,   0.048710253,
1766        0.1847461,     -0.007487823,  0.11000021,    -0.09542012,  0.22619456,
1767        -0.029149994,  0.08527916,    0.009043713,   0.0042746216, 0.016261552,
1768        0.022461696,   0.12689082,    -0.043589946,  -0.12035478,  -0.08361797,
1769        -0.050666027,  -0.1248618,    -0.1275799,    -0.071875185, 0.07377272,
1770        0.09944291,    -0.18897448,   -0.1593054,    -0.06526116,  -0.040107165,
1771        -0.004618631,  -0.067624845,  -0.007576253,  0.10727444,   0.041546922,
1772        -0.20424393,   0.06907816,    0.050412357,   0.00724631,   0.039827548,
1773        0.12449835,    0.10747581,    0.13708383,    0.09134148,   -0.12617786,
1774        -0.06428341,   0.09956831,    0.1208086,     -0.14676677,  -0.0727722,
1775        0.1126304,     0.010139365,   0.015571211,   -0.038128063, 0.022913318,
1776        -0.042050496,  0.16842307,    -0.060597885,  0.10531834,   -0.06411776,
1777        -0.07451711,   -0.03410368,   -0.13393489,   0.06534304,   0.003620307,
1778        0.04490757,    0.05970546,    0.05197996,    0.02839995,   0.10434969,
1779        -0.013699693,  -0.028353551,  -0.07260381,   0.047201227,  -0.024575593,
1780        -0.036445823,  0.07155557,    0.009672501,   -0.02328883,  0.009533515,
1781        -0.03606021,   -0.07421458,   -0.028082801,  -0.2678904,   -0.13221288,
1782        0.18419984,    -0.13012612,   -0.014588381,  -0.035059117, -0.04824723,
1783        0.07830115,    -0.056184657,  0.03277091,    0.025466874,  0.14494097,
1784        -0.12522776,   -0.098633975,  -0.10766018,   -0.08317623,  0.08594209,
1785        0.07749552,    0.039474737,   0.1776665,     -0.07409566,  -0.0477268,
1786        0.29323658,    0.10801441,    0.1154011,     0.013952499,  0.10739139,
1787        0.10708251,    -0.051456142,  0.0074137426,  -0.10430189,  0.10034707,
1788        0.045594677,   0.0635285,     -0.0715442,    -0.089667566, -0.10811871,
1789        0.00026344223, 0.08298446,    -0.009525053,  0.006585689,  -0.24567553,
1790        -0.09450807,   0.09648481,    0.026996298,   -0.06419476,  -0.04752702,
1791        -0.11063944,   -0.23441927,   -0.17608605,   -0.052156363, 0.067035615,
1792        0.19271925,    -0.0032889997, -0.043264326,  0.09663576,   -0.057112187,
1793        -0.10100678,   0.0628376,     0.04447668,    0.017961001,  -0.10094388,
1794        -0.10190601,   0.18335468,    0.10494553,    -0.052095775, -0.0026118709,
1795        0.10539724,    -0.04383912,   -0.042349473,  0.08438151,   -0.1947263,
1796        0.02251204,    0.11216432,    -0.10307853,   0.17351969,   -0.039091777,
1797        0.08066188,    -0.00561982,   0.12633002,    0.11335965,   -0.0088127935,
1798        -0.019777594,  0.06864014,    -0.059751723,  0.016233567,  -0.06894641,
1799        -0.28651384,   -0.004228674,  0.019708522,   -0.16305895,  -0.07468996,
1800        -0.0855457,    0.099339016,   -0.07580735,   -0.13775392,  0.08434318,
1801        0.08330512,    -0.12131499,   0.031935584,   0.09180414,   -0.08876437,
1802        -0.08049874,   0.008753825,   0.03498998,    0.030215185,  0.03907079,
1803        0.089751154,   0.029194152,   -0.03337423,   -0.019092513, 0.04331237,
1804        0.04299654,    -0.036394123,  -0.12915532,   0.09793732,   0.07512415,
1805        -0.11319543,   -0.032502122,  0.15661901,    0.07671967,   -0.005491124,
1806        -0.19379048,   -0.218606,     0.21448623,    0.017840758,  0.1416943,
1807        -0.07051762,   0.19488361,    0.02664691,    -0.18104725,  -0.09334311,
1808        0.15026465,    -0.15493552,   -0.057762887,  -0.11604192,  -0.262013,
1809        -0.01391798,   0.012185008,   0.11156489,    -0.07483202,  0.06693364,
1810        -0.26151478,   0.046425626,   0.036540434,   -0.16435726,  0.17338543,
1811        -0.21401681,   -0.11385144,   -0.08283257,   -0.069031075, 0.030635102,
1812        0.010969227,   0.11109743,    0.010919218,   0.027526086,  0.13519906,
1813        0.01891392,    -0.046839405,  -0.040167913,  0.017953383,  -0.09700955,
1814        0.0061885654,  -0.07000971,   0.026893595,   -0.038844477, 0.14543656});
1815 
1816   static float lstm_input[][20] = {
1817       {// Batch0: 4 (input_sequence_size) * 5 (n_input)
1818        0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
1819        0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
1820        0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
1821 
1822       {// Batch1: 4 (input_sequence_size) * 5 (n_input)
1823        0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
1824        0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
1825        0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
1826 
1827   static float lstm_fw_golden_output[][64] = {
1828       {// Batch0: 4 (input_sequence_size) * 16 (n_output)
1829        -0.00396806, 0.029352,     -0.00279226, 0.0159977,   -0.00835576,
1830        -0.0211779,  0.0283512,    -0.0114597,  0.00907307,  -0.0244004,
1831        -0.0152191,  -0.0259063,   0.00914318,  0.00415118,  0.017147,
1832        0.0134203,   -0.0166936,   0.0381209,   0.000889694, 0.0143363,
1833        -0.0328911,  -0.0234288,   0.0333051,   -0.012229,   0.0110322,
1834        -0.0457725,  -0.000832209, -0.0202817,  0.0327257,   0.0121308,
1835        0.0155969,   0.0312091,    -0.0213783,  0.0350169,   0.000324794,
1836        0.0276012,   -0.0263374,   -0.0371449,  0.0446149,   -0.0205474,
1837        0.0103729,   -0.0576349,   -0.0150052,  -0.0292043,  0.0376827,
1838        0.0136115,   0.0243435,    0.0354492,   -0.0189322,  0.0464512,
1839        -0.00251373, 0.0225745,    -0.0308346,  -0.0317124,  0.0460407,
1840        -0.0189395,  0.0149363,    -0.0530162,  -0.0150767,  -0.0340193,
1841        0.0286833,   0.00824207,   0.0264887,   0.0305169},
1842       {// Batch1: 4 (input_sequence_size) * 16 (n_output)
1843        -0.013869,    0.0287268,   -0.00334693, 0.00733398,  -0.0287926,
1844        -0.0186926,   0.0193662,   -0.0115437,  0.00422612,  -0.0345232,
1845        0.00223253,   -0.00957321, 0.0210624,   0.013331,    0.0150954,
1846        0.02168,      -0.0141913,  0.0322082,   0.00227024,  0.0260507,
1847        -0.0188721,   -0.0296489,  0.0399134,   -0.0160509,  0.0116039,
1848        -0.0447318,   -0.0150515,  -0.0277406,  0.0316596,   0.0118233,
1849        0.0214762,    0.0293641,   -0.0204549,  0.0450315,   -0.00117378,
1850        0.0167673,    -0.0375007,  -0.0238314,  0.038784,    -0.0174034,
1851        0.0131743,    -0.0506589,  -0.0048447,  -0.0240239,  0.0325789,
1852        0.00790065,   0.0220157,   0.0333314,   -0.0264787,  0.0387855,
1853        -0.000764675, 0.0217599,   -0.037537,   -0.0335206,  0.0431679,
1854        -0.0211424,   0.010203,    -0.062785,   -0.00832363, -0.025181,
1855        0.0412031,    0.0118723,   0.0239643,   0.0394009}};
1856 
1857   static float lstm_combined_golden_output[][64] = {
1858       {-0.022014, 0.073544,  -0.002235, 0.040068,  -0.037136, -0.052788,
1859        0.075325,  -0.029378, 0.024298,  -0.07733,  -0.030674, -0.060229,
1860        0.040599,  0.011608,  0.042005,  0.045977,  -0.039225, 0.076294,
1861        0.000735,  0.032852,  -0.069869, -0.053312, 0.073527,  -0.028136,
1862        0.021585,  -0.102679, -0.004327, -0.043304, 0.072861,  0.027077,
1863        0.034558,  0.068292,  -0.036292, 0.069832,  -0.003032, 0.053829,
1864        -0.043821, -0.072713, 0.085029,  -0.040374, 0.020014,  -0.104521,
1865        -0.034504, -0.059759, 0.062569,  0.025652,  0.049306,  0.061189,
1866        -0.025146, 0.079643,  -0.005188, 0.033080,  -0.048079, -0.048082,
1867        0.069369,  -0.028900, 0.024572,  -0.077547, -0.022517, -0.054477,
1868        0.038857,  0.013336,  0.043234,  0.044788},
1869       {-0.039186, 0.070792,  -0.005913, 0.02642,   -0.068274, -0.05022,
1870        0.061444,  -0.031241, 0.014996,  -0.094544, -0.004146, -0.03464,
1871        0.058981,  0.026097,  0.039781,  0.058408,  -0.031887, 0.069252,
1872        0.00576,   0.054062,  -0.042801, -0.059974, 0.085272,  -0.034453,
1873        0.026097,  -0.0959,   -0.031164, -0.058699, 0.06839,   0.020512,
1874        0.044727,  0.063609,  -0.039863, 0.084819,  -0.003909, 0.028666,
1875        -0.075677, -0.045125, 0.070379,  -0.033895, 0.022111,  -0.097184,
1876        -0.004921, -0.040851, 0.062316,  0.017435,  0.041437,  0.064568,
1877        -0.039656, 0.060726,  -0.003402, 0.036854,  -0.056503, -0.058554,
1878        0.068588,  -0.034879, 0.01352,   -0.09962,  -0.01434,  -0.039505,
1879        0.065133,  0.024321,  0.038473,  0.062438}};
1880 
1881   for (int i = 0; i < lstm.sequence_length(); i++) {
1882     float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
1883     float* batch0_end = batch0_start + lstm.num_inputs();
1884 
1885     lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end);
1886 
1887     float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
1888     float* batch1_end = batch1_start + lstm.num_inputs();
1889     lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end);
1890   }
1891 
1892   ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
1893 
1894   std::vector<float> expected;
1895   for (int i = 0; i < lstm.sequence_length(); i++) {
1896     float* golden_start_batch0 =
1897         lstm_fw_golden_output[0] + i * lstm.num_fw_outputs();
1898     float* golden_end_batch0 = golden_start_batch0 + lstm.num_fw_outputs();
1899     float* golden_start_batch1 =
1900         lstm_fw_golden_output[1] + i * lstm.num_fw_outputs();
1901     float* golden_end_batch1 = golden_start_batch1 + lstm.num_fw_outputs();
1902     expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
1903     expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
1904   }
1905   EXPECT_THAT(lstm.GetFwOutput(), ElementsAreArray(ArrayFloatNear(expected)));
1906 
1907   // Check if the sum of forward backward matches the golden.
1908   expected.clear();
1909   for (int i = 0; i < lstm.sequence_length(); i++) {
1910     float* golden_start_batch0 =
1911         lstm_combined_golden_output[0] + i * lstm.num_fw_outputs();
1912     float* golden_end_batch0 = golden_start_batch0 + lstm.num_fw_outputs();
1913     float* golden_start_batch1 =
1914         lstm_combined_golden_output[1] + i * lstm.num_fw_outputs();
1915     float* golden_end_batch1 = golden_start_batch1 + lstm.num_fw_outputs();
1916     expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
1917     expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
1918   }
1919 
1920   std::vector<float> combined;
1921   for (int i = 0; i < lstm.GetFwOutput().size(); ++i) {
1922     combined.push_back(lstm.GetFwOutput()[i] + lstm.GetBwOutput()[i]);
1923   }
1924   EXPECT_THAT(combined, ElementsAreArray(ArrayFloatNear(expected)));
1925 }
1926 
1927 // Same as above but with batch_major input/output.
TEST(LSTMOpTest,BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor)1928 TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor) {
1929   const int n_batch = 2;
1930   const int n_input = 5;
1931   const int n_cell = 20;
1932   const int n_output = 16;
1933   const int sequence_length = 4;
1934 
1935   BidirectionalLSTMOpModel lstm(
1936       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
1937       /*use_peephole=*/true, /*use_projection_weights=*/true,
1938       /*use_projection_bias=*/false, /*merge_outputs=*/false,
1939       /*use_aux_input=*/false, /*cell_clip=*/0.0,
1940       /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/false,
1941       {
1942           {n_batch, sequence_length, n_input},  // input tensor
1943 
1944           {n_cell, n_input},  // input_to_input_weight tensor
1945           {n_cell, n_input},  // input_to_forget_weight tensor
1946           {n_cell, n_input},  // input_to_cell_weight tensor
1947           {n_cell, n_input},  // input_to_output_weight tensor
1948 
1949           {n_cell, n_output},  // recurrent_to_input_weight tensor
1950           {n_cell, n_output},  // recurrent_to_forget_weight tensor
1951           {n_cell, n_output},  // recurrent_to_cell_weight tensor
1952           {n_cell, n_output},  // recurrent_to_output_weight tensor
1953 
1954           {n_cell},  // cell_to_input_weight tensor
1955           {n_cell},  // cell_to_forget_weight tensor
1956           {n_cell},  // cell_to_output_weight tensor
1957 
1958           {n_cell},  // input_gate_bias tensor
1959           {n_cell},  // forget_gate_bias tensor
1960           {n_cell},  // cell_gate_bias tensor
1961           {n_cell},  // output_gate_bias tensor
1962 
1963           {n_output, n_cell},  // projection_weight tensor
1964           {0},                 // projection_bias tensor
1965 
1966           {n_cell, n_input},  // input_to_input_weight tensor
1967           {n_cell, n_input},  // input_to_forget_weight tensor
1968           {n_cell, n_input},  // input_to_cell_weight tensor
1969           {n_cell, n_input},  // input_to_output_weight tensor
1970 
1971           {n_cell, n_output},  // recurrent_to_input_weight tensor
1972           {n_cell, n_output},  // recurrent_to_forget_weight tensor
1973           {n_cell, n_output},  // recurrent_to_cell_weight tensor
1974           {n_cell, n_output},  // recurrent_to_output_weight tensor
1975 
1976           {n_cell},  // cell_to_input_weight tensor
1977           {n_cell},  // cell_to_forget_weight tensor
1978           {n_cell},  // cell_to_output_weight tensor
1979 
1980           {n_cell},  // input_gate_bias tensor
1981           {n_cell},  // forget_gate_bias tensor
1982           {n_cell},  // cell_gate_bias tensor
1983           {n_cell},  // output_gate_bias tensor
1984 
1985           {n_output, n_cell},  // projection_weight tensor
1986           {0},                 // projection_bias tensor
1987 
1988           {n_batch, n_output},  // activation_state tensor
1989           {n_batch, n_cell},    // cell_state tensor
1990 
1991           {n_batch, n_output},  // activation_state tensor
1992           {n_batch, n_cell},    // cell_state tensor
1993 
1994           {n_batch, sequence_length, 0},  // aux_input tensor
1995           {0},                            // aux_fw_input_to_input tensor
1996           {0},                            // aux_fw_input_to_forget tensor
1997           {0},                            // aux_fw_input_to_cell tensor
1998           {0},                            // aux_fw_input_to_output tensor
1999           {0},                            // aux_bw_input_to_input tensor
2000           {0},                            // aux_bw_input_to_forget tensor
2001           {0},                            // aux_bw_input_to_cell tensor
2002           {0},                            // aux_bw_input_to_output tensor
2003       });
2004 
2005   lstm.SetInputToInputWeights(
2006       {0.021393683,  0.06124551,    0.046905167,  -0.014657677,  -0.03149463,
2007        0.09171803,   0.14647801,    0.10797193,   -0.0057968358, 0.0019193048,
2008        -0.2726754,   0.10154029,    -0.018539885, 0.080349885,   -0.10262385,
2009        -0.022599787, -0.09121155,   -0.008675967, -0.045206103,  -0.0821282,
2010        -0.008045952, 0.015478081,   0.055217247,  0.038719587,   0.044153627,
2011        -0.06453243,  0.05031825,    -0.046935108, -0.008164439,  0.014574226,
2012        -0.1671009,   -0.15519552,   -0.16819797,  -0.13971269,   -0.11953059,
2013        0.25005487,   -0.22790983,   0.009855087,  -0.028140958,  -0.11200698,
2014        0.11295408,   -0.0035217577, 0.054485075,  0.05184695,    0.064711206,
2015        0.10989193,   0.11674786,    0.03490607,   0.07727357,    0.11390585,
2016        -0.1863375,   -0.1034451,    -0.13945189,  -0.049401227,  -0.18767063,
2017        0.042483903,  0.14233552,    0.13832581,   0.18350165,    0.14545603,
2018        -0.028545704, 0.024939531,   0.050929718,  0.0076203286,  -0.0029723682,
2019        -0.042484224, -0.11827596,   -0.09171104,  -0.10808628,   -0.16327988,
2020        -0.2273378,   -0.0993647,    -0.017155107, 0.0023917493,  0.049272764,
2021        0.0038534778, 0.054764505,   0.089753784,  0.06947234,    0.08014476,
2022        -0.04544234,  -0.0497073,    -0.07135631,  -0.048929106,  -0.004042012,
2023        -0.009284026, 0.018042054,   0.0036860977, -0.07427302,   -0.11434604,
2024        -0.018995456, 0.031487543,   0.012834908,  0.019977754,   0.044256654,
2025        -0.39292613,  -0.18519334,   -0.11651281,  -0.06809892,   0.011373677});
2026 
2027   lstm.SetInputToForgetWeights(
2028       {-0.0018401089, -0.004852237,  0.03698424,   0.014181704,   0.028273236,
2029        -0.016726194,  -0.05249759,   -0.10204261,  0.00861066,    -0.040979505,
2030        -0.009899187,  0.01923892,    -0.028177269, -0.08535103,   -0.14585495,
2031        0.10662567,    -0.01909731,   -0.017883534, -0.0047269356, -0.045103323,
2032        0.0030784295,  0.076784775,   0.07463696,   0.094531395,   0.0814421,
2033        -0.12257899,   -0.033945758,  -0.031303465, 0.045630626,   0.06843887,
2034        -0.13492945,   -0.012480007,  -0.0811829,   -0.07224499,   -0.09628791,
2035        0.045100946,   0.0012300825,  0.013964662,  0.099372394,   0.02543059,
2036        0.06958324,    0.034257296,   0.0482646,    0.06267997,    0.052625068,
2037        0.12784666,    0.07077897,    0.025725935,  0.04165009,    0.07241905,
2038        0.018668644,   -0.037377294,  -0.06277783,  -0.08833636,   -0.040120605,
2039        -0.011405586,  -0.007808335,  -0.010301386, -0.005102167,  0.027717464,
2040        0.05483423,    0.11449111,    0.11289652,   0.10939839,    0.13396506,
2041        -0.08402166,   -0.01901462,   -0.044678304, -0.07720565,   0.014350063,
2042        -0.11757958,   -0.0652038,    -0.08185733,  -0.076754324,  -0.092614375,
2043        0.10405491,    0.052960336,   0.035755895,  0.035839386,   -0.012540553,
2044        0.036881298,   0.02913376,    0.03420159,   0.05448447,    -0.054523353,
2045        0.02582715,    0.02327355,    -0.011857179, -0.0011980024, -0.034641717,
2046        -0.026125094,  -0.17582615,   -0.15923657,  -0.27486774,   -0.0006143371,
2047        0.0001771948,  -8.470171e-05, 0.02651807,   0.045790765,   0.06956496});
2048 
2049   lstm.SetInputToCellWeights(
2050       {-0.04580283,   -0.09549462,   -0.032418985,  -0.06454633,
2051        -0.043528453,  0.043018587,   -0.049152344,  -0.12418144,
2052        -0.078985475,  -0.07596889,   0.019484362,   -0.11434962,
2053        -0.0074034138, -0.06314844,   -0.092981495,  0.0062155537,
2054        -0.025034338,  -0.0028890965, 0.048929527,   0.06235075,
2055        0.10665918,    -0.032036792,  -0.08505916,   -0.10843358,
2056        -0.13002433,   -0.036816437,  -0.02130134,   -0.016518239,
2057        0.0047691227,  -0.0025825808, 0.066017866,   0.029991534,
2058        -0.10652836,   -0.1037554,    -0.13056071,   -0.03266643,
2059        -0.033702414,  -0.006473424,  -0.04611692,   0.014419339,
2060        -0.025174323,  0.0396852,     0.081777506,   0.06157468,
2061        0.10210095,    -0.009658194,  0.046511717,   0.03603906,
2062        0.0069369148,  0.015960095,   -0.06507666,   0.09551598,
2063        0.053568836,   0.06408714,    0.12835667,    -0.008714329,
2064        -0.20211966,   -0.12093674,   0.029450472,   0.2849013,
2065        -0.029227901,  0.1164364,     -0.08560263,   0.09941786,
2066        -0.036999565,  -0.028842626,  -0.0033637602, -0.017012902,
2067        -0.09720865,   -0.11193351,   -0.029155117,  -0.017936034,
2068        -0.009768936,  -0.04223324,   -0.036159635,  0.06505112,
2069        -0.021742892,  -0.023377212,  -0.07221364,   -0.06430552,
2070        0.05453865,    0.091149814,   0.06387331,    0.007518393,
2071        0.055960953,   0.069779344,   0.046411168,   0.10509911,
2072        0.07463894,    0.0075130584,  0.012850982,   0.04555431,
2073        0.056955688,   0.06555285,    0.050801456,   -0.009862683,
2074        0.00826772,    -0.026555609,  -0.0073611983, -0.0014897042});
2075 
2076   lstm.SetInputToOutputWeights(
2077       {-0.0998932,   -0.07201956,  -0.052803773,  -0.15629593,  -0.15001918,
2078        -0.07650751,  0.02359855,   -0.075155355,  -0.08037709,  -0.15093534,
2079        0.029517552,  -0.04751393,  0.010350531,   -0.02664851,  -0.016839722,
2080        -0.023121163, 0.0077019283, 0.012851257,   -0.05040649,  -0.0129761,
2081        -0.021737747, -0.038305793, -0.06870586,   -0.01481247,  -0.001285394,
2082        0.10124236,   0.083122835,  0.053313006,   -0.062235646, -0.075637154,
2083        -0.027833903, 0.029774971,  0.1130802,     0.09218906,   0.09506135,
2084        -0.086665764, -0.037162706, -0.038880914,  -0.035832845, -0.014481564,
2085        -0.09825003,  -0.12048569,  -0.097665586,  -0.05287633,  -0.0964047,
2086        -0.11366429,  0.035777505,  0.13568819,    0.052451383,  0.050649304,
2087        0.05798951,   -0.021852335, -0.099848844,  0.014740475,  -0.078897946,
2088        0.04974699,   0.014160473,  0.06973932,    0.04964942,   0.033364646,
2089        0.08190124,   0.025535367,  0.050893165,   0.048514254,  0.06945813,
2090        -0.078907564, -0.06707616,  -0.11844508,   -0.09986688,  -0.07509403,
2091        0.06263226,   0.14925587,   0.20188436,    0.12098451,   0.14639415,
2092        0.0015017595, -0.014267382, -0.03417257,   0.012711468,  0.0028300495,
2093        -0.024758482, -0.05098548,  -0.0821182,    0.014225672,  0.021544158,
2094        0.08949725,   0.07505268,   -0.0020780868, 0.04908258,   0.06476295,
2095        -0.022907063, 0.027562456,  0.040185735,   0.019567577,  -0.015598739,
2096        -0.049097303, -0.017121866, -0.083368234,  -0.02332002,  -0.0840956});
2097 
2098   lstm.SetInputGateBias(
2099       {0.02234832,  0.14757581,   0.18176508,  0.10380666,  0.053110216,
2100        -0.06928846, -0.13942584,  -0.11816189, 0.19483899,  0.03652339,
2101        -0.10250295, 0.036714908,  -0.18426876, 0.036065217, 0.21810818,
2102        0.02383196,  -0.043370757, 0.08690144,  -0.04444982, 0.00030581196});
2103 
2104   lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
2105                           0.11098921,  0.15378423,   0.09263801,  0.09790885,
2106                           0.09508917,  0.061199076,  0.07665568,  -0.015443159,
2107                           -0.03499149, 0.046190713,  0.08895977,  0.10899629,
2108                           0.40694186,  0.06030037,   0.012413437, -0.06108739});
2109 
2110   lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132,   0.033463873,
2111                     -0.1483596,   -0.10639995,  -0.091433935, 0.058573797,
2112                     -0.06809782,  -0.07889636,  -0.043246906, -0.09829136,
2113                     -0.4279842,   0.034901652,  0.18797937,   0.0075234566,
2114                     0.016178843,  0.1749513,    0.13975595,   0.92058027});
2115 
2116   lstm.SetOutputGateBias(
2117       {0.046159424,  -0.0012809046, 0.03563469,   0.12648113, 0.027195795,
2118        0.35373217,   -0.018957434,  0.008907322,  -0.0762701, 0.12018895,
2119        0.04216877,   0.0022856654,  0.040952638,  0.3147856,  0.08225149,
2120        -0.057416286, -0.14995944,   -0.008040261, 0.13208859, 0.029760877});
2121 
2122   lstm.SetRecurrentToInputWeights(
2123       {-0.001374326,   -0.078856036,   0.10672688,    0.029162422,
2124        -0.11585556,    0.02557986,     -0.13446963,   -0.035785314,
2125        -0.01244275,    0.025961924,    -0.02337298,   -0.044228926,
2126        -0.055839065,   -0.046598054,   -0.010546039,  -0.06900766,
2127        0.027239809,    0.022582639,    -0.013296484,  -0.05459212,
2128        0.08981,        -0.045407712,   0.08682226,    -0.06867011,
2129        -0.14390695,    -0.02916037,    0.000996957,   0.091420636,
2130        0.14283475,     -0.07390571,    -0.06402044,   0.062524505,
2131        -0.093129106,   0.04860203,     -0.08364217,   -0.08119002,
2132        0.009352075,    0.22920375,     0.0016303885,  0.11583097,
2133        -0.13732095,    0.012405723,    -0.07551853,   0.06343048,
2134        0.12162708,     -0.031923793,   -0.014335606,  0.01790974,
2135        -0.10650317,    -0.0724401,     0.08554849,    -0.05727212,
2136        0.06556731,     -0.042729504,   -0.043227166,  0.011683251,
2137        -0.013082158,   -0.029302018,   -0.010899579,  -0.062036745,
2138        -0.022509435,   -0.00964907,    -0.01567329,   0.04260106,
2139        -0.07787477,    -0.11576462,    0.017356863,   0.048673786,
2140        -0.017577527,   -0.05527947,    -0.082487635,  -0.040137455,
2141        -0.10820036,    -0.04666372,    0.022746278,   -0.07851417,
2142        0.01068115,     0.032956902,    0.022433773,   0.0026891115,
2143        0.08944216,     -0.0685835,     0.010513544,   0.07228705,
2144        0.02032331,     -0.059686817,   -0.0005566496, -0.086984694,
2145        0.040414046,    -0.1380399,     0.094208956,   -0.05722982,
2146        0.012092817,    -0.04989123,    -0.086576,     -0.003399834,
2147        -0.04696032,    -0.045747425,   0.10091314,    0.048676282,
2148        -0.029037097,   0.031399418,    -0.0040285117, 0.047237843,
2149        0.09504992,     0.041799378,    -0.049185462,  -0.031518843,
2150        -0.10516937,    0.026374253,    0.10058866,    -0.0033195973,
2151        -0.041975245,   0.0073591834,   0.0033782164,  -0.004325073,
2152        -0.10167381,    0.042500053,    -0.01447153,   0.06464186,
2153        -0.017142897,   0.03312627,     0.009205989,   0.024138335,
2154        -0.011337001,   0.035530265,    -0.010912711,  0.0706555,
2155        -0.005894094,   0.051841937,    -0.1401738,    -0.02351249,
2156        0.0365468,      0.07590991,     0.08838724,    0.021681072,
2157        -0.10086113,    0.019608743,    -0.06195883,   0.077335775,
2158        0.023646897,    -0.095322326,   0.02233014,    0.09756986,
2159        -0.048691444,   -0.009579111,   0.07595467,    0.11480546,
2160        -0.09801813,    0.019894179,    0.08502348,    0.004032281,
2161        0.037211012,    0.068537936,    -0.048005626,  -0.091520436,
2162        -0.028379958,   -0.01556313,    0.06554592,    -0.045599163,
2163        -0.01672207,    -0.020169014,   -0.011877351,  -0.20212261,
2164        0.010889619,    0.0047078193,   0.038385306,   0.08540671,
2165        -0.017140968,   -0.0035865551,  0.016678626,   0.005633034,
2166        0.015963363,    0.00871737,     0.060130805,   0.028611384,
2167        0.10109069,     -0.015060172,   -0.07894427,   0.06401885,
2168        0.011584063,    -0.024466386,   0.0047652307,  -0.09041358,
2169        0.030737216,    -0.0046374933,  0.14215417,    -0.11823516,
2170        0.019899689,    0.006106124,    -0.027092824,  0.0786356,
2171        0.05052217,     -0.058925,      -0.011402121,  -0.024987547,
2172        -0.0013661642,  -0.06832946,    -0.015667673,  -0.1083353,
2173        -0.00096863037, -0.06988685,    -0.053350925,  -0.027275559,
2174        -0.033664223,   -0.07978348,    -0.025200296,  -0.017207067,
2175        -0.058403496,   -0.055697463,   0.005798788,   0.12965427,
2176        -0.062582195,   0.0013350133,   -0.10482091,   0.0379771,
2177        0.072521195,    -0.0029455067,  -0.13797039,   -0.03628521,
2178        0.013806405,    -0.017858358,   -0.01008298,   -0.07700066,
2179        -0.017081132,   0.019358726,    0.0027079724,  0.004635139,
2180        0.062634714,    -0.02338735,    -0.039547626,  -0.02050681,
2181        0.03385117,     -0.083611414,   0.002862572,   -0.09421313,
2182        0.058618143,    -0.08598433,    0.00972939,    0.023867095,
2183        -0.053934585,   -0.023203006,   0.07452513,    -0.048767887,
2184        -0.07314807,    -0.056307215,   -0.10433547,   -0.06440842,
2185        0.04328182,     0.04389765,     -0.020006588,  -0.09076438,
2186        -0.11652589,    -0.021705797,   0.03345259,    -0.010329105,
2187        -0.025767034,   0.013057034,    -0.07316461,   -0.10145612,
2188        0.06358255,     0.18531723,     0.07759293,    0.12006465,
2189        0.1305557,      0.058638252,    -0.03393652,   0.09622831,
2190        -0.16253184,    -2.4580743e-06, 0.079869635,   -0.070196845,
2191        -0.005644518,   0.06857898,     -0.12598175,   -0.035084512,
2192        0.03156317,     -0.12794146,    -0.031963028,  0.04692781,
2193        0.030070418,    0.0071660685,   -0.095516115,  -0.004643372,
2194        0.040170413,    -0.062104587,   -0.0037324072, 0.0554317,
2195        0.08184801,     -0.019164372,   0.06791302,    0.034257166,
2196        -0.10307039,    0.021943003,    0.046745934,   0.0790918,
2197        -0.0265588,     -0.007824208,   0.042546265,   -0.00977924,
2198        -0.0002440307,  -0.017384544,   -0.017990116,  0.12252321,
2199        -0.014512694,   -0.08251313,    0.08861942,    0.13589665,
2200        0.026351685,    0.012641483,    0.07466548,    0.044301085,
2201        -0.045414884,   -0.051112458,   0.03444247,    -0.08502782,
2202        -0.04106223,    -0.028126027,   0.028473156,   0.10467447});
2203 
2204   lstm.SetRecurrentToForgetWeights(
2205       {-0.057784554,  -0.026057621,  -0.068447545,   -0.022581743,
2206        0.14811787,    0.10826372,    0.09471067,     0.03987225,
2207        -0.0039523416, 0.00030638507, 0.053185795,    0.10572994,
2208        0.08414449,    -0.022036452,  -0.00066928595, -0.09203576,
2209        0.032950465,   -0.10985798,   -0.023809856,   0.0021431844,
2210        -0.02196096,   -0.00326074,   0.00058621005,  -0.074678116,
2211        -0.06193199,   0.055729095,   0.03736828,     0.020123724,
2212        0.061878487,   -0.04729229,   0.034919553,    -0.07585433,
2213        -0.04421272,   -0.044019096,  0.085488975,    0.04058006,
2214        -0.06890133,   -0.030951202,  -0.024628663,   -0.07672815,
2215        0.034293607,   0.08556707,    -0.05293577,    -0.033561368,
2216        -0.04899627,   0.0241671,     0.015736353,    -0.095442444,
2217        -0.029564252,  0.016493602,   -0.035026584,   0.022337519,
2218        -0.026871363,  0.004780428,   0.0077918363,   -0.03601621,
2219        0.016435321,   -0.03263031,   -0.09543275,    -0.047392778,
2220        0.013454138,   0.028934088,   0.01685226,     -0.086110644,
2221        -0.046250615,  -0.01847454,   0.047608484,    0.07339695,
2222        0.034546845,   -0.04881143,   0.009128804,    -0.08802852,
2223        0.03761666,    0.008096139,   -0.014454086,   0.014361001,
2224        -0.023502491,  -0.0011840804, -0.07607001,    0.001856849,
2225        -0.06509276,   -0.006021153,  -0.08570962,    -0.1451793,
2226        0.060212336,   0.055259194,   0.06974018,     0.049454916,
2227        -0.027794661,  -0.08077226,   -0.016179763,   0.1169753,
2228        0.17213494,    -0.0056326236, -0.053934924,   -0.0124349,
2229        -0.11520337,   0.05409887,    0.088759385,    0.0019655675,
2230        0.0042065294,  0.03881498,    0.019844765,    0.041858196,
2231        -0.05695512,   0.047233116,   0.038937137,    -0.06542224,
2232        0.014429736,   -0.09719407,   0.13908425,     -0.05379757,
2233        0.012321099,   0.082840554,   -0.029899208,   0.044217527,
2234        0.059855383,   0.07711018,    -0.045319796,   0.0948846,
2235        -0.011724666,  -0.0033288454, -0.033542685,   -0.04764985,
2236        -0.13873616,   0.040668588,   0.034832682,    -0.015319203,
2237        -0.018715994,  0.046002675,   0.0599172,      -0.043107376,
2238        0.0294216,     -0.002314414,  -0.022424703,   0.0030315618,
2239        0.0014641669,  0.0029166266,  -0.11878115,    0.013738511,
2240        0.12375372,    -0.0006038222, 0.029104086,    0.087442465,
2241        0.052958444,   0.07558703,    0.04817258,     0.044462286,
2242        -0.015213451,  -0.08783778,   -0.0561384,     -0.003008196,
2243        0.047060397,   -0.002058388,  0.03429439,     -0.018839769,
2244        0.024734668,   0.024614193,   -0.042046934,   0.09597743,
2245        -0.0043254104, 0.04320769,    0.0064070094,   -0.0019131786,
2246        -0.02558259,   -0.022822596,  -0.023273505,   -0.02464396,
2247        -0.10991725,   -0.006240552,  0.0074488563,   0.024044557,
2248        0.04383914,    -0.046476185,  0.028658995,    0.060410924,
2249        0.050786525,   0.009452605,   -0.0073054377,  -0.024810238,
2250        0.0052906186,  0.0066939713,  -0.0020913032,  0.014515517,
2251        0.015898481,   0.021362653,   -0.030262267,   0.016587038,
2252        -0.011442813,  0.041154444,   -0.007631438,   -0.03423484,
2253        -0.010977775,  0.036152758,   0.0066366293,   0.11915515,
2254        0.02318443,    -0.041350313,  0.021485701,    -0.10906167,
2255        -0.028218046,  -0.00954771,   0.020531068,    -0.11995105,
2256        -0.03672871,   0.024019798,   0.014255957,    -0.05221243,
2257        -0.00661567,   -0.04630967,   0.033188973,    0.10107534,
2258        -0.014027541,  0.030796422,   -0.10270911,    -0.035999842,
2259        0.15443139,    0.07684145,    0.036571592,    -0.035900835,
2260        -0.0034699554, 0.06209149,    0.015920248,    -0.031122351,
2261        -0.03858649,   0.01849943,    0.13872518,     0.01503974,
2262        0.069941424,   -0.06948533,   -0.0088794185,  0.061282158,
2263        -0.047401894,  0.03100163,    -0.041533746,   -0.10430945,
2264        0.044574402,   -0.01425562,   -0.024290353,   0.034563623,
2265        0.05866852,    0.023947537,   -0.09445152,    0.035450947,
2266        0.02247216,    -0.0042998926, 0.061146557,    -0.10250651,
2267        0.020881841,   -0.06747029,   0.10062043,     -0.0023941975,
2268        0.03532124,    -0.016341697,  0.09685456,     -0.016764693,
2269        0.051808182,   0.05875331,    -0.04536488,    0.001626336,
2270        -0.028892258,  -0.01048663,   -0.009793449,   -0.017093895,
2271        0.010987891,   0.02357273,    -0.00010856845, 0.0099760275,
2272        -0.001845119,  -0.03551521,   0.0018358806,   0.05763657,
2273        -0.01769146,   0.040995963,   0.02235177,     -0.060430344,
2274        0.11475477,    -0.023854522,  0.10071741,     0.0686208,
2275        -0.014250481,  0.034261297,   0.047418304,    0.08562733,
2276        -0.030519066,  0.0060542435,  0.014653856,    -0.038836084,
2277        0.04096551,    0.032249358,   -0.08355519,    -0.026823482,
2278        0.056386515,   -0.010401743,  -0.028396193,   0.08507674,
2279        0.014410365,   0.020995233,   0.17040324,     0.11511526,
2280        0.02459721,    0.0066619175,  0.025853224,    -0.023133837,
2281        -0.081302024,  0.017264642,   -0.009585969,   0.09491168,
2282        -0.051313367,  0.054532815,   -0.014298593,   0.10657464,
2283        0.007076659,   0.10964551,    0.0409152,      0.008275321,
2284        -0.07283536,   0.07937492,    0.04192024,     -0.1075027});
2285 
2286   lstm.SetRecurrentToCellWeights(
2287       {-0.037322544,   0.018592842,   0.0056175636,  -0.06253426,
2288        0.055647098,    -0.05713207,   -0.05626563,   0.005559383,
2289        0.03375411,     -0.025757805,  -0.088049285,  0.06017052,
2290        -0.06570978,    0.007384076,   0.035123326,   -0.07920549,
2291        0.053676967,    0.044480428,   -0.07663568,   0.0071805613,
2292        0.08089997,     0.05143358,    0.038261272,   0.03339287,
2293        -0.027673481,   0.044746667,   0.028349208,   0.020090483,
2294        -0.019443132,   -0.030755889,  -0.0040000007, 0.04465846,
2295        -0.021585021,   0.0031670958,  0.0053199246,  -0.056117613,
2296        -0.10893326,    0.076739706,   -0.08509834,   -0.027997585,
2297        0.037871376,    0.01449768,    -0.09002357,   -0.06111149,
2298        -0.046195522,   0.0422062,     -0.005683705,  -0.1253618,
2299        -0.012925729,   -0.04890792,   0.06985068,    0.037654128,
2300        0.03398274,     -0.004781977,  0.007032333,   -0.031787455,
2301        0.010868644,    -0.031489216,  0.09525667,    0.013939797,
2302        0.0058680447,   0.0167067,     0.02668468,    -0.04797466,
2303        -0.048885044,   -0.12722108,   0.035304096,   0.06554885,
2304        0.00972396,     -0.039238118,  -0.05159735,   -0.11329045,
2305        0.1613692,      -0.03750952,   0.06529313,    -0.071974665,
2306        -0.11769596,    0.015524369,   -0.0013754242, -0.12446318,
2307        0.02786344,     -0.014179351,  0.005264273,   0.14376344,
2308        0.015983658,    0.03406988,    -0.06939408,   0.040699873,
2309        0.02111075,     0.09669095,    0.041345075,   -0.08316494,
2310        -0.07684199,    -0.045768797,  0.032298047,   -0.041805092,
2311        0.0119405,      0.0061010392,  0.12652606,    0.0064572375,
2312        -0.024950314,   0.11574242,    0.04508852,    -0.04335324,
2313        0.06760663,     -0.027437469,  0.07216407,    0.06977076,
2314        -0.05438599,    0.034033038,   -0.028602652,  0.05346137,
2315        0.043184172,    -0.037189785,  0.10420091,    0.00882477,
2316        -0.054019816,   -0.074273005,  -0.030617684,  -0.0028467078,
2317        0.024302477,    -0.0038869337, 0.005332455,   0.0013399826,
2318        0.04361412,     -0.007001822,  0.09631092,    -0.06702025,
2319        -0.042049985,   -0.035070654,  -0.04103342,   -0.10273396,
2320        0.0544271,      0.037184782,   -0.13150354,   -0.0058036847,
2321        -0.008264958,   0.042035464,   0.05891794,    0.029673764,
2322        0.0063542654,   0.044788733,   0.054816857,   0.062257513,
2323        -0.00093483756, 0.048938446,   -0.004952862,  -0.007730018,
2324        -0.04043371,    -0.017094059,  0.07229206,    -0.023670016,
2325        -0.052195564,   -0.025616996,  -0.01520939,   0.045104615,
2326        -0.007376126,   0.003533447,   0.006570588,   0.056037236,
2327        0.12436656,     0.051817212,   0.028532185,   -0.08686856,
2328        0.11868599,     0.07663395,    -0.07323171,   0.03463402,
2329        -0.050708205,   -0.04458982,   -0.11590894,   0.021273347,
2330        0.1251325,      -0.15313013,   -0.12224372,   0.17228661,
2331        0.023029093,    0.086124025,   0.006445803,   -0.03496501,
2332        0.028332196,    0.04449512,    -0.042436164,  -0.026587414,
2333        -0.006041347,   -0.09292539,   -0.05678812,   0.03897832,
2334        0.09465633,     0.008115513,   -0.02171956,   0.08304309,
2335        0.071401566,    0.019622514,   0.032163795,   -0.004167056,
2336        0.02295182,     0.030739572,   0.056506045,   0.004612461,
2337        0.06524936,     0.059999723,   0.046395954,   -0.0045512207,
2338        -0.1335546,     -0.030136576,  0.11584653,    -0.014678886,
2339        0.0020118146,   -0.09688814,   -0.0790206,    0.039770417,
2340        -0.0329582,     0.07922767,    0.029322514,   0.026405897,
2341        0.04207835,     -0.07073373,   0.063781224,   0.0859677,
2342        -0.10925287,    -0.07011058,   0.048005477,   0.03438226,
2343        -0.09606514,    -0.006669445,  -0.043381985,  0.04240257,
2344        -0.06955775,    -0.06769346,   0.043903265,   -0.026784198,
2345        -0.017840602,   0.024307009,   -0.040079936,  -0.019946516,
2346        0.045318738,    -0.12233574,   0.026170589,   0.0074471775,
2347        0.15978073,     0.10185836,    0.10298046,    -0.015476589,
2348        -0.039390966,   -0.072174534,  0.0739445,     -0.1211869,
2349        -0.0347889,     -0.07943156,   0.014809798,   -0.12412325,
2350        -0.0030663363,  0.039695457,   0.0647603,     -0.08291318,
2351        -0.018529687,   -0.004423833,  0.0037507233,  0.084633216,
2352        -0.01514876,    -0.056505352,  -0.012800942,  -0.06994386,
2353        0.012962922,    -0.031234352,  0.07029052,    0.016418684,
2354        0.03618972,     0.055686004,   -0.08663945,   -0.017404709,
2355        -0.054761406,   0.029065743,   0.052404847,   0.020238016,
2356        0.0048197987,   -0.0214882,    0.07078733,    0.013016777,
2357        0.06262858,     0.009184685,   0.020785125,   -0.043904778,
2358        -0.0270329,     -0.03299152,   -0.060088247,  -0.015162964,
2359        -0.001828936,   0.12642565,    -0.056757294,  0.013586685,
2360        0.09232601,     -0.035886683,  0.06000002,    0.05229691,
2361        -0.052580316,   -0.082029596,  -0.010794592,  0.012947712,
2362        -0.036429964,   -0.085508935,  -0.13127148,   -0.017744139,
2363        0.031502828,    0.036232427,   -0.031581745,  0.023051167,
2364        -0.05325106,    -0.03421577,   0.028793324,   -0.034633752,
2365        -0.009881397,   -0.043551125,  -0.018609839,  0.0019097115,
2366        -0.008799762,   0.056595087,   0.0022273948,  0.055752404});
2367 
2368   lstm.SetRecurrentToOutputWeights({
2369       0.025825322,   -0.05813119,  0.09495884,   -0.045984812,   -0.01255415,
2370       -0.0026479573, -0.08196161,  -0.054914974, -0.0046604523,  -0.029587349,
2371       -0.044576716,  -0.07480124,  -0.082868785, 0.023254942,    0.027502948,
2372       -0.0039728214, -0.08683098,  -0.08116779,  -0.014675607,   -0.037924774,
2373       -0.023314456,  -0.007401714, -0.09255757,  0.029460307,    -0.08829125,
2374       -0.005139627,  -0.08989442,  -0.0555066,   0.13596267,     -0.025062224,
2375       -0.048351806,  -0.03850004,  0.07266485,   -0.022414139,   0.05940088,
2376       0.075114764,   0.09597592,   -0.010211725, -0.0049794707,  -0.011523867,
2377       -0.025980417,  0.072999895,  0.11091378,   -0.081685916,   0.014416728,
2378       0.043229222,   0.034178585,  -0.07530371,  0.035837382,    -0.085607,
2379       -0.007721233,  -0.03287832,  -0.043848954, -0.06404588,    -0.06632928,
2380       -0.073643476,  0.008214239,  -0.045984086, 0.039764922,    0.03474462,
2381       0.060612556,   -0.080590084, 0.049127717,  0.04151091,     -0.030063879,
2382       0.008801774,   -0.023021035, -0.019558564, 0.05158114,     -0.010947698,
2383       -0.011825728,  0.0075720972, 0.0699727,    -0.0039981045,  0.069350146,
2384       0.08799282,    0.016156472,  0.035502106,  0.11695009,     0.006217345,
2385       0.13392477,    -0.037875112, 0.025745004,  0.08940699,     -0.00924166,
2386       0.0046702605,  -0.036598757, -0.08811812,  0.10522024,     -0.032441203,
2387       0.008176899,   -0.04454919,  0.07058152,   0.0067963637,   0.039206743,
2388       0.03259838,    0.03725492,   -0.09515802,  0.013326398,    -0.052055415,
2389       -0.025676316,  0.03198509,   -0.015951829, -0.058556724,   0.036879618,
2390       0.043357447,   0.028362012,  -0.05908629,  0.0059240665,   -0.04995891,
2391       -0.019187413,  0.0276265,    -0.01628143,  0.0025863599,   0.08800015,
2392       0.035250366,   -0.022165963, -0.07328642,  -0.009415526,   -0.07455109,
2393       0.11690406,    0.0363299,    0.07411125,   0.042103454,    -0.009660886,
2394       0.019076364,   0.018299393,  -0.046004917, 0.08891175,     0.0431396,
2395       -0.026327137,  -0.051502608, 0.08979574,   -0.051670972,   0.04940282,
2396       -0.07491107,   -0.021240504, 0.022596184,  -0.034280192,   0.060163025,
2397       -0.058211457,  -0.051837247, -0.01349775,  -0.04639988,    -0.035936575,
2398       -0.011681591,  0.064818054,  0.0073146066, -0.021745546,   -0.043124277,
2399       -0.06471268,   -0.07053354,  -0.029321948, -0.05330136,    0.016933719,
2400       -0.053782392,  0.13747959,   -0.1361751,   -0.11569455,    0.0033329215,
2401       0.05693899,    -0.053219706, 0.063698,     0.07977434,     -0.07924483,
2402       0.06936997,    0.0034815092, -0.007305279, -0.037325785,   -0.07251102,
2403       -0.033633437,  -0.08677009,  0.091591336,  -0.14165086,    0.021752775,
2404       0.019683983,   0.0011612234, -0.058154266, 0.049996935,    0.0288841,
2405       -0.0024567875, -0.14345716,  0.010955264,  -0.10234828,    0.1183656,
2406       -0.0010731248, -0.023590032, -0.072285876, -0.0724771,     -0.026382286,
2407       -0.0014920527, 0.042667855,  0.0018776858, 0.02986552,     0.009814309,
2408       0.0733756,     0.12289186,   0.018043943,  -0.0458958,     0.049412545,
2409       0.033632483,   0.05495232,   0.036686596,  -0.013781798,   -0.010036754,
2410       0.02576849,    -0.08307328,  0.010112348,  0.042521734,    -0.05869831,
2411       -0.071689695,  0.03876447,   -0.13275425,  -0.0352966,     -0.023077697,
2412       0.10285965,    0.084736146,  0.15568255,   -0.00040734606, 0.027835453,
2413       -0.10292561,   -0.032401145, 0.10053256,   -0.026142767,   -0.08271222,
2414       -0.0030240538, -0.016368777, 0.1070414,    0.042672627,    0.013456989,
2415       -0.0437609,    -0.022309763, 0.11576483,   0.04108048,     0.061026827,
2416       -0.0190714,    -0.0869359,   0.037901703,  0.0610107,      0.07202949,
2417       0.01675338,    0.086139716,  -0.08795751,  -0.014898893,   -0.023771819,
2418       -0.01965048,   0.007955471,  -0.043740474, 0.03346837,     -0.10549954,
2419       0.090567775,   0.042013682,  -0.03176985,  0.12569028,     -0.02421228,
2420       -0.029526481,  0.023851605,  0.031539805,  0.05292009,     -0.02344001,
2421       -0.07811758,   -0.08834428,  0.10094801,   0.16594367,     -0.06861939,
2422       -0.021256343,  -0.041093912, -0.06669611,  0.035498552,    0.021757556,
2423       -0.09302526,   -0.015403468, -0.06614931,  -0.051798206,   -0.013874718,
2424       0.03630673,    0.010412845,  -0.08077351,  0.046185967,    0.0035662893,
2425       0.03541868,    -0.094149634, -0.034814864, 0.003128424,    -0.020674974,
2426       -0.03944324,   -0.008110165, -0.11113267,  0.08484226,     0.043586485,
2427       0.040582247,   0.0968012,    -0.065249965, -0.028036479,   0.0050708856,
2428       0.0017462453,  0.0326779,    0.041296225,  0.09164146,     -0.047743853,
2429       -0.015952192,  -0.034451712, 0.084197424,  -0.05347844,    -0.11768019,
2430       0.085926116,   -0.08251791,  -0.045081906, 0.0948852,      0.068401024,
2431       0.024856757,   0.06978981,   -0.057309967, -0.012775832,   -0.0032452994,
2432       0.01977615,    -0.041040014, -0.024264973, 0.063464895,    0.05431621,
2433   });
2434 
2435   lstm.SetCellToInputWeights(
2436       {0.040369894, 0.030746894,  0.24704495,  0.018586371,  -0.037586458,
2437        -0.15312155, -0.11812848,  -0.11465643, 0.20259799,   0.11418174,
2438        -0.10116027, -0.011334949, 0.12411352,  -0.076769054, -0.052169047,
2439        0.21198851,  -0.38871562,  -0.09061183, -0.09683246,  -0.21929175});
2440 
2441   lstm.SetCellToForgetWeights(
2442       {-0.01998659,  -0.15568835,  -0.24248174,   -0.012770197, 0.041331276,
2443        -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
2444        -0.047248036, 0.021479502,  0.033189066,   0.11952997,   -0.020432774,
2445        0.64658105,   -0.06650122,  -0.03467612,   0.095340036,  0.23647355});
2446 
2447   lstm.SetCellToOutputWeights(
2448       {0.08286371,  -0.08261836, -0.51210177, 0.002913762, 0.17764764,
2449        -0.5495371,  -0.08460716, -0.24552552, 0.030037103, 0.04123544,
2450        -0.11940523, 0.007358328, 0.1890978,   0.4833202,   -0.34441817,
2451        0.36312827,  -0.26375428, 0.1457655,   -0.19724406, 0.15548733});
2452 
2453   lstm.SetProjectionWeights(
2454       {-0.009802181,  0.09401916,    0.0717386,     -0.13895074,  0.09641832,
2455        0.060420845,   0.08539281,    0.054285463,   0.061395317,  0.034448683,
2456        -0.042991187,  0.019801661,   -0.16840284,   -0.015726732, -0.23041931,
2457        -0.024478018,  -0.10959692,   -0.013875541,  0.18600968,   -0.061274476,
2458        0.0138165,     -0.08160894,   -0.07661644,   0.032372914,  0.16169067,
2459        0.22465782,    -0.03993472,   -0.004017731,  0.08633481,   -0.28869787,
2460        0.08682067,    0.17240396,    0.014975425,   0.056431185,  0.031037588,
2461        0.16702051,    0.0077946745,  0.15140012,    0.29405436,   0.120285,
2462        -0.188994,     -0.027265169,  0.043389652,   -0.022061434, 0.014777949,
2463        -0.20203483,   0.094781205,   0.19100232,    0.13987629,   -0.036132768,
2464        -0.06426278,   -0.05108664,   0.13221376,    0.009441198,  -0.16715929,
2465        0.15859416,    -0.040437475,  0.050779544,   -0.022187516, 0.012166504,
2466        0.027685808,   -0.07675938,   -0.0055694645, -0.09444123,  0.0046453946,
2467        0.050794356,   0.10770313,    -0.20790008,   -0.07149004,  -0.11425117,
2468        0.008225835,   -0.035802525,  0.14374903,    0.15262283,   0.048710253,
2469        0.1847461,     -0.007487823,  0.11000021,    -0.09542012,  0.22619456,
2470        -0.029149994,  0.08527916,    0.009043713,   0.0042746216, 0.016261552,
2471        0.022461696,   0.12689082,    -0.043589946,  -0.12035478,  -0.08361797,
2472        -0.050666027,  -0.1248618,    -0.1275799,    -0.071875185, 0.07377272,
2473        0.09944291,    -0.18897448,   -0.1593054,    -0.06526116,  -0.040107165,
2474        -0.004618631,  -0.067624845,  -0.007576253,  0.10727444,   0.041546922,
2475        -0.20424393,   0.06907816,    0.050412357,   0.00724631,   0.039827548,
2476        0.12449835,    0.10747581,    0.13708383,    0.09134148,   -0.12617786,
2477        -0.06428341,   0.09956831,    0.1208086,     -0.14676677,  -0.0727722,
2478        0.1126304,     0.010139365,   0.015571211,   -0.038128063, 0.022913318,
2479        -0.042050496,  0.16842307,    -0.060597885,  0.10531834,   -0.06411776,
2480        -0.07451711,   -0.03410368,   -0.13393489,   0.06534304,   0.003620307,
2481        0.04490757,    0.05970546,    0.05197996,    0.02839995,   0.10434969,
2482        -0.013699693,  -0.028353551,  -0.07260381,   0.047201227,  -0.024575593,
2483        -0.036445823,  0.07155557,    0.009672501,   -0.02328883,  0.009533515,
2484        -0.03606021,   -0.07421458,   -0.028082801,  -0.2678904,   -0.13221288,
2485        0.18419984,    -0.13012612,   -0.014588381,  -0.035059117, -0.04824723,
2486        0.07830115,    -0.056184657,  0.03277091,    0.025466874,  0.14494097,
2487        -0.12522776,   -0.098633975,  -0.10766018,   -0.08317623,  0.08594209,
2488        0.07749552,    0.039474737,   0.1776665,     -0.07409566,  -0.0477268,
2489        0.29323658,    0.10801441,    0.1154011,     0.013952499,  0.10739139,
2490        0.10708251,    -0.051456142,  0.0074137426,  -0.10430189,  0.10034707,
2491        0.045594677,   0.0635285,     -0.0715442,    -0.089667566, -0.10811871,
2492        0.00026344223, 0.08298446,    -0.009525053,  0.006585689,  -0.24567553,
2493        -0.09450807,   0.09648481,    0.026996298,   -0.06419476,  -0.04752702,
2494        -0.11063944,   -0.23441927,   -0.17608605,   -0.052156363, 0.067035615,
2495        0.19271925,    -0.0032889997, -0.043264326,  0.09663576,   -0.057112187,
2496        -0.10100678,   0.0628376,     0.04447668,    0.017961001,  -0.10094388,
2497        -0.10190601,   0.18335468,    0.10494553,    -0.052095775, -0.0026118709,
2498        0.10539724,    -0.04383912,   -0.042349473,  0.08438151,   -0.1947263,
2499        0.02251204,    0.11216432,    -0.10307853,   0.17351969,   -0.039091777,
2500        0.08066188,    -0.00561982,   0.12633002,    0.11335965,   -0.0088127935,
2501        -0.019777594,  0.06864014,    -0.059751723,  0.016233567,  -0.06894641,
2502        -0.28651384,   -0.004228674,  0.019708522,   -0.16305895,  -0.07468996,
2503        -0.0855457,    0.099339016,   -0.07580735,   -0.13775392,  0.08434318,
2504        0.08330512,    -0.12131499,   0.031935584,   0.09180414,   -0.08876437,
2505        -0.08049874,   0.008753825,   0.03498998,    0.030215185,  0.03907079,
2506        0.089751154,   0.029194152,   -0.03337423,   -0.019092513, 0.04331237,
2507        0.04299654,    -0.036394123,  -0.12915532,   0.09793732,   0.07512415,
2508        -0.11319543,   -0.032502122,  0.15661901,    0.07671967,   -0.005491124,
2509        -0.19379048,   -0.218606,     0.21448623,    0.017840758,  0.1416943,
2510        -0.07051762,   0.19488361,    0.02664691,    -0.18104725,  -0.09334311,
2511        0.15026465,    -0.15493552,   -0.057762887,  -0.11604192,  -0.262013,
2512        -0.01391798,   0.012185008,   0.11156489,    -0.07483202,  0.06693364,
2513        -0.26151478,   0.046425626,   0.036540434,   -0.16435726,  0.17338543,
2514        -0.21401681,   -0.11385144,   -0.08283257,   -0.069031075, 0.030635102,
2515        0.010969227,   0.11109743,    0.010919218,   0.027526086,  0.13519906,
2516        0.01891392,    -0.046839405,  -0.040167913,  0.017953383,  -0.09700955,
2517        0.0061885654,  -0.07000971,   0.026893595,   -0.038844477, 0.14543656});
2518 
2519   static float lstm_input[][20] = {
2520       {// Batch0: 4 (input_sequence_size) * 5 (n_input)
2521        0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
2522        0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
2523        0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
2524 
2525       {// Batch1: 4 (input_sequence_size) * 5 (n_input)
2526        0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
2527        0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
2528        0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
2529 
2530   static float lstm_fw_golden_output[][64] = {
2531       {// Batch0: 4 (input_sequence_size) * 16 (n_output)
2532        -0.00396806, 0.029352,     -0.00279226, 0.0159977,   -0.00835576,
2533        -0.0211779,  0.0283512,    -0.0114597,  0.00907307,  -0.0244004,
2534        -0.0152191,  -0.0259063,   0.00914318,  0.00415118,  0.017147,
2535        0.0134203,   -0.0166936,   0.0381209,   0.000889694, 0.0143363,
2536        -0.0328911,  -0.0234288,   0.0333051,   -0.012229,   0.0110322,
2537        -0.0457725,  -0.000832209, -0.0202817,  0.0327257,   0.0121308,
2538        0.0155969,   0.0312091,    -0.0213783,  0.0350169,   0.000324794,
2539        0.0276012,   -0.0263374,   -0.0371449,  0.0446149,   -0.0205474,
2540        0.0103729,   -0.0576349,   -0.0150052,  -0.0292043,  0.0376827,
2541        0.0136115,   0.0243435,    0.0354492,   -0.0189322,  0.0464512,
2542        -0.00251373, 0.0225745,    -0.0308346,  -0.0317124,  0.0460407,
2543        -0.0189395,  0.0149363,    -0.0530162,  -0.0150767,  -0.0340193,
2544        0.0286833,   0.00824207,   0.0264887,   0.0305169},
2545       {// Batch1: 4 (input_sequence_size) * 16 (n_output)
2546        -0.013869,    0.0287268,   -0.00334693, 0.00733398,  -0.0287926,
2547        -0.0186926,   0.0193662,   -0.0115437,  0.00422612,  -0.0345232,
2548        0.00223253,   -0.00957321, 0.0210624,   0.013331,    0.0150954,
2549        0.02168,      -0.0141913,  0.0322082,   0.00227024,  0.0260507,
2550        -0.0188721,   -0.0296489,  0.0399134,   -0.0160509,  0.0116039,
2551        -0.0447318,   -0.0150515,  -0.0277406,  0.0316596,   0.0118233,
2552        0.0214762,    0.0293641,   -0.0204549,  0.0450315,   -0.00117378,
2553        0.0167673,    -0.0375007,  -0.0238314,  0.038784,    -0.0174034,
2554        0.0131743,    -0.0506589,  -0.0048447,  -0.0240239,  0.0325789,
2555        0.00790065,   0.0220157,   0.0333314,   -0.0264787,  0.0387855,
2556        -0.000764675, 0.0217599,   -0.037537,   -0.0335206,  0.0431679,
2557        -0.0211424,   0.010203,    -0.062785,   -0.00832363, -0.025181,
2558        0.0412031,    0.0118723,   0.0239643,   0.0394009}};
2559 
2560   static float lstm_combined_golden_output[][64] = {
2561       {-0.022014, 0.073544,  -0.002235, 0.040068,  -0.037136, -0.052788,
2562        0.075325,  -0.029378, 0.024298,  -0.07733,  -0.030674, -0.060229,
2563        0.040599,  0.011608,  0.042005,  0.045977,  -0.039225, 0.076294,
2564        0.000735,  0.032852,  -0.069869, -0.053312, 0.073527,  -0.028136,
2565        0.021585,  -0.102679, -0.004327, -0.043304, 0.072861,  0.027077,
2566        0.034558,  0.068292,  -0.036292, 0.069832,  -0.003032, 0.053829,
2567        -0.043821, -0.072713, 0.085029,  -0.040374, 0.020014,  -0.104521,
2568        -0.034504, -0.059759, 0.062569,  0.025652,  0.049306,  0.061189,
2569        -0.025146, 0.079643,  -0.005188, 0.033080,  -0.048079, -0.048082,
2570        0.069369,  -0.028900, 0.024572,  -0.077547, -0.022517, -0.054477,
2571        0.038857,  0.013336,  0.043234,  0.044788},
2572       {-0.039186, 0.070792,  -0.005913, 0.02642,   -0.068274, -0.05022,
2573        0.061444,  -0.031241, 0.014996,  -0.094544, -0.004146, -0.03464,
2574        0.058981,  0.026097,  0.039781,  0.058408,  -0.031887, 0.069252,
2575        0.00576,   0.054062,  -0.042801, -0.059974, 0.085272,  -0.034453,
2576        0.026097,  -0.0959,   -0.031164, -0.058699, 0.06839,   0.020512,
2577        0.044727,  0.063609,  -0.039863, 0.084819,  -0.003909, 0.028666,
2578        -0.075677, -0.045125, 0.070379,  -0.033895, 0.022111,  -0.097184,
2579        -0.004921, -0.040851, 0.062316,  0.017435,  0.041437,  0.064568,
2580        -0.039656, 0.060726,  -0.003402, 0.036854,  -0.056503, -0.058554,
2581        0.068588,  -0.034879, 0.01352,   -0.09962,  -0.01434,  -0.039505,
2582        0.065133,  0.024321,  0.038473,  0.062438}};
2583 
2584   const int input_sequence_size = lstm.sequence_length() * lstm.num_inputs();
2585   EXPECT_EQ(input_sequence_size, 20);
2586   float* batch0_start = lstm_input[0];
2587   float* batch0_end = batch0_start + input_sequence_size;
2588   lstm.SetInput(0, batch0_start, batch0_end);
2589 
2590   float* batch1_start = lstm_input[1];
2591   float* batch1_end = batch1_start + input_sequence_size;
2592   lstm.SetInput(input_sequence_size, batch1_start, batch1_end);
2593 
2594   ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
2595 
2596   const int output_sequence_size =
2597       lstm.sequence_length() * lstm.num_fw_outputs();
2598   EXPECT_EQ(output_sequence_size, 64);
2599   std::vector<float> expected;
2600   const float* golden_start_batch0 = lstm_fw_golden_output[0];
2601   const float* golden_end_batch0 = golden_start_batch0 + output_sequence_size;
2602   expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
2603 
2604   const float* golden_start_batch1 = lstm_fw_golden_output[1];
2605   const float* golden_end_batch1 = golden_start_batch1 + output_sequence_size;
2606   expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
2607   EXPECT_THAT(lstm.GetFwOutput(), ElementsAreArray(ArrayFloatNear(expected)));
2608 
2609   // Check if the sum of forward backward matches the golden.
2610   expected.clear();
2611   golden_start_batch0 = lstm_combined_golden_output[0];
2612   golden_end_batch0 = golden_start_batch0 + output_sequence_size;
2613   expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
2614 
2615   golden_start_batch1 = lstm_combined_golden_output[1];
2616   golden_end_batch1 = golden_start_batch1 + output_sequence_size;
2617   expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
2618 
2619   std::vector<float> combined;
2620   for (int i = 0; i < lstm.GetFwOutput().size(); ++i) {
2621     combined.push_back(lstm.GetFwOutput()[i] + lstm.GetBwOutput()[i]);
2622   }
2623   EXPECT_THAT(combined, ElementsAreArray(ArrayFloatNear(expected)));
2624 }
2625 
2626 // Same as the no cifg no peephole no projection no clipping test, but have an
2627 // aux input (without aux input weights), this is the case when stacking but no
2628 // cross-links.
TEST_P(LSTMOpTest,BlackBoxTestWithAuxInputZeroAuxWeight)2629 TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
2630   const int n_batch = 1;
2631   const int n_input = 2;
2632   // n_cell and n_output have the same size when there is no projection.
2633   const int n_cell = 4;
2634   const int n_output = 4;
2635   const int sequence_length = 3;
2636   auto params = GetParam();
2637   const bool quantize_weights = std::get<0>(params);
2638   const bool asymmetric_quantize_inputs = std::get<1>(params);
2639 
2640   BidirectionalLSTMOpModel lstm(
2641       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
2642       /*use_peephole=*/false, /*use_projection_weights=*/false,
2643       /*use_projection_bias=*/false, /*merge_outputs=*/false,
2644       /*use_aux_input=*/true, /*cell_clip=*/0.0,
2645       /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
2646       {
2647           {sequence_length, n_batch, n_input},  // input tensor
2648 
2649           // Forward cell
2650           {n_cell, n_input},  // input_to_input_weight tensor
2651           {n_cell, n_input},  // input_to_forget_weight tensor
2652           {n_cell, n_input},  // input_to_cell_weight tensor
2653           {n_cell, n_input},  // input_to_output_weight tensor
2654 
2655           {n_cell, n_output},  // recurrent_to_input_weight tensor
2656           {n_cell, n_output},  // recurrent_to_forget_weight tensor
2657           {n_cell, n_output},  // recurrent_to_cell_weight tensor
2658           {n_cell, n_output},  // recurrent_to_output_weight tensor
2659 
2660           {0},  // cell_to_input_weight tensor
2661           {0},  // cell_to_forget_weight tensor
2662           {0},  // cell_to_output_weight tensor
2663 
2664           {n_cell},  // input_gate_bias tensor
2665           {n_cell},  // forget_gate_bias tensor
2666           {n_cell},  // cell_gate_bias tensor
2667           {n_cell},  // output_gate_bias tensor
2668 
2669           {0, 0},  // projection_weight tensor
2670           {0},     // projection_bias tensor
2671 
2672           // Backward cell
2673           {n_cell, n_input},  // input_to_input_weight tensor
2674           {n_cell, n_input},  // input_to_forget_weight tensor
2675           {n_cell, n_input},  // input_to_cell_weight tensor
2676           {n_cell, n_input},  // input_to_output_weight tensor
2677 
2678           {n_cell, n_output},  // recurrent_to_input_weight tensor
2679           {n_cell, n_output},  // recurrent_to_forget_weight tensor
2680           {n_cell, n_output},  // recurrent_to_cell_weight tensor
2681           {n_cell, n_output},  // recurrent_to_output_weight tensor
2682 
2683           {0},  // cell_to_input_weight tensor
2684           {0},  // cell_to_forget_weight tensor
2685           {0},  // cell_to_output_weight tensor
2686 
2687           {n_cell},  // input_gate_bias tensor
2688           {n_cell},  // forget_gate_bias tensor
2689           {n_cell},  // cell_gate_bias tensor
2690           {n_cell},  // output_gate_bias tensor
2691 
2692           {0, 0},  // projection_weight tensor
2693           {0},     // projection_bias tensor
2694 
2695           {n_batch, n_output},  // activation_state tensor
2696           {n_batch, n_cell},    // cell_state tensor
2697 
2698           {n_batch, n_output},  // activation_state tensor
2699           {n_batch, n_cell},    // cell_state tensor
2700 
2701           {sequence_length, n_batch, n_input},  // aux_input tensor
2702           {n_cell, n_input},                    // aux_fw_input_to_input tensor
2703           {n_cell, n_input},                    // aux_fw_input_to_forget tensor
2704           {n_cell, n_input},                    // aux_fw_input_to_cell tensor
2705           {n_cell, n_input},                    // aux_fw_input_to_output tensor
2706           {n_cell, n_input},                    // aux_bw_input_to_input tensor
2707           {n_cell, n_input},                    // aux_bw_input_to_forget tensor
2708           {n_cell, n_input},                    // aux_bw_input_to_cell tensor
2709           {n_cell, n_input},                    // aux_bw_input_to_output tensor
2710       },
2711       asymmetric_quantize_inputs);
2712 
2713   lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
2714                                -0.34550029, 0.04266912, -0.15680569,
2715                                -0.34856534, 0.43890524});
2716 
2717   lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
2718                               -0.20583314, 0.44344562, 0.22077113,
2719                               -0.29909778});
2720 
2721   lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
2722                                 -0.31343272, -0.40032279, 0.44781327,
2723                                 0.01387155, -0.35593212});
2724 
2725   lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
2726                                 0.40525138, 0.44272184, 0.03897077, -0.1556896,
2727                                 0.19487578});
2728 
2729   lstm.SetInputGateBias({0., 0., 0., 0.});
2730 
2731   lstm.SetCellBias({0., 0., 0., 0.});
2732 
2733   lstm.SetForgetGateBias({1., 1., 1., 1.});
2734 
2735   lstm.SetOutputGateBias({0., 0., 0., 0.});
2736 
2737   lstm.SetRecurrentToInputWeights(
2738       {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
2739        -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
2740        -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
2741 
2742   lstm.SetRecurrentToCellWeights(
2743       {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
2744        -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
2745        -0.46367589, 0.26016325, -0.03894562, -0.16368064});
2746 
2747   lstm.SetRecurrentToForgetWeights(
2748       {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
2749        -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
2750        0.28053468, 0.01560611, -0.20127171, -0.01140004});
2751 
2752   lstm.SetRecurrentToOutputWeights(
2753       {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
2754        0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
2755        -0.51818722, -0.15390486, 0.0468148, 0.39922136});
2756 
2757   // Input should have n_input * sequence_length many values.
2758   static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
2759   static float lstm_fw_golden_output[] = {
2760       -0.02973187, 0.1229473,  0.20885126, -0.15358765,
2761       -0.03716109, 0.12507336, 0.41193449, -0.20860538,
2762       -0.15053082, 0.09120187, 0.24278517, -0.12222792};
2763   static float lstm_bw_golden_output[] = {
2764       -0.0806187, 0.139077, 0.400476,   -0.197842, -0.0332076, 0.123838,
2765       0.309777,   -0.17621, -0.0490733, 0.0739237, 0.067706,   -0.0208124};
2766 
2767   float* batch0_start = lstm_input;
2768   float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
2769 
2770   lstm.SetInput(0, batch0_start, batch0_end);
2771   // Aux input and input are the same, so we should observe the same outputs
2772   // as there's no aux input.
2773   lstm.SetAuxInput(0, batch0_start, batch0_end);
2774   std::vector<float> dummy_weights(n_cell * n_input, 0.0f);
2775   lstm.SetAuxInputToInputWeights(dummy_weights);
2776   lstm.SetAuxInputToForgetWeights(dummy_weights);
2777   lstm.SetAuxInputToCellWeights(dummy_weights);
2778   lstm.SetAuxInputToOutputWeights(dummy_weights);
2779 
2780   ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
2781 
2782   float* fw_golden_start = lstm_fw_golden_output;
2783   float* fw_golden_end =
2784       fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length();
2785   std::vector<float> fw_expected;
2786   fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end);
2787   EXPECT_THAT(lstm.GetFwOutput(),
2788               ElementsAreArray(
2789                   ArrayFloatNear(fw_expected, quantize_weights ? 1e-2 : 1e-5)));
2790 
2791   float* bw_golden_start = lstm_bw_golden_output;
2792   float* bw_golden_end =
2793       bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length();
2794   std::vector<float> bw_expected;
2795   bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
2796   EXPECT_THAT(lstm.GetBwOutput(),
2797               ElementsAreArray(
2798                   ArrayFloatNear(bw_expected, quantize_weights ? 1e-2 : 1e-5)));
2799 }
2800 
2801 // Same as the no cifg no peephole no projection no clipping test, but have an
2802 // aux input with non-zero weights.
TEST_P(LSTMOpTest,BlackBoxTestWithAuxInput)2803 TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
2804   const int n_batch = 1;
2805   const int n_input = 2;
2806   // n_cell and n_output have the same size when there is no projection.
2807   const int n_cell = 4;
2808   const int n_output = 4;
2809   const int sequence_length = 3;
2810   auto params = GetParam();
2811   const bool quantize_weights = std::get<0>(params);
2812   const bool asymmetric_quantize_inputs = std::get<1>(params);
2813 
2814   BidirectionalLSTMOpModel lstm(
2815       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
2816       /*use_peephole=*/false, /*use_projection_weights=*/false,
2817       /*use_projection_bias=*/false, /*merge_outputs=*/false,
2818       /*use_aux_input=*/true, /*cell_clip=*/0.0,
2819       /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
2820       {
2821           {sequence_length, n_batch, n_input},  // input tensor
2822 
2823           // Forward cell
2824           {n_cell, n_input},  // input_to_input_weight tensor
2825           {n_cell, n_input},  // input_to_forget_weight tensor
2826           {n_cell, n_input},  // input_to_cell_weight tensor
2827           {n_cell, n_input},  // input_to_output_weight tensor
2828 
2829           {n_cell, n_output},  // recurrent_to_input_weight tensor
2830           {n_cell, n_output},  // recurrent_to_forget_weight tensor
2831           {n_cell, n_output},  // recurrent_to_cell_weight tensor
2832           {n_cell, n_output},  // recurrent_to_output_weight tensor
2833 
2834           {0},  // cell_to_input_weight tensor
2835           {0},  // cell_to_forget_weight tensor
2836           {0},  // cell_to_output_weight tensor
2837 
2838           {n_cell},  // input_gate_bias tensor
2839           {n_cell},  // forget_gate_bias tensor
2840           {n_cell},  // cell_gate_bias tensor
2841           {n_cell},  // output_gate_bias tensor
2842 
2843           {0, 0},  // projection_weight tensor
2844           {0},     // projection_bias tensor
2845 
2846           // Backward cell
2847           {n_cell, n_input},  // input_to_input_weight tensor
2848           {n_cell, n_input},  // input_to_forget_weight tensor
2849           {n_cell, n_input},  // input_to_cell_weight tensor
2850           {n_cell, n_input},  // input_to_output_weight tensor
2851 
2852           {n_cell, n_output},  // recurrent_to_input_weight tensor
2853           {n_cell, n_output},  // recurrent_to_forget_weight tensor
2854           {n_cell, n_output},  // recurrent_to_cell_weight tensor
2855           {n_cell, n_output},  // recurrent_to_output_weight tensor
2856 
2857           {0},  // cell_to_input_weight tensor
2858           {0},  // cell_to_forget_weight tensor
2859           {0},  // cell_to_output_weight tensor
2860 
2861           {n_cell},  // input_gate_bias tensor
2862           {n_cell},  // forget_gate_bias tensor
2863           {n_cell},  // cell_gate_bias tensor
2864           {n_cell},  // output_gate_bias tensor
2865 
2866           {0, 0},  // projection_weight tensor
2867           {0},     // projection_bias tensor
2868 
2869           {n_batch, n_output},  // activation_state tensor
2870           {n_batch, n_cell},    // cell_state tensor
2871 
2872           {n_batch, n_output},  // activation_state tensor
2873           {n_batch, n_cell},    // cell_state tensor
2874 
2875           {sequence_length, n_batch, n_input},  // aux_input tensor
2876           {n_cell, n_input},                    // aux_fw_input_to_input tensor
2877           {n_cell, n_input},                    // aux_fw_input_to_forget tensor
2878           {n_cell, n_input},                    // aux_fw_input_to_cell tensor
2879           {n_cell, n_input},                    // aux_fw_input_to_output tensor
2880           {n_cell, n_input},                    // aux_bw_input_to_input tensor
2881           {n_cell, n_input},                    // aux_bw_input_to_forget tensor
2882           {n_cell, n_input},                    // aux_bw_input_to_cell tensor
2883           {n_cell, n_input},                    // aux_bw_input_to_output tensor
2884       },
2885       asymmetric_quantize_inputs);
2886 
2887   lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
2888                                -0.34550029, 0.04266912, -0.15680569,
2889                                -0.34856534, 0.43890524});
2890 
2891   lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
2892                               -0.20583314, 0.44344562, 0.22077113,
2893                               -0.29909778});
2894 
2895   lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
2896                                 -0.31343272, -0.40032279, 0.44781327,
2897                                 0.01387155, -0.35593212});
2898 
2899   lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
2900                                 0.40525138, 0.44272184, 0.03897077, -0.1556896,
2901                                 0.19487578});
2902 
2903   lstm.SetInputGateBias({0., 0., 0., 0.});
2904 
2905   lstm.SetCellBias({0., 0., 0., 0.});
2906 
2907   lstm.SetForgetGateBias({1., 1., 1., 1.});
2908 
2909   lstm.SetOutputGateBias({0., 0., 0., 0.});
2910 
2911   lstm.SetRecurrentToInputWeights(
2912       {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
2913        -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
2914        -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
2915 
2916   lstm.SetRecurrentToCellWeights(
2917       {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
2918        -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
2919        -0.46367589, 0.26016325, -0.03894562, -0.16368064});
2920 
2921   lstm.SetRecurrentToForgetWeights(
2922       {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
2923        -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
2924        0.28053468, 0.01560611, -0.20127171, -0.01140004});
2925 
2926   lstm.SetRecurrentToOutputWeights(
2927       {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
2928        0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
2929        -0.51818722, -0.15390486, 0.0468148, 0.39922136});
2930 
2931   // Input should have n_input * sequence_length many values.
2932   static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
2933   static float lstm_fw_golden_output[] = {
2934       0.153335, 0.542754, 0.708602, 0.742855, 0.247581, 0.835739,
2935       0.947797, 0.958177, 0.410892, 0.672268, 0.761909, 0.829133};
2936   static float lstm_bw_golden_output[] = {
2937       0.342275, 0.883431, 0.955930, 0.975621, 0.204939, 0.806858,
2938       0.914849, 0.934871, 0.123236, 0.373087, 0.465377, 0.517630};
2939 
2940   lstm.SetAuxInputToInputWeights({0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
2941   lstm.SetAuxInputToForgetWeights({0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 1.0});
2942   lstm.SetAuxInputToCellWeights({0.5, 0.6, 0.7, 0.8, 0.5, 0.6, 0.7, 0.8});
2943   lstm.SetAuxInputToOutputWeights({0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
2944 
2945   float* batch0_start = lstm_input;
2946   float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
2947 
2948   lstm.SetInput(0, batch0_start, batch0_end);
2949   lstm.SetAuxInput(0, batch0_start, batch0_end);
2950 
2951   ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
2952 
2953   float* fw_golden_start = lstm_fw_golden_output;
2954   float* fw_golden_end =
2955       fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length();
2956   std::vector<float> fw_expected;
2957   fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end);
2958   EXPECT_THAT(lstm.GetFwOutput(),
2959               ElementsAreArray(
2960                   ArrayFloatNear(fw_expected, quantize_weights ? 1e-2 : 1e-5)));
2961 
2962   float* bw_golden_start = lstm_bw_golden_output;
2963   float* bw_golden_end =
2964       bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length();
2965   std::vector<float> bw_expected;
2966   bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
2967   EXPECT_THAT(lstm.GetBwOutput(),
2968               ElementsAreArray(
2969                   ArrayFloatNear(bw_expected, quantize_weights ? 1e-2 : 1e-5)));
2970 }
2971 
2972 }  // namespace
2973 }  // namespace tflite
2974