1 #pragma once 2 3 #include <c10/core/SafePyObject.h> 4 #include <c10/macros/Export.h> 5 6 namespace c10::impl { 7 8 enum class TorchDispatchModeKey : int8_t { 9 FAKE, 10 PROXY, 11 FUNCTIONAL, 12 NUM_MODE_KEYS 13 }; 14 15 using PyObject_TorchDispatchMode = SafePyObjectT<TorchDispatchModeKey>; 16 17 struct C10_API TorchDispatchModeTLS { 18 // This API is NOT invariant safe. 19 // It must not take in an infra mode that uses TorchDispatchModeKey 20 // If you're pushing an infra mode onto the stack, we expect 21 // you to use set_mode 22 static void push_non_infra_mode_onto_stack( 23 std::shared_ptr<PyObject_TorchDispatchMode> mode); 24 // Pops the top mode of the stack, 25 // giving precedence to user modes before attempting to pop 26 // any infra modes 27 static const std::shared_ptr<PyObject_TorchDispatchMode> pop_stack(); 28 // Returns the highest-priority infra mode on the stack, 29 // along with its mode key. 30 static const std:: 31 tuple<std::shared_ptr<PyObject_TorchDispatchMode>, TorchDispatchModeKey> 32 pop_highest_infra_mode(); 33 34 static const std::shared_ptr<PyObject_TorchDispatchMode>& get_stack_at( 35 int64_t idx); 36 static int64_t stack_len(); 37 38 static const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>> 39 get_mode(TorchDispatchModeKey mode_key); 40 static const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>> 41 unset_mode(TorchDispatchModeKey mode_key); 42 static void set_mode( 43 const std::shared_ptr<PyObject_TorchDispatchMode>& mode, 44 TorchDispatchModeKey mode_key); 45 46 static const TorchDispatchModeTLS& get_state(); 47 static void set_state(TorchDispatchModeTLS state); 48 49 static bool any_modes_set(bool skip_infra_modes = false); 50 51 private: 52 std::vector<std::shared_ptr<PyObject_TorchDispatchMode>> stack_; 53 // Users are allowed to push multiple ProxyTorchDispatchMode objects onto the 54 // stack 55 // However, we only allow a single FakeTensorMode onto the stack at a time 56 // (Pushing additional FakeTensorModes onto the stack is a no-op) 57 std::array< 58 std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>, 59 static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS)> 60 infra_modes_; 61 }; 62 63 C10_API bool dispatch_mode_enabled(); 64 65 C10_API std::string to_string(TorchDispatchModeKey mode_key); 66 67 } // namespace c10::impl 68