xref: /aosp_15_r20/external/pytorch/c10/core/impl/LocalDispatchKeySet.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/impl/LocalDispatchKeySet.h>
2 
3 namespace c10::impl {
4 
5 // NB: POD, must be zero initialized!
6 // Note [TLS Initialization]
7 // We wanted raw_local_dispatch_key_set to be initialized with non-zero state
8 // e.g. BackendSelect and ADInplaceOrView in included set.  But certain Windows
9 // compiler (e.g the one used in ARVR tests) only allow TLS to be
10 // zero-initialized. To preserve the invariant that raw TLS storage of the
11 // default state is zero, we obtain the actual include keyset by XORing
12 // raw_local_dispatch_key_set.included_ with c10::default_included_set.  This
13 // logic is encapsulated in struct PODLocalDispatchKeySet.
14 thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
15 
16 #if defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
tls_local_dispatch_key_set()17 LocalDispatchKeySet tls_local_dispatch_key_set() {
18   return raw_local_dispatch_key_set;
19 }
20 #endif // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
21 
_force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set)22 void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) {
23   raw_local_dispatch_key_set.set_included(key_set.included_);
24   raw_local_dispatch_key_set.set_excluded(key_set.excluded_);
25 }
26 
27 // An RAII guard could snapshot and restore the entire state (entire
28 // DispatchKeySet) as opposed to only snapshotting and restoring the state of
29 // its assigned DispatchKeySet. I'm not sure which is better.  If only the RAII
30 // API is used, the two choices are not distinguishable.
31 //
32 // However, if the guard chooses to snapshot and restore the entire
33 // DispatchKeySet, the interaction with the non-RAII API changes.  Consider this
34 // sequence of events:
35 // - An RAII guard is declared for a particular DispatchKeySet, but snapshots
36 // the entire
37 //   current DispatchKeySet.
38 // - A call to the non-RAII API changes the state for DispatchKeys outside the
39 // assigned
40 //   set.
41 // - The RAII guard goes out of scope, restoring the entire DispatchKeySet it
42 // snapshotted
43 //   (which restores the state for its own assigned DispatchKey and wipes out
44 //   the state for the other DispatchKeys set by the non-RAII API).
45 
46 // RAII API
47 
IncludeDispatchKeyGuard(DispatchKeySet include)48 IncludeDispatchKeyGuard::IncludeDispatchKeyGuard(DispatchKeySet include)
49     : tls_(&raw_local_dispatch_key_set), include_(include - tls_->included()) {
50   if (!include_.empty()) {
51     tls_->set_included(tls_->included() | include_);
52   }
53 }
54 
~IncludeDispatchKeyGuard()55 IncludeDispatchKeyGuard::~IncludeDispatchKeyGuard() {
56   if (!include_.empty()) {
57     tls_->set_included(tls_->included() - include_);
58   }
59 }
60 
ExcludeDispatchKeyGuard(DispatchKeySet exclude)61 ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(DispatchKeySet exclude)
62     : tls_(&raw_local_dispatch_key_set), exclude_(exclude - tls_->excluded()) {
63   if (!exclude_.empty()) {
64     tls_->set_excluded(tls_->excluded() | exclude_);
65   }
66 }
67 
~ExcludeDispatchKeyGuard()68 ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() {
69   if (!exclude_.empty()) {
70     tls_->set_excluded(tls_->excluded() - exclude_);
71   }
72 }
73 
74 // Non-RAII API
75 // Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h
76 // for details.
77 
tls_is_dispatch_key_excluded(DispatchKey x)78 bool tls_is_dispatch_key_excluded(DispatchKey x) {
79   return raw_local_dispatch_key_set.excluded().has(x);
80 }
81 
tls_set_dispatch_key_excluded(DispatchKey x,bool desired_state)82 void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state) {
83   auto* tls = &raw_local_dispatch_key_set;
84   bool current_state = tls->excluded().has(x);
85   if (desired_state != current_state) {
86     if (desired_state) {
87       tls->set_excluded(tls->excluded().add(x));
88     } else {
89       tls->set_excluded(tls->excluded().remove(x));
90     }
91   }
92 }
93 
tls_is_dispatch_key_included(DispatchKey x)94 bool tls_is_dispatch_key_included(DispatchKey x) {
95   return raw_local_dispatch_key_set.included().has(x);
96 }
97 
tls_set_dispatch_key_included(DispatchKey x,bool desired_state)98 void tls_set_dispatch_key_included(DispatchKey x, bool desired_state) {
99   auto* tls = &raw_local_dispatch_key_set;
100   bool current_state = tls->included().has(x);
101   if (desired_state != current_state) {
102     if (desired_state) {
103       tls->set_included(tls->included().add(x));
104     } else {
105       tls->set_included(tls->included().remove(x));
106     }
107   }
108 }
109 
tls_is_dispatch_keyset_excluded(DispatchKeySet ks)110 bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks) {
111   return raw_local_dispatch_key_set.excluded().isSupersetOf(ks);
112 }
113 
tls_is_dispatch_keyset_included(DispatchKeySet ks)114 bool tls_is_dispatch_keyset_included(DispatchKeySet ks) {
115   return raw_local_dispatch_key_set.included().isSupersetOf(ks);
116 }
117 } // namespace c10::impl
118