1 #pragma once
2
3 #include <c10/core/DispatchKeySet.h>
4 #include <c10/macros/Export.h>
5
6 // TLS management for DispatchKeySet (the "local" DispatchKeySet(s))
7 //
8 // This manages two thread-local DispatchKeySets:
9 //
10 // - The included type set, which adds a tensor type for consideration
11 // in dispatch. (For example, you might add Profiling to
12 // the included type set to turn on profiling on all tensor operations.)
13 //
14 // - The excluded type set, which disqualifies a tensor type from dispatch.
15 // (For example, after redispatching on variable, we disqualify
16 // Autograd so we don't attempt to handle variable again.)
17 // (Exclusion wins over inclusion.)
18 //
19 // NB: Originally, I implemented the excluded type set as storing the inverted
20 // set, but TLS is defined to be zero-initialized, so this doesn't actually work
21 // (if it's inverted, you want the set to be -1 initialized).
22
23 namespace c10::impl {
24
25 // POD version of LocalDispatchKeySet. Declared here just so that
26 // we can put it in the guards.
27 // This struct encapsulates special handling for TLS initialization
28 // in set_included()/included() API so that they reflect the truth.
29 // If you want to create PODLocalDispatchKeySet with non-zero state,
30 // use set_included() instead of default constructor.
31 struct C10_API PODLocalDispatchKeySet {
32 uint64_t included_;
33 uint64_t excluded_;
34
35 // See Note [TLS Initialization]
includedPODLocalDispatchKeySet36 DispatchKeySet included() const {
37 return DispatchKeySet(DispatchKeySet::RAW, included_) ^
38 c10::default_included_set;
39 }
excludedPODLocalDispatchKeySet40 DispatchKeySet excluded() const {
41 return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^
42 c10::default_excluded_set;
43 }
44
set_includedPODLocalDispatchKeySet45 void set_included(DispatchKeySet x) {
46 included_ = (x ^ c10::default_included_set).raw_repr();
47 }
set_excludedPODLocalDispatchKeySet48 void set_excluded(DispatchKeySet x) {
49 excluded_ = (x ^ c10::default_excluded_set).raw_repr();
50 }
51 };
52 static_assert(
53 std::is_trivial_v<PODLocalDispatchKeySet>,
54 "PODLocalDispatchKeySet must be a POD type.");
55
56 struct C10_API LocalDispatchKeySet {
LocalDispatchKeySetLocalDispatchKeySet57 /* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x)
58 : included_(x.included()), excluded_(x.excluded()) {}
59 DispatchKeySet included_;
60 DispatchKeySet excluded_;
61 };
62
63 // thread_local variables cannot be C10_API on Windows.
64 // Inlining this seems to break AutoDispatchBelowAutograd on Android.
65 #if defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
66 C10_API LocalDispatchKeySet tls_local_dispatch_key_set();
67 #else // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
68 extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
69
tls_local_dispatch_key_set()70 inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() {
71 // Don't let people fiddle with the thread_local directly just
72 // because they include this header.
73 return raw_local_dispatch_key_set;
74 }
75 #endif // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
76
77 // Internal, use ThreadLocalStateGuard
78 C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set);
79
80 // RAII API for manipulating the thread-local dispatch state.
81
82 class C10_API IncludeDispatchKeyGuard {
83 public:
84 IncludeDispatchKeyGuard(DispatchKeySet);
IncludeDispatchKeyGuard(DispatchKey k)85 IncludeDispatchKeyGuard(DispatchKey k)
86 : IncludeDispatchKeyGuard(DispatchKeySet(k)) {}
87 IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete;
88 IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete;
89 IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete;
90 IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete;
91 ~IncludeDispatchKeyGuard();
92
93 private:
94 // A little micro-optimization to save us from tls_get_addr call
95 // on destruction
96 PODLocalDispatchKeySet* tls_;
97 DispatchKeySet include_;
98 };
99
100 class C10_API ExcludeDispatchKeyGuard {
101 public:
102 ExcludeDispatchKeyGuard(DispatchKeySet);
ExcludeDispatchKeyGuard(DispatchKey k)103 ExcludeDispatchKeyGuard(DispatchKey k)
104 : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {}
105 ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete;
106 ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete;
107 ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete;
108 ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete;
109 ~ExcludeDispatchKeyGuard();
110
111 private:
112 // A little micro-optimization to save us from tls_get_addr call
113 // on destruction
114 PODLocalDispatchKeySet* tls_;
115 DispatchKeySet exclude_;
116 };
117
118 struct C10_API ForceDispatchKeyGuard {
119 public:
ForceDispatchKeyGuardForceDispatchKeyGuard120 ForceDispatchKeyGuard()
121 : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {}
ForceDispatchKeyGuardForceDispatchKeyGuard122 ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set)
123 : ForceDispatchKeyGuard() {
124 c10::impl::_force_tls_local_dispatch_key_set(key_set);
125 }
ForceDispatchKeyGuardForceDispatchKeyGuard126 ForceDispatchKeyGuard(
127 c10::DispatchKeySet include,
128 c10::DispatchKeySet exclude)
129 : ForceDispatchKeyGuard() {
130 auto updated_set = saved_keyset_;
131 updated_set.included_ = include;
132 updated_set.excluded_ = exclude;
133 c10::impl::_force_tls_local_dispatch_key_set(updated_set);
134 }
~ForceDispatchKeyGuardForceDispatchKeyGuard135 ~ForceDispatchKeyGuard() {
136 c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_);
137 }
138
139 private:
140 c10::impl::LocalDispatchKeySet saved_keyset_;
141 };
142
143 // Non-RAII API for manipulating the thread-local dispatch state.
144 // Please prefer the RAII API. The non-RAII API may be useful when
145 // the included/excluded state of a given DispatchKey must span
146 // many calls from the Python to the C++, so you cannot conveniently
147 // use an RAII guard.
148 //
149 // Example use case: a Python context manager that includes a certain
150 // DispatchKey, to ensure ops running under the context manager dispatch
151 // through that DispatchKey's registered overrides.
152 //
153 // The non-RAII API is less efficient than the RAII guards because both the
154 // getter and setter will do a tls_getaddr lookup (the RAII struct only needs
155 // one!)
156
157 C10_API bool tls_is_dispatch_key_excluded(DispatchKey x);
158 C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state);
159 C10_API bool tls_is_dispatch_key_included(DispatchKey x);
160 C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state);
161 C10_API bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks);
162 C10_API bool tls_is_dispatch_keyset_included(DispatchKeySet ks);
163
164 } // namespace c10::impl
165