xref: /aosp_15_r20/external/pytorch/test/cpp/api/nn_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <test/cpp/api/support.h>
7 
8 #include <algorithm>
9 #include <iostream>
10 #include <random>
11 #include <sstream>
12 #include <string>
13 
14 using namespace torch::nn;
15 
16 namespace rnn_utils = torch::nn::utils::rnn;
17 
18 struct NNUtilsTest : torch::test::SeedingFixture {};
19 struct PackedSequenceTest : torch::test::SeedingFixture {};
20 
TEST_F(NNUtilsTest,ClipGradNorm)21 TEST_F(NNUtilsTest, ClipGradNorm) {
22   auto l = Linear(10, 10);
23   float max_norm = 2;
24   auto compute_norm = [&](float norm_type) -> float {
25     float total_norm = 0.0;
26     if (norm_type != std::numeric_limits<float>::infinity()) {
27       for (const auto& p : l->parameters()) {
28         total_norm +=
29             p.grad().data().abs().pow(norm_type).sum().item().toFloat();
30       }
31       return std::pow(total_norm, 1.0 / norm_type);
32     } else {
33       for (const auto& p : l->parameters()) {
34         auto param_max = p.grad().data().abs().max().item().toFloat();
35         if (param_max > total_norm) {
36           total_norm = param_max;
37         }
38       }
39       return total_norm;
40     }
41   };
42   auto compare_scaling =
43       [&](const std::vector<torch::Tensor>& grads) -> torch::Tensor {
44     std::vector<torch::Tensor> p_scale;
45     for (const auto i : c10::irange(grads.size())) {
46       auto param = l->parameters()[i];
47       auto grad = grads[i];
48       p_scale.push_back(param.grad().data().div(grad).view(-1));
49     }
50     auto scale = torch::cat(p_scale);
51     return scale; // need to assert std is 0.
52   };
53 
54   std::vector<torch::Tensor> grads = {
55       torch::arange(1.0, 101).view({10, 10}),
56       torch::ones({10}).div(1000),
57   };
58   std::vector<float> norm_types = {
59       0.5,
60       1.5,
61       2.0,
62       4.0,
63       std::numeric_limits<float>::infinity(),
64   };
65   for (auto norm_type : norm_types) {
66     for (const auto i : c10::irange(grads.size())) {
67       l->parameters()[i].mutable_grad() =
68           grads[i].clone().view_as(l->parameters()[i].data());
69     }
70     auto norm_before = compute_norm(norm_type);
71     auto norm = utils::clip_grad_norm_(l->parameters(), max_norm, norm_type);
72     auto norm_after = compute_norm(norm_type);
73     ASSERT_FLOAT_EQ(norm, norm_before);
74     ASSERT_NEAR(norm_after, max_norm, 1e-6);
75     ASSERT_LE(norm_after, max_norm);
76     auto scaled = compare_scaling(grads);
77     ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7);
78   }
79   // Small gradients should be left unchanged
80   grads = {
81       torch::rand({10, 10}).div(10000),
82       torch::ones(10).div(500),
83   };
84   for (auto norm_type : norm_types) {
85     for (const auto i : c10::irange(grads.size())) {
86       l->parameters()[i].grad().data().copy_(grads[i]);
87     }
88     auto norm_before = compute_norm(norm_type);
89     auto norm = utils::clip_grad_norm_(l->parameters(), max_norm, norm_type);
90     auto norm_after = compute_norm(norm_type);
91     ASSERT_FLOAT_EQ(norm, norm_before);
92     ASSERT_FLOAT_EQ(norm_before, norm_after);
93     ASSERT_LE(norm_after, max_norm);
94     auto scaled = compare_scaling(grads);
95     ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7);
96     ASSERT_FLOAT_EQ(scaled[0].item().toFloat(), 1);
97   }
98   // should accept a single tensor as input
99   auto p1 = torch::randn({10, 10});
100   auto p2 = torch::randn({10, 10});
101   auto g = torch::arange(1., 101).view({10, 10});
102   p1.mutable_grad() = g.clone();
103   p2.mutable_grad() = g.clone();
104   for (const auto norm_type : norm_types) {
105     utils::clip_grad_norm_(p1, max_norm, norm_type);
106     utils::clip_grad_norm_({p2}, max_norm, norm_type);
107     ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
108   }
109 }
110 
111 // Check that clip_grad_norm_ raises an error if the norm of a gradient
112 // is non-finite
TEST_F(NNUtilsTest,ClipGradNormErrorIfNonfinite)113 TEST_F(NNUtilsTest, ClipGradNormErrorIfNonfinite) {
114   double inf = std::numeric_limits<double>::infinity();
115   double nan = std::numeric_limits<double>::quiet_NaN();
116 
117   using Vector = std::vector<double>;
118 
119   Vector norms_pos = {0.1, 1, 2, 3.5, inf};
120   Vector norms_neg = {-0.1, -1, -2, -3.5};
121   Vector norms_neg_plus_0 = {0, -0.1, -1, -2, -3.5};
122   Vector norms_except_0 = {0.1, 1, 2, 3.5, inf, -0.1, -1, -2, -3.5};
123   Vector norms_all = {0, 0.1, 1, 2, 3.5, inf, -0.1, -1, -2, -3.5};
124 
125   // Each entry in test_cases has the following values, in this order:
126   //
127   // grad_only_one_elem    If True, only one element of the parameter's
128   //                       gradient is set to the scalar grad, and the
129   //                       rest of the elements are 0. If False, all grad
130   //                       elements are equal to the scalar.
131   //
132   // prefix_finite_grad_param  If True, prefix a parameter that has a grad
133   //                           of 1.
134   //
135   // scalars           Scalars to use as the parameter's grad, through
136   //                   multiplication
137   //
138   // norms_nonfinite   Norm types that should produce nonfinite total norm
139   //
140   // norms_finite      Norm types that should produce finite total norm
141   std::vector<std::tuple<bool, bool, Vector, Vector, Vector>> test_cases({
142       // Test errors from an infinite grad
143       std::make_tuple(
144           false, false, Vector({inf, -inf}), norms_except_0, Vector({0})),
145       std::make_tuple(
146           false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0),
147       std::make_tuple(
148           true, false, Vector({inf, -inf}), norms_pos, norms_neg_plus_0),
149       std::make_tuple(
150           false, true, Vector({inf, -inf}), norms_pos, norms_neg_plus_0),
151 
152       // Test errors from a NaN grad
153       std::make_tuple(false, false, Vector({nan}), norms_except_0, Vector({0})),
154       std::make_tuple(false, true, Vector({nan}), norms_except_0, Vector({0})),
155       std::make_tuple(true, false, Vector({nan}), norms_except_0, Vector({0})),
156       std::make_tuple(true, true, Vector({nan}), norms_except_0, Vector({0})),
157 
158       // Test a grad that should never error
159       std::make_tuple(false, false, Vector({2e22, -2e22}), Vector(), norms_all),
160       std::make_tuple(false, true, Vector({2e22, -2e22}), Vector(), norms_all),
161       std::make_tuple(true, false, Vector({2e22, -2e22}), Vector(), norms_all),
162       std::make_tuple(true, true, Vector({2e22, -2e22}), Vector(), norms_all),
163 
164       // Test a grad that will overflow to inf for only some norm orders
165       std::make_tuple(
166           false,
167           false,
168           Vector({2e200, -2e200}),
169           Vector({3.5, 2, -2, -3.5}),
170           Vector({inf, 1, 0.1, 0, -1, -0.1})),
171       std::make_tuple(
172           false,
173           true,
174           Vector({2e200, -2e200}),
175           Vector({3.5, 2}),
176           Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})),
177       std::make_tuple(
178           true,
179           false,
180           Vector({2e200, -2e200}),
181           Vector({3.5, 2}),
182           Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})),
183       std::make_tuple(
184           false,
185           true,
186           Vector({2e200, -2e200}),
187           Vector({3.5, 2}),
188           Vector({inf, 1, 0.1, 0, -1, -0.1, -2, -3.5})),
189   });
190 
191   auto gen_parameters = [](double scalar,
192                            bool grad_only_one_elem,
193                            bool prefix_finite_grad_param,
194                            torch::DeviceType device_type) {
195     auto param = torch::ones(
196         10,
197         torch::TensorOptions()
198             .dtype(torch::kDouble)
199             .device(device_type)
200             .requires_grad(true));
201     if (grad_only_one_elem) {
202       param[1].mul(scalar).sum().backward();
203     } else {
204       param.mul(scalar).sum().backward();
205     }
206 
207     std::vector<torch::Tensor> parameters;
208     if (prefix_finite_grad_param) {
209       auto prefix_param = torch::ones(
210           1,
211           torch::TensorOptions()
212               .dtype(torch::kDouble)
213               .device(device_type)
214               .requires_grad(true));
215       prefix_param.mul(1).sum().backward();
216       parameters.push_back(prefix_param);
217     }
218     parameters.push_back(param);
219 
220     return parameters;
221   };
222 
223   auto run_test_case = [&gen_parameters](
224                            double norm_type,
225                            bool error_if_nonfinite,
226                            double scalar,
227                            bool grad_only_one_elem,
228                            bool prefix_finite_grad_param,
229                            bool is_norm_nonfinite,
230                            torch::DeviceType device_type) {
231     std::stringstream ss;
232     ss << "device: " << device_type << ", norm_type: " << norm_type
233        << ", error_if_nonfinite: " << error_if_nonfinite
234        << ", scalar: " << scalar
235        << ", grad_only_one_elem: " << grad_only_one_elem
236        << ", prefix_finite_grad_param: " << prefix_finite_grad_param
237        << ", is_norm_nonfinite: " << is_norm_nonfinite;
238     std::string msg = ss.str();
239 
240     auto parameters = gen_parameters(
241         scalar, grad_only_one_elem, prefix_finite_grad_param, device_type);
242 
243     if (is_norm_nonfinite && error_if_nonfinite) {
244       std::vector<torch::Tensor> grads_before;
245       // NOLINTNEXTLINE(performance-for-range-copy)
246       for (auto p : parameters) {
247         // NOLINTNEXTLINE(performance-inefficient-vector-operation)
248         grads_before.push_back(p.grad().clone());
249       }
250       // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
251       EXPECT_THROW(
252           utils::clip_grad_norm_(parameters, 1., norm_type, true),
253           std::exception)
254           << msg;
255       // Grads should not change if error is thrown
256       for (const auto p_idx : c10::irange(parameters.size())) {
257         ASSERT_TRUE(torch::allclose(
258             parameters[p_idx].grad(),
259             grads_before[p_idx],
260             1.0,
261             0.0,
262             /*equal_nan*/ true))
263             << msg;
264       }
265     } else {
266       // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
267       EXPECT_NO_THROW(
268           utils::clip_grad_norm_(parameters, 1., norm_type, error_if_nonfinite))
269           << msg;
270     }
271   };
272 
273   for (auto device_type : {torch::kCPU, torch::kCUDA}) {
274     if (device_type == torch::kCUDA && !torch::cuda::is_available()) {
275       continue;
276     }
277     for (auto test_case : test_cases) {
278       auto grad_only_one_elem = std::get<0>(test_case);
279       auto prefix_finite_grad_param = std::get<1>(test_case);
280       auto scalars = std::get<2>(test_case);
281       auto norms_nonfinite = std::get<3>(test_case);
282       auto norms_finite = std::get<4>(test_case);
283 
284       for (auto error_if_nonfinite : {false, true}) {
285         for (auto scalar : scalars) {
286           for (auto norm_type : norms_nonfinite) {
287             run_test_case(
288                 norm_type,
289                 error_if_nonfinite,
290                 scalar,
291                 grad_only_one_elem,
292                 prefix_finite_grad_param,
293                 true,
294                 device_type);
295           }
296 
297           for (auto norm_type : norms_finite) {
298             run_test_case(
299                 norm_type,
300                 error_if_nonfinite,
301                 scalar,
302                 grad_only_one_elem,
303                 prefix_finite_grad_param,
304                 false,
305                 device_type);
306           }
307         }
308       }
309     }
310   }
311 }
312 
TEST_F(NNUtilsTest,ClipGradValue)313 TEST_F(NNUtilsTest, ClipGradValue) {
314   auto l = Linear(10, 10);
315   float clip_value = 2.5;
316 
317   torch::Tensor grad_w = torch::arange(-50., 50).view({10, 10}).div_(5);
318   torch::Tensor grad_b = torch::ones({10}).mul_(2);
319   std::vector<std::vector<torch::Tensor>> grad_lists = {
320       {grad_w, grad_b}, {grad_w, torch::Tensor()}};
321   for (auto grad_list : grad_lists) {
322     for (const auto i : c10::irange(grad_list.size())) {
323       auto p = l->parameters()[i];
324       auto g = grad_list[i];
325       p.mutable_grad() = g.defined() ? g.clone().view_as(p.data()) : g;
326     }
327 
328     utils::clip_grad_value_(l->parameters(), clip_value);
329     for (const auto& p : l->parameters()) {
330       if (p.grad().defined()) {
331         ASSERT_LE(p.grad().data().max().item().toFloat(), clip_value);
332         ASSERT_GE(p.grad().data().min().item().toFloat(), -clip_value);
333       }
334     }
335   }
336 
337   // Should accept a single Tensor as input
338   auto p1 = torch::randn({10, 10});
339   auto p2 = torch::randn({10, 10});
340   auto g = torch::arange(-50., 50).view({10, 10}).div_(5);
341   p1.mutable_grad() = g.clone();
342   p2.mutable_grad() = g.clone();
343   utils::clip_grad_value_(p1, clip_value);
344   utils::clip_grad_value_({p2}, clip_value);
345   ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad()));
346 }
347 
TEST_F(NNUtilsTest,ConvertParameters)348 TEST_F(NNUtilsTest, ConvertParameters) {
349   std::vector<torch::Tensor> parameters{
350       torch::arange(9, torch::kFloat32),
351       torch::arange(9, torch::kFloat32).view({3, 3}),
352       torch::arange(8, torch::kFloat32).view({2, 2, 2})};
353 
354   auto expected = torch::cat(
355       {torch::arange(9, torch::kFloat32),
356        torch::arange(9, torch::kFloat32).view(-1),
357        torch::arange(8, torch::kFloat32).view(-1)});
358   auto vector = utils::parameters_to_vector(parameters);
359   ASSERT_TRUE(vector.allclose(expected));
360 
361   std::vector<torch::Tensor> zero_parameters{
362       torch::zeros({9}, torch::kFloat32),
363       torch::zeros({9}, torch::kFloat32).view({3, 3}),
364       torch::zeros({8}, torch::kFloat32).view({2, 2, 2})};
365 
366   utils::vector_to_parameters(vector, zero_parameters);
367   for (const auto i : c10::irange(zero_parameters.size())) {
368     ASSERT_TRUE(zero_parameters[i].allclose(parameters[i]));
369   }
370 
371   {
372     auto conv1 = Conv2d(3, 10, 5);
373     auto fc1 = Linear(10, 20);
374     auto model = Sequential(conv1, fc1);
375 
376     auto vec = utils::parameters_to_vector(model->parameters());
377     ASSERT_EQ(vec.size(0), 980);
378   }
379   {
380     auto conv1 = Conv2d(3, 10, 5);
381     auto fc1 = Linear(10, 20);
382     auto model = Sequential(conv1, fc1);
383 
384     auto vec = torch::arange(0., 980);
385     utils::vector_to_parameters(vec, model->parameters());
386 
387     auto sample = model->parameters()[0][0][0][0];
388     ASSERT_TRUE(torch::equal(sample.data(), vec.data().slice(0, 0, 5)));
389   }
390 }
391 
392 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-non-const-global-variables)
393 int64_t PackedSequenceTest_batch_size = 5;
394 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-non-const-global-variables)
395 int64_t PackedSequenceTest_max_length = 6;
396 
PackedSequenceTest_ordered_sequence(torch::ScalarType tensor_type)397 std::vector<torch::Tensor> PackedSequenceTest_ordered_sequence(
398     torch::ScalarType tensor_type) {
399   std::vector<torch::Tensor> seqs;
400   seqs.reserve(PackedSequenceTest_batch_size);
401   for (const auto i : c10::irange(PackedSequenceTest_batch_size)) {
402     (void)i; // Suppress unused variable warning
403     seqs.emplace_back(torch::empty(
404         {torch::randint(1, PackedSequenceTest_max_length, {1}).item<int64_t>()},
405         tensor_type));
406   }
407   for (auto& s : seqs) {
408     s.random_(-128, 128);
409   }
410   sort(
411       seqs.begin(),
412       seqs.end(),
413       [&](const torch::Tensor& t1, const torch::Tensor& t2) {
414         return t1.size(0) > t2.size(0);
415       });
416   return seqs;
417 }
418 
PackedSequenceTest_padded_sequence(torch::ScalarType tensor_type)419 std::tuple<torch::Tensor, torch::Tensor> PackedSequenceTest_padded_sequence(
420     torch::ScalarType tensor_type) {
421   // Create Tensor of random padded sequences
422   auto ordered = PackedSequenceTest_ordered_sequence(tensor_type);
423   auto lengths = torch::empty({(int64_t)ordered.size()}, torch::kInt64);
424   for (const auto i : c10::irange(ordered.size())) {
425     lengths[i] = ordered[i].size(0);
426   }
427   auto padded_tensor = rnn_utils::pad_sequence(ordered);
428   return std::make_tuple(padded_tensor, lengths);
429 }
430 
assert_is_equal_packed_sequence(const rnn_utils::PackedSequence & a,const rnn_utils::PackedSequence & b)431 void assert_is_equal_packed_sequence(
432     const rnn_utils::PackedSequence& a,
433     const rnn_utils::PackedSequence& b) {
434   ASSERT_TRUE(torch::allclose(a.data(), b.data()));
435   ASSERT_TRUE(torch::allclose(a.batch_sizes(), b.batch_sizes()));
436   ASSERT_TRUE(
437       (!a.sorted_indices().defined() && !b.sorted_indices().defined()) ||
438       torch::allclose(a.sorted_indices(), b.sorted_indices()));
439   ASSERT_TRUE(
440       (!a.unsorted_indices().defined() && !b.unsorted_indices().defined()) ||
441       torch::allclose(a.unsorted_indices(), b.unsorted_indices()));
442 }
443 
assert_is_same_packed_sequence(const rnn_utils::PackedSequence & a,const rnn_utils::PackedSequence & b)444 void assert_is_same_packed_sequence(
445     const rnn_utils::PackedSequence& a,
446     const rnn_utils::PackedSequence& b) {
447   ASSERT_TRUE(a.data().is_same(b.data()));
448   ASSERT_TRUE(a.batch_sizes().is_same(b.batch_sizes()));
449   ASSERT_TRUE(a.sorted_indices().is_same(b.sorted_indices()));
450   ASSERT_TRUE(a.unsorted_indices().is_same(b.unsorted_indices()));
451 }
452 
TEST_F(PackedSequenceTest,WrongOrder)453 TEST_F(PackedSequenceTest, WrongOrder) {
454   auto a = torch::ones({25, 300});
455   auto b = torch::ones({22, 300});
456   auto b_a = rnn_utils::pad_sequence({b, a});
457   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
458   ASSERT_THROW(
459       rnn_utils::pack_padded_sequence(
460           b_a,
461           torch::tensor({22, 25}),
462           /*batch_first=*/false,
463           /*enforce_sorted=*/true),
464       c10::Error);
465 }
466 
TEST_F(PackedSequenceTest,TotalLength)467 TEST_F(PackedSequenceTest, TotalLength) {
468   auto [padded, lengths] = PackedSequenceTest_padded_sequence(torch::kFloat);
469   int64_t max_length = torch::max(lengths).item<int64_t>();
470   rnn_utils::PackedSequence packed =
471       rnn_utils::pack_padded_sequence(padded, lengths);
472 
473   // test ValueError if total_length < max_length
474   for (int64_t total_length : std::vector<int64_t>{-1, 0, max_length - 1}) {
475     for (bool batch_first : std::vector<bool>{true, false}) {
476       auto err_fn = [&]() {
477         rnn_utils::pad_packed_sequence(
478             packed,
479             /*batch_first=*/batch_first,
480             /*padding_value=*/0.0,
481             /*total_length=*/total_length);
482       };
483       ASSERT_THROWS_WITH(
484           err_fn(),
485           "Expected total_length to be at least the length of the longest sequence in input");
486     }
487   }
488 
489   // test that pad_packed_sequence returns results of correct length
490   for (bool batch_first : std::vector<bool>{true, false}) {
491     auto no_extra_pad = std::get<0>(
492         rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first));
493     for (int64_t total_length_delta : std::vector<int64_t>{0, 1, 8}) {
494       int64_t total_length = max_length + total_length_delta;
495       auto [unpacked, lengths_out] = rnn_utils::pad_packed_sequence(
496           packed,
497           /*batch_first=*/batch_first,
498           /*padding_value=*/0.0,
499           /*total_length=*/total_length);
500       ASSERT_TRUE(torch::allclose(lengths, lengths_out));
501       ASSERT_EQ(unpacked.size(batch_first ? 1 : 0), total_length);
502       torch::Tensor ref_output, extra_pad;
503       if (total_length_delta == 0) {
504         ref_output = no_extra_pad;
505       } else if (batch_first) {
506         extra_pad = torch::zeros(
507             {PackedSequenceTest_batch_size, total_length_delta},
508             no_extra_pad.options());
509         ref_output = torch::cat({no_extra_pad, extra_pad}, 1);
510       } else {
511         extra_pad = torch::zeros(
512             {total_length_delta, PackedSequenceTest_batch_size},
513             no_extra_pad.options());
514         ref_output = torch::cat({no_extra_pad, extra_pad}, 0);
515       }
516       ASSERT_TRUE(torch::allclose(unpacked, ref_output));
517     }
518   }
519 }
520 
TEST_F(PackedSequenceTest,To)521 TEST_F(PackedSequenceTest, To) {
522   for (bool enforce_sorted : std::vector<bool>{true, false}) {
523     auto [padded, lengths] = PackedSequenceTest_padded_sequence(torch::kInt);
524     rnn_utils::PackedSequence a = rnn_utils::pack_padded_sequence(
525                                       padded,
526                                       lengths,
527                                       /*batch_first=*/false,
528                                       /*enforce_sorted=*/enforce_sorted)
529                                       .cpu();
530 
531     assert_is_same_packed_sequence(a, a.to(torch::kCPU));
532     assert_is_same_packed_sequence(a, a.cpu());
533     assert_is_same_packed_sequence(
534         a, a.to(torch::device(torch::kCPU).dtype(torch::kInt32)));
535 
536     if (torch::cuda::is_available()) {
537       auto b = a.cuda();
538       assert_is_same_packed_sequence(b, b.to(torch::kCUDA));
539       assert_is_same_packed_sequence(b, b.cuda());
540       assert_is_equal_packed_sequence(a, b.to(torch::kCPU));
541       assert_is_equal_packed_sequence(b, a.to(torch::kCUDA));
542       assert_is_equal_packed_sequence(
543           a, b.to(torch::device(torch::kCPU).dtype(torch::kInt32)));
544       assert_is_same_packed_sequence(b, b.to(torch::kInt32));
545     }
546   }
547 }
548 
TEST_F(NNUtilsTest,PackSequence)549 TEST_F(NNUtilsTest, PackSequence) {
550   auto _compatibility_test = [&](torch::ArrayRef<torch::Tensor> sequences,
551                                  torch::Tensor lengths,
552                                  bool batch_first,
553                                  bool enforce_sorted = false) {
554     torch::Tensor padded = rnn_utils::pad_sequence(sequences, batch_first);
555     rnn_utils::PackedSequence packed =
556         rnn_utils::pack_sequence(sequences, enforce_sorted);
557     std::tuple<torch::Tensor, torch::Tensor> unpacked =
558         rnn_utils::pad_packed_sequence(packed, batch_first);
559     ASSERT_TRUE(torch::allclose(padded, std::get<0>(unpacked)));
560     rnn_utils::PackedSequence pack_padded = rnn_utils::pack_padded_sequence(
561         padded, lengths, batch_first, enforce_sorted);
562     assert_is_equal_packed_sequence(packed, pack_padded);
563   };
564 
565   // single dimensional
566   auto a = torch::tensor({1, 2, 3});
567   auto b = torch::tensor({4, 5});
568   auto c = torch::tensor({6});
569   rnn_utils::PackedSequence packed =
570       rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/false);
571   auto expected = torch::tensor({1, 4, 6, 2, 5, 3});
572   ASSERT_TRUE(torch::allclose(packed.batch_sizes(), torch::tensor({3, 2, 1})));
573   ASSERT_TRUE(torch::allclose(packed.data(), expected));
574   ASSERT_TRUE(
575       torch::allclose(packed.sorted_indices(), torch::tensor({0, 1, 2})));
576   ASSERT_TRUE(
577       torch::allclose(packed.unsorted_indices(), torch::tensor({0, 1, 2})));
578 
579   rnn_utils::PackedSequence packed_unsorted =
580       rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/false);
581   ASSERT_TRUE(
582       torch::allclose(packed_unsorted.batch_sizes(), torch::tensor({3, 2, 1})));
583   ASSERT_TRUE(torch::allclose(packed_unsorted.data(), expected));
584   ASSERT_TRUE(torch::allclose(
585       packed_unsorted.sorted_indices(), torch::tensor({2, 0, 1})));
586   ASSERT_TRUE(torch::allclose(
587       packed_unsorted.unsorted_indices(), torch::tensor({1, 2, 0})));
588 
589   // single dimensional, enforce_sorted = True
590   rnn_utils::PackedSequence packed_enforce_sorted =
591       rnn_utils::pack_sequence({a, b, c}, /*enforce_sorted=*/true);
592   ASSERT_TRUE(torch::allclose(
593       packed_enforce_sorted.batch_sizes(), torch::tensor({3, 2, 1})));
594   ASSERT_TRUE(torch::allclose(packed_enforce_sorted.data(), expected));
595   ASSERT_FALSE(packed_enforce_sorted.sorted_indices().defined());
596   ASSERT_FALSE(packed_enforce_sorted.unsorted_indices().defined());
597 
598   ASSERT_THROWS_WITH(
599       rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true),
600       "must be sorted in decreasing order");
601 
602   ASSERT_THROWS_WITH(
603       rnn_utils::pack_sequence({b, c, a}, /*enforce_sorted=*/true),
604       "You can pass `enforce_sorted=False`");
605 
606   // more dimensions
607   int64_t maxlen = 9;
608   for (int64_t num_dim : std::vector<int64_t>{0, 1, 2, 3}) {
609     std::vector<torch::Tensor> sequences;
610     std::vector<int64_t> lengths_vec;
611     std::vector<int64_t> trailing_dims(num_dim, 4);
612     for (int64_t i = maxlen; i > 0; i--) {
613       int64_t seq_len = i * i;
614       lengths_vec.emplace_back(seq_len);
615       std::vector<int64_t> tensor_sizes{seq_len, 5};
616       tensor_sizes.insert(
617           tensor_sizes.end(), trailing_dims.begin(), trailing_dims.end());
618       sequences.emplace_back(torch::rand(tensor_sizes));
619     }
620     std::vector<torch::Tensor> unsorted_sequences;
621     for (const auto& s : sequences) {
622       // NOLINTNEXTLINE(performance-inefficient-vector-operation)
623       unsorted_sequences.emplace_back(s.clone());
624     }
625     std::shuffle(
626         std::begin(unsorted_sequences),
627         std::end(unsorted_sequences),
628         std::default_random_engine{});
629 
630     std::vector<int64_t> unsorted_sequences_lengths_vec;
631     for (const auto& t : unsorted_sequences) {
632       // NOLINTNEXTLINE(performance-inefficient-vector-operation)
633       unsorted_sequences_lengths_vec.emplace_back(t.size(0));
634     }
635 
636     // compatibility with other utilities
637     for (bool batch_first : std::vector<bool>{true, false}) {
638       for (bool enforce_sorted : std::vector<bool>{true, false}) {
639         _compatibility_test(
640             sequences, torch::tensor(lengths_vec), batch_first, enforce_sorted);
641       }
642       _compatibility_test(
643           unsorted_sequences,
644           torch::tensor(unsorted_sequences_lengths_vec),
645           batch_first);
646     }
647   }
648 }
649 
TEST_F(NNUtilsTest,PackPaddedSequence)650 TEST_F(NNUtilsTest, PackPaddedSequence) {
651   auto generate_test_case = [&](torch::ArrayRef<int64_t> sorted_lengths,
652                                 bool should_shuffle) {
653     auto pad = [&](torch::Tensor tensor, int64_t length) {
654       std::vector<int64_t> tensor_sizes{length - tensor.size(0)};
655       tensor_sizes.insert(
656           tensor_sizes.end(),
657           tensor.sizes().slice(1).begin(),
658           tensor.sizes().slice(1).end());
659       return torch::cat({tensor, torch::zeros(tensor_sizes, tensor.options())});
660     };
661     int64_t max_length = sorted_lengths[0];
662     torch::Tensor batch_sizes = torch::empty({max_length}, torch::kInt64);
663     for (int64_t i = 1; i < max_length + 1; i++) {
664       int64_t total = 0;
665       for (const auto& x : sorted_lengths) {
666         if (x >= i) {
667           total++;
668         }
669       }
670       batch_sizes[i - 1] = total;
671     }
672     std::vector<torch::Tensor> tensors_to_be_cat;
673     for (int64_t i = 1; i < static_cast<int64_t>(sorted_lengths.size() + 1);
674          i++) {
675       int64_t l = sorted_lengths.at(i - 1);
676       tensors_to_be_cat.emplace_back(pad(
677           i * 100 + torch::arange(1., 5 * l + 1).view({l, 1, 5}), max_length));
678     }
679     auto padded = torch::cat(tensors_to_be_cat, 1);
680     std::vector<torch::Tensor> expected_data_vec;
681     for (const auto n : c10::irange(batch_sizes.size(0))) {
682       int64_t batch_size = batch_sizes[n].item<int64_t>();
683       for (const auto i : c10::irange(batch_size)) {
684         expected_data_vec.emplace_back(
685             torch::arange(1., 6) + (i + 1) * 100 + 5 * n);
686       }
687     }
688     auto expected_data = torch::stack(expected_data_vec, /*dim=*/0);
689 
690     torch::Tensor unsorted_indices, lengths;
691     if (should_shuffle) {
692       // Shuffle the padded sequence to create an unsorted sequence
693       std::vector<int64_t> permutation;
694       for (const auto i : c10::irange(sorted_lengths.size())) {
695         permutation.emplace_back(i);
696       }
697       std::shuffle(
698           std::begin(permutation),
699           std::end(permutation),
700           std::default_random_engine{});
701 
702       unsorted_indices = torch::tensor(permutation);
703       padded = padded.index_select(1, unsorted_indices);
704       lengths = torch::tensor(sorted_lengths).index_select(0, unsorted_indices);
705     } else {
706       unsorted_indices = torch::Tensor();
707       lengths = torch::tensor(sorted_lengths);
708     }
709 
710     return std::make_tuple(
711         padded.requires_grad_(),
712         lengths,
713         expected_data,
714         batch_sizes,
715         unsorted_indices);
716   };
717 
718   std::vector<std::pair<std::vector<int64_t>, bool>> test_cases = {
719       // sorted_lengths, should_shuffle
720       {{10, 8, 4, 2, 2, 2, 1}, false},
721       {{11, 10, 8, 6, 4, 3, 1}, false},
722       {{11, 10, 8, 6, 4, 3, 1}, true}};
723 
724   for (const auto& test_case : test_cases) {
725     for (bool batch_first : std::vector<bool>{true, false}) {
726       // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
727       std::vector<int64_t> sorted_lengths = std::get<0>(test_case);
728       bool should_shuffle = std::get<1>(test_case);
729 
730       auto [padded, lengths, expected_data, batch_sizes, unsorted_indices] =
731           generate_test_case(sorted_lengths, should_shuffle);
732 
733       auto src = padded;
734       if (batch_first) {
735         src = src.transpose(0, 1);
736       }
737 
738       // check output
739       rnn_utils::PackedSequence packed = rnn_utils::pack_padded_sequence(
740           src,
741           lengths,
742           /*batch_first=*/batch_first,
743           /*enforce_sorted=*/!should_shuffle);
744       ASSERT_TRUE(torch::allclose(packed.data(), expected_data));
745       ASSERT_TRUE(torch::allclose(packed.batch_sizes(), batch_sizes));
746       ASSERT_TRUE(
747           (!packed.unsorted_indices().defined() &&
748            !unsorted_indices.defined()) ||
749           torch::allclose(packed.unsorted_indices(), unsorted_indices));
750 
751       // test inverse
752       auto [unpacked, unpacked_len] =
753           rnn_utils::pad_packed_sequence(packed, /*batch_first=*/batch_first);
754       ASSERT_TRUE(torch::allclose(unpacked, src));
755       ASSERT_TRUE(torch::allclose(unpacked_len, lengths));
756 
757       // check grad
758       if (padded.grad().defined()) {
759         torch::NoGradGuard no_grad;
760         padded.grad().zero_();
761       }
762       torch::Tensor grad_output;
763       {
764         torch::NoGradGuard no_grad;
765         grad_output = unpacked.clone().normal_();
766       }
767       unpacked.backward(grad_output);
768       if (batch_first) {
769         grad_output.transpose_(0, 1);
770       }
771       for (const auto i : c10::irange(lengths.size(0))) {
772         int64_t l = lengths[i].item<int64_t>();
773         ASSERT_TRUE(torch::allclose(
774             padded.grad().narrow(0, 0, l).select(1, i),
775             grad_output.narrow(0, 0, l).select(1, i)));
776         if (l < 10) {
777           ASSERT_EQ(
778               padded.grad()
779                   .narrow(0, l, padded.grad().size(0) - l)
780                   .select(1, i)
781                   .abs()
782                   .sum()
783                   .item<double>(),
784               0);
785         }
786       }
787     }
788   }
789 
790   // test error messages
791   ASSERT_THROWS_WITH(
792       rnn_utils::pack_padded_sequence(
793           torch::randn({3, 3}), torch::tensor({1, 3, 2})),
794       "You can pass `enforce_sorted=False`");
795   ASSERT_THROWS_WITH(
796       rnn_utils::pack_padded_sequence(torch::randn({0, 0}), torch::tensor({})),
797       "empty tensor");
798 }
799 
TEST_F(NNUtilsTest,PadSequence)800 TEST_F(NNUtilsTest, PadSequence) {
801   auto pad = [&](const torch::Tensor& tensor, int64_t length) {
802     torch::NoGradGuard no_grad;
803     std::vector<int64_t> tensor_sizes{length - tensor.size(0)};
804     tensor_sizes.insert(
805         tensor_sizes.end(),
806         tensor.sizes().slice(1).begin(),
807         tensor.sizes().slice(1).end());
808     return torch::cat({tensor, torch::zeros(tensor_sizes, tensor.options())});
809   };
810 
811   // single dimensional
812   auto a = torch::tensor({1, 2, 3});
813   auto b = torch::tensor({4, 5});
814   auto c = torch::tensor({6});
815 
816   torch::Tensor expected, padded;
817 
818   // batch_first = true
819   expected = torch::tensor({{4, 5, 0}, {1, 2, 3}, {6, 0, 0}});
820   padded = rnn_utils::pad_sequence({b, a, c}, true);
821   ASSERT_TRUE(padded.allclose(expected));
822 
823   // batch_first = false
824   padded = rnn_utils::pad_sequence({b, a, c});
825   ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
826 
827   // padding_side = "left", batch_first = true
828   expected = torch::tensor({{0, 4, 5}, {1, 2, 3}, {0, 0, 6}});
829   padded = rnn_utils::pad_sequence({b, a, c}, true, 0, "left");
830   ASSERT_TRUE(padded.allclose(expected));
831 
832   // padding_side = "left", batch_first = false
833   padded = rnn_utils::pad_sequence({b, a, c}, false, 0, "left");
834   ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
835 
836   // pad with non-zero value
837   expected = torch::tensor({{4, 5, 1}, {1, 2, 3}, {6, 1, 1}});
838   padded = rnn_utils::pad_sequence({b, a, c}, true, 1);
839   ASSERT_TRUE(padded.allclose(expected));
840 
841   // Test pad sorted sequence
842   expected = torch::tensor({{1, 2, 3}, {4, 5, 0}, {6, 0, 0}});
843   padded = rnn_utils::pad_sequence({a, b, c}, true);
844   ASSERT_TRUE(padded.allclose(expected));
845 
846   // more dimensions
847   int64_t maxlen = 9;
848   for (int64_t num_dim : std::vector<int64_t>{0, 1, 2, 3}) {
849     std::vector<torch::Tensor> sequences;
850     std::vector<int64_t> trailing_dims(num_dim, 4);
851     for (int64_t i = 1; i < maxlen + 1; i++) {
852       int64_t seq_len = i * i;
853       std::vector<int64_t> tensor_sizes{seq_len, 5};
854       tensor_sizes.insert(
855           tensor_sizes.end(), trailing_dims.begin(), trailing_dims.end());
856       sequences.emplace_back(torch::rand(tensor_sizes));
857     }
858     std::shuffle(
859         std::begin(sequences),
860         std::end(sequences),
861         std::default_random_engine{});
862     std::vector<torch::Tensor> expected_tensors;
863     for (const torch::Tensor& seq : sequences) {
864       // NOLINTNEXTLINE(performance-inefficient-vector-operation)
865       expected_tensors.emplace_back(pad(seq, maxlen * maxlen));
866     }
867 
868     // batch first = true
869     auto expected = torch::stack(expected_tensors);
870     auto padded = rnn_utils::pad_sequence(sequences, true);
871     ASSERT_TRUE(padded.allclose(expected));
872 
873     // batch first = false
874     padded = rnn_utils::pad_sequence(sequences);
875     ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
876 
877     // reset expected_tensors for padding_side
878     expected_tensors.clear();
879     for (const torch::Tensor& seq : sequences) {
880       // NOLINTNEXTLINE(performance-inefficient-vector-operation)
881       expected_tensors.emplace_back(
882           torch::flip(pad(torch::flip(seq, {0}), maxlen * maxlen), {0}));
883     }
884     expected = torch::stack(expected_tensors);
885     // padding_side = "left", batch_first = true
886     padded = rnn_utils::pad_sequence(sequences, true, 0, "left");
887     ASSERT_TRUE(padded.allclose(expected));
888 
889     // padding_side = "left", batch_first = false
890     padded = rnn_utils::pad_sequence(sequences, false, 0, "left");
891     ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
892   }
893 }
894