1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/LocalDispatchKeySet.h> 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker namespace at::impl { 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker // VmapMode contains a thread local count of how many nested vmaps 8*da0073e9SAndroid Build Coastguard Worker // we are currently inside. That number is known as the `vmap level`. 9*da0073e9SAndroid Build Coastguard Worker // VmapMode is used in the implementation of the Python `torch.vmap` API. 10*da0073e9SAndroid Build Coastguard Worker // 11*da0073e9SAndroid Build Coastguard Worker // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet. 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker struct TORCH_API VmapMode { 14*da0073e9SAndroid Build Coastguard Worker // Returns the vmap level, aka the count of how many nested vmaps we're in. 15*da0073e9SAndroid Build Coastguard Worker static int64_t current_vmap_level(); 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker // Increment the count of nested vmaps. If this causes the vmap level to be 18*da0073e9SAndroid Build Coastguard Worker // greater than 0, then it enables DispatchKey::VmapMode on all tensors. 19*da0073e9SAndroid Build Coastguard Worker static int64_t increment_nesting(); 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker // Decrements the count of nested vmaps. If this causes the vmap level to be 22*da0073e9SAndroid Build Coastguard Worker // equal to 0, then it disables DispatchKey::VmapMode on all tensors. 23*da0073e9SAndroid Build Coastguard Worker static int64_t decrement_nesting(); 24*da0073e9SAndroid Build Coastguard Worker }; 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker } // namespace at::impl 27