1 #include <torch/csrc/dynamo/extra_state.h>
2
3 #include <torch/csrc/dynamo/cache_entry.h>
4 #include <torch/csrc/dynamo/cpython_defs.h>
5 #include <torch/csrc/dynamo/debug_macros.h>
6 #include <torch/csrc/dynamo/framelocals_mapping.h>
7 #include <torch/csrc/dynamo/guards.h>
8 #include <torch/csrc/utils/python_compat.h>
9
10 #if IS_PYTHON_3_12_PLUS
11 #define _PyCode_GetExtra PyUnstable_Code_GetExtra
12 #define _PyCode_SetExtra PyUnstable_Code_SetExtra
13 #endif
14
15 Py_ssize_t extra_index = -1;
16
get_first_entry()17 CacheEntry* ExtraState::get_first_entry() {
18 if (this->cache_entry_list.empty()) {
19 return nullptr;
20 }
21 return &this->cache_entry_list.front();
22 }
23
move_to_front(CacheEntry * cache_entry)24 void ExtraState::move_to_front(CacheEntry* cache_entry) {
25 CHECK(cache_entry->_owner == this);
26 CHECK(!this->cache_entry_list.empty());
27 CHECK(cache_entry == &*cache_entry->_owner_loc);
28 this->cache_entry_list.splice(
29 this->cache_entry_list.begin(),
30 this->cache_entry_list,
31 cache_entry->_owner_loc);
32 }
33
invalidate(CacheEntry * cache_entry)34 void ExtraState::invalidate(CacheEntry* cache_entry) {
35 CHECK(cache_entry->_owner == this);
36 CHECK(!this->cache_entry_list.empty());
37 CHECK(cache_entry == &*cache_entry->_owner_loc);
38 this->cache_entry_list.erase(cache_entry->_owner_loc);
39 }
40
is_extra_state_unset(ExtraState * extra_state)41 static bool is_extra_state_unset(ExtraState* extra_state) {
42 return extra_state == nullptr || extra_state == SKIP_CODE ||
43 extra_state == SKIP_CODE_RECURSIVE;
44 }
45
extract_cache_entry(ExtraState * extra_state)46 CacheEntry* extract_cache_entry(ExtraState* extra_state) {
47 if (is_extra_state_unset(extra_state)) {
48 return nullptr;
49 }
50 return extra_state->get_first_entry();
51 }
52
extract_frame_state(ExtraState * extra_state)53 FrameState* extract_frame_state(ExtraState* extra_state) {
54 if (is_extra_state_unset(extra_state)) {
55 return nullptr;
56 }
57 return (FrameState*)extra_state->frame_state.ptr();
58 }
59
get_extra_state(PyCodeObject * code)60 ExtraState* get_extra_state(PyCodeObject* code) {
61 ExtraState* extra = nullptr;
62 _PyCode_GetExtra((PyObject*)code, extra_index, (void**)&extra);
63 return extra;
64 }
65
destroy_extra_state(void * obj)66 void destroy_extra_state(void* obj) {
67 ExtraState* extra = (ExtraState*)obj;
68 if (!is_extra_state_unset(extra)) {
69 delete extra;
70 }
71 }
72
set_extra_state(PyCodeObject * code,ExtraState * extra_state)73 void set_extra_state(PyCodeObject* code, ExtraState* extra_state) {
74 ExtraState* old_extra_state = get_extra_state(code);
75 CHECK(is_extra_state_unset(extra_state) || old_extra_state != extra_state);
76 _PyCode_SetExtra((PyObject*)code, extra_index, extra_state);
77 }
78
init_and_set_extra_state(PyCodeObject * code)79 ExtraState* init_and_set_extra_state(PyCodeObject* code) {
80 // Invariant - Extra state should not have been set before, therefore it
81 // should be nullptr.
82 CHECK(get_extra_state(code) == nullptr);
83 ExtraState* extra_state = new ExtraState();
84 NULL_CHECK(extra_state);
85 set_extra_state(code, extra_state);
86 // freed by destroy_extra_state (since we need to pass these objects to C)
87 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
88 return extra_state;
89 }
90
backend_match(PyObject * saved_backend,PyObject * backend)91 bool backend_match(PyObject* saved_backend, PyObject* backend) {
92 // Pointer equality check for common case
93 if (saved_backend != backend) {
94 // The Py_TYPE check should not be required but there is a pre-existing
95 // issue where backend is possibly deallocated (or nullptr) and causes
96 // segfaults. Check test - test_inplace_custom_op_intermediate
97 return (
98 Py_TYPE(saved_backend) == Py_TYPE(backend) &&
99 PyObject_RichCompareBool(saved_backend, backend, Py_EQ));
100 }
101 return true;
102 }
103
lookup(ExtraState * extra_state,PyObject * f_locals,PyObject * backend)104 PyObject* lookup(
105 ExtraState* extra_state,
106 PyObject* f_locals,
107 PyObject* backend) {
108 size_t index = 0;
109 CacheEntry* found = nullptr;
110 py::handle locals(f_locals);
111 for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
112 // Check backend. Py_False means run only mode.
113
114 bool valid =
115 backend == Py_False || backend_match(cache_entry.backend, backend);
116
117 if (valid) {
118 try {
119 // TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is
120 // True by default
121 if (cache_entry.root_mgr != nullptr) {
122 valid = torch::dynamo::run_root_guard_manager(
123 cache_entry.root_mgr, f_locals);
124 } else {
125 valid = cache_entry.check_fn(locals).cast<bool>();
126 }
127 } catch (py::error_already_set& e) {
128 if (guard_error_hook) {
129 py::handle guard_error_hook_handle(guard_error_hook);
130 guard_error_hook_handle(
131 cache_entry.check_fn,
132 cache_entry.code,
133 locals,
134 index,
135 index == extra_state->cache_entry_list.size() - 1);
136 }
137 // this function is called from C, so we cannot repropagate
138 // the exception
139 e.restore();
140 return nullptr;
141 }
142 }
143 if (valid) {
144 found = &cache_entry;
145 break;
146 }
147 ++index;
148 }
149 if (found) {
150 extra_state->move_to_front(found);
151 return found->code.ptr();
152 }
153 return py::none().ptr();
154 }
155
create_cache_entry(ExtraState * extra_state,PyObject * guarded_code,PyObject * backend)156 CacheEntry* create_cache_entry(
157 ExtraState* extra_state,
158 PyObject* guarded_code,
159 PyObject* backend) {
160 extra_state->cache_entry_list.emplace_front(guarded_code, backend);
161 auto new_iter = extra_state->cache_entry_list.begin();
162 new_iter->_owner = extra_state;
163 new_iter->_owner_loc = new_iter;
164 // Set check_fn references to extra_state and CacheEntry
165 // Warning: lifetime is controlled by C++!
166 py::handle check_fn = py::handle(guarded_code).attr("check_fn");
167 check_fn.attr("cache_entry") =
168 py::cast(*new_iter, py::return_value_policy::reference);
169 check_fn.attr("extra_state") =
170 py::cast(extra_state, py::return_value_policy::reference);
171 return &*new_iter;
172 }
173
_debug_get_cache_entry_list(const py::handle & code_obj)174 py::list _debug_get_cache_entry_list(const py::handle& code_obj) {
175 if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) {
176 throw py::type_error("expected a code object!");
177 }
178 PyCodeObject* code = (PyCodeObject*)code_obj.ptr();
179 ExtraState* extra = get_extra_state(code);
180 py::list result;
181 if (!is_extra_state_unset(extra)) {
182 for (CacheEntry& e : extra->cache_entry_list) {
183 result.append(py::cast(e, py::return_value_policy::reference));
184 }
185 }
186 return result;
187 }
188