1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/InferenceMode.h> 4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/LocalDispatchKeySet.h> 5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h> 6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ThreadLocalDebugInfo.h> 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker #include <ATen/FuncTorchTLS.h> 9*da0073e9SAndroid Build Coastguard Worker #include <ATen/PythonTorchFunctionTLS.h> 10*da0073e9SAndroid Build Coastguard Worker #include <ATen/SavedTensorHooks.h> 11*da0073e9SAndroid Build Coastguard Worker #include <ATen/ThreadLocalPythonObjects.h> 12*da0073e9SAndroid Build Coastguard Worker #include <ATen/record_function.h> 13*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/PythonDispatcherTLS.h> 14*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/TorchDispatchModeTLS.h> 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker namespace at { 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker // Thread local state contains values that are preserved across 19*da0073e9SAndroid Build Coastguard Worker // thread boundaries (e.g. at::launch/JIT fork, autograd). 20*da0073e9SAndroid Build Coastguard Worker // Note at::parallel_for doesn't preserve TLS across thread boundaries. 21*da0073e9SAndroid Build Coastguard Worker class TORCH_API ThreadLocalState { 22*da0073e9SAndroid Build Coastguard Worker public: 23*da0073e9SAndroid Build Coastguard Worker // Saves the thread local variables' values and 24*da0073e9SAndroid Build Coastguard Worker // returns them as a ThreadLocalState 25*da0073e9SAndroid Build Coastguard Worker ThreadLocalState(); 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker // set_grad_mode - force the value of the grad mode TLS in 28*da0073e9SAndroid Build Coastguard Worker // the current state object. This is used for example in the 29*da0073e9SAndroid Build Coastguard Worker // autograd engine. 30*da0073e9SAndroid Build Coastguard Worker void set_grad_mode(bool enabled); 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker // set_multithreading_enabled - force the value of the multithreadinmaximum 33*da0073e9SAndroid Build Coastguard Worker // threads TLS in 34*da0073e9SAndroid Build Coastguard Worker // the current state object. This is used for example in the 35*da0073e9SAndroid Build Coastguard Worker // autograd engine. 36*da0073e9SAndroid Build Coastguard Worker void set_multithreading_enabled(bool enabled); 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker // Sets thread local variables in the current thread, 39*da0073e9SAndroid Build Coastguard Worker // according to the thread boundary specified 40*da0073e9SAndroid Build Coastguard Worker static void setThreadLocalState(const ThreadLocalState& state); 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker private: 43*da0073e9SAndroid Build Coastguard Worker c10::impl::LocalDispatchKeySet dispatch_key_; 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker // ThreadLocalDebugInfo does not change after being created 46*da0073e9SAndroid Build Coastguard Worker // with DebugInfoGuard 47*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_; 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker // RecordFunction TLS 50*da0073e9SAndroid Build Coastguard Worker RecordFunctionTLS rf_tls_; 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker // TLS for out-of-tree functorch 53*da0073e9SAndroid Build Coastguard Worker // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a 54*da0073e9SAndroid Build Coastguard Worker // pointer (spoiler alert: it's due to the indirection) 55*da0073e9SAndroid Build Coastguard Worker // This needs to be a shared_ptr instead of a unique_ptr because 56*da0073e9SAndroid Build Coastguard Worker // ThreadLocalState is copy-able and does indeed get copied. Maybe we can 57*da0073e9SAndroid Build Coastguard Worker // consider adding an explicit copy constructor for ThreadLocalState in the 58*da0073e9SAndroid Build Coastguard Worker // future but I didn't want to add one just for this. 59*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_; 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker // TLS for AutogradModes 62*da0073e9SAndroid Build Coastguard Worker AutogradState autograd_tls_; 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker // TLS for enable_torch_dispatch_mode 65*da0073e9SAndroid Build Coastguard Worker c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_; 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker // TLS for enable_python_dispatcher 68*da0073e9SAndroid Build Coastguard Worker c10::impl::PyInterpreter* python_dispatcher_state_; 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker // TLS for __torch_function__ (mode and disable_torch_function) 71*da0073e9SAndroid Build Coastguard Worker at::impl::PythonTorchFunctionTLS python_torch_function_state_; 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker // TLS for saved tensors default hooks 74*da0073e9SAndroid Build Coastguard Worker at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker bool functionalization_reapply_views_state_; 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker // TLS for arbitrary python objects that is registered via hooks 79*da0073e9SAndroid Build Coastguard Worker at::impl::ThreadLocalPythonObjects saved_objects_; 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \ 82*da0073e9SAndroid Build Coastguard Worker !defined(BUILD_LITE_INTERPRETER) 83*da0073e9SAndroid Build Coastguard Worker // TLS for autocast dtypes 84*da0073e9SAndroid Build Coastguard Worker std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES> 85*da0073e9SAndroid Build Coastguard Worker autocast_dtypes_; 86*da0073e9SAndroid Build Coastguard Worker #endif 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker friend class ThreadLocalStateGuard; 89*da0073e9SAndroid Build Coastguard Worker }; 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker // Guard to set and reset the thread local state 92*da0073e9SAndroid Build Coastguard Worker class TORCH_API ThreadLocalStateGuard { 93*da0073e9SAndroid Build Coastguard Worker public: ThreadLocalStateGuard(const ThreadLocalState & state)94*da0073e9SAndroid Build Coastguard Worker explicit ThreadLocalStateGuard(const ThreadLocalState& state) 95*da0073e9SAndroid Build Coastguard Worker : prev_state_(ThreadLocalState()) { 96*da0073e9SAndroid Build Coastguard Worker // set the given state across the thread boundary 97*da0073e9SAndroid Build Coastguard Worker ThreadLocalState::setThreadLocalState(state); 98*da0073e9SAndroid Build Coastguard Worker } 99*da0073e9SAndroid Build Coastguard Worker ~ThreadLocalStateGuard()100*da0073e9SAndroid Build Coastguard Worker ~ThreadLocalStateGuard() { 101*da0073e9SAndroid Build Coastguard Worker // restore previously set variables 102*da0073e9SAndroid Build Coastguard Worker ThreadLocalState::setThreadLocalState(prev_state_); 103*da0073e9SAndroid Build Coastguard Worker } 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker private: 106*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 107*da0073e9SAndroid Build Coastguard Worker const ThreadLocalState prev_state_; 108*da0073e9SAndroid Build Coastguard Worker }; 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker template <typename T> wrapPropagateTLSState(T callback)111*da0073e9SAndroid Build Coastguard Workerauto wrapPropagateTLSState(T callback) { 112*da0073e9SAndroid Build Coastguard Worker return [tls_state = ThreadLocalState(), 113*da0073e9SAndroid Build Coastguard Worker callback = std::move(callback)](auto&&... args) { 114*da0073e9SAndroid Build Coastguard Worker ThreadLocalStateGuard g(tls_state); 115*da0073e9SAndroid Build Coastguard Worker // Propagate value returned by callback(). 116*da0073e9SAndroid Build Coastguard Worker return callback(std::forward<decltype(args)>(args)...); 117*da0073e9SAndroid Build Coastguard Worker }; 118*da0073e9SAndroid Build Coastguard Worker } 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker } // namespace at 121