xref: /aosp_15_r20/external/pytorch/c10/core/InferenceMode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/AutogradState.h>
4 #include <c10/core/DispatchKey.h>
5 #include <c10/core/DispatchKeySet.h>
6 #include <c10/core/impl/LocalDispatchKeySet.h>
7 #include <c10/macros/Export.h>
8 
9 namespace c10 {
10 
11 // A RAII, thread local (!) guard that enables or disables inference mode upon
12 // construction, and sets it back to the original value upon destruction.
13 struct C10_API InferenceMode {
14   // Note [Expected TLS state in InferenceMode]:
15   //   InferenceMode: ADInplaceOrView not in
16   //   raw_local_dispatch_key_set.included(),
17   //                  Autograd in raw_local_dispatch_key_set.excluded()
18   //                  GradMode is disabled.
19   //   NormalMode: ADInplaceOrView in raw_local_dispatch_key_set.included(),
20   //               Autograd not in raw_local_dispatch_key_set.excluded()
21   //               GradMode is enabled by default unless toggled manually
22   //               through other APIs, e.g. NoGradGuard.
23   //
24   // Invariant:
25   // - ADInplaceOrView is never in the excluded set
26   // - Autograd is never in the included set
27   // - Setting InferenceMode will set GradMode accordingly, but not vice versa.
28   //
29   //  1. Why do we put ADInplaceOrView in included set outside InferenceMode?
30   //
31   //     Inplace update to inference tensor outside InferenceMode is not
32   //     allowed. See Note [Inplace update inference tensor] for more details.
33   //     Without going through ADInplaceOrView kernel, we cannot throw error
34   //     for `inference_tensor.add_(1)` case.
35   //
36   // 2. Why not put ADInplaceOrView in the excluded set inside InferenceMode?
37   //
38   //    For example:
39   //    torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true);
40   //    torch::Tensor k = a + 2;
41   //    {
42   //      c10::InferenceMode guard(true);
43   //      k.add_(2);
44   //    }
45   //    `k.add_(2)` still need to go through ADInplaceOrView kernel so that it's
46   //    prepared for future autograd.
47   //
48   // 3. Why does setting InferenceMode also set GradMode?
49   //
50   //    This is required since InferenceMode is a faster and more restrictive
51   //    version of NoGradGuard. All runtime checks using GradMode::is_enabled()
52   //    are applicable to InferenceMode as well, e.g.
53   //    `tensorTypeInCurrentExecutionContext` in interpreter.cpp.
54   InferenceMode(bool enabled = true)
prev_modeInferenceMode55       : prev_mode(AutogradState::get_tls_state()),
56         prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
57     // Enabling inference mode means disabling grad modes
58     // And disabling inference mode means enabling grad modes
59     AutogradState::set_tls_state(AutogradState(
60         /* grad_mode */ !enabled,
61         /* inference_mode */ enabled,
62         /* fw_grad_mode */ !enabled,
63         /* multithreading_enabled*/ !enabled));
64     DispatchKeySet included = enabled
65         ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
66         : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
67     DispatchKeySet excluded = enabled
68         ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
69         : (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
70     c10::impl::PODLocalDispatchKeySet cur_keyset{};
71     cur_keyset.set_included(included);
72     cur_keyset.set_excluded(excluded);
73     c10::impl::_force_tls_local_dispatch_key_set(cur_keyset);
74   }
75 
~InferenceModeInferenceMode76   ~InferenceMode() {
77     AutogradState::set_tls_state(prev_mode);
78     c10::impl::_force_tls_local_dispatch_key_set(prev_keyset);
79   }
80   static bool is_enabled();
81 
82  private:
83   AutogradState prev_mode;
84   c10::impl::LocalDispatchKeySet prev_keyset;
85 };
86 } // namespace c10
87