1 #pragma once 2 3 #include <c10/core/AutogradState.h> 4 #include <c10/core/DispatchKey.h> 5 #include <c10/core/DispatchKeySet.h> 6 #include <c10/core/impl/LocalDispatchKeySet.h> 7 #include <c10/macros/Export.h> 8 9 namespace c10 { 10 11 // A RAII, thread local (!) guard that enables or disables inference mode upon 12 // construction, and sets it back to the original value upon destruction. 13 struct C10_API InferenceMode { 14 // Note [Expected TLS state in InferenceMode]: 15 // InferenceMode: ADInplaceOrView not in 16 // raw_local_dispatch_key_set.included(), 17 // Autograd in raw_local_dispatch_key_set.excluded() 18 // GradMode is disabled. 19 // NormalMode: ADInplaceOrView in raw_local_dispatch_key_set.included(), 20 // Autograd not in raw_local_dispatch_key_set.excluded() 21 // GradMode is enabled by default unless toggled manually 22 // through other APIs, e.g. NoGradGuard. 23 // 24 // Invariant: 25 // - ADInplaceOrView is never in the excluded set 26 // - Autograd is never in the included set 27 // - Setting InferenceMode will set GradMode accordingly, but not vice versa. 28 // 29 // 1. Why do we put ADInplaceOrView in included set outside InferenceMode? 30 // 31 // Inplace update to inference tensor outside InferenceMode is not 32 // allowed. See Note [Inplace update inference tensor] for more details. 33 // Without going through ADInplaceOrView kernel, we cannot throw error 34 // for `inference_tensor.add_(1)` case. 35 // 36 // 2. Why not put ADInplaceOrView in the excluded set inside InferenceMode? 37 // 38 // For example: 39 // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true); 40 // torch::Tensor k = a + 2; 41 // { 42 // c10::InferenceMode guard(true); 43 // k.add_(2); 44 // } 45 // `k.add_(2)` still need to go through ADInplaceOrView kernel so that it's 46 // prepared for future autograd. 47 // 48 // 3. Why does setting InferenceMode also set GradMode? 49 // 50 // This is required since InferenceMode is a faster and more restrictive 51 // version of NoGradGuard. All runtime checks using GradMode::is_enabled() 52 // are applicable to InferenceMode as well, e.g. 53 // `tensorTypeInCurrentExecutionContext` in interpreter.cpp. 54 InferenceMode(bool enabled = true) prev_modeInferenceMode55 : prev_mode(AutogradState::get_tls_state()), 56 prev_keyset(c10::impl::tls_local_dispatch_key_set()) { 57 // Enabling inference mode means disabling grad modes 58 // And disabling inference mode means enabling grad modes 59 AutogradState::set_tls_state(AutogradState( 60 /* grad_mode */ !enabled, 61 /* inference_mode */ enabled, 62 /* fw_grad_mode */ !enabled, 63 /* multithreading_enabled*/ !enabled)); 64 DispatchKeySet included = enabled 65 ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView) 66 : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView); 67 DispatchKeySet excluded = enabled 68 ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset) 69 : (prev_keyset.excluded_ - c10::autograd_dispatch_keyset); 70 c10::impl::PODLocalDispatchKeySet cur_keyset{}; 71 cur_keyset.set_included(included); 72 cur_keyset.set_excluded(excluded); 73 c10::impl::_force_tls_local_dispatch_key_set(cur_keyset); 74 } 75 ~InferenceModeInferenceMode76 ~InferenceMode() { 77 AutogradState::set_tls_state(prev_mode); 78 c10::impl::_force_tls_local_dispatch_key_set(prev_keyset); 79 } 80 static bool is_enabled(); 81 82 private: 83 AutogradState prev_mode; 84 c10::impl::LocalDispatchKeySet prev_keyset; 85 }; 86 } // namespace c10 87