1 #include <c10/core/impl/LocalDispatchKeySet.h>
2
3 namespace c10::impl {
4
5 // NB: POD, must be zero initialized!
6 // Note [TLS Initialization]
7 // We wanted raw_local_dispatch_key_set to be initialized with non-zero state
8 // e.g. BackendSelect and ADInplaceOrView in included set. But certain Windows
9 // compiler (e.g the one used in ARVR tests) only allow TLS to be
10 // zero-initialized. To preserve the invariant that raw TLS storage of the
11 // default state is zero, we obtain the actual include keyset by XORing
12 // raw_local_dispatch_key_set.included_ with c10::default_included_set. This
13 // logic is encapsulated in struct PODLocalDispatchKeySet.
14 thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
15
16 #if defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
tls_local_dispatch_key_set()17 LocalDispatchKeySet tls_local_dispatch_key_set() {
18 return raw_local_dispatch_key_set;
19 }
20 #endif // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
21
_force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set)22 void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) {
23 raw_local_dispatch_key_set.set_included(key_set.included_);
24 raw_local_dispatch_key_set.set_excluded(key_set.excluded_);
25 }
26
27 // An RAII guard could snapshot and restore the entire state (entire
28 // DispatchKeySet) as opposed to only snapshotting and restoring the state of
29 // its assigned DispatchKeySet. I'm not sure which is better. If only the RAII
30 // API is used, the two choices are not distinguishable.
31 //
32 // However, if the guard chooses to snapshot and restore the entire
33 // DispatchKeySet, the interaction with the non-RAII API changes. Consider this
34 // sequence of events:
35 // - An RAII guard is declared for a particular DispatchKeySet, but snapshots
36 // the entire
37 // current DispatchKeySet.
38 // - A call to the non-RAII API changes the state for DispatchKeys outside the
39 // assigned
40 // set.
41 // - The RAII guard goes out of scope, restoring the entire DispatchKeySet it
42 // snapshotted
43 // (which restores the state for its own assigned DispatchKey and wipes out
44 // the state for the other DispatchKeys set by the non-RAII API).
45
46 // RAII API
47
IncludeDispatchKeyGuard(DispatchKeySet include)48 IncludeDispatchKeyGuard::IncludeDispatchKeyGuard(DispatchKeySet include)
49 : tls_(&raw_local_dispatch_key_set), include_(include - tls_->included()) {
50 if (!include_.empty()) {
51 tls_->set_included(tls_->included() | include_);
52 }
53 }
54
~IncludeDispatchKeyGuard()55 IncludeDispatchKeyGuard::~IncludeDispatchKeyGuard() {
56 if (!include_.empty()) {
57 tls_->set_included(tls_->included() - include_);
58 }
59 }
60
ExcludeDispatchKeyGuard(DispatchKeySet exclude)61 ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(DispatchKeySet exclude)
62 : tls_(&raw_local_dispatch_key_set), exclude_(exclude - tls_->excluded()) {
63 if (!exclude_.empty()) {
64 tls_->set_excluded(tls_->excluded() | exclude_);
65 }
66 }
67
~ExcludeDispatchKeyGuard()68 ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() {
69 if (!exclude_.empty()) {
70 tls_->set_excluded(tls_->excluded() - exclude_);
71 }
72 }
73
74 // Non-RAII API
75 // Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h
76 // for details.
77
tls_is_dispatch_key_excluded(DispatchKey x)78 bool tls_is_dispatch_key_excluded(DispatchKey x) {
79 return raw_local_dispatch_key_set.excluded().has(x);
80 }
81
tls_set_dispatch_key_excluded(DispatchKey x,bool desired_state)82 void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state) {
83 auto* tls = &raw_local_dispatch_key_set;
84 bool current_state = tls->excluded().has(x);
85 if (desired_state != current_state) {
86 if (desired_state) {
87 tls->set_excluded(tls->excluded().add(x));
88 } else {
89 tls->set_excluded(tls->excluded().remove(x));
90 }
91 }
92 }
93
tls_is_dispatch_key_included(DispatchKey x)94 bool tls_is_dispatch_key_included(DispatchKey x) {
95 return raw_local_dispatch_key_set.included().has(x);
96 }
97
tls_set_dispatch_key_included(DispatchKey x,bool desired_state)98 void tls_set_dispatch_key_included(DispatchKey x, bool desired_state) {
99 auto* tls = &raw_local_dispatch_key_set;
100 bool current_state = tls->included().has(x);
101 if (desired_state != current_state) {
102 if (desired_state) {
103 tls->set_included(tls->included().add(x));
104 } else {
105 tls->set_included(tls->included().remove(x));
106 }
107 }
108 }
109
tls_is_dispatch_keyset_excluded(DispatchKeySet ks)110 bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks) {
111 return raw_local_dispatch_key_set.excluded().isSupersetOf(ks);
112 }
113
tls_is_dispatch_keyset_included(DispatchKeySet ks)114 bool tls_is_dispatch_keyset_included(DispatchKeySet ks) {
115 return raw_local_dispatch_key_set.included().isSupersetOf(ks);
116 }
117 } // namespace c10::impl
118