1 #pragma once 2 3 #include <c10/core/SafePyObject.h> 4 #include <c10/macros/Export.h> 5 #include <c10/util/python_stub.h> 6 #include <optional> 7 #include <stack> 8 #include <string> 9 10 #include <utility> 11 12 namespace at { 13 14 namespace impl { 15 16 struct TORCH_API SavedTensorDefaultHooksTLS { 17 // PyObject is defined in c10/util/python_stub.h 18 std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack; 19 20 // See NOTE: [Disabling SavedTensorDefaultHooks] for context 21 // NOTE: [disabled_error_message invariant] 22 // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled 23 // We did this for efficiency (so we didn't have to keep a separate bool 24 // around) 25 std::optional<std::string> disabled_error_message; 26 27 // See NOTE: [Deferring tensor pack/unpack hooks until runtime] 28 bool is_tracing = false; 29 }; 30 31 } // namespace impl 32 33 struct TORCH_API SavedTensorDefaultHooks { 34 static void push_hooks( 35 c10::SafePyObject pack_hook, 36 c10::SafePyObject unpack_hook); 37 static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks(); 38 static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>> 39 get_hooks(); 40 static void lazy_initialize(); 41 42 static const impl::SavedTensorDefaultHooksTLS& get_tls_state(); 43 static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls); 44 45 // NOTE: [Disabling SavedTensorDefaultHooks] 46 // A developer of a PyTorch feature may choose to disable SavedTensorDefault 47 // hooks, especially if their feature does not work with it. If they are 48 // disabled, then the following will raise an error: 49 // - Attempting to push_hooks 50 // - calling disable(message) with a non-zero stack (hooks) size 51 static void disable(const std::string& error_message); 52 static void enable(); 53 static bool is_enabled(); 54 static const std::optional<std::string>& get_disabled_error_message(); 55 56 // NOTE: [Deferring tensor pack/unpack hooks until runtime] 57 // To preserve eager semantics of pack/unpack hooks firing only once per saved 58 // variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using 59 // disable() would loud error at trace time, and pushing a no-op hook would 60 // fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx. 61 // To do so, we disable these hooks during tracing. See 62 // https://github.com/pytorch/pytorch/issues/113263. 63 static bool set_tracing(bool is_tracing); 64 }; 65 66 } // namespace at 67