xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <torch/library.h>
8 #include <ATen/ATen.h>
9 #include <ATen/functorch/LegacyVmapTransforms.h>
10 #include <ATen/functorch/BatchedTensorImpl.h>
11 #include <ATen/functorch/PlumbingHelper.h>
12 #include <ATen/functorch/DynamicLayer.h>
13 #include <ATen/core/dispatch/Dispatcher.h>
14 
15 // functorch's vmap has two Dispatch Keys that implement it:
16 // FuncTorchBatched and FuncTorchVmapMode. This file contains registrations for
17 // FuncTorchVmapMode -- these registrations are to error out on operations
18 // that we don't support on regular Tensors.
19 
20 namespace at::functorch {
21 
unsupportedRandomOp(const c10::OperatorHandle & op,torch::jit::Stack * stack)22 static void unsupportedRandomOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
23   TORCH_CHECK(false, "vmap: We do not support calling out variants of random operations inside of vmap. ",
24               "Please use non-out variants as a workaround");
25 }
26 
TORCH_LIBRARY_IMPL(_,FuncTorchVmapMode,m)27 TORCH_LIBRARY_IMPL(_, FuncTorchVmapMode, m) {
28   m.fallback(torch::CppFunction::makeFallthrough());
29 }
30 
nyiRandomOp(const c10::OperatorHandle & op,torch::jit::Stack * stack)31 static void nyiRandomOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
32   TORCH_CHECK(false, "vmap: we do not yet support ", op.schema().operator_name(),
33               ". Please file an issue");
34 }
35 
36 #define UNSUPPORTED_RANDOM(op) \
37   m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp>());
38 
39 #define UNSUPPORTED_RANDOM2(op, overload) \
40   m.impl(#op"."#overload, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp>());
41 
42 #define NYI_RANDOM(op) \
43   m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&nyiRandomOp>());
44 
45 #define NYI_RANDOM2(op, overload) \
46   m.impl(#op"."#overload, torch::CppFunction::makeFromBoxedFunction<&nyiRandomOp>());
47 
TORCH_LIBRARY_IMPL(aten,FuncTorchVmapMode,m)48 TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
49   UNSUPPORTED_RANDOM2(bernoulli, out);
50   UNSUPPORTED_RANDOM2(rand, generator_out);
51   UNSUPPORTED_RANDOM2(rand, out);
52   UNSUPPORTED_RANDOM2(randint, generator_out);
53   UNSUPPORTED_RANDOM2(randint, out);
54   UNSUPPORTED_RANDOM2(randn, generator_out);
55   UNSUPPORTED_RANDOM2(randn, out);
56   UNSUPPORTED_RANDOM2(randperm, generator_out);
57   UNSUPPORTED_RANDOM2(randperm, out);
58   UNSUPPORTED_RANDOM2(multinomial, out);
59   UNSUPPORTED_RANDOM2(normal, float_Tensor_out);
60   UNSUPPORTED_RANDOM2(normal, Tensor_Tensor_out);
61   UNSUPPORTED_RANDOM2(normal, float_float_out);
62   UNSUPPORTED_RANDOM2(rrelu_with_noise, out);
63 
64   NYI_RANDOM(rrelu_with_noise);
65   NYI_RANDOM(rrelu_with_noise_);
66   NYI_RANDOM(rrelu_);
67   NYI_RANDOM(rrelu);
68 }
69 
70 } // namespace at::functorch
71