xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ThreadLocalState.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/InferenceMode.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/LocalDispatchKeySet.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ThreadLocalDebugInfo.h>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #include <ATen/FuncTorchTLS.h>
9*da0073e9SAndroid Build Coastguard Worker #include <ATen/PythonTorchFunctionTLS.h>
10*da0073e9SAndroid Build Coastguard Worker #include <ATen/SavedTensorHooks.h>
11*da0073e9SAndroid Build Coastguard Worker #include <ATen/ThreadLocalPythonObjects.h>
12*da0073e9SAndroid Build Coastguard Worker #include <ATen/record_function.h>
13*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/PythonDispatcherTLS.h>
14*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/TorchDispatchModeTLS.h>
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker namespace at {
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker // Thread local state contains values that are preserved across
19*da0073e9SAndroid Build Coastguard Worker // thread boundaries (e.g. at::launch/JIT fork, autograd).
20*da0073e9SAndroid Build Coastguard Worker // Note at::parallel_for doesn't preserve TLS across thread boundaries.
21*da0073e9SAndroid Build Coastguard Worker class TORCH_API ThreadLocalState {
22*da0073e9SAndroid Build Coastguard Worker  public:
23*da0073e9SAndroid Build Coastguard Worker   // Saves the thread local variables' values and
24*da0073e9SAndroid Build Coastguard Worker   // returns them as a ThreadLocalState
25*da0073e9SAndroid Build Coastguard Worker   ThreadLocalState();
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker   // set_grad_mode - force the value of the grad mode TLS in
28*da0073e9SAndroid Build Coastguard Worker   //  the current state object. This is used for example in the
29*da0073e9SAndroid Build Coastguard Worker   //  autograd engine.
30*da0073e9SAndroid Build Coastguard Worker   void set_grad_mode(bool enabled);
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker   // set_multithreading_enabled - force the value of the multithreadinmaximum
33*da0073e9SAndroid Build Coastguard Worker   // threads TLS in
34*da0073e9SAndroid Build Coastguard Worker   //  the current state object. This is used for example in the
35*da0073e9SAndroid Build Coastguard Worker   //  autograd engine.
36*da0073e9SAndroid Build Coastguard Worker   void set_multithreading_enabled(bool enabled);
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker   // Sets thread local variables in the current thread,
39*da0073e9SAndroid Build Coastguard Worker   // according to the thread boundary specified
40*da0073e9SAndroid Build Coastguard Worker   static void setThreadLocalState(const ThreadLocalState& state);
41*da0073e9SAndroid Build Coastguard Worker 
42*da0073e9SAndroid Build Coastguard Worker  private:
43*da0073e9SAndroid Build Coastguard Worker   c10::impl::LocalDispatchKeySet dispatch_key_;
44*da0073e9SAndroid Build Coastguard Worker 
45*da0073e9SAndroid Build Coastguard Worker   // ThreadLocalDebugInfo does not change after being created
46*da0073e9SAndroid Build Coastguard Worker   // with DebugInfoGuard
47*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker   // RecordFunction TLS
50*da0073e9SAndroid Build Coastguard Worker   RecordFunctionTLS rf_tls_;
51*da0073e9SAndroid Build Coastguard Worker 
52*da0073e9SAndroid Build Coastguard Worker   // TLS for out-of-tree functorch
53*da0073e9SAndroid Build Coastguard Worker   // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
54*da0073e9SAndroid Build Coastguard Worker   // pointer (spoiler alert: it's due to the indirection)
55*da0073e9SAndroid Build Coastguard Worker   // This needs to be a shared_ptr instead of a unique_ptr because
56*da0073e9SAndroid Build Coastguard Worker   // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
57*da0073e9SAndroid Build Coastguard Worker   // consider adding an explicit copy constructor for ThreadLocalState in the
58*da0073e9SAndroid Build Coastguard Worker   // future but I didn't want to add one just for this.
59*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker   // TLS for AutogradModes
62*da0073e9SAndroid Build Coastguard Worker   AutogradState autograd_tls_;
63*da0073e9SAndroid Build Coastguard Worker 
64*da0073e9SAndroid Build Coastguard Worker   // TLS for enable_torch_dispatch_mode
65*da0073e9SAndroid Build Coastguard Worker   c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
66*da0073e9SAndroid Build Coastguard Worker 
67*da0073e9SAndroid Build Coastguard Worker   // TLS for enable_python_dispatcher
68*da0073e9SAndroid Build Coastguard Worker   c10::impl::PyInterpreter* python_dispatcher_state_;
69*da0073e9SAndroid Build Coastguard Worker 
70*da0073e9SAndroid Build Coastguard Worker   // TLS for __torch_function__ (mode and disable_torch_function)
71*da0073e9SAndroid Build Coastguard Worker   at::impl::PythonTorchFunctionTLS python_torch_function_state_;
72*da0073e9SAndroid Build Coastguard Worker 
73*da0073e9SAndroid Build Coastguard Worker   // TLS for saved tensors default hooks
74*da0073e9SAndroid Build Coastguard Worker   at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
75*da0073e9SAndroid Build Coastguard Worker 
76*da0073e9SAndroid Build Coastguard Worker   bool functionalization_reapply_views_state_;
77*da0073e9SAndroid Build Coastguard Worker 
78*da0073e9SAndroid Build Coastguard Worker   // TLS for arbitrary python objects that is registered via hooks
79*da0073e9SAndroid Build Coastguard Worker   at::impl::ThreadLocalPythonObjects saved_objects_;
80*da0073e9SAndroid Build Coastguard Worker 
81*da0073e9SAndroid Build Coastguard Worker #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \
82*da0073e9SAndroid Build Coastguard Worker     !defined(BUILD_LITE_INTERPRETER)
83*da0073e9SAndroid Build Coastguard Worker   // TLS for autocast dtypes
84*da0073e9SAndroid Build Coastguard Worker   std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
85*da0073e9SAndroid Build Coastguard Worker       autocast_dtypes_;
86*da0073e9SAndroid Build Coastguard Worker #endif
87*da0073e9SAndroid Build Coastguard Worker 
88*da0073e9SAndroid Build Coastguard Worker   friend class ThreadLocalStateGuard;
89*da0073e9SAndroid Build Coastguard Worker };
90*da0073e9SAndroid Build Coastguard Worker 
91*da0073e9SAndroid Build Coastguard Worker // Guard to set and reset the thread local state
92*da0073e9SAndroid Build Coastguard Worker class TORCH_API ThreadLocalStateGuard {
93*da0073e9SAndroid Build Coastguard Worker  public:
ThreadLocalStateGuard(const ThreadLocalState & state)94*da0073e9SAndroid Build Coastguard Worker   explicit ThreadLocalStateGuard(const ThreadLocalState& state)
95*da0073e9SAndroid Build Coastguard Worker       : prev_state_(ThreadLocalState()) {
96*da0073e9SAndroid Build Coastguard Worker     // set the given state across the thread boundary
97*da0073e9SAndroid Build Coastguard Worker     ThreadLocalState::setThreadLocalState(state);
98*da0073e9SAndroid Build Coastguard Worker   }
99*da0073e9SAndroid Build Coastguard Worker 
~ThreadLocalStateGuard()100*da0073e9SAndroid Build Coastguard Worker   ~ThreadLocalStateGuard() {
101*da0073e9SAndroid Build Coastguard Worker     // restore previously set variables
102*da0073e9SAndroid Build Coastguard Worker     ThreadLocalState::setThreadLocalState(prev_state_);
103*da0073e9SAndroid Build Coastguard Worker   }
104*da0073e9SAndroid Build Coastguard Worker 
105*da0073e9SAndroid Build Coastguard Worker  private:
106*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
107*da0073e9SAndroid Build Coastguard Worker   const ThreadLocalState prev_state_;
108*da0073e9SAndroid Build Coastguard Worker };
109*da0073e9SAndroid Build Coastguard Worker 
110*da0073e9SAndroid Build Coastguard Worker template <typename T>
wrapPropagateTLSState(T callback)111*da0073e9SAndroid Build Coastguard Worker auto wrapPropagateTLSState(T callback) {
112*da0073e9SAndroid Build Coastguard Worker   return [tls_state = ThreadLocalState(),
113*da0073e9SAndroid Build Coastguard Worker           callback = std::move(callback)](auto&&... args) {
114*da0073e9SAndroid Build Coastguard Worker     ThreadLocalStateGuard g(tls_state);
115*da0073e9SAndroid Build Coastguard Worker     // Propagate value returned by callback().
116*da0073e9SAndroid Build Coastguard Worker     return callback(std::forward<decltype(args)>(args)...);
117*da0073e9SAndroid Build Coastguard Worker   };
118*da0073e9SAndroid Build Coastguard Worker }
119*da0073e9SAndroid Build Coastguard Worker 
120*da0073e9SAndroid Build Coastguard Worker } // namespace at
121