xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesDynamic.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/ATen.h>
8 #include <ATen/functorch/BatchRulesHelper.h>
9 #include <ATen/functorch/BatchedFallback.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 #include <c10/util/Metaprogramming.h>
12 
13 // This file contains batching rules for operations that return Tensors of
14 // dynamic shape. We generally don't support those with vmap so we raise
15 // errors for them.
16 
17 
18 namespace at::functorch {
19 
20 namespace {
unsupportedDynamicOp(const c10::OperatorHandle & op,torch::jit::Stack * stack)21 void unsupportedDynamicOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
22     TORCH_CHECK(false, "vmap: We do not support batching operators that can output dynamic shape. ",
23         "Attempted to vmap over ", op.schema().operator_name(), ". ",
24         "Please voice your support in https://github.com/pytorch/functorch/issues/256");
25 }
26 #define UNSUPPORTED_DYNAMIC(op) \
27     m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&unsupportedDynamicOp>());
28 
29 // NB: item and is_nonzero can decompose to this...
unsupportedLocalScalarDense(const c10::OperatorHandle & op,torch::jit::Stack * stack)30 void unsupportedLocalScalarDense(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
31     TORCH_CHECK(false,
32         "vmap: It looks like you're either (1) calling .item() on a Tensor or ",
33         "(2) attempting to use a Tensor in some data-dependent control flow or ",
34         "(3) encountering this error in PyTorch internals. ",
35         "For (1): we don't support vmap over calling .item() on a Tensor, please try to ",
36         "rewrite what you're doing with other operations. ",
37         "For (2): If you're doing some ",
38         "control flow instead, we don't support that yet, please shout over at ",
39         "https://github.com/pytorch/functorch/issues/257 . ",
40         "For (3): please file an issue.");
41 }
42 
unsupportedItem(const c10::OperatorHandle & op,torch::jit::Stack * stack)43 void unsupportedItem(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
44     TORCH_CHECK(false,
45         "vmap: It looks like you're calling .item() on a Tensor. ",
46         "We don't support vmap over calling .item() on a Tensor, please try to ",
47         "rewrite what you're doing with other operations. If error is occurring ",
48         "somewhere inside PyTorch internals, please file a bug report.");
49 }
50 
unsupportedIsNonzero(const c10::OperatorHandle & op,torch::jit::Stack * stack)51 void unsupportedIsNonzero(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
52     TORCH_CHECK(false,
53         "vmap: It looks like you're attempting to use a Tensor in some ",
54         "data-dependent control flow. ",
55         "We don't support that yet, please shout over at ",
56         "https://github.com/pytorch/functorch/issues/257 .");
57 }
58 
unsupportedAllclose(const c10::OperatorHandle & op,torch::jit::Stack * stack)59 void unsupportedAllclose(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
60     TORCH_CHECK(false,
61         "vmap over torch.allclose isn't supported yet. Please voice your ",
62         "support over at github.com/pytorch/functorch/issues/275");
63 }
64 }
65 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)66 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
67     UNSUPPORTED_DYNAMIC(nonzero);
68     UNSUPPORTED_DYNAMIC(where);
69     UNSUPPORTED_DYNAMIC(unique_dim);
70     UNSUPPORTED_DYNAMIC(unique_consecutive);
71     UNSUPPORTED_DYNAMIC(unique_dim_consecutive);
72     UNSUPPORTED_DYNAMIC(_unique2);
73     m.impl("_local_scalar_dense", torch::CppFunction::makeFromBoxedFunction<&unsupportedLocalScalarDense>());
74     m.impl("item", torch::CppFunction::makeFromBoxedFunction<&unsupportedItem>());
75     m.impl("is_nonzero", torch::CppFunction::makeFromBoxedFunction<&unsupportedIsNonzero>());
76     m.impl("allclose", torch::CppFunction::makeFromBoxedFunction<&unsupportedAllclose>());
77 }
78 
79 } // namespace at::functorch
80