xref: /aosp_15_r20/external/pytorch/aten/src/ATen/PythonTorchFunctionTLS.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/PythonTorchFunctionTLS.h>
2 #include <c10/core/TensorImpl.h>
3 
4 namespace at::impl {
5 
6 static thread_local PythonTorchFunctionTLS pythonTorchFunctionState;
7 
push_onto_stack(std::shared_ptr<SafePyObject> mode)8 void PythonTorchFunctionTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
9   pythonTorchFunctionState.stack_.push_back(std::move(mode));
10 }
11 
pop_stack()12 const std::shared_ptr<SafePyObject> PythonTorchFunctionTLS::pop_stack() {
13   TORCH_CHECK(!pythonTorchFunctionState.stack_.empty(), "trying to pop from empty mode stack");
14   auto out = pythonTorchFunctionState.stack_.back();
15   pythonTorchFunctionState.stack_.pop_back();
16   return out;
17 }
18 
get_stack_at(int64_t idx)19 const std::shared_ptr<SafePyObject>& PythonTorchFunctionTLS::get_stack_at(int64_t idx) {
20   TORCH_CHECK(idx < static_cast<int64_t>(pythonTorchFunctionState.stack_.size()), "Tried to get stack at idx that's too big");
21   return pythonTorchFunctionState.stack_[idx];
22 }
23 
stack_len()24 int64_t PythonTorchFunctionTLS::stack_len() {
25   return static_cast<int64_t>(pythonTorchFunctionState.stack_.size());
26 }
27 
set_disabled_state(TorchFunctionDisabledState disabled_state)28 void PythonTorchFunctionTLS::set_disabled_state(TorchFunctionDisabledState disabled_state) {
29   pythonTorchFunctionState.disabled_state_ = disabled_state;
30 }
31 
get_disabled_state()32 TorchFunctionDisabledState PythonTorchFunctionTLS::get_disabled_state() {
33   return pythonTorchFunctionState.disabled_state_;
34 }
35 
set_state(const PythonTorchFunctionTLS & state)36 void PythonTorchFunctionTLS::set_state(const PythonTorchFunctionTLS& state) {
37   pythonTorchFunctionState = state;
38 }
39 
get_state()40 const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
41   return pythonTorchFunctionState;
42 }
43 
torch_function_mode_enabled()44 bool torch_function_mode_enabled() {
45   return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED &&
46          PythonTorchFunctionTLS::stack_len() > 0;
47 }
48 
49 // This is needed to disambiguate the ternary torch function disabled states
torch_function_all_disabled()50 bool torch_function_all_disabled() {
51   return PythonTorchFunctionTLS::get_disabled_state() == TorchFunctionDisabledState::ALL_DISABLED;
52 }
53 
54 } // namespace at::impl
55