xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ThreadLocalPythonObjects.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/SafePyObject.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
5*da0073e9SAndroid Build Coastguard Worker #include <unordered_map>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker namespace at::impl {
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker struct TORCH_API ThreadLocalPythonObjects {
10*da0073e9SAndroid Build Coastguard Worker   static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
11*da0073e9SAndroid Build Coastguard Worker   static const std::shared_ptr<SafePyObject>& get(const std::string& key);
12*da0073e9SAndroid Build Coastguard Worker   static bool contains(const std::string& key);
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker   static const ThreadLocalPythonObjects& get_state();
15*da0073e9SAndroid Build Coastguard Worker   static void set_state(ThreadLocalPythonObjects state);
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker  private:
18*da0073e9SAndroid Build Coastguard Worker   std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
19*da0073e9SAndroid Build Coastguard Worker };
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker } // namespace at::impl
22