xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesDecompositions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 // Copyright (c) Facebook, Inc. and its affiliates.
3 // All rights reserved.
4 //
5 // This source code is licensed under the BSD-style license found in the
6 // LICENSE file in the root directory of this source tree.
7 
8 #include <ATen/FunctionalTensorWrapper.h>
9 #include <ATen/Operators.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 #include <ATen/functorch/BatchRulesHelper.h>
12 #include <ATen/functorch/BatchedFallback.h>
13 #include <ATen/functorch/DynamicLayer.h>
14 #include <ATen/functorch/PlumbingHelper.h>
15 
16 namespace at::functorch {
17 
18 #define OP_DECOMPOSE(op)  m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
19 #define OP_DECOMPOSE2(op, overload)  m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
20 
TORCH_LIBRARY_IMPL(aten,FuncTorchVmapMode,m)21 TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
22   OP_DECOMPOSE(alpha_dropout_);
23   OP_DECOMPOSE(dropout_);
24   OP_DECOMPOSE(feature_alpha_dropout_);
25   OP_DECOMPOSE(feature_dropout_);
26   OP_DECOMPOSE(dropout);
27   OP_DECOMPOSE(_scaled_dot_product_attention_math);
28   OP_DECOMPOSE(scaled_dot_product_attention);
29 }
30 
unsupportedData(const c10::OperatorHandle & op,torch::jit::Stack * stack)31 static void unsupportedData(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
32     TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed.");
33 }
34 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatchedDecomposition,m)35 TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
36   OP_DECOMPOSE2(__and__, Scalar);
37   OP_DECOMPOSE2(__and__, Tensor);
38   OP_DECOMPOSE2(__iand__, Tensor);
39   OP_DECOMPOSE2(__iand__, Scalar);
40   OP_DECOMPOSE2(__ior__, Tensor);
41   OP_DECOMPOSE2(__ior__, Scalar);
42   OP_DECOMPOSE2(__ixor__, Tensor);
43   OP_DECOMPOSE2(__ixor__, Scalar);
44   OP_DECOMPOSE2(__or__, Tensor);
45   OP_DECOMPOSE2(__or__, Scalar);
46   OP_DECOMPOSE2(__xor__, Tensor);
47   OP_DECOMPOSE2(__xor__, Scalar);
48   OP_DECOMPOSE(_batch_norm_impl_index);
49   OP_DECOMPOSE(absolute);
50   OP_DECOMPOSE(absolute_);
51   OP_DECOMPOSE(arctan2);
52   OP_DECOMPOSE(arctan2_);
53   OP_DECOMPOSE(argsort);
54   OP_DECOMPOSE2(argsort, stable);
55   OP_DECOMPOSE(avg_pool1d);
56   OP_DECOMPOSE(adaptive_max_pool1d);
57   OP_DECOMPOSE(adaptive_avg_pool1d);
58   m.impl("adaptive_avg_pool2d", native::adaptive_avg_pool2d_symint);
59   m.impl("adaptive_avg_pool3d", native::adaptive_avg_pool3d_symint);
60   OP_DECOMPOSE(adjoint);
61   OP_DECOMPOSE(arccos);
62   OP_DECOMPOSE(arccos_);
63   OP_DECOMPOSE(arccosh);
64   OP_DECOMPOSE(arccosh_);
65   OP_DECOMPOSE(arcsin);
66   OP_DECOMPOSE(arcsin_);
67   OP_DECOMPOSE(arcsinh);
68   OP_DECOMPOSE(arcsinh_);
69   OP_DECOMPOSE(arctan);
70   OP_DECOMPOSE(arctan_);
71   OP_DECOMPOSE(arctanh);
72   OP_DECOMPOSE(arctanh_);
73   OP_DECOMPOSE(atleast_1d);
74   OP_DECOMPOSE2(atleast_1d, Sequence);
75   OP_DECOMPOSE(atleast_2d);
76   OP_DECOMPOSE2(atleast_2d, Sequence);
77   OP_DECOMPOSE(atleast_3d);
78   OP_DECOMPOSE2(atleast_3d, Sequence);
79   OP_DECOMPOSE(batch_norm);
80   OP_DECOMPOSE(broadcast_tensors);
81   m.impl("broadcast_to", native::broadcast_to_symint);
82   OP_DECOMPOSE(cartesian_prod);
83   OP_DECOMPOSE(cdist);
84   OP_DECOMPOSE(chunk);
85   OP_DECOMPOSE(clip);
86   OP_DECOMPOSE2(clip, Tensor );
87   OP_DECOMPOSE(concat);
88   OP_DECOMPOSE(conj_physical);
89   OP_DECOMPOSE(contiguous);
90   OP_DECOMPOSE(combinations);
91   OP_DECOMPOSE(corrcoef);
92   OP_DECOMPOSE(cosine_embedding_loss);
93   OP_DECOMPOSE(cosine_similarity);
94   OP_DECOMPOSE(cov);
95   OP_DECOMPOSE(cross);
96   m.impl("cross_entropy_loss", native::cross_entropy_loss_symint);
97   OP_DECOMPOSE2(cumulative_trapezoid, x);
98   OP_DECOMPOSE2(cumulative_trapezoid, dx);
99   OP_DECOMPOSE2(dsplit, int);
100   OP_DECOMPOSE2(dsplit, array);
101   OP_DECOMPOSE(det);
102   OP_DECOMPOSE(diff);
103   OP_DECOMPOSE(diag);
104   OP_DECOMPOSE(dstack);
105   OP_DECOMPOSE(einsum);
106   m.impl("embedding_backward", native::embedding_backward_symint);
107   OP_DECOMPOSE(expand_as);
108   m.impl("fft_fft", native::fft_fft_symint);
109   OP_DECOMPOSE(fft_fftshift);
110   m.impl("fft_fft2", native::fft_fft2_symint);
111   m.impl("fft_fftn", native::fft_fftn_symint);
112   m.impl("fft_hfft", native::fft_hfft_symint);
113   m.impl("fft_hfft2", native::fft_hfft2_symint);
114   m.impl("fft_hfftn", native::fft_hfftn_symint);
115   m.impl("fft_ifft", native::fft_ifft_symint);
116   OP_DECOMPOSE(fft_ifftshift);
117   m.impl("fft_ifft2", native::fft_ifft2_symint);
118   m.impl("fft_ifftn", native::fft_ifftn_symint);
119   m.impl("fft_ihfft", native::fft_ihfft_symint);
120   m.impl("fft_irfft", native::fft_irfft_symint);
121   m.impl("fft_irfft2", native::fft_irfft2_symint);
122   m.impl("fft_irfftn", native::fft_irfftn_symint);
123   m.impl("fft_rfft", native::fft_rfft_symint);
124   m.impl("fft_rfft2", native::fft_rfft2_symint);
125   m.impl("fft_rfftn", native::fft_rfftn_symint);
126   OP_DECOMPOSE(fix);
127   OP_DECOMPOSE(fliplr);
128   OP_DECOMPOSE(flipud);
129   OP_DECOMPOSE2(flatten, using_ints);
130   OP_DECOMPOSE2(float_power, Tensor_Tensor);
131   OP_DECOMPOSE2(float_power, Tensor_Scalar);
132   OP_DECOMPOSE2(float_power, Scalar);
133   OP_DECOMPOSE(gather_backward);
134   OP_DECOMPOSE(ger);
135   OP_DECOMPOSE2(gradient, scalarint);
136   OP_DECOMPOSE2(gradient, scalararray);
137   OP_DECOMPOSE2(gradient, array);
138   OP_DECOMPOSE2(gradient, scalarrayint);
139   OP_DECOMPOSE2(gradient, scalarrayarray);
140   OP_DECOMPOSE2(gradient, tensorarrayint);
141   OP_DECOMPOSE2(gradient, tensorarray);
142   OP_DECOMPOSE2(greater_equal, Tensor );
143   OP_DECOMPOSE2(greater_equal, Scalar );
144   OP_DECOMPOSE2(greater, Tensor );
145   OP_DECOMPOSE(grid_sampler);
146   OP_DECOMPOSE(group_norm);
147   OP_DECOMPOSE(hinge_embedding_loss);
148   OP_DECOMPOSE2(hsplit, int);
149   OP_DECOMPOSE2(hsplit, array);
150   OP_DECOMPOSE(hstack);
151   m.impl("index_select_backward", native::index_select_backward_symint);
152   OP_DECOMPOSE(inner);
153   OP_DECOMPOSE(inverse);
154   OP_DECOMPOSE(isfinite);
155   OP_DECOMPOSE(isreal);
156   OP_DECOMPOSE(concatenate);
157   OP_DECOMPOSE(instance_norm);
158   OP_DECOMPOSE(kron);
159   OP_DECOMPOSE(l1_loss);
160   m.impl("layer_norm", native::layer_norm_symint);
161   OP_DECOMPOSE2(ldexp, Tensor);
162   OP_DECOMPOSE2(less_equal, Tensor );
163   OP_DECOMPOSE2(less, Tensor );
164   OP_DECOMPOSE(linear);
165   OP_DECOMPOSE(linalg_cond);
166   OP_DECOMPOSE(linalg_cholesky);
167   OP_DECOMPOSE(linalg_det);
168   OP_DECOMPOSE(linalg_eigvalsh);
169   OP_DECOMPOSE(linalg_eigvals);
170   OP_DECOMPOSE(linalg_inv);
171   OP_DECOMPOSE(linalg_lu_factor);
172   OP_DECOMPOSE(linalg_matmul);
173   OP_DECOMPOSE(linalg_matrix_norm);
174   OP_DECOMPOSE2(linalg_matrix_norm, str_ord);
175   OP_DECOMPOSE(linalg_multi_dot);
176   OP_DECOMPOSE(linalg_norm);
177   OP_DECOMPOSE2(linalg_norm, ord_str);
178   OP_DECOMPOSE(linalg_eigh);
179   OP_DECOMPOSE(linalg_solve);
180   OP_DECOMPOSE(linalg_solve_ex);
181   OP_DECOMPOSE(linalg_svd);
182   OP_DECOMPOSE(linalg_svdvals);
183   OP_DECOMPOSE(linalg_pinv);
184   OP_DECOMPOSE(linalg_tensorinv);
185   OP_DECOMPOSE2(linalg_pinv, atol_rtol_float);
186   m.impl("linalg_vander", native::linalg_vander_symint);
187   OP_DECOMPOSE(cumprod_backward);
188   OP_DECOMPOSE(linalg_matrix_power);
189   OP_DECOMPOSE(linalg_vecdot);
190   OP_DECOMPOSE(log_sigmoid);
191   OP_DECOMPOSE(logdet);
192   OP_DECOMPOSE2(log_softmax, int);
193   OP_DECOMPOSE(_lu_with_info);
194   OP_DECOMPOSE(matmul);
195   OP_DECOMPOSE(matrix_H);
196   OP_DECOMPOSE(matrix_power);
197   OP_DECOMPOSE2(max, other );
198   OP_DECOMPOSE(max_pool1d);
199   OP_DECOMPOSE(max_pool1d_with_indices);
200   OP_DECOMPOSE(max_pool2d);
201   OP_DECOMPOSE(max_pool3d);
202   OP_DECOMPOSE(meshgrid);
203   OP_DECOMPOSE2(meshgrid, indexing);
204   OP_DECOMPOSE(mH);
205   OP_DECOMPOSE2(min, other );
206   OP_DECOMPOSE2(moveaxis, intlist);
207   OP_DECOMPOSE2(movedim, int);
208   OP_DECOMPOSE2(movedim, intlist);
209   OP_DECOMPOSE(msort);
210   OP_DECOMPOSE(mT);
211   OP_DECOMPOSE(nanmean);
212   m.impl("narrow", native::narrow_symint);
213   OP_DECOMPOSE(negative);
214   OP_DECOMPOSE2(frobenius_norm, dim);
215   OP_DECOMPOSE2(nuclear_norm, dim);
216   OP_DECOMPOSE(nuclear_norm);
217   m.impl("nll_loss_nd", native::nll_loss_nd_symint);
218   m.impl("nll_loss", native::nll_loss_symint);
219   m.impl("nll_loss2d", native::nll_loss2d_symint);
220   OP_DECOMPOSE2(not_equal, Tensor );
221   OP_DECOMPOSE(outer);
222   OP_DECOMPOSE(pairwise_distance);
223   OP_DECOMPOSE(pinverse);
224   OP_DECOMPOSE(poisson_nll_loss);
225   OP_DECOMPOSE(positive);
226   OP_DECOMPOSE(qr);
227   OP_DECOMPOSE(ravel);
228   m.impl("repeat_interleave.self_int", static_cast<decltype(&ATEN_FN2(repeat_interleave, self_int))>(native::repeat_interleave_symint));
229   m.impl("repeat_interleave.self_Tensor", static_cast<decltype(&ATEN_FN2(repeat_interleave, self_Tensor))>(native::repeat_interleave_symint));
230   m.impl("reshape", native::reshape_symint);
231   OP_DECOMPOSE(resolve_conj);
232   OP_DECOMPOSE(resolve_neg);
233   OP_DECOMPOSE(rms_norm);
234   OP_DECOMPOSE(row_stack);
235   OP_DECOMPOSE(rrelu);
236   OP_DECOMPOSE(rrelu_);
237   OP_DECOMPOSE(relu6);
238   OP_DECOMPOSE(relu6_);
239   OP_DECOMPOSE(prelu);
240   OP_DECOMPOSE2(softmax, int);
241   OP_DECOMPOSE(special_gammainc);
242   OP_DECOMPOSE(special_gammaincc);
243   OP_DECOMPOSE(special_logit);
244   OP_DECOMPOSE(special_log_softmax);
245   OP_DECOMPOSE(special_logsumexp);
246   OP_DECOMPOSE(special_multigammaln);
247   OP_DECOMPOSE(special_polygamma);
248   OP_DECOMPOSE(special_softmax);
249   OP_DECOMPOSE(special_digamma);
250   OP_DECOMPOSE(special_erf);
251   OP_DECOMPOSE(special_erfc);
252   OP_DECOMPOSE(special_erfinv);
253   OP_DECOMPOSE(special_exp2);
254   OP_DECOMPOSE(special_expm1);
255   OP_DECOMPOSE(special_expit);
256   OP_DECOMPOSE(special_gammaln);
257   OP_DECOMPOSE(special_i0);
258   OP_DECOMPOSE(special_log1p);
259   OP_DECOMPOSE(special_ndtr);
260   OP_DECOMPOSE(special_psi);
261   OP_DECOMPOSE(special_round);
262   OP_DECOMPOSE(special_sinc);
263   OP_DECOMPOSE(special_xlogy);
264   OP_DECOMPOSE2(special_xlogy, other_scalar);
265   OP_DECOMPOSE2(special_xlogy, self_scalar);
266 
267 
268   m.impl("split.sizes", native::split_symint);
269   OP_DECOMPOSE(square);
270   OP_DECOMPOSE(numpy_T);
271   OP_DECOMPOSE(reshape_as);
272   OP_DECOMPOSE(slogdet);
273   OP_DECOMPOSE2(result_type, Tensor);
274   OP_DECOMPOSE2(result_type, Scalar);
275   OP_DECOMPOSE2(result_type, Scalar_Tensor);
276   OP_DECOMPOSE2(result_type, Scalar_Scalar);
277   OP_DECOMPOSE(is_same_size);
278   OP_DECOMPOSE(view_as);
279   OP_DECOMPOSE2(size, int);
280   OP_DECOMPOSE(is_complex);
281   OP_DECOMPOSE(std);
282   OP_DECOMPOSE(selu);
283   OP_DECOMPOSE(selu_);
284   OP_DECOMPOSE2(std, dim);
285   OP_DECOMPOSE(std_mean);
286   OP_DECOMPOSE2(std_mean, dim);
287   OP_DECOMPOSE(swapaxes);
288   OP_DECOMPOSE2(subtract, Tensor);
289   m.impl("sum_to_size", native::sum_to_size_symint);
290   OP_DECOMPOSE(svd);
291   OP_DECOMPOSE(swapdims);
292   OP_DECOMPOSE(take_along_dim);
293   OP_DECOMPOSE(tensordot);
294   m.impl("tensor_split.indices", native::tensor_split_indices_symint);
295   m.impl("tensor_split.sections", native::tensor_split_sections_symint);
296   OP_DECOMPOSE(_test_check_tensor);
297   m.impl("tile", native::tile_symint);
298   OP_DECOMPOSE2(trapezoid, x);
299   OP_DECOMPOSE2(trapezoid, dx);
300   OP_DECOMPOSE2(trapz, x);
301   OP_DECOMPOSE2(trapz, dx);
302   OP_DECOMPOSE(unsafe_chunk);
303   m.impl("value_selecting_reduction_backward", native::value_selecting_reduction_backward_symint);
304   OP_DECOMPOSE(var);
305   OP_DECOMPOSE2(var, dim);
306   OP_DECOMPOSE(var_mean);
307   OP_DECOMPOSE2(var_mean, dim);
308   OP_DECOMPOSE2(vsplit, int);
309   OP_DECOMPOSE2(vsplit, array);
310   OP_DECOMPOSE(vstack);
311   OP_DECOMPOSE2(where, ScalarOther);
312   OP_DECOMPOSE2(where, ScalarSelf);
313   OP_DECOMPOSE2(where, Scalar);
314   OP_DECOMPOSE(orgqr);
315   m.impl("unflatten.int", native::unflatten_symint);
316   m.impl("_convolution_double_backward", native::_convolution_double_backward);
317   m.impl("conv_transpose1d", native::conv_transpose1d_symint);
318   m.impl("conv_transpose2d.input", native::conv_transpose2d_symint);
319   m.impl("conv_transpose3d.input", native::conv_transpose3d_symint);
320   m.impl("conv1d", native::conv1d_symint);
321   m.impl("conv2d", native::conv2d_symint);
322   m.impl("conv3d", native::conv3d_symint);
323   m.impl("conv1d.padding", native::conv1d_padding_symint);
324   m.impl("conv2d.padding", native::conv2d_padding_symint);
325   m.impl("conv3d.padding", native::conv3d_padding_symint);
326   m.impl("_convolution_mode", native::_convolution_mode_symint);
327   OP_DECOMPOSE(type_as);
328   OP_DECOMPOSE(linalg_diagonal);
329   OP_DECOMPOSE(diagonal_copy);
330   OP_DECOMPOSE(alias_copy);
331   m.impl("as_strided_copy", native::as_strided_copy_symint);
332   m.impl("pad", native::pad_symint);
333   m.impl("_pad_circular", native::_pad_circular_symint);
334   OP_DECOMPOSE(swapdims_);
335   OP_DECOMPOSE(swapaxes_);
336   OP_DECOMPOSE(unfold_copy);
337   // Easy way to decompose upsample*.vec overloads instead of introducing *_symint methods
338   // if used OP_DECOMPOSE2.
339   m.impl("upsample_bilinear2d.vec", native::upsample_bilinear2d);
340   m.impl("upsample_bicubic2d.vec", native::upsample_bicubic2d);
341   m.impl("_upsample_bilinear2d_aa.vec", native::_upsample_bilinear2d_aa);
342   m.impl("_upsample_bicubic2d_aa.vec", native::_upsample_bicubic2d_aa);
343   m.impl("upsample_linear1d.vec", native::upsample_linear1d);
344   m.impl("upsample_nearest1d.vec", native::upsample_nearest1d);
345   m.impl("upsample_nearest2d.vec", native::upsample_nearest2d);
346   m.impl("upsample_nearest3d.vec", native::upsample_nearest3d);
347   m.impl("upsample_trilinear3d.vec", native::upsample_trilinear3d);
348 
349   // views on complex tensor
350   OP_DECOMPOSE(imag);
351   OP_DECOMPOSE(real);
352 
353   // divide, alias for div
354   OP_DECOMPOSE2(divide, Tensor);
355   OP_DECOMPOSE2(divide_, Tensor);
356   OP_DECOMPOSE2(divide, Scalar);
357   OP_DECOMPOSE2(divide, Tensor_mode);
358   OP_DECOMPOSE2(divide_, Tensor_mode);
359   OP_DECOMPOSE2(divide, Scalar_mode);
360   OP_DECOMPOSE2(divide_, Scalar_mode);
361 
362   // divide, alias for div
363   OP_DECOMPOSE2(true_divide, Tensor);
364   OP_DECOMPOSE2(true_divide_, Tensor);
365   OP_DECOMPOSE2(true_divide, Scalar);
366   OP_DECOMPOSE2(true_divide_, Scalar);
367 
368   // multiply, alias for mul
369   OP_DECOMPOSE2(multiply, Tensor)
370   OP_DECOMPOSE2(multiply_, Tensor)
371   OP_DECOMPOSE2(multiply, Scalar)
372   OP_DECOMPOSE2(multiply_, Scalar)
373 
374   OP_DECOMPOSE2(linalg_matrix_rank, atol_rtol_tensor);
375   OP_DECOMPOSE2(linalg_matrix_rank, atol_rtol_float);
376   OP_DECOMPOSE(linalg_ldl_factor);
377 
378   // comparison ops
379   OP_DECOMPOSE2(greater, Scalar);
380   OP_DECOMPOSE2(less_equal, Scalar);
381   OP_DECOMPOSE2(less, Scalar);
382   OP_DECOMPOSE2(not_equal, Scalar);
383   m.impl("_has_compatible_shallow_copy_type", torch::CppFunction::makeFromBoxedFunction<&unsupportedData>());
384 
385   // to.*
386   OP_DECOMPOSE2(to, device);
387   OP_DECOMPOSE2(to, dtype);
388   OP_DECOMPOSE2(to, dtype_layout);
389   OP_DECOMPOSE2(to, other);
390 
391   // Random ops that are also registered here
392   OP_DECOMPOSE(dropout);
393   OP_DECOMPOSE(_scaled_dot_product_attention_math);
394   OP_DECOMPOSE(scaled_dot_product_attention);
395 }
396 
397 } // namespace at::functorch
398