xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/saved_variable.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/autograd/forward_grad.h>
5 #include <torch/csrc/autograd/saved_variable_hooks.h>
6 
7 #include <ATen/core/Tensor.h>
8 
9 #include <cstdint>
10 #include <memory>
11 
12 namespace torch::autograd {
13 
14 using Variable = at::Tensor;
15 struct Node;
16 
17 TORCH_API extern const char* ERR_BACKWARD_TWICE;
18 
19 /// A snapshot of a variable at a certain version. A `SavedVariable` stores
20 /// enough information to reconstruct a variable from a certain point in time.
21 class TORCH_API SavedVariable {
22  public:
23   SavedVariable() = default;
24   SavedVariable(
25       const Variable& variable,
26       bool is_output,
27       bool is_inplace_on_view = false);
28   SavedVariable(
29       const std::optional<Variable>& variable,
30       bool is_output,
31       bool is_inplace_on_view = false);
32   SavedVariable(SavedVariable&&) = default;
33   SavedVariable& operator=(SavedVariable&&) = default;
~SavedVariable()34   ~SavedVariable() {
35     if (fw_grad_) {
36       // See note [ Using ForwardGrad ]
37       fw_grad_->clear();
38     }
39   }
40 
41   /// Reconstructs the saved variable. Pass `saved_for` as the gradient
42   /// function if constructing the `SavedVariable` with it would have caused a
43   /// circular reference.
44   Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const;
45 
46   void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks);
47 
48   void reset_data();
49 
has_hooks()50   bool has_hooks() const {
51     return (bool)hooks_;
52   }
53 
54  private:
55   // This field contains either:
56   // 1. the variable to save
57   // 2. or its tensor_data.
58   // If storing the variable itself would create a circular reference,
59   // we fall into the second case and its metadata is also saved separately.
60   // In that case, the grad_fn must be passed in to the unpack function when
61   // reconstructing the Variable (except when we are doing an inplace operation
62   // on a view, see below). The field saved_original_ below reflects the two
63   // cases: its value is true in the first case and false in the second case.
64   // The value data_.defined() can be false in three cases:
65   // 1. SavedVariable was constructed without a Tensor (the value to save is
66   // None), in that case was_default_constructed_ will be kept at true
67   // 2. The saved variable has been released by calling
68   // SavedVariable::reset_data(), typically during the backward pass
69   // 3. Hooks have been registered. In that case, hooks_ will be defined
70   // instead. Note that the value of saved_original_ only reflects what happened
71   // during the construction of the SavedVariable. If saved_original_ is true,
72   // we saved the original tensor in data_, but if the user registers hooks, we
73   // will no longer have it (despite the saved_original_ still being true)
74   at::Tensor data_;
75 
76   // This field is used to store the forward AD gradients associated with
77   // the saved Tensor. Note that this shared_ptr must never be shared with
78   // either the saved Tensor or the unpacked Tensor. See note [ Using
79   // ForwardGrad ]
80   std::shared_ptr<ForwardGrad> fw_grad_;
81 
82   // Weak version of grad_fn_ that prevents leaks in rebase_history() for
83   // inplace views.
84   // This variable is used when the user chooses to create a SavedVariable with
85   // is_inplace_on_view = true.
86   // In that case, the grad_fn passed in to the unpack function at unwrapping
87   // time is unused.
88   std::weak_ptr<Node> weak_grad_fn_;
89 
90   uint32_t saved_version_ = 0;
91   uint32_t output_nr_ = 0;
92   bool was_default_constructed_ = true;
93   bool is_inplace_on_view_ = false;
94   bool saved_original_ = false;
95   bool is_leaf_ = false;
96   bool is_output_ = false;
97 
98   // Hooks are a pair of functions pack_hook/unpack_hook that provides
99   // fine-grained control over how the SavedVariable should save its data.
100   // pack_hook is called upon registration, while unpack_hook is called when
101   // unpacking.
102   std::unique_ptr<SavedVariableHooks> hooks_;
103   // Fields grad_fn_, grad_accumulator_, and requires_grad_ are only used if
104   // hooks are defined. They are set before pack_hook is called and used after
105   // unpack_hook is called.
106   std::shared_ptr<Node> grad_fn_;
107   // For the usual case where leaf tensors are the input, we expect its
108   // grad_acc to be kept alive by the graph. The reason SavedVariable holds
109   // a owning reference is to support the case where a custom autograd Function
110   // saves an intermediate.
111   std::shared_ptr<Node> grad_accumulator_;
112   bool requires_grad_ = false;
113 
114   void save_metadata(const Variable& data);
115   static std::unique_ptr<SavedVariableHooks> get_default_hooks();
116   void set_hooks_and_pack_data(
117       std::unique_ptr<SavedVariableHooks>&& hooks,
118       const Variable& data);
119 };
120 } // namespace torch::autograd
121