xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/functions/accumulate_grad.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/CachedTensorUtils.h>
4 #include <ATen/LegacyBatchedTensorImpl.h>
5 #include <ATen/TensorOperators.h>
6 #include <torch/csrc/Export.h>
7 #include <torch/csrc/autograd/function.h>
8 #include <torch/csrc/autograd/utils/grad_layout_contract.h>
9 #include <torch/csrc/autograd/variable.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #else
14 #include <ATen/ops/_sparse_coo_tensor_unsafe.h>
15 #endif
16 
17 #include <mutex>
18 
19 namespace torch::autograd {
20 
21 #define CHECK_RESULT(RESULT, VAR)                                          \
22   if (!(RESULT.is_sparse() || VAR.is_sparse() || RESULT.is_sparse_csr() || \
23         VAR.is_sparse_csr())) {                                            \
24     if (!utils::obeys_layout_contract(RESULT, VAR)) {                      \
25       TORCH_WARN_ONCE(                                                     \
26           "grad and param do not obey the gradient layout contract. "      \
27           "This is not an error, but may impair performance.\n"            \
28           "grad.sizes() = ",                                               \
29           RESULT.sizes(),                                                  \
30           ", strides() = ",                                                \
31           RESULT.strides(),                                                \
32           "\n",                                                            \
33           "param.sizes() = ",                                              \
34           VAR.sizes(),                                                     \
35           ", strides() = ",                                                \
36           VAR.strides());                                                  \
37     }                                                                      \
38   }
39 
40 struct TORCH_API AccumulateGrad : public Node {
41   explicit AccumulateGrad(Variable variable_);
42 
43   variable_list apply(variable_list&& grads) override;
44 
tensor_pre_hooksAccumulateGrad45   std::vector<std::unique_ptr<FunctionPreHook>>& tensor_pre_hooks() noexcept
46       override {
47     // NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
48     //     it can be destroyed even though the Tensor is still alive (contrary
49     //     to all other Nodes). So we must lazily read the Tensor hooks here.
50     return impl::hooks(variable);
51   }
52 
tensor_post_acc_grad_hooksAccumulateGrad53   std::unique_ptr<PostAccumulateGradHook>& tensor_post_acc_grad_hooks() noexcept
54       override {
55     // NB: Since the AccumulateGrad Node is only a weak ref from the Tensor,
56     //     it can be destroyed even though the Tensor is still alive (contrary
57     //     to all other Nodes). So we must lazily read the Tensor hooks here.
58     return impl::post_acc_grad_hooks(variable);
59   }
60 
61   // Given a variable with its current grad as variable_grad, accumulates
62   // new_grad into variable_grad if in place accumulation is possible.
63   // Otherwise, uses 'update_grad' to update the grad for the variable.
64 
65   // "Gradient Layout Contract"
66   //
67   // AccumulateGrad tries to stash strided (non-sparse) grads with memory layout
68   // (strides) such that variables and grads interact efficiently in later
69   // optimizer kernels, and grads interact efficiently with c10d::Reducer.cpp.
70   //
71   // Specifically, AccumulateGrad tries to ensure the following
72   // (cf torch/csrc/autograd/utils/grad_layout_contract.h):
73   //   (1) if variable.is_non_overlapping_and_dense(), the stashed grad's
74   //       strides match variable.
75   //   (2) else, stashed grad is rowmajor contiguous.
76   // If variable's grad does not exist (!variable_grad.defined())
77   // AccumulateGrad steals new_grad if it's stealable and obeys the contract
78   // already, otherwise it deep copies new_grad into an obedient clone.
79   //
80   // If variable's grad already exists (variable_grad.defined()), new_grad must
81   // be added to variable_grad.  If we aren't setting up for double backward
82   // (!GradMode::is_enabled()), AccumulateGrad performs "variable_grad +=
83   // new_grad" in-place, which keeps variable_grad's layout. We assume (hope)
84   // variable_grad was created obeying (1) or (2) at some point in the past.
85   //
86   // If we are setting up for double backward, AccumulateGrad updates the grad
87   // out-of-place via "variable_grad + new_grad."  TensorIterator operator+
88   // decides result's layout.  Typically TensorIterator matches strides of the
89   // first arg, so we once again assume (hope) variable_grad was originally
90   // created obeying (1) or (2).
91   //
92   // AccumulateGrad does not enforce the contract with 100% certainty. Examples:
93   //  - If a user manually permutes a param or its grad, then runs a fwd+bwd,
94   //    variable_grad += new_grad keeps variable_grad's layout without
95   //    rechecking the contract.
96   //  - If TensorIterator changes its corner cases about operator+'s result
97   //    (for example, giving more or less priority to channels_last inputs, see
98   //    https://github.com/pytorch/pytorch/pull/37968) the result may not obey.
99   //
100   // Fortunately, if a given grad doesn't satisfy (1) or (2), the penalty is
101   // degraded performance in Reducer.cpp or optimizer kernels, not death by
102   // assert or silently bad numerics.
103 
104   // variable: the variable whose grad we're accumulating.
105   // variable_grad: the current grad for the variable.
106   // new_grad: new grad we want to accumulate for the variable.
107   // num_expected_refs: the number of refs we expect to hold internally
108   //                    such that it is safe to avoid cloning the grad
109   //                    if use_count() of the grad is less than or equal
110   //                    to this value (in addition to post_hooks).
111   // update_grad: Function that is used to update grad for the variable.
112   //              The argument to the function is a Tensor which
113   //              is used to set a new value for the grad.
114   template <typename T>
accumulateGradAccumulateGrad115   static void accumulateGrad(
116       const Variable& variable,
117       at::Tensor& variable_grad,
118       const at::Tensor& new_grad,
119       size_t num_expected_refs,
120       const T& update_grad) {
121     if (!variable_grad.defined()) {
122       if (!GradMode::is_enabled() && !new_grad.is_sparse() &&
123           !new_grad.is_sparse_csr() &&
124           !(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) &&
125           at::caching::adjusted_use_count(new_grad) <= num_expected_refs &&
126           (new_grad.is_mkldnn() ||
127            utils::obeys_layout_contract(new_grad, variable))) {
128         // we aren't setting up for double-backward
129         // not sparse
130         // no other user-visible tensor references new_grad
131         // new_grad obeys the "Gradient Layout Contract", there has a special
132         // case, For MKLDNN tensor, which is a opaque tensor, assuming it obeys
133         // layout_contract. Under these conditions, we can steal new_grad
134         // without a deep copy.
135         update_grad(new_grad.detach());
136       } else if (
137           !GradMode::is_enabled() && new_grad.is_sparse() &&
138           new_grad._indices().is_contiguous() &&
139           new_grad._values().is_contiguous() &&
140           // Use count for indices and values should always be <=1 since the
141           // SparseTensor should be the only one holding a reference to these.
142           new_grad._indices().use_count() <= 1 &&
143           new_grad._values().use_count() <= 1 &&
144           new_grad.use_count() <= num_expected_refs) {
145         // Can't detach sparse tensor (since metadata changes are not allowed
146         // after detach), so just create a new one for the grad which is a
147         // shallow copy. We need a shallow copy so that modifying the original
148         // grad tensor doesn't modify the grad we accumulate.
149         // We only skip clone if indices and values themselves are contiguous
150         // for backward compatibility reasons. Since without this optimization,
151         // earlier we would clone the entire SparseTensor which cloned indices
152         // and values.
153         // For details see https://github.com/pytorch/pytorch/issues/34375.
154 
155         // No scenario where we expect this to be true currently
156         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
157             !at::caching::is_cached_tensor(new_grad._indices()) &&
158             !at::caching::is_cached_tensor(new_grad._values()) &&
159             !at::caching::is_cached_tensor(new_grad));
160 
161         update_grad(at::_sparse_coo_tensor_unsafe(
162             new_grad._indices(),
163             new_grad._values(),
164             new_grad.sizes(),
165             new_grad.options()));
166       } else {
167         if (new_grad.is_sparse() || new_grad.is_sparse_csr() ||
168             new_grad.is_nested()) {
169           update_grad(new_grad.clone());
170         } else {
171           if (new_grad.is_mkldnn()) {
172             update_grad(new_grad.clone());
173           } else {
174             // Deep copies new_grad according to the "Gradient Layout Contract."
175             update_grad(utils::clone_obey_contract(new_grad, variable));
176           }
177         }
178       }
179     } else if (!GradMode::is_enabled()) {
180       // This case is not strictly necessary, but it makes the first-order only
181       // case slightly more efficient.
182       if (variable_grad.is_sparse() && !new_grad.is_sparse()) {
183         // If `variable_grad` is sparse and `new_grad` is not sparse, their
184         // sum is not sparse, and we must change the TensorImpl type of
185         // `variable_grad` for it to store the result. However, changing the
186         // TensorImpl type of a tensor requires changing the tensor itself, and
187         // thus in this case we have to change the grad tensor.
188         auto result = new_grad + variable_grad;
189         CHECK_RESULT(result, variable);
190         update_grad(std::move(result));
191       } else if (!at::inplaceIsVmapCompatible(variable_grad, new_grad)) {
192         // Ideally we'd perform an in-place operation to avoid changing
193         // the grad tensor. However, if that's impossible because the grads
194         // are vmap-incompatible (See NOTE: [vmap-incompatible in-place
195         // operations]), then we just add them out-of-place.
196         auto result = variable_grad + new_grad;
197         CHECK_RESULT(result, variable);
198         update_grad(std::move(result));
199       } else {
200         // In this case we can avoid changing the grad tensor. There are three
201         // scenarios when we'll hit this case:
202         //
203         // 1. `variable_grad` is sparse, and `new_grad` is sparse.
204         // 2. `variable_grad` is dense, and `new_grad` is sparse.
205         // 3. `variable_grad` is dense, and `new_grad` is dense.
206         // 4. `variable_grad` is mkldnn, and `new_grad` is mkldnn.
207         //
208         // In all of these four cases, `variable_grad += new_grad` is a
209         // valid operation which adds `new_grad` to `variable_grad` in
210         // place. `variable_grad` is thus still referring to the same tensor
211         // after the operation.
212         // Also DistributedDataParallel(DDP) package relies on grad being
213         // mutated in place for saving peak memory usage. DDP will still
214         // work correctly if it is mutated out of place here, but DDP will
215         // maintain one extra copy of grad tensors in buffer and thus
216         // increase peak memory usage.
217         variable_grad += new_grad;
218         CHECK_RESULT(variable_grad, variable);
219         // ^ We could enforce the contract more aggressively here by writing:
220         // if (variable_grad.is_sparse() || new_grad.is_sparse()) {
221         //   variable_grad += new_grad;
222         // } else if (obeys_layout_contract(variable_grad, variable)) {
223         //   variable_grad += new_grad;
224         // } else {
225         //   result = at::empty_strided(variable.sizes(), variable.strides(),
226         //                              variable.options().memory_format(std::nullopt));
227         //   update_grad(at::native::add_out(result, variable_grad,
228         //   new_grad, 1.0);
229         // }
230         // However, that accumulation is sometimes in place and sometimes not,
231         // which may break user code.
232       }
233     } else {
234       at::Tensor result;
235       if (variable_grad.is_sparse() && !new_grad.is_sparse()) {
236         // CPU backend throws an error on sparse + dense, so prefer dense +
237         // sparse here.
238         result = new_grad + variable_grad;
239       } else {
240         // Assumes operator+ result typically matches strides of first arg,
241         // and hopes variable_grad was originally created obeying layout
242         // contract.
243         result = variable_grad + new_grad;
244       }
245       CHECK_RESULT(result, variable);
246       update_grad(std::move(result));
247       // ^ We could enforce the contract more aggressively here by saying
248       // if (obeys_layout_contract(new_grad, variable)) {
249       //   update_grad(new_grad + variable_grad);
250       // } else {
251       //   update_grad(variable_grad + new_grad);
252       // }
253       // such that the stashed grad is likely to have the right strides if
254       // either variable_grad or new_grad already has the right strides.
255       // We could enforce the contract with certainty by saying
256       // auto result = variable_grad + new_grad (or vice versa), checking
257       // result's layout, and copying to an obedient clone if necessary before
258       // update_grad. The copy would require another gmem pass.  We can't create
259       // empty result with the right layout then add_out into it with a single
260       // kernel, because GradMode is enabled in this branch, and add_out isn't
261       // differentiable. Maybe more trouble than it's worth.
262     }
263   }
264 
265   void compiled_args(CompiledNodeArgs& args) override;
266   variable_list apply_with_saved(
267       const variable_list& inputs,
268       SwapSavedVariables& saved) override;
269 
270   Variable variable;
271 };
272 
273 #undef CHECK_RESULT
274 
275 } // namespace torch::autograd
276