1 #include <c10/core/DispatchKey.h>
2 #include <c10/core/SafePyObject.h>
3 #include <c10/core/impl/LocalDispatchKeySet.h>
4 #include <c10/core/impl/TorchDispatchModeTLS.h>
5 #include <c10/util/irange.h>
6
7 #include <utility>
8
9 namespace c10::impl {
10
11 thread_local TorchDispatchModeTLS torchDispatchModeState;
12
any_modes_set(bool skip_infra_modes)13 bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) {
14 if (!torchDispatchModeState.stack_.empty())
15 return true;
16 if (!skip_infra_modes) {
17 for (const auto i : c10::irange(
18 static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
19 if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
20 return true;
21 }
22 }
23 }
24 return false;
25 }
26
push_non_infra_mode_onto_stack(std::shared_ptr<PyObject_TorchDispatchMode> mode)27 void TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
28 std::shared_ptr<PyObject_TorchDispatchMode> mode) {
29 if (!any_modes_set()) {
30 c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
31 c10::impl::tls_set_dispatch_key_included(
32 DispatchKey::PythonTLSSnapshot, true);
33 }
34 torchDispatchModeState.stack_.push_back(std::move(mode));
35 }
36
37 const std::shared_ptr<PyObject_TorchDispatchMode> TorchDispatchModeTLS::
pop_stack()38 pop_stack() {
39 std::shared_ptr<PyObject_TorchDispatchMode> out;
40 if (!torchDispatchModeState.stack_.empty()) {
41 out = torchDispatchModeState.stack_.back();
42 torchDispatchModeState.stack_.pop_back();
43 } else {
44 for (int64_t i =
45 static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
46 i >= 0;
47 --i) {
48 if (torchDispatchModeState.infra_modes_[i].has_value()) {
49 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
50 out = std::move(torchDispatchModeState.infra_modes_[i].value());
51 torchDispatchModeState.infra_modes_[i] = std::nullopt;
52 break;
53 }
54 }
55 }
56 TORCH_CHECK(out, "trying to pop from empty mode stack");
57 if (!any_modes_set()) {
58 c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
59 c10::impl::tls_set_dispatch_key_included(
60 DispatchKey::PythonTLSSnapshot, false);
61 }
62 return out;
63 }
64 const std::
65 tuple<std::shared_ptr<PyObject_TorchDispatchMode>, TorchDispatchModeKey>
pop_highest_infra_mode()66 TorchDispatchModeTLS::pop_highest_infra_mode() {
67 for (int64_t i = static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
68 i >= 0;
69 --i) {
70 if (torchDispatchModeState.infra_modes_[i].has_value()) {
71 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
72 auto out_mode = torchDispatchModeState.infra_modes_[i].value();
73 torchDispatchModeState.infra_modes_[i] = std::nullopt;
74 if (!any_modes_set()) {
75 c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
76 c10::impl::tls_set_dispatch_key_included(
77 DispatchKey::PythonTLSSnapshot, false);
78 }
79 return std::make_tuple(
80 std::move(out_mode), static_cast<TorchDispatchModeKey>(i));
81 }
82 }
83 TORCH_CHECK(
84 false, "Called pop_highest_infra_mode, but no infra modes were active.")
85 }
86
87 const std::shared_ptr<PyObject_TorchDispatchMode>& TorchDispatchModeTLS::
get_stack_at(int64_t idx)88 get_stack_at(int64_t idx) {
89 TORCH_CHECK(idx < stack_len(), "Tried to get stack at idx that's too big");
90 // Our "logical" stack includes both:
91 // - any user modes (the entire torchDispatchModeState.stack_)
92 // - any infra modes (members of torchDispatchModeState.infra_modes_ that are
93 // not None)
94
95 // idx == 0 means the "bottom" of the stack, which starts with any infra
96 // modes (iterating from lowest-priority to highest-priority).
97 auto curr_idx = idx;
98 for (const auto i :
99 c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
100 if (torchDispatchModeState.infra_modes_[i].has_value()) {
101 if (curr_idx == 0) {
102 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
103 return torchDispatchModeState.infra_modes_[i].value();
104 }
105 curr_idx -= 1;
106 }
107 }
108 // At this point, we're guaranteed that curr_idx < stack_.size()
109 return torchDispatchModeState.stack_[curr_idx];
110 }
111
stack_len()112 int64_t TorchDispatchModeTLS::stack_len() {
113 auto stack_len = static_cast<int64_t>(torchDispatchModeState.stack_.size());
114 int64_t infra_modes_len = 0;
115 for (const auto i :
116 c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
117 if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
118 infra_modes_len += 1;
119 }
120 }
121 return stack_len + infra_modes_len;
122 }
123
124 const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
get_mode(TorchDispatchModeKey mode_key)125 TorchDispatchModeTLS::get_mode(TorchDispatchModeKey mode_key) {
126 return torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
127 }
128
set_mode(const std::shared_ptr<PyObject_TorchDispatchMode> & mode,TorchDispatchModeKey mode_key)129 void TorchDispatchModeTLS::set_mode(
130 const std::shared_ptr<PyObject_TorchDispatchMode>& mode,
131 TorchDispatchModeKey mode_key) {
132 TORCH_CHECK(
133 torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] ==
134 std::nullopt,
135 "trying to set the current ",
136 to_string(mode_key),
137 ", but one already exists");
138
139 if (!any_modes_set()) {
140 c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
141 c10::impl::tls_set_dispatch_key_included(
142 DispatchKey::PythonTLSSnapshot, true);
143 }
144
145 torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] = mode;
146 }
147
148 const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
unset_mode(TorchDispatchModeKey mode_key)149 TorchDispatchModeTLS::unset_mode(TorchDispatchModeKey mode_key) {
150 auto out = torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
151 torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] =
152 std::nullopt;
153 if (out.has_value() && !any_modes_set()) {
154 c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
155 c10::impl::tls_set_dispatch_key_included(
156 DispatchKey::PythonTLSSnapshot, false);
157 }
158 return out;
159 }
160
get_state()161 const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
162 return torchDispatchModeState;
163 }
164
set_state(TorchDispatchModeTLS state)165 void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) {
166 torchDispatchModeState = std::move(state);
167 if (!any_modes_set()) {
168 c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
169 c10::impl::tls_set_dispatch_key_included(
170 DispatchKey::PythonTLSSnapshot, false);
171 } else {
172 c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
173 c10::impl::tls_set_dispatch_key_included(
174 DispatchKey::PythonTLSSnapshot, true);
175 }
176 }
177
178 // UTIL
179
dispatch_mode_enabled()180 bool dispatch_mode_enabled() {
181 return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) &&
182 TorchDispatchModeTLS::any_modes_set();
183 }
184
to_string(TorchDispatchModeKey mode_key)185 std::string to_string(TorchDispatchModeKey mode_key) {
186 switch (mode_key) {
187 case TorchDispatchModeKey::PROXY:
188 return "ProxyTorchDispatchMode";
189 case TorchDispatchModeKey::FAKE:
190 return "FakeTensorMode";
191 default:
192 return "UNKNOWN_MODE";
193 }
194 }
195
196 } // namespace c10::impl
197