1
2 // Copyright (c) Facebook, Inc. and its affiliates.
3 // All rights reserved.
4 //
5 // This source code is licensed under the BSD-style license found in the
6 // LICENSE file in the root directory of this source tree.
7
8 #include <ATen/FunctionalTensorWrapper.h>
9 #include <ATen/Operators.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 #include <ATen/functorch/BatchRulesHelper.h>
12 #include <ATen/functorch/BatchedFallback.h>
13 #include <ATen/functorch/DynamicLayer.h>
14 #include <ATen/functorch/PlumbingHelper.h>
15
16 namespace at::functorch {
17
18 #define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
19 #define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
20
TORCH_LIBRARY_IMPL(aten,FuncTorchVmapMode,m)21 TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
22 OP_DECOMPOSE(alpha_dropout_);
23 OP_DECOMPOSE(dropout_);
24 OP_DECOMPOSE(feature_alpha_dropout_);
25 OP_DECOMPOSE(feature_dropout_);
26 OP_DECOMPOSE(dropout);
27 OP_DECOMPOSE(_scaled_dot_product_attention_math);
28 OP_DECOMPOSE(scaled_dot_product_attention);
29 }
30
unsupportedData(const c10::OperatorHandle & op,torch::jit::Stack * stack)31 static void unsupportedData(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
32 TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed.");
33 }
34
TORCH_LIBRARY_IMPL(aten,FuncTorchBatchedDecomposition,m)35 TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
36 OP_DECOMPOSE2(__and__, Scalar);
37 OP_DECOMPOSE2(__and__, Tensor);
38 OP_DECOMPOSE2(__iand__, Tensor);
39 OP_DECOMPOSE2(__iand__, Scalar);
40 OP_DECOMPOSE2(__ior__, Tensor);
41 OP_DECOMPOSE2(__ior__, Scalar);
42 OP_DECOMPOSE2(__ixor__, Tensor);
43 OP_DECOMPOSE2(__ixor__, Scalar);
44 OP_DECOMPOSE2(__or__, Tensor);
45 OP_DECOMPOSE2(__or__, Scalar);
46 OP_DECOMPOSE2(__xor__, Tensor);
47 OP_DECOMPOSE2(__xor__, Scalar);
48 OP_DECOMPOSE(_batch_norm_impl_index);
49 OP_DECOMPOSE(absolute);
50 OP_DECOMPOSE(absolute_);
51 OP_DECOMPOSE(arctan2);
52 OP_DECOMPOSE(arctan2_);
53 OP_DECOMPOSE(argsort);
54 OP_DECOMPOSE2(argsort, stable);
55 OP_DECOMPOSE(avg_pool1d);
56 OP_DECOMPOSE(adaptive_max_pool1d);
57 OP_DECOMPOSE(adaptive_avg_pool1d);
58 m.impl("adaptive_avg_pool2d", native::adaptive_avg_pool2d_symint);
59 m.impl("adaptive_avg_pool3d", native::adaptive_avg_pool3d_symint);
60 OP_DECOMPOSE(adjoint);
61 OP_DECOMPOSE(arccos);
62 OP_DECOMPOSE(arccos_);
63 OP_DECOMPOSE(arccosh);
64 OP_DECOMPOSE(arccosh_);
65 OP_DECOMPOSE(arcsin);
66 OP_DECOMPOSE(arcsin_);
67 OP_DECOMPOSE(arcsinh);
68 OP_DECOMPOSE(arcsinh_);
69 OP_DECOMPOSE(arctan);
70 OP_DECOMPOSE(arctan_);
71 OP_DECOMPOSE(arctanh);
72 OP_DECOMPOSE(arctanh_);
73 OP_DECOMPOSE(atleast_1d);
74 OP_DECOMPOSE2(atleast_1d, Sequence);
75 OP_DECOMPOSE(atleast_2d);
76 OP_DECOMPOSE2(atleast_2d, Sequence);
77 OP_DECOMPOSE(atleast_3d);
78 OP_DECOMPOSE2(atleast_3d, Sequence);
79 OP_DECOMPOSE(batch_norm);
80 OP_DECOMPOSE(broadcast_tensors);
81 m.impl("broadcast_to", native::broadcast_to_symint);
82 OP_DECOMPOSE(cartesian_prod);
83 OP_DECOMPOSE(cdist);
84 OP_DECOMPOSE(chunk);
85 OP_DECOMPOSE(clip);
86 OP_DECOMPOSE2(clip, Tensor );
87 OP_DECOMPOSE(concat);
88 OP_DECOMPOSE(conj_physical);
89 OP_DECOMPOSE(contiguous);
90 OP_DECOMPOSE(combinations);
91 OP_DECOMPOSE(corrcoef);
92 OP_DECOMPOSE(cosine_embedding_loss);
93 OP_DECOMPOSE(cosine_similarity);
94 OP_DECOMPOSE(cov);
95 OP_DECOMPOSE(cross);
96 m.impl("cross_entropy_loss", native::cross_entropy_loss_symint);
97 OP_DECOMPOSE2(cumulative_trapezoid, x);
98 OP_DECOMPOSE2(cumulative_trapezoid, dx);
99 OP_DECOMPOSE2(dsplit, int);
100 OP_DECOMPOSE2(dsplit, array);
101 OP_DECOMPOSE(det);
102 OP_DECOMPOSE(diff);
103 OP_DECOMPOSE(diag);
104 OP_DECOMPOSE(dstack);
105 OP_DECOMPOSE(einsum);
106 m.impl("embedding_backward", native::embedding_backward_symint);
107 OP_DECOMPOSE(expand_as);
108 m.impl("fft_fft", native::fft_fft_symint);
109 OP_DECOMPOSE(fft_fftshift);
110 m.impl("fft_fft2", native::fft_fft2_symint);
111 m.impl("fft_fftn", native::fft_fftn_symint);
112 m.impl("fft_hfft", native::fft_hfft_symint);
113 m.impl("fft_hfft2", native::fft_hfft2_symint);
114 m.impl("fft_hfftn", native::fft_hfftn_symint);
115 m.impl("fft_ifft", native::fft_ifft_symint);
116 OP_DECOMPOSE(fft_ifftshift);
117 m.impl("fft_ifft2", native::fft_ifft2_symint);
118 m.impl("fft_ifftn", native::fft_ifftn_symint);
119 m.impl("fft_ihfft", native::fft_ihfft_symint);
120 m.impl("fft_irfft", native::fft_irfft_symint);
121 m.impl("fft_irfft2", native::fft_irfft2_symint);
122 m.impl("fft_irfftn", native::fft_irfftn_symint);
123 m.impl("fft_rfft", native::fft_rfft_symint);
124 m.impl("fft_rfft2", native::fft_rfft2_symint);
125 m.impl("fft_rfftn", native::fft_rfftn_symint);
126 OP_DECOMPOSE(fix);
127 OP_DECOMPOSE(fliplr);
128 OP_DECOMPOSE(flipud);
129 OP_DECOMPOSE2(flatten, using_ints);
130 OP_DECOMPOSE2(float_power, Tensor_Tensor);
131 OP_DECOMPOSE2(float_power, Tensor_Scalar);
132 OP_DECOMPOSE2(float_power, Scalar);
133 OP_DECOMPOSE(gather_backward);
134 OP_DECOMPOSE(ger);
135 OP_DECOMPOSE2(gradient, scalarint);
136 OP_DECOMPOSE2(gradient, scalararray);
137 OP_DECOMPOSE2(gradient, array);
138 OP_DECOMPOSE2(gradient, scalarrayint);
139 OP_DECOMPOSE2(gradient, scalarrayarray);
140 OP_DECOMPOSE2(gradient, tensorarrayint);
141 OP_DECOMPOSE2(gradient, tensorarray);
142 OP_DECOMPOSE2(greater_equal, Tensor );
143 OP_DECOMPOSE2(greater_equal, Scalar );
144 OP_DECOMPOSE2(greater, Tensor );
145 OP_DECOMPOSE(grid_sampler);
146 OP_DECOMPOSE(group_norm);
147 OP_DECOMPOSE(hinge_embedding_loss);
148 OP_DECOMPOSE2(hsplit, int);
149 OP_DECOMPOSE2(hsplit, array);
150 OP_DECOMPOSE(hstack);
151 m.impl("index_select_backward", native::index_select_backward_symint);
152 OP_DECOMPOSE(inner);
153 OP_DECOMPOSE(inverse);
154 OP_DECOMPOSE(isfinite);
155 OP_DECOMPOSE(isreal);
156 OP_DECOMPOSE(concatenate);
157 OP_DECOMPOSE(instance_norm);
158 OP_DECOMPOSE(kron);
159 OP_DECOMPOSE(l1_loss);
160 m.impl("layer_norm", native::layer_norm_symint);
161 OP_DECOMPOSE2(ldexp, Tensor);
162 OP_DECOMPOSE2(less_equal, Tensor );
163 OP_DECOMPOSE2(less, Tensor );
164 OP_DECOMPOSE(linear);
165 OP_DECOMPOSE(linalg_cond);
166 OP_DECOMPOSE(linalg_cholesky);
167 OP_DECOMPOSE(linalg_det);
168 OP_DECOMPOSE(linalg_eigvalsh);
169 OP_DECOMPOSE(linalg_eigvals);
170 OP_DECOMPOSE(linalg_inv);
171 OP_DECOMPOSE(linalg_lu_factor);
172 OP_DECOMPOSE(linalg_matmul);
173 OP_DECOMPOSE(linalg_matrix_norm);
174 OP_DECOMPOSE2(linalg_matrix_norm, str_ord);
175 OP_DECOMPOSE(linalg_multi_dot);
176 OP_DECOMPOSE(linalg_norm);
177 OP_DECOMPOSE2(linalg_norm, ord_str);
178 OP_DECOMPOSE(linalg_eigh);
179 OP_DECOMPOSE(linalg_solve);
180 OP_DECOMPOSE(linalg_solve_ex);
181 OP_DECOMPOSE(linalg_svd);
182 OP_DECOMPOSE(linalg_svdvals);
183 OP_DECOMPOSE(linalg_pinv);
184 OP_DECOMPOSE(linalg_tensorinv);
185 OP_DECOMPOSE2(linalg_pinv, atol_rtol_float);
186 m.impl("linalg_vander", native::linalg_vander_symint);
187 OP_DECOMPOSE(cumprod_backward);
188 OP_DECOMPOSE(linalg_matrix_power);
189 OP_DECOMPOSE(linalg_vecdot);
190 OP_DECOMPOSE(log_sigmoid);
191 OP_DECOMPOSE(logdet);
192 OP_DECOMPOSE2(log_softmax, int);
193 OP_DECOMPOSE(_lu_with_info);
194 OP_DECOMPOSE(matmul);
195 OP_DECOMPOSE(matrix_H);
196 OP_DECOMPOSE(matrix_power);
197 OP_DECOMPOSE2(max, other );
198 OP_DECOMPOSE(max_pool1d);
199 OP_DECOMPOSE(max_pool1d_with_indices);
200 OP_DECOMPOSE(max_pool2d);
201 OP_DECOMPOSE(max_pool3d);
202 OP_DECOMPOSE(meshgrid);
203 OP_DECOMPOSE2(meshgrid, indexing);
204 OP_DECOMPOSE(mH);
205 OP_DECOMPOSE2(min, other );
206 OP_DECOMPOSE2(moveaxis, intlist);
207 OP_DECOMPOSE2(movedim, int);
208 OP_DECOMPOSE2(movedim, intlist);
209 OP_DECOMPOSE(msort);
210 OP_DECOMPOSE(mT);
211 OP_DECOMPOSE(nanmean);
212 m.impl("narrow", native::narrow_symint);
213 OP_DECOMPOSE(negative);
214 OP_DECOMPOSE2(frobenius_norm, dim);
215 OP_DECOMPOSE2(nuclear_norm, dim);
216 OP_DECOMPOSE(nuclear_norm);
217 m.impl("nll_loss_nd", native::nll_loss_nd_symint);
218 m.impl("nll_loss", native::nll_loss_symint);
219 m.impl("nll_loss2d", native::nll_loss2d_symint);
220 OP_DECOMPOSE2(not_equal, Tensor );
221 OP_DECOMPOSE(outer);
222 OP_DECOMPOSE(pairwise_distance);
223 OP_DECOMPOSE(pinverse);
224 OP_DECOMPOSE(poisson_nll_loss);
225 OP_DECOMPOSE(positive);
226 OP_DECOMPOSE(qr);
227 OP_DECOMPOSE(ravel);
228 m.impl("repeat_interleave.self_int", static_cast<decltype(&ATEN_FN2(repeat_interleave, self_int))>(native::repeat_interleave_symint));
229 m.impl("repeat_interleave.self_Tensor", static_cast<decltype(&ATEN_FN2(repeat_interleave, self_Tensor))>(native::repeat_interleave_symint));
230 m.impl("reshape", native::reshape_symint);
231 OP_DECOMPOSE(resolve_conj);
232 OP_DECOMPOSE(resolve_neg);
233 OP_DECOMPOSE(rms_norm);
234 OP_DECOMPOSE(row_stack);
235 OP_DECOMPOSE(rrelu);
236 OP_DECOMPOSE(rrelu_);
237 OP_DECOMPOSE(relu6);
238 OP_DECOMPOSE(relu6_);
239 OP_DECOMPOSE(prelu);
240 OP_DECOMPOSE2(softmax, int);
241 OP_DECOMPOSE(special_gammainc);
242 OP_DECOMPOSE(special_gammaincc);
243 OP_DECOMPOSE(special_logit);
244 OP_DECOMPOSE(special_log_softmax);
245 OP_DECOMPOSE(special_logsumexp);
246 OP_DECOMPOSE(special_multigammaln);
247 OP_DECOMPOSE(special_polygamma);
248 OP_DECOMPOSE(special_softmax);
249 OP_DECOMPOSE(special_digamma);
250 OP_DECOMPOSE(special_erf);
251 OP_DECOMPOSE(special_erfc);
252 OP_DECOMPOSE(special_erfinv);
253 OP_DECOMPOSE(special_exp2);
254 OP_DECOMPOSE(special_expm1);
255 OP_DECOMPOSE(special_expit);
256 OP_DECOMPOSE(special_gammaln);
257 OP_DECOMPOSE(special_i0);
258 OP_DECOMPOSE(special_log1p);
259 OP_DECOMPOSE(special_ndtr);
260 OP_DECOMPOSE(special_psi);
261 OP_DECOMPOSE(special_round);
262 OP_DECOMPOSE(special_sinc);
263 OP_DECOMPOSE(special_xlogy);
264 OP_DECOMPOSE2(special_xlogy, other_scalar);
265 OP_DECOMPOSE2(special_xlogy, self_scalar);
266
267
268 m.impl("split.sizes", native::split_symint);
269 OP_DECOMPOSE(square);
270 OP_DECOMPOSE(numpy_T);
271 OP_DECOMPOSE(reshape_as);
272 OP_DECOMPOSE(slogdet);
273 OP_DECOMPOSE2(result_type, Tensor);
274 OP_DECOMPOSE2(result_type, Scalar);
275 OP_DECOMPOSE2(result_type, Scalar_Tensor);
276 OP_DECOMPOSE2(result_type, Scalar_Scalar);
277 OP_DECOMPOSE(is_same_size);
278 OP_DECOMPOSE(view_as);
279 OP_DECOMPOSE2(size, int);
280 OP_DECOMPOSE(is_complex);
281 OP_DECOMPOSE(std);
282 OP_DECOMPOSE(selu);
283 OP_DECOMPOSE(selu_);
284 OP_DECOMPOSE2(std, dim);
285 OP_DECOMPOSE(std_mean);
286 OP_DECOMPOSE2(std_mean, dim);
287 OP_DECOMPOSE(swapaxes);
288 OP_DECOMPOSE2(subtract, Tensor);
289 m.impl("sum_to_size", native::sum_to_size_symint);
290 OP_DECOMPOSE(svd);
291 OP_DECOMPOSE(swapdims);
292 OP_DECOMPOSE(take_along_dim);
293 OP_DECOMPOSE(tensordot);
294 m.impl("tensor_split.indices", native::tensor_split_indices_symint);
295 m.impl("tensor_split.sections", native::tensor_split_sections_symint);
296 OP_DECOMPOSE(_test_check_tensor);
297 m.impl("tile", native::tile_symint);
298 OP_DECOMPOSE2(trapezoid, x);
299 OP_DECOMPOSE2(trapezoid, dx);
300 OP_DECOMPOSE2(trapz, x);
301 OP_DECOMPOSE2(trapz, dx);
302 OP_DECOMPOSE(unsafe_chunk);
303 m.impl("value_selecting_reduction_backward", native::value_selecting_reduction_backward_symint);
304 OP_DECOMPOSE(var);
305 OP_DECOMPOSE2(var, dim);
306 OP_DECOMPOSE(var_mean);
307 OP_DECOMPOSE2(var_mean, dim);
308 OP_DECOMPOSE2(vsplit, int);
309 OP_DECOMPOSE2(vsplit, array);
310 OP_DECOMPOSE(vstack);
311 OP_DECOMPOSE2(where, ScalarOther);
312 OP_DECOMPOSE2(where, ScalarSelf);
313 OP_DECOMPOSE2(where, Scalar);
314 OP_DECOMPOSE(orgqr);
315 m.impl("unflatten.int", native::unflatten_symint);
316 m.impl("_convolution_double_backward", native::_convolution_double_backward);
317 m.impl("conv_transpose1d", native::conv_transpose1d_symint);
318 m.impl("conv_transpose2d.input", native::conv_transpose2d_symint);
319 m.impl("conv_transpose3d.input", native::conv_transpose3d_symint);
320 m.impl("conv1d", native::conv1d_symint);
321 m.impl("conv2d", native::conv2d_symint);
322 m.impl("conv3d", native::conv3d_symint);
323 m.impl("conv1d.padding", native::conv1d_padding_symint);
324 m.impl("conv2d.padding", native::conv2d_padding_symint);
325 m.impl("conv3d.padding", native::conv3d_padding_symint);
326 m.impl("_convolution_mode", native::_convolution_mode_symint);
327 OP_DECOMPOSE(type_as);
328 OP_DECOMPOSE(linalg_diagonal);
329 OP_DECOMPOSE(diagonal_copy);
330 OP_DECOMPOSE(alias_copy);
331 m.impl("as_strided_copy", native::as_strided_copy_symint);
332 m.impl("pad", native::pad_symint);
333 m.impl("_pad_circular", native::_pad_circular_symint);
334 OP_DECOMPOSE(swapdims_);
335 OP_DECOMPOSE(swapaxes_);
336 OP_DECOMPOSE(unfold_copy);
337 // Easy way to decompose upsample*.vec overloads instead of introducing *_symint methods
338 // if used OP_DECOMPOSE2.
339 m.impl("upsample_bilinear2d.vec", native::upsample_bilinear2d);
340 m.impl("upsample_bicubic2d.vec", native::upsample_bicubic2d);
341 m.impl("_upsample_bilinear2d_aa.vec", native::_upsample_bilinear2d_aa);
342 m.impl("_upsample_bicubic2d_aa.vec", native::_upsample_bicubic2d_aa);
343 m.impl("upsample_linear1d.vec", native::upsample_linear1d);
344 m.impl("upsample_nearest1d.vec", native::upsample_nearest1d);
345 m.impl("upsample_nearest2d.vec", native::upsample_nearest2d);
346 m.impl("upsample_nearest3d.vec", native::upsample_nearest3d);
347 m.impl("upsample_trilinear3d.vec", native::upsample_trilinear3d);
348
349 // views on complex tensor
350 OP_DECOMPOSE(imag);
351 OP_DECOMPOSE(real);
352
353 // divide, alias for div
354 OP_DECOMPOSE2(divide, Tensor);
355 OP_DECOMPOSE2(divide_, Tensor);
356 OP_DECOMPOSE2(divide, Scalar);
357 OP_DECOMPOSE2(divide, Tensor_mode);
358 OP_DECOMPOSE2(divide_, Tensor_mode);
359 OP_DECOMPOSE2(divide, Scalar_mode);
360 OP_DECOMPOSE2(divide_, Scalar_mode);
361
362 // divide, alias for div
363 OP_DECOMPOSE2(true_divide, Tensor);
364 OP_DECOMPOSE2(true_divide_, Tensor);
365 OP_DECOMPOSE2(true_divide, Scalar);
366 OP_DECOMPOSE2(true_divide_, Scalar);
367
368 // multiply, alias for mul
369 OP_DECOMPOSE2(multiply, Tensor)
370 OP_DECOMPOSE2(multiply_, Tensor)
371 OP_DECOMPOSE2(multiply, Scalar)
372 OP_DECOMPOSE2(multiply_, Scalar)
373
374 OP_DECOMPOSE2(linalg_matrix_rank, atol_rtol_tensor);
375 OP_DECOMPOSE2(linalg_matrix_rank, atol_rtol_float);
376 OP_DECOMPOSE(linalg_ldl_factor);
377
378 // comparison ops
379 OP_DECOMPOSE2(greater, Scalar);
380 OP_DECOMPOSE2(less_equal, Scalar);
381 OP_DECOMPOSE2(less, Scalar);
382 OP_DECOMPOSE2(not_equal, Scalar);
383 m.impl("_has_compatible_shallow_copy_type", torch::CppFunction::makeFromBoxedFunction<&unsupportedData>());
384
385 // to.*
386 OP_DECOMPOSE2(to, device);
387 OP_DECOMPOSE2(to, dtype);
388 OP_DECOMPOSE2(to, dtype_layout);
389 OP_DECOMPOSE2(to, other);
390
391 // Random ops that are also registered here
392 OP_DECOMPOSE(dropout);
393 OP_DECOMPOSE(_scaled_dot_product_attention_math);
394 OP_DECOMPOSE(scaled_dot_product_attention);
395 }
396
397 } // namespace at::functorch
398