1 #pragma once 2 3 #include <c10/core/impl/TorchDispatchModeTLS.h> 4 5 namespace torch::torch_dispatch_mode { 6 7 struct StashTorchDispatchModeGuard { 8 public: StashTorchDispatchModeGuardStashTorchDispatchModeGuard9 StashTorchDispatchModeGuard() { 10 if (c10::impl::TorchDispatchModeTLS::any_modes_set( 11 /*skip_infra_modes=*/true)) { 12 saved_mode_ = c10::impl::TorchDispatchModeTLS::pop_stack(); 13 } else { 14 auto mode_and_key = 15 c10::impl::TorchDispatchModeTLS::pop_highest_infra_mode(); 16 saved_mode_ = std::move(std::get<0>(mode_and_key)); 17 saved_mode_key_ = std::get<1>(mode_and_key); 18 } 19 } 20 ~StashTorchDispatchModeGuardStashTorchDispatchModeGuard21 ~StashTorchDispatchModeGuard() { 22 if (saved_mode_key_ != std::nullopt) { 23 c10::impl::TorchDispatchModeTLS::set_mode( 24 saved_mode_, saved_mode_key_.value()); 25 } else { 26 c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack( 27 std::move(saved_mode_)); 28 } 29 } 30 get_cur_modeStashTorchDispatchModeGuard31 const std::shared_ptr<c10::impl::PyObject_TorchDispatchMode>& get_cur_mode() { 32 return saved_mode_; 33 } 34 35 private: 36 std::shared_ptr<c10::impl::PyObject_TorchDispatchMode> saved_mode_; 37 std::optional<c10::impl::TorchDispatchModeKey> saved_mode_key_; 38 }; 39 40 struct StashTorchDispatchStackGuard { 41 public: StashTorchDispatchStackGuardStashTorchDispatchStackGuard42 StashTorchDispatchStackGuard() { 43 auto old = c10::impl::TorchDispatchModeTLS::get_state(); 44 c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_)); 45 saved_state_ = std::move(old); 46 } 47 ~StashTorchDispatchStackGuardStashTorchDispatchStackGuard48 ~StashTorchDispatchStackGuard() { 49 c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_)); 50 } 51 52 private: 53 c10::impl::TorchDispatchModeTLS saved_state_; 54 }; 55 56 } // namespace torch::torch_dispatch_mode 57