xref: /aosp_15_r20/external/pytorch/c10/core/AutogradState.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Export.h>
4 
5 namespace c10 {
6 
7 // Structure used to pack all the thread local boolean
8 // flags used by autograd
9 struct C10_API AutogradState {
10   static AutogradState& get_tls_state();
11   static void set_tls_state(AutogradState state);
12 
AutogradStateAutogradState13   AutogradState(
14       bool grad_mode,
15       bool inference_mode,
16       bool fw_grad_mode,
17       bool multithreading_enabled)
18       : grad_mode_(grad_mode),
19         inference_mode_(inference_mode),
20         fw_grad_mode_(fw_grad_mode),
21         multithreading_enabled_(multithreading_enabled),
22         view_replay_enabled_(false) {}
23 
set_grad_modeAutogradState24   void set_grad_mode(bool enabled) {
25     grad_mode_ = enabled;
26   }
27 
set_fw_grad_modeAutogradState28   void set_fw_grad_mode(bool enabled) {
29     fw_grad_mode_ = enabled;
30   }
31 
set_inference_modeAutogradState32   void set_inference_mode(bool enabled) {
33     inference_mode_ = enabled;
34   }
35 
set_multithreading_enabledAutogradState36   void set_multithreading_enabled(bool multithreading_enabled) {
37     multithreading_enabled_ = multithreading_enabled;
38   }
39 
set_view_replay_enabledAutogradState40   void set_view_replay_enabled(bool view_replay_enabled) {
41     view_replay_enabled_ = view_replay_enabled;
42   }
43 
get_grad_modeAutogradState44   bool get_grad_mode() const {
45     return grad_mode_;
46   }
47 
get_fw_grad_modeAutogradState48   bool get_fw_grad_mode() const {
49     return fw_grad_mode_;
50   }
51 
get_inference_modeAutogradState52   bool get_inference_mode() const {
53     return inference_mode_;
54   }
55 
get_multithreading_enabledAutogradState56   bool get_multithreading_enabled() const {
57     return multithreading_enabled_;
58   }
59 
get_view_replay_enabledAutogradState60   bool get_view_replay_enabled() const {
61     return view_replay_enabled_;
62   }
63 
64  private:
65   bool grad_mode_ : 1;
66   bool inference_mode_ : 1;
67   bool fw_grad_mode_ : 1;
68   bool multithreading_enabled_ : 1;
69   bool view_replay_enabled_ : 1;
70 };
71 
72 } // namespace c10
73