xref: /aosp_15_r20/external/pytorch/torch/csrc/dynamo/cache_entry.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/dynamo/cache_entry.h>
2 #include <torch/csrc/dynamo/guards.h>
3 
4 #include <torch/csrc/dynamo/debug_macros.h>
5 #include <torch/csrc/dynamo/extra_state.h>
6 
CacheEntry(const py::handle & guarded_code,PyObject * backend)7 CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
8     : backend(backend) {
9   this->check_fn = guarded_code.attr("check_fn");
10   this->code = guarded_code.attr("code");
11   this->compile_id = guarded_code.attr("compile_id");
12   // TODO - clean this up when enable_cpp_guard_manager is True by default
13   if (py::hasattr(this->check_fn, "root")) {
14     this->root_mgr = torch::dynamo::convert_to_root_guard_manager(
15         this->check_fn.attr("root"));
16   }
17 }
18 
19 // NOLINTNEXTLINE(bugprone-exception-escape)
~CacheEntry()20 CacheEntry::~CacheEntry() {
21   // prevent check_fn from use-after-free when invalidating
22   this->check_fn.attr("cache_entry") = py::none();
23   this->check_fn.attr("extra_state") = py::none();
24 }
25 
next()26 py::object CacheEntry::next() {
27   NULL_CHECK(this->_owner);
28   auto it = this->_owner_loc;
29   ++it;
30   if (it == this->_owner->cache_entry_list.end()) {
31     return py::none();
32   }
33   return py::cast(*it, py::return_value_policy::reference);
34 }
35 
CacheEntry_get_code(CacheEntry * e)36 PyCodeObject* CacheEntry_get_code(CacheEntry* e) {
37   return (PyCodeObject*)e->code.ptr();
38 }
39 
CacheEntry_to_obj(CacheEntry * e)40 PyObject* CacheEntry_to_obj(CacheEntry* e) {
41   if (!e) {
42     return py::none().release().ptr();
43   }
44   return py::cast(e, py::return_value_policy::reference).release().ptr();
45 }
46 
get_backend(PyObject * callback)47 PyObject* get_backend(PyObject* callback) {
48   py::handle handle = py::handle(callback);
49   while (py::hasattr(handle, "_torchdynamo_orig_callable")) {
50     handle = handle.attr("_torchdynamo_orig_callable");
51   }
52   return handle.ptr();
53 }
54