xref: /aosp_15_r20/external/pytorch/test/cpp/api/sequential.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <algorithm>
7 #include <memory>
8 #include <vector>
9 
10 #include <test/cpp/api/support.h>
11 
12 using namespace torch::nn;
13 using namespace torch::test;
14 
15 struct SequentialTest : torch::test::SeedingFixture {};
16 
TEST_F(SequentialTest,CanContainThings)17 TEST_F(SequentialTest, CanContainThings) {
18   Sequential sequential(Linear(3, 4), ReLU(), BatchNorm1d(3));
19 }
20 
TEST_F(SequentialTest,ConstructsFromSharedPointer)21 TEST_F(SequentialTest, ConstructsFromSharedPointer) {
22   struct M : torch::nn::Module {
23     explicit M(int value_) : value(value_) {}
24     int value;
25     int forward() {
26       return value;
27     }
28   };
29   Sequential sequential(
30       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
31   ASSERT_EQ(sequential->size(), 3);
32 
33   Sequential sequential_named(
34       {{"m1", std::make_shared<M>(1)},
35        {std::string("m2"), std::make_shared<M>(2)},
36        {"m3", std::make_shared<M>(3)}});
37   ASSERT_EQ(sequential->size(), 3);
38 }
39 
TEST_F(SequentialTest,ConstructsFromConcreteType)40 TEST_F(SequentialTest, ConstructsFromConcreteType) {
41   static int copy_count;
42 
43   struct M : torch::nn::Module {
44     explicit M(int value_) : value(value_) {}
45     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
46     M(const M& other) : torch::nn::Module(other) {
47       copy_count++;
48     }
49     int value;
50     int forward() {
51       return value;
52     }
53   };
54 
55   copy_count = 0;
56   Sequential sequential(M(1), M(2), M(3));
57   ASSERT_EQ(sequential->size(), 3);
58   // NOTE: The current implementation expects each module to be copied exactly
59   // once, which happens when the module is passed into `std::make_shared<T>()`.
60   // TODO: Find a way to avoid copying, and then delete the copy constructor of
61   // `M`.
62   ASSERT_EQ(copy_count, 3);
63 
64   copy_count = 0;
65   Sequential sequential_named(
66       {{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}});
67   ASSERT_EQ(sequential->size(), 3);
68   ASSERT_EQ(copy_count, 3);
69 }
70 
TEST_F(SequentialTest,ConstructsFromModuleHolder)71 TEST_F(SequentialTest, ConstructsFromModuleHolder) {
72   struct MImpl : torch::nn::Module {
73     explicit MImpl(int value_) : value(value_) {}
74     int forward() {
75       return value;
76     }
77     int value;
78   };
79 
80   struct M : torch::nn::ModuleHolder<MImpl> {
81     using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
82     using torch::nn::ModuleHolder<MImpl>::get;
83   };
84 
85   Sequential sequential(M(1), M(2), M(3));
86   ASSERT_EQ(sequential->size(), 3);
87 
88   Sequential sequential_named(
89       {{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}});
90   ASSERT_EQ(sequential->size(), 3);
91 }
92 
TEST_F(SequentialTest,PushBackAddsAnElement)93 TEST_F(SequentialTest, PushBackAddsAnElement) {
94   struct M : torch::nn::Module {
95     explicit M(int value_) : value(value_) {}
96     int forward() {
97       return value;
98     }
99     int value;
100   };
101 
102   // Test unnamed submodules
103   Sequential sequential;
104   ASSERT_EQ(sequential->size(), 0);
105   ASSERT_TRUE(sequential->is_empty());
106   sequential->push_back(Linear(3, 4));
107   ASSERT_EQ(sequential->size(), 1);
108   sequential->push_back(std::make_shared<M>(1));
109   ASSERT_EQ(sequential->size(), 2);
110   sequential->push_back(M(2));
111   ASSERT_EQ(sequential->size(), 3);
112 
113   // Mix named and unnamed submodules
114   Sequential sequential_named;
115   ASSERT_EQ(sequential_named->size(), 0);
116   ASSERT_TRUE(sequential_named->is_empty());
117 
118   sequential_named->push_back(Linear(3, 4));
119   ASSERT_EQ(sequential_named->size(), 1);
120   ASSERT_EQ(sequential_named->named_children()[0].key(), "0");
121   sequential_named->push_back(std::string("linear2"), Linear(3, 4));
122   ASSERT_EQ(sequential_named->size(), 2);
123   ASSERT_EQ(sequential_named->named_children()[1].key(), "linear2");
124 
125   sequential_named->push_back("shared_m1", std::make_shared<M>(1));
126   ASSERT_EQ(sequential_named->size(), 3);
127   ASSERT_EQ(sequential_named->named_children()[2].key(), "shared_m1");
128   sequential_named->push_back(std::make_shared<M>(1));
129   ASSERT_EQ(sequential_named->size(), 4);
130   ASSERT_EQ(sequential_named->named_children()[3].key(), "3");
131 
132   sequential_named->push_back(M(1));
133   ASSERT_EQ(sequential_named->size(), 5);
134   ASSERT_EQ(sequential_named->named_children()[4].key(), "4");
135   sequential_named->push_back(std::string("m2"), M(1));
136   ASSERT_EQ(sequential_named->size(), 6);
137   ASSERT_EQ(sequential_named->named_children()[5].key(), "m2");
138 
139   // named and unnamed AnyModule's
140   Sequential sequential_any;
141   auto a = torch::nn::AnyModule(torch::nn::Linear(1, 2));
142   ASSERT_EQ(sequential_any->size(), 0);
143   ASSERT_TRUE(sequential_any->is_empty());
144   sequential_any->push_back(a);
145   ASSERT_EQ(sequential_any->size(), 1);
146   ASSERT_EQ(sequential_any->named_children()[0].key(), "0");
147   sequential_any->push_back("fc", a);
148   ASSERT_EQ(sequential_any->size(), 2);
149   ASSERT_EQ(sequential_any->named_children()[1].key(), "fc");
150 }
151 
TEST_F(SequentialTest,AccessWithAt)152 TEST_F(SequentialTest, AccessWithAt) {
153   struct M : torch::nn::Module {
154     explicit M(int value_) : value(value_) {}
155     int forward() {
156       return value;
157     }
158     int value;
159   };
160   std::vector<std::shared_ptr<M>> modules = {
161       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
162 
163   Sequential sequential;
164   for (auto& module : modules) {
165     sequential->push_back(module);
166   }
167   ASSERT_EQ(sequential->size(), 3);
168 
169   // returns the correct module for a given index
170   for (const auto i : c10::irange(modules.size())) {
171     ASSERT_EQ(&sequential->at<M>(i), modules[i].get());
172   }
173 
174   // throws for a bad index
175   ASSERT_THROWS_WITH(
176       sequential->at<M>(modules.size() + 1), "Index out of range");
177   ASSERT_THROWS_WITH(
178       sequential->at<M>(modules.size() + 1000000), "Index out of range");
179 }
180 
TEST_F(SequentialTest,AccessWithPtr)181 TEST_F(SequentialTest, AccessWithPtr) {
182   struct M : torch::nn::Module {
183     explicit M(int value_) : value(value_) {}
184     int forward() {
185       return value;
186     }
187     int value;
188   };
189   std::vector<std::shared_ptr<M>> modules = {
190       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
191 
192   Sequential sequential;
193   for (auto& module : modules) {
194     sequential->push_back(module);
195   }
196   ASSERT_EQ(sequential->size(), 3);
197 
198   // returns the correct module for a given index
199   for (const auto i : c10::irange(modules.size())) {
200     ASSERT_EQ(sequential->ptr(i).get(), modules[i].get());
201     ASSERT_EQ(sequential[i].get(), modules[i].get());
202     ASSERT_EQ(sequential->ptr<M>(i).get(), modules[i].get());
203   }
204 
205   // throws for a bad index
206   ASSERT_THROWS_WITH(sequential->ptr(modules.size() + 1), "Index out of range");
207   ASSERT_THROWS_WITH(
208       sequential->ptr(modules.size() + 1000000), "Index out of range");
209 }
210 
TEST_F(SequentialTest,CallingForwardOnEmptySequentialIsDisallowed)211 TEST_F(SequentialTest, CallingForwardOnEmptySequentialIsDisallowed) {
212   Sequential empty;
213   ASSERT_THROWS_WITH(
214       empty->forward<int>(), "Cannot call forward() on an empty Sequential");
215 }
216 
TEST_F(SequentialTest,CallingForwardChainsCorrectly)217 TEST_F(SequentialTest, CallingForwardChainsCorrectly) {
218   struct MockModule : torch::nn::Module {
219     explicit MockModule(int value) : expected(value) {}
220     int expected;
221     int forward(int value) {
222       assert(value == expected);
223       return value + 1;
224     }
225   };
226 
227   Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
228 
229   ASSERT_EQ(sequential->forward<int>(1), 4);
230 }
231 
TEST_F(SequentialTest,CallingForwardWithTheWrongReturnTypeThrows)232 TEST_F(SequentialTest, CallingForwardWithTheWrongReturnTypeThrows) {
233   struct M : public torch::nn::Module {
234     int forward() {
235       return 5;
236     }
237   };
238 
239   Sequential sequential(M{});
240   ASSERT_EQ(sequential->forward<int>(), 5);
241   ASSERT_THROWS_WITH(
242       sequential->forward<float>(),
243       "The type of the return value is int, but you asked for type float");
244 }
245 
TEST_F(SequentialTest,TheReturnTypeOfForwardDefaultsToTensor)246 TEST_F(SequentialTest, TheReturnTypeOfForwardDefaultsToTensor) {
247   struct M : public torch::nn::Module {
248     torch::Tensor forward(torch::Tensor v) {
249       return v;
250     }
251   };
252 
253   Sequential sequential(M{});
254   auto variable = torch::ones({3, 3}, torch::requires_grad());
255   ASSERT_TRUE(sequential->forward(variable).equal(variable));
256 }
257 
TEST_F(SequentialTest,ForwardReturnsTheLastValue)258 TEST_F(SequentialTest, ForwardReturnsTheLastValue) {
259   torch::manual_seed(0);
260   Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
261 
262   auto x = torch::randn({1000, 10}, torch::requires_grad());
263   auto y = sequential->forward(x);
264   ASSERT_EQ(y.ndimension(), 2);
265   ASSERT_EQ(y.size(0), 1000);
266   ASSERT_EQ(y.size(1), 100);
267 }
268 
TEST_F(SequentialTest,SanityCheckForHoldingStandardModules)269 TEST_F(SequentialTest, SanityCheckForHoldingStandardModules) {
270   Sequential sequential(
271       Linear(10, 3),
272       Conv2d(1, 2, 3),
273       Dropout(0.5),
274       BatchNorm2d(5),
275       Embedding(4, 10),
276       LSTM(4, 5));
277 }
278 
TEST_F(SequentialTest,ExtendPushesModulesFromOtherSequential)279 TEST_F(SequentialTest, ExtendPushesModulesFromOtherSequential) {
280   struct A : torch::nn::Module {
281     int forward(int x) {
282       return x;
283     }
284   };
285   struct B : torch::nn::Module {
286     int forward(int x) {
287       return x;
288     }
289   };
290   struct C : torch::nn::Module {
291     int forward(int x) {
292       return x;
293     }
294   };
295   struct D : torch::nn::Module {
296     int forward(int x) {
297       return x;
298     }
299   };
300   Sequential a(A{}, B{});
301   Sequential b(C{}, D{});
302   a->extend(*b);
303 
304   ASSERT_EQ(a->size(), 4);
305   ASSERT_TRUE(a[0]->as<A>());
306   ASSERT_TRUE(a[1]->as<B>());
307   ASSERT_TRUE(a[2]->as<C>());
308   ASSERT_TRUE(a[3]->as<D>());
309 
310   ASSERT_EQ(b->size(), 2);
311   ASSERT_TRUE(b[0]->as<C>());
312   ASSERT_TRUE(b[1]->as<D>());
313 
314   std::vector<std::shared_ptr<A>> c = {
315       std::make_shared<A>(), std::make_shared<A>()};
316   b->extend(c);
317 
318   ASSERT_EQ(b->size(), 4);
319   ASSERT_TRUE(b[0]->as<C>());
320   ASSERT_TRUE(b[1]->as<D>());
321   ASSERT_TRUE(b[2]->as<A>());
322   ASSERT_TRUE(b[3]->as<A>());
323 }
324 
TEST_F(SequentialTest,HasReferenceSemantics)325 TEST_F(SequentialTest, HasReferenceSemantics) {
326   Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
327   Sequential second(first);
328 
329   ASSERT_EQ(first.get(), second.get());
330   ASSERT_EQ(first->size(), second->size());
331   ASSERT_TRUE(std::equal(
332       first->begin(),
333       first->end(),
334       second->begin(),
335       [](const AnyModule& first, const AnyModule& second) {
336         return &first == &second;
337       }));
338 }
339 
TEST_F(SequentialTest,IsCloneable)340 TEST_F(SequentialTest, IsCloneable) {
341   Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
342   Sequential clone =
343       std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
344   ASSERT_EQ(sequential->size(), clone->size());
345 
346   for (size_t i = 0; i < sequential->size(); ++i) {
347     // The modules should be the same kind (type).
348     ASSERT_EQ(sequential[i]->name(), clone[i]->name());
349     // But not pointer-equal (distinct objects).
350     ASSERT_NE(sequential[i], clone[i]);
351   }
352 
353   // Verify that the clone is deep, i.e. parameters of modules are cloned too.
354 
355   torch::NoGradGuard no_grad;
356 
357   auto params1 = sequential->named_parameters();
358   auto params2 = clone->named_parameters();
359   ASSERT_EQ(params1.size(), params2.size());
360   for (auto& param : params1) {
361     ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
362     ASSERT_EQ(param->device(), params2[param.key()].device());
363     ASSERT_TRUE(param->allclose(params2[param.key()]));
364     param->add_(2);
365   }
366   for (auto& param : params1) {
367     ASSERT_FALSE(param->allclose(params2[param.key()]));
368   }
369 }
370 
TEST_F(SequentialTest,RegistersElementsAsSubmodules)371 TEST_F(SequentialTest, RegistersElementsAsSubmodules) {
372   Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), Dropout2d(0.5));
373 
374   auto modules = sequential->children();
375   ASSERT_TRUE(modules[0]->as<Linear>());
376   ASSERT_TRUE(modules[1]->as<Conv2d>());
377   ASSERT_TRUE(modules[2]->as<Dropout2d>());
378 }
379 
TEST_F(SequentialTest,CloneToDevice_CUDA)380 TEST_F(SequentialTest, CloneToDevice_CUDA) {
381   Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
382   torch::Device device(torch::kCUDA, 0);
383   Sequential clone =
384       std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
385   for (const auto& p : clone->parameters()) {
386     ASSERT_EQ(p.device(), device);
387   }
388   for (const auto& b : clone->buffers()) {
389     ASSERT_EQ(b.device(), device);
390   }
391 }
392 
TEST_F(SequentialTest,PrettyPrintSequential)393 TEST_F(SequentialTest, PrettyPrintSequential) {
394   Sequential sequential(
395       Linear(10, 3),
396       Conv2d(1, 2, 3),
397       Dropout(0.5),
398       BatchNorm2d(5),
399       Embedding(4, 10),
400       LSTM(4, 5));
401   ASSERT_EQ(
402       c10::str(sequential),
403       "torch::nn::Sequential(\n"
404       "  (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
405       "  (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
406       "  (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
407       "  (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
408       "  (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
409       "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
410       ")");
411 
412   Sequential sequential_named(
413       {{"linear", Linear(10, 3)},
414        {"conv2d", Conv2d(1, 2, 3)},
415        {"dropout", Dropout(0.5)},
416        {"batchnorm2d", BatchNorm2d(5)},
417        {"embedding", Embedding(4, 10)},
418        {"lstm", LSTM(4, 5)}});
419   ASSERT_EQ(
420       c10::str(sequential_named),
421       "torch::nn::Sequential(\n"
422       "  (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
423       "  (conv2d): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
424       "  (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
425       "  (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
426       "  (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
427       "  (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
428       ")");
429 }
430 
TEST_F(SequentialTest,ModuleForwardMethodOptionalArg)431 TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
432   {
433     Sequential sequential(
434         Identity(),
435         ConvTranspose1d(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false)));
436     std::dynamic_pointer_cast<ConvTranspose1dImpl>(sequential[1])
437         ->weight.set_data(torch::arange(18.).reshape({3, 2, 3}));
438     auto x = torch::arange(30.).reshape({2, 3, 5});
439     auto y = sequential->forward(x);
440     auto expected = torch::tensor(
441         {{{150., 333., 552., 615., 678., 501., 276.},
442           {195., 432., 714., 804., 894., 654., 357.}},
443          {{420., 918., 1497., 1560., 1623., 1176., 636.},
444           {600., 1287., 2064., 2154., 2244., 1599., 852.}}});
445     ASSERT_TRUE(torch::allclose(y, expected));
446   }
447   {
448     Sequential sequential(
449         Identity(),
450         ConvTranspose2d(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false)));
451     std::dynamic_pointer_cast<ConvTranspose2dImpl>(sequential[1])
452         ->weight.set_data(torch::arange(54.).reshape({3, 2, 3, 3}));
453     auto x = torch::arange(75.).reshape({1, 3, 5, 5});
454     auto y = sequential->forward(x);
455     auto expected = torch::tensor(
456         {{{{2250., 4629., 7140., 7311., 7482., 5133., 2640.},
457            {4995., 10272., 15837., 16206., 16575., 11364., 5841.},
458            {8280., 17019., 26226., 26820., 27414., 18783., 9648.},
459            {9225., 18954., 29196., 29790., 30384., 20808., 10683.},
460            {10170., 20889., 32166., 32760., 33354., 22833., 11718.},
461            {7515., 15420., 23721., 24144., 24567., 16800., 8613.},
462            {4140., 8487., 13044., 13269., 13494., 9219., 4722.}},
463           {{2925., 6006., 9246., 9498., 9750., 6672., 3423.},
464            {6480., 13296., 20454., 20985., 21516., 14712., 7542.},
465            {10710., 21960., 33759., 34596., 35433., 24210., 12402.},
466            {12060., 24705., 37944., 38781., 39618., 27045., 13842.},
467            {13410., 27450., 42129., 42966., 43803., 29880., 15282.},
468            {9810., 20064., 30768., 31353., 31938., 21768., 11124.},
469            {5355., 10944., 16770., 17076., 17382., 11838., 6045.}}}});
470     ASSERT_TRUE(torch::allclose(y, expected));
471   }
472   {
473     Sequential sequential(
474         Identity(),
475         ConvTranspose3d(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false)));
476     std::dynamic_pointer_cast<ConvTranspose3dImpl>(sequential[1])
477         ->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2}));
478     auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2});
479     auto y = sequential->forward(x);
480     auto expected = torch::tensor(
481         {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
482            {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
483            {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
484           {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
485            {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
486            {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
487     ASSERT_TRUE(torch::allclose(y, expected));
488   }
489   {
490     auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
491     Sequential sequential(Identity(), EmbeddingBag::from_pretrained(weight));
492     auto x = torch::tensor({{1, 0}}, torch::kLong);
493     auto y = sequential->forward(x);
494     auto expected = torch::tensor({2.5000, 3.7000, 4.6500});
495     ASSERT_TRUE(torch::allclose(y, expected));
496   }
497   {
498     torch::manual_seed(0);
499 
500     int64_t embed_dim = 8;
501     int64_t num_heads = 4;
502     int64_t batch_size = 8;
503     int64_t src_len = 3;
504     int64_t tgt_len = 1;
505 
506     auto query = torch::ones({batch_size, tgt_len, embed_dim});
507     auto key = torch::ones({batch_size, src_len, embed_dim});
508     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
509     auto value = key;
510 
511     Sequential sequential(MultiheadAttention(embed_dim, num_heads));
512     auto output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(
513         query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1));
514 
515     auto attn_output = std::get<0>(output);
516     auto attn_output_expected = torch::tensor(
517         {{{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
518           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
519           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
520           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
521           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
522           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
523           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
524           {0.0674,
525            -0.0056,
526            0.1324,
527            0.0922,
528            0.0160,
529            -0.0934,
530            -0.1700,
531            0.1663}}});
532     ASSERT_TRUE(
533         torch::allclose(attn_output, attn_output_expected, 1e-05, 2e-04));
534 
535     auto attn_output_weights = std::get<1>(output);
536     auto attn_output_weights_expected = torch::tensor(
537         {{{0.3333, 0.3333, 0.3333}},
538          {{0.3333, 0.3333, 0.3333}},
539          {{0.3333, 0.3333, 0.3333}},
540          {{0.3333, 0.3333, 0.3333}},
541          {{0.3333, 0.3333, 0.3333}},
542          {{0.3333, 0.3333, 0.3333}},
543          {{0.3333, 0.3333, 0.3333}},
544          {{0.3333, 0.3333, 0.3333}}});
545     ASSERT_TRUE(torch::allclose(
546         attn_output_weights, attn_output_weights_expected, 1e-05, 2e-04));
547   }
548   {
549     auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
550     auto x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat));
551     Sequential sequential(MaxUnpool1d(3));
552     auto y = sequential->forward(x, indices);
553     auto expected =
554         torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat);
555     ASSERT_TRUE(torch::allclose(y, expected));
556   }
557   {
558     auto indices = torch::tensor(
559         {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
560          {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
561         torch::kLong);
562     auto x = torch::tensor(
563         {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
564          {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
565         torch::dtype(torch::kFloat));
566     Sequential sequential(
567         MaxUnpool2d(MaxUnpool2dOptions(3).stride(2).padding(1)));
568     auto y = sequential->forward(x, indices);
569     auto expected = torch::tensor(
570         {{{{0, 0, 0, 0, 0},
571            {0, 6, 0, 8, 9},
572            {0, 0, 0, 0, 0},
573            {0, 16, 0, 18, 19},
574            {0, 21, 0, 23, 24}}},
575          {{{0, 0, 0, 0, 0},
576            {0, 31, 0, 33, 34},
577            {0, 0, 0, 0, 0},
578            {0, 41, 0, 43, 44},
579            {0, 46, 0, 48, 49}}}},
580         torch::kFloat);
581     ASSERT_TRUE(torch::allclose(y, expected));
582   }
583   {
584     auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
585     auto x = torch::tensor(
586         {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
587     Sequential sequential(MaxUnpool3d(3));
588     auto y = sequential->forward(x, indices);
589     auto expected = torch::tensor(
590         {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
591            {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
592            {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
593         torch::kFloat);
594     ASSERT_TRUE(torch::allclose(y, expected));
595   }
596   {
597     torch::manual_seed(0);
598     Sequential sequential(Identity(), RNN(2, 3));
599     auto x = torch::ones({2, 3, 2});
600     auto rnn_output =
601         sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
602     auto expected_output = torch::tensor(
603         {{{-0.0645, -0.7274, 0.4531},
604           {-0.0645, -0.7274, 0.4531},
605           {-0.0645, -0.7274, 0.4531}},
606          {{-0.3970, -0.6950, 0.6009},
607           {-0.3970, -0.6950, 0.6009},
608           {-0.3970, -0.6950, 0.6009}}});
609     ASSERT_TRUE(torch::allclose(
610         std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
611   }
612   {
613     torch::manual_seed(0);
614     Sequential sequential(Identity(), LSTM(2, 3));
615     auto x = torch::ones({2, 3, 2});
616     auto rnn_output = sequential->forward<
617         std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>>(x);
618     auto expected_output = torch::tensor(
619         {{{-0.2693, -0.1240, 0.0744},
620           {-0.2693, -0.1240, 0.0744},
621           {-0.2693, -0.1240, 0.0744}},
622          {{-0.3889, -0.1919, 0.1183},
623           {-0.3889, -0.1919, 0.1183},
624           {-0.3889, -0.1919, 0.1183}}});
625     ASSERT_TRUE(torch::allclose(
626         std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
627   }
628   {
629     torch::manual_seed(0);
630     Sequential sequential(Identity(), GRU(2, 3));
631     auto x = torch::ones({2, 3, 2});
632     auto rnn_output =
633         sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
634     auto expected_output = torch::tensor(
635         {{{-0.1134, 0.0467, 0.2336},
636           {-0.1134, 0.0467, 0.2336},
637           {-0.1134, 0.0467, 0.2336}},
638          {{-0.1189, 0.0502, 0.2960},
639           {-0.1189, 0.0502, 0.2960},
640           {-0.1189, 0.0502, 0.2960}}});
641     ASSERT_TRUE(torch::allclose(
642         std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
643   }
644   {
645     torch::manual_seed(0);
646     Sequential sequential(Identity(), RNNCell(2, 3));
647     auto x = torch::ones({2, 2});
648     auto rnn_output = sequential->forward<torch::Tensor>(x);
649     auto expected_output =
650         torch::tensor({{-0.0645, -0.7274, 0.4531}, {-0.0645, -0.7274, 0.4531}});
651     ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04));
652   }
653   {
654     torch::manual_seed(0);
655     Sequential sequential(Identity(), LSTMCell(2, 3));
656     auto x = torch::ones({2, 2});
657     auto rnn_output =
658         sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
659     auto expected_output =
660         torch::tensor({{-0.2693, -0.1240, 0.0744}, {-0.2693, -0.1240, 0.0744}});
661     ASSERT_TRUE(torch::allclose(
662         std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
663   }
664   {
665     torch::manual_seed(0);
666     Sequential sequential(Identity(), GRUCell(2, 3));
667     auto x = torch::ones({2, 2});
668     auto rnn_output = sequential->forward<torch::Tensor>(x);
669     auto expected_output =
670         torch::tensor({{-0.1134, 0.0467, 0.2336}, {-0.1134, 0.0467, 0.2336}});
671     ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04));
672   }
673 }
674