1 #pragma once 2 3 #include <c10/core/InferenceMode.h> 4 #include <c10/core/impl/LocalDispatchKeySet.h> 5 #include <c10/util/Exception.h> 6 #include <c10/util/ThreadLocalDebugInfo.h> 7 8 #include <ATen/FuncTorchTLS.h> 9 #include <ATen/PythonTorchFunctionTLS.h> 10 #include <ATen/SavedTensorHooks.h> 11 #include <ATen/ThreadLocalPythonObjects.h> 12 #include <ATen/record_function.h> 13 #include <c10/core/impl/PythonDispatcherTLS.h> 14 #include <c10/core/impl/TorchDispatchModeTLS.h> 15 16 namespace at { 17 18 // Thread local state contains values that are preserved across 19 // thread boundaries (e.g. at::launch/JIT fork, autograd). 20 // Note at::parallel_for doesn't preserve TLS across thread boundaries. 21 class TORCH_API ThreadLocalState { 22 public: 23 // Saves the thread local variables' values and 24 // returns them as a ThreadLocalState 25 ThreadLocalState(); 26 27 // set_grad_mode - force the value of the grad mode TLS in 28 // the current state object. This is used for example in the 29 // autograd engine. 30 void set_grad_mode(bool enabled); 31 32 // set_multithreading_enabled - force the value of the multithreadinmaximum 33 // threads TLS in 34 // the current state object. This is used for example in the 35 // autograd engine. 36 void set_multithreading_enabled(bool enabled); 37 38 // Sets thread local variables in the current thread, 39 // according to the thread boundary specified 40 static void setThreadLocalState(const ThreadLocalState& state); 41 42 private: 43 c10::impl::LocalDispatchKeySet dispatch_key_; 44 45 // ThreadLocalDebugInfo does not change after being created 46 // with DebugInfoGuard 47 std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_; 48 49 // RecordFunction TLS 50 RecordFunctionTLS rf_tls_; 51 52 // TLS for out-of-tree functorch 53 // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a 54 // pointer (spoiler alert: it's due to the indirection) 55 // This needs to be a shared_ptr instead of a unique_ptr because 56 // ThreadLocalState is copy-able and does indeed get copied. Maybe we can 57 // consider adding an explicit copy constructor for ThreadLocalState in the 58 // future but I didn't want to add one just for this. 59 std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_; 60 61 // TLS for AutogradModes 62 AutogradState autograd_tls_; 63 64 // TLS for enable_torch_dispatch_mode 65 c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_; 66 67 // TLS for enable_python_dispatcher 68 c10::impl::PyInterpreter* python_dispatcher_state_; 69 70 // TLS for __torch_function__ (mode and disable_torch_function) 71 at::impl::PythonTorchFunctionTLS python_torch_function_state_; 72 73 // TLS for saved tensors default hooks 74 at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; 75 76 bool functionalization_reapply_views_state_; 77 78 // TLS for arbitrary python objects that is registered via hooks 79 at::impl::ThreadLocalPythonObjects saved_objects_; 80 81 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \ 82 !defined(BUILD_LITE_INTERPRETER) 83 // TLS for autocast dtypes 84 std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES> 85 autocast_dtypes_; 86 #endif 87 88 friend class ThreadLocalStateGuard; 89 }; 90 91 // Guard to set and reset the thread local state 92 class TORCH_API ThreadLocalStateGuard { 93 public: ThreadLocalStateGuard(const ThreadLocalState & state)94 explicit ThreadLocalStateGuard(const ThreadLocalState& state) 95 : prev_state_(ThreadLocalState()) { 96 // set the given state across the thread boundary 97 ThreadLocalState::setThreadLocalState(state); 98 } 99 ~ThreadLocalStateGuard()100 ~ThreadLocalStateGuard() { 101 // restore previously set variables 102 ThreadLocalState::setThreadLocalState(prev_state_); 103 } 104 105 private: 106 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 107 const ThreadLocalState prev_state_; 108 }; 109 110 template <typename T> wrapPropagateTLSState(T callback)111auto wrapPropagateTLSState(T callback) { 112 return [tls_state = ThreadLocalState(), 113 callback = std::move(callback)](auto&&... args) { 114 ThreadLocalStateGuard g(tls_state); 115 // Propagate value returned by callback(). 116 return callback(std::forward<decltype(args)>(args)...); 117 }; 118 } 119 120 } // namespace at 121