xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ThreadLocalState.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/InferenceMode.h>
4 #include <c10/core/impl/LocalDispatchKeySet.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/ThreadLocalDebugInfo.h>
7 
8 #include <ATen/FuncTorchTLS.h>
9 #include <ATen/PythonTorchFunctionTLS.h>
10 #include <ATen/SavedTensorHooks.h>
11 #include <ATen/ThreadLocalPythonObjects.h>
12 #include <ATen/record_function.h>
13 #include <c10/core/impl/PythonDispatcherTLS.h>
14 #include <c10/core/impl/TorchDispatchModeTLS.h>
15 
16 namespace at {
17 
18 // Thread local state contains values that are preserved across
19 // thread boundaries (e.g. at::launch/JIT fork, autograd).
20 // Note at::parallel_for doesn't preserve TLS across thread boundaries.
21 class TORCH_API ThreadLocalState {
22  public:
23   // Saves the thread local variables' values and
24   // returns them as a ThreadLocalState
25   ThreadLocalState();
26 
27   // set_grad_mode - force the value of the grad mode TLS in
28   //  the current state object. This is used for example in the
29   //  autograd engine.
30   void set_grad_mode(bool enabled);
31 
32   // set_multithreading_enabled - force the value of the multithreadinmaximum
33   // threads TLS in
34   //  the current state object. This is used for example in the
35   //  autograd engine.
36   void set_multithreading_enabled(bool enabled);
37 
38   // Sets thread local variables in the current thread,
39   // according to the thread boundary specified
40   static void setThreadLocalState(const ThreadLocalState& state);
41 
42  private:
43   c10::impl::LocalDispatchKeySet dispatch_key_;
44 
45   // ThreadLocalDebugInfo does not change after being created
46   // with DebugInfoGuard
47   std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
48 
49   // RecordFunction TLS
50   RecordFunctionTLS rf_tls_;
51 
52   // TLS for out-of-tree functorch
53   // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
54   // pointer (spoiler alert: it's due to the indirection)
55   // This needs to be a shared_ptr instead of a unique_ptr because
56   // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
57   // consider adding an explicit copy constructor for ThreadLocalState in the
58   // future but I didn't want to add one just for this.
59   std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
60 
61   // TLS for AutogradModes
62   AutogradState autograd_tls_;
63 
64   // TLS for enable_torch_dispatch_mode
65   c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
66 
67   // TLS for enable_python_dispatcher
68   c10::impl::PyInterpreter* python_dispatcher_state_;
69 
70   // TLS for __torch_function__ (mode and disable_torch_function)
71   at::impl::PythonTorchFunctionTLS python_torch_function_state_;
72 
73   // TLS for saved tensors default hooks
74   at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
75 
76   bool functionalization_reapply_views_state_;
77 
78   // TLS for arbitrary python objects that is registered via hooks
79   at::impl::ThreadLocalPythonObjects saved_objects_;
80 
81 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \
82     !defined(BUILD_LITE_INTERPRETER)
83   // TLS for autocast dtypes
84   std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
85       autocast_dtypes_;
86 #endif
87 
88   friend class ThreadLocalStateGuard;
89 };
90 
91 // Guard to set and reset the thread local state
92 class TORCH_API ThreadLocalStateGuard {
93  public:
ThreadLocalStateGuard(const ThreadLocalState & state)94   explicit ThreadLocalStateGuard(const ThreadLocalState& state)
95       : prev_state_(ThreadLocalState()) {
96     // set the given state across the thread boundary
97     ThreadLocalState::setThreadLocalState(state);
98   }
99 
~ThreadLocalStateGuard()100   ~ThreadLocalStateGuard() {
101     // restore previously set variables
102     ThreadLocalState::setThreadLocalState(prev_state_);
103   }
104 
105  private:
106   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
107   const ThreadLocalState prev_state_;
108 };
109 
110 template <typename T>
wrapPropagateTLSState(T callback)111 auto wrapPropagateTLSState(T callback) {
112   return [tls_state = ThreadLocalState(),
113           callback = std::move(callback)](auto&&... args) {
114     ThreadLocalStateGuard g(tls_state);
115     // Propagate value returned by callback().
116     return callback(std::forward<decltype(args)>(args)...);
117   };
118 }
119 
120 } // namespace at
121