xref: /aosp_15_r20/external/pytorch/c10/core/impl/TorchDispatchModeTLS.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SafePyObject.h>
4 #include <c10/macros/Export.h>
5 
6 namespace c10::impl {
7 
8 enum class TorchDispatchModeKey : int8_t {
9   FAKE,
10   PROXY,
11   FUNCTIONAL,
12   NUM_MODE_KEYS
13 };
14 
15 using PyObject_TorchDispatchMode = SafePyObjectT<TorchDispatchModeKey>;
16 
17 struct C10_API TorchDispatchModeTLS {
18   // This API is NOT invariant safe.
19   // It must not take in an infra mode that uses TorchDispatchModeKey
20   // If you're pushing an infra mode onto the stack, we expect
21   // you to use set_mode
22   static void push_non_infra_mode_onto_stack(
23       std::shared_ptr<PyObject_TorchDispatchMode> mode);
24   // Pops the top mode of the stack,
25   // giving precedence to user modes before attempting to pop
26   // any infra modes
27   static const std::shared_ptr<PyObject_TorchDispatchMode> pop_stack();
28   // Returns the highest-priority infra mode on the stack,
29   // along with its mode key.
30   static const std::
31       tuple<std::shared_ptr<PyObject_TorchDispatchMode>, TorchDispatchModeKey>
32       pop_highest_infra_mode();
33 
34   static const std::shared_ptr<PyObject_TorchDispatchMode>& get_stack_at(
35       int64_t idx);
36   static int64_t stack_len();
37 
38   static const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
39   get_mode(TorchDispatchModeKey mode_key);
40   static const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
41   unset_mode(TorchDispatchModeKey mode_key);
42   static void set_mode(
43       const std::shared_ptr<PyObject_TorchDispatchMode>& mode,
44       TorchDispatchModeKey mode_key);
45 
46   static const TorchDispatchModeTLS& get_state();
47   static void set_state(TorchDispatchModeTLS state);
48 
49   static bool any_modes_set(bool skip_infra_modes = false);
50 
51  private:
52   std::vector<std::shared_ptr<PyObject_TorchDispatchMode>> stack_;
53   // Users are allowed to push multiple ProxyTorchDispatchMode objects onto the
54   // stack
55   // However, we only allow a single FakeTensorMode onto the stack at a time
56   // (Pushing additional FakeTensorModes onto the stack is a no-op)
57   std::array<
58       std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>,
59       static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS)>
60       infra_modes_;
61 };
62 
63 C10_API bool dispatch_mode_enabled();
64 
65 C10_API std::string to_string(TorchDispatchModeKey mode_key);
66 
67 } // namespace c10::impl
68