xref: /aosp_15_r20/external/pytorch/aten/src/ATen/VmapModeRegistrations.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/library.h>
2 #include <ATen/core/boxing/KernelFunction.h>
3 
4 using torch::CppFunction;
5 
6 namespace at {
7 
8 // Note: [DispatchKey::VmapMode usage]
9 // Whenever we're inside a vmap, all Tensors dispatch on this key. At the moment,
10 // this key is used to disable random operations inside of vmap. If you are looking
11 // for Batching Rules, those are registered with DispatchKey::Batched instead.
12 //
13 // Note: [Ambiguity of random operations inside vmap]
14 // Random operations have an ambiguity where it isn't clear if they should
15 // apply the same randomness or apply different randomness. For example:
16 //
17 // >>> vmap(lambda t: torch.rand(1))(torch.zeros(5))
18 // Should the above return the same random number 5 times, or a different one?
19 //
20 // We haven't made a decision on that yet so we are temporarily banning random
21 // operations inside of vmap while we gather user feedback.
22 
unsupportedRandomOp(Args...args)23 template <typename... Args> Tensor unsupportedRandomOp(Args... args) {
24   TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. ",
25               "Please perform random operations outside of vmap as a workaround");
26 }
27 
unsupportedRandomOp_(Args...args)28 template <typename... Args> Tensor& unsupportedRandomOp_(Args... args) {
29   TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. ",
30               "Please perform random operations outside of vmap as a workaround");
31 }
32 
TORCH_LIBRARY_IMPL(_,VmapMode,m)33 TORCH_LIBRARY_IMPL(_, VmapMode, m) {
34   m.fallback(torch::CppFunction::makeFallthrough());
35 }
36 
TORCH_LIBRARY_IMPL(aten,VmapMode,m)37 TORCH_LIBRARY_IMPL(aten, VmapMode, m) {
38   // NB: I'd really like to register a special kernel like
39   // CppFunction::makeNamedNotSupported() to avoid listing out the types of everything.
40   // However, registering e.g. CppFunction::makeNamedNotSupported() as an implementation
41   // only works for operators that support boxing.
42 #define TENSOROPTIONS std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>
43 
44   // random operations (out-of-place)
45   m.impl("bernoulli", unsupportedRandomOp<const Tensor&, std::optional<Generator>>);
46   m.impl("bernoulli.out", unsupportedRandomOp_<const Tensor&, std::optional<Generator>, Tensor&>);
47   m.impl("bernoulli.p", unsupportedRandomOp<const Tensor&, double, std::optional<Generator>>);
48   m.impl("bernoulli_.Tensor", unsupportedRandomOp_<Tensor&, const Tensor&, std::optional<Generator>>);
49   m.impl("bernoulli_.float", unsupportedRandomOp_<Tensor&, double, std::optional<Generator>>);
50 
51   m.impl("cauchy_", unsupportedRandomOp_<Tensor&, double, double, std::optional<Generator>>);
52   m.impl("exponential_", unsupportedRandomOp_<Tensor&, double, std::optional<Generator>>);
53   m.impl("geometric_", unsupportedRandomOp_<Tensor&, double, std::optional<Generator>>);
54   m.impl("log_normal_", unsupportedRandomOp_<Tensor&, double, double, std::optional<Generator>>);
55   m.impl("multinomial", unsupportedRandomOp<const Tensor&, int64_t, bool, std::optional<Generator>>);
56   m.impl("multinomial.out", unsupportedRandomOp_<const Tensor&, int64_t, bool, std::optional<Generator>, Tensor&>);
57 
58   m.impl("normal.Tensor_float", unsupportedRandomOp<const Tensor&, double, std::optional<Generator>>);
59   m.impl("normal.Tensor_float_out", unsupportedRandomOp_<const Tensor&, double, std::optional<Generator>, Tensor&>);
60   m.impl("normal.float_Tensor_out", unsupportedRandomOp_<double, const Tensor&, std::optional<Generator>, Tensor&>);
61   m.impl("normal.float_Tensor", unsupportedRandomOp<double, const Tensor&, std::optional<Generator>>);
62   m.impl("normal.Tensor_Tensor", unsupportedRandomOp<const Tensor&, const Tensor&, std::optional<Generator>>);
63   m.impl("normal.Tensor_Tensor_out", unsupportedRandomOp_<const Tensor&, const Tensor&, std::optional<Generator>, Tensor&>);
64   m.impl("normal.float_float", unsupportedRandomOp<double, double, IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
65   m.impl("normal.float_float_out", unsupportedRandomOp_<double, double, IntArrayRef, std::optional<Generator>, Tensor&>);
66   m.impl("normal_", unsupportedRandomOp_<Tensor&, double, double, std::optional<Generator>>);
67 
68   m.impl("poisson", unsupportedRandomOp<const Tensor&, std::optional<Generator>>);
69 
70   m.impl("random_.from", unsupportedRandomOp_<Tensor&, int64_t, std::optional<int64_t>, std::optional<Generator>>);
71   m.impl("random_.to", unsupportedRandomOp_<Tensor&, int64_t, std::optional<Generator>>);
72   m.impl("random_", unsupportedRandomOp_<Tensor&, std::optional<Generator>>);
73 
74   m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
75   m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
76 
77   m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
78   m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
79 
80   m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
81   m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
82   m.impl("rand.names", unsupportedRandomOp<IntArrayRef, std::optional<DimnameList>, TENSOROPTIONS>);
83   m.impl("rand.generator_with_names", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, std::optional<DimnameList>, TENSOROPTIONS>);
84   m.impl("rand.out", unsupportedRandomOp_<IntArrayRef, Tensor&>);
85   m.impl("rand.generator_out", unsupportedRandomOp_<IntArrayRef, std::optional<Generator>, Tensor&>);
86 
87   m.impl("randn", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
88   m.impl("randn.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
89   m.impl("randn.names", unsupportedRandomOp<IntArrayRef, std::optional<DimnameList>, TENSOROPTIONS>);
90   m.impl("randn.generator_with_names", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, std::optional<DimnameList>, TENSOROPTIONS>);
91   m.impl("randn.out", unsupportedRandomOp_<IntArrayRef, Tensor&>);
92   m.impl("randn.generator_out", unsupportedRandomOp_<IntArrayRef, std::optional<Generator>, Tensor&>);
93 
94   m.impl("randperm", unsupportedRandomOp<int64_t, TENSOROPTIONS>);
95   m.impl("randperm.generator", unsupportedRandomOp<int64_t, std::optional<Generator>, TENSOROPTIONS>);
96   m.impl("randperm.out", unsupportedRandomOp_<int64_t, Tensor&>);
97   m.impl("randperm.generator_out", unsupportedRandomOp_<int64_t, std::optional<Generator>, Tensor&>);
98 
99   m.impl("randint", unsupportedRandomOp<int64_t, IntArrayRef, TENSOROPTIONS>);
100   m.impl("randint.generator", unsupportedRandomOp<int64_t, IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
101   m.impl("randint.low", unsupportedRandomOp<int64_t, int64_t, IntArrayRef, TENSOROPTIONS>);
102   m.impl("randint.low_generator", unsupportedRandomOp<int64_t, int64_t, IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
103   m.impl("randint.out", unsupportedRandomOp_<int64_t, IntArrayRef, Tensor&>);
104   m.impl("randint.generator_out", unsupportedRandomOp_<int64_t, IntArrayRef, std::optional<Generator>, Tensor&>);
105   m.impl("randint.low_out", unsupportedRandomOp_<int64_t, int64_t, IntArrayRef, Tensor&>);
106   m.impl("randint.low_generator_out", unsupportedRandomOp_<int64_t, int64_t, IntArrayRef, std::optional<Generator>, Tensor&>);
107 
108   m.impl("uniform_", unsupportedRandomOp_<Tensor&, double, double, std::optional<Generator>>);
109 
110 #undef TENSOROPTIONS
111 }
112 
113 
114 } // namespace at
115