1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/core/dispatch/Dispatcher.h>
10
11 #include <utility>
12
13 namespace at::functorch {
14
getStepTensor(const Tensor & indices,const c10::SymInt & bdim_size,const c10::SymInt & num_embeddings)15 static Tensor getStepTensor(const Tensor& indices, const c10::SymInt& bdim_size, const c10::SymInt& num_embeddings) {
16 // [batch_size, 1, 1, 1, ..., 1]
17 c10::SymDimVector view_shape(indices.dim(), 1);
18 view_shape[0] = bdim_size;
19 auto range = at::arange(0, bdim_size * num_embeddings, num_embeddings, indices.options());
20 return range.view_symint(view_shape);
21 }
22
embedding_batch_rule(const Tensor & weight,std::optional<int64_t> weight_bdim,const Tensor & indices,std::optional<int64_t> indices_bdim,c10::SymInt padding_idx,bool scale_grad_by_freq,bool sparse)23 static std::tuple<Tensor, std::optional<int64_t>> embedding_batch_rule(
24 const Tensor& weight, std::optional<int64_t> weight_bdim,
25 const Tensor& indices, std::optional<int64_t> indices_bdim,
26 c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
27 if (!weight_bdim && indices_bdim) {
28 // B*, ED -> B*D
29 auto result = at::embedding_symint(weight, indices, std::move(padding_idx), scale_grad_by_freq, sparse);
30 return std::make_tuple(std::move(result), indices_bdim);
31 } else if (weight_bdim && !indices_bdim) {
32 // *, BED -> *, E(BD) -> *(BD) -> *BD
33 const auto batch_size = weight.size(*weight_bdim);
34 const auto weight_ = reshape_dim_into(*weight_bdim, /*embedding_dim*/1, weight);
35 auto result = at::embedding_symint(weight_, indices, std::move(padding_idx), scale_grad_by_freq, sparse);
36 result = reshape_dim_outof(-1, batch_size, result);
37 return std::make_tuple(result, result.dim() - 2);
38 }
39 TORCH_INTERNAL_ASSERT(weight_bdim && indices_bdim);
40 // B*, BED -> B*, (BE)D -> B*D
41 // We'll need to do something extra: add (0, E, 2*E, ...) to the indices.
42 const auto batch_size = weight.size(*weight_bdim);
43 const auto num_embeddings = weight.size((*weight_bdim == 0) ? 1 : 0);
44 const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
45 auto indices_ = moveBatchDimToFront(indices, indices_bdim);
46
47 const auto range = getStepTensor(indices, batch_size, num_embeddings);
48 indices_ = indices_ + range;
49 auto result = at::embedding_symint(weight_, indices_, std::move(padding_idx), scale_grad_by_freq, sparse);
50 return std::make_tuple(std::move(result), 0);
51 }
52
53 static std::tuple<Tensor, std::optional<int64_t>>
embedding_dense_backward_batch_rule(const Tensor & grad_,std::optional<int64_t> grad_bdim,const Tensor & indices_,std::optional<int64_t> indices_bdim,c10::SymInt num_weights,c10::SymInt padding_idx,bool scale_grad_by_freq)54 embedding_dense_backward_batch_rule(
55 const Tensor& grad_, std::optional<int64_t> grad_bdim,
56 const Tensor& indices_, std::optional<int64_t> indices_bdim,
57 c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) {
58 Tensor grad = grad_;
59 Tensor indices = indices_;
60 if (!indices_bdim && grad_bdim) {
61 const auto bdim_size = grad.sym_size(*grad_bdim);
62 grad = reshape_dim_into(*grad_bdim, -1, grad);
63 auto result = at::embedding_dense_backward_symint(
64 grad, indices, std::move(num_weights), std::move(padding_idx), scale_grad_by_freq);
65 result = reshape_dim_outof_symint(1, bdim_size, result);
66 return std::make_tuple(std::move(result), 1);
67 }
68 const auto bdim_size = indices.size(*indices_bdim);
69 indices = moveBatchDimToFront(indices, indices_bdim);
70 grad = moveBatchDimToFront(grad, grad_bdim);
71 grad = ensure_has_bdim(grad, grad_bdim.has_value(), bdim_size);
72 const auto range = getStepTensor(indices, bdim_size, num_weights);
73 auto result = at::embedding_dense_backward_symint(
74 grad, indices + range, num_weights * bdim_size, -1, scale_grad_by_freq);
75 result = reshape_dim_outof(0, bdim_size, result);
76 // Fill in the padding. We can't do it in the embedding_dense_backward call
77 // because we need to fill in multiple rows!
78 if (padding_idx >= 0) {
79 result.select_symint(1, std::move(padding_idx)).fill_(0);
80 }
81 return std::make_tuple(std::move(result), 0);
82 }
83
84 /**
85 * grid sample batch rule breaks down into 3 cases:
86 * case 1 (input is batched, grid is not):
87 * batch input along first dimension, unpack along first dimension
88 * 2d:
89 * input: N(BC)H_{in}W_{in}, grid: NH_{out}W_{out}2
90 * output: N(BC)H_{out}W_{out}
91 * 3d:
92 * input: N(BC)D_{in}H_{in}W_{in}, grid: ND_{out}H_{out}W_{out}3
93 * output: N(BC)D_{out}H_{out}W_{out}
94 * case 2 (input is not batched, grid is batched):
95 * batch grid along second dimension, unpack along second dimension
96 * 2d:
97 * input: NCH_{in}W_{in}, grid: N(BH_{out})W_{out}2
98 * output: NC(BH_{out})W_{out}
99 * 3d:
100 * input: NCD_{in}H_{in}W_{in}, grid: N(BD_{out})H_{out}W_{out}3
101 * output: NC(BD_{out})H_{out}W_{out}
102 * case 3 (input and grid are both batched):
103 * batch grid and input along 0th dimension, unpack along 0th dimension
104 * 2d:
105 * input: (BN)CH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
106 * output: (BN)CH_{out}W_{out}
107 * 3d:
108 * input: (BN)CD_{in}H_{in}W_{in}, grid: (BN)D_{out}H_{out}W_{out}3
109 * output: (BN)CD_{out}H_{out}W_{out}
110 */
111 template<typename F, F Func, typename... ExtraArgs>
112 std::tuple<Tensor, std::optional<int64_t>>
grid_sample_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & grid,std::optional<int64_t> grid_bdim,ExtraArgs...extra_args)113 grid_sample_batch_rule(const Tensor& input, std::optional<int64_t> input_bdim, const Tensor& grid, std::optional<int64_t> grid_bdim, ExtraArgs... extra_args) {
114 std::tuple<Tensor, std::optional<int64_t>> result;
115 if (input_bdim && !grid_bdim) {
116 auto new_input = reshape_dim_into(*input_bdim, 1, input);
117 auto out = Func(new_input, grid, std::forward<ExtraArgs>(extra_args)...);
118 out = reshape_dim_outof(1, input.sizes()[*input_bdim], out);
119 result = std::make_tuple(std::move(out), 1);
120 } else if (!input_bdim && grid_bdim) {
121 // grid of N(BH)W2 -> NC(BH)W or grid of N(BD)HBW3 -> NC(BD)HW
122 auto new_grid = reshape_dim_into(*grid_bdim, 1, grid);
123 auto out = Func(input, new_grid, std::forward<ExtraArgs>(extra_args)...);
124 out = reshape_dim_outof(2, grid.sizes()[*grid_bdim], out);
125 result = std::make_tuple(std::move(out), 2);
126 } else if (input_bdim && grid_bdim) {
127 auto new_input = reshape_dim_into(*input_bdim, 0, input);
128 auto new_grid = reshape_dim_into(*grid_bdim, 0, grid);
129 auto out = Func(new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
130 out = reshape_dim_outof(0, input.sizes()[*grid_bdim], out);
131 result = std::make_tuple(std::move(out), 0);
132 } else {
133 result = std::make_tuple(Func(input, grid, std::forward<ExtraArgs>(extra_args)...), std::nullopt);
134 }
135 return result;
136 }
137
138 static std::tuple<Tensor, Tensor, Tensor, int64_t>
grid_sample_backward_helper_in(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & grid,std::optional<int64_t> grid_bdim)139 grid_sample_backward_helper_in(
140 const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
141 const Tensor& input, std::optional<int64_t> input_bdim,
142 const Tensor& grid, std::optional<int64_t> grid_bdim) {
143
144 auto batch_size = get_bdim_size3(
145 grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
146
147 auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
148 grad_output_ = ensure_has_bdim(grad_output_, grad_output_bdim.has_value(), batch_size);
149 grad_output_ = reshape_dim_into(0, 0, grad_output_);
150
151 auto input_ = moveBatchDimToFront(input, input_bdim);
152 input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
153 input_ = reshape_dim_into(0, 0, input_);
154
155 auto grid_ = moveBatchDimToFront(grid, grid_bdim);
156 grid_ = ensure_has_bdim(grid_, grid_bdim.has_value(), batch_size);
157 grid_ = reshape_dim_into(0, 0, grid_);
158
159 return std::make_tuple(std::move(grad_output_), std::move(input_), std::move(grid_), batch_size);
160 }
161
162 static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
grid_sample_backward_helper_out(const std::tuple<Tensor,Tensor> & bw_out,std::optional<int64_t> grad_input_out_bdim,std::optional<int64_t> grad_grid_out_bdim,int64_t bdim_size)163 grid_sample_backward_helper_out(
164 const std::tuple<Tensor, Tensor> & bw_out,
165 std::optional<int64_t> grad_input_out_bdim,
166 std::optional<int64_t> grad_grid_out_bdim,
167 int64_t bdim_size) {
168 auto grad_input = std::get<0>(bw_out);
169 auto grad_grid = std::get<1>(bw_out);
170 grad_input = reshape_dim_outof(*grad_input_out_bdim, bdim_size, grad_input);
171 grad_grid = reshape_dim_outof(*grad_grid_out_bdim, bdim_size, grad_grid);
172 auto result = std::make_tuple(grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim);
173 return result;
174 }
175
176
177 template<typename F, F Func, typename... ExtraArgs>
178 std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
grid_sample_backward_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & grid,std::optional<int64_t> grid_bdim,ExtraArgs...extra_args)179 grid_sample_backward_batch_rule(
180 const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
181 const Tensor& input, std::optional<int64_t> input_bdim,
182 const Tensor& grid, std::optional<int64_t> grid_bdim,
183 ExtraArgs... extra_args) {
184
185 auto new_bw_input = grid_sample_backward_helper_in(
186 grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
187
188 auto new_grad_output = std::get<0>(new_bw_input);
189 auto new_input = std::get<1>(new_bw_input);
190 auto new_grid = std::get<2>(new_bw_input);
191 int64_t batch_size = std::get<3>(new_bw_input);
192
193 auto bw_out = Func(new_grad_output, new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
194
195 return grid_sample_backward_helper_out(bw_out, 0, 0, batch_size);
196 }
197
198 template<typename F, F Func>
199 std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
cudnn_grid_sample_backward_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & grid,std::optional<int64_t> grid_bdim,const Tensor & grad_output,std::optional<int64_t> grad_output_bdim)200 cudnn_grid_sample_backward_batch_rule(
201 const Tensor& input, std::optional<int64_t> input_bdim,
202 const Tensor& grid, std::optional<int64_t> grid_bdim,
203 const Tensor& grad_output, std::optional<int64_t> grad_output_bdim) {
204
205 auto new_bw_input = grid_sample_backward_helper_in(
206 grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
207
208 auto new_grad_output = std::get<0>(new_bw_input);
209 auto new_input = std::get<1>(new_bw_input);
210 auto new_grid = std::get<2>(new_bw_input);
211 int64_t bdim_size = std::get<3>(new_bw_input);
212
213 auto bw_out = Func(new_input, new_grid, new_grad_output);
214
215 return grid_sample_backward_helper_out(bw_out, 0, 0, bdim_size);
216 }
217
218 // TODO: replace with targetable functionalization
one_hot_decomposition_hack(const Tensor & self,int64_t num_classes)219 static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
220 TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
221 auto shape = self.sym_sizes().vec();
222
223 // empty tensor could be converted to one hot representation,
224 // but shape inference is not possible.
225 if (self.sym_numel() == 0) {
226 if (num_classes <= 0) {
227 AT_ERROR("Can not infer total number of classes from empty tensor.");
228 } else {
229 shape.push_back(num_classes);
230 return at::empty_symint(shape, self.options());
231 }
232 }
233
234 TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
235 "provide an explicit positive num_classes argument.");
236
237 // Disabling all of the following checks. This is OK because scatter has checks too.
238 // Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
239 // // non-empty tensor
240 // if (self.device().type() != at::kCUDA) {
241 // //for cuda, rely on device assert thrown by scatter
242 // TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
243 // }
244 // if (self.device().type() != at::kCUDA) {
245 // //rely on device asserts from scatter to avoid sync here
246 // TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
247 // }
248
249 shape.push_back(num_classes);
250 Tensor ret = at::zeros_symint(shape, self.options());
251 return ret.scatter(-1, self.unsqueeze(-1), 1);
252 }
253
254 template <typename A, A a, typename C>
255 struct UpsampleBackwardBatchRuleHelper;
256
257 template <typename F, F Func, typename A, typename B, typename C, typename... T>
258 struct UpsampleBackwardBatchRuleHelper<F, Func, typelist<A, B, C, T...>> {
applyat::functorch::UpsampleBackwardBatchRuleHelper259 static std::tuple<Tensor, std::optional<int64_t>> apply(
260 const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
261 c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size,
262 T... extra_args) {
263 auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
264 TORCH_INTERNAL_ASSERT(!input_size.empty());
265
266 // input_size is wrong so we correct it
267 c10::SymDimVector physical_input_size(input_size.begin(), input_size.end());
268 physical_input_size[0] = grad_output_.sym_sizes()[0];
269
270 auto out = Func(
271 grad_output_,
272 output_size,
273 physical_input_size,
274 std::forward<T>(extra_args)...);
275 return std::make_tuple(reshape_dim_outof_symint(0, grad_output.sym_sizes()[*grad_output_bdim], out), 0);
276 }
277
278 };
279
280 template <typename A, A a, typename C>
281 struct GridSampleBatchRuleHelper;
282
283 template <typename F, F Func, typename T1, typename T2, typename... T>
284 struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
applyat::functorch::GridSampleBatchRuleHelper285 static std::tuple<Tensor, std::optional<int64_t>> apply(
286 const Tensor& input, std::optional<int64_t> input_batch_dim,
287 const Tensor& grid, std::optional<int64_t> grid_batch_dim,
288 T... extra_args) {
289 return grid_sample_batch_rule<F, Func, T...>(
290 input, input_batch_dim, grid, grid_batch_dim, std::forward<T>(extra_args)...);
291 }
292 };
293
294 template <typename A, A a, typename C>
295 struct GridSampleBackwardBatchRuleHelper;
296
297 template <typename F, F Func, typename T1, typename T2, typename T3, typename... T>
298 struct GridSampleBackwardBatchRuleHelper<F, Func, typelist<T1, T2, T3, T...>> {
applyat::functorch::GridSampleBackwardBatchRuleHelper299 static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>> apply(
300 const Tensor& grad_output, std::optional<int64_t> grad_output_batch_dim,
301 const Tensor& input, std::optional<int64_t> input_batch_dim,
302 const Tensor& grid, std::optional<int64_t> grid_batch_dim,
303 T... extra_args) {
304 return grid_sample_backward_batch_rule<F, Func, T...>(
305 grad_output, grad_output_batch_dim,
306 input, input_batch_dim,
307 grid, grid_batch_dim,
308 std::forward<T>(extra_args)...);
309 }
310 };
311
312 template <typename F, F Func>
313 struct CudnnGridSampleBackwardBatchRuleHelper {
applyat::functorch::CudnnGridSampleBackwardBatchRuleHelper314 static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>> apply(
315 const Tensor& input, std::optional<int64_t> input_batch_dim,
316 const Tensor& grid, std::optional<int64_t> grid_batch_dim,
317 const Tensor& grad_output, std::optional<int64_t> grad_output_batch_dim) {
318 return cudnn_grid_sample_backward_batch_rule<F, Func>(
319 input, input_batch_dim,
320 grid, grid_batch_dim,
321 grad_output, grad_output_batch_dim
322 );
323 }
324 };
325
326 #define GRID_SAMPLE_BATCH_RULE(fn) SINGLE_ARG(\
327 GridSampleBatchRuleHelper<\
328 decltype(&ATEN_FN(fn)),\
329 &ATEN_FN(fn),\
330 c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
331
332 #define GRID_SAMPLE_BW_BATCH_RULE(fn) SINGLE_ARG(\
333 GridSampleBackwardBatchRuleHelper<\
334 decltype(&ATEN_FN(fn)),\
335 &ATEN_FN(fn),\
336 c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
337
338 #define CUDNN_GRID_SAMPLE_BW_BATCH_RULE(fn)\
339 CudnnGridSampleBackwardBatchRuleHelper<decltype(&ATEN_FN(fn)), &ATEN_FN(fn)>::apply
340
341 #define UPSAMPLE_BACKWARD(op) VMAP_SUPPORT(op, SINGLE_ARG(\
342 UpsampleBackwardBatchRuleHelper<\
343 decltype(&ATEN_FN(op)),\
344 &ATEN_FN(op),\
345 c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
346
347
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)348 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
349 EXISTING_BDIM(im2col);
350 EXISTING_BDIM(col2im);
351
352 VMAP_SUPPORT(embedding, embedding_batch_rule);
353 VMAP_SUPPORT(embedding_dense_backward, embedding_dense_backward_batch_rule);
354
355 VMAP_SUPPORT(grid_sampler_2d, GRID_SAMPLE_BATCH_RULE(grid_sampler));
356 VMAP_SUPPORT(grid_sampler_2d_backward, GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_2d_backward));
357
358 VMAP_SUPPORT(grid_sampler_3d, GRID_SAMPLE_BATCH_RULE(grid_sampler));
359 VMAP_SUPPORT(grid_sampler_3d_backward, GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_3d_backward));
360 VMAP_SUPPORT(cudnn_grid_sampler_backward, CUDNN_GRID_SAMPLE_BW_BATCH_RULE(cudnn_grid_sampler_backward));
361
362 VMAP_SUPPORT(cudnn_grid_sampler, GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler));
363
364 EXISTING_BDIM(pixel_shuffle);
365 EXISTING_BDIM(pixel_unshuffle);
366 EXISTING_BDIM(channel_shuffle);
367
368 VARIADIC_BDIMS(constant_pad_nd);
369 EXISTING_BDIM(reflection_pad1d);
370 EXISTING_BDIM(reflection_pad2d);
371 EXISTING_BDIM(reflection_pad3d);
372 EXISTING_BDIM(replication_pad1d);
373 EXISTING_BDIM(replication_pad2d);
374 EXISTING_BDIM(replication_pad3d);
375
376 EXISTING_BDIM_ALL_BOXED(replication_pad1d_backward);
377 EXISTING_BDIM_ALL_BOXED(replication_pad2d_backward);
378 EXISTING_BDIM_ALL_BOXED(replication_pad3d_backward);
379
380 EXISTING_BDIM_ALL_BOXED(reflection_pad1d_backward);
381 EXISTING_BDIM_ALL_BOXED(reflection_pad2d_backward);
382 EXISTING_BDIM_ALL_BOXED(reflection_pad3d_backward);
383
384 EXISTING_BDIM(upsample_bicubic2d);
385 EXISTING_BDIM(upsample_bilinear2d);
386 EXISTING_BDIM(upsample_linear1d);
387 EXISTING_BDIM(upsample_nearest1d);
388 EXISTING_BDIM(upsample_nearest2d);
389 EXISTING_BDIM(upsample_nearest3d);
390 EXISTING_BDIM(upsample_trilinear3d);
391 EXISTING_BDIM(_upsample_bilinear2d_aa);
392 EXISTING_BDIM(_upsample_bicubic2d_aa);
393
394 UPSAMPLE_BACKWARD(upsample_bicubic2d_backward);
395 UPSAMPLE_BACKWARD(upsample_bilinear2d_backward);
396 UPSAMPLE_BACKWARD(upsample_linear1d_backward);
397 UPSAMPLE_BACKWARD(upsample_nearest1d_backward);
398 UPSAMPLE_BACKWARD(upsample_nearest2d_backward);
399 UPSAMPLE_BACKWARD(upsample_nearest3d_backward);
400 UPSAMPLE_BACKWARD(upsample_trilinear3d_backward);
401 UPSAMPLE_BACKWARD(_upsample_bilinear2d_aa_backward);
402 UPSAMPLE_BACKWARD(_upsample_bicubic2d_aa_backward);
403
404 m.impl("one_hot", one_hot_decomposition_hack);
405 }
406
407 } // namespace at::functorch
408