xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesModules.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 #include <ATen/core/dispatch/Dispatcher.h>
10 
11 #include <utility>
12 
13 namespace at::functorch {
14 
getStepTensor(const Tensor & indices,const c10::SymInt & bdim_size,const c10::SymInt & num_embeddings)15 static Tensor getStepTensor(const Tensor& indices, const c10::SymInt& bdim_size, const c10::SymInt& num_embeddings) {
16   // [batch_size, 1, 1, 1, ..., 1]
17   c10::SymDimVector view_shape(indices.dim(), 1);
18   view_shape[0] = bdim_size;
19   auto range = at::arange(0, bdim_size * num_embeddings, num_embeddings, indices.options());
20   return range.view_symint(view_shape);
21 }
22 
embedding_batch_rule(const Tensor & weight,std::optional<int64_t> weight_bdim,const Tensor & indices,std::optional<int64_t> indices_bdim,c10::SymInt padding_idx,bool scale_grad_by_freq,bool sparse)23 static std::tuple<Tensor, std::optional<int64_t>> embedding_batch_rule(
24     const Tensor& weight, std::optional<int64_t> weight_bdim,
25     const Tensor& indices, std::optional<int64_t> indices_bdim,
26     c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
27   if (!weight_bdim && indices_bdim) {
28     // B*, ED -> B*D
29     auto result = at::embedding_symint(weight, indices, std::move(padding_idx), scale_grad_by_freq, sparse);
30     return std::make_tuple(std::move(result), indices_bdim);
31   } else if (weight_bdim && !indices_bdim) {
32     // *, BED -> *, E(BD) -> *(BD) -> *BD
33     const auto batch_size = weight.size(*weight_bdim);
34     const auto weight_ = reshape_dim_into(*weight_bdim, /*embedding_dim*/1, weight);
35     auto result = at::embedding_symint(weight_, indices, std::move(padding_idx), scale_grad_by_freq, sparse);
36     result = reshape_dim_outof(-1, batch_size, result);
37     return std::make_tuple(result, result.dim() - 2);
38   }
39   TORCH_INTERNAL_ASSERT(weight_bdim && indices_bdim);
40   // B*, BED -> B*, (BE)D -> B*D
41   // We'll need to do something extra: add (0, E, 2*E, ...) to the indices.
42   const auto batch_size = weight.size(*weight_bdim);
43   const auto num_embeddings = weight.size((*weight_bdim == 0) ? 1 : 0);
44   const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
45   auto indices_ = moveBatchDimToFront(indices, indices_bdim);
46 
47   const auto range = getStepTensor(indices, batch_size, num_embeddings);
48   indices_ = indices_ + range;
49   auto result = at::embedding_symint(weight_, indices_, std::move(padding_idx), scale_grad_by_freq, sparse);
50   return std::make_tuple(std::move(result), 0);
51 }
52 
53 static std::tuple<Tensor, std::optional<int64_t>>
embedding_dense_backward_batch_rule(const Tensor & grad_,std::optional<int64_t> grad_bdim,const Tensor & indices_,std::optional<int64_t> indices_bdim,c10::SymInt num_weights,c10::SymInt padding_idx,bool scale_grad_by_freq)54 embedding_dense_backward_batch_rule(
55     const Tensor& grad_, std::optional<int64_t> grad_bdim,
56     const Tensor& indices_, std::optional<int64_t> indices_bdim,
57     c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) {
58   Tensor grad = grad_;
59   Tensor indices = indices_;
60   if (!indices_bdim && grad_bdim) {
61     const auto bdim_size = grad.sym_size(*grad_bdim);
62     grad = reshape_dim_into(*grad_bdim, -1, grad);
63     auto result = at::embedding_dense_backward_symint(
64         grad, indices, std::move(num_weights), std::move(padding_idx), scale_grad_by_freq);
65     result = reshape_dim_outof_symint(1, bdim_size, result);
66     return std::make_tuple(std::move(result), 1);
67   }
68   const auto bdim_size = indices.size(*indices_bdim);
69   indices = moveBatchDimToFront(indices, indices_bdim);
70   grad = moveBatchDimToFront(grad, grad_bdim);
71   grad = ensure_has_bdim(grad, grad_bdim.has_value(), bdim_size);
72   const auto range = getStepTensor(indices, bdim_size, num_weights);
73   auto result = at::embedding_dense_backward_symint(
74       grad, indices + range, num_weights * bdim_size, -1, scale_grad_by_freq);
75   result = reshape_dim_outof(0, bdim_size, result);
76   // Fill in the padding. We can't do it in the embedding_dense_backward call
77   // because we need to fill in multiple rows!
78   if (padding_idx >= 0) {
79     result.select_symint(1, std::move(padding_idx)).fill_(0);
80   }
81   return std::make_tuple(std::move(result), 0);
82 }
83 
84 /**
85  * grid sample batch rule breaks down into 3 cases:
86  *   case 1 (input is batched, grid is not):
87  *     batch input along first dimension, unpack along first dimension
88  *     2d:
89  *       input: N(BC)H_{in}W_{in}, grid: NH_{out}W_{out}2
90  *       output: N(BC)H_{out}W_{out}
91  *     3d:
92  *       input: N(BC)D_{in}H_{in}W_{in}, grid: ND_{out}H_{out}W_{out}3
93  *       output: N(BC)D_{out}H_{out}W_{out}
94  *   case 2 (input is not batched, grid is batched):
95  *     batch grid along second dimension, unpack along second dimension
96  *     2d:
97  *       input: NCH_{in}W_{in}, grid: N(BH_{out})W_{out}2
98  *       output: NC(BH_{out})W_{out}
99  *     3d:
100  *       input: NCD_{in}H_{in}W_{in}, grid: N(BD_{out})H_{out}W_{out}3
101  *       output: NC(BD_{out})H_{out}W_{out}
102  *   case 3 (input and grid are both batched):
103  *     batch grid and input along 0th dimension, unpack along 0th dimension
104  *     2d:
105  *       input: (BN)CH_{in}W_{in}, grid: (BN)H_{out}W_{out}2
106  *       output: (BN)CH_{out}W_{out}
107  *     3d:
108  *       input: (BN)CD_{in}H_{in}W_{in}, grid: (BN)D_{out}H_{out}W_{out}3
109  *       output: (BN)CD_{out}H_{out}W_{out}
110  */
111 template<typename F, F Func, typename... ExtraArgs>
112 std::tuple<Tensor, std::optional<int64_t>>
grid_sample_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & grid,std::optional<int64_t> grid_bdim,ExtraArgs...extra_args)113 grid_sample_batch_rule(const Tensor& input, std::optional<int64_t> input_bdim, const Tensor& grid, std::optional<int64_t> grid_bdim, ExtraArgs... extra_args) {
114   std::tuple<Tensor, std::optional<int64_t>> result;
115   if (input_bdim && !grid_bdim) {
116     auto new_input = reshape_dim_into(*input_bdim, 1, input);
117     auto out = Func(new_input, grid, std::forward<ExtraArgs>(extra_args)...);
118     out = reshape_dim_outof(1, input.sizes()[*input_bdim], out);
119     result = std::make_tuple(std::move(out), 1);
120   } else if (!input_bdim && grid_bdim) {
121     // grid of N(BH)W2 -> NC(BH)W or grid of N(BD)HBW3 -> NC(BD)HW
122     auto new_grid = reshape_dim_into(*grid_bdim, 1, grid);
123     auto out = Func(input, new_grid, std::forward<ExtraArgs>(extra_args)...);
124     out = reshape_dim_outof(2, grid.sizes()[*grid_bdim], out);
125     result = std::make_tuple(std::move(out), 2);
126   } else if (input_bdim && grid_bdim) {
127     auto new_input = reshape_dim_into(*input_bdim, 0, input);
128     auto new_grid = reshape_dim_into(*grid_bdim, 0, grid);
129     auto out = Func(new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
130     out = reshape_dim_outof(0, input.sizes()[*grid_bdim], out);
131     result = std::make_tuple(std::move(out), 0);
132   } else {
133     result = std::make_tuple(Func(input, grid, std::forward<ExtraArgs>(extra_args)...), std::nullopt);
134   }
135   return result;
136 }
137 
138 static std::tuple<Tensor, Tensor, Tensor, int64_t>
grid_sample_backward_helper_in(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & grid,std::optional<int64_t> grid_bdim)139 grid_sample_backward_helper_in(
140     const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
141     const Tensor& input, std::optional<int64_t> input_bdim,
142     const Tensor& grid, std::optional<int64_t> grid_bdim) {
143 
144   auto batch_size = get_bdim_size3(
145       grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
146 
147   auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
148   grad_output_ = ensure_has_bdim(grad_output_, grad_output_bdim.has_value(), batch_size);
149   grad_output_ = reshape_dim_into(0, 0, grad_output_);
150 
151   auto input_ = moveBatchDimToFront(input, input_bdim);
152   input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
153   input_ = reshape_dim_into(0, 0, input_);
154 
155   auto grid_ = moveBatchDimToFront(grid, grid_bdim);
156   grid_ = ensure_has_bdim(grid_, grid_bdim.has_value(), batch_size);
157   grid_ = reshape_dim_into(0, 0, grid_);
158 
159   return std::make_tuple(std::move(grad_output_), std::move(input_), std::move(grid_), batch_size);
160 }
161 
162 static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
grid_sample_backward_helper_out(const std::tuple<Tensor,Tensor> & bw_out,std::optional<int64_t> grad_input_out_bdim,std::optional<int64_t> grad_grid_out_bdim,int64_t bdim_size)163 grid_sample_backward_helper_out(
164     const std::tuple<Tensor, Tensor> & bw_out,
165     std::optional<int64_t> grad_input_out_bdim,
166     std::optional<int64_t> grad_grid_out_bdim,
167     int64_t bdim_size) {
168   auto grad_input = std::get<0>(bw_out);
169   auto grad_grid = std::get<1>(bw_out);
170   grad_input = reshape_dim_outof(*grad_input_out_bdim, bdim_size, grad_input);
171   grad_grid = reshape_dim_outof(*grad_grid_out_bdim, bdim_size, grad_grid);
172   auto result = std::make_tuple(grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim);
173   return result;
174 }
175 
176 
177 template<typename F, F Func, typename... ExtraArgs>
178 std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
grid_sample_backward_batch_rule(const Tensor & grad_output,std::optional<int64_t> grad_output_bdim,const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & grid,std::optional<int64_t> grid_bdim,ExtraArgs...extra_args)179 grid_sample_backward_batch_rule(
180     const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
181     const Tensor& input, std::optional<int64_t> input_bdim,
182     const Tensor& grid, std::optional<int64_t> grid_bdim,
183     ExtraArgs... extra_args) {
184 
185   auto new_bw_input = grid_sample_backward_helper_in(
186       grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
187 
188   auto new_grad_output = std::get<0>(new_bw_input);
189   auto new_input = std::get<1>(new_bw_input);
190   auto new_grid = std::get<2>(new_bw_input);
191   int64_t batch_size = std::get<3>(new_bw_input);
192 
193   auto bw_out = Func(new_grad_output, new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
194 
195   return grid_sample_backward_helper_out(bw_out, 0, 0, batch_size);
196 }
197 
198 template<typename F, F Func>
199 std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
cudnn_grid_sample_backward_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & grid,std::optional<int64_t> grid_bdim,const Tensor & grad_output,std::optional<int64_t> grad_output_bdim)200 cudnn_grid_sample_backward_batch_rule(
201     const Tensor& input, std::optional<int64_t> input_bdim,
202     const Tensor& grid, std::optional<int64_t> grid_bdim,
203     const Tensor& grad_output, std::optional<int64_t> grad_output_bdim) {
204 
205   auto new_bw_input = grid_sample_backward_helper_in(
206       grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
207 
208   auto new_grad_output = std::get<0>(new_bw_input);
209   auto new_input = std::get<1>(new_bw_input);
210   auto new_grid = std::get<2>(new_bw_input);
211   int64_t bdim_size = std::get<3>(new_bw_input);
212 
213   auto bw_out = Func(new_input, new_grid, new_grad_output);
214 
215   return grid_sample_backward_helper_out(bw_out, 0, 0, bdim_size);
216 }
217 
218 // TODO: replace with targetable functionalization
one_hot_decomposition_hack(const Tensor & self,int64_t num_classes)219 static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
220     TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
221     auto shape = self.sym_sizes().vec();
222 
223     // empty tensor could be converted to one hot representation,
224     // but shape inference is not possible.
225     if (self.sym_numel() == 0) {
226         if (num_classes <= 0) {
227             AT_ERROR("Can not infer total number of classes from empty tensor.");
228         } else {
229             shape.push_back(num_classes);
230             return at::empty_symint(shape, self.options());
231         }
232     }
233 
234     TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
235         "provide an explicit positive num_classes argument.");
236 
237     // Disabling all of the following checks. This is OK because scatter has checks too.
238     // Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
239     // // non-empty tensor
240     // if (self.device().type() != at::kCUDA) {
241     //   //for cuda, rely on device assert thrown by scatter
242     //   TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
243     // }
244     // if (self.device().type() != at::kCUDA) {
245     //   //rely on device asserts from scatter to avoid sync here
246     //   TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
247     // }
248 
249     shape.push_back(num_classes);
250     Tensor ret = at::zeros_symint(shape, self.options());
251     return ret.scatter(-1, self.unsqueeze(-1), 1);
252 }
253 
254 template <typename A, A a, typename C>
255 struct UpsampleBackwardBatchRuleHelper;
256 
257 template <typename F, F Func, typename A, typename B, typename C, typename... T>
258 struct UpsampleBackwardBatchRuleHelper<F, Func, typelist<A, B, C, T...>> {
applyat::functorch::UpsampleBackwardBatchRuleHelper259   static std::tuple<Tensor, std::optional<int64_t>> apply(
260       const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
261       c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size,
262       T... extra_args) {
263     auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
264     TORCH_INTERNAL_ASSERT(!input_size.empty());
265 
266     // input_size is wrong so we correct it
267     c10::SymDimVector physical_input_size(input_size.begin(), input_size.end());
268     physical_input_size[0] = grad_output_.sym_sizes()[0];
269 
270     auto out = Func(
271         grad_output_,
272         output_size,
273         physical_input_size,
274         std::forward<T>(extra_args)...);
275     return std::make_tuple(reshape_dim_outof_symint(0, grad_output.sym_sizes()[*grad_output_bdim], out), 0);
276   }
277 
278 };
279 
280 template <typename A, A a, typename C>
281 struct GridSampleBatchRuleHelper;
282 
283 template <typename F, F Func, typename T1, typename T2, typename... T>
284 struct GridSampleBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
applyat::functorch::GridSampleBatchRuleHelper285   static std::tuple<Tensor, std::optional<int64_t>> apply(
286       const Tensor& input, std::optional<int64_t> input_batch_dim,
287       const Tensor& grid, std::optional<int64_t> grid_batch_dim,
288       T... extra_args) {
289     return grid_sample_batch_rule<F, Func, T...>(
290         input, input_batch_dim, grid, grid_batch_dim, std::forward<T>(extra_args)...);
291   }
292 };
293 
294 template <typename A, A a, typename C>
295 struct GridSampleBackwardBatchRuleHelper;
296 
297 template <typename F, F Func, typename T1, typename T2, typename T3, typename... T>
298 struct GridSampleBackwardBatchRuleHelper<F, Func, typelist<T1, T2, T3, T...>> {
applyat::functorch::GridSampleBackwardBatchRuleHelper299   static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>> apply(
300       const Tensor& grad_output, std::optional<int64_t> grad_output_batch_dim,
301       const Tensor& input, std::optional<int64_t> input_batch_dim,
302       const Tensor& grid, std::optional<int64_t> grid_batch_dim,
303       T... extra_args) {
304     return grid_sample_backward_batch_rule<F, Func, T...>(
305         grad_output, grad_output_batch_dim,
306         input, input_batch_dim,
307         grid, grid_batch_dim,
308         std::forward<T>(extra_args)...);
309   }
310 };
311 
312 template <typename F, F Func>
313 struct CudnnGridSampleBackwardBatchRuleHelper {
applyat::functorch::CudnnGridSampleBackwardBatchRuleHelper314   static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>> apply(
315       const Tensor& input, std::optional<int64_t> input_batch_dim,
316       const Tensor& grid, std::optional<int64_t> grid_batch_dim,
317       const Tensor& grad_output, std::optional<int64_t> grad_output_batch_dim) {
318     return cudnn_grid_sample_backward_batch_rule<F, Func>(
319         input, input_batch_dim,
320         grid, grid_batch_dim,
321         grad_output, grad_output_batch_dim
322     );
323   }
324 };
325 
326 #define GRID_SAMPLE_BATCH_RULE(fn) SINGLE_ARG(\
327     GridSampleBatchRuleHelper<\
328       decltype(&ATEN_FN(fn)),\
329       &ATEN_FN(fn),\
330       c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
331 
332 #define GRID_SAMPLE_BW_BATCH_RULE(fn) SINGLE_ARG(\
333     GridSampleBackwardBatchRuleHelper<\
334       decltype(&ATEN_FN(fn)),\
335       &ATEN_FN(fn),\
336       c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply)
337 
338 #define CUDNN_GRID_SAMPLE_BW_BATCH_RULE(fn)\
339     CudnnGridSampleBackwardBatchRuleHelper<decltype(&ATEN_FN(fn)), &ATEN_FN(fn)>::apply
340 
341 #define UPSAMPLE_BACKWARD(op) VMAP_SUPPORT(op, SINGLE_ARG(\
342     UpsampleBackwardBatchRuleHelper<\
343       decltype(&ATEN_FN(op)),\
344       &ATEN_FN(op),\
345       c10::guts::function_traits<decltype(ATEN_FN(op))>::parameter_types>::apply))
346 
347 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)348 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
349   EXISTING_BDIM(im2col);
350   EXISTING_BDIM(col2im);
351 
352   VMAP_SUPPORT(embedding, embedding_batch_rule);
353   VMAP_SUPPORT(embedding_dense_backward, embedding_dense_backward_batch_rule);
354 
355   VMAP_SUPPORT(grid_sampler_2d, GRID_SAMPLE_BATCH_RULE(grid_sampler));
356   VMAP_SUPPORT(grid_sampler_2d_backward, GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_2d_backward));
357 
358   VMAP_SUPPORT(grid_sampler_3d, GRID_SAMPLE_BATCH_RULE(grid_sampler));
359   VMAP_SUPPORT(grid_sampler_3d_backward, GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_3d_backward));
360   VMAP_SUPPORT(cudnn_grid_sampler_backward, CUDNN_GRID_SAMPLE_BW_BATCH_RULE(cudnn_grid_sampler_backward));
361 
362   VMAP_SUPPORT(cudnn_grid_sampler, GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler));
363 
364   EXISTING_BDIM(pixel_shuffle);
365   EXISTING_BDIM(pixel_unshuffle);
366   EXISTING_BDIM(channel_shuffle);
367 
368   VARIADIC_BDIMS(constant_pad_nd);
369   EXISTING_BDIM(reflection_pad1d);
370   EXISTING_BDIM(reflection_pad2d);
371   EXISTING_BDIM(reflection_pad3d);
372   EXISTING_BDIM(replication_pad1d);
373   EXISTING_BDIM(replication_pad2d);
374   EXISTING_BDIM(replication_pad3d);
375 
376   EXISTING_BDIM_ALL_BOXED(replication_pad1d_backward);
377   EXISTING_BDIM_ALL_BOXED(replication_pad2d_backward);
378   EXISTING_BDIM_ALL_BOXED(replication_pad3d_backward);
379 
380   EXISTING_BDIM_ALL_BOXED(reflection_pad1d_backward);
381   EXISTING_BDIM_ALL_BOXED(reflection_pad2d_backward);
382   EXISTING_BDIM_ALL_BOXED(reflection_pad3d_backward);
383 
384   EXISTING_BDIM(upsample_bicubic2d);
385   EXISTING_BDIM(upsample_bilinear2d);
386   EXISTING_BDIM(upsample_linear1d);
387   EXISTING_BDIM(upsample_nearest1d);
388   EXISTING_BDIM(upsample_nearest2d);
389   EXISTING_BDIM(upsample_nearest3d);
390   EXISTING_BDIM(upsample_trilinear3d);
391   EXISTING_BDIM(_upsample_bilinear2d_aa);
392   EXISTING_BDIM(_upsample_bicubic2d_aa);
393 
394   UPSAMPLE_BACKWARD(upsample_bicubic2d_backward);
395   UPSAMPLE_BACKWARD(upsample_bilinear2d_backward);
396   UPSAMPLE_BACKWARD(upsample_linear1d_backward);
397   UPSAMPLE_BACKWARD(upsample_nearest1d_backward);
398   UPSAMPLE_BACKWARD(upsample_nearest2d_backward);
399   UPSAMPLE_BACKWARD(upsample_nearest3d_backward);
400   UPSAMPLE_BACKWARD(upsample_trilinear3d_backward);
401   UPSAMPLE_BACKWARD(_upsample_bilinear2d_aa_backward);
402   UPSAMPLE_BACKWARD(_upsample_bicubic2d_aa_backward);
403 
404   m.impl("one_hot", one_hot_decomposition_hack);
405 }
406 
407 } // namespace at::functorch
408