xref: /aosp_15_r20/external/pytorch/aten/src/ATen/SavedTensorHooks.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SafePyObject.h>
4 #include <c10/macros/Export.h>
5 #include <c10/util/python_stub.h>
6 #include <optional>
7 #include <stack>
8 #include <string>
9 
10 #include <utility>
11 
12 namespace at {
13 
14 namespace impl {
15 
16 struct TORCH_API SavedTensorDefaultHooksTLS {
17   // PyObject is defined in c10/util/python_stub.h
18   std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
19 
20   // See NOTE: [Disabling SavedTensorDefaultHooks] for context
21   // NOTE: [disabled_error_message invariant]
22   // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
23   // We did this for efficiency (so we didn't have to keep a separate bool
24   // around)
25   std::optional<std::string> disabled_error_message;
26 
27   // See NOTE: [Deferring tensor pack/unpack hooks until runtime]
28   bool is_tracing = false;
29 };
30 
31 } // namespace impl
32 
33 struct TORCH_API SavedTensorDefaultHooks {
34   static void push_hooks(
35       c10::SafePyObject pack_hook,
36       c10::SafePyObject unpack_hook);
37   static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
38   static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
39   get_hooks();
40   static void lazy_initialize();
41 
42   static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
43   static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
44 
45   // NOTE: [Disabling SavedTensorDefaultHooks]
46   // A developer of a PyTorch feature may choose to disable SavedTensorDefault
47   // hooks, especially if their feature does not work with it. If they are
48   // disabled, then the following will raise an error:
49   // - Attempting to push_hooks
50   // - calling disable(message) with a non-zero stack (hooks) size
51   static void disable(const std::string& error_message);
52   static void enable();
53   static bool is_enabled();
54   static const std::optional<std::string>& get_disabled_error_message();
55 
56   // NOTE: [Deferring tensor pack/unpack hooks until runtime]
57   // To preserve eager semantics of pack/unpack hooks firing only once per saved
58   // variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using
59   // disable() would loud error at trace time, and pushing a no-op hook would
60   // fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx.
61   // To do so, we disable these hooks during tracing. See
62   // https://github.com/pytorch/pytorch/issues/113263.
63   static bool set_tracing(bool is_tracing);
64 };
65 
66 } // namespace at
67