1 #include <torch/csrc/autograd/functions/comm.h>
2
3 #include <ATen/core/functional.h>
4 #include <torch/csrc/autograd/function.h>
5 #include <torch/csrc/autograd/functions/utils.h>
6 #include <torch/csrc/autograd/variable.h>
7 #include <torch/csrc/cuda/comm.h>
8
9 #include <ATen/ATen.h>
10 #include <ATen/cuda/CUDAContext.h>
11
12 #include <memory>
13 #include <vector>
14
15 namespace torch::autograd {
Scatter(std::vector<at::Device> devices,std::optional<std::vector<int64_t>> chunk_sizes,int64_t dim,std::optional<std::vector<std::optional<at::cuda::CUDAStream>>> streams,bool unsqueeze_scalars)16 Scatter::Scatter(
17 std::vector<at::Device> devices,
18 std::optional<std::vector<int64_t>> chunk_sizes,
19 int64_t dim,
20 std::optional<std::vector<std::optional<at::cuda::CUDAStream>>> streams,
21 bool unsqueeze_scalars)
22 : devices_(std::move(devices)),
23 chunk_sizes_(std::move(chunk_sizes)),
24 dim_(dim),
25 streams_(std::move(streams)),
26 unsqueeze_scalars_(unsqueeze_scalars) {}
27
28 Scatter::~Scatter() = default;
29
apply(variable_list && inputs)30 variable_list Scatter::apply(variable_list&& inputs) {
31 AT_ASSERT(inputs.size() == 1);
32 auto& input = inputs.front();
33
34 std::shared_ptr<Node> grad_fn;
35 if (compute_requires_grad(input)) {
36 grad_fn =
37 std::make_shared<Gather>(/*destination_device=*/input.device(), dim_);
38 grad_fn->set_next_edges(collect_next_edges(input));
39 }
40
41 auto device_indices = fmap(devices_, [](const at::Device& device) -> int64_t {
42 return device.index();
43 });
44 auto tensors =
45 torch::cuda::scatter(input, device_indices, chunk_sizes_, dim_, streams_);
46
47 std::vector<Variable> variables;
48 variables.reserve(tensors.size());
49 for (auto& tensor : tensors) {
50 AT_ASSERT(tensor.defined());
51 if (unsqueeze_scalars_) {
52 AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1);
53 variables.push_back(tensor[0]);
54 } else {
55 variables.push_back(std::move(tensor));
56 }
57 }
58
59 if (grad_fn) {
60 set_history(variables, grad_fn);
61 }
62
63 return variables;
64 }
65
Gather(const at::Device & destination_device,int64_t dim)66 Gather::Gather(const at::Device& destination_device, int64_t dim)
67 : destination_device_(destination_device), dim_(dim) {}
68
69 Gather::~Gather() = default;
70
apply(variable_list && inputs)71 variable_list Gather::apply(variable_list&& inputs) {
72 bool all_are_zero_dim = true;
73 for (const auto& input : inputs) {
74 TORCH_CHECK(
75 input.is_cuda(),
76 "All inputs to Gather must be CUDA tensors, got ",
77 input.toString());
78 if (input.dim() > 0) {
79 all_are_zero_dim = false;
80 }
81 }
82
83 const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0;
84 if (unsqueeze_scalars) {
85 TORCH_WARN(
86 "Was asked to gather along dimension 0, but all "
87 "input tensors were scalars; will instead unsqueeze "
88 "and return a vector.");
89 }
90
91 std::shared_ptr<Node> grad_fn;
92 // compute this before moving variables from `inputs`
93 if (compute_requires_grad(inputs)) {
94 std::vector<at::Device> source_devices;
95 source_devices.reserve(inputs.size());
96 std::vector<int64_t> input_sizes;
97 input_sizes.reserve(inputs.size());
98 for (auto& input : inputs) {
99 source_devices.push_back(input.device());
100 input_sizes.push_back(input.size(dim_));
101 }
102 grad_fn = std::make_shared<Scatter>(
103 std::move(source_devices),
104 std::move(input_sizes),
105 dim_,
106 /*streams=*/std::nullopt,
107 /*unsqueeze_scalars=*/unsqueeze_scalars);
108 grad_fn->set_next_edges(collect_next_edges(inputs));
109 }
110
111 std::vector<at::Tensor> tensors;
112 tensors.reserve(inputs.size());
113 for (auto& variable : inputs) {
114 if (unsqueeze_scalars) {
115 tensors.push_back(variable.view(1));
116 } else {
117 tensors.push_back(std::move(variable));
118 }
119 }
120
121 // Disable the autograd during the actual computation
122 // torch::cuda::gather does not return a view or change things inplace
123 // so no need for extra logic here
124 at::Tensor variable;
125 {
126 at::AutoDispatchBelowAutograd mode;
127 // This is special logic for torch::cuda::gather!
128 const auto destination_index =
129 destination_device_.is_cpu() ? -1 : destination_device_.index();
130 variable = torch::cuda::gather(tensors, dim_, destination_index);
131 }
132 if (grad_fn) {
133 set_history(variable, grad_fn);
134 }
135 return {variable};
136 }
137
138 } // namespace torch::autograd
139