xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyBatchedFallback.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/ATen.h>
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/op_registration/op_registration.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/library.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker namespace at {
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker // If an operator doesn't have a batching rule implemented then we fallback
9*da0073e9SAndroid Build Coastguard Worker // to this implementation. The fallback only works on out-of-place operators
10*da0073e9SAndroid Build Coastguard Worker // that return only tensors with new memory. (e.g., no in-place operators, no
11*da0073e9SAndroid Build Coastguard Worker // view operations).
12*da0073e9SAndroid Build Coastguard Worker //
13*da0073e9SAndroid Build Coastguard Worker // The fallback effectively takes all of the BatchedTensors in `stack`, slices
14*da0073e9SAndroid Build Coastguard Worker // them, and runs `op` on all of the corresponding slices to produce slices
15*da0073e9SAndroid Build Coastguard Worker // of the outputs. The output slices then get `torch.stack`ed to create the
16*da0073e9SAndroid Build Coastguard Worker // final returns.
17*da0073e9SAndroid Build Coastguard Worker //
18*da0073e9SAndroid Build Coastguard Worker // The performance of the fallback is not very good because it introduces an
19*da0073e9SAndroid Build Coastguard Worker // extra copy from stacking the sliced outputs. Because of this, we prefer to
20*da0073e9SAndroid Build Coastguard Worker // write batching rules for operators whenever possible.
21*da0073e9SAndroid Build Coastguard Worker void batchedTensorForLoopFallback(
22*da0073e9SAndroid Build Coastguard Worker     const c10::OperatorHandle& op,
23*da0073e9SAndroid Build Coastguard Worker     torch::jit::Stack* stack);
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker } // namespace at
26