xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/torch_dispatch_mode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/impl/TorchDispatchModeTLS.h>
4 
5 namespace torch::torch_dispatch_mode {
6 
7 struct StashTorchDispatchModeGuard {
8  public:
StashTorchDispatchModeGuardStashTorchDispatchModeGuard9   StashTorchDispatchModeGuard() {
10     if (c10::impl::TorchDispatchModeTLS::any_modes_set(
11             /*skip_infra_modes=*/true)) {
12       saved_mode_ = c10::impl::TorchDispatchModeTLS::pop_stack();
13     } else {
14       auto mode_and_key =
15           c10::impl::TorchDispatchModeTLS::pop_highest_infra_mode();
16       saved_mode_ = std::move(std::get<0>(mode_and_key));
17       saved_mode_key_ = std::get<1>(mode_and_key);
18     }
19   }
20 
~StashTorchDispatchModeGuardStashTorchDispatchModeGuard21   ~StashTorchDispatchModeGuard() {
22     if (saved_mode_key_ != std::nullopt) {
23       c10::impl::TorchDispatchModeTLS::set_mode(
24           saved_mode_, saved_mode_key_.value());
25     } else {
26       c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
27           std::move(saved_mode_));
28     }
29   }
30 
get_cur_modeStashTorchDispatchModeGuard31   const std::shared_ptr<c10::impl::PyObject_TorchDispatchMode>& get_cur_mode() {
32     return saved_mode_;
33   }
34 
35  private:
36   std::shared_ptr<c10::impl::PyObject_TorchDispatchMode> saved_mode_;
37   std::optional<c10::impl::TorchDispatchModeKey> saved_mode_key_;
38 };
39 
40 struct StashTorchDispatchStackGuard {
41  public:
StashTorchDispatchStackGuardStashTorchDispatchStackGuard42   StashTorchDispatchStackGuard() {
43     auto old = c10::impl::TorchDispatchModeTLS::get_state();
44     c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_));
45     saved_state_ = std::move(old);
46   }
47 
~StashTorchDispatchStackGuardStashTorchDispatchStackGuard48   ~StashTorchDispatchStackGuard() {
49     c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_));
50   }
51 
52  private:
53   c10::impl::TorchDispatchModeTLS saved_state_;
54 };
55 
56 } // namespace torch::torch_dispatch_mode
57