xref: /aosp_15_r20/external/pytorch/test/cpp/api/parallel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/functions/comm.h>
5*da0073e9SAndroid Build Coastguard Worker #include <torch/nn/module.h>
6*da0073e9SAndroid Build Coastguard Worker #include <torch/nn/modules/conv.h>
7*da0073e9SAndroid Build Coastguard Worker #include <torch/nn/modules/linear.h>
8*da0073e9SAndroid Build Coastguard Worker #include <torch/nn/parallel/data_parallel.h>
9*da0073e9SAndroid Build Coastguard Worker #include <torch/nn/pimpl.h>
10*da0073e9SAndroid Build Coastguard Worker #include <torch/optim/sgd.h>
11*da0073e9SAndroid Build Coastguard Worker #include <torch/types.h>
12*da0073e9SAndroid Build Coastguard Worker #include <torch/utils.h>
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker #include <iostream>
17*da0073e9SAndroid Build Coastguard Worker #include <memory>
18*da0073e9SAndroid Build Coastguard Worker #include <utility>
19*da0073e9SAndroid Build Coastguard Worker #include <vector>
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker using namespace torch::autograd;
22*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
23*da0073e9SAndroid Build Coastguard Worker 
24*da0073e9SAndroid Build Coastguard Worker struct ParallelTest : torch::test::SeedingFixture {};
25*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParallelTest,DifferentiableScatter_MultiCUDA)26*da0073e9SAndroid Build Coastguard Worker TEST_F(ParallelTest, DifferentiableScatter_MultiCUDA) {
27*da0073e9SAndroid Build Coastguard Worker   Scatter scatter(
28*da0073e9SAndroid Build Coastguard Worker       {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
29*da0073e9SAndroid Build Coastguard Worker 
30*da0073e9SAndroid Build Coastguard Worker   auto input = torch::ones(10, torch::requires_grad(true));
31*da0073e9SAndroid Build Coastguard Worker   auto output = scatter.apply({input});
32*da0073e9SAndroid Build Coastguard Worker 
33*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.size(), 2);
34*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output[0].size(0), 5);
35*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output[1].size(0), 5);
36*da0073e9SAndroid Build Coastguard Worker 
37*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
38*da0073e9SAndroid Build Coastguard Worker                   .allclose(input));
39*da0073e9SAndroid Build Coastguard Worker 
40*da0073e9SAndroid Build Coastguard Worker   torch::Tensor sum = output[0].to({torch::kCUDA, 1}) + output[1];
41*da0073e9SAndroid Build Coastguard Worker   sum.backward(torch::ones_like(sum));
42*da0073e9SAndroid Build Coastguard Worker 
43*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.grad().defined());
44*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input.grad().device().is_cpu());
45*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(input.grad().sum().item<int32_t>(), 10);
46*da0073e9SAndroid Build Coastguard Worker }
47*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParallelTest,DifferentiableGather_MultiCUDA)48*da0073e9SAndroid Build Coastguard Worker TEST_F(ParallelTest, DifferentiableGather_MultiCUDA) {
49*da0073e9SAndroid Build Coastguard Worker   Gather gather(torch::Device(torch::kCUDA, 1));
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker   auto a = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 0));
52*da0073e9SAndroid Build Coastguard Worker   auto b = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 1));
53*da0073e9SAndroid Build Coastguard Worker 
54*da0073e9SAndroid Build Coastguard Worker   auto outputs = gather.apply({a, b});
55*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(outputs.size(), 1);
56*da0073e9SAndroid Build Coastguard Worker   torch::Tensor output = outputs.front();
57*da0073e9SAndroid Build Coastguard Worker 
58*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.size(0), 10);
59*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.device(), torch::Device(torch::kCUDA, 1));
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker   auto chunks = output.chunk(2);
62*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
63*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(chunks[1].allclose(b));
64*da0073e9SAndroid Build Coastguard Worker 
65*da0073e9SAndroid Build Coastguard Worker   output.backward(torch::ones_like(output));
66*da0073e9SAndroid Build Coastguard Worker 
67*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(a.grad().defined());
68*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(a.grad().device(), torch::Device(torch::kCUDA, 0));
69*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(a.grad().sum().item<int32_t>(), 5);
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(b.grad().defined());
72*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(b.grad().device(), torch::Device(torch::kCUDA, 1));
73*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(b.grad().sum().item<int32_t>(), 5);
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParallelTest,Replicate_MultiCUDA)76*da0073e9SAndroid Build Coastguard Worker TEST_F(ParallelTest, Replicate_MultiCUDA) {
77*da0073e9SAndroid Build Coastguard Worker   Linear linear(3, 4);
78*da0073e9SAndroid Build Coastguard Worker   auto replicas = parallel::replicate(
79*da0073e9SAndroid Build Coastguard Worker       linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
80*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(replicas.size(), 2);
81*da0073e9SAndroid Build Coastguard Worker 
82*da0073e9SAndroid Build Coastguard Worker   auto original_parameters = linear->parameters();
83*da0073e9SAndroid Build Coastguard Worker 
84*da0073e9SAndroid Build Coastguard Worker   auto replica1_parameters = replicas[0]->parameters();
85*da0073e9SAndroid Build Coastguard Worker   for (auto& parameter : replica1_parameters) {
86*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 0));
87*da0073e9SAndroid Build Coastguard Worker   }
88*da0073e9SAndroid Build Coastguard Worker   replicas[0]->to(torch::kCPU);
89*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(replica1_parameters.size(), original_parameters.size());
90*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(original_parameters.size())) {
91*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(replica1_parameters[i].allclose(original_parameters[i]));
92*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
93*da0073e9SAndroid Build Coastguard Worker         replica1_parameters[i].data_ptr<float>() !=
94*da0073e9SAndroid Build Coastguard Worker         original_parameters[i].data_ptr<float>());
95*da0073e9SAndroid Build Coastguard Worker   }
96*da0073e9SAndroid Build Coastguard Worker 
97*da0073e9SAndroid Build Coastguard Worker   auto replica2_parameters = replicas[1]->parameters();
98*da0073e9SAndroid Build Coastguard Worker   for (auto& parameter : replica2_parameters) {
99*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 1));
100*da0073e9SAndroid Build Coastguard Worker   }
101*da0073e9SAndroid Build Coastguard Worker   replicas[1]->to(torch::kCPU);
102*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(replica2_parameters.size(), original_parameters.size());
103*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(original_parameters.size())) {
104*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(replica2_parameters[i].allclose(original_parameters[i]));
105*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
106*da0073e9SAndroid Build Coastguard Worker         replica2_parameters[i].data_ptr<float>() !=
107*da0073e9SAndroid Build Coastguard Worker         original_parameters[i].data_ptr<float>());
108*da0073e9SAndroid Build Coastguard Worker   }
109*da0073e9SAndroid Build Coastguard Worker }
110*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParallelTest,ParallelApply_MultiCUDA)111*da0073e9SAndroid Build Coastguard Worker TEST_F(ParallelTest, ParallelApply_MultiCUDA) {
112*da0073e9SAndroid Build Coastguard Worker   Linear a(3, 4);
113*da0073e9SAndroid Build Coastguard Worker 
114*da0073e9SAndroid Build Coastguard Worker   Linear b(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
115*da0073e9SAndroid Build Coastguard Worker   b->to({torch::kCUDA, 0});
116*da0073e9SAndroid Build Coastguard Worker 
117*da0073e9SAndroid Build Coastguard Worker   Linear c(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
118*da0073e9SAndroid Build Coastguard Worker   c->to({torch::kCUDA, 1});
119*da0073e9SAndroid Build Coastguard Worker 
120*da0073e9SAndroid Build Coastguard Worker   std::vector<Linear> modules = {a, b, c};
121*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> inputs = {
122*da0073e9SAndroid Build Coastguard Worker       torch::ones({2, 3}),
123*da0073e9SAndroid Build Coastguard Worker       torch::ones({2, 3}, torch::device({torch::kCUDA, 0})),
124*da0073e9SAndroid Build Coastguard Worker       torch::ones({2, 3}, torch::device({torch::kCUDA, 1}))};
125*da0073e9SAndroid Build Coastguard Worker 
126*da0073e9SAndroid Build Coastguard Worker   auto outputs = parallel::parallel_apply(modules, inputs);
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(outputs.size(), 3);
129*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(outputs[0].device().is_cpu());
130*da0073e9SAndroid Build Coastguard Worker 
131*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
132*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
133*da0073e9SAndroid Build Coastguard Worker 
134*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(outputs[2].device(), torch::Device(torch::kCUDA, 1));
135*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
136*da0073e9SAndroid Build Coastguard Worker }
137*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParallelTest,ParallelApplyWithDifferentOutputDevice_MultiCUDA)138*da0073e9SAndroid Build Coastguard Worker TEST_F(ParallelTest, ParallelApplyWithDifferentOutputDevice_MultiCUDA) {
139*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
140*da0073e9SAndroid Build Coastguard Worker     torch::Tensor forward(torch::Tensor input) {
141*da0073e9SAndroid Build Coastguard Worker       return torch::ones(5, torch::kInt32);
142*da0073e9SAndroid Build Coastguard Worker     }
143*da0073e9SAndroid Build Coastguard Worker   };
144*da0073e9SAndroid Build Coastguard Worker 
145*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<M>> modules = {
146*da0073e9SAndroid Build Coastguard Worker       std::make_shared<M>(), std::make_shared<M>(), std::make_shared<M>()};
147*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> inputs = {
148*da0073e9SAndroid Build Coastguard Worker       torch::empty({}), torch::empty({}), torch::empty({})};
149*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Device> devices = {
150*da0073e9SAndroid Build Coastguard Worker       {torch::kCUDA, 1}, {torch::kCUDA, 0}, {torch::kCPU}};
151*da0073e9SAndroid Build Coastguard Worker 
152*da0073e9SAndroid Build Coastguard Worker   auto outputs = parallel::parallel_apply(modules, inputs, devices);
153*da0073e9SAndroid Build Coastguard Worker 
154*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(outputs.size(), 3);
155*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(outputs[0].device().is_cuda());
156*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(outputs[0].device(), torch::Device(torch::kCUDA, 1));
157*da0073e9SAndroid Build Coastguard Worker 
158*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(outputs[1].device().is_cuda());
159*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
160*da0073e9SAndroid Build Coastguard Worker 
161*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(outputs[2].device().is_cpu());
162*da0073e9SAndroid Build Coastguard Worker }
163*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParallelTest,ParallelApplyRethrowsException_MultiCUDA)164*da0073e9SAndroid Build Coastguard Worker TEST_F(ParallelTest, ParallelApplyRethrowsException_MultiCUDA) {
165*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Cloneable<M> {
166*da0073e9SAndroid Build Coastguard Worker     void reset() override {}
167*da0073e9SAndroid Build Coastguard Worker     torch::Tensor forward(torch::Tensor input) {
168*da0073e9SAndroid Build Coastguard Worker       throw std::runtime_error("Badness!");
169*da0073e9SAndroid Build Coastguard Worker     }
170*da0073e9SAndroid Build Coastguard Worker   };
171*da0073e9SAndroid Build Coastguard Worker 
172*da0073e9SAndroid Build Coastguard Worker   auto m = std::make_shared<M>();
173*da0073e9SAndroid Build Coastguard Worker   auto input = torch::ones({10, 3});
174*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(parallel::data_parallel(m, input), "Badness!");
175*da0073e9SAndroid Build Coastguard Worker }
176*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParallelTest,DataParallelPlacesTheOutputOnTheRequestedDevice_MultiCUDA)177*da0073e9SAndroid Build Coastguard Worker TEST_F(
178*da0073e9SAndroid Build Coastguard Worker     ParallelTest,
179*da0073e9SAndroid Build Coastguard Worker     DataParallelPlacesTheOutputOnTheRequestedDevice_MultiCUDA) {
180*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Cloneable<M> {
181*da0073e9SAndroid Build Coastguard Worker     void reset() override {}
182*da0073e9SAndroid Build Coastguard Worker     torch::Tensor forward(torch::Tensor input) {
183*da0073e9SAndroid Build Coastguard Worker       // The returned tensor should be on the output device.
184*da0073e9SAndroid Build Coastguard Worker       return torch::ones(3);
185*da0073e9SAndroid Build Coastguard Worker     }
186*da0073e9SAndroid Build Coastguard Worker   };
187*da0073e9SAndroid Build Coastguard Worker   auto m = std::make_shared<M>();
188*da0073e9SAndroid Build Coastguard Worker   auto input = torch::ones({10, 3});
189*da0073e9SAndroid Build Coastguard Worker   {
190*da0073e9SAndroid Build Coastguard Worker     auto output = parallel::data_parallel(
191*da0073e9SAndroid Build Coastguard Worker         m,
192*da0073e9SAndroid Build Coastguard Worker         input,
193*da0073e9SAndroid Build Coastguard Worker         /*devices=*/torch::nullopt,
194*da0073e9SAndroid Build Coastguard Worker         /*output_device=*/torch::Device(torch::kCUDA, 1));
195*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.defined());
196*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.device().is_cuda());
197*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.device().index(), 1);
198*da0073e9SAndroid Build Coastguard Worker   }
199*da0073e9SAndroid Build Coastguard Worker   {
200*da0073e9SAndroid Build Coastguard Worker     // Verify for the single-device case (where we don't scatter/gather).
201*da0073e9SAndroid Build Coastguard Worker     auto output = parallel::data_parallel(
202*da0073e9SAndroid Build Coastguard Worker         m,
203*da0073e9SAndroid Build Coastguard Worker         input,
204*da0073e9SAndroid Build Coastguard Worker         /*devices=*/std::vector<torch::Device>{torch::Device(torch::kCUDA, 0)},
205*da0073e9SAndroid Build Coastguard Worker         /*output_device=*/torch::Device(torch::kCUDA, 1));
206*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.defined());
207*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(output.device().is_cuda());
208*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output.device().index(), 1);
209*da0073e9SAndroid Build Coastguard Worker   }
210*da0073e9SAndroid Build Coastguard Worker }
211*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParallelTest,DataParallelUsesAllAvailableCUDADevices_CUDA)212*da0073e9SAndroid Build Coastguard Worker TEST_F(ParallelTest, DataParallelUsesAllAvailableCUDADevices_CUDA) {
213*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Cloneable<M> {
214*da0073e9SAndroid Build Coastguard Worker     void reset() override {}
215*da0073e9SAndroid Build Coastguard Worker     torch::Tensor forward(torch::Tensor input) {
216*da0073e9SAndroid Build Coastguard Worker       return torch::tensor({input.device().index()});
217*da0073e9SAndroid Build Coastguard Worker     }
218*da0073e9SAndroid Build Coastguard Worker   };
219*da0073e9SAndroid Build Coastguard Worker 
220*da0073e9SAndroid Build Coastguard Worker   auto m = std::make_shared<M>();
221*da0073e9SAndroid Build Coastguard Worker   const auto device_count = torch::cuda::device_count();
222*da0073e9SAndroid Build Coastguard Worker   auto input = torch::ones({std::max(10, int(2 * device_count)), 3});
223*da0073e9SAndroid Build Coastguard Worker   auto output = parallel::data_parallel(m, input);
224*da0073e9SAndroid Build Coastguard Worker 
225*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output.numel(), device_count);
226*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(device_count)) {
227*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(output[i].item<int32_t>(), i);
228*da0073e9SAndroid Build Coastguard Worker   }
229*da0073e9SAndroid Build Coastguard Worker }
230*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ParallelTest,DataParallelNumericalEquivalence_MultiCUDA)231*da0073e9SAndroid Build Coastguard Worker TEST_F(ParallelTest, DataParallelNumericalEquivalence_MultiCUDA) {
232*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Cloneable<M> {
233*da0073e9SAndroid Build Coastguard Worker     M() {
234*da0073e9SAndroid Build Coastguard Worker       reset();
235*da0073e9SAndroid Build Coastguard Worker     }
236*da0073e9SAndroid Build Coastguard Worker 
237*da0073e9SAndroid Build Coastguard Worker     void reset() override {
238*da0073e9SAndroid Build Coastguard Worker       conv = register_module(
239*da0073e9SAndroid Build Coastguard Worker           "conv",
240*da0073e9SAndroid Build Coastguard Worker           torch::nn::Conv2d(torch::nn::Conv2dOptions(2, 2, /*kernel_size=*/2)));
241*da0073e9SAndroid Build Coastguard Worker       fc = register_module("fc", torch::nn::Linear(8, 2));
242*da0073e9SAndroid Build Coastguard Worker     }
243*da0073e9SAndroid Build Coastguard Worker 
244*da0073e9SAndroid Build Coastguard Worker     torch::Tensor forward(torch::Tensor x) {
245*da0073e9SAndroid Build Coastguard Worker       x = conv->forward(x);
246*da0073e9SAndroid Build Coastguard Worker       x = torch::relu(x);
247*da0073e9SAndroid Build Coastguard Worker       x = x.view({-1, 8});
248*da0073e9SAndroid Build Coastguard Worker       x = fc->forward(x);
249*da0073e9SAndroid Build Coastguard Worker       return torch::log_softmax(x, /*dim=*/1);
250*da0073e9SAndroid Build Coastguard Worker     }
251*da0073e9SAndroid Build Coastguard Worker 
252*da0073e9SAndroid Build Coastguard Worker     torch::nn::Conv2d conv{nullptr};
253*da0073e9SAndroid Build Coastguard Worker     torch::nn::Linear fc{nullptr};
254*da0073e9SAndroid Build Coastguard Worker   };
255*da0073e9SAndroid Build Coastguard Worker 
256*da0073e9SAndroid Build Coastguard Worker   // prepare modules and inputs
257*da0073e9SAndroid Build Coastguard Worker   auto input = torch::ones({16, 2, 3, 3});
258*da0073e9SAndroid Build Coastguard Worker   auto input_dp = torch::ones({16, 2, 3, 3});
259*da0073e9SAndroid Build Coastguard Worker   auto model = std::make_shared<M>();
260*da0073e9SAndroid Build Coastguard Worker   auto model_dp = std::dynamic_pointer_cast<M>(model->clone());
261*da0073e9SAndroid Build Coastguard Worker 
262*da0073e9SAndroid Build Coastguard Worker   // run 3 training iterations
263*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(3)) {
264*da0073e9SAndroid Build Coastguard Worker     input += i;
265*da0073e9SAndroid Build Coastguard Worker     input_dp += i;
266*da0073e9SAndroid Build Coastguard Worker 
267*da0073e9SAndroid Build Coastguard Worker     // non-prallel training
268*da0073e9SAndroid Build Coastguard Worker     torch::optim::SGD optim(model->parameters(), torch::optim::SGDOptions(0.1));
269*da0073e9SAndroid Build Coastguard Worker     auto output = model->forward(input);
270*da0073e9SAndroid Build Coastguard Worker     auto loss = torch::mse_loss(output, torch::zeros_like(output));
271*da0073e9SAndroid Build Coastguard Worker     loss.backward();
272*da0073e9SAndroid Build Coastguard Worker     optim.step();
273*da0073e9SAndroid Build Coastguard Worker 
274*da0073e9SAndroid Build Coastguard Worker     // data-parallel training
275*da0073e9SAndroid Build Coastguard Worker     torch::optim::SGD optim_dp(
276*da0073e9SAndroid Build Coastguard Worker         model_dp->parameters(), torch::optim::SGDOptions(0.1));
277*da0073e9SAndroid Build Coastguard Worker     auto output_dp = parallel::data_parallel(model_dp, input_dp);
278*da0073e9SAndroid Build Coastguard Worker     auto loss_dp = torch::mse_loss(output_dp, torch::zeros_like(output_dp));
279*da0073e9SAndroid Build Coastguard Worker     loss_dp.backward();
280*da0073e9SAndroid Build Coastguard Worker     optim_dp.step();
281*da0073e9SAndroid Build Coastguard Worker 
282*da0073e9SAndroid Build Coastguard Worker     // make sure that weights are the same
283*da0073e9SAndroid Build Coastguard Worker     model->to(torch::kCPU);
284*da0073e9SAndroid Build Coastguard Worker     model_dp->to(torch::kCPU);
285*da0073e9SAndroid Build Coastguard Worker     auto params = model->parameters();
286*da0073e9SAndroid Build Coastguard Worker     auto params_dp = model_dp->parameters();
287*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(params.size(), params_dp.size());
288*da0073e9SAndroid Build Coastguard Worker     for (auto it = params.begin(), it_dp = params_dp.begin();
289*da0073e9SAndroid Build Coastguard Worker          it != params.end() && it_dp != params.end();
290*da0073e9SAndroid Build Coastguard Worker          ++it, ++it_dp) {
291*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(torch::allclose(*it, *it_dp));
292*da0073e9SAndroid Build Coastguard Worker     }
293*da0073e9SAndroid Build Coastguard Worker   }
294*da0073e9SAndroid Build Coastguard Worker }
295