1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <torch/library.h>
8 #include <ATen/ATen.h>
9 #include <ATen/functorch/LegacyVmapTransforms.h>
10 #include <ATen/functorch/BatchedTensorImpl.h>
11 #include <ATen/functorch/PlumbingHelper.h>
12 #include <ATen/functorch/DynamicLayer.h>
13 #include <ATen/core/dispatch/Dispatcher.h>
14
15 // functorch's vmap has two Dispatch Keys that implement it:
16 // FuncTorchBatched and FuncTorchVmapMode. This file contains registrations for
17 // FuncTorchVmapMode -- these registrations are to error out on operations
18 // that we don't support on regular Tensors.
19
20 namespace at::functorch {
21
unsupportedRandomOp(const c10::OperatorHandle & op,torch::jit::Stack * stack)22 static void unsupportedRandomOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
23 TORCH_CHECK(false, "vmap: We do not support calling out variants of random operations inside of vmap. ",
24 "Please use non-out variants as a workaround");
25 }
26
TORCH_LIBRARY_IMPL(_,FuncTorchVmapMode,m)27 TORCH_LIBRARY_IMPL(_, FuncTorchVmapMode, m) {
28 m.fallback(torch::CppFunction::makeFallthrough());
29 }
30
nyiRandomOp(const c10::OperatorHandle & op,torch::jit::Stack * stack)31 static void nyiRandomOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
32 TORCH_CHECK(false, "vmap: we do not yet support ", op.schema().operator_name(),
33 ". Please file an issue");
34 }
35
36 #define UNSUPPORTED_RANDOM(op) \
37 m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp>());
38
39 #define UNSUPPORTED_RANDOM2(op, overload) \
40 m.impl(#op"."#overload, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp>());
41
42 #define NYI_RANDOM(op) \
43 m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&nyiRandomOp>());
44
45 #define NYI_RANDOM2(op, overload) \
46 m.impl(#op"."#overload, torch::CppFunction::makeFromBoxedFunction<&nyiRandomOp>());
47
TORCH_LIBRARY_IMPL(aten,FuncTorchVmapMode,m)48 TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
49 UNSUPPORTED_RANDOM2(bernoulli, out);
50 UNSUPPORTED_RANDOM2(rand, generator_out);
51 UNSUPPORTED_RANDOM2(rand, out);
52 UNSUPPORTED_RANDOM2(randint, generator_out);
53 UNSUPPORTED_RANDOM2(randint, out);
54 UNSUPPORTED_RANDOM2(randn, generator_out);
55 UNSUPPORTED_RANDOM2(randn, out);
56 UNSUPPORTED_RANDOM2(randperm, generator_out);
57 UNSUPPORTED_RANDOM2(randperm, out);
58 UNSUPPORTED_RANDOM2(multinomial, out);
59 UNSUPPORTED_RANDOM2(normal, float_Tensor_out);
60 UNSUPPORTED_RANDOM2(normal, Tensor_Tensor_out);
61 UNSUPPORTED_RANDOM2(normal, float_float_out);
62 UNSUPPORTED_RANDOM2(rrelu_with_noise, out);
63
64 NYI_RANDOM(rrelu_with_noise);
65 NYI_RANDOM(rrelu_with_noise_);
66 NYI_RANDOM(rrelu_);
67 NYI_RANDOM(rrelu);
68 }
69
70 } // namespace at::functorch
71