1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <c10/util/irange.h>
3 #include <torch/csrc/autograd/variable.h>
4
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #else
8 #include <ATen/ops/_has_same_storage_numel.h>
9 #include <ATen/ops/_new_zeros_with_same_feature_meta.h>
10 #include <ATen/ops/zeros.h>
11 #endif
12
13 namespace torch::autograd {
14
15 using at::Tensor;
16
17 // [Forward Grad View/inplace]
18 // It is important to us to allow view and inplace to work with dual Tensors.
19 // These operations should either compute the right gradient or raise a
20 // user-friendly error.
21
22 // The basic case where all Tensors are dual Tensors is as follows:
23 // # Have:
24 // # foo is a dual Tensor that is not a view
25 // # bar is a dual Tensor of appropriate size (depending on cases) that is
26 // not a view
27 //
28 // # Case 1: no view
29 // foo.copy_(bar)
30 //
31 // # Case 2: with view, propagate from view to base
32 // view = foo[0]
33 // view.copy_(bar)
34 //
35 // # Case 3: with view, propagate from base to view
36 // view = foo[0]
37 // foo.copy_(bar)
38 //
39 // # In both cases, the forward grad of foo must be properly updated.
40 // # In the second and third cases, the forward grad of view must match
41 // # the one of foo for the subset they have in common.
42 //
43 // All these cases can be handled by the following layout constraint on the
44 // forward grad:
45 // - A Tensor and its forward grad (for all levels) must have the same
46 // metadata (size, stride
47 // conj/neg bit and storage offset). Storage offset must be in this metadata
48 // because of as_strided. conj/neg bit must be part of this metadata because
49 // of ops like `real`.
50 // - View operations must create a forward grad that is a view of the base's
51 // forward grad.
52 // - Inplace operations must modify the input's forward grad inplace.
53 //
54 // This layout constraint is ensured in the `set_fw_grad` function below
55
56 // More complex cases arrise when non-dual Tensor interact with dual Tensors.
57 // The two most important cases are:
58 //
59 // # Have:
60 // # foo is a regular Tensor that is not a view
61 // # bar is a dual Tensor of appropriate size (depending on cases) that is
62 // not a view
63 //
64 // # Case 4: Changes on the view must propagate to its base
65 // view = foo[0]
66 // # view is still a regular Tensor here
67 // view.copy_(bar)
68 // # Now both view and foo are dual Tensor with appropriate forward grad
69 //
70 // # Case 5: Changes on the base must propagate on all its views
71 // view = foo[0]
72 // # view is still a regular Tensor here
73 // base.copy_(bar)
74 // # Now both view and foo are dual Tensor with appropriate forward grad
75 //
76 // # NB there is a case 6 involving changes on a view propagating to other
77 // views # but it is fully described by the two others and is skipped in
78 // this discussion.
79 //
80 // Case 4 is handled by set_fw_grad by properly setting the forward grad of the
81 // base if needed. Case 5 is handled in fw_grad by reading the forward grad from
82 // the base if needed.
83
84 namespace utils {
85
86 // Enforcing that the metadata between the primal and tangent are same has two
87 // goals:
88 // - When properties of the primal are checked in composite op's to determine
89 // control flow, the code path decided upon is also reasonable for the tangent
90 // - Make sure that when the same as_strided is applied to both primal and
91 // and tangent, it behaves similarly.
92 //
93 // We do that by checking:
94 // 1) the storages have same properties: size and conj/neg-ness
95 // 2) the same indices refer to the same elements in storage
96 // (we are more strict than necessary here to satisfy the goal 1)
has_same_meta(const Variable & base,const Variable & other)97 bool has_same_meta(const Variable& base, const Variable& other) {
98 if (!base.defined() || !other.defined()) {
99 return false;
100 }
101 // 1) The storages have the same properties
102 if (!at::_has_same_storage_numel(base, other)) {
103 return false;
104 }
105 if (base.is_conj() != other.is_conj() || base.is_neg() != other.is_neg()) {
106 return false;
107 }
108
109 // Technically dim and size belong as part of (2), so we shouldn't really care
110 // if a zero-numel tensor violates these. But since these properties
111 // (unlike offset and strides) often determine control flow in composite ops
112 // it is useful to enforce that they match for primal and tangent here so
113 // nothing funny happens later (See goal 1).
114 if (base.dim() != other.dim()) {
115 return false;
116 }
117 for (const auto i : c10::irange(base.dim())) {
118 if (base.sym_sizes()[i] != other.sym_sizes()[i]) {
119 return false;
120 }
121 }
122
123 // The check below will always be vacuously true for 0-element tensors
124 if (base.sym_numel() == 0 && other.sym_numel() == 0) {
125 return true;
126 }
127
128 // 2) The same indices refer to the same elements in storage
129 if (base.sym_storage_offset() != other.sym_storage_offset()) {
130 return false;
131 }
132
133 for (const auto i : c10::irange(base.dim())) {
134 if (base.sym_strides()[i] != other.sym_strides()[i] &&
135 base.sym_sizes()[i] != 1 && base.sym_sizes()[i] != 0) {
136 return false;
137 }
138 }
139 return true;
140 }
141
142 } // namespace utils
143
144 // This function is will ensure that the fw_grad_ is properly a view of the base
145 // for inplace ops on Tensors that do not have forward grad originally.
set_fw_grad(const at::TensorBase & new_grad_base,const at::TensorBase & self_base,uint64_t level,bool is_inplace_op)146 void AutogradMeta::set_fw_grad(
147 const at::TensorBase& new_grad_base,
148 const at::TensorBase& self_base,
149 uint64_t level,
150 bool is_inplace_op) {
151 TORCH_CHECK(
152 !new_grad_base._fw_grad(level).defined(),
153 "Setting a forward grad that "
154 "itself has a forward gradient at the same level",
155 level,
156 " is not supported.");
157 TORCH_INTERNAL_ASSERT(
158 (new_grad_base.is_floating_point() || new_grad_base.is_complex()) &&
159 (self_base.is_floating_point() || self_base.is_complex()),
160 "Expected both tensor and its forward grad to be floating point or complex");
161 // Lazy initialization
162 {
163 std::lock_guard<std::mutex> lock(mutex_);
164 if (!fw_grad_) {
165 fw_grad_ = std::make_shared<ForwardGrad>();
166 }
167 }
168 if (fw_grad_->contains(level)) {
169 // Setting the forward grad again is only allowed if it is a no-op.
170 // We do allow this case to simplify writing codegen for inplace ops.
171 TORCH_INTERNAL_ASSERT(
172 new_grad_base.defined(),
173 "Cannot set a forward grad that is an undefined Tensor. Use "
174 "_fw_primal(level) to get a new Tensor with this forward grad unset.");
175
176 TORCH_INTERNAL_ASSERT(
177 is_inplace_op,
178 "Only inplace operations can re-set the forward grad of a Tensor that "
179 "already has one.");
180
181 TORCH_INTERNAL_ASSERT(
182 fw_grad_->value(level).is_same(new_grad_base),
183 "Cannot set a value of a forward grad if it "
184 "already exists. Inplace operations should modify it inplace.");
185 } else {
186 // TODO(alband) remove this spurious version counter bump
187 Tensor new_grad(new_grad_base);
188 at::OptionalTensorRef self_ref(self_base);
189 const Tensor& self = *self_ref;
190
191 TORCH_CHECK(
192 self.is_same_size(new_grad),
193 "Trying to set a forward gradient that has a different size than that "
194 "of the original Tensor, this is not supported. Tensor is of size ",
195 self.sizes(),
196 " while the given "
197 "forward gradient is of size ",
198 new_grad.sizes(),
199 ".");
200
201 if (is_inplace_op && is_view_) {
202 auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
203
204 // For inplace ops on a Tensor that does not already have a forward grad
205 // and is a view, we propagate the tangent to the base and ensure that the
206 // new_grad is a view of that base's tangent. This ensure that case 4 from
207 // [Forward Grad View/inplace] above works fine What happens in this long
208 // if statement is:
209 // - Check if the base already has a grad
210 // - If not, set a new fw_grad for it full of zeros
211 // - Take a view of the base's forward grad
212 // - Copy the given new_grad into this view
213 // - Use this view as the new new_grad
214 if (this_view_meta->has_fw_view()) {
215 auto& view_info = this_view_meta->get_forward_view();
216 auto& base = view_info.base_;
217
218 if (!base._fw_grad(level).defined()) {
219 // Enforce same meta here to make sure that the view op below is
220 // always valid
221 Tensor new_base_fw_grad;
222 if (utils::has_same_meta(new_grad, base) &&
223 utils::has_same_meta(new_grad, self)) {
224 // TODO extend this special case to when the underlying storage of
225 // new_grad can be re-used.
226 new_base_fw_grad = new_grad;
227 } else {
228 new_base_fw_grad =
229 at::_new_zeros_with_same_feature_meta(new_grad, base);
230 new_base_fw_grad._set_conj(base.is_conj());
231 new_base_fw_grad._set_neg(base.is_neg());
232
233 // Update new_grad to be a view of the base
234 Tensor new_fw_grad_value;
235 if (view_info.has_view_fn()) {
236 new_fw_grad_value = view_info.view_fn()(new_base_fw_grad);
237 } else {
238 new_fw_grad_value = new_base_fw_grad.as_strided(
239 self.sizes(), self.strides(), self.storage_offset());
240 }
241
242 new_fw_grad_value.copy_(new_grad);
243 new_grad = new_fw_grad_value;
244 }
245
246 base._set_fw_grad(new_base_fw_grad, level, /* is_inplace_op */ false);
247 }
248 }
249 }
250
251 // Enforce the basic layout constraint
252 if (!utils::has_same_meta(new_grad, self)) {
253 if (is_view_) {
254 auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
255 TORCH_INTERNAL_ASSERT(
256 !this_view_meta->has_fw_view(),
257 "Expected the output of forward differentiable view operations to have the tangent have the same layout as primal")
258 }
259 auto res = at::_new_zeros_with_same_feature_meta(new_grad, self);
260 res._set_conj(self.is_conj());
261 res._set_neg(self.is_neg());
262 res.copy_(new_grad);
263 new_grad = res;
264 }
265
266 fw_grad_->set_value(new_grad, level);
267 }
268 }
269
fw_grad(uint64_t level,const at::TensorBase & self) const270 const Variable& AutogradMeta::fw_grad(
271 uint64_t level,
272 const at::TensorBase& self) const {
273 // TLS that disables forward AD.
274 if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) {
275 return ForwardGrad::undef_grad();
276 }
277
278 // Ensure that concurrent fw_grad() "reads" are thread safe
279 std::lock_guard<std::mutex> lock(mutex_);
280
281 const auto& direct_fw_grad =
282 fw_grad_ ? fw_grad_->value(level) : ForwardGrad::undef_grad();
283
284 if (!direct_fw_grad.defined() && is_view_) {
285 // For view that don't have a forward grad, check if their base has one that
286 // has been defined by an inplace operation.
287 // This ensure that case 5 from [Forward Grad View/inplace] above works fine
288 auto const_view_meta =
289 static_cast<const torch::autograd::DifferentiableViewMeta*>(this);
290 // This is ok to do as we ONLY modify fw_grad_ and this field is properly
291 // locked in all methods
292 if (const_view_meta->has_fw_view()) {
293 const auto& view_info = const_view_meta->get_forward_view();
294 const auto& base = view_info.base_;
295
296 const auto& base_val = base._fw_grad(level);
297 if (base_val.defined()) {
298 // Lazy initialization of fw_grad_
299 const_view_meta->fw_grad_ = std::make_shared<ForwardGrad>();
300
301 Variable new_val;
302 if (view_info.has_view_fn()) {
303 new_val = view_info.view_fn()(base_val);
304 } else {
305 new_val = base_val.as_strided(
306 self.sizes(), self.strides(), self.storage_offset());
307 }
308
309 const_view_meta->fw_grad_->set_value(new_val, level);
310 return const_view_meta->fw_grad_->value(level);
311 }
312 }
313 }
314 return direct_fw_grad;
315 }
316
317 } // namespace torch::autograd
318