xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ThreadLocalPythonObjects.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/TensorImpl.h>
2 #include <ATen/ThreadLocalPythonObjects.h>
3 #include <c10/util/Exception.h>
4 
5 #include <utility>
6 
7 namespace at::impl {
8 
9 static thread_local ThreadLocalPythonObjects py_objects;
10 
11 
set(const std::string & key,std::shared_ptr<SafePyObject> value)12 void ThreadLocalPythonObjects::set(const std::string& key, std::shared_ptr<SafePyObject> value) {
13   py_objects.obj_dict_[key] = std::move(value);
14 }
15 
get(const std::string & key)16 const std::shared_ptr<SafePyObject>& ThreadLocalPythonObjects::get(const std::string& key) {
17   TORCH_CHECK(py_objects.obj_dict_.count(key));
18   return py_objects.obj_dict_[key];
19 }
20 
contains(const std::string & key)21 bool ThreadLocalPythonObjects::contains(const std::string& key) {
22   return py_objects.obj_dict_.count(key);
23 }
24 
set_state(ThreadLocalPythonObjects state)25 void ThreadLocalPythonObjects::set_state(ThreadLocalPythonObjects state) {
26   py_objects = std::move(state);
27 }
28 
get_state()29 const ThreadLocalPythonObjects& ThreadLocalPythonObjects::get_state() {
30   return py_objects;
31 }
32 
33 
34 } // namespace at::impl
35