xref: /aosp_15_r20/external/pytorch/test/cpp/api/sequential.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
7*da0073e9SAndroid Build Coastguard Worker #include <memory>
8*da0073e9SAndroid Build Coastguard Worker #include <vector>
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
13*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker struct SequentialTest : torch::test::SeedingFixture {};
16*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,CanContainThings)17*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, CanContainThings) {
18*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(Linear(3, 4), ReLU(), BatchNorm1d(3));
19*da0073e9SAndroid Build Coastguard Worker }
20*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,ConstructsFromSharedPointer)21*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, ConstructsFromSharedPointer) {
22*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
23*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
24*da0073e9SAndroid Build Coastguard Worker     int value;
25*da0073e9SAndroid Build Coastguard Worker     int forward() {
26*da0073e9SAndroid Build Coastguard Worker       return value;
27*da0073e9SAndroid Build Coastguard Worker     }
28*da0073e9SAndroid Build Coastguard Worker   };
29*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(
30*da0073e9SAndroid Build Coastguard Worker       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
31*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 3);
32*da0073e9SAndroid Build Coastguard Worker 
33*da0073e9SAndroid Build Coastguard Worker   Sequential sequential_named(
34*da0073e9SAndroid Build Coastguard Worker       {{"m1", std::make_shared<M>(1)},
35*da0073e9SAndroid Build Coastguard Worker        {std::string("m2"), std::make_shared<M>(2)},
36*da0073e9SAndroid Build Coastguard Worker        {"m3", std::make_shared<M>(3)}});
37*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 3);
38*da0073e9SAndroid Build Coastguard Worker }
39*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,ConstructsFromConcreteType)40*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, ConstructsFromConcreteType) {
41*da0073e9SAndroid Build Coastguard Worker   static int copy_count;
42*da0073e9SAndroid Build Coastguard Worker 
43*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
44*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
45*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
46*da0073e9SAndroid Build Coastguard Worker     M(const M& other) : torch::nn::Module(other) {
47*da0073e9SAndroid Build Coastguard Worker       copy_count++;
48*da0073e9SAndroid Build Coastguard Worker     }
49*da0073e9SAndroid Build Coastguard Worker     int value;
50*da0073e9SAndroid Build Coastguard Worker     int forward() {
51*da0073e9SAndroid Build Coastguard Worker       return value;
52*da0073e9SAndroid Build Coastguard Worker     }
53*da0073e9SAndroid Build Coastguard Worker   };
54*da0073e9SAndroid Build Coastguard Worker 
55*da0073e9SAndroid Build Coastguard Worker   copy_count = 0;
56*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(M(1), M(2), M(3));
57*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 3);
58*da0073e9SAndroid Build Coastguard Worker   // NOTE: The current implementation expects each module to be copied exactly
59*da0073e9SAndroid Build Coastguard Worker   // once, which happens when the module is passed into `std::make_shared<T>()`.
60*da0073e9SAndroid Build Coastguard Worker   // TODO: Find a way to avoid copying, and then delete the copy constructor of
61*da0073e9SAndroid Build Coastguard Worker   // `M`.
62*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(copy_count, 3);
63*da0073e9SAndroid Build Coastguard Worker 
64*da0073e9SAndroid Build Coastguard Worker   copy_count = 0;
65*da0073e9SAndroid Build Coastguard Worker   Sequential sequential_named(
66*da0073e9SAndroid Build Coastguard Worker       {{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}});
67*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 3);
68*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(copy_count, 3);
69*da0073e9SAndroid Build Coastguard Worker }
70*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,ConstructsFromModuleHolder)71*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, ConstructsFromModuleHolder) {
72*da0073e9SAndroid Build Coastguard Worker   struct MImpl : torch::nn::Module {
73*da0073e9SAndroid Build Coastguard Worker     explicit MImpl(int value_) : value(value_) {}
74*da0073e9SAndroid Build Coastguard Worker     int forward() {
75*da0073e9SAndroid Build Coastguard Worker       return value;
76*da0073e9SAndroid Build Coastguard Worker     }
77*da0073e9SAndroid Build Coastguard Worker     int value;
78*da0073e9SAndroid Build Coastguard Worker   };
79*da0073e9SAndroid Build Coastguard Worker 
80*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::ModuleHolder<MImpl> {
81*da0073e9SAndroid Build Coastguard Worker     using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
82*da0073e9SAndroid Build Coastguard Worker     using torch::nn::ModuleHolder<MImpl>::get;
83*da0073e9SAndroid Build Coastguard Worker   };
84*da0073e9SAndroid Build Coastguard Worker 
85*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(M(1), M(2), M(3));
86*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 3);
87*da0073e9SAndroid Build Coastguard Worker 
88*da0073e9SAndroid Build Coastguard Worker   Sequential sequential_named(
89*da0073e9SAndroid Build Coastguard Worker       {{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}});
90*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 3);
91*da0073e9SAndroid Build Coastguard Worker }
92*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,PushBackAddsAnElement)93*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, PushBackAddsAnElement) {
94*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
95*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
96*da0073e9SAndroid Build Coastguard Worker     int forward() {
97*da0073e9SAndroid Build Coastguard Worker       return value;
98*da0073e9SAndroid Build Coastguard Worker     }
99*da0073e9SAndroid Build Coastguard Worker     int value;
100*da0073e9SAndroid Build Coastguard Worker   };
101*da0073e9SAndroid Build Coastguard Worker 
102*da0073e9SAndroid Build Coastguard Worker   // Test unnamed submodules
103*da0073e9SAndroid Build Coastguard Worker   Sequential sequential;
104*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 0);
105*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(sequential->is_empty());
106*da0073e9SAndroid Build Coastguard Worker   sequential->push_back(Linear(3, 4));
107*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 1);
108*da0073e9SAndroid Build Coastguard Worker   sequential->push_back(std::make_shared<M>(1));
109*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 2);
110*da0073e9SAndroid Build Coastguard Worker   sequential->push_back(M(2));
111*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 3);
112*da0073e9SAndroid Build Coastguard Worker 
113*da0073e9SAndroid Build Coastguard Worker   // Mix named and unnamed submodules
114*da0073e9SAndroid Build Coastguard Worker   Sequential sequential_named;
115*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->size(), 0);
116*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(sequential_named->is_empty());
117*da0073e9SAndroid Build Coastguard Worker 
118*da0073e9SAndroid Build Coastguard Worker   sequential_named->push_back(Linear(3, 4));
119*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->size(), 1);
120*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->named_children()[0].key(), "0");
121*da0073e9SAndroid Build Coastguard Worker   sequential_named->push_back(std::string("linear2"), Linear(3, 4));
122*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->size(), 2);
123*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->named_children()[1].key(), "linear2");
124*da0073e9SAndroid Build Coastguard Worker 
125*da0073e9SAndroid Build Coastguard Worker   sequential_named->push_back("shared_m1", std::make_shared<M>(1));
126*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->size(), 3);
127*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->named_children()[2].key(), "shared_m1");
128*da0073e9SAndroid Build Coastguard Worker   sequential_named->push_back(std::make_shared<M>(1));
129*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->size(), 4);
130*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->named_children()[3].key(), "3");
131*da0073e9SAndroid Build Coastguard Worker 
132*da0073e9SAndroid Build Coastguard Worker   sequential_named->push_back(M(1));
133*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->size(), 5);
134*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->named_children()[4].key(), "4");
135*da0073e9SAndroid Build Coastguard Worker   sequential_named->push_back(std::string("m2"), M(1));
136*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->size(), 6);
137*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_named->named_children()[5].key(), "m2");
138*da0073e9SAndroid Build Coastguard Worker 
139*da0073e9SAndroid Build Coastguard Worker   // named and unnamed AnyModule's
140*da0073e9SAndroid Build Coastguard Worker   Sequential sequential_any;
141*da0073e9SAndroid Build Coastguard Worker   auto a = torch::nn::AnyModule(torch::nn::Linear(1, 2));
142*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_any->size(), 0);
143*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(sequential_any->is_empty());
144*da0073e9SAndroid Build Coastguard Worker   sequential_any->push_back(a);
145*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_any->size(), 1);
146*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_any->named_children()[0].key(), "0");
147*da0073e9SAndroid Build Coastguard Worker   sequential_any->push_back("fc", a);
148*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_any->size(), 2);
149*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential_any->named_children()[1].key(), "fc");
150*da0073e9SAndroid Build Coastguard Worker }
151*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,AccessWithAt)152*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, AccessWithAt) {
153*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
154*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
155*da0073e9SAndroid Build Coastguard Worker     int forward() {
156*da0073e9SAndroid Build Coastguard Worker       return value;
157*da0073e9SAndroid Build Coastguard Worker     }
158*da0073e9SAndroid Build Coastguard Worker     int value;
159*da0073e9SAndroid Build Coastguard Worker   };
160*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<M>> modules = {
161*da0073e9SAndroid Build Coastguard Worker       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
162*da0073e9SAndroid Build Coastguard Worker 
163*da0073e9SAndroid Build Coastguard Worker   Sequential sequential;
164*da0073e9SAndroid Build Coastguard Worker   for (auto& module : modules) {
165*da0073e9SAndroid Build Coastguard Worker     sequential->push_back(module);
166*da0073e9SAndroid Build Coastguard Worker   }
167*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 3);
168*da0073e9SAndroid Build Coastguard Worker 
169*da0073e9SAndroid Build Coastguard Worker   // returns the correct module for a given index
170*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(modules.size())) {
171*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(&sequential->at<M>(i), modules[i].get());
172*da0073e9SAndroid Build Coastguard Worker   }
173*da0073e9SAndroid Build Coastguard Worker 
174*da0073e9SAndroid Build Coastguard Worker   // throws for a bad index
175*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
176*da0073e9SAndroid Build Coastguard Worker       sequential->at<M>(modules.size() + 1), "Index out of range");
177*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
178*da0073e9SAndroid Build Coastguard Worker       sequential->at<M>(modules.size() + 1000000), "Index out of range");
179*da0073e9SAndroid Build Coastguard Worker }
180*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,AccessWithPtr)181*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, AccessWithPtr) {
182*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
183*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
184*da0073e9SAndroid Build Coastguard Worker     int forward() {
185*da0073e9SAndroid Build Coastguard Worker       return value;
186*da0073e9SAndroid Build Coastguard Worker     }
187*da0073e9SAndroid Build Coastguard Worker     int value;
188*da0073e9SAndroid Build Coastguard Worker   };
189*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<M>> modules = {
190*da0073e9SAndroid Build Coastguard Worker       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
191*da0073e9SAndroid Build Coastguard Worker 
192*da0073e9SAndroid Build Coastguard Worker   Sequential sequential;
193*da0073e9SAndroid Build Coastguard Worker   for (auto& module : modules) {
194*da0073e9SAndroid Build Coastguard Worker     sequential->push_back(module);
195*da0073e9SAndroid Build Coastguard Worker   }
196*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), 3);
197*da0073e9SAndroid Build Coastguard Worker 
198*da0073e9SAndroid Build Coastguard Worker   // returns the correct module for a given index
199*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(modules.size())) {
200*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(sequential->ptr(i).get(), modules[i].get());
201*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(sequential[i].get(), modules[i].get());
202*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(sequential->ptr<M>(i).get(), modules[i].get());
203*da0073e9SAndroid Build Coastguard Worker   }
204*da0073e9SAndroid Build Coastguard Worker 
205*da0073e9SAndroid Build Coastguard Worker   // throws for a bad index
206*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(sequential->ptr(modules.size() + 1), "Index out of range");
207*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
208*da0073e9SAndroid Build Coastguard Worker       sequential->ptr(modules.size() + 1000000), "Index out of range");
209*da0073e9SAndroid Build Coastguard Worker }
210*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,CallingForwardOnEmptySequentialIsDisallowed)211*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, CallingForwardOnEmptySequentialIsDisallowed) {
212*da0073e9SAndroid Build Coastguard Worker   Sequential empty;
213*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
214*da0073e9SAndroid Build Coastguard Worker       empty->forward<int>(), "Cannot call forward() on an empty Sequential");
215*da0073e9SAndroid Build Coastguard Worker }
216*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,CallingForwardChainsCorrectly)217*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, CallingForwardChainsCorrectly) {
218*da0073e9SAndroid Build Coastguard Worker   struct MockModule : torch::nn::Module {
219*da0073e9SAndroid Build Coastguard Worker     explicit MockModule(int value) : expected(value) {}
220*da0073e9SAndroid Build Coastguard Worker     int expected;
221*da0073e9SAndroid Build Coastguard Worker     int forward(int value) {
222*da0073e9SAndroid Build Coastguard Worker       assert(value == expected);
223*da0073e9SAndroid Build Coastguard Worker       return value + 1;
224*da0073e9SAndroid Build Coastguard Worker     }
225*da0073e9SAndroid Build Coastguard Worker   };
226*da0073e9SAndroid Build Coastguard Worker 
227*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
228*da0073e9SAndroid Build Coastguard Worker 
229*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->forward<int>(1), 4);
230*da0073e9SAndroid Build Coastguard Worker }
231*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,CallingForwardWithTheWrongReturnTypeThrows)232*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, CallingForwardWithTheWrongReturnTypeThrows) {
233*da0073e9SAndroid Build Coastguard Worker   struct M : public torch::nn::Module {
234*da0073e9SAndroid Build Coastguard Worker     int forward() {
235*da0073e9SAndroid Build Coastguard Worker       return 5;
236*da0073e9SAndroid Build Coastguard Worker     }
237*da0073e9SAndroid Build Coastguard Worker   };
238*da0073e9SAndroid Build Coastguard Worker 
239*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(M{});
240*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->forward<int>(), 5);
241*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
242*da0073e9SAndroid Build Coastguard Worker       sequential->forward<float>(),
243*da0073e9SAndroid Build Coastguard Worker       "The type of the return value is int, but you asked for type float");
244*da0073e9SAndroid Build Coastguard Worker }
245*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,TheReturnTypeOfForwardDefaultsToTensor)246*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, TheReturnTypeOfForwardDefaultsToTensor) {
247*da0073e9SAndroid Build Coastguard Worker   struct M : public torch::nn::Module {
248*da0073e9SAndroid Build Coastguard Worker     torch::Tensor forward(torch::Tensor v) {
249*da0073e9SAndroid Build Coastguard Worker       return v;
250*da0073e9SAndroid Build Coastguard Worker     }
251*da0073e9SAndroid Build Coastguard Worker   };
252*da0073e9SAndroid Build Coastguard Worker 
253*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(M{});
254*da0073e9SAndroid Build Coastguard Worker   auto variable = torch::ones({3, 3}, torch::requires_grad());
255*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(sequential->forward(variable).equal(variable));
256*da0073e9SAndroid Build Coastguard Worker }
257*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,ForwardReturnsTheLastValue)258*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, ForwardReturnsTheLastValue) {
259*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
260*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
261*da0073e9SAndroid Build Coastguard Worker 
262*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({1000, 10}, torch::requires_grad());
263*da0073e9SAndroid Build Coastguard Worker   auto y = sequential->forward(x);
264*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.ndimension(), 2);
265*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(0), 1000);
266*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(y.size(1), 100);
267*da0073e9SAndroid Build Coastguard Worker }
268*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,SanityCheckForHoldingStandardModules)269*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, SanityCheckForHoldingStandardModules) {
270*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(
271*da0073e9SAndroid Build Coastguard Worker       Linear(10, 3),
272*da0073e9SAndroid Build Coastguard Worker       Conv2d(1, 2, 3),
273*da0073e9SAndroid Build Coastguard Worker       Dropout(0.5),
274*da0073e9SAndroid Build Coastguard Worker       BatchNorm2d(5),
275*da0073e9SAndroid Build Coastguard Worker       Embedding(4, 10),
276*da0073e9SAndroid Build Coastguard Worker       LSTM(4, 5));
277*da0073e9SAndroid Build Coastguard Worker }
278*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,ExtendPushesModulesFromOtherSequential)279*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, ExtendPushesModulesFromOtherSequential) {
280*da0073e9SAndroid Build Coastguard Worker   struct A : torch::nn::Module {
281*da0073e9SAndroid Build Coastguard Worker     int forward(int x) {
282*da0073e9SAndroid Build Coastguard Worker       return x;
283*da0073e9SAndroid Build Coastguard Worker     }
284*da0073e9SAndroid Build Coastguard Worker   };
285*da0073e9SAndroid Build Coastguard Worker   struct B : torch::nn::Module {
286*da0073e9SAndroid Build Coastguard Worker     int forward(int x) {
287*da0073e9SAndroid Build Coastguard Worker       return x;
288*da0073e9SAndroid Build Coastguard Worker     }
289*da0073e9SAndroid Build Coastguard Worker   };
290*da0073e9SAndroid Build Coastguard Worker   struct C : torch::nn::Module {
291*da0073e9SAndroid Build Coastguard Worker     int forward(int x) {
292*da0073e9SAndroid Build Coastguard Worker       return x;
293*da0073e9SAndroid Build Coastguard Worker     }
294*da0073e9SAndroid Build Coastguard Worker   };
295*da0073e9SAndroid Build Coastguard Worker   struct D : torch::nn::Module {
296*da0073e9SAndroid Build Coastguard Worker     int forward(int x) {
297*da0073e9SAndroid Build Coastguard Worker       return x;
298*da0073e9SAndroid Build Coastguard Worker     }
299*da0073e9SAndroid Build Coastguard Worker   };
300*da0073e9SAndroid Build Coastguard Worker   Sequential a(A{}, B{});
301*da0073e9SAndroid Build Coastguard Worker   Sequential b(C{}, D{});
302*da0073e9SAndroid Build Coastguard Worker   a->extend(*b);
303*da0073e9SAndroid Build Coastguard Worker 
304*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(a->size(), 4);
305*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(a[0]->as<A>());
306*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(a[1]->as<B>());
307*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(a[2]->as<C>());
308*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(a[3]->as<D>());
309*da0073e9SAndroid Build Coastguard Worker 
310*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(b->size(), 2);
311*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(b[0]->as<C>());
312*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(b[1]->as<D>());
313*da0073e9SAndroid Build Coastguard Worker 
314*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<A>> c = {
315*da0073e9SAndroid Build Coastguard Worker       std::make_shared<A>(), std::make_shared<A>()};
316*da0073e9SAndroid Build Coastguard Worker   b->extend(c);
317*da0073e9SAndroid Build Coastguard Worker 
318*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(b->size(), 4);
319*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(b[0]->as<C>());
320*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(b[1]->as<D>());
321*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(b[2]->as<A>());
322*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(b[3]->as<A>());
323*da0073e9SAndroid Build Coastguard Worker }
324*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,HasReferenceSemantics)325*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, HasReferenceSemantics) {
326*da0073e9SAndroid Build Coastguard Worker   Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
327*da0073e9SAndroid Build Coastguard Worker   Sequential second(first);
328*da0073e9SAndroid Build Coastguard Worker 
329*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(first.get(), second.get());
330*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(first->size(), second->size());
331*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(std::equal(
332*da0073e9SAndroid Build Coastguard Worker       first->begin(),
333*da0073e9SAndroid Build Coastguard Worker       first->end(),
334*da0073e9SAndroid Build Coastguard Worker       second->begin(),
335*da0073e9SAndroid Build Coastguard Worker       [](const AnyModule& first, const AnyModule& second) {
336*da0073e9SAndroid Build Coastguard Worker         return &first == &second;
337*da0073e9SAndroid Build Coastguard Worker       }));
338*da0073e9SAndroid Build Coastguard Worker }
339*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,IsCloneable)340*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, IsCloneable) {
341*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
342*da0073e9SAndroid Build Coastguard Worker   Sequential clone =
343*da0073e9SAndroid Build Coastguard Worker       std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
344*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(sequential->size(), clone->size());
345*da0073e9SAndroid Build Coastguard Worker 
346*da0073e9SAndroid Build Coastguard Worker   for (size_t i = 0; i < sequential->size(); ++i) {
347*da0073e9SAndroid Build Coastguard Worker     // The modules should be the same kind (type).
348*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(sequential[i]->name(), clone[i]->name());
349*da0073e9SAndroid Build Coastguard Worker     // But not pointer-equal (distinct objects).
350*da0073e9SAndroid Build Coastguard Worker     ASSERT_NE(sequential[i], clone[i]);
351*da0073e9SAndroid Build Coastguard Worker   }
352*da0073e9SAndroid Build Coastguard Worker 
353*da0073e9SAndroid Build Coastguard Worker   // Verify that the clone is deep, i.e. parameters of modules are cloned too.
354*da0073e9SAndroid Build Coastguard Worker 
355*da0073e9SAndroid Build Coastguard Worker   torch::NoGradGuard no_grad;
356*da0073e9SAndroid Build Coastguard Worker 
357*da0073e9SAndroid Build Coastguard Worker   auto params1 = sequential->named_parameters();
358*da0073e9SAndroid Build Coastguard Worker   auto params2 = clone->named_parameters();
359*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(params1.size(), params2.size());
360*da0073e9SAndroid Build Coastguard Worker   for (auto& param : params1) {
361*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
362*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(param->device(), params2[param.key()].device());
363*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(param->allclose(params2[param.key()]));
364*da0073e9SAndroid Build Coastguard Worker     param->add_(2);
365*da0073e9SAndroid Build Coastguard Worker   }
366*da0073e9SAndroid Build Coastguard Worker   for (auto& param : params1) {
367*da0073e9SAndroid Build Coastguard Worker     ASSERT_FALSE(param->allclose(params2[param.key()]));
368*da0073e9SAndroid Build Coastguard Worker   }
369*da0073e9SAndroid Build Coastguard Worker }
370*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,RegistersElementsAsSubmodules)371*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, RegistersElementsAsSubmodules) {
372*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), Dropout2d(0.5));
373*da0073e9SAndroid Build Coastguard Worker 
374*da0073e9SAndroid Build Coastguard Worker   auto modules = sequential->children();
375*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(modules[0]->as<Linear>());
376*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(modules[1]->as<Conv2d>());
377*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(modules[2]->as<Dropout2d>());
378*da0073e9SAndroid Build Coastguard Worker }
379*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,CloneToDevice_CUDA)380*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, CloneToDevice_CUDA) {
381*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3));
382*da0073e9SAndroid Build Coastguard Worker   torch::Device device(torch::kCUDA, 0);
383*da0073e9SAndroid Build Coastguard Worker   Sequential clone =
384*da0073e9SAndroid Build Coastguard Worker       std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
385*da0073e9SAndroid Build Coastguard Worker   for (const auto& p : clone->parameters()) {
386*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(p.device(), device);
387*da0073e9SAndroid Build Coastguard Worker   }
388*da0073e9SAndroid Build Coastguard Worker   for (const auto& b : clone->buffers()) {
389*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(b.device(), device);
390*da0073e9SAndroid Build Coastguard Worker   }
391*da0073e9SAndroid Build Coastguard Worker }
392*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,PrettyPrintSequential)393*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, PrettyPrintSequential) {
394*da0073e9SAndroid Build Coastguard Worker   Sequential sequential(
395*da0073e9SAndroid Build Coastguard Worker       Linear(10, 3),
396*da0073e9SAndroid Build Coastguard Worker       Conv2d(1, 2, 3),
397*da0073e9SAndroid Build Coastguard Worker       Dropout(0.5),
398*da0073e9SAndroid Build Coastguard Worker       BatchNorm2d(5),
399*da0073e9SAndroid Build Coastguard Worker       Embedding(4, 10),
400*da0073e9SAndroid Build Coastguard Worker       LSTM(4, 5));
401*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
402*da0073e9SAndroid Build Coastguard Worker       c10::str(sequential),
403*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Sequential(\n"
404*da0073e9SAndroid Build Coastguard Worker       "  (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
405*da0073e9SAndroid Build Coastguard Worker       "  (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
406*da0073e9SAndroid Build Coastguard Worker       "  (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
407*da0073e9SAndroid Build Coastguard Worker       "  (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
408*da0073e9SAndroid Build Coastguard Worker       "  (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
409*da0073e9SAndroid Build Coastguard Worker       "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
410*da0073e9SAndroid Build Coastguard Worker       ")");
411*da0073e9SAndroid Build Coastguard Worker 
412*da0073e9SAndroid Build Coastguard Worker   Sequential sequential_named(
413*da0073e9SAndroid Build Coastguard Worker       {{"linear", Linear(10, 3)},
414*da0073e9SAndroid Build Coastguard Worker        {"conv2d", Conv2d(1, 2, 3)},
415*da0073e9SAndroid Build Coastguard Worker        {"dropout", Dropout(0.5)},
416*da0073e9SAndroid Build Coastguard Worker        {"batchnorm2d", BatchNorm2d(5)},
417*da0073e9SAndroid Build Coastguard Worker        {"embedding", Embedding(4, 10)},
418*da0073e9SAndroid Build Coastguard Worker        {"lstm", LSTM(4, 5)}});
419*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
420*da0073e9SAndroid Build Coastguard Worker       c10::str(sequential_named),
421*da0073e9SAndroid Build Coastguard Worker       "torch::nn::Sequential(\n"
422*da0073e9SAndroid Build Coastguard Worker       "  (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
423*da0073e9SAndroid Build Coastguard Worker       "  (conv2d): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
424*da0073e9SAndroid Build Coastguard Worker       "  (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
425*da0073e9SAndroid Build Coastguard Worker       "  (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
426*da0073e9SAndroid Build Coastguard Worker       "  (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
427*da0073e9SAndroid Build Coastguard Worker       "  (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
428*da0073e9SAndroid Build Coastguard Worker       ")");
429*da0073e9SAndroid Build Coastguard Worker }
430*da0073e9SAndroid Build Coastguard Worker 
TEST_F(SequentialTest,ModuleForwardMethodOptionalArg)431*da0073e9SAndroid Build Coastguard Worker TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) {
432*da0073e9SAndroid Build Coastguard Worker   {
433*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(
434*da0073e9SAndroid Build Coastguard Worker         Identity(),
435*da0073e9SAndroid Build Coastguard Worker         ConvTranspose1d(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false)));
436*da0073e9SAndroid Build Coastguard Worker     std::dynamic_pointer_cast<ConvTranspose1dImpl>(sequential[1])
437*da0073e9SAndroid Build Coastguard Worker         ->weight.set_data(torch::arange(18.).reshape({3, 2, 3}));
438*da0073e9SAndroid Build Coastguard Worker     auto x = torch::arange(30.).reshape({2, 3, 5});
439*da0073e9SAndroid Build Coastguard Worker     auto y = sequential->forward(x);
440*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
441*da0073e9SAndroid Build Coastguard Worker         {{{150., 333., 552., 615., 678., 501., 276.},
442*da0073e9SAndroid Build Coastguard Worker           {195., 432., 714., 804., 894., 654., 357.}},
443*da0073e9SAndroid Build Coastguard Worker          {{420., 918., 1497., 1560., 1623., 1176., 636.},
444*da0073e9SAndroid Build Coastguard Worker           {600., 1287., 2064., 2154., 2244., 1599., 852.}}});
445*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, expected));
446*da0073e9SAndroid Build Coastguard Worker   }
447*da0073e9SAndroid Build Coastguard Worker   {
448*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(
449*da0073e9SAndroid Build Coastguard Worker         Identity(),
450*da0073e9SAndroid Build Coastguard Worker         ConvTranspose2d(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false)));
451*da0073e9SAndroid Build Coastguard Worker     std::dynamic_pointer_cast<ConvTranspose2dImpl>(sequential[1])
452*da0073e9SAndroid Build Coastguard Worker         ->weight.set_data(torch::arange(54.).reshape({3, 2, 3, 3}));
453*da0073e9SAndroid Build Coastguard Worker     auto x = torch::arange(75.).reshape({1, 3, 5, 5});
454*da0073e9SAndroid Build Coastguard Worker     auto y = sequential->forward(x);
455*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
456*da0073e9SAndroid Build Coastguard Worker         {{{{2250., 4629., 7140., 7311., 7482., 5133., 2640.},
457*da0073e9SAndroid Build Coastguard Worker            {4995., 10272., 15837., 16206., 16575., 11364., 5841.},
458*da0073e9SAndroid Build Coastguard Worker            {8280., 17019., 26226., 26820., 27414., 18783., 9648.},
459*da0073e9SAndroid Build Coastguard Worker            {9225., 18954., 29196., 29790., 30384., 20808., 10683.},
460*da0073e9SAndroid Build Coastguard Worker            {10170., 20889., 32166., 32760., 33354., 22833., 11718.},
461*da0073e9SAndroid Build Coastguard Worker            {7515., 15420., 23721., 24144., 24567., 16800., 8613.},
462*da0073e9SAndroid Build Coastguard Worker            {4140., 8487., 13044., 13269., 13494., 9219., 4722.}},
463*da0073e9SAndroid Build Coastguard Worker           {{2925., 6006., 9246., 9498., 9750., 6672., 3423.},
464*da0073e9SAndroid Build Coastguard Worker            {6480., 13296., 20454., 20985., 21516., 14712., 7542.},
465*da0073e9SAndroid Build Coastguard Worker            {10710., 21960., 33759., 34596., 35433., 24210., 12402.},
466*da0073e9SAndroid Build Coastguard Worker            {12060., 24705., 37944., 38781., 39618., 27045., 13842.},
467*da0073e9SAndroid Build Coastguard Worker            {13410., 27450., 42129., 42966., 43803., 29880., 15282.},
468*da0073e9SAndroid Build Coastguard Worker            {9810., 20064., 30768., 31353., 31938., 21768., 11124.},
469*da0073e9SAndroid Build Coastguard Worker            {5355., 10944., 16770., 17076., 17382., 11838., 6045.}}}});
470*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, expected));
471*da0073e9SAndroid Build Coastguard Worker   }
472*da0073e9SAndroid Build Coastguard Worker   {
473*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(
474*da0073e9SAndroid Build Coastguard Worker         Identity(),
475*da0073e9SAndroid Build Coastguard Worker         ConvTranspose3d(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false)));
476*da0073e9SAndroid Build Coastguard Worker     std::dynamic_pointer_cast<ConvTranspose3dImpl>(sequential[1])
477*da0073e9SAndroid Build Coastguard Worker         ->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2}));
478*da0073e9SAndroid Build Coastguard Worker     auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2});
479*da0073e9SAndroid Build Coastguard Worker     auto y = sequential->forward(x);
480*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
481*da0073e9SAndroid Build Coastguard Worker         {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
482*da0073e9SAndroid Build Coastguard Worker            {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
483*da0073e9SAndroid Build Coastguard Worker            {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
484*da0073e9SAndroid Build Coastguard Worker           {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
485*da0073e9SAndroid Build Coastguard Worker            {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
486*da0073e9SAndroid Build Coastguard Worker            {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
487*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, expected));
488*da0073e9SAndroid Build Coastguard Worker   }
489*da0073e9SAndroid Build Coastguard Worker   {
490*da0073e9SAndroid Build Coastguard Worker     auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
491*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(Identity(), EmbeddingBag::from_pretrained(weight));
492*da0073e9SAndroid Build Coastguard Worker     auto x = torch::tensor({{1, 0}}, torch::kLong);
493*da0073e9SAndroid Build Coastguard Worker     auto y = sequential->forward(x);
494*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor({2.5000, 3.7000, 4.6500});
495*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, expected));
496*da0073e9SAndroid Build Coastguard Worker   }
497*da0073e9SAndroid Build Coastguard Worker   {
498*da0073e9SAndroid Build Coastguard Worker     torch::manual_seed(0);
499*da0073e9SAndroid Build Coastguard Worker 
500*da0073e9SAndroid Build Coastguard Worker     int64_t embed_dim = 8;
501*da0073e9SAndroid Build Coastguard Worker     int64_t num_heads = 4;
502*da0073e9SAndroid Build Coastguard Worker     int64_t batch_size = 8;
503*da0073e9SAndroid Build Coastguard Worker     int64_t src_len = 3;
504*da0073e9SAndroid Build Coastguard Worker     int64_t tgt_len = 1;
505*da0073e9SAndroid Build Coastguard Worker 
506*da0073e9SAndroid Build Coastguard Worker     auto query = torch::ones({batch_size, tgt_len, embed_dim});
507*da0073e9SAndroid Build Coastguard Worker     auto key = torch::ones({batch_size, src_len, embed_dim});
508*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
509*da0073e9SAndroid Build Coastguard Worker     auto value = key;
510*da0073e9SAndroid Build Coastguard Worker 
511*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(MultiheadAttention(embed_dim, num_heads));
512*da0073e9SAndroid Build Coastguard Worker     auto output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(
513*da0073e9SAndroid Build Coastguard Worker         query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1));
514*da0073e9SAndroid Build Coastguard Worker 
515*da0073e9SAndroid Build Coastguard Worker     auto attn_output = std::get<0>(output);
516*da0073e9SAndroid Build Coastguard Worker     auto attn_output_expected = torch::tensor(
517*da0073e9SAndroid Build Coastguard Worker         {{{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
518*da0073e9SAndroid Build Coastguard Worker           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
519*da0073e9SAndroid Build Coastguard Worker           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
520*da0073e9SAndroid Build Coastguard Worker           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
521*da0073e9SAndroid Build Coastguard Worker           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
522*da0073e9SAndroid Build Coastguard Worker           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
523*da0073e9SAndroid Build Coastguard Worker           {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663},
524*da0073e9SAndroid Build Coastguard Worker           {0.0674,
525*da0073e9SAndroid Build Coastguard Worker            -0.0056,
526*da0073e9SAndroid Build Coastguard Worker            0.1324,
527*da0073e9SAndroid Build Coastguard Worker            0.0922,
528*da0073e9SAndroid Build Coastguard Worker            0.0160,
529*da0073e9SAndroid Build Coastguard Worker            -0.0934,
530*da0073e9SAndroid Build Coastguard Worker            -0.1700,
531*da0073e9SAndroid Build Coastguard Worker            0.1663}}});
532*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
533*da0073e9SAndroid Build Coastguard Worker         torch::allclose(attn_output, attn_output_expected, 1e-05, 2e-04));
534*da0073e9SAndroid Build Coastguard Worker 
535*da0073e9SAndroid Build Coastguard Worker     auto attn_output_weights = std::get<1>(output);
536*da0073e9SAndroid Build Coastguard Worker     auto attn_output_weights_expected = torch::tensor(
537*da0073e9SAndroid Build Coastguard Worker         {{{0.3333, 0.3333, 0.3333}},
538*da0073e9SAndroid Build Coastguard Worker          {{0.3333, 0.3333, 0.3333}},
539*da0073e9SAndroid Build Coastguard Worker          {{0.3333, 0.3333, 0.3333}},
540*da0073e9SAndroid Build Coastguard Worker          {{0.3333, 0.3333, 0.3333}},
541*da0073e9SAndroid Build Coastguard Worker          {{0.3333, 0.3333, 0.3333}},
542*da0073e9SAndroid Build Coastguard Worker          {{0.3333, 0.3333, 0.3333}},
543*da0073e9SAndroid Build Coastguard Worker          {{0.3333, 0.3333, 0.3333}},
544*da0073e9SAndroid Build Coastguard Worker          {{0.3333, 0.3333, 0.3333}}});
545*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
546*da0073e9SAndroid Build Coastguard Worker         attn_output_weights, attn_output_weights_expected, 1e-05, 2e-04));
547*da0073e9SAndroid Build Coastguard Worker   }
548*da0073e9SAndroid Build Coastguard Worker   {
549*da0073e9SAndroid Build Coastguard Worker     auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
550*da0073e9SAndroid Build Coastguard Worker     auto x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat));
551*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(MaxUnpool1d(3));
552*da0073e9SAndroid Build Coastguard Worker     auto y = sequential->forward(x, indices);
553*da0073e9SAndroid Build Coastguard Worker     auto expected =
554*da0073e9SAndroid Build Coastguard Worker         torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat);
555*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, expected));
556*da0073e9SAndroid Build Coastguard Worker   }
557*da0073e9SAndroid Build Coastguard Worker   {
558*da0073e9SAndroid Build Coastguard Worker     auto indices = torch::tensor(
559*da0073e9SAndroid Build Coastguard Worker         {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
560*da0073e9SAndroid Build Coastguard Worker          {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
561*da0073e9SAndroid Build Coastguard Worker         torch::kLong);
562*da0073e9SAndroid Build Coastguard Worker     auto x = torch::tensor(
563*da0073e9SAndroid Build Coastguard Worker         {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
564*da0073e9SAndroid Build Coastguard Worker          {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
565*da0073e9SAndroid Build Coastguard Worker         torch::dtype(torch::kFloat));
566*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(
567*da0073e9SAndroid Build Coastguard Worker         MaxUnpool2d(MaxUnpool2dOptions(3).stride(2).padding(1)));
568*da0073e9SAndroid Build Coastguard Worker     auto y = sequential->forward(x, indices);
569*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
570*da0073e9SAndroid Build Coastguard Worker         {{{{0, 0, 0, 0, 0},
571*da0073e9SAndroid Build Coastguard Worker            {0, 6, 0, 8, 9},
572*da0073e9SAndroid Build Coastguard Worker            {0, 0, 0, 0, 0},
573*da0073e9SAndroid Build Coastguard Worker            {0, 16, 0, 18, 19},
574*da0073e9SAndroid Build Coastguard Worker            {0, 21, 0, 23, 24}}},
575*da0073e9SAndroid Build Coastguard Worker          {{{0, 0, 0, 0, 0},
576*da0073e9SAndroid Build Coastguard Worker            {0, 31, 0, 33, 34},
577*da0073e9SAndroid Build Coastguard Worker            {0, 0, 0, 0, 0},
578*da0073e9SAndroid Build Coastguard Worker            {0, 41, 0, 43, 44},
579*da0073e9SAndroid Build Coastguard Worker            {0, 46, 0, 48, 49}}}},
580*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
581*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, expected));
582*da0073e9SAndroid Build Coastguard Worker   }
583*da0073e9SAndroid Build Coastguard Worker   {
584*da0073e9SAndroid Build Coastguard Worker     auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
585*da0073e9SAndroid Build Coastguard Worker     auto x = torch::tensor(
586*da0073e9SAndroid Build Coastguard Worker         {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
587*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(MaxUnpool3d(3));
588*da0073e9SAndroid Build Coastguard Worker     auto y = sequential->forward(x, indices);
589*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::tensor(
590*da0073e9SAndroid Build Coastguard Worker         {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
591*da0073e9SAndroid Build Coastguard Worker            {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
592*da0073e9SAndroid Build Coastguard Worker            {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
593*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
594*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, expected));
595*da0073e9SAndroid Build Coastguard Worker   }
596*da0073e9SAndroid Build Coastguard Worker   {
597*da0073e9SAndroid Build Coastguard Worker     torch::manual_seed(0);
598*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(Identity(), RNN(2, 3));
599*da0073e9SAndroid Build Coastguard Worker     auto x = torch::ones({2, 3, 2});
600*da0073e9SAndroid Build Coastguard Worker     auto rnn_output =
601*da0073e9SAndroid Build Coastguard Worker         sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
602*da0073e9SAndroid Build Coastguard Worker     auto expected_output = torch::tensor(
603*da0073e9SAndroid Build Coastguard Worker         {{{-0.0645, -0.7274, 0.4531},
604*da0073e9SAndroid Build Coastguard Worker           {-0.0645, -0.7274, 0.4531},
605*da0073e9SAndroid Build Coastguard Worker           {-0.0645, -0.7274, 0.4531}},
606*da0073e9SAndroid Build Coastguard Worker          {{-0.3970, -0.6950, 0.6009},
607*da0073e9SAndroid Build Coastguard Worker           {-0.3970, -0.6950, 0.6009},
608*da0073e9SAndroid Build Coastguard Worker           {-0.3970, -0.6950, 0.6009}}});
609*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
610*da0073e9SAndroid Build Coastguard Worker         std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
611*da0073e9SAndroid Build Coastguard Worker   }
612*da0073e9SAndroid Build Coastguard Worker   {
613*da0073e9SAndroid Build Coastguard Worker     torch::manual_seed(0);
614*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(Identity(), LSTM(2, 3));
615*da0073e9SAndroid Build Coastguard Worker     auto x = torch::ones({2, 3, 2});
616*da0073e9SAndroid Build Coastguard Worker     auto rnn_output = sequential->forward<
617*da0073e9SAndroid Build Coastguard Worker         std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>>(x);
618*da0073e9SAndroid Build Coastguard Worker     auto expected_output = torch::tensor(
619*da0073e9SAndroid Build Coastguard Worker         {{{-0.2693, -0.1240, 0.0744},
620*da0073e9SAndroid Build Coastguard Worker           {-0.2693, -0.1240, 0.0744},
621*da0073e9SAndroid Build Coastguard Worker           {-0.2693, -0.1240, 0.0744}},
622*da0073e9SAndroid Build Coastguard Worker          {{-0.3889, -0.1919, 0.1183},
623*da0073e9SAndroid Build Coastguard Worker           {-0.3889, -0.1919, 0.1183},
624*da0073e9SAndroid Build Coastguard Worker           {-0.3889, -0.1919, 0.1183}}});
625*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
626*da0073e9SAndroid Build Coastguard Worker         std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
627*da0073e9SAndroid Build Coastguard Worker   }
628*da0073e9SAndroid Build Coastguard Worker   {
629*da0073e9SAndroid Build Coastguard Worker     torch::manual_seed(0);
630*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(Identity(), GRU(2, 3));
631*da0073e9SAndroid Build Coastguard Worker     auto x = torch::ones({2, 3, 2});
632*da0073e9SAndroid Build Coastguard Worker     auto rnn_output =
633*da0073e9SAndroid Build Coastguard Worker         sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
634*da0073e9SAndroid Build Coastguard Worker     auto expected_output = torch::tensor(
635*da0073e9SAndroid Build Coastguard Worker         {{{-0.1134, 0.0467, 0.2336},
636*da0073e9SAndroid Build Coastguard Worker           {-0.1134, 0.0467, 0.2336},
637*da0073e9SAndroid Build Coastguard Worker           {-0.1134, 0.0467, 0.2336}},
638*da0073e9SAndroid Build Coastguard Worker          {{-0.1189, 0.0502, 0.2960},
639*da0073e9SAndroid Build Coastguard Worker           {-0.1189, 0.0502, 0.2960},
640*da0073e9SAndroid Build Coastguard Worker           {-0.1189, 0.0502, 0.2960}}});
641*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
642*da0073e9SAndroid Build Coastguard Worker         std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
643*da0073e9SAndroid Build Coastguard Worker   }
644*da0073e9SAndroid Build Coastguard Worker   {
645*da0073e9SAndroid Build Coastguard Worker     torch::manual_seed(0);
646*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(Identity(), RNNCell(2, 3));
647*da0073e9SAndroid Build Coastguard Worker     auto x = torch::ones({2, 2});
648*da0073e9SAndroid Build Coastguard Worker     auto rnn_output = sequential->forward<torch::Tensor>(x);
649*da0073e9SAndroid Build Coastguard Worker     auto expected_output =
650*da0073e9SAndroid Build Coastguard Worker         torch::tensor({{-0.0645, -0.7274, 0.4531}, {-0.0645, -0.7274, 0.4531}});
651*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04));
652*da0073e9SAndroid Build Coastguard Worker   }
653*da0073e9SAndroid Build Coastguard Worker   {
654*da0073e9SAndroid Build Coastguard Worker     torch::manual_seed(0);
655*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(Identity(), LSTMCell(2, 3));
656*da0073e9SAndroid Build Coastguard Worker     auto x = torch::ones({2, 2});
657*da0073e9SAndroid Build Coastguard Worker     auto rnn_output =
658*da0073e9SAndroid Build Coastguard Worker         sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
659*da0073e9SAndroid Build Coastguard Worker     auto expected_output =
660*da0073e9SAndroid Build Coastguard Worker         torch::tensor({{-0.2693, -0.1240, 0.0744}, {-0.2693, -0.1240, 0.0744}});
661*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(
662*da0073e9SAndroid Build Coastguard Worker         std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
663*da0073e9SAndroid Build Coastguard Worker   }
664*da0073e9SAndroid Build Coastguard Worker   {
665*da0073e9SAndroid Build Coastguard Worker     torch::manual_seed(0);
666*da0073e9SAndroid Build Coastguard Worker     Sequential sequential(Identity(), GRUCell(2, 3));
667*da0073e9SAndroid Build Coastguard Worker     auto x = torch::ones({2, 2});
668*da0073e9SAndroid Build Coastguard Worker     auto rnn_output = sequential->forward<torch::Tensor>(x);
669*da0073e9SAndroid Build Coastguard Worker     auto expected_output =
670*da0073e9SAndroid Build Coastguard Worker         torch::tensor({{-0.1134, 0.0467, 0.2336}, {-0.1134, 0.0467, 0.2336}});
671*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04));
672*da0073e9SAndroid Build Coastguard Worker   }
673*da0073e9SAndroid Build Coastguard Worker }
674