xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyVmapMode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/LocalDispatchKeySet.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker namespace at::impl {
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker // VmapMode contains a thread local count of how many nested vmaps
8*da0073e9SAndroid Build Coastguard Worker // we are currently inside. That number is known as the `vmap level`.
9*da0073e9SAndroid Build Coastguard Worker // VmapMode is used in the implementation of the Python `torch.vmap` API.
10*da0073e9SAndroid Build Coastguard Worker //
11*da0073e9SAndroid Build Coastguard Worker // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker struct TORCH_API VmapMode {
14*da0073e9SAndroid Build Coastguard Worker   // Returns the vmap level, aka the count of how many nested vmaps we're in.
15*da0073e9SAndroid Build Coastguard Worker   static int64_t current_vmap_level();
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker   // Increment the count of nested vmaps. If this causes the vmap level to be
18*da0073e9SAndroid Build Coastguard Worker   // greater than 0, then it enables DispatchKey::VmapMode on all tensors.
19*da0073e9SAndroid Build Coastguard Worker   static int64_t increment_nesting();
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker   // Decrements the count of nested vmaps. If this causes the vmap level to be
22*da0073e9SAndroid Build Coastguard Worker   // equal to 0, then it disables DispatchKey::VmapMode on all tensors.
23*da0073e9SAndroid Build Coastguard Worker   static int64_t decrement_nesting();
24*da0073e9SAndroid Build Coastguard Worker };
25*da0073e9SAndroid Build Coastguard Worker 
26*da0073e9SAndroid Build Coastguard Worker } // namespace at::impl
27