1 #include <ATen/PythonTorchFunctionTLS.h> 2 #include <c10/core/TensorImpl.h> 3 4 namespace at::impl { 5 6 static thread_local PythonTorchFunctionTLS pythonTorchFunctionState; 7 push_onto_stack(std::shared_ptr<SafePyObject> mode)8void PythonTorchFunctionTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) { 9 pythonTorchFunctionState.stack_.push_back(std::move(mode)); 10 } 11 pop_stack()12const std::shared_ptr<SafePyObject> PythonTorchFunctionTLS::pop_stack() { 13 TORCH_CHECK(!pythonTorchFunctionState.stack_.empty(), "trying to pop from empty mode stack"); 14 auto out = pythonTorchFunctionState.stack_.back(); 15 pythonTorchFunctionState.stack_.pop_back(); 16 return out; 17 } 18 get_stack_at(int64_t idx)19const std::shared_ptr<SafePyObject>& PythonTorchFunctionTLS::get_stack_at(int64_t idx) { 20 TORCH_CHECK(idx < static_cast<int64_t>(pythonTorchFunctionState.stack_.size()), "Tried to get stack at idx that's too big"); 21 return pythonTorchFunctionState.stack_[idx]; 22 } 23 stack_len()24int64_t PythonTorchFunctionTLS::stack_len() { 25 return static_cast<int64_t>(pythonTorchFunctionState.stack_.size()); 26 } 27 set_disabled_state(TorchFunctionDisabledState disabled_state)28void PythonTorchFunctionTLS::set_disabled_state(TorchFunctionDisabledState disabled_state) { 29 pythonTorchFunctionState.disabled_state_ = disabled_state; 30 } 31 get_disabled_state()32TorchFunctionDisabledState PythonTorchFunctionTLS::get_disabled_state() { 33 return pythonTorchFunctionState.disabled_state_; 34 } 35 set_state(const PythonTorchFunctionTLS & state)36void PythonTorchFunctionTLS::set_state(const PythonTorchFunctionTLS& state) { 37 pythonTorchFunctionState = state; 38 } 39 get_state()40const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() { 41 return pythonTorchFunctionState; 42 } 43 torch_function_mode_enabled()44bool torch_function_mode_enabled() { 45 return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED && 46 PythonTorchFunctionTLS::stack_len() > 0; 47 } 48 49 // This is needed to disambiguate the ternary torch function disabled states torch_function_all_disabled()50bool torch_function_all_disabled() { 51 return PythonTorchFunctionTLS::get_disabled_state() == TorchFunctionDisabledState::ALL_DISABLED; 52 } 53 54 } // namespace at::impl 55