1 #pragma once
2
3 #include <torch/cuda.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/types.h>
7
8 #include <ATen/core/functional.h>
9 #include <torch/csrc/autograd/functions/comm.h>
10 #include <torch/csrc/autograd/functions/utils.h>
11
12 #include <ATen/Device.h>
13 #include <ATen/Parallel.h>
14 #include <c10/core/TensorOptions.h>
15 #include <c10/util/Exception.h>
16 #include <c10/util/irange.h>
17
18 #include <cstddef>
19 #include <exception>
20 #include <memory>
21 #include <mutex>
22 #include <vector>
23
24 namespace torch {
25 namespace nn {
26
27 namespace {
28
29 // Note [Replicating Modules]
30 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
31 //
32 // Module replication is implemented in the following two steps:
33 // 1) create a module replica on each destination device using Module.clone().
34 // 2) manually add a gradient edge pointing from every parameter X in every
35 // module replica to the same parameter X in the original module, using
36 // ReduceAdd as the grad_fn.
37 //
38 // ReduceAdd can ONLY be used during the backward pass of data parallel. Forward
39 // pass cannot use this function as it does not setup gradient function and
40 // history at all. Do NOT try to use ReduceAdd for any other purposes.
41 //
42 // NB: An alternative is to add Broadcast and ReduceAddCoalesce to
43 // torch/csrc/autograd/functions/comm.cpp as normal autograd functions,
44 // implement a Replicatable (like cloneable) class and add it as a friend class
45 // in Module.h. In the forward pass, the Replicatable could use the Broadcast
46 // function to replicate every module parameter and set gradient functions using
47 // ReduceAddCoalesce (like how it is implemented in Python). However, unlike in
48 // Python, where changes to Linear._parameters["weight"] would also apply to
49 // Linear.weight (using Linear as an example), Linear.weight and
50 // Linear.parameters_["weight"] are two tensor objects pointing to the same
51 // TensorImpl. Assigning a new tensor to Linear.parameters_["weight"] will not
52 // change Linear.weight. To make this work, we will have to:
53 // 1) force every module to also inherit from Replicatable
54 // 2) force every module to implement an additional function, e.g.,
55 // Replicatable::load_params(), to pick up changes from parameters_ to their
56 // own member fields.
57 // This will be an overkill as Replicatable will only be used in data_parallel,
58 // not even ddp.
59
60 // Autograd function for the replicate step in data parallel. This is only used
61 // in data parallel, and should not be exposed as a user API.
62 struct ReduceAdd : public autograd::Node {
ReduceAddReduceAdd63 explicit ReduceAdd(const at::Device& destination_device)
64 : destination_device_(destination_device){};
~ReduceAddReduceAdd65 ~ReduceAdd() override {}
66
applyReduceAdd67 autograd::variable_list apply(autograd::variable_list&& inputs) override {
68 TORCH_CHECK(
69 !torch::autograd::compute_requires_grad(inputs),
70 "ReduceAdd can only be used during the backward pass of data parallel.");
71
72 Tensor output = torch::zeros_like(inputs[0], {destination_device_});
73
74 for (auto& input : inputs) {
75 TORCH_CHECK(
76 input.sizes() == inputs[0].sizes(),
77 "All inputs of ReduceAdd must have the same size, but got ",
78 input.sizes(),
79 " and ",
80 inputs[0].sizes());
81
82 TORCH_CHECK(
83 input.dtype() == inputs[0].dtype(),
84 "All inputs of ReduceAdd must have the same dtype, but got ",
85 input.dtype(),
86 " and ",
87 inputs[0].dtype());
88
89 // TODO: use nccl reduce
90 output.add_(input.to(destination_device_));
91 }
92
93 return {output};
94 }
95
96 private:
97 at::Device destination_device_;
98 };
99
100 } // namespace
101
102 // A friend function to Module, it recursively sets gradient edges pointing from
103 // every parameter X in every module replica to the same parameter X in the
104 // original module. See [Replicating Modules]
105 template <typename ModuleType>
replicate_grad_edges(const std::shared_ptr<Module> & module,const std::vector<std::shared_ptr<ModuleType>> & replicas,const std::vector<Device> & devices)106 void replicate_grad_edges(
107 const std::shared_ptr<Module>& module,
108 const std::vector<std::shared_ptr<ModuleType>>& replicas,
109 const std::vector<Device>& devices) {
110 for (auto& parameter : module->named_parameters(/*recurse=*/false)) {
111 auto grad_fn = std::make_shared<ReduceAdd>((*parameter).device());
112 grad_fn->set_next_edges(autograd::collect_next_edges(*parameter));
113
114 for (const auto i : c10::irange(devices.size())) {
115 autograd::set_history(replicas[i]->parameters_[parameter.key()], grad_fn);
116 }
117 }
118
119 for (auto& buffer : module->named_buffers(/*recurse=*/false)) {
120 if (buffer.value().requires_grad()) {
121 auto grad_fn = std::make_shared<ReduceAdd>((*buffer).device());
122 grad_fn->set_next_edges(autograd::collect_next_edges(*buffer));
123
124 for (const auto i : c10::irange(devices.size())) {
125 autograd::set_history(replicas[i]->buffers_[buffer.key()], grad_fn);
126 }
127 }
128 }
129
130 for (auto& child : module->children_) {
131 std::vector<std::shared_ptr<Module>> child_replicas;
132 child_replicas.reserve(devices.size());
133 for (auto& replica : replicas) {
134 child_replicas.push_back(replica->children_[child.key()]);
135 }
136
137 // recursively set gradient edges for all children
138 replicate_grad_edges(*child, child_replicas, devices);
139 }
140 }
141
142 namespace parallel {
143
144 /// Replicates a module on the given list of devices.
145 /// A replica is created by calling `clone()` on the module. For this, the
146 /// module must inherit from `nn::Cloneable`, or define its own `clone()`
147 /// method, which is expected to perform a deep copy of the module.
148 template <typename ModuleType>
replicate(const std::shared_ptr<ModuleType> & module,const std::vector<Device> & devices)149 std::vector<std::shared_ptr<ModuleType>> replicate(
150 const std::shared_ptr<ModuleType>& module,
151 const std::vector<Device>& devices) {
152 std::vector<std::shared_ptr<ModuleType>> replicas;
153 replicas.reserve(devices.size());
154 for (const auto& device : devices) {
155 replicas.push_back(
156 std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
157 }
158 // Configure gradient edges to point from replcia parameters to original
159 // module parameters. See [Replicating Modules]
160 replicate_grad_edges(module, replicas, devices);
161 return replicas;
162 }
163
164 /// Replicates a module holder on the given list of devices.
165 /// This method allows calling `replicate()` with a module holder, such as
166 /// `Linear`.
167 template <typename ModuleType>
replicate(const ModuleHolder<ModuleType> & module,const std::vector<Device> & devices)168 std::vector<ModuleHolder<ModuleType>> replicate(
169 const ModuleHolder<ModuleType>& module,
170 const std::vector<Device>& devices) {
171 auto ptrs = replicate(module.ptr(), devices);
172 return std::vector<ModuleHolder<ModuleType>>(ptrs.begin(), ptrs.end());
173 }
174
175 /// Applies the given inputs to the given modules in a parallel fashion.
176 /// Conceptually, a thread is spawned for each `(module, input)` pair, in which
177 /// `forward()` is called on the module with its corresponding input. The
178 /// outputs of the individual calls are stored in a vector and returned.
179 ///
180 /// The first exception caught by any thread is stashed and rethrown after all
181 /// threads have completed their operation.
182 ///
183 /// Further remarks:
184 /// 1. The length of the module container must match the length of the inputs.
185 /// 2. If a list of devices is supplied, it must match the list of modules in
186 /// length. Each device will be set to the current default device during the
187 /// invocation of the respective module. This means any tensors allocated on the
188 /// default device inside the module will be constructed on this device.
189 template <typename ModuleType>
190 std::vector<Tensor> parallel_apply(
191 std::vector<ModuleType>& modules,
192 const std::vector<Tensor>& inputs,
193 const std::optional<std::vector<Device>>& devices = std::nullopt) {
194 TORCH_CHECK(
195 modules.size() == inputs.size(), "Must have as many inputs as modules");
196 if (devices) {
197 TORCH_CHECK(
198 modules.size() == devices->size(),
199 "Must have as many devices as modules");
200 }
201
202 std::vector<Tensor> outputs(modules.size());
203 std::mutex mutex;
204
205 // std::exception_ptr can be passed between threads:
206 // > An instance of std::exception_ptr may be passed to another function,
207 // > possibly on another thread, where the exception may be rethrown [...].
208 // https://en.cppreference.com/w/cpp/error/exception_ptr
209 std::exception_ptr exception;
210
211 at::parallel_for(
212 /*begin=*/0,
213 /*end=*/modules.size(),
214 /*grain_size=*/1,
215 [&modules, &inputs, &devices, &outputs, &mutex, &exception](
216 int64_t index, int64_t stop) {
217 for (; index < stop; ++index) {
218 try {
219 auto output = modules[index]->forward(inputs[index]);
220 output =
221 output.to(devices ? (*devices)[index] : inputs[index].device());
222 std::lock_guard<std::mutex> lock(mutex);
223 outputs[index] = output;
224 } catch (...) {
225 std::lock_guard<std::mutex> lock(mutex);
226 if (!exception) {
227 exception = std::current_exception();
228 }
229 }
230 }
231 });
232
233 if (exception) {
234 std::rethrow_exception(exception);
235 }
236
237 return outputs;
238 }
239
240 /// Evaluates `module(input)` in parallel across the given `devices`. If
241 /// `devices` is not supplied, the invocation is parallelized across all
242 /// available CUDA devices. If `output_device` is supplied, the final, combined
243 /// tensor will be placed on this device. If not, it defaults to the first
244 /// device in `devices`.
245 ///
246 /// In detail, this method performs the following four distinct steps:
247 /// 1. *Scatter* the input to the given devices,
248 /// 2. *Replicate* (deep clone) the model on each device,
249 /// 3. *Evaluate* each module with its input on its device,
250 /// 4. *Gather* the outputs of each replica into a single output tensor, located
251 /// on the `output_device`.
252 template <typename ModuleType>
253 Tensor data_parallel(
254 ModuleType module,
255 Tensor input,
256 std::optional<std::vector<Device>> devices = std::nullopt,
257 std::optional<Device> output_device = std::nullopt,
258 int64_t dim = 0) {
259 if (!devices) {
260 const auto device_count = torch::cuda::device_count();
261 TORCH_CHECK(
262 device_count > 0, "Expected at least one CUDA device to be available");
263 devices = std::vector<Device>();
264 devices->reserve(device_count);
265 for (const auto index : c10::irange(device_count)) {
266 devices->emplace_back(kCUDA, static_cast<torch::DeviceIndex>(index));
267 }
268 }
269 if (!output_device) {
270 output_device = devices->front();
271 }
272
273 if (devices->size() == 1) {
274 module->to(devices->front());
275 input = input.to(devices->front());
276 return module->forward(std::move(input)).to(*output_device);
277 }
278
279 autograd::Scatter scatter(*devices, /*chunk_sizes=*/nullopt, dim);
280 auto scattered_inputs = fmap<Tensor>(scatter.apply({std::move(input)}));
281 // Input tensor might not be big enough to scale across all available devices
282 if (scattered_inputs.size() < devices->size()) {
283 devices->resize(
284 scattered_inputs.size(),
285 Device(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES));
286 }
287
288 auto replicas = replicate(module, *devices);
289 auto outputs = parallel_apply(replicas, scattered_inputs, *devices);
290 return autograd::Gather(*output_device, dim)
291 .apply(fmap<autograd::Variable>(std::move(outputs)))
292 .front();
293 }
294
295 } // namespace parallel
296 } // namespace nn
297 } // namespace torch
298