xref: /aosp_15_r20/external/pytorch/torch/csrc/dynamo/extra_state.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/dynamo/extra_state.h>
2 
3 #include <torch/csrc/dynamo/cache_entry.h>
4 #include <torch/csrc/dynamo/cpython_defs.h>
5 #include <torch/csrc/dynamo/debug_macros.h>
6 #include <torch/csrc/dynamo/framelocals_mapping.h>
7 #include <torch/csrc/dynamo/guards.h>
8 #include <torch/csrc/utils/python_compat.h>
9 
10 #if IS_PYTHON_3_12_PLUS
11 #define _PyCode_GetExtra PyUnstable_Code_GetExtra
12 #define _PyCode_SetExtra PyUnstable_Code_SetExtra
13 #endif
14 
15 Py_ssize_t extra_index = -1;
16 
get_first_entry()17 CacheEntry* ExtraState::get_first_entry() {
18   if (this->cache_entry_list.empty()) {
19     return nullptr;
20   }
21   return &this->cache_entry_list.front();
22 }
23 
move_to_front(CacheEntry * cache_entry)24 void ExtraState::move_to_front(CacheEntry* cache_entry) {
25   CHECK(cache_entry->_owner == this);
26   CHECK(!this->cache_entry_list.empty());
27   CHECK(cache_entry == &*cache_entry->_owner_loc);
28   this->cache_entry_list.splice(
29       this->cache_entry_list.begin(),
30       this->cache_entry_list,
31       cache_entry->_owner_loc);
32 }
33 
invalidate(CacheEntry * cache_entry)34 void ExtraState::invalidate(CacheEntry* cache_entry) {
35   CHECK(cache_entry->_owner == this);
36   CHECK(!this->cache_entry_list.empty());
37   CHECK(cache_entry == &*cache_entry->_owner_loc);
38   this->cache_entry_list.erase(cache_entry->_owner_loc);
39 }
40 
is_extra_state_unset(ExtraState * extra_state)41 static bool is_extra_state_unset(ExtraState* extra_state) {
42   return extra_state == nullptr || extra_state == SKIP_CODE ||
43       extra_state == SKIP_CODE_RECURSIVE;
44 }
45 
extract_cache_entry(ExtraState * extra_state)46 CacheEntry* extract_cache_entry(ExtraState* extra_state) {
47   if (is_extra_state_unset(extra_state)) {
48     return nullptr;
49   }
50   return extra_state->get_first_entry();
51 }
52 
extract_frame_state(ExtraState * extra_state)53 FrameState* extract_frame_state(ExtraState* extra_state) {
54   if (is_extra_state_unset(extra_state)) {
55     return nullptr;
56   }
57   return (FrameState*)extra_state->frame_state.ptr();
58 }
59 
get_extra_state(PyCodeObject * code)60 ExtraState* get_extra_state(PyCodeObject* code) {
61   ExtraState* extra = nullptr;
62   _PyCode_GetExtra((PyObject*)code, extra_index, (void**)&extra);
63   return extra;
64 }
65 
destroy_extra_state(void * obj)66 void destroy_extra_state(void* obj) {
67   ExtraState* extra = (ExtraState*)obj;
68   if (!is_extra_state_unset(extra)) {
69     delete extra;
70   }
71 }
72 
set_extra_state(PyCodeObject * code,ExtraState * extra_state)73 void set_extra_state(PyCodeObject* code, ExtraState* extra_state) {
74   ExtraState* old_extra_state = get_extra_state(code);
75   CHECK(is_extra_state_unset(extra_state) || old_extra_state != extra_state);
76   _PyCode_SetExtra((PyObject*)code, extra_index, extra_state);
77 }
78 
init_and_set_extra_state(PyCodeObject * code)79 ExtraState* init_and_set_extra_state(PyCodeObject* code) {
80   // Invariant - Extra state should not have been set before, therefore it
81   // should be nullptr.
82   CHECK(get_extra_state(code) == nullptr);
83   ExtraState* extra_state = new ExtraState();
84   NULL_CHECK(extra_state);
85   set_extra_state(code, extra_state);
86   // freed by destroy_extra_state (since we need to pass these objects to C)
87   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
88   return extra_state;
89 }
90 
backend_match(PyObject * saved_backend,PyObject * backend)91 bool backend_match(PyObject* saved_backend, PyObject* backend) {
92   // Pointer equality check for common case
93   if (saved_backend != backend) {
94     // The Py_TYPE check should not be required but there is a pre-existing
95     // issue where backend is possibly deallocated (or nullptr) and causes
96     // segfaults. Check test - test_inplace_custom_op_intermediate
97     return (
98         Py_TYPE(saved_backend) == Py_TYPE(backend) &&
99         PyObject_RichCompareBool(saved_backend, backend, Py_EQ));
100   }
101   return true;
102 }
103 
lookup(ExtraState * extra_state,PyObject * f_locals,PyObject * backend)104 PyObject* lookup(
105     ExtraState* extra_state,
106     PyObject* f_locals,
107     PyObject* backend) {
108   size_t index = 0;
109   CacheEntry* found = nullptr;
110   py::handle locals(f_locals);
111   for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
112     // Check backend. Py_False means run only mode.
113 
114     bool valid =
115         backend == Py_False || backend_match(cache_entry.backend, backend);
116 
117     if (valid) {
118       try {
119         // TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is
120         // True by default
121         if (cache_entry.root_mgr != nullptr) {
122           valid = torch::dynamo::run_root_guard_manager(
123               cache_entry.root_mgr, f_locals);
124         } else {
125           valid = cache_entry.check_fn(locals).cast<bool>();
126         }
127       } catch (py::error_already_set& e) {
128         if (guard_error_hook) {
129           py::handle guard_error_hook_handle(guard_error_hook);
130           guard_error_hook_handle(
131               cache_entry.check_fn,
132               cache_entry.code,
133               locals,
134               index,
135               index == extra_state->cache_entry_list.size() - 1);
136         }
137         // this function is called from C, so we cannot repropagate
138         // the exception
139         e.restore();
140         return nullptr;
141       }
142     }
143     if (valid) {
144       found = &cache_entry;
145       break;
146     }
147     ++index;
148   }
149   if (found) {
150     extra_state->move_to_front(found);
151     return found->code.ptr();
152   }
153   return py::none().ptr();
154 }
155 
create_cache_entry(ExtraState * extra_state,PyObject * guarded_code,PyObject * backend)156 CacheEntry* create_cache_entry(
157     ExtraState* extra_state,
158     PyObject* guarded_code,
159     PyObject* backend) {
160   extra_state->cache_entry_list.emplace_front(guarded_code, backend);
161   auto new_iter = extra_state->cache_entry_list.begin();
162   new_iter->_owner = extra_state;
163   new_iter->_owner_loc = new_iter;
164   // Set check_fn references to extra_state and CacheEntry
165   // Warning: lifetime is controlled by C++!
166   py::handle check_fn = py::handle(guarded_code).attr("check_fn");
167   check_fn.attr("cache_entry") =
168       py::cast(*new_iter, py::return_value_policy::reference);
169   check_fn.attr("extra_state") =
170       py::cast(extra_state, py::return_value_policy::reference);
171   return &*new_iter;
172 }
173 
_debug_get_cache_entry_list(const py::handle & code_obj)174 py::list _debug_get_cache_entry_list(const py::handle& code_obj) {
175   if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) {
176     throw py::type_error("expected a code object!");
177   }
178   PyCodeObject* code = (PyCodeObject*)code_obj.ptr();
179   ExtraState* extra = get_extra_state(code);
180   py::list result;
181   if (!is_extra_state_unset(extra)) {
182     for (CacheEntry& e : extra->cache_entry_list) {
183       result.append(py::cast(e, py::return_value_policy::reference));
184     }
185   }
186   return result;
187 }
188