xref: /aosp_15_r20/external/pytorch/c10/core/impl/TorchDispatchModeTLS.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/DispatchKey.h>
2 #include <c10/core/SafePyObject.h>
3 #include <c10/core/impl/LocalDispatchKeySet.h>
4 #include <c10/core/impl/TorchDispatchModeTLS.h>
5 #include <c10/util/irange.h>
6 
7 #include <utility>
8 
9 namespace c10::impl {
10 
11 thread_local TorchDispatchModeTLS torchDispatchModeState;
12 
any_modes_set(bool skip_infra_modes)13 bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) {
14   if (!torchDispatchModeState.stack_.empty())
15     return true;
16   if (!skip_infra_modes) {
17     for (const auto i : c10::irange(
18              static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
19       if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
20         return true;
21       }
22     }
23   }
24   return false;
25 }
26 
push_non_infra_mode_onto_stack(std::shared_ptr<PyObject_TorchDispatchMode> mode)27 void TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
28     std::shared_ptr<PyObject_TorchDispatchMode> mode) {
29   if (!any_modes_set()) {
30     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
31     c10::impl::tls_set_dispatch_key_included(
32         DispatchKey::PythonTLSSnapshot, true);
33   }
34   torchDispatchModeState.stack_.push_back(std::move(mode));
35 }
36 
37 const std::shared_ptr<PyObject_TorchDispatchMode> TorchDispatchModeTLS::
pop_stack()38     pop_stack() {
39   std::shared_ptr<PyObject_TorchDispatchMode> out;
40   if (!torchDispatchModeState.stack_.empty()) {
41     out = torchDispatchModeState.stack_.back();
42     torchDispatchModeState.stack_.pop_back();
43   } else {
44     for (int64_t i =
45              static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
46          i >= 0;
47          --i) {
48       if (torchDispatchModeState.infra_modes_[i].has_value()) {
49         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
50         out = std::move(torchDispatchModeState.infra_modes_[i].value());
51         torchDispatchModeState.infra_modes_[i] = std::nullopt;
52         break;
53       }
54     }
55   }
56   TORCH_CHECK(out, "trying to pop from empty mode stack");
57   if (!any_modes_set()) {
58     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
59     c10::impl::tls_set_dispatch_key_included(
60         DispatchKey::PythonTLSSnapshot, false);
61   }
62   return out;
63 }
64 const std::
65     tuple<std::shared_ptr<PyObject_TorchDispatchMode>, TorchDispatchModeKey>
pop_highest_infra_mode()66     TorchDispatchModeTLS::pop_highest_infra_mode() {
67   for (int64_t i = static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
68        i >= 0;
69        --i) {
70     if (torchDispatchModeState.infra_modes_[i].has_value()) {
71       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
72       auto out_mode = torchDispatchModeState.infra_modes_[i].value();
73       torchDispatchModeState.infra_modes_[i] = std::nullopt;
74       if (!any_modes_set()) {
75         c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
76         c10::impl::tls_set_dispatch_key_included(
77             DispatchKey::PythonTLSSnapshot, false);
78       }
79       return std::make_tuple(
80           std::move(out_mode), static_cast<TorchDispatchModeKey>(i));
81     }
82   }
83   TORCH_CHECK(
84       false, "Called pop_highest_infra_mode, but no infra modes were active.")
85 }
86 
87 const std::shared_ptr<PyObject_TorchDispatchMode>& TorchDispatchModeTLS::
get_stack_at(int64_t idx)88     get_stack_at(int64_t idx) {
89   TORCH_CHECK(idx < stack_len(), "Tried to get stack at idx that's too big");
90   // Our "logical" stack includes both:
91   // - any user modes (the entire torchDispatchModeState.stack_)
92   // - any infra modes (members of torchDispatchModeState.infra_modes_ that are
93   // not None)
94 
95   // idx == 0 means the "bottom" of the stack, which starts with any infra
96   // modes (iterating from lowest-priority to highest-priority).
97   auto curr_idx = idx;
98   for (const auto i :
99        c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
100     if (torchDispatchModeState.infra_modes_[i].has_value()) {
101       if (curr_idx == 0) {
102         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
103         return torchDispatchModeState.infra_modes_[i].value();
104       }
105       curr_idx -= 1;
106     }
107   }
108   // At this point, we're guaranteed that curr_idx < stack_.size()
109   return torchDispatchModeState.stack_[curr_idx];
110 }
111 
stack_len()112 int64_t TorchDispatchModeTLS::stack_len() {
113   auto stack_len = static_cast<int64_t>(torchDispatchModeState.stack_.size());
114   int64_t infra_modes_len = 0;
115   for (const auto i :
116        c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
117     if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
118       infra_modes_len += 1;
119     }
120   }
121   return stack_len + infra_modes_len;
122 }
123 
124 const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
get_mode(TorchDispatchModeKey mode_key)125 TorchDispatchModeTLS::get_mode(TorchDispatchModeKey mode_key) {
126   return torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
127 }
128 
set_mode(const std::shared_ptr<PyObject_TorchDispatchMode> & mode,TorchDispatchModeKey mode_key)129 void TorchDispatchModeTLS::set_mode(
130     const std::shared_ptr<PyObject_TorchDispatchMode>& mode,
131     TorchDispatchModeKey mode_key) {
132   TORCH_CHECK(
133       torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] ==
134           std::nullopt,
135       "trying to set the current ",
136       to_string(mode_key),
137       ", but one already exists");
138 
139   if (!any_modes_set()) {
140     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
141     c10::impl::tls_set_dispatch_key_included(
142         DispatchKey::PythonTLSSnapshot, true);
143   }
144 
145   torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] = mode;
146 }
147 
148 const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
unset_mode(TorchDispatchModeKey mode_key)149 TorchDispatchModeTLS::unset_mode(TorchDispatchModeKey mode_key) {
150   auto out = torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
151   torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] =
152       std::nullopt;
153   if (out.has_value() && !any_modes_set()) {
154     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
155     c10::impl::tls_set_dispatch_key_included(
156         DispatchKey::PythonTLSSnapshot, false);
157   }
158   return out;
159 }
160 
get_state()161 const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
162   return torchDispatchModeState;
163 }
164 
set_state(TorchDispatchModeTLS state)165 void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) {
166   torchDispatchModeState = std::move(state);
167   if (!any_modes_set()) {
168     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
169     c10::impl::tls_set_dispatch_key_included(
170         DispatchKey::PythonTLSSnapshot, false);
171   } else {
172     c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
173     c10::impl::tls_set_dispatch_key_included(
174         DispatchKey::PythonTLSSnapshot, true);
175   }
176 }
177 
178 // UTIL
179 
dispatch_mode_enabled()180 bool dispatch_mode_enabled() {
181   return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) &&
182       TorchDispatchModeTLS::any_modes_set();
183 }
184 
to_string(TorchDispatchModeKey mode_key)185 std::string to_string(TorchDispatchModeKey mode_key) {
186   switch (mode_key) {
187     case TorchDispatchModeKey::PROXY:
188       return "ProxyTorchDispatchMode";
189     case TorchDispatchModeKey::FAKE:
190       return "FakeTensorMode";
191     default:
192       return "UNKNOWN_MODE";
193   }
194 }
195 
196 } // namespace c10::impl
197