xref: /aosp_15_r20/external/pytorch/c10/core/impl/LocalDispatchKeySet.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/DispatchKeySet.h>
4 #include <c10/macros/Export.h>
5 
6 // TLS management for DispatchKeySet (the "local" DispatchKeySet(s))
7 //
8 // This manages two thread-local DispatchKeySets:
9 //
10 //  - The included type set, which adds a tensor type for consideration
11 //    in dispatch.  (For example, you might add Profiling to
12 //    the included type set to turn on profiling on all tensor operations.)
13 //
14 //  - The excluded type set, which disqualifies a tensor type from dispatch.
15 //    (For example, after redispatching on variable, we disqualify
16 //    Autograd so we don't attempt to handle variable again.)
17 //    (Exclusion wins over inclusion.)
18 //
19 // NB: Originally, I implemented the excluded type set as storing the inverted
20 // set, but TLS is defined to be zero-initialized, so this doesn't actually work
21 // (if it's inverted, you want the set to be -1 initialized).
22 
23 namespace c10::impl {
24 
25 // POD version of LocalDispatchKeySet.  Declared here just so that
26 // we can put it in the guards.
27 // This struct encapsulates special handling for TLS initialization
28 // in set_included()/included() API so that they reflect the truth.
29 // If you want to create PODLocalDispatchKeySet with non-zero state,
30 // use set_included() instead of default constructor.
31 struct C10_API PODLocalDispatchKeySet {
32   uint64_t included_;
33   uint64_t excluded_;
34 
35   // See Note [TLS Initialization]
includedPODLocalDispatchKeySet36   DispatchKeySet included() const {
37     return DispatchKeySet(DispatchKeySet::RAW, included_) ^
38         c10::default_included_set;
39   }
excludedPODLocalDispatchKeySet40   DispatchKeySet excluded() const {
41     return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^
42         c10::default_excluded_set;
43   }
44 
set_includedPODLocalDispatchKeySet45   void set_included(DispatchKeySet x) {
46     included_ = (x ^ c10::default_included_set).raw_repr();
47   }
set_excludedPODLocalDispatchKeySet48   void set_excluded(DispatchKeySet x) {
49     excluded_ = (x ^ c10::default_excluded_set).raw_repr();
50   }
51 };
52 static_assert(
53     std::is_trivial_v<PODLocalDispatchKeySet>,
54     "PODLocalDispatchKeySet must be a POD type.");
55 
56 struct C10_API LocalDispatchKeySet {
LocalDispatchKeySetLocalDispatchKeySet57   /* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x)
58       : included_(x.included()), excluded_(x.excluded()) {}
59   DispatchKeySet included_;
60   DispatchKeySet excluded_;
61 };
62 
63 // thread_local variables cannot be C10_API on Windows.
64 // Inlining this seems to break AutoDispatchBelowAutograd on Android.
65 #if defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
66 C10_API LocalDispatchKeySet tls_local_dispatch_key_set();
67 #else // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
68 extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
69 
tls_local_dispatch_key_set()70 inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() {
71   // Don't let people fiddle with the thread_local directly just
72   // because they include this header.
73   return raw_local_dispatch_key_set;
74 }
75 #endif // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
76 
77 // Internal, use ThreadLocalStateGuard
78 C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set);
79 
80 // RAII API for manipulating the thread-local dispatch state.
81 
82 class C10_API IncludeDispatchKeyGuard {
83  public:
84   IncludeDispatchKeyGuard(DispatchKeySet);
IncludeDispatchKeyGuard(DispatchKey k)85   IncludeDispatchKeyGuard(DispatchKey k)
86       : IncludeDispatchKeyGuard(DispatchKeySet(k)) {}
87   IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete;
88   IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete;
89   IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete;
90   IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete;
91   ~IncludeDispatchKeyGuard();
92 
93  private:
94   // A little micro-optimization to save us from tls_get_addr call
95   // on destruction
96   PODLocalDispatchKeySet* tls_;
97   DispatchKeySet include_;
98 };
99 
100 class C10_API ExcludeDispatchKeyGuard {
101  public:
102   ExcludeDispatchKeyGuard(DispatchKeySet);
ExcludeDispatchKeyGuard(DispatchKey k)103   ExcludeDispatchKeyGuard(DispatchKey k)
104       : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {}
105   ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete;
106   ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete;
107   ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete;
108   ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete;
109   ~ExcludeDispatchKeyGuard();
110 
111  private:
112   // A little micro-optimization to save us from tls_get_addr call
113   // on destruction
114   PODLocalDispatchKeySet* tls_;
115   DispatchKeySet exclude_;
116 };
117 
118 struct C10_API ForceDispatchKeyGuard {
119  public:
ForceDispatchKeyGuardForceDispatchKeyGuard120   ForceDispatchKeyGuard()
121       : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {}
ForceDispatchKeyGuardForceDispatchKeyGuard122   ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set)
123       : ForceDispatchKeyGuard() {
124     c10::impl::_force_tls_local_dispatch_key_set(key_set);
125   }
ForceDispatchKeyGuardForceDispatchKeyGuard126   ForceDispatchKeyGuard(
127       c10::DispatchKeySet include,
128       c10::DispatchKeySet exclude)
129       : ForceDispatchKeyGuard() {
130     auto updated_set = saved_keyset_;
131     updated_set.included_ = include;
132     updated_set.excluded_ = exclude;
133     c10::impl::_force_tls_local_dispatch_key_set(updated_set);
134   }
~ForceDispatchKeyGuardForceDispatchKeyGuard135   ~ForceDispatchKeyGuard() {
136     c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_);
137   }
138 
139  private:
140   c10::impl::LocalDispatchKeySet saved_keyset_;
141 };
142 
143 // Non-RAII API for manipulating the thread-local dispatch state.
144 // Please prefer the RAII API.  The non-RAII API may be useful when
145 // the included/excluded state of a given DispatchKey must span
146 // many calls from the Python to the C++, so you cannot conveniently
147 // use an RAII guard.
148 //
149 // Example use case:  a Python context manager that includes a certain
150 // DispatchKey, to ensure ops running under the context manager dispatch
151 // through that DispatchKey's registered overrides.
152 //
153 // The non-RAII API is less efficient than the RAII guards because both the
154 // getter and setter will do a tls_getaddr lookup (the RAII struct only needs
155 // one!)
156 
157 C10_API bool tls_is_dispatch_key_excluded(DispatchKey x);
158 C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state);
159 C10_API bool tls_is_dispatch_key_included(DispatchKey x);
160 C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state);
161 C10_API bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks);
162 C10_API bool tls_is_dispatch_keyset_included(DispatchKeySet ks);
163 
164 } // namespace c10::impl
165