1 #include <c10/core/TensorImpl.h> 2 #include <ATen/ThreadLocalPythonObjects.h> 3 #include <c10/util/Exception.h> 4 5 #include <utility> 6 7 namespace at::impl { 8 9 static thread_local ThreadLocalPythonObjects py_objects; 10 11 set(const std::string & key,std::shared_ptr<SafePyObject> value)12void ThreadLocalPythonObjects::set(const std::string& key, std::shared_ptr<SafePyObject> value) { 13 py_objects.obj_dict_[key] = std::move(value); 14 } 15 get(const std::string & key)16const std::shared_ptr<SafePyObject>& ThreadLocalPythonObjects::get(const std::string& key) { 17 TORCH_CHECK(py_objects.obj_dict_.count(key)); 18 return py_objects.obj_dict_[key]; 19 } 20 contains(const std::string & key)21bool ThreadLocalPythonObjects::contains(const std::string& key) { 22 return py_objects.obj_dict_.count(key); 23 } 24 set_state(ThreadLocalPythonObjects state)25void ThreadLocalPythonObjects::set_state(ThreadLocalPythonObjects state) { 26 py_objects = std::move(state); 27 } 28 get_state()29const ThreadLocalPythonObjects& ThreadLocalPythonObjects::get_state() { 30 return py_objects; 31 } 32 33 34 } // namespace at::impl 35