xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ThreadLocalPythonObjects.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SafePyObject.h>
4 #include <c10/macros/Macros.h>
5 #include <unordered_map>
6 
7 namespace at::impl {
8 
9 struct TORCH_API ThreadLocalPythonObjects {
10   static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
11   static const std::shared_ptr<SafePyObject>& get(const std::string& key);
12   static bool contains(const std::string& key);
13 
14   static const ThreadLocalPythonObjects& get_state();
15   static void set_state(ThreadLocalPythonObjects state);
16 
17  private:
18   std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
19 };
20 
21 } // namespace at::impl
22