1 #pragma once 2 3 #include <c10/macros/Export.h> 4 5 namespace c10 { 6 7 // Structure used to pack all the thread local boolean 8 // flags used by autograd 9 struct C10_API AutogradState { 10 static AutogradState& get_tls_state(); 11 static void set_tls_state(AutogradState state); 12 AutogradStateAutogradState13 AutogradState( 14 bool grad_mode, 15 bool inference_mode, 16 bool fw_grad_mode, 17 bool multithreading_enabled) 18 : grad_mode_(grad_mode), 19 inference_mode_(inference_mode), 20 fw_grad_mode_(fw_grad_mode), 21 multithreading_enabled_(multithreading_enabled), 22 view_replay_enabled_(false) {} 23 set_grad_modeAutogradState24 void set_grad_mode(bool enabled) { 25 grad_mode_ = enabled; 26 } 27 set_fw_grad_modeAutogradState28 void set_fw_grad_mode(bool enabled) { 29 fw_grad_mode_ = enabled; 30 } 31 set_inference_modeAutogradState32 void set_inference_mode(bool enabled) { 33 inference_mode_ = enabled; 34 } 35 set_multithreading_enabledAutogradState36 void set_multithreading_enabled(bool multithreading_enabled) { 37 multithreading_enabled_ = multithreading_enabled; 38 } 39 set_view_replay_enabledAutogradState40 void set_view_replay_enabled(bool view_replay_enabled) { 41 view_replay_enabled_ = view_replay_enabled; 42 } 43 get_grad_modeAutogradState44 bool get_grad_mode() const { 45 return grad_mode_; 46 } 47 get_fw_grad_modeAutogradState48 bool get_fw_grad_mode() const { 49 return fw_grad_mode_; 50 } 51 get_inference_modeAutogradState52 bool get_inference_mode() const { 53 return inference_mode_; 54 } 55 get_multithreading_enabledAutogradState56 bool get_multithreading_enabled() const { 57 return multithreading_enabled_; 58 } 59 get_view_replay_enabledAutogradState60 bool get_view_replay_enabled() const { 61 return view_replay_enabled_; 62 } 63 64 private: 65 bool grad_mode_ : 1; 66 bool inference_mode_ : 1; 67 bool fw_grad_mode_ : 1; 68 bool multithreading_enabled_ : 1; 69 bool view_replay_enabled_ : 1; 70 }; 71 72 } // namespace c10 73