xref: /aosp_15_r20/external/pytorch/test/cpp/api/parallel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/autograd/functions/comm.h>
5 #include <torch/nn/module.h>
6 #include <torch/nn/modules/conv.h>
7 #include <torch/nn/modules/linear.h>
8 #include <torch/nn/parallel/data_parallel.h>
9 #include <torch/nn/pimpl.h>
10 #include <torch/optim/sgd.h>
11 #include <torch/types.h>
12 #include <torch/utils.h>
13 
14 #include <test/cpp/api/support.h>
15 
16 #include <iostream>
17 #include <memory>
18 #include <utility>
19 #include <vector>
20 
21 using namespace torch::autograd;
22 using namespace torch::nn;
23 
24 struct ParallelTest : torch::test::SeedingFixture {};
25 
TEST_F(ParallelTest,DifferentiableScatter_MultiCUDA)26 TEST_F(ParallelTest, DifferentiableScatter_MultiCUDA) {
27   Scatter scatter(
28       {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
29 
30   auto input = torch::ones(10, torch::requires_grad(true));
31   auto output = scatter.apply({input});
32 
33   ASSERT_EQ(output.size(), 2);
34   ASSERT_EQ(output[0].size(0), 5);
35   ASSERT_EQ(output[1].size(0), 5);
36 
37   ASSERT_TRUE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
38                   .allclose(input));
39 
40   torch::Tensor sum = output[0].to({torch::kCUDA, 1}) + output[1];
41   sum.backward(torch::ones_like(sum));
42 
43   ASSERT_TRUE(input.grad().defined());
44   ASSERT_TRUE(input.grad().device().is_cpu());
45   ASSERT_EQ(input.grad().sum().item<int32_t>(), 10);
46 }
47 
TEST_F(ParallelTest,DifferentiableGather_MultiCUDA)48 TEST_F(ParallelTest, DifferentiableGather_MultiCUDA) {
49   Gather gather(torch::Device(torch::kCUDA, 1));
50 
51   auto a = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 0));
52   auto b = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 1));
53 
54   auto outputs = gather.apply({a, b});
55   ASSERT_EQ(outputs.size(), 1);
56   torch::Tensor output = outputs.front();
57 
58   ASSERT_EQ(output.size(0), 10);
59   ASSERT_EQ(output.device(), torch::Device(torch::kCUDA, 1));
60 
61   auto chunks = output.chunk(2);
62   ASSERT_TRUE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
63   ASSERT_TRUE(chunks[1].allclose(b));
64 
65   output.backward(torch::ones_like(output));
66 
67   ASSERT_TRUE(a.grad().defined());
68   ASSERT_EQ(a.grad().device(), torch::Device(torch::kCUDA, 0));
69   ASSERT_EQ(a.grad().sum().item<int32_t>(), 5);
70 
71   ASSERT_TRUE(b.grad().defined());
72   ASSERT_EQ(b.grad().device(), torch::Device(torch::kCUDA, 1));
73   ASSERT_EQ(b.grad().sum().item<int32_t>(), 5);
74 }
75 
TEST_F(ParallelTest,Replicate_MultiCUDA)76 TEST_F(ParallelTest, Replicate_MultiCUDA) {
77   Linear linear(3, 4);
78   auto replicas = parallel::replicate(
79       linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
80   ASSERT_EQ(replicas.size(), 2);
81 
82   auto original_parameters = linear->parameters();
83 
84   auto replica1_parameters = replicas[0]->parameters();
85   for (auto& parameter : replica1_parameters) {
86     ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 0));
87   }
88   replicas[0]->to(torch::kCPU);
89   ASSERT_EQ(replica1_parameters.size(), original_parameters.size());
90   for (const auto i : c10::irange(original_parameters.size())) {
91     ASSERT_TRUE(replica1_parameters[i].allclose(original_parameters[i]));
92     ASSERT_TRUE(
93         replica1_parameters[i].data_ptr<float>() !=
94         original_parameters[i].data_ptr<float>());
95   }
96 
97   auto replica2_parameters = replicas[1]->parameters();
98   for (auto& parameter : replica2_parameters) {
99     ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 1));
100   }
101   replicas[1]->to(torch::kCPU);
102   ASSERT_EQ(replica2_parameters.size(), original_parameters.size());
103   for (const auto i : c10::irange(original_parameters.size())) {
104     ASSERT_TRUE(replica2_parameters[i].allclose(original_parameters[i]));
105     ASSERT_TRUE(
106         replica2_parameters[i].data_ptr<float>() !=
107         original_parameters[i].data_ptr<float>());
108   }
109 }
110 
TEST_F(ParallelTest,ParallelApply_MultiCUDA)111 TEST_F(ParallelTest, ParallelApply_MultiCUDA) {
112   Linear a(3, 4);
113 
114   Linear b(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
115   b->to({torch::kCUDA, 0});
116 
117   Linear c(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
118   c->to({torch::kCUDA, 1});
119 
120   std::vector<Linear> modules = {a, b, c};
121   std::vector<torch::Tensor> inputs = {
122       torch::ones({2, 3}),
123       torch::ones({2, 3}, torch::device({torch::kCUDA, 0})),
124       torch::ones({2, 3}, torch::device({torch::kCUDA, 1}))};
125 
126   auto outputs = parallel::parallel_apply(modules, inputs);
127 
128   ASSERT_EQ(outputs.size(), 3);
129   ASSERT_TRUE(outputs[0].device().is_cpu());
130 
131   ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
132   ASSERT_TRUE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
133 
134   ASSERT_EQ(outputs[2].device(), torch::Device(torch::kCUDA, 1));
135   ASSERT_TRUE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
136 }
137 
TEST_F(ParallelTest,ParallelApplyWithDifferentOutputDevice_MultiCUDA)138 TEST_F(ParallelTest, ParallelApplyWithDifferentOutputDevice_MultiCUDA) {
139   struct M : torch::nn::Module {
140     torch::Tensor forward(torch::Tensor input) {
141       return torch::ones(5, torch::kInt32);
142     }
143   };
144 
145   std::vector<std::shared_ptr<M>> modules = {
146       std::make_shared<M>(), std::make_shared<M>(), std::make_shared<M>()};
147   std::vector<torch::Tensor> inputs = {
148       torch::empty({}), torch::empty({}), torch::empty({})};
149   std::vector<torch::Device> devices = {
150       {torch::kCUDA, 1}, {torch::kCUDA, 0}, {torch::kCPU}};
151 
152   auto outputs = parallel::parallel_apply(modules, inputs, devices);
153 
154   ASSERT_EQ(outputs.size(), 3);
155   ASSERT_TRUE(outputs[0].device().is_cuda());
156   ASSERT_EQ(outputs[0].device(), torch::Device(torch::kCUDA, 1));
157 
158   ASSERT_TRUE(outputs[1].device().is_cuda());
159   ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
160 
161   ASSERT_TRUE(outputs[2].device().is_cpu());
162 }
163 
TEST_F(ParallelTest,ParallelApplyRethrowsException_MultiCUDA)164 TEST_F(ParallelTest, ParallelApplyRethrowsException_MultiCUDA) {
165   struct M : torch::nn::Cloneable<M> {
166     void reset() override {}
167     torch::Tensor forward(torch::Tensor input) {
168       throw std::runtime_error("Badness!");
169     }
170   };
171 
172   auto m = std::make_shared<M>();
173   auto input = torch::ones({10, 3});
174   ASSERT_THROWS_WITH(parallel::data_parallel(m, input), "Badness!");
175 }
176 
TEST_F(ParallelTest,DataParallelPlacesTheOutputOnTheRequestedDevice_MultiCUDA)177 TEST_F(
178     ParallelTest,
179     DataParallelPlacesTheOutputOnTheRequestedDevice_MultiCUDA) {
180   struct M : torch::nn::Cloneable<M> {
181     void reset() override {}
182     torch::Tensor forward(torch::Tensor input) {
183       // The returned tensor should be on the output device.
184       return torch::ones(3);
185     }
186   };
187   auto m = std::make_shared<M>();
188   auto input = torch::ones({10, 3});
189   {
190     auto output = parallel::data_parallel(
191         m,
192         input,
193         /*devices=*/torch::nullopt,
194         /*output_device=*/torch::Device(torch::kCUDA, 1));
195     ASSERT_TRUE(output.defined());
196     ASSERT_TRUE(output.device().is_cuda());
197     ASSERT_EQ(output.device().index(), 1);
198   }
199   {
200     // Verify for the single-device case (where we don't scatter/gather).
201     auto output = parallel::data_parallel(
202         m,
203         input,
204         /*devices=*/std::vector<torch::Device>{torch::Device(torch::kCUDA, 0)},
205         /*output_device=*/torch::Device(torch::kCUDA, 1));
206     ASSERT_TRUE(output.defined());
207     ASSERT_TRUE(output.device().is_cuda());
208     ASSERT_EQ(output.device().index(), 1);
209   }
210 }
211 
TEST_F(ParallelTest,DataParallelUsesAllAvailableCUDADevices_CUDA)212 TEST_F(ParallelTest, DataParallelUsesAllAvailableCUDADevices_CUDA) {
213   struct M : torch::nn::Cloneable<M> {
214     void reset() override {}
215     torch::Tensor forward(torch::Tensor input) {
216       return torch::tensor({input.device().index()});
217     }
218   };
219 
220   auto m = std::make_shared<M>();
221   const auto device_count = torch::cuda::device_count();
222   auto input = torch::ones({std::max(10, int(2 * device_count)), 3});
223   auto output = parallel::data_parallel(m, input);
224 
225   ASSERT_EQ(output.numel(), device_count);
226   for (const auto i : c10::irange(device_count)) {
227     ASSERT_EQ(output[i].item<int32_t>(), i);
228   }
229 }
230 
TEST_F(ParallelTest,DataParallelNumericalEquivalence_MultiCUDA)231 TEST_F(ParallelTest, DataParallelNumericalEquivalence_MultiCUDA) {
232   struct M : torch::nn::Cloneable<M> {
233     M() {
234       reset();
235     }
236 
237     void reset() override {
238       conv = register_module(
239           "conv",
240           torch::nn::Conv2d(torch::nn::Conv2dOptions(2, 2, /*kernel_size=*/2)));
241       fc = register_module("fc", torch::nn::Linear(8, 2));
242     }
243 
244     torch::Tensor forward(torch::Tensor x) {
245       x = conv->forward(x);
246       x = torch::relu(x);
247       x = x.view({-1, 8});
248       x = fc->forward(x);
249       return torch::log_softmax(x, /*dim=*/1);
250     }
251 
252     torch::nn::Conv2d conv{nullptr};
253     torch::nn::Linear fc{nullptr};
254   };
255 
256   // prepare modules and inputs
257   auto input = torch::ones({16, 2, 3, 3});
258   auto input_dp = torch::ones({16, 2, 3, 3});
259   auto model = std::make_shared<M>();
260   auto model_dp = std::dynamic_pointer_cast<M>(model->clone());
261 
262   // run 3 training iterations
263   for (const auto i : c10::irange(3)) {
264     input += i;
265     input_dp += i;
266 
267     // non-prallel training
268     torch::optim::SGD optim(model->parameters(), torch::optim::SGDOptions(0.1));
269     auto output = model->forward(input);
270     auto loss = torch::mse_loss(output, torch::zeros_like(output));
271     loss.backward();
272     optim.step();
273 
274     // data-parallel training
275     torch::optim::SGD optim_dp(
276         model_dp->parameters(), torch::optim::SGDOptions(0.1));
277     auto output_dp = parallel::data_parallel(model_dp, input_dp);
278     auto loss_dp = torch::mse_loss(output_dp, torch::zeros_like(output_dp));
279     loss_dp.backward();
280     optim_dp.step();
281 
282     // make sure that weights are the same
283     model->to(torch::kCPU);
284     model_dp->to(torch::kCPU);
285     auto params = model->parameters();
286     auto params_dp = model_dp->parameters();
287     ASSERT_EQ(params.size(), params_dp.size());
288     for (auto it = params.begin(), it_dp = params_dp.begin();
289          it != params.end() && it_dp != params.end();
290          ++it, ++it_dp) {
291       ASSERT_TRUE(torch::allclose(*it, *it_dp));
292     }
293   }
294 }
295