1 #include <gtest/gtest.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5
6 #include <test/cpp/api/support.h>
7
8 #include <algorithm>
9 #include <iostream>
10 #include <random>
11 #include <sstream>
12 #include <string>
13
14 using namespace torch::nn;
15
16 namespace rnn_utils = torch::nn::utils::rnn;
17
18 struct NNUtilsTest : torch::test::SeedingFixture {};
19 struct PackedSequenceTest : torch::test::SeedingFixture {};
20
TEST_F(NNUtilsTest,ClipGradNorm)21 TEST_F(NNUtilsTest, ClipGradNorm) {
22 auto l = Linear(10, 10);
23 float max_norm = 2;
24 auto compute_norm = [&](float norm_type) -> float {
25 float total_norm = 0.0;
26 if (norm_type != std::numeric_limits<float>::infinity()) {
27 for (const auto& p : l->parameters()) {
28 total_norm +=
29 p.grad().data().abs().pow(norm_type).sum().item().toFloat();
30 }
31 return std::pow(total_norm, 1.0 / norm_type);
32 } else {
33 for (const auto& p : l->parameters()) {
34 auto param_max = p.grad().data().abs().max().item().toFloat();
35 if (param_max > total_norm) {
36 total_norm = param_max;
37 }
38 }
39 return total_norm;
40 }
41 };
42 auto compare_scaling =
43 [&](const std::vector<torch::Tensor>& grads) -> torch::Tensor {
44 std::vector<torch::Tensor> p_scale;
45 for (const auto i : c10::irange(grads.size())) {
46 auto param = l->parameters()[i];
47 auto grad = grads[i];
48 p_scale.push_back(param.grad().data().div(grad).view(-1));
49 }
50 auto scale = torch::cat(p_scale);
51 return scale; // need to assert std is 0.
52 };
53
54 std::vector<torch::Tensor> grads = {
55 torch::arange(1.0, 101).view({10, 10}),
56 torch::ones({10}).div(1000),
57 };
58 std::vector<float> norm_types = {
59 0.5,
60 1.5,
61 2.0,
62 4.0,
63 std::numeric_limits<float>::infinity(),
64 };
65 for (auto norm_type : norm_types) {
66 for (const auto i : c10::irange(grads.size())) {
67 l->parameters()[i].mutable_grad() =
68 grads[i].clone().view_as(l->parameters()[i].data());
69 }
70 auto norm_before = compute_norm(norm_type);
71 auto norm = utils::clip_grad_norm_(l->parameters(), max_norm, norm_type);
72 auto norm_after = compute_norm(norm_type);
73 ASSERT_FLOAT_EQ(norm, norm_before);
74 ASSERT_NEAR(norm_after, max_norm, 1e-6);
75 ASSERT_LE(norm_after, max_norm);
76 auto scaled = compare_scaling(grads);
77 ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7);
78 }
79 // Small gradients should be left unchanged
80 grads = {
81 torch::rand({10, 10}).div(10000),
82 torch::ones(10).div(500),
83 };
84 for (auto norm_type : norm_types) {
85 for (const auto i : c10::irange(grads.size())) {
86 l->parameters()[i].grad().data().copy_(grads[i]);
87 }
88 auto norm_before = compute_norm(norm_type);
89 auto norm = utils::clip_grad_norm_(l->parameters(), max_norm, norm_type);
90 auto norm_after = compute_norm(norm_type);
91 ASSERT_FLOAT_EQ(norm, norm_before);
92 ASSERT_FLOAT_EQ(norm_before, norm_after);
93 ASSERT_LE(norm_after, max_norm);
94 auto scaled = compare_scaling(grads);
95 ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7);
96 ASSERT_FLOAT_EQ(scaled[0].item().toFloat(), 1);
97 }
98 // should accept a single tensor as input
99 auto p1 = torch::randn({10, 10});
100 auto p2 = torch::randn({10, 10});
101 auto g = torch::arange(1., 101).view({10, 10});
102 p1.mutable_grad() = g.clone();
103 p2.mutable_grad() = g.clone();
104 for (const auto norm_type : norm_types) {
105 utils::clip_grad_norm_(p1, max_norm, norm_type);
106 utils::clip_grad_norm_({p2}, max_norm, norm_type);
107 ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
108 }
109 }
110
111 // Check that clip_grad_norm_ raises an error if the norm of a gradient
112 // is non-finite
TEST_F(NNUtilsTest,ClipGradNormErrorIfNonfinite)113 TEST_F(NNUtilsTest, ClipGradNormErrorIfNonfinite) {
114 double inf = std::numeric_limits<double>::infinity();
115 double nan = std::numeric_limits<double>::quiet_NaN();
116
117 using Vector = std::vector<double>;
118
119 Vector norms_pos = {0.1, 1, 2, 3.5, inf};
120 Vector norms_neg = {-0.1, -1, -2, -3.5};
121 Vector norms_neg_plus_0 = {0, -0.1, -1, -2, -3.5};
122 Vector norms_except_0 = {0.1, 1, 2, 3.5, inf, -0.1, -1, -2, -3.5};
123 Vector norms_all = {0, 0.1, 1, 2, 3.5, inf, -0.1, -1, -2, -3.5};
124
125 // Each entry in test_cases has the following values, in this order:
126 //
127 // grad_only_one_elem If True, only one element of the parameter's
128 // gradient is set to the scalar grad, and the
129 // rest of the elements are 0. If False, all grad
130 // elements are equal to the scalar.
131 //
132 // prefix_finite_grad_param If True, prefix a parameter that has a grad
133 // of 1.
134 //
135 // scalars Scalars to use as the parameter's grad, through
136 // multiplication
137 //
138 // norms_nonfinite Norm types that should produce nonfinite total norm
139 //
140 // norms_finite Norm types that should produce finite total norm
141 std::vector<std::tuple<bool, bool, Vector, Vector, Vector>> test_cases({
142 // Test errors from an infinite grad
143 std::make_tuple(
144 false, false, Vector({inf, -inf}), norms_except_0, Vector({0})),
145 std::make_tuple(
146 false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0),
147 std::make_tuple(
148 true, false, Vector({inf, -inf}), norms_pos, norms_neg_plus_0),
149 std::make_tuple(
150 false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0),
151
152 // Test errors from a NaN grad
153 std::make_tuple(false, false, Vector({nan}), norms_except_0, Vector({0})),
154 std::make_tuple(false, true, Vector({nan}), norms_except_0, Vector({0})),
155 std::make_tuple(true, false, Vector({nan}), norms_except_0, Vector({0})),
156 std::make_tuple(true, true, Vector({nan}), norms_except_0, Vector({0})),
157
158 // Test a grad that should never error
159 std::make_tuple(false, false, Vector({2e22, -2e22}), Vector(), norms_all),
160 std::make_tuple(false, true, Vector({2e22, -2e22}), Vector(), norms_all),
161 std::make_tuple(true, false, Vector({2e22, -2e22}), Vector(), norms_all),
162 std::make_tuple(true, true, Vector({2e22, -2e22}), Vector(), norms_all),
163
164 // Test a grad that will overflow to inf for only some norm orders
165 std::make_tuple(
166 false,
167 false,
168 Vector({2e200, -2e200}),
169 Vector({3.5, 2, -2, -3.5}),
170 Vector({inf, 1, 0.1, 0, -1, -0.1})),
171 std::make_tuple(
172 false,
173 true,
174 Vector({2e200, -2e200}),
175 Vector({3.5, 2}),
176 Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})),
177 std::make_tuple(
178 true,
179 false,
180 Vector({2e200, -2e200}),
181 Vector({3.5, 2}),
182 Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})),
183 std::make_tuple(
184 false,
185 true,
186 Vector({2e200, -2e200}),
187 Vector({3.5, 2}),
188 Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})),
189 });
190
191 auto gen_parameters = [](double scalar,
192 bool grad_only_one_elem,
193 bool prefix_finite_grad_param,
194 torch::DeviceType device_type) {
195 auto param = torch::ones(
196 10,
197 torch::TensorOptions()
198 .dtype(torch::kDouble)
199 .device(device_type)
200 .requires_grad(true));
201 if (grad_only_one_elem) {
202 param[1].mul(scalar).sum().backward();
203 } else {
204 param.mul(scalar).sum().backward();
205 }
206
207 std::vector<torch::Tensor> parameters;
208 if (prefix_finite_grad_param) {
209 auto prefix_param = torch::ones(
210 1,
211 torch::TensorOptions()
212 .dtype(torch::kDouble)
213 .device(device_type)
214 .requires_grad(true));
215 prefix_param.mul(1).sum().backward();
216 parameters.push_back(prefix_param);
217 }
218 parameters.push_back(param);
219
220 return parameters;
221 };
222
223 auto run_test_case = [&gen_parameters](
224 double norm_type,
225 bool error_if_nonfinite,
226 double scalar,
227 bool grad_only_one_elem,
228 bool prefix_finite_grad_param,
229 bool is_norm_nonfinite,
230 torch::DeviceType device_type) {
231 std::stringstream ss;
232 ss << "device: " << device_type << ", norm_type: " << norm_type
233 << ", error_if_nonfinite: " << error_if_nonfinite
234 << ", scalar: " << scalar
235 << ", grad_only_one_elem: " << grad_only_one_elem
236 << ", prefix_finite_grad_param: " << prefix_finite_grad_param
237 << ", is_norm_nonfinite: " << is_norm_nonfinite;
238 std::string msg = ss.str();
239
240 auto parameters = gen_parameters(
241 scalar, grad_only_one_elem, prefix_finite_grad_param, device_type);
242
243 if (is_norm_nonfinite && error_if_nonfinite) {
244 std::vector<torch::Tensor> grads_before;
245 // NOLINTNEXTLINE(performance-for-range-copy)
246 for (auto p : parameters) {
247 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
248 grads_before.push_back(p.grad().clone());
249 }
250 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
251 EXPECT_THROW(
252 utils::clip_grad_norm_(parameters, 1., norm_type, true),
253 std::exception)
254 << msg;
255 // Grads should not change if error is thrown
256 for (const auto p_idx : c10::irange(parameters.size())) {
257 ASSERT_TRUE(torch::allclose(
258 parameters[p_idx].grad(),
259 grads_before[p_idx],
260 1.0,
261 0.0,
262 /*equal_nan*/ true))
263 << msg;
264 }
265 } else {
266 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
267 EXPECT_NO_THROW(
268 utils::clip_grad_norm_(parameters, 1., norm_type, error_if_nonfinite))
269 << msg;
270 }
271 };
272
273 for (auto device_type : {torch::kCPU, torch::kCUDA}) {
274 if (device_type == torch::kCUDA && !torch::cuda::is_available()) {
275 continue;
276 }
277 for (auto test_case : test_cases) {
278 auto grad_only_one_elem = std::get<0>(test_case);
279 auto prefix_finite_grad_param = std::get<1>(test_case);
280 auto scalars = std::get<2>(test_case);
281 auto norms_nonfinite = std::get<3>(test_case);
282 auto norms_finite = std::get<4>(test_case);
283
284 for (auto error_if_nonfinite : {false, true}) {
285 for (auto scalar : scalars) {
286 for (auto norm_type : norms_nonfinite) {
287 run_test_case(
288 norm_type,
289 error_if_nonfinite,
290 scalar,
291 grad_only_one_elem,
292 prefix_finite_grad_param,
293 true,
294 device_type);
295 }
296
297 for (auto norm_type : norms_finite) {
298 run_test_case(
299 norm_type,
300 error_if_nonfinite,
301 scalar,
302 grad_only_one_elem,
303 prefix_finite_grad_param,
304 false,
305 device_type);
306 }
307 }
308 }
309 }
310 }
311 }
312
TEST_F(NNUtilsTest,ClipGradValue)313 TEST_F(NNUtilsTest, ClipGradValue) {
314 auto l = Linear(10, 10);
315 float clip_value = 2.5;
316
317 torch::Tensor grad_w = torch::arange(-50., 50).view({10, 10}).div_(5);
318 torch::Tensor grad_b = torch::ones({10}).mul_(2);
319 std::vector<std::vector<torch::Tensor>> grad_lists = {
320 {grad_w, grad_b}, {grad_w, torch::Tensor()}};
321 for (auto grad_list : grad_lists) {
322 for (const auto i : c10::irange(grad_list.size())) {
323 auto p = l->parameters()[i];
324 auto g = grad_list[i];
325 p.mutable_grad() = g.defined() ? g.clone().view_as(p.data()) : g;
326 }
327
328 utils::clip_grad_value_(l->parameters(), clip_value);
329 for (const auto& p : l->parameters()) {
330 if (p.grad().defined()) {
331 ASSERT_LE(p.grad().data().max().item().toFloat(), clip_value);
332 ASSERT_GE(p.grad().data().min().item().toFloat(), -clip_value);
333 }
334 }
335 }
336
337 // Should accept a single Tensor as input
338 auto p1 = torch::randn({10, 10});
339 auto p2 = torch::randn({10, 10});
340 auto g = torch::arange(-50., 50).view({10, 10}).div_(5);
341 p1.mutable_grad() = g.clone();
342 p2.mutable_grad() = g.clone();
343 utils::clip_grad_value_(p1, clip_value);
344 utils::clip_grad_value_({p2}, clip_value);
345 ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
346 }
347
TEST_F(NNUtilsTest,ConvertParameters)348 TEST_F(NNUtilsTest, ConvertParameters) {
349 std::vector<torch::Tensor> parameters{
350 torch::arange(9, torch::kFloat32),
351 torch::arange(9, torch::kFloat32).view({3, 3}),
352 torch::arange(8, torch::kFloat32).view({2, 2, 2})};
353
354 auto expected = torch::cat(
355 {torch::arange(9, torch::kFloat32),
356 torch::arange(9, torch::kFloat32).view(-1),
357 torch::arange(8, torch::kFloat32).view(-1)});
358 auto vector = utils::parameters_to_vector(parameters);
359 ASSERT_TRUE(vector.allclose(expected));
360
361 std::vector<torch::Tensor> zero_parameters{
362 torch::zeros({9}, torch::kFloat32),
363 torch::zeros({9}, torch::kFloat32).view({3, 3}),
364 torch::zeros({8}, torch::kFloat32).view({2, 2, 2})};
365
366 utils::vector_to_parameters(vector, zero_parameters);
367 for (const auto i : c10::irange(zero_parameters.size())) {
368 ASSERT_TRUE(zero_parameters[i].allclose(parameters[i]));
369 }
370
371 {
372 auto conv1 = Conv2d(3, 10, 5);
373 auto fc1 = Linear(10, 20);
374 auto model = Sequential(conv1, fc1);
375
376 auto vec = utils::parameters_to_vector(model->parameters());
377 ASSERT_EQ(vec.size(0), 980);
378 }
379 {
380 auto conv1 = Conv2d(3, 10, 5);
381 auto fc1 = Linear(10, 20);
382 auto model = Sequential(conv1, fc1);
383
384 auto vec = torch::arange(0., 980);
385 utils::vector_to_parameters(vec, model->parameters());
386
387 auto sample = model->parameters()[0][0][0][0];
388 ASSERT_TRUE(torch::equal(sample.data(), vec.data().slice(0, 0, 5)));
389 }
390 }
391
392 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-non-const-global-variables)
393 int64_t PackedSequenceTest_batch_size = 5;
394 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-non-const-global-variables)
395 int64_t PackedSequenceTest_max_length = 6;
396
PackedSequenceTest_ordered_sequence(torch::ScalarType tensor_type)397 std::vector<torch::Tensor> PackedSequenceTest_ordered_sequence(
398 torch::ScalarType tensor_type) {
399 std::vector<torch::Tensor> seqs;
400 seqs.reserve(PackedSequenceTest_batch_size);
401 for (const auto i : c10::irange(PackedSequenceTest_batch_size)) {
402 (void)i; // Suppress unused variable warning
403 seqs.emplace_back(torch::empty(
404 {torch::randint(1, PackedSequenceTest_max_length, {1}).item<int64_t>()},
405 tensor_type));
406 }
407 for (auto& s : seqs) {
408 s.random_(-128, 128);
409 }
410 sort(
411 seqs.begin(),
412 seqs.end(),
413 [&](const torch::Tensor& t1, const torch::Tensor& t2) {
414 return t1.size(0) > t2.size(0);
415 });
416 return seqs;
417 }
418
PackedSequenceTest_padded_sequence(torch::ScalarType tensor_type)419 std::tuple<torch::Tensor, torch::Tensor> PackedSequenceTest_padded_sequence(
420 torch::ScalarType tensor_type) {
421 // Create Tensor of random padded sequences
422 auto ordered = PackedSequenceTest_ordered_sequence(tensor_type);
423 auto lengths = torch::empty({(int64_t)ordered.size()}, torch::kInt64);
424 for (const auto i : c10::irange(ordered.size())) {
425 lengths[i] = ordered[i].size(0);
426 }
427 auto padded_tensor = rnn_utils::pad_sequence(ordered);
428 return std::make_tuple(padded_tensor, lengths);
429 }
430
assert_is_equal_packed_sequence(const rnn_utils::PackedSequence & a,const rnn_utils::PackedSequence & b)431 void assert_is_equal_packed_sequence(
432 const rnn_utils::PackedSequence& a,
433 const rnn_utils::PackedSequence& b) {
434 ASSERT_TRUE(torch::allclose(a.data(), b.data()));
435 ASSERT_TRUE(torch::allclose(a.batch_sizes(), b.batch_sizes()));
436 ASSERT_TRUE(
437 (!a.sorted_indices().defined() && !b.sorted_indices().defined()) ||
438 torch::allclose(a.sorted_indices(), b.sorted_indices()));
439 ASSERT_TRUE(
440 (!a.unsorted_indices().defined() && !b.unsorted_indices().defined()) ||
441 torch::allclose(a.unsorted_indices(), b.unsorted_indices()));
442 }
443
assert_is_same_packed_sequence(const rnn_utils::PackedSequence & a,const rnn_utils::PackedSequence & b)444 void assert_is_same_packed_sequence(
445 const rnn_utils::PackedSequence& a,
446 const rnn_utils::PackedSequence& b) {
447 ASSERT_TRUE(a.data().is_same(b.data()));
448 ASSERT_TRUE(a.batch_sizes().is_same(b.batch_sizes()));
449 ASSERT_TRUE(a.sorted_indices().is_same(b.sorted_indices()));
450 ASSERT_TRUE(a.unsorted_indices().is_same(b.unsorted_indices()));
451 }
452
TEST_F(PackedSequenceTest,WrongOrder)453 TEST_F(PackedSequenceTest, WrongOrder) {
454 auto a = torch::ones({25, 300});
455 auto b = torch::ones({22, 300});
456 auto b_a = rnn_utils::pad_sequence({b, a});
457 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
458 ASSERT_THROW(
459 rnn_utils::pack_padded_sequence(
460 b_a,
461 torch::tensor({22, 25}),
462 /*batch_first=*/false,
463 /*enforce_sorted=*/true),
464 c10::Error);
465 }
466
TEST_F(PackedSequenceTest,TotalLength)467 TEST_F(PackedSequenceTest, TotalLength) {
468 auto [padded, lengths] = PackedSequenceTest_padded_sequence(torch::kFloat);
469 int64_t max_length = torch::max(lengths).item<int64_t>();
470 rnn_utils::PackedSequence packed =
471 rnn_utils::pack_padded_sequence(padded, lengths);
472
473 // test ValueError if total_length < max_length
474 for (int64_t total_length : std::vector<int64_t>{-1, 0, max_length - 1}) {
475 for (bool batch_first : std::vector<bool>{true, false}) {
476 auto err_fn = [&]() {
477 rnn_utils::pad_packed_sequence(
478 packed,
479 /*batch_first=*/batch_first,
480 /*padding_value=*/0.0,
481 /*total_length=*/total_length);
482 };
483 ASSERT_THROWS_WITH(
484 err_fn(),
485 "Expected total_length to be at least the length of the longest sequence in input");
486 }
487 }
488
489 // test that pad_packed_sequence returns results of correct length
490 for (bool batch_first : std::vector<bool>{true, false}) {
491 auto no_extra_pad = std::get<0>(
492 rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first));
493 for (int64_t total_length_delta : std::vector<int64_t>{0, 1, 8}) {
494 int64_t total_length = max_length + total_length_delta;
495 auto [unpacked, lengths_out] = rnn_utils::pad_packed_sequence(
496 packed,
497 /*batch_first=*/batch_first,
498 /*padding_value=*/0.0,
499 /*total_length=*/total_length);
500 ASSERT_TRUE(torch::allclose(lengths, lengths_out));
501 ASSERT_EQ(unpacked.size(batch_first ? 1 : 0), total_length);
502 torch::Tensor ref_output, extra_pad;
503 if (total_length_delta == 0) {
504 ref_output = no_extra_pad;
505 } else if (batch_first) {
506 extra_pad = torch::zeros(
507 {PackedSequenceTest_batch_size, total_length_delta},
508 no_extra_pad.options());
509 ref_output = torch::cat({no_extra_pad, extra_pad}, 1);
510 } else {
511 extra_pad = torch::zeros(
512 {total_length_delta, PackedSequenceTest_batch_size},
513 no_extra_pad.options());
514 ref_output = torch::cat({no_extra_pad, extra_pad}, 0);
515 }
516 ASSERT_TRUE(torch::allclose(unpacked, ref_output));
517 }
518 }
519 }
520
TEST_F(PackedSequenceTest,To)521 TEST_F(PackedSequenceTest, To) {
522 for (bool enforce_sorted : std::vector<bool>{true, false}) {
523 auto [padded, lengths] = PackedSequenceTest_padded_sequence(torch::kInt);
524 rnn_utils::PackedSequence a = rnn_utils::pack_padded_sequence(
525 padded,
526 lengths,
527 /*batch_first=*/false,
528 /*enforce_sorted=*/enforce_sorted)
529 .cpu();
530
531 assert_is_same_packed_sequence(a, a.to(torch::kCPU));
532 assert_is_same_packed_sequence(a, a.cpu());
533 assert_is_same_packed_sequence(
534 a, a.to(torch::device(torch::kCPU).dtype(torch::kInt32)));
535
536 if (torch::cuda::is_available()) {
537 auto b = a.cuda();
538 assert_is_same_packed_sequence(b, b.to(torch::kCUDA));
539 assert_is_same_packed_sequence(b, b.cuda());
540 assert_is_equal_packed_sequence(a, b.to(torch::kCPU));
541 assert_is_equal_packed_sequence(b, a.to(torch::kCUDA));
542 assert_is_equal_packed_sequence(
543 a, b.to(torch::device(torch::kCPU).dtype(torch::kInt32)));
544 assert_is_same_packed_sequence(b, b.to(torch::kInt32));
545 }
546 }
547 }
548
TEST_F(NNUtilsTest,PackSequence)549 TEST_F(NNUtilsTest, PackSequence) {
550 auto _compatibility_test = [&](torch::ArrayRef<torch::Tensor> sequences,
551 torch::Tensor lengths,
552 bool batch_first,
553 bool enforce_sorted = false) {
554 torch::Tensor padded = rnn_utils::pad_sequence(sequences, batch_first);
555 rnn_utils::PackedSequence packed =
556 rnn_utils::pack_sequence(sequences, enforce_sorted);
557 std::tuple<torch::Tensor, torch::Tensor> unpacked =
558 rnn_utils::pad_packed_sequence(packed, batch_first);
559 ASSERT_TRUE(torch::allclose(padded, std::get<0>(unpacked)));
560 rnn_utils::PackedSequence pack_padded = rnn_utils::pack_padded_sequence(
561 padded, lengths, batch_first, enforce_sorted);
562 assert_is_equal_packed_sequence(packed, pack_padded);
563 };
564
565 // single dimensional
566 auto a = torch::tensor({1, 2, 3});
567 auto b = torch::tensor({4, 5});
568 auto c = torch::tensor({6});
569 rnn_utils::PackedSequence packed =
570 rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/false);
571 auto expected = torch::tensor({1, 4, 6, 2, 5, 3});
572 ASSERT_TRUE(torch::allclose(packed.batch_sizes(), torch::tensor({3, 2, 1})));
573 ASSERT_TRUE(torch::allclose(packed.data(), expected));
574 ASSERT_TRUE(
575 torch::allclose(packed.sorted_indices(), torch::tensor({0, 1, 2})));
576 ASSERT_TRUE(
577 torch::allclose(packed.unsorted_indices(), torch::tensor({0, 1, 2})));
578
579 rnn_utils::PackedSequence packed_unsorted =
580 rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/false);
581 ASSERT_TRUE(
582 torch::allclose(packed_unsorted.batch_sizes(), torch::tensor({3, 2, 1})));
583 ASSERT_TRUE(torch::allclose(packed_unsorted.data(), expected));
584 ASSERT_TRUE(torch::allclose(
585 packed_unsorted.sorted_indices(), torch::tensor({2, 0, 1})));
586 ASSERT_TRUE(torch::allclose(
587 packed_unsorted.unsorted_indices(), torch::tensor({1, 2, 0})));
588
589 // single dimensional, enforce_sorted = True
590 rnn_utils::PackedSequence packed_enforce_sorted =
591 rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/true);
592 ASSERT_TRUE(torch::allclose(
593 packed_enforce_sorted.batch_sizes(), torch::tensor({3, 2, 1})));
594 ASSERT_TRUE(torch::allclose(packed_enforce_sorted.data(), expected));
595 ASSERT_FALSE(packed_enforce_sorted.sorted_indices().defined());
596 ASSERT_FALSE(packed_enforce_sorted.unsorted_indices().defined());
597
598 ASSERT_THROWS_WITH(
599 rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true),
600 "must be sorted in decreasing order");
601
602 ASSERT_THROWS_WITH(
603 rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true),
604 "You can pass `enforce_sorted=False`");
605
606 // more dimensions
607 int64_t maxlen = 9;
608 for (int64_t num_dim : std::vector<int64_t>{0, 1, 2, 3}) {
609 std::vector<torch::Tensor> sequences;
610 std::vector<int64_t> lengths_vec;
611 std::vector<int64_t> trailing_dims(num_dim, 4);
612 for (int64_t i = maxlen; i > 0; i--) {
613 int64_t seq_len = i * i;
614 lengths_vec.emplace_back(seq_len);
615 std::vector<int64_t> tensor_sizes{seq_len, 5};
616 tensor_sizes.insert(
617 tensor_sizes.end(), trailing_dims.begin(), trailing_dims.end());
618 sequences.emplace_back(torch::rand(tensor_sizes));
619 }
620 std::vector<torch::Tensor> unsorted_sequences;
621 for (const auto& s : sequences) {
622 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
623 unsorted_sequences.emplace_back(s.clone());
624 }
625 std::shuffle(
626 std::begin(unsorted_sequences),
627 std::end(unsorted_sequences),
628 std::default_random_engine{});
629
630 std::vector<int64_t> unsorted_sequences_lengths_vec;
631 for (const auto& t : unsorted_sequences) {
632 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
633 unsorted_sequences_lengths_vec.emplace_back(t.size(0));
634 }
635
636 // compatibility with other utilities
637 for (bool batch_first : std::vector<bool>{true, false}) {
638 for (bool enforce_sorted : std::vector<bool>{true, false}) {
639 _compatibility_test(
640 sequences, torch::tensor(lengths_vec), batch_first, enforce_sorted);
641 }
642 _compatibility_test(
643 unsorted_sequences,
644 torch::tensor(unsorted_sequences_lengths_vec),
645 batch_first);
646 }
647 }
648 }
649
TEST_F(NNUtilsTest,PackPaddedSequence)650 TEST_F(NNUtilsTest, PackPaddedSequence) {
651 auto generate_test_case = [&](torch::ArrayRef<int64_t> sorted_lengths,
652 bool should_shuffle) {
653 auto pad = [&](torch::Tensor tensor, int64_t length) {
654 std::vector<int64_t> tensor_sizes{length - tensor.size(0)};
655 tensor_sizes.insert(
656 tensor_sizes.end(),
657 tensor.sizes().slice(1).begin(),
658 tensor.sizes().slice(1).end());
659 return torch::cat({tensor, torch::zeros(tensor_sizes, tensor.options())});
660 };
661 int64_t max_length = sorted_lengths[0];
662 torch::Tensor batch_sizes = torch::empty({max_length}, torch::kInt64);
663 for (int64_t i = 1; i < max_length + 1; i++) {
664 int64_t total = 0;
665 for (const auto& x : sorted_lengths) {
666 if (x >= i) {
667 total++;
668 }
669 }
670 batch_sizes[i - 1] = total;
671 }
672 std::vector<torch::Tensor> tensors_to_be_cat;
673 for (int64_t i = 1; i < static_cast<int64_t>(sorted_lengths.size() + 1);
674 i++) {
675 int64_t l = sorted_lengths.at(i - 1);
676 tensors_to_be_cat.emplace_back(pad(
677 i * 100 + torch::arange(1., 5 * l + 1).view({l, 1, 5}), max_length));
678 }
679 auto padded = torch::cat(tensors_to_be_cat, 1);
680 std::vector<torch::Tensor> expected_data_vec;
681 for (const auto n : c10::irange(batch_sizes.size(0))) {
682 int64_t batch_size = batch_sizes[n].item<int64_t>();
683 for (const auto i : c10::irange(batch_size)) {
684 expected_data_vec.emplace_back(
685 torch::arange(1., 6) + (i + 1) * 100 + 5 * n);
686 }
687 }
688 auto expected_data = torch::stack(expected_data_vec, /*dim=*/0);
689
690 torch::Tensor unsorted_indices, lengths;
691 if (should_shuffle) {
692 // Shuffle the padded sequence to create an unsorted sequence
693 std::vector<int64_t> permutation;
694 for (const auto i : c10::irange(sorted_lengths.size())) {
695 permutation.emplace_back(i);
696 }
697 std::shuffle(
698 std::begin(permutation),
699 std::end(permutation),
700 std::default_random_engine{});
701
702 unsorted_indices = torch::tensor(permutation);
703 padded = padded.index_select(1, unsorted_indices);
704 lengths = torch::tensor(sorted_lengths).index_select(0, unsorted_indices);
705 } else {
706 unsorted_indices = torch::Tensor();
707 lengths = torch::tensor(sorted_lengths);
708 }
709
710 return std::make_tuple(
711 padded.requires_grad_(),
712 lengths,
713 expected_data,
714 batch_sizes,
715 unsorted_indices);
716 };
717
718 std::vector<std::pair<std::vector<int64_t>, bool>> test_cases = {
719 // sorted_lengths, should_shuffle
720 {{10, 8, 4, 2, 2, 2, 1}, false},
721 {{11, 10, 8, 6, 4, 3, 1}, false},
722 {{11, 10, 8, 6, 4, 3, 1}, true}};
723
724 for (const auto& test_case : test_cases) {
725 for (bool batch_first : std::vector<bool>{true, false}) {
726 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
727 std::vector<int64_t> sorted_lengths = std::get<0>(test_case);
728 bool should_shuffle = std::get<1>(test_case);
729
730 auto [padded, lengths, expected_data, batch_sizes, unsorted_indices] =
731 generate_test_case(sorted_lengths, should_shuffle);
732
733 auto src = padded;
734 if (batch_first) {
735 src = src.transpose(0, 1);
736 }
737
738 // check output
739 rnn_utils::PackedSequence packed = rnn_utils::pack_padded_sequence(
740 src,
741 lengths,
742 /*batch_first=*/batch_first,
743 /*enforce_sorted=*/!should_shuffle);
744 ASSERT_TRUE(torch::allclose(packed.data(), expected_data));
745 ASSERT_TRUE(torch::allclose(packed.batch_sizes(), batch_sizes));
746 ASSERT_TRUE(
747 (!packed.unsorted_indices().defined() &&
748 !unsorted_indices.defined()) ||
749 torch::allclose(packed.unsorted_indices(), unsorted_indices));
750
751 // test inverse
752 auto [unpacked, unpacked_len] =
753 rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first);
754 ASSERT_TRUE(torch::allclose(unpacked, src));
755 ASSERT_TRUE(torch::allclose(unpacked_len, lengths));
756
757 // check grad
758 if (padded.grad().defined()) {
759 torch::NoGradGuard no_grad;
760 padded.grad().zero_();
761 }
762 torch::Tensor grad_output;
763 {
764 torch::NoGradGuard no_grad;
765 grad_output = unpacked.clone().normal_();
766 }
767 unpacked.backward(grad_output);
768 if (batch_first) {
769 grad_output.transpose_(0, 1);
770 }
771 for (const auto i : c10::irange(lengths.size(0))) {
772 int64_t l = lengths[i].item<int64_t>();
773 ASSERT_TRUE(torch::allclose(
774 padded.grad().narrow(0, 0, l).select(1, i),
775 grad_output.narrow(0, 0, l).select(1, i)));
776 if (l < 10) {
777 ASSERT_EQ(
778 padded.grad()
779 .narrow(0, l, padded.grad().size(0) - l)
780 .select(1, i)
781 .abs()
782 .sum()
783 .item<double>(),
784 0);
785 }
786 }
787 }
788 }
789
790 // test error messages
791 ASSERT_THROWS_WITH(
792 rnn_utils::pack_padded_sequence(
793 torch::randn({3, 3}), torch::tensor({1, 3, 2})),
794 "You can pass `enforce_sorted=False`");
795 ASSERT_THROWS_WITH(
796 rnn_utils::pack_padded_sequence(torch::randn({0, 0}), torch::tensor({})),
797 "empty tensor");
798 }
799
TEST_F(NNUtilsTest,PadSequence)800 TEST_F(NNUtilsTest, PadSequence) {
801 auto pad = [&](const torch::Tensor& tensor, int64_t length) {
802 torch::NoGradGuard no_grad;
803 std::vector<int64_t> tensor_sizes{length - tensor.size(0)};
804 tensor_sizes.insert(
805 tensor_sizes.end(),
806 tensor.sizes().slice(1).begin(),
807 tensor.sizes().slice(1).end());
808 return torch::cat({tensor, torch::zeros(tensor_sizes, tensor.options())});
809 };
810
811 // single dimensional
812 auto a = torch::tensor({1, 2, 3});
813 auto b = torch::tensor({4, 5});
814 auto c = torch::tensor({6});
815
816 torch::Tensor expected, padded;
817
818 // batch_first = true
819 expected = torch::tensor({{4, 5, 0}, {1, 2, 3}, {6, 0, 0}});
820 padded = rnn_utils::pad_sequence({b, a, c}, true);
821 ASSERT_TRUE(padded.allclose(expected));
822
823 // batch_first = false
824 padded = rnn_utils::pad_sequence({b, a, c});
825 ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
826
827 // padding_side = "left", batch_first = true
828 expected = torch::tensor({{0, 4, 5}, {1, 2, 3}, {0, 0, 6}});
829 padded = rnn_utils::pad_sequence({b, a, c}, true, 0, "left");
830 ASSERT_TRUE(padded.allclose(expected));
831
832 // padding_side = "left", batch_first = false
833 padded = rnn_utils::pad_sequence({b, a, c}, false, 0, "left");
834 ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
835
836 // pad with non-zero value
837 expected = torch::tensor({{4, 5, 1}, {1, 2, 3}, {6, 1, 1}});
838 padded = rnn_utils::pad_sequence({b, a, c}, true, 1);
839 ASSERT_TRUE(padded.allclose(expected));
840
841 // Test pad sorted sequence
842 expected = torch::tensor({{1, 2, 3}, {4, 5, 0}, {6, 0, 0}});
843 padded = rnn_utils::pad_sequence({a, b, c}, true);
844 ASSERT_TRUE(padded.allclose(expected));
845
846 // more dimensions
847 int64_t maxlen = 9;
848 for (int64_t num_dim : std::vector<int64_t>{0, 1, 2, 3}) {
849 std::vector<torch::Tensor> sequences;
850 std::vector<int64_t> trailing_dims(num_dim, 4);
851 for (int64_t i = 1; i < maxlen + 1; i++) {
852 int64_t seq_len = i * i;
853 std::vector<int64_t> tensor_sizes{seq_len, 5};
854 tensor_sizes.insert(
855 tensor_sizes.end(), trailing_dims.begin(), trailing_dims.end());
856 sequences.emplace_back(torch::rand(tensor_sizes));
857 }
858 std::shuffle(
859 std::begin(sequences),
860 std::end(sequences),
861 std::default_random_engine{});
862 std::vector<torch::Tensor> expected_tensors;
863 for (const torch::Tensor& seq : sequences) {
864 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
865 expected_tensors.emplace_back(pad(seq, maxlen * maxlen));
866 }
867
868 // batch first = true
869 auto expected = torch::stack(expected_tensors);
870 auto padded = rnn_utils::pad_sequence(sequences, true);
871 ASSERT_TRUE(padded.allclose(expected));
872
873 // batch first = false
874 padded = rnn_utils::pad_sequence(sequences);
875 ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
876
877 // reset expected_tensors for padding_side
878 expected_tensors.clear();
879 for (const torch::Tensor& seq : sequences) {
880 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
881 expected_tensors.emplace_back(
882 torch::flip(pad(torch::flip(seq, {0}), maxlen * maxlen), {0}));
883 }
884 expected = torch::stack(expected_tensors);
885 // padding_side = "left", batch_first = true
886 padded = rnn_utils::pad_sequence(sequences, true, 0, "left");
887 ASSERT_TRUE(padded.allclose(expected));
888
889 // padding_side = "left", batch_first = false
890 padded = rnn_utils::pad_sequence(sequences, false, 0, "left");
891 ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
892 }
893 }
894