xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/autograd_meta.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <c10/util/irange.h>
3 #include <torch/csrc/autograd/variable.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #else
8 #include <ATen/ops/_has_same_storage_numel.h>
9 #include <ATen/ops/_new_zeros_with_same_feature_meta.h>
10 #include <ATen/ops/zeros.h>
11 #endif
12 
13 namespace torch::autograd {
14 
15 using at::Tensor;
16 
17 // [Forward Grad View/inplace]
18 // It is important to us to allow view and inplace to work with dual Tensors.
19 // These operations should either compute the right gradient or raise a
20 // user-friendly error.
21 
22 // The basic case where all Tensors are dual Tensors is as follows:
23 //     # Have:
24 //     #   foo is a dual Tensor that is not a view
25 //     #   bar is a dual Tensor of appropriate size (depending on cases) that is
26 //     not a view
27 //
28 //     # Case 1: no view
29 //     foo.copy_(bar)
30 //
31 //     # Case 2: with view, propagate from view to base
32 //     view = foo[0]
33 //     view.copy_(bar)
34 //
35 //     # Case 3: with view, propagate from base to view
36 //     view = foo[0]
37 //     foo.copy_(bar)
38 //
39 //     # In both cases, the forward grad of foo must be properly updated.
40 //     # In the second and third cases, the forward grad of view must match
41 //     # the one of foo for the subset they have in common.
42 //
43 // All these cases can be handled by the following layout constraint on the
44 // forward grad:
45 //   - A Tensor and its forward grad (for all levels) must have the same
46 //   metadata (size, stride
47 //     conj/neg bit and storage offset). Storage offset must be in this metadata
48 //     because of as_strided. conj/neg bit must be part of this metadata because
49 //     of ops like `real`.
50 //   - View operations must create a forward grad that is a view of the base's
51 //   forward grad.
52 //   - Inplace operations must modify the input's forward grad inplace.
53 //
54 // This layout constraint is ensured in the `set_fw_grad` function below
55 
56 // More complex cases arrise when non-dual Tensor interact with dual Tensors.
57 // The two most important cases are:
58 //
59 //     # Have:
60 //     #   foo is a regular Tensor that is not a view
61 //     #   bar is a dual Tensor of appropriate size (depending on cases) that is
62 //     not a view
63 //
64 //     # Case 4: Changes on the view must propagate to its base
65 //     view = foo[0]
66 //     # view is still a regular Tensor here
67 //     view.copy_(bar)
68 //     # Now both view and foo are dual Tensor with appropriate forward grad
69 //
70 //     # Case 5: Changes on the base must propagate on all its views
71 //     view = foo[0]
72 //     # view is still a regular Tensor here
73 //     base.copy_(bar)
74 //     # Now both view and foo are dual Tensor with appropriate forward grad
75 //
76 //     # NB there is a case 6 involving changes on a view propagating to other
77 //     views # but it is fully described by the two others and is skipped in
78 //     this discussion.
79 //
80 // Case 4 is handled by set_fw_grad by properly setting the forward grad of the
81 // base if needed. Case 5 is handled in fw_grad by reading the forward grad from
82 // the base if needed.
83 
84 namespace utils {
85 
86 // Enforcing that the metadata between the primal and tangent are same has two
87 // goals:
88 // - When properties of the primal are checked in composite op's to determine
89 //   control flow, the code path decided upon is also reasonable for the tangent
90 // - Make sure that when the same as_strided is applied to both primal and
91 //   and tangent, it behaves similarly.
92 //
93 // We do that by checking:
94 //   1) the storages have same properties: size and conj/neg-ness
95 //   2) the same indices refer to the same elements in storage
96 //      (we are more strict than necessary here to satisfy the goal 1)
has_same_meta(const Variable & base,const Variable & other)97 bool has_same_meta(const Variable& base, const Variable& other) {
98   if (!base.defined() || !other.defined()) {
99     return false;
100   }
101   // 1) The storages have the same properties
102   if (!at::_has_same_storage_numel(base, other)) {
103     return false;
104   }
105   if (base.is_conj() != other.is_conj() || base.is_neg() != other.is_neg()) {
106     return false;
107   }
108 
109   // Technically dim and size belong as part of (2), so we shouldn't really care
110   // if a zero-numel tensor violates these. But since these properties
111   // (unlike offset and strides) often determine control flow in composite ops
112   // it is useful to enforce that they match for primal and tangent here so
113   // nothing funny happens later (See goal 1).
114   if (base.dim() != other.dim()) {
115     return false;
116   }
117   for (const auto i : c10::irange(base.dim())) {
118     if (base.sym_sizes()[i] != other.sym_sizes()[i]) {
119       return false;
120     }
121   }
122 
123   // The check below will always be vacuously true for 0-element tensors
124   if (base.sym_numel() == 0 && other.sym_numel() == 0) {
125     return true;
126   }
127 
128   // 2) The same indices refer to the same elements in storage
129   if (base.sym_storage_offset() != other.sym_storage_offset()) {
130     return false;
131   }
132 
133   for (const auto i : c10::irange(base.dim())) {
134     if (base.sym_strides()[i] != other.sym_strides()[i] &&
135         base.sym_sizes()[i] != 1 && base.sym_sizes()[i] != 0) {
136       return false;
137     }
138   }
139   return true;
140 }
141 
142 } // namespace utils
143 
144 // This function is will ensure that the fw_grad_ is properly a view of the base
145 // for inplace ops on Tensors that do not have forward grad originally.
set_fw_grad(const at::TensorBase & new_grad_base,const at::TensorBase & self_base,uint64_t level,bool is_inplace_op)146 void AutogradMeta::set_fw_grad(
147     const at::TensorBase& new_grad_base,
148     const at::TensorBase& self_base,
149     uint64_t level,
150     bool is_inplace_op) {
151   TORCH_CHECK(
152       !new_grad_base._fw_grad(level).defined(),
153       "Setting a forward grad that "
154       "itself has a forward gradient at the same level",
155       level,
156       " is not supported.");
157   TORCH_INTERNAL_ASSERT(
158       (new_grad_base.is_floating_point() || new_grad_base.is_complex()) &&
159           (self_base.is_floating_point() || self_base.is_complex()),
160       "Expected both tensor and its forward grad to be floating point or complex");
161   // Lazy initialization
162   {
163     std::lock_guard<std::mutex> lock(mutex_);
164     if (!fw_grad_) {
165       fw_grad_ = std::make_shared<ForwardGrad>();
166     }
167   }
168   if (fw_grad_->contains(level)) {
169     // Setting the forward grad again is only allowed if it is a no-op.
170     // We do allow this case to simplify writing codegen for inplace ops.
171     TORCH_INTERNAL_ASSERT(
172         new_grad_base.defined(),
173         "Cannot set a forward grad that is an undefined Tensor. Use "
174         "_fw_primal(level) to get a new Tensor with this forward grad unset.");
175 
176     TORCH_INTERNAL_ASSERT(
177         is_inplace_op,
178         "Only inplace operations can re-set the forward grad of a Tensor that "
179         "already has one.");
180 
181     TORCH_INTERNAL_ASSERT(
182         fw_grad_->value(level).is_same(new_grad_base),
183         "Cannot set a value of a forward grad if it "
184         "already exists. Inplace operations should modify it inplace.");
185   } else {
186     // TODO(alband) remove this spurious version counter bump
187     Tensor new_grad(new_grad_base);
188     at::OptionalTensorRef self_ref(self_base);
189     const Tensor& self = *self_ref;
190 
191     TORCH_CHECK(
192         self.is_same_size(new_grad),
193         "Trying to set a forward gradient that has a different size than that "
194         "of the original Tensor, this is not supported. Tensor is of size ",
195         self.sizes(),
196         " while the given "
197         "forward gradient is of size ",
198         new_grad.sizes(),
199         ".");
200 
201     if (is_inplace_op && is_view_) {
202       auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
203 
204       // For inplace ops on a Tensor that does not already have a forward grad
205       // and is a view, we propagate the tangent to the base and ensure that the
206       // new_grad is a view of that base's tangent. This ensure that case 4 from
207       // [Forward Grad View/inplace] above works fine What happens in this long
208       // if statement is:
209       //   - Check if the base already has a grad
210       //   - If not, set a new fw_grad for it full of zeros
211       //   - Take a view of the base's forward grad
212       //   - Copy the given new_grad into this view
213       //   - Use this view as the new new_grad
214       if (this_view_meta->has_fw_view()) {
215         auto& view_info = this_view_meta->get_forward_view();
216         auto& base = view_info.base_;
217 
218         if (!base._fw_grad(level).defined()) {
219           // Enforce same meta here to make sure that the view op below is
220           // always valid
221           Tensor new_base_fw_grad;
222           if (utils::has_same_meta(new_grad, base) &&
223               utils::has_same_meta(new_grad, self)) {
224             // TODO extend this special case to when the underlying storage of
225             // new_grad can be re-used.
226             new_base_fw_grad = new_grad;
227           } else {
228             new_base_fw_grad =
229                 at::_new_zeros_with_same_feature_meta(new_grad, base);
230             new_base_fw_grad._set_conj(base.is_conj());
231             new_base_fw_grad._set_neg(base.is_neg());
232 
233             // Update new_grad to be a view of the base
234             Tensor new_fw_grad_value;
235             if (view_info.has_view_fn()) {
236               new_fw_grad_value = view_info.view_fn()(new_base_fw_grad);
237             } else {
238               new_fw_grad_value = new_base_fw_grad.as_strided(
239                   self.sizes(), self.strides(), self.storage_offset());
240             }
241 
242             new_fw_grad_value.copy_(new_grad);
243             new_grad = new_fw_grad_value;
244           }
245 
246           base._set_fw_grad(new_base_fw_grad, level, /* is_inplace_op */ false);
247         }
248       }
249     }
250 
251     // Enforce the basic layout constraint
252     if (!utils::has_same_meta(new_grad, self)) {
253       if (is_view_) {
254         auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
255         TORCH_INTERNAL_ASSERT(
256             !this_view_meta->has_fw_view(),
257             "Expected the output of forward differentiable view operations to have the tangent have the same layout as primal")
258       }
259       auto res = at::_new_zeros_with_same_feature_meta(new_grad, self);
260       res._set_conj(self.is_conj());
261       res._set_neg(self.is_neg());
262       res.copy_(new_grad);
263       new_grad = res;
264     }
265 
266     fw_grad_->set_value(new_grad, level);
267   }
268 }
269 
fw_grad(uint64_t level,const at::TensorBase & self) const270 const Variable& AutogradMeta::fw_grad(
271     uint64_t level,
272     const at::TensorBase& self) const {
273   // TLS that disables forward AD.
274   if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) {
275     return ForwardGrad::undef_grad();
276   }
277 
278   // Ensure that concurrent fw_grad() "reads" are thread safe
279   std::lock_guard<std::mutex> lock(mutex_);
280 
281   const auto& direct_fw_grad =
282       fw_grad_ ? fw_grad_->value(level) : ForwardGrad::undef_grad();
283 
284   if (!direct_fw_grad.defined() && is_view_) {
285     // For view that don't have a forward grad, check if their base has one that
286     // has been defined by an inplace operation.
287     // This ensure that case 5 from [Forward Grad View/inplace] above works fine
288     auto const_view_meta =
289         static_cast<const torch::autograd::DifferentiableViewMeta*>(this);
290     // This is ok to do as we ONLY modify fw_grad_ and this field is properly
291     // locked in all methods
292     if (const_view_meta->has_fw_view()) {
293       const auto& view_info = const_view_meta->get_forward_view();
294       const auto& base = view_info.base_;
295 
296       const auto& base_val = base._fw_grad(level);
297       if (base_val.defined()) {
298         // Lazy initialization of fw_grad_
299         const_view_meta->fw_grad_ = std::make_shared<ForwardGrad>();
300 
301         Variable new_val;
302         if (view_info.has_view_fn()) {
303           new_val = view_info.view_fn()(base_val);
304         } else {
305           new_val = base_val.as_strided(
306               self.sizes(), self.strides(), self.storage_offset());
307         }
308 
309         const_view_meta->fw_grad_->set_value(new_val, level);
310         return const_view_meta->fw_grad_->value(level);
311       }
312     }
313   }
314   return direct_fw_grad;
315 }
316 
317 } // namespace torch::autograd
318