xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchedFallback.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #pragma once
8 #include <ATen/ATen.h>
9 #include <ATen/core/op_registration/op_registration.h>
10 #include <torch/library.h>
11 
12 namespace at::functorch {
13 
14 // This file contains code for the vmap fallback (also known as the
15 // BatchedTensor fallback or the Batched fallback). This code runs
16 // when an operation doesn't have a batching rule implemented.
17 
18 // If an operator doesn't have a batching rule implemented then we fallback
19 // to this implementation. The fallback doesn't work on out= variants or
20 // view operations; that is, it works for out-of-place operations and
21 // in-place non-view operations.
22 //
23 // For out-of-place operations, the fallback effectively takes all of the
24 // BatchedTensors in `stack`, slices them, and runs `op` on all of the
25 // corresponding slices to produce slices of the outputs. The output slices
26 // then get `torch.stack`ed to create the
27 // final returns.
28 //
29 // The performance of the fallback is not very good because it introduces an
30 // extra copy from stacking the sliced outputs. Because of this, we prefer to
31 // write batching rules for operators whenever possible.
32 void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
33 void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
34 
35 void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
36 
37 // The vmap fallback emits a warning by default, but it may be disabled if
38 // the user finds it to be too annoying.
39 TORCH_API bool isVmapFallbackWarningEnabled();
40 TORCH_API void setVmapFallbackWarningEnabled(bool enabled);
41 
42 // Used for testing. The vmap fallback is enabled by default. When it is disabled,
43 // it raises an error.
44 TORCH_API bool isVmapFallbackEnabled();
45 TORCH_API void setVmapFallbackEnabled(bool enabled);
46 
vector_to_result(const std::vector<IValue> & buffer)47 template <typename A> A vector_to_result(const std::vector<IValue>& buffer) {
48   return buffer[0].to<A>();
49 }
vector_to_result(const std::vector<IValue> & buffer)50 template <typename A, typename B> std::tuple<A, B> vector_to_result(const std::vector<IValue>& buffer) {
51   return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>());
52 }
vector_to_result(const std::vector<IValue> & buffer)53 template <typename A, typename B, typename C> std::tuple<A, B, C> vector_to_result(const std::vector<IValue>& buffer) {
54   return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>(), buffer[2].to<B>());
55 }
56 
57 // slow_fallback is a way to call the vmap fallback inside some boxed kernel.
58 // There is probably some better way to metaprogram this.
59 template <typename Ret>
slow_fallback(const c10::OperatorHandle & op,ArrayRef<IValue> args)60 Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
61   std::vector<IValue> stack(args.begin(), args.end());
62   batchedTensorForLoopFallback(op, &stack);
63   return vector_to_result<Ret>(stack);
64 }
65 
66 template <typename A, typename B>
slow_fallback(const c10::OperatorHandle & op,ArrayRef<IValue> args)67 std::tuple<A, B> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
68   std::vector<IValue> stack(args.begin(), args.end());
69   batchedTensorForLoopFallback(op, &stack);
70   return vector_to_result<A, B>(stack);
71 }
72 
73 template <typename A, typename B, typename C>
slow_fallback(const c10::OperatorHandle & op,ArrayRef<IValue> args)74 std::tuple<A, B, C> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
75   std::vector<IValue> stack(args.begin(), args.end());
76   batchedTensorForLoopFallback(op, &stack);
77   return vector_to_result<A, B, C>(stack);
78 }
79 
80 
81 } // namespace at::functorch
82