1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite Bidirectional LSTM op.
16
17 #include <tuple>
18 #include <vector>
19
20 #include <gtest/gtest.h>
21 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
22 #include "tensorflow/lite/kernels/test_util.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24
25 namespace tflite {
26 namespace {
27
28 using ::testing::ElementsAreArray;
29
30 class BidirectionalLSTMOpModel : public SingleOpModel {
31 public:
BidirectionalLSTMOpModel(int n_batch,int n_input,int n_cell,int n_output,int sequence_length,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,bool merge_outputs,bool use_aux_input,float cell_clip,float proj_clip,bool quantize_weights,bool time_major,const std::vector<std::vector<int>> & input_shapes,bool asymmetric_quantize_inputs=false)32 BidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
33 int sequence_length, bool use_cifg,
34 bool use_peephole, bool use_projection_weights,
35 bool use_projection_bias, bool merge_outputs,
36 bool use_aux_input, float cell_clip, float proj_clip,
37 bool quantize_weights, bool time_major,
38 const std::vector<std::vector<int>>& input_shapes,
39 bool asymmetric_quantize_inputs = false)
40 : n_batch_(n_batch),
41 n_input_(n_input),
42 n_fw_cell_(n_cell),
43 n_bw_cell_(n_cell),
44 n_fw_output_(n_output),
45 n_bw_output_(n_output),
46 sequence_length_(sequence_length),
47 quantize_weights_(quantize_weights) {
48 input_ = AddInput(TensorType_FLOAT32);
49 const auto weight_type =
50 quantize_weights_ ? TensorType_UINT8 : TensorType_FLOAT32;
51
52 if (use_cifg) {
53 fw_input_to_input_weights_ = AddNullInput();
54 } else {
55 fw_input_to_input_weights_ = AddInput(weight_type);
56 }
57
58 fw_input_to_forget_weights_ = AddInput(weight_type);
59 fw_input_to_cell_weights_ = AddInput(weight_type);
60 fw_input_to_output_weights_ = AddInput(weight_type);
61
62 if (use_cifg) {
63 fw_recurrent_to_input_weights_ = AddNullInput();
64 } else {
65 fw_recurrent_to_input_weights_ = AddInput(weight_type);
66 }
67
68 fw_recurrent_to_forget_weights_ = AddInput(weight_type);
69 fw_recurrent_to_cell_weights_ = AddInput(weight_type);
70 fw_recurrent_to_output_weights_ = AddInput(weight_type);
71
72 if (use_peephole) {
73 if (use_cifg) {
74 fw_cell_to_input_weights_ = AddNullInput();
75 } else {
76 fw_cell_to_input_weights_ = AddInput(weight_type);
77 }
78 fw_cell_to_forget_weights_ = AddInput(weight_type);
79 fw_cell_to_output_weights_ = AddInput(weight_type);
80 } else {
81 fw_cell_to_input_weights_ = AddNullInput();
82 fw_cell_to_forget_weights_ = AddNullInput();
83 fw_cell_to_output_weights_ = AddNullInput();
84 }
85
86 if (use_cifg) {
87 fw_input_gate_bias_ = AddNullInput();
88 } else {
89 fw_input_gate_bias_ = AddInput(TensorType_FLOAT32);
90 }
91 fw_forget_gate_bias_ = AddInput(TensorType_FLOAT32);
92 fw_cell_gate_bias_ = AddInput(TensorType_FLOAT32);
93 fw_output_gate_bias_ = AddInput(TensorType_FLOAT32);
94
95 if (use_projection_weights) {
96 fw_projection_weights_ = AddInput(TensorType_FLOAT32);
97 if (use_projection_bias) {
98 fw_projection_bias_ = AddInput(TensorType_FLOAT32);
99 } else {
100 fw_projection_bias_ = AddNullInput();
101 }
102 } else {
103 fw_projection_weights_ = AddNullInput();
104 fw_projection_bias_ = AddNullInput();
105 }
106
107 if (use_cifg) {
108 bw_input_to_input_weights_ = AddNullInput();
109 } else {
110 bw_input_to_input_weights_ = AddInput(weight_type);
111 }
112
113 bw_input_to_forget_weights_ = AddInput(weight_type);
114 bw_input_to_cell_weights_ = AddInput(weight_type);
115 bw_input_to_output_weights_ = AddInput(weight_type);
116
117 if (use_cifg) {
118 bw_recurrent_to_input_weights_ = AddNullInput();
119 } else {
120 bw_recurrent_to_input_weights_ = AddInput(weight_type);
121 }
122
123 bw_recurrent_to_forget_weights_ = AddInput(weight_type);
124 bw_recurrent_to_cell_weights_ = AddInput(weight_type);
125 bw_recurrent_to_output_weights_ = AddInput(weight_type);
126
127 if (use_peephole) {
128 if (use_cifg) {
129 bw_cell_to_input_weights_ = AddNullInput();
130 } else {
131 bw_cell_to_input_weights_ = AddInput(weight_type);
132 }
133 bw_cell_to_forget_weights_ = AddInput(weight_type);
134 bw_cell_to_output_weights_ = AddInput(weight_type);
135 } else {
136 bw_cell_to_input_weights_ = AddNullInput();
137 bw_cell_to_forget_weights_ = AddNullInput();
138 bw_cell_to_output_weights_ = AddNullInput();
139 }
140
141 if (use_cifg) {
142 bw_input_gate_bias_ = AddNullInput();
143 } else {
144 bw_input_gate_bias_ = AddInput(TensorType_FLOAT32);
145 }
146 bw_forget_gate_bias_ = AddInput(TensorType_FLOAT32);
147 bw_cell_gate_bias_ = AddInput(TensorType_FLOAT32);
148 bw_output_gate_bias_ = AddInput(TensorType_FLOAT32);
149
150 if (use_projection_weights) {
151 bw_projection_weights_ = AddInput(weight_type);
152 if (use_projection_bias) {
153 bw_projection_bias_ = AddInput(TensorType_FLOAT32);
154 } else {
155 bw_projection_bias_ = AddNullInput();
156 }
157 } else {
158 bw_projection_weights_ = AddNullInput();
159 bw_projection_bias_ = AddNullInput();
160 }
161
162 // Adding the 2 input state tensors.
163 fw_input_activation_state_ = AddVariableInput(
164 TensorData{TensorType_FLOAT32, {n_fw_output_ * n_batch_}});
165 fw_input_cell_state_ = AddVariableInput(
166 TensorData{TensorType_FLOAT32, {n_fw_cell_ * n_batch_}});
167
168 // Adding the 2 input state tensors.
169 bw_input_activation_state_ = AddVariableInput(
170 TensorData{TensorType_FLOAT32, {n_bw_output_ * n_batch_}});
171 bw_input_cell_state_ = AddVariableInput(
172 TensorData{TensorType_FLOAT32, {n_bw_cell_ * n_batch_}});
173
174 fw_output_ = AddOutput(TensorType_FLOAT32);
175
176 if (!merge_outputs) {
177 bw_output_ = AddOutput(TensorType_FLOAT32);
178 }
179
180 if (use_aux_input) {
181 aux_input_ = AddInput(TensorType_FLOAT32);
182 fw_aux_input_to_input_weights_ = AddInput(weight_type);
183 fw_aux_input_to_forget_weights_ = AddInput(weight_type);
184 fw_aux_input_to_cell_weights_ = AddInput(weight_type);
185 fw_aux_input_to_output_weights_ = AddInput(weight_type);
186 bw_aux_input_to_input_weights_ = AddInput(weight_type);
187 bw_aux_input_to_forget_weights_ = AddInput(weight_type);
188 bw_aux_input_to_cell_weights_ = AddInput(weight_type);
189 bw_aux_input_to_output_weights_ = AddInput(weight_type);
190 } else {
191 aux_input_ = AddNullInput();
192 fw_aux_input_to_input_weights_ = AddNullInput();
193 fw_aux_input_to_forget_weights_ = AddNullInput();
194 fw_aux_input_to_cell_weights_ = AddNullInput();
195 fw_aux_input_to_output_weights_ = AddNullInput();
196 bw_aux_input_to_input_weights_ = AddNullInput();
197 bw_aux_input_to_forget_weights_ = AddNullInput();
198 bw_aux_input_to_cell_weights_ = AddNullInput();
199 bw_aux_input_to_output_weights_ = AddNullInput();
200 }
201
202 SetBuiltinOp(
203 BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
204 BuiltinOptions_BidirectionalSequenceLSTMOptions,
205 CreateBidirectionalSequenceLSTMOptions(
206 builder_, ActivationFunctionType_TANH, cell_clip, proj_clip,
207 merge_outputs, time_major, asymmetric_quantize_inputs)
208 .Union());
209 BuildInterpreter(input_shapes);
210 }
211
PopulateWeightTensor(int tensor_id,const std::vector<float> & f)212 void PopulateWeightTensor(int tensor_id, const std::vector<float>& f) {
213 if (quantize_weights_) {
214 SymmetricQuantizeAndPopulate(tensor_id, f);
215 } else {
216 PopulateTensor(tensor_id, f);
217 }
218 }
219
220 // Set weights in forward and backward cells to be the same.
SetInputToInputWeights(const std::vector<float> & f)221 void SetInputToInputWeights(const std::vector<float>& f) {
222 PopulateWeightTensor(fw_input_to_input_weights_, f);
223 PopulateWeightTensor(bw_input_to_input_weights_, f);
224 }
225
SetInputToForgetWeights(const std::vector<float> & f)226 void SetInputToForgetWeights(const std::vector<float>& f) {
227 PopulateWeightTensor(fw_input_to_forget_weights_, f);
228 PopulateWeightTensor(bw_input_to_forget_weights_, f);
229 }
230
SetInputToCellWeights(const std::vector<float> & f)231 void SetInputToCellWeights(const std::vector<float>& f) {
232 PopulateWeightTensor(fw_input_to_cell_weights_, f);
233 PopulateWeightTensor(bw_input_to_cell_weights_, f);
234 }
235
SetInputToOutputWeights(const std::vector<float> & f)236 void SetInputToOutputWeights(const std::vector<float>& f) {
237 PopulateWeightTensor(fw_input_to_output_weights_, f);
238 PopulateWeightTensor(bw_input_to_output_weights_, f);
239 }
240
SetRecurrentToInputWeights(const std::vector<float> & f)241 void SetRecurrentToInputWeights(const std::vector<float>& f) {
242 PopulateWeightTensor(fw_recurrent_to_input_weights_, f);
243 PopulateWeightTensor(bw_recurrent_to_input_weights_, f);
244 }
245
SetRecurrentToForgetWeights(const std::vector<float> & f)246 void SetRecurrentToForgetWeights(const std::vector<float>& f) {
247 PopulateWeightTensor(fw_recurrent_to_forget_weights_, f);
248 PopulateWeightTensor(bw_recurrent_to_forget_weights_, f);
249 }
250
SetRecurrentToCellWeights(const std::vector<float> & f)251 void SetRecurrentToCellWeights(const std::vector<float>& f) {
252 PopulateWeightTensor(fw_recurrent_to_cell_weights_, f);
253 PopulateWeightTensor(bw_recurrent_to_cell_weights_, f);
254 }
255
SetRecurrentToOutputWeights(const std::vector<float> & f)256 void SetRecurrentToOutputWeights(const std::vector<float>& f) {
257 PopulateWeightTensor(fw_recurrent_to_output_weights_, f);
258 PopulateWeightTensor(bw_recurrent_to_output_weights_, f);
259 }
260
SetCellToInputWeights(const std::vector<float> & f)261 void SetCellToInputWeights(const std::vector<float>& f) {
262 PopulateWeightTensor(fw_cell_to_input_weights_, f);
263 PopulateWeightTensor(bw_cell_to_input_weights_, f);
264 }
265
SetCellToForgetWeights(const std::vector<float> & f)266 void SetCellToForgetWeights(const std::vector<float>& f) {
267 PopulateWeightTensor(fw_cell_to_forget_weights_, f);
268 PopulateWeightTensor(bw_cell_to_forget_weights_, f);
269 }
270
SetCellToOutputWeights(const std::vector<float> & f)271 void SetCellToOutputWeights(const std::vector<float>& f) {
272 PopulateWeightTensor(fw_cell_to_output_weights_, f);
273 PopulateWeightTensor(bw_cell_to_output_weights_, f);
274 }
275
SetInputGateBias(const std::vector<float> & f)276 void SetInputGateBias(const std::vector<float>& f) {
277 PopulateTensor(fw_input_gate_bias_, f);
278 PopulateTensor(bw_input_gate_bias_, f);
279 }
280
SetForgetGateBias(const std::vector<float> & f)281 void SetForgetGateBias(const std::vector<float>& f) {
282 PopulateTensor(fw_forget_gate_bias_, f);
283 PopulateTensor(bw_forget_gate_bias_, f);
284 }
285
SetCellBias(const std::vector<float> & f)286 void SetCellBias(const std::vector<float>& f) {
287 PopulateTensor(fw_cell_gate_bias_, f);
288 PopulateTensor(bw_cell_gate_bias_, f);
289 }
290
SetOutputGateBias(const std::vector<float> & f)291 void SetOutputGateBias(const std::vector<float>& f) {
292 PopulateTensor(fw_output_gate_bias_, f);
293 PopulateTensor(bw_output_gate_bias_, f);
294 }
295
SetProjectionWeights(const std::vector<float> & f)296 void SetProjectionWeights(const std::vector<float>& f) {
297 PopulateWeightTensor(fw_projection_weights_, f);
298 PopulateWeightTensor(bw_projection_weights_, f);
299 }
300
SetProjectionBias(const std::vector<float> & f)301 void SetProjectionBias(const std::vector<float>& f) {
302 PopulateTensor(fw_projection_bias_, f);
303 PopulateTensor(bw_projection_bias_, f);
304 }
305
SetInput(int offset,float * begin,float * end)306 void SetInput(int offset, float* begin, float* end) {
307 PopulateTensor(input_, offset, begin, end);
308 }
309
SetAuxInput(int offset,float * begin,float * end)310 void SetAuxInput(int offset, float* begin, float* end) {
311 PopulateTensor(aux_input_, offset, begin, end);
312 }
313
SetAuxInputToInputWeights(const std::vector<float> & f)314 void SetAuxInputToInputWeights(const std::vector<float>& f) {
315 PopulateWeightTensor(fw_aux_input_to_input_weights_, f);
316 PopulateWeightTensor(bw_aux_input_to_input_weights_, f);
317 }
318
SetAuxInputToForgetWeights(const std::vector<float> & f)319 void SetAuxInputToForgetWeights(const std::vector<float>& f) {
320 PopulateWeightTensor(fw_aux_input_to_forget_weights_, f);
321 PopulateWeightTensor(bw_aux_input_to_forget_weights_, f);
322 }
323
SetAuxInputToCellWeights(const std::vector<float> & f)324 void SetAuxInputToCellWeights(const std::vector<float>& f) {
325 PopulateWeightTensor(fw_aux_input_to_cell_weights_, f);
326 PopulateWeightTensor(bw_aux_input_to_cell_weights_, f);
327 }
328
SetAuxInputToOutputWeights(const std::vector<float> & f)329 void SetAuxInputToOutputWeights(const std::vector<float>& f) {
330 PopulateWeightTensor(fw_aux_input_to_output_weights_, f);
331 PopulateWeightTensor(bw_aux_input_to_output_weights_, f);
332 }
333
GetFwOutput()334 std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
GetBwOutput()335 std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
336
num_inputs()337 int num_inputs() { return n_input_; }
num_fw_outputs()338 int num_fw_outputs() { return n_fw_output_; }
num_bw_outputs()339 int num_bw_outputs() { return n_bw_output_; }
num_fw_cells()340 int num_fw_cells() { return n_fw_cell_; }
num_bw_cells()341 int num_bw_cells() { return n_bw_cell_; }
num_batches()342 int num_batches() { return n_batch_; }
sequence_length()343 int sequence_length() { return sequence_length_; }
344
345 private:
346 int input_;
347 int fw_input_to_input_weights_;
348 int fw_input_to_forget_weights_;
349 int fw_input_to_cell_weights_;
350 int fw_input_to_output_weights_;
351
352 int fw_recurrent_to_input_weights_;
353 int fw_recurrent_to_forget_weights_;
354 int fw_recurrent_to_cell_weights_;
355 int fw_recurrent_to_output_weights_;
356
357 int fw_cell_to_input_weights_;
358 int fw_cell_to_forget_weights_;
359 int fw_cell_to_output_weights_;
360
361 int fw_input_gate_bias_;
362 int fw_forget_gate_bias_;
363 int fw_cell_gate_bias_;
364 int fw_output_gate_bias_;
365
366 int fw_projection_weights_;
367 int fw_projection_bias_;
368
369 int bw_input_to_input_weights_;
370 int bw_input_to_forget_weights_;
371 int bw_input_to_cell_weights_;
372 int bw_input_to_output_weights_;
373
374 int bw_recurrent_to_input_weights_;
375 int bw_recurrent_to_forget_weights_;
376 int bw_recurrent_to_cell_weights_;
377 int bw_recurrent_to_output_weights_;
378
379 int bw_cell_to_input_weights_;
380 int bw_cell_to_forget_weights_;
381 int bw_cell_to_output_weights_;
382
383 int bw_input_gate_bias_;
384 int bw_forget_gate_bias_;
385 int bw_cell_gate_bias_;
386 int bw_output_gate_bias_;
387
388 int bw_projection_weights_;
389 int bw_projection_bias_;
390
391 int fw_input_activation_state_;
392 int fw_input_cell_state_;
393 int bw_input_activation_state_;
394 int bw_input_cell_state_;
395
396 int fw_output_;
397 int bw_output_;
398
399 int aux_input_;
400 int fw_aux_input_to_input_weights_;
401 int fw_aux_input_to_forget_weights_;
402 int fw_aux_input_to_cell_weights_;
403 int fw_aux_input_to_output_weights_;
404 int bw_aux_input_to_input_weights_;
405 int bw_aux_input_to_forget_weights_;
406 int bw_aux_input_to_cell_weights_;
407 int bw_aux_input_to_output_weights_;
408
409 int n_batch_;
410 int n_input_;
411 int n_fw_cell_;
412 int n_bw_cell_;
413 int n_fw_output_;
414 int n_bw_output_;
415 int sequence_length_;
416
417 bool quantize_weights_;
418 };
419
420 // Declare LSTMOpTest as a parameterized test.
421 class LSTMOpTest
422 : public ::testing::TestWithParam<::testing::tuple<bool, bool>> {};
423
424 INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest,
425 ::testing::Combine(
426 /*quantize_weights*/ ::testing::Bool(),
427 /*asymmetric_quantize_inputs*/ ::testing::Bool()));
428
TEST_P(LSTMOpTest,BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping)429 TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
430 const int n_batch = 1;
431 const int n_input = 2;
432 // n_cell and n_output have the same size when there is no projection.
433 const int n_cell = 4;
434 const int n_output = 4;
435 const int sequence_length = 3;
436 auto params = GetParam();
437 const bool quantize_weights = std::get<0>(params);
438 const bool asymmetric_quantize_inputs = std::get<1>(params);
439
440 BidirectionalLSTMOpModel lstm(
441 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
442 /*use_peephole=*/false, /*use_projection_weights=*/false,
443 /*use_projection_bias=*/false, /*merge_outputs=*/false,
444 /*use_aux_input=*/false, /*cell_clip=*/0.0,
445 /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
446 {
447 {sequence_length, n_batch, n_input}, // input tensor
448
449 // Forward cell
450 {n_cell, n_input}, // input_to_input_weight tensor
451 {n_cell, n_input}, // input_to_forget_weight tensor
452 {n_cell, n_input}, // input_to_cell_weight tensor
453 {n_cell, n_input}, // input_to_output_weight tensor
454
455 {n_cell, n_output}, // recurrent_to_input_weight tensor
456 {n_cell, n_output}, // recurrent_to_forget_weight tensor
457 {n_cell, n_output}, // recurrent_to_cell_weight tensor
458 {n_cell, n_output}, // recurrent_to_output_weight tensor
459
460 {0}, // cell_to_input_weight tensor
461 {0}, // cell_to_forget_weight tensor
462 {0}, // cell_to_output_weight tensor
463
464 {n_cell}, // input_gate_bias tensor
465 {n_cell}, // forget_gate_bias tensor
466 {n_cell}, // cell_gate_bias tensor
467 {n_cell}, // output_gate_bias tensor
468
469 {0, 0}, // projection_weight tensor
470 {0}, // projection_bias tensor
471
472 // Backward cell
473 {n_cell, n_input}, // input_to_input_weight tensor
474 {n_cell, n_input}, // input_to_forget_weight tensor
475 {n_cell, n_input}, // input_to_cell_weight tensor
476 {n_cell, n_input}, // input_to_output_weight tensor
477
478 {n_cell, n_output}, // recurrent_to_input_weight tensor
479 {n_cell, n_output}, // recurrent_to_forget_weight tensor
480 {n_cell, n_output}, // recurrent_to_cell_weight tensor
481 {n_cell, n_output}, // recurrent_to_output_weight tensor
482
483 {0}, // cell_to_input_weight tensor
484 {0}, // cell_to_forget_weight tensor
485 {0}, // cell_to_output_weight tensor
486
487 {n_cell}, // input_gate_bias tensor
488 {n_cell}, // forget_gate_bias tensor
489 {n_cell}, // cell_gate_bias tensor
490 {n_cell}, // output_gate_bias tensor
491
492 {0, 0}, // projection_weight tensor
493 {0}, // projection_bias tensor
494
495 {n_batch, n_output}, // activation_state tensor
496 {n_batch, n_cell}, // cell_state tensor
497
498 {n_batch, n_output}, // activation_state tensor
499 {n_batch, n_cell}, // cell_state tensor
500
501 {sequence_length, n_batch, 0}, // aux_input tensor
502 {0}, // aux_fw_input_to_input tensor
503 {0}, // aux_fw_input_to_forget tensor
504 {0}, // aux_fw_input_to_cell tensor
505 {0}, // aux_fw_input_to_output tensor
506 {0}, // aux_bw_input_to_input tensor
507 {0}, // aux_bw_input_to_forget tensor
508 {0}, // aux_bw_input_to_cell tensor
509 {0}, // aux_bw_input_to_output tensor
510 },
511 asymmetric_quantize_inputs);
512
513 lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
514 -0.34550029, 0.04266912, -0.15680569,
515 -0.34856534, 0.43890524});
516
517 lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
518 -0.20583314, 0.44344562, 0.22077113,
519 -0.29909778});
520
521 lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
522 -0.31343272, -0.40032279, 0.44781327,
523 0.01387155, -0.35593212});
524
525 lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
526 0.40525138, 0.44272184, 0.03897077, -0.1556896,
527 0.19487578});
528
529 lstm.SetInputGateBias({0., 0., 0., 0.});
530
531 lstm.SetCellBias({0., 0., 0., 0.});
532
533 lstm.SetForgetGateBias({1., 1., 1., 1.});
534
535 lstm.SetOutputGateBias({0., 0., 0., 0.});
536
537 lstm.SetRecurrentToInputWeights(
538 {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
539 -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
540 -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
541
542 lstm.SetRecurrentToCellWeights(
543 {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
544 -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
545 -0.46367589, 0.26016325, -0.03894562, -0.16368064});
546
547 lstm.SetRecurrentToForgetWeights(
548 {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
549 -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
550 0.28053468, 0.01560611, -0.20127171, -0.01140004});
551
552 lstm.SetRecurrentToOutputWeights(
553 {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
554 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
555 -0.51818722, -0.15390486, 0.0468148, 0.39922136});
556
557 // Input should have n_input * sequence_length many values.
558 static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
559 static float lstm_fw_golden_output[] = {
560 -0.02973187, 0.1229473, 0.20885126, -0.15358765,
561 -0.03716109, 0.12507336, 0.41193449, -0.20860538,
562 -0.15053082, 0.09120187, 0.24278517, -0.12222792};
563 static float lstm_bw_golden_output[] = {
564 -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
565 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
566
567 float* batch0_start = lstm_input;
568 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
569
570 lstm.SetInput(0, batch0_start, batch0_end);
571
572 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
573
574 float* fw_golden_start = lstm_fw_golden_output;
575 float* fw_golden_end =
576 fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length();
577 std::vector<float> fw_expected;
578 fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end);
579 EXPECT_THAT(lstm.GetFwOutput(),
580 ElementsAreArray(
581 ArrayFloatNear(fw_expected, quantize_weights ? 1e-2 : 1e-5)));
582
583 float* bw_golden_start = lstm_bw_golden_output;
584 float* bw_golden_end =
585 bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length();
586 std::vector<float> bw_expected;
587 bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
588 EXPECT_THAT(lstm.GetBwOutput(),
589 ElementsAreArray(
590 ArrayFloatNear(bw_expected, quantize_weights ? 1e-2 : 1e-5)));
591 }
592
593 // Same as the previous test, yet with a single merged output tensor and n_batch
594 // of 2.
TEST_P(LSTMOpTest,BlackBoxTestMergedOutput)595 TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
596 const int n_batch = 2;
597 const int n_input = 2;
598 // n_cell and n_output have the same size when there is no projection.
599 const int n_cell = 4;
600 const int n_output = 4;
601 const int sequence_length = 3;
602 auto params = GetParam();
603 const bool quantize_weights = std::get<0>(params);
604 const bool asymmetric_quantize_inputs = std::get<1>(params);
605
606 BidirectionalLSTMOpModel lstm(
607 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
608 /*use_peephole=*/false, /*use_projection_weights=*/false,
609 /*use_projection_bias=*/false, /*merge_outputs=*/true,
610 /*use_aux_input=*/false, /*cell_clip=*/0.0,
611 /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
612 {
613 {sequence_length, n_batch, n_input}, // input tensor
614
615 // Forward cell
616 {n_cell, n_input}, // input_to_input_weight tensor
617 {n_cell, n_input}, // input_to_forget_weight tensor
618 {n_cell, n_input}, // input_to_cell_weight tensor
619 {n_cell, n_input}, // input_to_output_weight tensor
620
621 {n_cell, n_output}, // recurrent_to_input_weight tensor
622 {n_cell, n_output}, // recurrent_to_forget_weight tensor
623 {n_cell, n_output}, // recurrent_to_cell_weight tensor
624 {n_cell, n_output}, // recurrent_to_output_weight tensor
625
626 {0}, // cell_to_input_weight tensor
627 {0}, // cell_to_forget_weight tensor
628 {0}, // cell_to_output_weight tensor
629
630 {n_cell}, // input_gate_bias tensor
631 {n_cell}, // forget_gate_bias tensor
632 {n_cell}, // cell_gate_bias tensor
633 {n_cell}, // output_gate_bias tensor
634
635 {0, 0}, // projection_weight tensor
636 {0}, // projection_bias tensor
637
638 // Backward cell
639 {n_cell, n_input}, // input_to_input_weight tensor
640 {n_cell, n_input}, // input_to_forget_weight tensor
641 {n_cell, n_input}, // input_to_cell_weight tensor
642 {n_cell, n_input}, // input_to_output_weight tensor
643
644 {n_cell, n_output}, // recurrent_to_input_weight tensor
645 {n_cell, n_output}, // recurrent_to_forget_weight tensor
646 {n_cell, n_output}, // recurrent_to_cell_weight tensor
647 {n_cell, n_output}, // recurrent_to_output_weight tensor
648
649 {0}, // cell_to_input_weight tensor
650 {0}, // cell_to_forget_weight tensor
651 {0}, // cell_to_output_weight tensor
652
653 {n_cell}, // input_gate_bias tensor
654 {n_cell}, // forget_gate_bias tensor
655 {n_cell}, // cell_gate_bias tensor
656 {n_cell}, // output_gate_bias tensor
657
658 {0, 0}, // projection_weight tensor
659 {0}, // projection_bias tensor
660
661 {n_batch, n_output}, // activation_state tensor
662 {n_batch, n_cell}, // cell_state tensor
663
664 {n_batch, n_output}, // activation_state tensor
665 {n_batch, n_cell}, // cell_state tensor
666
667 {sequence_length, n_batch, 0}, // aux_input tensor
668 {0}, // aux_fw_input_to_input tensor
669 {0}, // aux_fw_input_to_forget tensor
670 {0}, // aux_fw_input_to_cell tensor
671 {0}, // aux_fw_input_to_output tensor
672 {0}, // aux_bw_input_to_input tensor
673 {0}, // aux_bw_input_to_forget tensor
674 {0}, // aux_bw_input_to_cell tensor
675 {0}, // aux_bw_input_to_output tensor
676 },
677 asymmetric_quantize_inputs);
678
679 lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
680 -0.34550029, 0.04266912, -0.15680569,
681 -0.34856534, 0.43890524});
682
683 lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
684 -0.20583314, 0.44344562, 0.22077113,
685 -0.29909778});
686
687 lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
688 -0.31343272, -0.40032279, 0.44781327,
689 0.01387155, -0.35593212});
690
691 lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
692 0.40525138, 0.44272184, 0.03897077, -0.1556896,
693 0.19487578});
694
695 lstm.SetInputGateBias({0., 0., 0., 0.});
696
697 lstm.SetCellBias({0., 0., 0., 0.});
698
699 lstm.SetForgetGateBias({1., 1., 1., 1.});
700
701 lstm.SetOutputGateBias({0., 0., 0., 0.});
702
703 lstm.SetRecurrentToInputWeights(
704 {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
705 -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
706 -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
707
708 lstm.SetRecurrentToCellWeights(
709 {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
710 -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
711 -0.46367589, 0.26016325, -0.03894562, -0.16368064});
712
713 lstm.SetRecurrentToForgetWeights(
714 {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
715 -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
716 0.28053468, 0.01560611, -0.20127171, -0.01140004});
717
718 lstm.SetRecurrentToOutputWeights(
719 {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
720 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
721 -0.51818722, -0.15390486, 0.0468148, 0.39922136});
722
723 // Input should have n_input * sequence_length many values.
724 static float lstm_input[] = {2., 3., 2., 3., 3., 4., 3., 4., 1., 1., 1., 1.};
725 static float lstm_fw_golden_output[] = {
726 -0.02973187, 0.1229473, 0.20885126, -0.15358765, -0.02973187,
727 0.1229473, 0.20885126, -0.15358765, -0.03716109, 0.12507336,
728 0.41193449, -0.20860538, -0.03716109, 0.12507336, 0.41193449,
729 -0.20860538, -0.15053082, 0.09120187, 0.24278517, -0.12222792,
730 -0.15053082, 0.09120187, 0.24278517, -0.12222792};
731 static float lstm_bw_golden_output[] = {
732 -0.0806187, 0.139077, 0.400476, -0.197842, -0.0806187, 0.139077,
733 0.400476, -0.197842, -0.0332076, 0.123838, 0.309777, -0.17621,
734 -0.0332076, 0.123838, 0.309777, -0.17621, -0.0490733, 0.0739237,
735 0.067706, -0.0208124, -0.0490733, 0.0739237, 0.067706, -0.0208124};
736
737 float* batch0_start = lstm_input;
738 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.num_batches() *
739 lstm.sequence_length();
740
741 lstm.SetInput(0, batch0_start, batch0_end);
742
743 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
744
745 std::vector<float> merged_expected;
746 for (int k = 0; k < lstm.sequence_length() * lstm.num_batches(); k++) {
747 merged_expected.insert(
748 merged_expected.end(),
749 lstm_fw_golden_output + k * lstm.num_fw_outputs(),
750 lstm_fw_golden_output + (k + 1) * lstm.num_fw_outputs());
751 merged_expected.insert(
752 merged_expected.end(),
753 lstm_bw_golden_output + k * lstm.num_bw_outputs(),
754 lstm_bw_golden_output + (k + 1) * lstm.num_bw_outputs());
755 }
756 EXPECT_THAT(lstm.GetFwOutput(),
757 ElementsAreArray(ArrayFloatNear(merged_expected,
758 quantize_weights ? 1e-2 : 1e-5)));
759 }
760
TEST(LSTMOpTest,BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse)761 TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
762 const int n_batch = 1;
763 const int n_input = 2;
764 // n_cell and n_output have the same size when there is no projection.
765 const int n_cell = 4;
766 const int n_output = 4;
767 const int sequence_length = 3;
768
769 BidirectionalLSTMOpModel lstm(
770 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
771 /*use_peephole=*/false, /*use_projection_weights=*/false,
772 /*use_projection_bias=*/false, /*merge_outputs=*/false,
773 /*use_aux_input=*/false, /*cell_clip=*/0.0,
774 /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
775 {
776 {sequence_length, n_batch, n_input}, // input tensor
777
778 // Forward cell
779 {n_cell, n_input}, // input_to_input_weight tensor
780 {n_cell, n_input}, // input_to_forget_weight tensor
781 {n_cell, n_input}, // input_to_cell_weight tensor
782 {n_cell, n_input}, // input_to_output_weight tensor
783
784 {n_cell, n_output}, // recurrent_to_input_weight tensor
785 {n_cell, n_output}, // recurrent_to_forget_weight tensor
786 {n_cell, n_output}, // recurrent_to_cell_weight tensor
787 {n_cell, n_output}, // recurrent_to_output_weight tensor
788
789 {0}, // cell_to_input_weight tensor
790 {0}, // cell_to_forget_weight tensor
791 {0}, // cell_to_output_weight tensor
792
793 {n_cell}, // input_gate_bias tensor
794 {n_cell}, // forget_gate_bias tensor
795 {n_cell}, // cell_gate_bias tensor
796 {n_cell}, // output_gate_bias tensor
797
798 {0, 0}, // projection_weight tensor
799 {0}, // projection_bias tensor
800
801 // Backward cell
802 {n_cell, n_input}, // input_to_input_weight tensor
803 {n_cell, n_input}, // input_to_forget_weight tensor
804 {n_cell, n_input}, // input_to_cell_weight tensor
805 {n_cell, n_input}, // input_to_output_weight tensor
806
807 {n_cell, n_output}, // recurrent_to_input_weight tensor
808 {n_cell, n_output}, // recurrent_to_forget_weight tensor
809 {n_cell, n_output}, // recurrent_to_cell_weight tensor
810 {n_cell, n_output}, // recurrent_to_output_weight tensor
811
812 {0}, // cell_to_input_weight tensor
813 {0}, // cell_to_forget_weight tensor
814 {0}, // cell_to_output_weight tensor
815
816 {n_cell}, // input_gate_bias tensor
817 {n_cell}, // forget_gate_bias tensor
818 {n_cell}, // cell_gate_bias tensor
819 {n_cell}, // output_gate_bias tensor
820
821 {0, 0}, // projection_weight tensor
822 {0}, // projection_bias tensor
823
824 {n_batch, n_output}, // activation_state tensor
825 {n_batch, n_cell}, // cell_state tensor
826
827 {n_batch, n_output}, // activation_state tensor
828 {n_batch, n_cell}, // cell_state tensor
829
830 {sequence_length, n_batch, 0}, // aux_input tensor
831 {0}, // aux_fw_input_to_input tensor
832 {0}, // aux_fw_input_to_forget tensor
833 {0}, // aux_fw_input_to_cell tensor
834 {0}, // aux_fw_input_to_output tensor
835 {0}, // aux_bw_input_to_input tensor
836 {0}, // aux_bw_input_to_forget tensor
837 {0}, // aux_bw_input_to_cell tensor
838 {0}, // aux_bw_input_to_output tensor
839 });
840
841 lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
842 -0.34550029, 0.04266912, -0.15680569,
843 -0.34856534, 0.43890524});
844
845 lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
846 -0.20583314, 0.44344562, 0.22077113,
847 -0.29909778});
848
849 lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
850 -0.31343272, -0.40032279, 0.44781327,
851 0.01387155, -0.35593212});
852
853 lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
854 0.40525138, 0.44272184, 0.03897077, -0.1556896,
855 0.19487578});
856
857 lstm.SetInputGateBias({0., 0., 0., 0.});
858
859 lstm.SetCellBias({0., 0., 0., 0.});
860
861 lstm.SetForgetGateBias({1., 1., 1., 1.});
862
863 lstm.SetOutputGateBias({0., 0., 0., 0.});
864
865 lstm.SetRecurrentToInputWeights(
866 {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
867 -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
868 -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
869
870 lstm.SetRecurrentToCellWeights(
871 {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
872 -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
873 -0.46367589, 0.26016325, -0.03894562, -0.16368064});
874
875 lstm.SetRecurrentToForgetWeights(
876 {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
877 -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
878 0.28053468, 0.01560611, -0.20127171, -0.01140004});
879
880 lstm.SetRecurrentToOutputWeights(
881 {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
882 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
883 -0.51818722, -0.15390486, 0.0468148, 0.39922136});
884
885 // Input should have n_input * sequence_length many values.
886 // Check reversed inputs.
887 static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
888 static float lstm_fw_golden_output[] = {
889 -0.02973187, 0.1229473, 0.20885126, -0.15358765,
890 -0.03716109, 0.12507336, 0.41193449, -0.20860538,
891 -0.15053082, 0.09120187, 0.24278517, -0.12222792};
892 static float lstm_bw_golden_output[] = {
893 -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
894 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
895
896 float* batch0_start = lstm_input_reversed;
897 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
898
899 lstm.SetInput(0, batch0_start, batch0_end);
900
901 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
902
903 std::vector<float> fw_expected;
904 for (int s = 0; s < lstm.sequence_length(); s++) {
905 float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
906 float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
907 fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
908 }
909 EXPECT_THAT(lstm.GetBwOutput(),
910 ElementsAreArray(ArrayFloatNear(fw_expected)));
911
912 std::vector<float> bw_expected;
913 for (int s = 0; s < lstm.sequence_length(); s++) {
914 float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
915 float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
916 bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
917 }
918 EXPECT_THAT(lstm.GetFwOutput(),
919 ElementsAreArray(ArrayFloatNear(bw_expected)));
920 }
921
TEST(LSTMOpTest,BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping)922 TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
923 const int n_batch = 1;
924 const int n_input = 2;
925 // n_cell and n_output have the same size when there is no projection.
926 const int n_cell = 4;
927 const int n_output = 4;
928 const int sequence_length = 3;
929
930 BidirectionalLSTMOpModel lstm(
931 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
932 /*use_peephole=*/true, /*use_projection_weights=*/false,
933 /*use_projection_bias=*/false, /*merge_outputs=*/false,
934 /*use_aux_input=*/false, /*cell_clip=*/0.0,
935 /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
936 {
937 {sequence_length, n_batch, n_input}, // input tensor
938
939 {0, 0}, // input_to_input_weight tensor
940 {n_cell, n_input}, // input_to_forget_weight tensor
941 {n_cell, n_input}, // input_to_cell_weight tensor
942 {n_cell, n_input}, // input_to_output_weight tensor
943
944 {0, 0}, // recurrent_to_input_weight tensor
945 {n_cell, n_output}, // recurrent_to_forget_weight tensor
946 {n_cell, n_output}, // recurrent_to_cell_weight tensor
947 {n_cell, n_output}, // recurrent_to_output_weight tensor
948
949 {0}, // cell_to_input_weight tensor
950 {n_cell}, // cell_to_forget_weight tensor
951 {n_cell}, // cell_to_output_weight tensor
952
953 {0}, // input_gate_bias tensor
954 {n_cell}, // forget_gate_bias tensor
955 {n_cell}, // cell_gate_bias tensor
956 {n_cell}, // output_gate_bias tensor
957
958 {0, 0}, // projection_weight tensor
959 {0}, // projection_bias tensor
960
961 {0, 0}, // input_to_input_weight tensor
962 {n_cell, n_input}, // input_to_forget_weight tensor
963 {n_cell, n_input}, // input_to_cell_weight tensor
964 {n_cell, n_input}, // input_to_output_weight tensor
965
966 {0, 0}, // recurrent_to_input_weight tensor
967 {n_cell, n_output}, // recurrent_to_forget_weight tensor
968 {n_cell, n_output}, // recurrent_to_cell_weight tensor
969 {n_cell, n_output}, // recurrent_to_output_weight tensor
970
971 {0}, // cell_to_input_weight tensor
972 {n_cell}, // cell_to_forget_weight tensor
973 {n_cell}, // cell_to_output_weight tensor
974
975 {0}, // input_gate_bias tensor
976 {n_cell}, // forget_gate_bias tensor
977 {n_cell}, // cell_gate_bias tensor
978 {n_cell}, // output_gate_bias tensor
979
980 {0, 0}, // projection_weight tensor
981 {0}, // projection_bias tensor
982
983 {n_batch, n_output}, // activation_state tensor
984 {n_batch, n_cell}, // cell_state tensor
985
986 {n_batch, n_output}, // activation_state tensor
987 {n_batch, n_cell}, // cell_state tensor
988
989 {sequence_length, n_batch, 0}, // aux_input tensor
990 {0}, // aux_fw_input_to_input tensor
991 {0}, // aux_fw_input_to_forget tensor
992 {0}, // aux_fw_input_to_cell tensor
993 {0}, // aux_fw_input_to_output tensor
994 {0}, // aux_bw_input_to_input tensor
995 {0}, // aux_bw_input_to_forget tensor
996 {0}, // aux_bw_input_to_cell tensor
997 {0}, // aux_bw_input_to_output tensor
998 });
999
1000 lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
1001 0.04717243, 0.48944736, -0.38535351,
1002 -0.17212132});
1003
1004 lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
1005 -0.3633365, -0.22755712, 0.28253698, 0.24407166,
1006 0.33826375});
1007
1008 lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
1009 -0.09426838, -0.44257352, 0.54939759,
1010 0.01533556, 0.42751634});
1011
1012 lstm.SetCellBias({0., 0., 0., 0.});
1013
1014 lstm.SetForgetGateBias({1., 1., 1., 1.});
1015
1016 lstm.SetOutputGateBias({0., 0., 0., 0.});
1017
1018 lstm.SetRecurrentToCellWeights(
1019 {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
1020 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
1021 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
1022 0.21193194});
1023
1024 lstm.SetRecurrentToForgetWeights(
1025 {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
1026 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
1027 -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
1028
1029 lstm.SetRecurrentToOutputWeights(
1030 {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
1031 -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
1032 0.50248802, 0.26114327, -0.43736315, 0.33149987});
1033
1034 lstm.SetCellToForgetWeights(
1035 {0.47485286, -0.51955009, -0.24458408, 0.31544167});
1036 lstm.SetCellToOutputWeights(
1037 {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
1038
1039 static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
1040 static float lstm_fw_golden_output[] = {
1041 -0.36444446, -0.00352185, 0.12886585, -0.05163646,
1042 -0.42312205, -0.01218222, 0.24201041, -0.08124574,
1043 -0.358325, -0.04621704, 0.21641694, -0.06471302};
1044 static float lstm_bw_golden_output[] = {
1045 -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577,
1046 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578};
1047
1048 float* batch0_start = lstm_input;
1049 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
1050
1051 lstm.SetInput(0, batch0_start, batch0_end);
1052
1053 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
1054
1055 float* fw_golden_start = lstm_fw_golden_output;
1056 float* fw_golden_end =
1057 fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length();
1058 std::vector<float> fw_expected;
1059 fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end);
1060 EXPECT_THAT(lstm.GetFwOutput(),
1061 ElementsAreArray(ArrayFloatNear(fw_expected)));
1062
1063 float* bw_golden_start = lstm_bw_golden_output;
1064 float* bw_golden_end =
1065 bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length();
1066 std::vector<float> bw_expected;
1067 bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
1068 EXPECT_THAT(lstm.GetBwOutput(),
1069 ElementsAreArray(ArrayFloatNear(bw_expected)));
1070 }
1071
TEST(LSTMOpTest,BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed)1072 TEST(LSTMOpTest,
1073 BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed) {
1074 const int n_batch = 1;
1075 const int n_input = 2;
1076 // n_cell and n_output have the same size when there is no projection.
1077 const int n_cell = 4;
1078 const int n_output = 4;
1079 const int sequence_length = 3;
1080
1081 BidirectionalLSTMOpModel lstm(
1082 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
1083 /*use_peephole=*/true, /*use_projection_weights=*/false,
1084 /*use_projection_bias=*/false, /*merge_outputs=*/false,
1085 /*use_aux_input=*/false, /*cell_clip=*/0.0,
1086 /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
1087 {
1088 {sequence_length, n_batch, n_input}, // input tensor
1089
1090 {0, 0}, // input_to_input_weight tensor
1091 {n_cell, n_input}, // input_to_forget_weight tensor
1092 {n_cell, n_input}, // input_to_cell_weight tensor
1093 {n_cell, n_input}, // input_to_output_weight tensor
1094
1095 {0, 0}, // recurrent_to_input_weight tensor
1096 {n_cell, n_output}, // recurrent_to_forget_weight tensor
1097 {n_cell, n_output}, // recurrent_to_cell_weight tensor
1098 {n_cell, n_output}, // recurrent_to_output_weight tensor
1099
1100 {0}, // cell_to_input_weight tensor
1101 {n_cell}, // cell_to_forget_weight tensor
1102 {n_cell}, // cell_to_output_weight tensor
1103
1104 {0}, // input_gate_bias tensor
1105 {n_cell}, // forget_gate_bias tensor
1106 {n_cell}, // cell_gate_bias tensor
1107 {n_cell}, // output_gate_bias tensor
1108
1109 {0, 0}, // projection_weight tensor
1110 {0}, // projection_bias tensor
1111
1112 {0, 0}, // input_to_input_weight tensor
1113 {n_cell, n_input}, // input_to_forget_weight tensor
1114 {n_cell, n_input}, // input_to_cell_weight tensor
1115 {n_cell, n_input}, // input_to_output_weight tensor
1116
1117 {0, 0}, // recurrent_to_input_weight tensor
1118 {n_cell, n_output}, // recurrent_to_forget_weight tensor
1119 {n_cell, n_output}, // recurrent_to_cell_weight tensor
1120 {n_cell, n_output}, // recurrent_to_output_weight tensor
1121
1122 {0}, // cell_to_input_weight tensor
1123 {n_cell}, // cell_to_forget_weight tensor
1124 {n_cell}, // cell_to_output_weight tensor
1125
1126 {0}, // input_gate_bias tensor
1127 {n_cell}, // forget_gate_bias tensor
1128 {n_cell}, // cell_gate_bias tensor
1129 {n_cell}, // output_gate_bias tensor
1130
1131 {0, 0}, // projection_weight tensor
1132 {0}, // projection_bias tensor
1133
1134 {n_batch, n_output}, // activation_state tensor
1135 {n_batch, n_cell}, // cell_state tensor
1136
1137 {n_batch, n_output}, // activation_state tensor
1138 {n_batch, n_cell}, // cell_state tensor
1139
1140 {sequence_length, n_batch, 0}, // aux_input tensor
1141 {0}, // aux_fw_input_to_input tensor
1142 {0}, // aux_fw_input_to_forget tensor
1143 {0}, // aux_fw_input_to_cell tensor
1144 {0}, // aux_fw_input_to_output tensor
1145 {0}, // aux_bw_input_to_input tensor
1146 {0}, // aux_bw_input_to_forget tensor
1147 {0}, // aux_bw_input_to_cell tensor
1148 {0}, // aux_bw_input_to_output tensor
1149 });
1150
1151 lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
1152 0.04717243, 0.48944736, -0.38535351,
1153 -0.17212132});
1154
1155 lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
1156 -0.3633365, -0.22755712, 0.28253698, 0.24407166,
1157 0.33826375});
1158
1159 lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
1160 -0.09426838, -0.44257352, 0.54939759,
1161 0.01533556, 0.42751634});
1162
1163 lstm.SetCellBias({0., 0., 0., 0.});
1164
1165 lstm.SetForgetGateBias({1., 1., 1., 1.});
1166
1167 lstm.SetOutputGateBias({0., 0., 0., 0.});
1168
1169 lstm.SetRecurrentToCellWeights(
1170 {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
1171 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
1172 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
1173 0.21193194});
1174
1175 lstm.SetRecurrentToForgetWeights(
1176 {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
1177 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
1178 -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
1179
1180 lstm.SetRecurrentToOutputWeights(
1181 {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
1182 -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
1183 0.50248802, 0.26114327, -0.43736315, 0.33149987});
1184
1185 lstm.SetCellToForgetWeights(
1186 {0.47485286, -0.51955009, -0.24458408, 0.31544167});
1187 lstm.SetCellToOutputWeights(
1188 {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
1189
1190 static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
1191 static float lstm_fw_golden_output[] = {
1192 -0.36444446, -0.00352185, 0.12886585, -0.05163646,
1193 -0.42312205, -0.01218222, 0.24201041, -0.08124574,
1194 -0.358325, -0.04621704, 0.21641694, -0.06471302};
1195 static float lstm_bw_golden_output[] = {
1196 -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577,
1197 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578};
1198
1199 float* batch0_start = lstm_input_reversed;
1200 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
1201
1202 lstm.SetInput(0, batch0_start, batch0_end);
1203
1204 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
1205
1206 std::vector<float> fw_expected;
1207 for (int s = 0; s < lstm.sequence_length(); s++) {
1208 float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
1209 float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
1210 fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
1211 }
1212 EXPECT_THAT(lstm.GetBwOutput(),
1213 ElementsAreArray(ArrayFloatNear(fw_expected)));
1214
1215 std::vector<float> bw_expected;
1216 for (int s = 0; s < lstm.sequence_length(); s++) {
1217 float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
1218 float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
1219 bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
1220 }
1221 EXPECT_THAT(lstm.GetFwOutput(),
1222 ElementsAreArray(ArrayFloatNear(bw_expected)));
1223 }
1224
TEST(LSTMOpTest,BlackBoxTestWithPeepholeWithProjectionNoClipping)1225 TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
1226 const int n_batch = 2;
1227 const int n_input = 5;
1228 const int n_cell = 20;
1229 const int n_output = 16;
1230 const int sequence_length = 4;
1231
1232 BidirectionalLSTMOpModel lstm(
1233 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
1234 /*use_peephole=*/true, /*use_projection_weights=*/true,
1235 /*use_projection_bias=*/false, /*merge_outputs=*/false,
1236 /*use_aux_input=*/false, /*cell_clip=*/0.0,
1237 /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true,
1238 {
1239 {sequence_length, n_batch, n_input}, // input tensor
1240
1241 {n_cell, n_input}, // input_to_input_weight tensor
1242 {n_cell, n_input}, // input_to_forget_weight tensor
1243 {n_cell, n_input}, // input_to_cell_weight tensor
1244 {n_cell, n_input}, // input_to_output_weight tensor
1245
1246 {n_cell, n_output}, // recurrent_to_input_weight tensor
1247 {n_cell, n_output}, // recurrent_to_forget_weight tensor
1248 {n_cell, n_output}, // recurrent_to_cell_weight tensor
1249 {n_cell, n_output}, // recurrent_to_output_weight tensor
1250
1251 {n_cell}, // cell_to_input_weight tensor
1252 {n_cell}, // cell_to_forget_weight tensor
1253 {n_cell}, // cell_to_output_weight tensor
1254
1255 {n_cell}, // input_gate_bias tensor
1256 {n_cell}, // forget_gate_bias tensor
1257 {n_cell}, // cell_gate_bias tensor
1258 {n_cell}, // output_gate_bias tensor
1259
1260 {n_output, n_cell}, // projection_weight tensor
1261 {0}, // projection_bias tensor
1262
1263 {n_cell, n_input}, // input_to_input_weight tensor
1264 {n_cell, n_input}, // input_to_forget_weight tensor
1265 {n_cell, n_input}, // input_to_cell_weight tensor
1266 {n_cell, n_input}, // input_to_output_weight tensor
1267
1268 {n_cell, n_output}, // recurrent_to_input_weight tensor
1269 {n_cell, n_output}, // recurrent_to_forget_weight tensor
1270 {n_cell, n_output}, // recurrent_to_cell_weight tensor
1271 {n_cell, n_output}, // recurrent_to_output_weight tensor
1272
1273 {n_cell}, // cell_to_input_weight tensor
1274 {n_cell}, // cell_to_forget_weight tensor
1275 {n_cell}, // cell_to_output_weight tensor
1276
1277 {n_cell}, // input_gate_bias tensor
1278 {n_cell}, // forget_gate_bias tensor
1279 {n_cell}, // cell_gate_bias tensor
1280 {n_cell}, // output_gate_bias tensor
1281
1282 {n_output, n_cell}, // projection_weight tensor
1283 {0}, // projection_bias tensor
1284
1285 {n_batch, n_output}, // activation_state tensor
1286 {n_batch, n_cell}, // cell_state tensor
1287
1288 {n_batch, n_output}, // activation_state tensor
1289 {n_batch, n_cell}, // cell_state tensor
1290
1291 {sequence_length, n_batch, 0}, // aux_input tensor
1292 {0}, // aux_fw_input_to_input tensor
1293 {0}, // aux_fw_input_to_forget tensor
1294 {0}, // aux_fw_input_to_cell tensor
1295 {0}, // aux_fw_input_to_output tensor
1296 {0}, // aux_bw_input_to_input tensor
1297 {0}, // aux_bw_input_to_forget tensor
1298 {0}, // aux_bw_input_to_cell tensor
1299 {0}, // aux_bw_input_to_output tensor
1300 });
1301
1302 lstm.SetInputToInputWeights(
1303 {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
1304 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
1305 -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
1306 -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
1307 -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
1308 -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
1309 -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
1310 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
1311 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
1312 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
1313 -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
1314 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
1315 -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
1316 -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
1317 -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
1318 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
1319 -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
1320 -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
1321 -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
1322 -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
1323
1324 lstm.SetInputToForgetWeights(
1325 {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
1326 -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
1327 -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
1328 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
1329 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
1330 -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
1331 -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
1332 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
1333 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
1334 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
1335 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
1336 -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
1337 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
1338 -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
1339 -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
1340 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
1341 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
1342 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
1343 -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
1344 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
1345
1346 lstm.SetInputToCellWeights(
1347 {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
1348 -0.043528453, 0.043018587, -0.049152344, -0.12418144,
1349 -0.078985475, -0.07596889, 0.019484362, -0.11434962,
1350 -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
1351 -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
1352 0.10665918, -0.032036792, -0.08505916, -0.10843358,
1353 -0.13002433, -0.036816437, -0.02130134, -0.016518239,
1354 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
1355 -0.10652836, -0.1037554, -0.13056071, -0.03266643,
1356 -0.033702414, -0.006473424, -0.04611692, 0.014419339,
1357 -0.025174323, 0.0396852, 0.081777506, 0.06157468,
1358 0.10210095, -0.009658194, 0.046511717, 0.03603906,
1359 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
1360 0.053568836, 0.06408714, 0.12835667, -0.008714329,
1361 -0.20211966, -0.12093674, 0.029450472, 0.2849013,
1362 -0.029227901, 0.1164364, -0.08560263, 0.09941786,
1363 -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
1364 -0.09720865, -0.11193351, -0.029155117, -0.017936034,
1365 -0.009768936, -0.04223324, -0.036159635, 0.06505112,
1366 -0.021742892, -0.023377212, -0.07221364, -0.06430552,
1367 0.05453865, 0.091149814, 0.06387331, 0.007518393,
1368 0.055960953, 0.069779344, 0.046411168, 0.10509911,
1369 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
1370 0.056955688, 0.06555285, 0.050801456, -0.009862683,
1371 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
1372
1373 lstm.SetInputToOutputWeights(
1374 {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
1375 -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
1376 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
1377 -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
1378 -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
1379 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
1380 -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
1381 -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
1382 -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
1383 -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
1384 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
1385 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
1386 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
1387 -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
1388 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
1389 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
1390 -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
1391 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
1392 -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
1393 -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
1394
1395 lstm.SetInputGateBias(
1396 {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
1397 -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
1398 -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
1399 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
1400
1401 lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
1402 0.11098921, 0.15378423, 0.09263801, 0.09790885,
1403 0.09508917, 0.061199076, 0.07665568, -0.015443159,
1404 -0.03499149, 0.046190713, 0.08895977, 0.10899629,
1405 0.40694186, 0.06030037, 0.012413437, -0.06108739});
1406
1407 lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
1408 -0.1483596, -0.10639995, -0.091433935, 0.058573797,
1409 -0.06809782, -0.07889636, -0.043246906, -0.09829136,
1410 -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
1411 0.016178843, 0.1749513, 0.13975595, 0.92058027});
1412
1413 lstm.SetOutputGateBias(
1414 {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
1415 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
1416 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
1417 -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
1418
1419 lstm.SetRecurrentToInputWeights(
1420 {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
1421 -0.11585556, 0.02557986, -0.13446963, -0.035785314,
1422 -0.01244275, 0.025961924, -0.02337298, -0.044228926,
1423 -0.055839065, -0.046598054, -0.010546039, -0.06900766,
1424 0.027239809, 0.022582639, -0.013296484, -0.05459212,
1425 0.08981, -0.045407712, 0.08682226, -0.06867011,
1426 -0.14390695, -0.02916037, 0.000996957, 0.091420636,
1427 0.14283475, -0.07390571, -0.06402044, 0.062524505,
1428 -0.093129106, 0.04860203, -0.08364217, -0.08119002,
1429 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
1430 -0.13732095, 0.012405723, -0.07551853, 0.06343048,
1431 0.12162708, -0.031923793, -0.014335606, 0.01790974,
1432 -0.10650317, -0.0724401, 0.08554849, -0.05727212,
1433 0.06556731, -0.042729504, -0.043227166, 0.011683251,
1434 -0.013082158, -0.029302018, -0.010899579, -0.062036745,
1435 -0.022509435, -0.00964907, -0.01567329, 0.04260106,
1436 -0.07787477, -0.11576462, 0.017356863, 0.048673786,
1437 -0.017577527, -0.05527947, -0.082487635, -0.040137455,
1438 -0.10820036, -0.04666372, 0.022746278, -0.07851417,
1439 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
1440 0.08944216, -0.0685835, 0.010513544, 0.07228705,
1441 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
1442 0.040414046, -0.1380399, 0.094208956, -0.05722982,
1443 0.012092817, -0.04989123, -0.086576, -0.003399834,
1444 -0.04696032, -0.045747425, 0.10091314, 0.048676282,
1445 -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
1446 0.09504992, 0.041799378, -0.049185462, -0.031518843,
1447 -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
1448 -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
1449 -0.10167381, 0.042500053, -0.01447153, 0.06464186,
1450 -0.017142897, 0.03312627, 0.009205989, 0.024138335,
1451 -0.011337001, 0.035530265, -0.010912711, 0.0706555,
1452 -0.005894094, 0.051841937, -0.1401738, -0.02351249,
1453 0.0365468, 0.07590991, 0.08838724, 0.021681072,
1454 -0.10086113, 0.019608743, -0.06195883, 0.077335775,
1455 0.023646897, -0.095322326, 0.02233014, 0.09756986,
1456 -0.048691444, -0.009579111, 0.07595467, 0.11480546,
1457 -0.09801813, 0.019894179, 0.08502348, 0.004032281,
1458 0.037211012, 0.068537936, -0.048005626, -0.091520436,
1459 -0.028379958, -0.01556313, 0.06554592, -0.045599163,
1460 -0.01672207, -0.020169014, -0.011877351, -0.20212261,
1461 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
1462 -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
1463 0.015963363, 0.00871737, 0.060130805, 0.028611384,
1464 0.10109069, -0.015060172, -0.07894427, 0.06401885,
1465 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
1466 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
1467 0.019899689, 0.006106124, -0.027092824, 0.0786356,
1468 0.05052217, -0.058925, -0.011402121, -0.024987547,
1469 -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
1470 -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
1471 -0.033664223, -0.07978348, -0.025200296, -0.017207067,
1472 -0.058403496, -0.055697463, 0.005798788, 0.12965427,
1473 -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
1474 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
1475 0.013806405, -0.017858358, -0.01008298, -0.07700066,
1476 -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
1477 0.062634714, -0.02338735, -0.039547626, -0.02050681,
1478 0.03385117, -0.083611414, 0.002862572, -0.09421313,
1479 0.058618143, -0.08598433, 0.00972939, 0.023867095,
1480 -0.053934585, -0.023203006, 0.07452513, -0.048767887,
1481 -0.07314807, -0.056307215, -0.10433547, -0.06440842,
1482 0.04328182, 0.04389765, -0.020006588, -0.09076438,
1483 -0.11652589, -0.021705797, 0.03345259, -0.010329105,
1484 -0.025767034, 0.013057034, -0.07316461, -0.10145612,
1485 0.06358255, 0.18531723, 0.07759293, 0.12006465,
1486 0.1305557, 0.058638252, -0.03393652, 0.09622831,
1487 -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
1488 -0.005644518, 0.06857898, -0.12598175, -0.035084512,
1489 0.03156317, -0.12794146, -0.031963028, 0.04692781,
1490 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
1491 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
1492 0.08184801, -0.019164372, 0.06791302, 0.034257166,
1493 -0.10307039, 0.021943003, 0.046745934, 0.0790918,
1494 -0.0265588, -0.007824208, 0.042546265, -0.00977924,
1495 -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
1496 -0.014512694, -0.08251313, 0.08861942, 0.13589665,
1497 0.026351685, 0.012641483, 0.07466548, 0.044301085,
1498 -0.045414884, -0.051112458, 0.03444247, -0.08502782,
1499 -0.04106223, -0.028126027, 0.028473156, 0.10467447});
1500
1501 lstm.SetRecurrentToForgetWeights(
1502 {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
1503 0.14811787, 0.10826372, 0.09471067, 0.03987225,
1504 -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
1505 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
1506 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
1507 -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
1508 -0.06193199, 0.055729095, 0.03736828, 0.020123724,
1509 0.061878487, -0.04729229, 0.034919553, -0.07585433,
1510 -0.04421272, -0.044019096, 0.085488975, 0.04058006,
1511 -0.06890133, -0.030951202, -0.024628663, -0.07672815,
1512 0.034293607, 0.08556707, -0.05293577, -0.033561368,
1513 -0.04899627, 0.0241671, 0.015736353, -0.095442444,
1514 -0.029564252, 0.016493602, -0.035026584, 0.022337519,
1515 -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
1516 0.016435321, -0.03263031, -0.09543275, -0.047392778,
1517 0.013454138, 0.028934088, 0.01685226, -0.086110644,
1518 -0.046250615, -0.01847454, 0.047608484, 0.07339695,
1519 0.034546845, -0.04881143, 0.009128804, -0.08802852,
1520 0.03761666, 0.008096139, -0.014454086, 0.014361001,
1521 -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
1522 -0.06509276, -0.006021153, -0.08570962, -0.1451793,
1523 0.060212336, 0.055259194, 0.06974018, 0.049454916,
1524 -0.027794661, -0.08077226, -0.016179763, 0.1169753,
1525 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
1526 -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
1527 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
1528 -0.05695512, 0.047233116, 0.038937137, -0.06542224,
1529 0.014429736, -0.09719407, 0.13908425, -0.05379757,
1530 0.012321099, 0.082840554, -0.029899208, 0.044217527,
1531 0.059855383, 0.07711018, -0.045319796, 0.0948846,
1532 -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
1533 -0.13873616, 0.040668588, 0.034832682, -0.015319203,
1534 -0.018715994, 0.046002675, 0.0599172, -0.043107376,
1535 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
1536 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
1537 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
1538 0.052958444, 0.07558703, 0.04817258, 0.044462286,
1539 -0.015213451, -0.08783778, -0.0561384, -0.003008196,
1540 0.047060397, -0.002058388, 0.03429439, -0.018839769,
1541 0.024734668, 0.024614193, -0.042046934, 0.09597743,
1542 -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
1543 -0.02558259, -0.022822596, -0.023273505, -0.02464396,
1544 -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
1545 0.04383914, -0.046476185, 0.028658995, 0.060410924,
1546 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
1547 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
1548 0.015898481, 0.021362653, -0.030262267, 0.016587038,
1549 -0.011442813, 0.041154444, -0.007631438, -0.03423484,
1550 -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
1551 0.02318443, -0.041350313, 0.021485701, -0.10906167,
1552 -0.028218046, -0.00954771, 0.020531068, -0.11995105,
1553 -0.03672871, 0.024019798, 0.014255957, -0.05221243,
1554 -0.00661567, -0.04630967, 0.033188973, 0.10107534,
1555 -0.014027541, 0.030796422, -0.10270911, -0.035999842,
1556 0.15443139, 0.07684145, 0.036571592, -0.035900835,
1557 -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
1558 -0.03858649, 0.01849943, 0.13872518, 0.01503974,
1559 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
1560 -0.047401894, 0.03100163, -0.041533746, -0.10430945,
1561 0.044574402, -0.01425562, -0.024290353, 0.034563623,
1562 0.05866852, 0.023947537, -0.09445152, 0.035450947,
1563 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
1564 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
1565 0.03532124, -0.016341697, 0.09685456, -0.016764693,
1566 0.051808182, 0.05875331, -0.04536488, 0.001626336,
1567 -0.028892258, -0.01048663, -0.009793449, -0.017093895,
1568 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
1569 -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
1570 -0.01769146, 0.040995963, 0.02235177, -0.060430344,
1571 0.11475477, -0.023854522, 0.10071741, 0.0686208,
1572 -0.014250481, 0.034261297, 0.047418304, 0.08562733,
1573 -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
1574 0.04096551, 0.032249358, -0.08355519, -0.026823482,
1575 0.056386515, -0.010401743, -0.028396193, 0.08507674,
1576 0.014410365, 0.020995233, 0.17040324, 0.11511526,
1577 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
1578 -0.081302024, 0.017264642, -0.009585969, 0.09491168,
1579 -0.051313367, 0.054532815, -0.014298593, 0.10657464,
1580 0.007076659, 0.10964551, 0.0409152, 0.008275321,
1581 -0.07283536, 0.07937492, 0.04192024, -0.1075027});
1582
1583 lstm.SetRecurrentToCellWeights(
1584 {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
1585 0.055647098, -0.05713207, -0.05626563, 0.005559383,
1586 0.03375411, -0.025757805, -0.088049285, 0.06017052,
1587 -0.06570978, 0.007384076, 0.035123326, -0.07920549,
1588 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
1589 0.08089997, 0.05143358, 0.038261272, 0.03339287,
1590 -0.027673481, 0.044746667, 0.028349208, 0.020090483,
1591 -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
1592 -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
1593 -0.10893326, 0.076739706, -0.08509834, -0.027997585,
1594 0.037871376, 0.01449768, -0.09002357, -0.06111149,
1595 -0.046195522, 0.0422062, -0.005683705, -0.1253618,
1596 -0.012925729, -0.04890792, 0.06985068, 0.037654128,
1597 0.03398274, -0.004781977, 0.007032333, -0.031787455,
1598 0.010868644, -0.031489216, 0.09525667, 0.013939797,
1599 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
1600 -0.048885044, -0.12722108, 0.035304096, 0.06554885,
1601 0.00972396, -0.039238118, -0.05159735, -0.11329045,
1602 0.1613692, -0.03750952, 0.06529313, -0.071974665,
1603 -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
1604 0.02786344, -0.014179351, 0.005264273, 0.14376344,
1605 0.015983658, 0.03406988, -0.06939408, 0.040699873,
1606 0.02111075, 0.09669095, 0.041345075, -0.08316494,
1607 -0.07684199, -0.045768797, 0.032298047, -0.041805092,
1608 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
1609 -0.024950314, 0.11574242, 0.04508852, -0.04335324,
1610 0.06760663, -0.027437469, 0.07216407, 0.06977076,
1611 -0.05438599, 0.034033038, -0.028602652, 0.05346137,
1612 0.043184172, -0.037189785, 0.10420091, 0.00882477,
1613 -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
1614 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
1615 0.04361412, -0.007001822, 0.09631092, -0.06702025,
1616 -0.042049985, -0.035070654, -0.04103342, -0.10273396,
1617 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
1618 -0.008264958, 0.042035464, 0.05891794, 0.029673764,
1619 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
1620 -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
1621 -0.04043371, -0.017094059, 0.07229206, -0.023670016,
1622 -0.052195564, -0.025616996, -0.01520939, 0.045104615,
1623 -0.007376126, 0.003533447, 0.006570588, 0.056037236,
1624 0.12436656, 0.051817212, 0.028532185, -0.08686856,
1625 0.11868599, 0.07663395, -0.07323171, 0.03463402,
1626 -0.050708205, -0.04458982, -0.11590894, 0.021273347,
1627 0.1251325, -0.15313013, -0.12224372, 0.17228661,
1628 0.023029093, 0.086124025, 0.006445803, -0.03496501,
1629 0.028332196, 0.04449512, -0.042436164, -0.026587414,
1630 -0.006041347, -0.09292539, -0.05678812, 0.03897832,
1631 0.09465633, 0.008115513, -0.02171956, 0.08304309,
1632 0.071401566, 0.019622514, 0.032163795, -0.004167056,
1633 0.02295182, 0.030739572, 0.056506045, 0.004612461,
1634 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
1635 -0.1335546, -0.030136576, 0.11584653, -0.014678886,
1636 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
1637 -0.0329582, 0.07922767, 0.029322514, 0.026405897,
1638 0.04207835, -0.07073373, 0.063781224, 0.0859677,
1639 -0.10925287, -0.07011058, 0.048005477, 0.03438226,
1640 -0.09606514, -0.006669445, -0.043381985, 0.04240257,
1641 -0.06955775, -0.06769346, 0.043903265, -0.026784198,
1642 -0.017840602, 0.024307009, -0.040079936, -0.019946516,
1643 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
1644 0.15978073, 0.10185836, 0.10298046, -0.015476589,
1645 -0.039390966, -0.072174534, 0.0739445, -0.1211869,
1646 -0.0347889, -0.07943156, 0.014809798, -0.12412325,
1647 -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
1648 -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
1649 -0.01514876, -0.056505352, -0.012800942, -0.06994386,
1650 0.012962922, -0.031234352, 0.07029052, 0.016418684,
1651 0.03618972, 0.055686004, -0.08663945, -0.017404709,
1652 -0.054761406, 0.029065743, 0.052404847, 0.020238016,
1653 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
1654 0.06262858, 0.009184685, 0.020785125, -0.043904778,
1655 -0.0270329, -0.03299152, -0.060088247, -0.015162964,
1656 -0.001828936, 0.12642565, -0.056757294, 0.013586685,
1657 0.09232601, -0.035886683, 0.06000002, 0.05229691,
1658 -0.052580316, -0.082029596, -0.010794592, 0.012947712,
1659 -0.036429964, -0.085508935, -0.13127148, -0.017744139,
1660 0.031502828, 0.036232427, -0.031581745, 0.023051167,
1661 -0.05325106, -0.03421577, 0.028793324, -0.034633752,
1662 -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
1663 -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
1664
1665 lstm.SetRecurrentToOutputWeights({
1666 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
1667 -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
1668 -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
1669 -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
1670 -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
1671 -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
1672 -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
1673 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
1674 -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
1675 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
1676 -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
1677 -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
1678 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
1679 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
1680 -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
1681 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
1682 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
1683 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
1684 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
1685 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
1686 -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
1687 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
1688 -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
1689 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
1690 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
1691 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
1692 -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
1693 -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
1694 -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
1695 -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
1696 -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
1697 -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
1698 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
1699 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
1700 -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
1701 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
1702 -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
1703 -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
1704 -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
1705 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
1706 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
1707 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
1708 -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
1709 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
1710 -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
1711 -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
1712 -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
1713 -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
1714 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
1715 -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
1716 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
1717 -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
1718 -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
1719 -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
1720 -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
1721 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
1722 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
1723 -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
1724 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
1725 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
1726 -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
1727 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
1728 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
1729 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
1730 });
1731
1732 lstm.SetCellToInputWeights(
1733 {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
1734 -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
1735 -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
1736 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
1737
1738 lstm.SetCellToForgetWeights(
1739 {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
1740 -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
1741 -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
1742 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
1743
1744 lstm.SetCellToOutputWeights(
1745 {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
1746 -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
1747 -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
1748 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
1749
1750 lstm.SetProjectionWeights(
1751 {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
1752 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
1753 -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
1754 -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
1755 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
1756 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
1757 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
1758 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
1759 -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
1760 -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
1761 -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
1762 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
1763 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
1764 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
1765 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
1766 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
1767 -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
1768 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
1769 -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
1770 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
1771 -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
1772 -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
1773 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
1774 -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
1775 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
1776 -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
1777 -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
1778 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
1779 -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
1780 -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
1781 -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
1782 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
1783 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
1784 -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
1785 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
1786 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
1787 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
1788 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
1789 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
1790 -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
1791 -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
1792 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
1793 -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
1794 -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
1795 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
1796 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
1797 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
1798 -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
1799 -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
1800 -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
1801 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
1802 -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
1803 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
1804 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
1805 -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
1806 -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
1807 -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
1808 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
1809 -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
1810 -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
1811 -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
1812 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
1813 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
1814 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
1815
1816 static float lstm_input[][20] = {
1817 {// Batch0: 4 (input_sequence_size) * 5 (n_input)
1818 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
1819 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
1820 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
1821
1822 {// Batch1: 4 (input_sequence_size) * 5 (n_input)
1823 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
1824 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
1825 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
1826
1827 static float lstm_fw_golden_output[][64] = {
1828 {// Batch0: 4 (input_sequence_size) * 16 (n_output)
1829 -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
1830 -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
1831 -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
1832 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
1833 -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
1834 -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
1835 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
1836 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
1837 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
1838 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
1839 -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
1840 -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
1841 0.0286833, 0.00824207, 0.0264887, 0.0305169},
1842 {// Batch1: 4 (input_sequence_size) * 16 (n_output)
1843 -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
1844 -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
1845 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
1846 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
1847 -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
1848 -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
1849 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
1850 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
1851 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
1852 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
1853 -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
1854 -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
1855 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
1856
1857 static float lstm_combined_golden_output[][64] = {
1858 {-0.022014, 0.073544, -0.002235, 0.040068, -0.037136, -0.052788,
1859 0.075325, -0.029378, 0.024298, -0.07733, -0.030674, -0.060229,
1860 0.040599, 0.011608, 0.042005, 0.045977, -0.039225, 0.076294,
1861 0.000735, 0.032852, -0.069869, -0.053312, 0.073527, -0.028136,
1862 0.021585, -0.102679, -0.004327, -0.043304, 0.072861, 0.027077,
1863 0.034558, 0.068292, -0.036292, 0.069832, -0.003032, 0.053829,
1864 -0.043821, -0.072713, 0.085029, -0.040374, 0.020014, -0.104521,
1865 -0.034504, -0.059759, 0.062569, 0.025652, 0.049306, 0.061189,
1866 -0.025146, 0.079643, -0.005188, 0.033080, -0.048079, -0.048082,
1867 0.069369, -0.028900, 0.024572, -0.077547, -0.022517, -0.054477,
1868 0.038857, 0.013336, 0.043234, 0.044788},
1869 {-0.039186, 0.070792, -0.005913, 0.02642, -0.068274, -0.05022,
1870 0.061444, -0.031241, 0.014996, -0.094544, -0.004146, -0.03464,
1871 0.058981, 0.026097, 0.039781, 0.058408, -0.031887, 0.069252,
1872 0.00576, 0.054062, -0.042801, -0.059974, 0.085272, -0.034453,
1873 0.026097, -0.0959, -0.031164, -0.058699, 0.06839, 0.020512,
1874 0.044727, 0.063609, -0.039863, 0.084819, -0.003909, 0.028666,
1875 -0.075677, -0.045125, 0.070379, -0.033895, 0.022111, -0.097184,
1876 -0.004921, -0.040851, 0.062316, 0.017435, 0.041437, 0.064568,
1877 -0.039656, 0.060726, -0.003402, 0.036854, -0.056503, -0.058554,
1878 0.068588, -0.034879, 0.01352, -0.09962, -0.01434, -0.039505,
1879 0.065133, 0.024321, 0.038473, 0.062438}};
1880
1881 for (int i = 0; i < lstm.sequence_length(); i++) {
1882 float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
1883 float* batch0_end = batch0_start + lstm.num_inputs();
1884
1885 lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end);
1886
1887 float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
1888 float* batch1_end = batch1_start + lstm.num_inputs();
1889 lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end);
1890 }
1891
1892 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
1893
1894 std::vector<float> expected;
1895 for (int i = 0; i < lstm.sequence_length(); i++) {
1896 float* golden_start_batch0 =
1897 lstm_fw_golden_output[0] + i * lstm.num_fw_outputs();
1898 float* golden_end_batch0 = golden_start_batch0 + lstm.num_fw_outputs();
1899 float* golden_start_batch1 =
1900 lstm_fw_golden_output[1] + i * lstm.num_fw_outputs();
1901 float* golden_end_batch1 = golden_start_batch1 + lstm.num_fw_outputs();
1902 expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
1903 expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
1904 }
1905 EXPECT_THAT(lstm.GetFwOutput(), ElementsAreArray(ArrayFloatNear(expected)));
1906
1907 // Check if the sum of forward backward matches the golden.
1908 expected.clear();
1909 for (int i = 0; i < lstm.sequence_length(); i++) {
1910 float* golden_start_batch0 =
1911 lstm_combined_golden_output[0] + i * lstm.num_fw_outputs();
1912 float* golden_end_batch0 = golden_start_batch0 + lstm.num_fw_outputs();
1913 float* golden_start_batch1 =
1914 lstm_combined_golden_output[1] + i * lstm.num_fw_outputs();
1915 float* golden_end_batch1 = golden_start_batch1 + lstm.num_fw_outputs();
1916 expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
1917 expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
1918 }
1919
1920 std::vector<float> combined;
1921 for (int i = 0; i < lstm.GetFwOutput().size(); ++i) {
1922 combined.push_back(lstm.GetFwOutput()[i] + lstm.GetBwOutput()[i]);
1923 }
1924 EXPECT_THAT(combined, ElementsAreArray(ArrayFloatNear(expected)));
1925 }
1926
1927 // Same as above but with batch_major input/output.
TEST(LSTMOpTest,BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor)1928 TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor) {
1929 const int n_batch = 2;
1930 const int n_input = 5;
1931 const int n_cell = 20;
1932 const int n_output = 16;
1933 const int sequence_length = 4;
1934
1935 BidirectionalLSTMOpModel lstm(
1936 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
1937 /*use_peephole=*/true, /*use_projection_weights=*/true,
1938 /*use_projection_bias=*/false, /*merge_outputs=*/false,
1939 /*use_aux_input=*/false, /*cell_clip=*/0.0,
1940 /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/false,
1941 {
1942 {n_batch, sequence_length, n_input}, // input tensor
1943
1944 {n_cell, n_input}, // input_to_input_weight tensor
1945 {n_cell, n_input}, // input_to_forget_weight tensor
1946 {n_cell, n_input}, // input_to_cell_weight tensor
1947 {n_cell, n_input}, // input_to_output_weight tensor
1948
1949 {n_cell, n_output}, // recurrent_to_input_weight tensor
1950 {n_cell, n_output}, // recurrent_to_forget_weight tensor
1951 {n_cell, n_output}, // recurrent_to_cell_weight tensor
1952 {n_cell, n_output}, // recurrent_to_output_weight tensor
1953
1954 {n_cell}, // cell_to_input_weight tensor
1955 {n_cell}, // cell_to_forget_weight tensor
1956 {n_cell}, // cell_to_output_weight tensor
1957
1958 {n_cell}, // input_gate_bias tensor
1959 {n_cell}, // forget_gate_bias tensor
1960 {n_cell}, // cell_gate_bias tensor
1961 {n_cell}, // output_gate_bias tensor
1962
1963 {n_output, n_cell}, // projection_weight tensor
1964 {0}, // projection_bias tensor
1965
1966 {n_cell, n_input}, // input_to_input_weight tensor
1967 {n_cell, n_input}, // input_to_forget_weight tensor
1968 {n_cell, n_input}, // input_to_cell_weight tensor
1969 {n_cell, n_input}, // input_to_output_weight tensor
1970
1971 {n_cell, n_output}, // recurrent_to_input_weight tensor
1972 {n_cell, n_output}, // recurrent_to_forget_weight tensor
1973 {n_cell, n_output}, // recurrent_to_cell_weight tensor
1974 {n_cell, n_output}, // recurrent_to_output_weight tensor
1975
1976 {n_cell}, // cell_to_input_weight tensor
1977 {n_cell}, // cell_to_forget_weight tensor
1978 {n_cell}, // cell_to_output_weight tensor
1979
1980 {n_cell}, // input_gate_bias tensor
1981 {n_cell}, // forget_gate_bias tensor
1982 {n_cell}, // cell_gate_bias tensor
1983 {n_cell}, // output_gate_bias tensor
1984
1985 {n_output, n_cell}, // projection_weight tensor
1986 {0}, // projection_bias tensor
1987
1988 {n_batch, n_output}, // activation_state tensor
1989 {n_batch, n_cell}, // cell_state tensor
1990
1991 {n_batch, n_output}, // activation_state tensor
1992 {n_batch, n_cell}, // cell_state tensor
1993
1994 {n_batch, sequence_length, 0}, // aux_input tensor
1995 {0}, // aux_fw_input_to_input tensor
1996 {0}, // aux_fw_input_to_forget tensor
1997 {0}, // aux_fw_input_to_cell tensor
1998 {0}, // aux_fw_input_to_output tensor
1999 {0}, // aux_bw_input_to_input tensor
2000 {0}, // aux_bw_input_to_forget tensor
2001 {0}, // aux_bw_input_to_cell tensor
2002 {0}, // aux_bw_input_to_output tensor
2003 });
2004
2005 lstm.SetInputToInputWeights(
2006 {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
2007 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
2008 -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
2009 -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
2010 -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
2011 -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
2012 -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
2013 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
2014 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
2015 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
2016 -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
2017 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
2018 -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
2019 -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
2020 -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
2021 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
2022 -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
2023 -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
2024 -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
2025 -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
2026
2027 lstm.SetInputToForgetWeights(
2028 {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
2029 -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
2030 -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
2031 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
2032 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
2033 -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
2034 -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
2035 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
2036 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
2037 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
2038 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
2039 -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
2040 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
2041 -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
2042 -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
2043 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
2044 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
2045 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
2046 -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
2047 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
2048
2049 lstm.SetInputToCellWeights(
2050 {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
2051 -0.043528453, 0.043018587, -0.049152344, -0.12418144,
2052 -0.078985475, -0.07596889, 0.019484362, -0.11434962,
2053 -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
2054 -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
2055 0.10665918, -0.032036792, -0.08505916, -0.10843358,
2056 -0.13002433, -0.036816437, -0.02130134, -0.016518239,
2057 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
2058 -0.10652836, -0.1037554, -0.13056071, -0.03266643,
2059 -0.033702414, -0.006473424, -0.04611692, 0.014419339,
2060 -0.025174323, 0.0396852, 0.081777506, 0.06157468,
2061 0.10210095, -0.009658194, 0.046511717, 0.03603906,
2062 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
2063 0.053568836, 0.06408714, 0.12835667, -0.008714329,
2064 -0.20211966, -0.12093674, 0.029450472, 0.2849013,
2065 -0.029227901, 0.1164364, -0.08560263, 0.09941786,
2066 -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
2067 -0.09720865, -0.11193351, -0.029155117, -0.017936034,
2068 -0.009768936, -0.04223324, -0.036159635, 0.06505112,
2069 -0.021742892, -0.023377212, -0.07221364, -0.06430552,
2070 0.05453865, 0.091149814, 0.06387331, 0.007518393,
2071 0.055960953, 0.069779344, 0.046411168, 0.10509911,
2072 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
2073 0.056955688, 0.06555285, 0.050801456, -0.009862683,
2074 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
2075
2076 lstm.SetInputToOutputWeights(
2077 {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
2078 -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
2079 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
2080 -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
2081 -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
2082 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
2083 -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
2084 -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
2085 -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
2086 -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
2087 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
2088 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
2089 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
2090 -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
2091 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
2092 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
2093 -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
2094 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
2095 -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
2096 -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
2097
2098 lstm.SetInputGateBias(
2099 {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
2100 -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
2101 -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
2102 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
2103
2104 lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
2105 0.11098921, 0.15378423, 0.09263801, 0.09790885,
2106 0.09508917, 0.061199076, 0.07665568, -0.015443159,
2107 -0.03499149, 0.046190713, 0.08895977, 0.10899629,
2108 0.40694186, 0.06030037, 0.012413437, -0.06108739});
2109
2110 lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
2111 -0.1483596, -0.10639995, -0.091433935, 0.058573797,
2112 -0.06809782, -0.07889636, -0.043246906, -0.09829136,
2113 -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
2114 0.016178843, 0.1749513, 0.13975595, 0.92058027});
2115
2116 lstm.SetOutputGateBias(
2117 {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
2118 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
2119 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
2120 -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
2121
2122 lstm.SetRecurrentToInputWeights(
2123 {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
2124 -0.11585556, 0.02557986, -0.13446963, -0.035785314,
2125 -0.01244275, 0.025961924, -0.02337298, -0.044228926,
2126 -0.055839065, -0.046598054, -0.010546039, -0.06900766,
2127 0.027239809, 0.022582639, -0.013296484, -0.05459212,
2128 0.08981, -0.045407712, 0.08682226, -0.06867011,
2129 -0.14390695, -0.02916037, 0.000996957, 0.091420636,
2130 0.14283475, -0.07390571, -0.06402044, 0.062524505,
2131 -0.093129106, 0.04860203, -0.08364217, -0.08119002,
2132 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
2133 -0.13732095, 0.012405723, -0.07551853, 0.06343048,
2134 0.12162708, -0.031923793, -0.014335606, 0.01790974,
2135 -0.10650317, -0.0724401, 0.08554849, -0.05727212,
2136 0.06556731, -0.042729504, -0.043227166, 0.011683251,
2137 -0.013082158, -0.029302018, -0.010899579, -0.062036745,
2138 -0.022509435, -0.00964907, -0.01567329, 0.04260106,
2139 -0.07787477, -0.11576462, 0.017356863, 0.048673786,
2140 -0.017577527, -0.05527947, -0.082487635, -0.040137455,
2141 -0.10820036, -0.04666372, 0.022746278, -0.07851417,
2142 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
2143 0.08944216, -0.0685835, 0.010513544, 0.07228705,
2144 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
2145 0.040414046, -0.1380399, 0.094208956, -0.05722982,
2146 0.012092817, -0.04989123, -0.086576, -0.003399834,
2147 -0.04696032, -0.045747425, 0.10091314, 0.048676282,
2148 -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
2149 0.09504992, 0.041799378, -0.049185462, -0.031518843,
2150 -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
2151 -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
2152 -0.10167381, 0.042500053, -0.01447153, 0.06464186,
2153 -0.017142897, 0.03312627, 0.009205989, 0.024138335,
2154 -0.011337001, 0.035530265, -0.010912711, 0.0706555,
2155 -0.005894094, 0.051841937, -0.1401738, -0.02351249,
2156 0.0365468, 0.07590991, 0.08838724, 0.021681072,
2157 -0.10086113, 0.019608743, -0.06195883, 0.077335775,
2158 0.023646897, -0.095322326, 0.02233014, 0.09756986,
2159 -0.048691444, -0.009579111, 0.07595467, 0.11480546,
2160 -0.09801813, 0.019894179, 0.08502348, 0.004032281,
2161 0.037211012, 0.068537936, -0.048005626, -0.091520436,
2162 -0.028379958, -0.01556313, 0.06554592, -0.045599163,
2163 -0.01672207, -0.020169014, -0.011877351, -0.20212261,
2164 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
2165 -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
2166 0.015963363, 0.00871737, 0.060130805, 0.028611384,
2167 0.10109069, -0.015060172, -0.07894427, 0.06401885,
2168 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
2169 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
2170 0.019899689, 0.006106124, -0.027092824, 0.0786356,
2171 0.05052217, -0.058925, -0.011402121, -0.024987547,
2172 -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
2173 -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
2174 -0.033664223, -0.07978348, -0.025200296, -0.017207067,
2175 -0.058403496, -0.055697463, 0.005798788, 0.12965427,
2176 -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
2177 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
2178 0.013806405, -0.017858358, -0.01008298, -0.07700066,
2179 -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
2180 0.062634714, -0.02338735, -0.039547626, -0.02050681,
2181 0.03385117, -0.083611414, 0.002862572, -0.09421313,
2182 0.058618143, -0.08598433, 0.00972939, 0.023867095,
2183 -0.053934585, -0.023203006, 0.07452513, -0.048767887,
2184 -0.07314807, -0.056307215, -0.10433547, -0.06440842,
2185 0.04328182, 0.04389765, -0.020006588, -0.09076438,
2186 -0.11652589, -0.021705797, 0.03345259, -0.010329105,
2187 -0.025767034, 0.013057034, -0.07316461, -0.10145612,
2188 0.06358255, 0.18531723, 0.07759293, 0.12006465,
2189 0.1305557, 0.058638252, -0.03393652, 0.09622831,
2190 -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
2191 -0.005644518, 0.06857898, -0.12598175, -0.035084512,
2192 0.03156317, -0.12794146, -0.031963028, 0.04692781,
2193 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
2194 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
2195 0.08184801, -0.019164372, 0.06791302, 0.034257166,
2196 -0.10307039, 0.021943003, 0.046745934, 0.0790918,
2197 -0.0265588, -0.007824208, 0.042546265, -0.00977924,
2198 -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
2199 -0.014512694, -0.08251313, 0.08861942, 0.13589665,
2200 0.026351685, 0.012641483, 0.07466548, 0.044301085,
2201 -0.045414884, -0.051112458, 0.03444247, -0.08502782,
2202 -0.04106223, -0.028126027, 0.028473156, 0.10467447});
2203
2204 lstm.SetRecurrentToForgetWeights(
2205 {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
2206 0.14811787, 0.10826372, 0.09471067, 0.03987225,
2207 -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
2208 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
2209 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
2210 -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
2211 -0.06193199, 0.055729095, 0.03736828, 0.020123724,
2212 0.061878487, -0.04729229, 0.034919553, -0.07585433,
2213 -0.04421272, -0.044019096, 0.085488975, 0.04058006,
2214 -0.06890133, -0.030951202, -0.024628663, -0.07672815,
2215 0.034293607, 0.08556707, -0.05293577, -0.033561368,
2216 -0.04899627, 0.0241671, 0.015736353, -0.095442444,
2217 -0.029564252, 0.016493602, -0.035026584, 0.022337519,
2218 -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
2219 0.016435321, -0.03263031, -0.09543275, -0.047392778,
2220 0.013454138, 0.028934088, 0.01685226, -0.086110644,
2221 -0.046250615, -0.01847454, 0.047608484, 0.07339695,
2222 0.034546845, -0.04881143, 0.009128804, -0.08802852,
2223 0.03761666, 0.008096139, -0.014454086, 0.014361001,
2224 -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
2225 -0.06509276, -0.006021153, -0.08570962, -0.1451793,
2226 0.060212336, 0.055259194, 0.06974018, 0.049454916,
2227 -0.027794661, -0.08077226, -0.016179763, 0.1169753,
2228 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
2229 -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
2230 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
2231 -0.05695512, 0.047233116, 0.038937137, -0.06542224,
2232 0.014429736, -0.09719407, 0.13908425, -0.05379757,
2233 0.012321099, 0.082840554, -0.029899208, 0.044217527,
2234 0.059855383, 0.07711018, -0.045319796, 0.0948846,
2235 -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
2236 -0.13873616, 0.040668588, 0.034832682, -0.015319203,
2237 -0.018715994, 0.046002675, 0.0599172, -0.043107376,
2238 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
2239 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
2240 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
2241 0.052958444, 0.07558703, 0.04817258, 0.044462286,
2242 -0.015213451, -0.08783778, -0.0561384, -0.003008196,
2243 0.047060397, -0.002058388, 0.03429439, -0.018839769,
2244 0.024734668, 0.024614193, -0.042046934, 0.09597743,
2245 -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
2246 -0.02558259, -0.022822596, -0.023273505, -0.02464396,
2247 -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
2248 0.04383914, -0.046476185, 0.028658995, 0.060410924,
2249 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
2250 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
2251 0.015898481, 0.021362653, -0.030262267, 0.016587038,
2252 -0.011442813, 0.041154444, -0.007631438, -0.03423484,
2253 -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
2254 0.02318443, -0.041350313, 0.021485701, -0.10906167,
2255 -0.028218046, -0.00954771, 0.020531068, -0.11995105,
2256 -0.03672871, 0.024019798, 0.014255957, -0.05221243,
2257 -0.00661567, -0.04630967, 0.033188973, 0.10107534,
2258 -0.014027541, 0.030796422, -0.10270911, -0.035999842,
2259 0.15443139, 0.07684145, 0.036571592, -0.035900835,
2260 -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
2261 -0.03858649, 0.01849943, 0.13872518, 0.01503974,
2262 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
2263 -0.047401894, 0.03100163, -0.041533746, -0.10430945,
2264 0.044574402, -0.01425562, -0.024290353, 0.034563623,
2265 0.05866852, 0.023947537, -0.09445152, 0.035450947,
2266 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
2267 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
2268 0.03532124, -0.016341697, 0.09685456, -0.016764693,
2269 0.051808182, 0.05875331, -0.04536488, 0.001626336,
2270 -0.028892258, -0.01048663, -0.009793449, -0.017093895,
2271 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
2272 -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
2273 -0.01769146, 0.040995963, 0.02235177, -0.060430344,
2274 0.11475477, -0.023854522, 0.10071741, 0.0686208,
2275 -0.014250481, 0.034261297, 0.047418304, 0.08562733,
2276 -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
2277 0.04096551, 0.032249358, -0.08355519, -0.026823482,
2278 0.056386515, -0.010401743, -0.028396193, 0.08507674,
2279 0.014410365, 0.020995233, 0.17040324, 0.11511526,
2280 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
2281 -0.081302024, 0.017264642, -0.009585969, 0.09491168,
2282 -0.051313367, 0.054532815, -0.014298593, 0.10657464,
2283 0.007076659, 0.10964551, 0.0409152, 0.008275321,
2284 -0.07283536, 0.07937492, 0.04192024, -0.1075027});
2285
2286 lstm.SetRecurrentToCellWeights(
2287 {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
2288 0.055647098, -0.05713207, -0.05626563, 0.005559383,
2289 0.03375411, -0.025757805, -0.088049285, 0.06017052,
2290 -0.06570978, 0.007384076, 0.035123326, -0.07920549,
2291 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
2292 0.08089997, 0.05143358, 0.038261272, 0.03339287,
2293 -0.027673481, 0.044746667, 0.028349208, 0.020090483,
2294 -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
2295 -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
2296 -0.10893326, 0.076739706, -0.08509834, -0.027997585,
2297 0.037871376, 0.01449768, -0.09002357, -0.06111149,
2298 -0.046195522, 0.0422062, -0.005683705, -0.1253618,
2299 -0.012925729, -0.04890792, 0.06985068, 0.037654128,
2300 0.03398274, -0.004781977, 0.007032333, -0.031787455,
2301 0.010868644, -0.031489216, 0.09525667, 0.013939797,
2302 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
2303 -0.048885044, -0.12722108, 0.035304096, 0.06554885,
2304 0.00972396, -0.039238118, -0.05159735, -0.11329045,
2305 0.1613692, -0.03750952, 0.06529313, -0.071974665,
2306 -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
2307 0.02786344, -0.014179351, 0.005264273, 0.14376344,
2308 0.015983658, 0.03406988, -0.06939408, 0.040699873,
2309 0.02111075, 0.09669095, 0.041345075, -0.08316494,
2310 -0.07684199, -0.045768797, 0.032298047, -0.041805092,
2311 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
2312 -0.024950314, 0.11574242, 0.04508852, -0.04335324,
2313 0.06760663, -0.027437469, 0.07216407, 0.06977076,
2314 -0.05438599, 0.034033038, -0.028602652, 0.05346137,
2315 0.043184172, -0.037189785, 0.10420091, 0.00882477,
2316 -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
2317 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
2318 0.04361412, -0.007001822, 0.09631092, -0.06702025,
2319 -0.042049985, -0.035070654, -0.04103342, -0.10273396,
2320 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
2321 -0.008264958, 0.042035464, 0.05891794, 0.029673764,
2322 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
2323 -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
2324 -0.04043371, -0.017094059, 0.07229206, -0.023670016,
2325 -0.052195564, -0.025616996, -0.01520939, 0.045104615,
2326 -0.007376126, 0.003533447, 0.006570588, 0.056037236,
2327 0.12436656, 0.051817212, 0.028532185, -0.08686856,
2328 0.11868599, 0.07663395, -0.07323171, 0.03463402,
2329 -0.050708205, -0.04458982, -0.11590894, 0.021273347,
2330 0.1251325, -0.15313013, -0.12224372, 0.17228661,
2331 0.023029093, 0.086124025, 0.006445803, -0.03496501,
2332 0.028332196, 0.04449512, -0.042436164, -0.026587414,
2333 -0.006041347, -0.09292539, -0.05678812, 0.03897832,
2334 0.09465633, 0.008115513, -0.02171956, 0.08304309,
2335 0.071401566, 0.019622514, 0.032163795, -0.004167056,
2336 0.02295182, 0.030739572, 0.056506045, 0.004612461,
2337 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
2338 -0.1335546, -0.030136576, 0.11584653, -0.014678886,
2339 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
2340 -0.0329582, 0.07922767, 0.029322514, 0.026405897,
2341 0.04207835, -0.07073373, 0.063781224, 0.0859677,
2342 -0.10925287, -0.07011058, 0.048005477, 0.03438226,
2343 -0.09606514, -0.006669445, -0.043381985, 0.04240257,
2344 -0.06955775, -0.06769346, 0.043903265, -0.026784198,
2345 -0.017840602, 0.024307009, -0.040079936, -0.019946516,
2346 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
2347 0.15978073, 0.10185836, 0.10298046, -0.015476589,
2348 -0.039390966, -0.072174534, 0.0739445, -0.1211869,
2349 -0.0347889, -0.07943156, 0.014809798, -0.12412325,
2350 -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
2351 -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
2352 -0.01514876, -0.056505352, -0.012800942, -0.06994386,
2353 0.012962922, -0.031234352, 0.07029052, 0.016418684,
2354 0.03618972, 0.055686004, -0.08663945, -0.017404709,
2355 -0.054761406, 0.029065743, 0.052404847, 0.020238016,
2356 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
2357 0.06262858, 0.009184685, 0.020785125, -0.043904778,
2358 -0.0270329, -0.03299152, -0.060088247, -0.015162964,
2359 -0.001828936, 0.12642565, -0.056757294, 0.013586685,
2360 0.09232601, -0.035886683, 0.06000002, 0.05229691,
2361 -0.052580316, -0.082029596, -0.010794592, 0.012947712,
2362 -0.036429964, -0.085508935, -0.13127148, -0.017744139,
2363 0.031502828, 0.036232427, -0.031581745, 0.023051167,
2364 -0.05325106, -0.03421577, 0.028793324, -0.034633752,
2365 -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
2366 -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
2367
2368 lstm.SetRecurrentToOutputWeights({
2369 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
2370 -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
2371 -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
2372 -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
2373 -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
2374 -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
2375 -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
2376 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
2377 -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
2378 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
2379 -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
2380 -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
2381 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
2382 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
2383 -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
2384 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
2385 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
2386 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
2387 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
2388 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
2389 -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
2390 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
2391 -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
2392 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
2393 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
2394 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
2395 -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
2396 -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
2397 -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
2398 -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
2399 -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
2400 -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
2401 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
2402 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
2403 -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
2404 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
2405 -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
2406 -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
2407 -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
2408 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
2409 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
2410 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
2411 -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
2412 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
2413 -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
2414 -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
2415 -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
2416 -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
2417 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
2418 -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
2419 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
2420 -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
2421 -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
2422 -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
2423 -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
2424 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
2425 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
2426 -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
2427 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
2428 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
2429 -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
2430 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
2431 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
2432 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
2433 });
2434
2435 lstm.SetCellToInputWeights(
2436 {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
2437 -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
2438 -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
2439 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
2440
2441 lstm.SetCellToForgetWeights(
2442 {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
2443 -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
2444 -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
2445 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
2446
2447 lstm.SetCellToOutputWeights(
2448 {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
2449 -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
2450 -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
2451 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
2452
2453 lstm.SetProjectionWeights(
2454 {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
2455 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
2456 -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
2457 -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
2458 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
2459 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
2460 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
2461 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
2462 -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
2463 -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
2464 -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
2465 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
2466 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
2467 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
2468 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
2469 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
2470 -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
2471 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
2472 -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
2473 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
2474 -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
2475 -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
2476 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
2477 -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
2478 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
2479 -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
2480 -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
2481 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
2482 -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
2483 -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
2484 -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
2485 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
2486 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
2487 -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
2488 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
2489 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
2490 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
2491 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
2492 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
2493 -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
2494 -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
2495 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
2496 -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
2497 -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
2498 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
2499 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
2500 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
2501 -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
2502 -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
2503 -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
2504 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
2505 -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
2506 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
2507 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
2508 -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
2509 -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
2510 -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
2511 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
2512 -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
2513 -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
2514 -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
2515 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
2516 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
2517 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
2518
2519 static float lstm_input[][20] = {
2520 {// Batch0: 4 (input_sequence_size) * 5 (n_input)
2521 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
2522 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
2523 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
2524
2525 {// Batch1: 4 (input_sequence_size) * 5 (n_input)
2526 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
2527 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
2528 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
2529
2530 static float lstm_fw_golden_output[][64] = {
2531 {// Batch0: 4 (input_sequence_size) * 16 (n_output)
2532 -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
2533 -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
2534 -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
2535 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
2536 -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
2537 -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
2538 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
2539 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
2540 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
2541 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
2542 -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
2543 -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
2544 0.0286833, 0.00824207, 0.0264887, 0.0305169},
2545 {// Batch1: 4 (input_sequence_size) * 16 (n_output)
2546 -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
2547 -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
2548 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
2549 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
2550 -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
2551 -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
2552 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
2553 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
2554 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
2555 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
2556 -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
2557 -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
2558 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
2559
2560 static float lstm_combined_golden_output[][64] = {
2561 {-0.022014, 0.073544, -0.002235, 0.040068, -0.037136, -0.052788,
2562 0.075325, -0.029378, 0.024298, -0.07733, -0.030674, -0.060229,
2563 0.040599, 0.011608, 0.042005, 0.045977, -0.039225, 0.076294,
2564 0.000735, 0.032852, -0.069869, -0.053312, 0.073527, -0.028136,
2565 0.021585, -0.102679, -0.004327, -0.043304, 0.072861, 0.027077,
2566 0.034558, 0.068292, -0.036292, 0.069832, -0.003032, 0.053829,
2567 -0.043821, -0.072713, 0.085029, -0.040374, 0.020014, -0.104521,
2568 -0.034504, -0.059759, 0.062569, 0.025652, 0.049306, 0.061189,
2569 -0.025146, 0.079643, -0.005188, 0.033080, -0.048079, -0.048082,
2570 0.069369, -0.028900, 0.024572, -0.077547, -0.022517, -0.054477,
2571 0.038857, 0.013336, 0.043234, 0.044788},
2572 {-0.039186, 0.070792, -0.005913, 0.02642, -0.068274, -0.05022,
2573 0.061444, -0.031241, 0.014996, -0.094544, -0.004146, -0.03464,
2574 0.058981, 0.026097, 0.039781, 0.058408, -0.031887, 0.069252,
2575 0.00576, 0.054062, -0.042801, -0.059974, 0.085272, -0.034453,
2576 0.026097, -0.0959, -0.031164, -0.058699, 0.06839, 0.020512,
2577 0.044727, 0.063609, -0.039863, 0.084819, -0.003909, 0.028666,
2578 -0.075677, -0.045125, 0.070379, -0.033895, 0.022111, -0.097184,
2579 -0.004921, -0.040851, 0.062316, 0.017435, 0.041437, 0.064568,
2580 -0.039656, 0.060726, -0.003402, 0.036854, -0.056503, -0.058554,
2581 0.068588, -0.034879, 0.01352, -0.09962, -0.01434, -0.039505,
2582 0.065133, 0.024321, 0.038473, 0.062438}};
2583
2584 const int input_sequence_size = lstm.sequence_length() * lstm.num_inputs();
2585 EXPECT_EQ(input_sequence_size, 20);
2586 float* batch0_start = lstm_input[0];
2587 float* batch0_end = batch0_start + input_sequence_size;
2588 lstm.SetInput(0, batch0_start, batch0_end);
2589
2590 float* batch1_start = lstm_input[1];
2591 float* batch1_end = batch1_start + input_sequence_size;
2592 lstm.SetInput(input_sequence_size, batch1_start, batch1_end);
2593
2594 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
2595
2596 const int output_sequence_size =
2597 lstm.sequence_length() * lstm.num_fw_outputs();
2598 EXPECT_EQ(output_sequence_size, 64);
2599 std::vector<float> expected;
2600 const float* golden_start_batch0 = lstm_fw_golden_output[0];
2601 const float* golden_end_batch0 = golden_start_batch0 + output_sequence_size;
2602 expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
2603
2604 const float* golden_start_batch1 = lstm_fw_golden_output[1];
2605 const float* golden_end_batch1 = golden_start_batch1 + output_sequence_size;
2606 expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
2607 EXPECT_THAT(lstm.GetFwOutput(), ElementsAreArray(ArrayFloatNear(expected)));
2608
2609 // Check if the sum of forward backward matches the golden.
2610 expected.clear();
2611 golden_start_batch0 = lstm_combined_golden_output[0];
2612 golden_end_batch0 = golden_start_batch0 + output_sequence_size;
2613 expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
2614
2615 golden_start_batch1 = lstm_combined_golden_output[1];
2616 golden_end_batch1 = golden_start_batch1 + output_sequence_size;
2617 expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
2618
2619 std::vector<float> combined;
2620 for (int i = 0; i < lstm.GetFwOutput().size(); ++i) {
2621 combined.push_back(lstm.GetFwOutput()[i] + lstm.GetBwOutput()[i]);
2622 }
2623 EXPECT_THAT(combined, ElementsAreArray(ArrayFloatNear(expected)));
2624 }
2625
2626 // Same as the no cifg no peephole no projection no clipping test, but have an
2627 // aux input (without aux input weights), this is the case when stacking but no
2628 // cross-links.
TEST_P(LSTMOpTest,BlackBoxTestWithAuxInputZeroAuxWeight)2629 TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
2630 const int n_batch = 1;
2631 const int n_input = 2;
2632 // n_cell and n_output have the same size when there is no projection.
2633 const int n_cell = 4;
2634 const int n_output = 4;
2635 const int sequence_length = 3;
2636 auto params = GetParam();
2637 const bool quantize_weights = std::get<0>(params);
2638 const bool asymmetric_quantize_inputs = std::get<1>(params);
2639
2640 BidirectionalLSTMOpModel lstm(
2641 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
2642 /*use_peephole=*/false, /*use_projection_weights=*/false,
2643 /*use_projection_bias=*/false, /*merge_outputs=*/false,
2644 /*use_aux_input=*/true, /*cell_clip=*/0.0,
2645 /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
2646 {
2647 {sequence_length, n_batch, n_input}, // input tensor
2648
2649 // Forward cell
2650 {n_cell, n_input}, // input_to_input_weight tensor
2651 {n_cell, n_input}, // input_to_forget_weight tensor
2652 {n_cell, n_input}, // input_to_cell_weight tensor
2653 {n_cell, n_input}, // input_to_output_weight tensor
2654
2655 {n_cell, n_output}, // recurrent_to_input_weight tensor
2656 {n_cell, n_output}, // recurrent_to_forget_weight tensor
2657 {n_cell, n_output}, // recurrent_to_cell_weight tensor
2658 {n_cell, n_output}, // recurrent_to_output_weight tensor
2659
2660 {0}, // cell_to_input_weight tensor
2661 {0}, // cell_to_forget_weight tensor
2662 {0}, // cell_to_output_weight tensor
2663
2664 {n_cell}, // input_gate_bias tensor
2665 {n_cell}, // forget_gate_bias tensor
2666 {n_cell}, // cell_gate_bias tensor
2667 {n_cell}, // output_gate_bias tensor
2668
2669 {0, 0}, // projection_weight tensor
2670 {0}, // projection_bias tensor
2671
2672 // Backward cell
2673 {n_cell, n_input}, // input_to_input_weight tensor
2674 {n_cell, n_input}, // input_to_forget_weight tensor
2675 {n_cell, n_input}, // input_to_cell_weight tensor
2676 {n_cell, n_input}, // input_to_output_weight tensor
2677
2678 {n_cell, n_output}, // recurrent_to_input_weight tensor
2679 {n_cell, n_output}, // recurrent_to_forget_weight tensor
2680 {n_cell, n_output}, // recurrent_to_cell_weight tensor
2681 {n_cell, n_output}, // recurrent_to_output_weight tensor
2682
2683 {0}, // cell_to_input_weight tensor
2684 {0}, // cell_to_forget_weight tensor
2685 {0}, // cell_to_output_weight tensor
2686
2687 {n_cell}, // input_gate_bias tensor
2688 {n_cell}, // forget_gate_bias tensor
2689 {n_cell}, // cell_gate_bias tensor
2690 {n_cell}, // output_gate_bias tensor
2691
2692 {0, 0}, // projection_weight tensor
2693 {0}, // projection_bias tensor
2694
2695 {n_batch, n_output}, // activation_state tensor
2696 {n_batch, n_cell}, // cell_state tensor
2697
2698 {n_batch, n_output}, // activation_state tensor
2699 {n_batch, n_cell}, // cell_state tensor
2700
2701 {sequence_length, n_batch, n_input}, // aux_input tensor
2702 {n_cell, n_input}, // aux_fw_input_to_input tensor
2703 {n_cell, n_input}, // aux_fw_input_to_forget tensor
2704 {n_cell, n_input}, // aux_fw_input_to_cell tensor
2705 {n_cell, n_input}, // aux_fw_input_to_output tensor
2706 {n_cell, n_input}, // aux_bw_input_to_input tensor
2707 {n_cell, n_input}, // aux_bw_input_to_forget tensor
2708 {n_cell, n_input}, // aux_bw_input_to_cell tensor
2709 {n_cell, n_input}, // aux_bw_input_to_output tensor
2710 },
2711 asymmetric_quantize_inputs);
2712
2713 lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
2714 -0.34550029, 0.04266912, -0.15680569,
2715 -0.34856534, 0.43890524});
2716
2717 lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
2718 -0.20583314, 0.44344562, 0.22077113,
2719 -0.29909778});
2720
2721 lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
2722 -0.31343272, -0.40032279, 0.44781327,
2723 0.01387155, -0.35593212});
2724
2725 lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
2726 0.40525138, 0.44272184, 0.03897077, -0.1556896,
2727 0.19487578});
2728
2729 lstm.SetInputGateBias({0., 0., 0., 0.});
2730
2731 lstm.SetCellBias({0., 0., 0., 0.});
2732
2733 lstm.SetForgetGateBias({1., 1., 1., 1.});
2734
2735 lstm.SetOutputGateBias({0., 0., 0., 0.});
2736
2737 lstm.SetRecurrentToInputWeights(
2738 {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
2739 -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
2740 -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
2741
2742 lstm.SetRecurrentToCellWeights(
2743 {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
2744 -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
2745 -0.46367589, 0.26016325, -0.03894562, -0.16368064});
2746
2747 lstm.SetRecurrentToForgetWeights(
2748 {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
2749 -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
2750 0.28053468, 0.01560611, -0.20127171, -0.01140004});
2751
2752 lstm.SetRecurrentToOutputWeights(
2753 {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
2754 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
2755 -0.51818722, -0.15390486, 0.0468148, 0.39922136});
2756
2757 // Input should have n_input * sequence_length many values.
2758 static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
2759 static float lstm_fw_golden_output[] = {
2760 -0.02973187, 0.1229473, 0.20885126, -0.15358765,
2761 -0.03716109, 0.12507336, 0.41193449, -0.20860538,
2762 -0.15053082, 0.09120187, 0.24278517, -0.12222792};
2763 static float lstm_bw_golden_output[] = {
2764 -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
2765 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
2766
2767 float* batch0_start = lstm_input;
2768 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
2769
2770 lstm.SetInput(0, batch0_start, batch0_end);
2771 // Aux input and input are the same, so we should observe the same outputs
2772 // as there's no aux input.
2773 lstm.SetAuxInput(0, batch0_start, batch0_end);
2774 std::vector<float> dummy_weights(n_cell * n_input, 0.0f);
2775 lstm.SetAuxInputToInputWeights(dummy_weights);
2776 lstm.SetAuxInputToForgetWeights(dummy_weights);
2777 lstm.SetAuxInputToCellWeights(dummy_weights);
2778 lstm.SetAuxInputToOutputWeights(dummy_weights);
2779
2780 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
2781
2782 float* fw_golden_start = lstm_fw_golden_output;
2783 float* fw_golden_end =
2784 fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length();
2785 std::vector<float> fw_expected;
2786 fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end);
2787 EXPECT_THAT(lstm.GetFwOutput(),
2788 ElementsAreArray(
2789 ArrayFloatNear(fw_expected, quantize_weights ? 1e-2 : 1e-5)));
2790
2791 float* bw_golden_start = lstm_bw_golden_output;
2792 float* bw_golden_end =
2793 bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length();
2794 std::vector<float> bw_expected;
2795 bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
2796 EXPECT_THAT(lstm.GetBwOutput(),
2797 ElementsAreArray(
2798 ArrayFloatNear(bw_expected, quantize_weights ? 1e-2 : 1e-5)));
2799 }
2800
2801 // Same as the no cifg no peephole no projection no clipping test, but have an
2802 // aux input with non-zero weights.
TEST_P(LSTMOpTest,BlackBoxTestWithAuxInput)2803 TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
2804 const int n_batch = 1;
2805 const int n_input = 2;
2806 // n_cell and n_output have the same size when there is no projection.
2807 const int n_cell = 4;
2808 const int n_output = 4;
2809 const int sequence_length = 3;
2810 auto params = GetParam();
2811 const bool quantize_weights = std::get<0>(params);
2812 const bool asymmetric_quantize_inputs = std::get<1>(params);
2813
2814 BidirectionalLSTMOpModel lstm(
2815 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
2816 /*use_peephole=*/false, /*use_projection_weights=*/false,
2817 /*use_projection_bias=*/false, /*merge_outputs=*/false,
2818 /*use_aux_input=*/true, /*cell_clip=*/0.0,
2819 /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true,
2820 {
2821 {sequence_length, n_batch, n_input}, // input tensor
2822
2823 // Forward cell
2824 {n_cell, n_input}, // input_to_input_weight tensor
2825 {n_cell, n_input}, // input_to_forget_weight tensor
2826 {n_cell, n_input}, // input_to_cell_weight tensor
2827 {n_cell, n_input}, // input_to_output_weight tensor
2828
2829 {n_cell, n_output}, // recurrent_to_input_weight tensor
2830 {n_cell, n_output}, // recurrent_to_forget_weight tensor
2831 {n_cell, n_output}, // recurrent_to_cell_weight tensor
2832 {n_cell, n_output}, // recurrent_to_output_weight tensor
2833
2834 {0}, // cell_to_input_weight tensor
2835 {0}, // cell_to_forget_weight tensor
2836 {0}, // cell_to_output_weight tensor
2837
2838 {n_cell}, // input_gate_bias tensor
2839 {n_cell}, // forget_gate_bias tensor
2840 {n_cell}, // cell_gate_bias tensor
2841 {n_cell}, // output_gate_bias tensor
2842
2843 {0, 0}, // projection_weight tensor
2844 {0}, // projection_bias tensor
2845
2846 // Backward cell
2847 {n_cell, n_input}, // input_to_input_weight tensor
2848 {n_cell, n_input}, // input_to_forget_weight tensor
2849 {n_cell, n_input}, // input_to_cell_weight tensor
2850 {n_cell, n_input}, // input_to_output_weight tensor
2851
2852 {n_cell, n_output}, // recurrent_to_input_weight tensor
2853 {n_cell, n_output}, // recurrent_to_forget_weight tensor
2854 {n_cell, n_output}, // recurrent_to_cell_weight tensor
2855 {n_cell, n_output}, // recurrent_to_output_weight tensor
2856
2857 {0}, // cell_to_input_weight tensor
2858 {0}, // cell_to_forget_weight tensor
2859 {0}, // cell_to_output_weight tensor
2860
2861 {n_cell}, // input_gate_bias tensor
2862 {n_cell}, // forget_gate_bias tensor
2863 {n_cell}, // cell_gate_bias tensor
2864 {n_cell}, // output_gate_bias tensor
2865
2866 {0, 0}, // projection_weight tensor
2867 {0}, // projection_bias tensor
2868
2869 {n_batch, n_output}, // activation_state tensor
2870 {n_batch, n_cell}, // cell_state tensor
2871
2872 {n_batch, n_output}, // activation_state tensor
2873 {n_batch, n_cell}, // cell_state tensor
2874
2875 {sequence_length, n_batch, n_input}, // aux_input tensor
2876 {n_cell, n_input}, // aux_fw_input_to_input tensor
2877 {n_cell, n_input}, // aux_fw_input_to_forget tensor
2878 {n_cell, n_input}, // aux_fw_input_to_cell tensor
2879 {n_cell, n_input}, // aux_fw_input_to_output tensor
2880 {n_cell, n_input}, // aux_bw_input_to_input tensor
2881 {n_cell, n_input}, // aux_bw_input_to_forget tensor
2882 {n_cell, n_input}, // aux_bw_input_to_cell tensor
2883 {n_cell, n_input}, // aux_bw_input_to_output tensor
2884 },
2885 asymmetric_quantize_inputs);
2886
2887 lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
2888 -0.34550029, 0.04266912, -0.15680569,
2889 -0.34856534, 0.43890524});
2890
2891 lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
2892 -0.20583314, 0.44344562, 0.22077113,
2893 -0.29909778});
2894
2895 lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
2896 -0.31343272, -0.40032279, 0.44781327,
2897 0.01387155, -0.35593212});
2898
2899 lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
2900 0.40525138, 0.44272184, 0.03897077, -0.1556896,
2901 0.19487578});
2902
2903 lstm.SetInputGateBias({0., 0., 0., 0.});
2904
2905 lstm.SetCellBias({0., 0., 0., 0.});
2906
2907 lstm.SetForgetGateBias({1., 1., 1., 1.});
2908
2909 lstm.SetOutputGateBias({0., 0., 0., 0.});
2910
2911 lstm.SetRecurrentToInputWeights(
2912 {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
2913 -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
2914 -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
2915
2916 lstm.SetRecurrentToCellWeights(
2917 {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
2918 -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
2919 -0.46367589, 0.26016325, -0.03894562, -0.16368064});
2920
2921 lstm.SetRecurrentToForgetWeights(
2922 {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
2923 -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
2924 0.28053468, 0.01560611, -0.20127171, -0.01140004});
2925
2926 lstm.SetRecurrentToOutputWeights(
2927 {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
2928 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
2929 -0.51818722, -0.15390486, 0.0468148, 0.39922136});
2930
2931 // Input should have n_input * sequence_length many values.
2932 static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
2933 static float lstm_fw_golden_output[] = {
2934 0.153335, 0.542754, 0.708602, 0.742855, 0.247581, 0.835739,
2935 0.947797, 0.958177, 0.410892, 0.672268, 0.761909, 0.829133};
2936 static float lstm_bw_golden_output[] = {
2937 0.342275, 0.883431, 0.955930, 0.975621, 0.204939, 0.806858,
2938 0.914849, 0.934871, 0.123236, 0.373087, 0.465377, 0.517630};
2939
2940 lstm.SetAuxInputToInputWeights({0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
2941 lstm.SetAuxInputToForgetWeights({0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 1.0});
2942 lstm.SetAuxInputToCellWeights({0.5, 0.6, 0.7, 0.8, 0.5, 0.6, 0.7, 0.8});
2943 lstm.SetAuxInputToOutputWeights({0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
2944
2945 float* batch0_start = lstm_input;
2946 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
2947
2948 lstm.SetInput(0, batch0_start, batch0_end);
2949 lstm.SetAuxInput(0, batch0_start, batch0_end);
2950
2951 ASSERT_EQ(lstm.Invoke(), kTfLiteOk);
2952
2953 float* fw_golden_start = lstm_fw_golden_output;
2954 float* fw_golden_end =
2955 fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length();
2956 std::vector<float> fw_expected;
2957 fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end);
2958 EXPECT_THAT(lstm.GetFwOutput(),
2959 ElementsAreArray(
2960 ArrayFloatNear(fw_expected, quantize_weights ? 1e-2 : 1e-5)));
2961
2962 float* bw_golden_start = lstm_bw_golden_output;
2963 float* bw_golden_end =
2964 bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length();
2965 std::vector<float> bw_expected;
2966 bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
2967 EXPECT_THAT(lstm.GetBwOutput(),
2968 ElementsAreArray(
2969 ArrayFloatNear(bw_expected, quantize_weights ? 1e-2 : 1e-5)));
2970 }
2971
2972 } // namespace
2973 } // namespace tflite
2974