1 #include <torch/csrc/autograd/FunctionsManual.h>
2 #include <torch/csrc/autograd/functions/basic_ops.h>
3 #include <torch/csrc/autograd/functions/utils.h>
4 #include <torch/csrc/autograd/variable.h>
5
6 #include <ATen/ATen.h>
7 #include <ATen/AccumulateType.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/ExpandUtils.h>
10 #include <ATen/LegacyBatchedTensorImpl.h>
11 #include <ATen/ScalarOps.h>
12 #include <ATen/SparseCsrTensorUtils.h>
13 #include <ATen/TensorSubclassLikeUtils.h>
14 #include <ATen/Utils.h>
15 #include <ATen/WrapDimUtils.h>
16 #include <ATen/WrapDimUtilsMulti.h>
17 #include <ATen/core/Reduction.h>
18 #include <ATen/core/grad_mode.h>
19 #include <ATen/native/Activation.h>
20 #include <ATen/native/IndexingUtils.h>
21 #include <ATen/native/LinearAlgebraUtils.h>
22 #include <ATen/native/SparseTensorUtils.h>
23 #include <ATen/native/nested/NestedTensorUtils.h>
24 #include <c10/core/TensorOptions.h>
25 #include <c10/util/OptionalArrayRef.h>
26 #include <c10/util/SmallBuffer.h>
27 #include <c10/util/accumulate.h>
28 #include <c10/util/irange.h>
29
30 #include <algorithm>
31 #include <ciso646>
32 #include <functional>
33 #include <numeric>
34 #include <utility>
35
36 // Helper functions for autogenerated code
37 // These used to be inlined into the codegened Functions.cpp
38
39 namespace torch::autograd::generated::details {
40
41 using at::areAnyTensorSubclassLike;
42 using at::IntArrayRef;
43 using at::OptionalIntArrayRef;
44 using at::Scalar;
45 using at::Tensor;
46 using at::TensorList;
47
48 const char* kCudnnDoubleBackwardMsg =
49 "Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n output = model(inputs)";
50
apply_loss_reduction(const Tensor & unreduced,int64_t reduction)51 Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction) {
52 if (reduction == at::Reduction::Mean) {
53 return unreduced.mean();
54 } else if (reduction == at::Reduction::Sum) {
55 return unreduced.sum();
56 }
57 return unreduced;
58 }
59
isDefined(const std::optional<Tensor> & t)60 static bool isDefined(const std::optional<Tensor>& t) {
61 return t.has_value() && t->defined();
62 }
63
toNonOptTensor(const std::optional<Tensor> & t)64 Tensor toNonOptTensor(const std::optional<Tensor>& t) {
65 return t.has_value() ? *t : Tensor();
66 }
67
toNonOptFwGrad(const std::optional<Tensor> & t)68 Tensor toNonOptFwGrad(const std::optional<Tensor>& t) {
69 return (t.has_value() && t->defined()) ? t->_fw_grad(/*level */ 0) : Tensor();
70 }
71
toNonOptPrimal(const std::optional<Tensor> & t)72 Tensor toNonOptPrimal(const std::optional<Tensor>& t) {
73 if (t.has_value() && t->defined()) {
74 if (t->unsafeGetTensorImpl()->is_wrapped_number()) {
75 return *t;
76 }
77 return t->_fw_primal(/* level */ 0);
78 }
79 return Tensor();
80 }
81
copy_range(variable_list & out,IndexRange range,const Tensor & t)82 void copy_range(variable_list& out, IndexRange range, const Tensor& t) {
83 TORCH_CHECK(range.second <= out.size());
84 TORCH_CHECK(
85 range.second - range.first == 1, "inconsistent range for Tensor output");
86 out[range.first] = t;
87 }
88
copy_range(variable_list & out,IndexRange range,at::ArrayRef<Tensor> t)89 void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
90 TORCH_CHECK(range.second <= out.size());
91 TORCH_CHECK(
92 range.second - range.first == t.size(),
93 "inconsistent range for TensorList output");
94 std::copy(
95 t.begin(), t.end(), out.begin() + static_cast<int64_t>(range.first));
96 }
97
copysign_tensor_self_backward(const Tensor & grad,const Tensor & self,const Tensor & result)98 Tensor copysign_tensor_self_backward(
99 const Tensor& grad,
100 const Tensor& self,
101 const Tensor& result) {
102 auto ratio = result / self;
103 ratio.masked_fill_(self == 0, 0);
104 return grad * ratio;
105 }
106
107 template <typename T>
not_implemented_base(const char * name,const char * reason)108 T not_implemented_base(const char* name, const char* reason) {
109 std::string msg =
110 c10::str("the derivative for '", name, "' is not implemented.");
111 if (reason[0] != '\0') {
112 msg = c10::str(msg, " ", reason);
113 };
114 TORCH_CHECK_NOT_IMPLEMENTED(false, msg);
115 }
116
not_implemented(const char * name,const char * reason)117 Tensor not_implemented(const char* name, const char* reason) {
118 return not_implemented_base<Tensor>(name, reason);
119 }
120
not_implemented_list(const char * name,const char * reason)121 std::vector<Tensor> not_implemented_list(const char* name, const char* reason) {
122 return not_implemented_base<std::vector<Tensor>>(name, reason);
123 }
124
maybe_multiply(const Tensor & t,const Scalar & s)125 Tensor maybe_multiply(const Tensor& t, const Scalar& s) {
126 bool is_one = false;
127 if (s.isFloatingPoint()) {
128 is_one = s.toSymFloat() == 1;
129 } else if (s.isIntegral(true)) {
130 is_one = s.toSymInt() == 1;
131 }
132
133 if (is_one) {
134 return t;
135 } else {
136 return t * s;
137 }
138 }
139
_safe_size(IntArrayRef sizes,IntArrayRef dim)140 int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
141 int64_t size = 1;
142 if (sizes.empty()) {
143 return 1;
144 }
145 for (auto d : dim) {
146 d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size()));
147 size *= sizes[d];
148 }
149 return size;
150 }
151
_safe_size(c10::SymIntArrayRef sizes,c10::IntArrayRef dim)152 static c10::SymInt _safe_size(c10::SymIntArrayRef sizes, c10::IntArrayRef dim) {
153 c10::SymInt size = 1;
154 if (sizes.empty()) {
155 return 1;
156 }
157 for (auto d : dim) {
158 d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size()));
159 size *= sizes[d];
160 }
161 return size;
162 }
163
handle_r_to_c(ScalarType self_st,Tensor gradient_result)164 Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) {
165 if (!at::isComplexType(self_st) && gradient_result.is_complex()) {
166 // R -> C
167 return at::real(gradient_result);
168 }
169 return gradient_result;
170 }
171
handle_r_to_c(const Tensor & self,Tensor gradient_result)172 static Tensor handle_r_to_c(const Tensor& self, Tensor gradient_result) {
173 if (!self.is_complex() && gradient_result.is_complex()) {
174 // R -> C
175 return at::real(gradient_result);
176 }
177 return gradient_result;
178 }
179
restore_reduced_dims(const Tensor & output,IntArrayRef dims,bool keepdim)180 Tensor restore_reduced_dims(
181 const Tensor& output,
182 IntArrayRef dims,
183 bool keepdim) {
184 if (keepdim) {
185 return output;
186 }
187 auto total_dims = output.dim() + dims.size();
188 std::vector<c10::SymInt> target_shape(total_dims, 0);
189 for (int64_t i : dims) {
190 if (i < 0) {
191 i = static_cast<int64_t>(total_dims) + i;
192 }
193 target_shape[i] = 1;
194 }
195 int64_t j = 0;
196 for (const c10::SymInt& i : output.sym_sizes()) {
197 while (target_shape[j] > 0)
198 j++;
199 target_shape[j++] = i;
200 }
201 return output.reshape_symint(target_shape);
202 }
203
scale_grad_by_count(const Tensor & grad,const Tensor & mask,IntArrayRef dims)204 Tensor scale_grad_by_count(
205 const Tensor& grad,
206 const Tensor& mask,
207 IntArrayRef dims) {
208 return (grad / mask.sum(dims, true)) * mask;
209 }
210
amaxamin_jvp(const Tensor & x,const Tensor & dx,const Tensor & result,IntArrayRef dim,bool keepdim)211 Tensor amaxamin_jvp(
212 const Tensor& x,
213 const Tensor& dx,
214 const Tensor& result,
215 IntArrayRef dim,
216 bool keepdim) {
217 auto mask = x == restore_reduced_dims(result, dim, keepdim);
218 return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim, keepdim);
219 }
220
_euclidean_dist_backward(const Tensor & grad,const Tensor & x1,const Tensor & x2,const Tensor & res)221 std::tuple<Tensor, Tensor> _euclidean_dist_backward(
222 const Tensor& grad,
223 const Tensor& x1,
224 const Tensor& x2,
225 const Tensor& res) {
226 if (!grad.defined()) {
227 return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
228 }
229 // handle case at 0 where we return a subgradient containing 0
230 Tensor ratio = grad / res;
231 ratio.masked_fill_(res == 0, 0);
232 return std::tuple<Tensor, Tensor>{
233 x1 * ratio.sum(-1, true) - ratio.matmul(x2),
234 x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.mT().matmul(x1)};
235 }
236
norm_backward(const Tensor & grad,const Tensor & self,const std::optional<Scalar> & p_,const Tensor & norm)237 Tensor norm_backward(
238 const Tensor& grad,
239 const Tensor& self,
240 const std::optional<Scalar>& p_,
241 const Tensor& norm) {
242 return norm_backward(grad, self, p_, norm, {}, true);
243 }
244
norm_backward(Tensor grad,const Tensor & self,const std::optional<Scalar> & p_,Tensor norm,IntArrayRef dim,bool keepdim)245 Tensor norm_backward(
246 Tensor grad,
247 const Tensor& self,
248 const std::optional<Scalar>& p_,
249 Tensor norm,
250 IntArrayRef dim,
251 bool keepdim) {
252 // NB: We mask fill the NaNs in the output to be zero but still do float
253 // division
254 // by zero, which ASAN complains about. One way to appease ASAN is to fill
255 // the problematic values with something arbitrary before the division,
256 // but we decide not to due to the perf hit. Instead we just silence ASAN
257 // where necessary
258 size_t ndim = self.dim();
259 double p = p_.value_or(2.0).toDouble();
260 Tensor self_scaled;
261 Tensor scale_v;
262
263 if (!keepdim && self.dim() != 0) {
264 grad = unsqueeze_multiple(grad, dim, ndim);
265 norm = unsqueeze_multiple(norm, dim, ndim);
266 }
267
268 if (p == 0.0) {
269 return {};
270 } else if (p == 1.0) {
271 return self.sgn() * grad;
272 } else if (p == 2.0) {
273 return grad * (self / norm).masked_fill_(norm == 0, 0);
274 } else if (std::isinf(p)) {
275 // Derivative of amax(abs(self), dim, keepdim) but respecting nans
276 // We create a mask of `argmax`: it's argmax if self.abs() == norm or it's
277 // NaN
278 auto self_abs = self.abs();
279 auto mask = self_abs.eq(norm).logical_or(self_abs.isnan());
280 return self.sgn() * ((grad / mask.sum(dim, true)) * mask);
281 } else if (p < 1.0) {
282 self_scaled =
283 self.sgn() * self.abs().pow_(p - 1).masked_fill_(self == 0, 0);
284 return self_scaled * grad * norm.pow(1 - p);
285 } else if (p < 2.0) {
286 self_scaled = self.sgn() * self.abs().pow_(p - 1);
287 scale_v = grad / norm.pow(p - 1);
288 scale_v.masked_fill_(norm == 0, 0);
289 return self_scaled * scale_v;
290 } else {
291 self_scaled = self * self.abs().pow_(p - 2);
292 scale_v = grad / norm.pow(p - 1);
293 scale_v.masked_fill_(norm == 0, 0);
294 return self_scaled * scale_v;
295 }
296 }
297
298 // See norm_backward above for a note on ignoring the sanitizer
norm_jvp(const Tensor & self_p,const Tensor & self_t,const std::optional<Scalar> & p_,Tensor norm,IntArrayRef dim,bool keepdim)299 Tensor norm_jvp(
300 const Tensor& self_p,
301 const Tensor& self_t,
302 const std::optional<Scalar>& p_,
303 Tensor norm,
304 IntArrayRef dim,
305 bool keepdim) {
306 // NB: currently norm_jvp is also reused for dist's jvp (which haas two
307 // differentiable inputs)
308 // but self_t still cannot be a ZT because that would require both self_t
309 // and other_t to be ZT
310 TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor());
311 size_t ndim = self_p.dim(); // composite compliance?
312 double p = p_.value_or(2.0).toDouble();
313
314 if (p == 0.0) {
315 return at::zeros_like(norm);
316 } else if (p == 1.0) {
317 auto result = self_p.sgn();
318 result = areAnyTensorSubclassLike({self_t}) ? result.mul(self_t.conj())
319 : result.mul_(self_t.conj());
320 result = at::real(result);
321 return result.sum(dim, keepdim);
322 } else if (p == 2.0) {
323 auto result = self_p.mul(self_t.conj());
324 result = at::real(result);
325 result = result.sum(dim, keepdim);
326 return result.div_(norm).masked_fill_(norm == 0, 0);
327 } else if (std::isinf(p)) {
328 if (!keepdim && self_p.dim() != 0) {
329 norm = unsqueeze_multiple(norm, dim, ndim);
330 }
331 const auto self_isnan = self_p.isnan();
332 const auto norm_isnan = norm.isnan();
333 const auto& self_and_norm_isnan = areAnyTensorSubclassLike({norm})
334 ? self_isnan.logical_and(norm_isnan)
335 : self_isnan.logical_and_(norm_isnan);
336 const auto is_eq_max =
337 (self_p.abs() == norm).logical_or_(self_and_norm_isnan).type_as(norm);
338 auto nb_max = is_eq_max.count_nonzero(dim);
339 if (self_p.dim() != 0) {
340 nb_max = unsqueeze_multiple(nb_max, dim, ndim);
341 }
342 return (at::real(self_p.sgn() * self_t.conj()) * is_eq_max / nb_max)
343 .sum(dim, keepdim);
344 } else if (p < 1.0) {
345 auto sumpow_t = (self_p.abs().pow_(p - 1).masked_fill_(self_p == 0, 0) *
346 at::real(self_p.sgn() * self_t.conj()))
347 .sum(dim, keepdim);
348 return sumpow_t * norm.pow(1 - p);
349 } else if (p < 2.0) {
350 auto sumpow_t =
351 (self_p.abs().pow_(p - 1) * at::real(self_p.sgn() * self_t.conj()))
352 .sum(dim, keepdim);
353 auto out = sumpow_t / norm.pow(p - 1);
354 return out.masked_fill_(norm == 0, 0);
355 } else {
356 auto sumpow_t =
357 (self_p.abs().pow_(p - 2) * at::real(self_p * self_t.conj()))
358 .sum(dim, keepdim);
359 auto out = sumpow_t / norm.pow(p - 1);
360 return out.masked_fill_(norm == 0, 0);
361 }
362 }
363
norm_jvp(const Tensor & self_p,const Tensor & self_t,const std::optional<Scalar> & p_,Tensor norm)364 Tensor norm_jvp(
365 const Tensor& self_p,
366 const Tensor& self_t,
367 const std::optional<Scalar>& p_,
368 Tensor norm) {
369 return norm_jvp(self_p, self_t, p_, std::move(norm), {}, true);
370 }
371
_nested_from_padded_backward(const Tensor & grad,const Tensor & input,bool do_transform_0213)372 Tensor _nested_from_padded_backward(
373 const Tensor& grad,
374 const Tensor& input,
375 bool do_transform_0213) {
376 if (do_transform_0213) {
377 auto new_sizes = {
378 input.size(0), input.size(2), (input.size(1) * input.size(3))};
379 auto out = grad.to_padded_tensor(0, new_sizes);
380 auto expand_last_dim_size = {
381 input.size(0), input.size(2), input.size(1), input.size(3)};
382 return out.view(expand_last_dim_size).permute({0, 2, 1, 3});
383 }
384 return grad.to_padded_tensor(0, input.sizes());
385 }
386
linear_double_backward(const variable_list & grads,const Tensor & self,const Tensor & grad_output,const Tensor & weight)387 std::tuple<Tensor, Tensor, Tensor> linear_double_backward(
388 const variable_list& grads,
389 const Tensor& self,
390 const Tensor& grad_output,
391 const Tensor& weight) {
392 if (!grad_output.defined()) {
393 return std::make_tuple(Tensor(), Tensor(), Tensor());
394 }
395
396 Tensor grad_self, grad_grad_output, grad_weight;
397
398 if (grads[1].defined()) {
399 grad_self =
400 (grad_output.dim() == 1 ? grad_output.unsqueeze(0) : grad_output)
401 .matmul(grads[1]);
402 if (grad_output.dim() == 1) {
403 grad_self = grad_self.squeeze(0);
404 }
405 }
406 if (grads[0].defined()) {
407 grad_weight =
408 (grad_output.dim() == 1 ? grad_output.unsqueeze(1) : grad_output.mT())
409 .matmul(grads[0].dim() == 1 ? grads[0].unsqueeze(0) : grads[0]);
410 }
411
412 if (grads[0].defined() || grads[1].defined() || grads[2].defined()) {
413 grad_grad_output = at::zeros_like(grad_output);
414 if (grad_output.dim() == 1) {
415 grad_grad_output = grad_grad_output.unsqueeze(0);
416 }
417 }
418
419 if (grads[0].defined()) {
420 grad_grad_output = grad_grad_output +
421 (grads[0].dim() == 1 ? grads[0].unsqueeze(0) : grads[0])
422 .matmul(weight.mT());
423 }
424 if (grads[1].defined()) {
425 grad_grad_output = grad_grad_output +
426 (self.dim() == 1 ? self.unsqueeze(0) : self).matmul(grads[1].mT());
427 }
428 if (grads[2].defined()) {
429 grad_grad_output = grad_grad_output + grads[2];
430 }
431 if (grad_grad_output.defined() && grad_output.dim() == 1) {
432 grad_grad_output = grad_grad_output.squeeze(0);
433 }
434
435 return std::make_tuple(
436 std::move(grad_self),
437 std::move(grad_grad_output),
438 std::move(grad_weight));
439 }
440
linalg_vector_norm_jvp(const Tensor & self_p,const Tensor & self_t,const Scalar & scalar_ord,Tensor norm,const at::OptionalIntArrayRef & opt_dim,bool keepdim)441 Tensor linalg_vector_norm_jvp(
442 const Tensor& self_p,
443 const Tensor& self_t,
444 const Scalar& scalar_ord,
445 Tensor norm,
446 const at::OptionalIntArrayRef& opt_dim,
447 bool keepdim) {
448 // No need to handle the dtype arg as it's handled via broadcasting in the
449 // function
450 auto dim = opt_dim.value_or(IntArrayRef({}));
451 return norm_jvp(self_p, self_t, scalar_ord, std::move(norm), dim, keepdim);
452 }
453
linalg_vector_norm_backward(Tensor grad,const Tensor & self,const Scalar & scalar_ord,Tensor norm,const at::OptionalIntArrayRef & opt_dim,bool keepdim)454 Tensor linalg_vector_norm_backward(
455 Tensor grad,
456 const Tensor& self,
457 const Scalar& scalar_ord,
458 Tensor norm,
459 const at::OptionalIntArrayRef& opt_dim,
460 bool keepdim) {
461 // No need to handle the dtype arg as it's handled via broadcasting in the
462 // function
463 auto dim = opt_dim.value_or(IntArrayRef({}));
464 return norm_backward(
465 std::move(grad), self, scalar_ord, std::move(norm), dim, keepdim);
466 }
467
pow_backward(Tensor grad,const Tensor & self,const Scalar & exponent)468 Tensor pow_backward(Tensor grad, const Tensor& self, const Scalar& exponent) {
469 if (exponent.equal(0.0)) {
470 return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
471 } else {
472 auto grad_lambda = [&](auto exp) {
473 return grad * (exp * self.pow(exp - 1)).conj();
474 };
475 Tensor out = (exponent.isComplex())
476 ? grad_lambda(exponent.toComplexDouble())
477 : grad_lambda(exponent.toDouble());
478 return handle_r_to_c(self, std::move(out));
479 }
480 }
481
pow_backward_self(const Tensor & grad,const Tensor & self,const Tensor & exponent)482 Tensor pow_backward_self(
483 const Tensor& grad,
484 const Tensor& self,
485 const Tensor& exponent) {
486 auto out = at::where(
487 exponent == 0.0,
488 at::zeros({}, grad.options()),
489 grad * (exponent * self.pow(exponent - 1)).conj());
490 return handle_r_to_c(self, std::move(out));
491 }
492
493 // Caveats:
494 // We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to
495 // d(a^b)/db -> -inf for a fixed b as a -> +0
496 // Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0.
497 //
498 // We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as
499 // d(a^b)/db = 0 for a > 0 and b -> +0.
500 // Currently, tensorflow agrees with us.
pow_backward_exponent(const Tensor & grad,const Tensor & self,const Tensor & exponent,const Tensor & result)501 Tensor pow_backward_exponent(
502 const Tensor& grad,
503 const Tensor& self,
504 const Tensor& exponent,
505 const Tensor& result) {
506 Tensor cond;
507 if (exponent.is_complex()) {
508 auto is_real_exp =
509 at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0);
510 cond = at::logical_and(self == 0, is_real_exp);
511 } else {
512 cond = at::logical_and(self == 0, exponent >= 0);
513 }
514 auto promoted_dtype = at::result_type(self, exponent);
515 // `.to()` is no-op if dtype is same.
516 auto self_ = self.to(promoted_dtype);
517
518 auto out =
519 grad *
520 at::where(
521 cond, at::zeros({}, grad.options()), (result * self_.log()).conj());
522 return handle_r_to_c(exponent, std::move(out));
523 }
524
pow_backward_exponent(const Tensor & grad,const Scalar & base,const Tensor & exponent,const Tensor & result)525 Tensor pow_backward_exponent(
526 const Tensor& grad,
527 const Scalar& base,
528 const Tensor& exponent,
529 const Tensor& result) {
530 auto grad_lambda = [](const Tensor& a, const Scalar& b) {
531 return (a * b.log()).conj();
532 };
533 auto base_ = exponent.is_complex() && !base.isComplex()
534 ? base.toComplexDouble()
535 : base;
536 if (base.equal(0.0)) {
537 auto cond = [](auto exp) {
538 if (exp.is_complex()) {
539 return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
540 } else {
541 return exp >= 0;
542 }
543 };
544 auto out = grad *
545 at::where(cond(exponent),
546 at::zeros({}, grad.options()),
547 grad_lambda(result, base_));
548 return handle_r_to_c(exponent, std::move(out));
549 } else {
550 auto out = grad * grad_lambda(result, base_);
551 return handle_r_to_c(exponent, std::move(out));
552 }
553 }
554
angle_backward(const Tensor & grad,const Tensor & self)555 Tensor angle_backward(const Tensor& grad, const Tensor& self) {
556 if (self.is_complex()) {
557 return at::where(
558 self == 0.0,
559 at::zeros({}, self.options()),
560 grad * self / self.abs().pow(2) *
561 Scalar(c10::complex<double>{0.0, 1.0}));
562 } else {
563 return at::zeros_like(self, at::MemoryFormat::Preserve);
564 }
565 }
566
mvlgamma_backward(const Tensor & grad,const Tensor & self,int64_t p)567 Tensor mvlgamma_backward(const Tensor& grad, const Tensor& self, int64_t p) {
568 Tensor args = at::arange(
569 -static_cast<double>(p) / 2. + 0.5,
570 0.5,
571 0.5,
572 // use strided here regardless of self's layout; useful for e.g. NJT
573 self.options().layout(c10::kStrided));
574 args = args.add(self.unsqueeze(-1));
575 return grad * args.digamma_().sum(-1);
576 }
577
sgn_backward(const Tensor & x,const Tensor & gx,const Tensor & sgn)578 Tensor sgn_backward(const Tensor& x, const Tensor& gx, const Tensor& sgn) {
579 if (x.is_complex()) {
580 auto abs = x.abs();
581 return ((gx - (sgn * sgn) * gx.conj()) / (2. * abs))
582 .masked_fill_(abs == 0., 0.);
583 } else {
584 return at::_efficientzerotensor(sgn.sizes(), sgn.options());
585 }
586 }
587
masked_fill_backward(const Tensor & grad,const Tensor & mask)588 Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask) {
589 // masked_select does not work well with functorch, as its shape is
590 // data-dependent
591 return areAnyTensorSubclassLike({grad, mask})
592 ? at::where(mask, grad, 0).sum()
593 : grad.masked_select(mask).sum();
594 }
595
596 template <typename T>
mul_tensor_backward(const Tensor & grad,T other,ScalarType self_st)597 Tensor mul_tensor_backward(const Tensor& grad, T other, ScalarType self_st) {
598 auto out = grad * other.conj();
599 return handle_r_to_c(self_st, std::move(out));
600 }
601 template Tensor mul_tensor_backward(const Tensor&, Tensor, ScalarType);
602 template Tensor mul_tensor_backward(const Tensor&, Scalar, ScalarType);
603
604 template <typename T>
div_tensor_self_backward(const Tensor & grad,T other,ScalarType self_st,const std::optional<c10::string_view> & rounding_mode)605 Tensor div_tensor_self_backward(
606 const Tensor& grad,
607 T other,
608 ScalarType self_st,
609 const std::optional<c10::string_view>& rounding_mode) {
610 if (rounding_mode.has_value()) {
611 return at::zeros_like(grad, grad.options().dtype(self_st));
612 }
613
614 auto result = grad / other.conj();
615 return handle_r_to_c(self_st, std::move(result));
616 }
617 template Tensor div_tensor_self_backward(
618 const Tensor&,
619 Tensor,
620 ScalarType,
621 const std::optional<c10::string_view>&);
622 template Tensor div_tensor_self_backward(
623 const Tensor&,
624 Scalar,
625 ScalarType,
626 const std::optional<c10::string_view>&);
627
628 template <typename T>
div_tensor_self_backward(const Tensor & grad,T other,ScalarType self_st)629 Tensor div_tensor_self_backward(
630 const Tensor& grad,
631 T other,
632 ScalarType self_st) {
633 return div_tensor_self_backward(
634 grad, std::move(other), self_st, std::nullopt);
635 }
636 template Tensor div_tensor_self_backward(const Tensor&, Tensor, ScalarType);
637 template Tensor div_tensor_self_backward(const Tensor&, Scalar, ScalarType);
638
div_tensor_other_backward(const Tensor & grad,const Tensor & self,const Tensor & other,const std::optional<c10::string_view> & rounding_mode)639 Tensor div_tensor_other_backward(
640 const Tensor& grad,
641 const Tensor& self,
642 const Tensor& other,
643 const std::optional<c10::string_view>& rounding_mode) {
644 if (rounding_mode.has_value()) {
645 return at::zeros_like(grad, grad.options().dtype(other.scalar_type()));
646 }
647
648 auto result = -grad * ((self / other) / other).conj();
649 return handle_r_to_c(other, std::move(result));
650 }
651
div_tensor_other_backward(const Tensor & grad,const Tensor & self,const Tensor & other)652 Tensor div_tensor_other_backward(
653 const Tensor& grad,
654 const Tensor& self,
655 const Tensor& other) {
656 return div_tensor_other_backward(grad, self, other, std::nullopt);
657 }
658
permute_backwards(const Tensor & grad,IntArrayRef fwd_dims)659 Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) {
660 // invert the permutation
661 auto ndims = fwd_dims.size();
662 std::vector<int64_t> dims(ndims);
663 for (const auto i : c10::irange(ndims)) {
664 dims[at::maybe_wrap_dim(fwd_dims[i], static_cast<int64_t>(ndims))] =
665 static_cast<int64_t>(i);
666 }
667 return grad.permute(dims);
668 }
669
rad2deg_backward(const Tensor & grad)670 Tensor rad2deg_backward(const Tensor& grad) {
671 constexpr double M_180_PI =
672 57.295779513082320876798154814105170332405472466564;
673 return at::mul(grad, Scalar(M_180_PI));
674 }
675
deg2rad_backward(const Tensor & grad)676 Tensor deg2rad_backward(const Tensor& grad) {
677 constexpr double M_PI_180 =
678 0.017453292519943295769236907684886127134428718885417;
679 return at::mul(grad, Scalar(M_PI_180));
680 }
681
unsqueeze_multiple(const Tensor & t,OptionalIntArrayRef opt_dim,size_t n_dims)682 Tensor unsqueeze_multiple(
683 const Tensor& t,
684 OptionalIntArrayRef opt_dim,
685 size_t n_dims) {
686 if (opt_dim.has_value()) {
687 IntArrayRef dim = opt_dim.value();
688 auto dim_size = dim.size();
689 // Optimisation for two common cases
690 if (dim_size == 0) {
691 return t;
692 } else if (dim_size == 1) {
693 return t.unsqueeze(dim[0]);
694 }
695 }
696 auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims);
697 Tensor res = t;
698 for (const auto i : c10::irange(n_dims)) {
699 if (dims_to_unsqueeze[i]) {
700 res = res.unsqueeze(static_cast<int64_t>(i));
701 }
702 }
703 return res;
704 }
705
sum_backward(const Tensor & grad,c10::SymIntArrayRef sizes,OptionalIntArrayRef opt_dims,bool keepdim)706 Tensor sum_backward(
707 const Tensor& grad,
708 c10::SymIntArrayRef sizes,
709 OptionalIntArrayRef opt_dims,
710 bool keepdim) {
711 if (!keepdim && !sizes.empty()) {
712 if (opt_dims.has_value() && !opt_dims.value().empty()) {
713 return unsqueeze_multiple(grad, opt_dims, sizes.size())
714 .expand_symint(sizes);
715 }
716 }
717 return grad.expand_symint(sizes);
718 }
719
sum_backward(const Tensor & grad,c10::SymIntArrayRef sizes,c10::IntArrayRef dims,bool keepdim)720 Tensor sum_backward(
721 const Tensor& grad,
722 c10::SymIntArrayRef sizes,
723 c10::IntArrayRef dims,
724 bool keepdim) {
725 if (!keepdim && !sizes.empty() && !dims.empty()) {
726 // we are only using `keepdim=true` path for SymInts for now
727 TORCH_CHECK_NOT_IMPLEMENTED(
728 false,
729 "Only the keepdim=true path is implemented to support symints in autograd");
730 } else {
731 return grad.expand_symint(sizes);
732 }
733 }
734
nansum_backward(const Tensor & grad,const Tensor & self,at::OptionalIntArrayRef dims,bool keepdim)735 Tensor nansum_backward(
736 const Tensor& grad,
737 const Tensor& self,
738 at::OptionalIntArrayRef dims,
739 bool keepdim) {
740 return sum_backward(grad, self.sym_sizes(), dims, keepdim) *
741 self.isnan().logical_not();
742 }
743
mean_backward(const Tensor & grad,c10::SymIntArrayRef shape,OptionalIntArrayRef opt_dim,c10::SymInt numel,bool keepdim)744 Tensor mean_backward(
745 const Tensor& grad,
746 c10::SymIntArrayRef shape,
747 OptionalIntArrayRef opt_dim,
748 c10::SymInt numel,
749 bool keepdim) {
750 bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
751 auto n =
752 is_all_reduce ? std::move(numel) : _safe_size(shape, opt_dim.value());
753 return sum_backward(grad, shape, opt_dim, keepdim) / std::move(n);
754 }
755
reverse_list_symint(const c10::SymIntArrayRef list)756 std::vector<c10::SymInt> reverse_list_symint(const c10::SymIntArrayRef list) {
757 auto result = std::vector<c10::SymInt>();
758 result.reserve(list.size());
759 for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
760 result.push_back(*iter);
761 }
762 return result;
763 }
764
reverse_list(const IntArrayRef list)765 std::vector<int64_t> reverse_list(const IntArrayRef list) {
766 auto result = std::vector<int64_t>();
767 result.reserve(list.size());
768 for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
769 result.push_back(*iter);
770 }
771 return result;
772 }
773
prod_safe_zeros_backward(const Tensor & grad,const Tensor & inp,int64_t dim)774 Tensor prod_safe_zeros_backward(
775 const Tensor& grad,
776 const Tensor& inp,
777 int64_t dim) {
778 if (inp.sym_numel() == 0) {
779 // When input has a zero sized dimension (empty tensor),
780 // we don't need to actually compute the grads.
781 // So we just reshape `grad` as `input`.
782 return grad.expand_as(inp);
783 }
784
785 if (inp.sym_size(dim) == 1) {
786 return grad;
787 }
788
789 auto ones_size = inp.sym_sizes().vec();
790 ones_size[dim] = 1;
791 Tensor ones = at::ones_symint(ones_size, grad.options());
792 Tensor exclusive_normal_nocp =
793 at::cat({ones, inp.narrow_symint(dim, 0, inp.sym_size(dim) - 1)}, dim);
794 Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);
795
796 Tensor narrow_reverse =
797 inp.narrow_symint(dim, 1, inp.sym_size(dim) - 1).flip(dim);
798 Tensor exclusive_reverse_nocp =
799 at::cat({std::move(ones), std::move(narrow_reverse)}, dim);
800 Tensor exclusive_reverse = exclusive_reverse_nocp.cumprod(dim).flip(dim);
801
802 return grad * (exclusive_normal * exclusive_reverse).conj();
803 }
804
805 // note that the gradient for prod is equivalent to:
806 // cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
807 // input: [ a, b, c]
808 // cumprod(exclusive, normal): [1 , a, a * b]
809 // cumprod(exclusive, reverse): [b * c, c, 1]
810 // product: [b * c, a * c, a * b]
811 // and this is safe under input with 0s.
prod_backward(const Tensor & grad,const Tensor & input,const Tensor & result)812 Tensor prod_backward(
813 const Tensor& grad,
814 const Tensor& input,
815 const Tensor& result) {
816 if (input.dim() == 0) {
817 return grad;
818 }
819 if (input.is_meta() || isTensorSubclassLike(input)) {
820 // For Composite Compliance, always take the safer (and slower) path
821 return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0)
822 .view_as(input);
823 }
824 Tensor zero_idx = (input == 0).nonzero();
825 if (zero_idx.sym_numel() == 0) {
826 return grad * (result / input).conj();
827 } else if (!at::GradMode::is_enabled() && zero_idx.sym_size(0) > 1) {
828 return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
829 } else {
830 return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0)
831 .view_as(input);
832 }
833 }
834
prod_backward(Tensor grad,const Tensor & input,Tensor result,int64_t dim,bool keepdim)835 Tensor prod_backward(
836 Tensor grad,
837 const Tensor& input,
838 Tensor result,
839 int64_t dim,
840 bool keepdim) {
841 if (input.dim() == 0) {
842 return grad;
843 }
844 dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(input.sym_sizes().size()));
845 if (!keepdim) {
846 // `prod` reduces the dimension at `dim`,
847 // so, unsqueeze `grad` and `result` at dim.
848 grad = grad.unsqueeze(dim);
849 result = result.unsqueeze(dim);
850 }
851 if (input.is_meta() || isTensorSubclassLike(input)) {
852 // For Composite Compliance, always take the safer (and slower) path
853 return prod_safe_zeros_backward(grad, input, dim);
854 }
855
856 Tensor zero_mask = (input == 0);
857 Tensor slice_zero_count = zero_mask.sum(dim, true);
858 int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
859 if (total_zeros == 0) {
860 return grad * (result / input).conj();
861 } else {
862 return prod_safe_zeros_backward(grad, input, dim);
863 }
864 }
865
866 template <typename solve_f>
generic_solve_jvp(solve_f solve,const Tensor & X,const Tensor & A,const Tensor & dA,const Tensor & dB)867 static Tensor generic_solve_jvp(
868 solve_f solve,
869 const Tensor& X,
870 const Tensor& A,
871 const Tensor& dA,
872 const Tensor& dB) {
873 auto is_vector_case = at::native::linalg_solve_is_vector_rhs(dA, dB);
874 auto dA_contrib =
875 is_vector_case ? dA.matmul(X.unsqueeze(-1)).squeeze(-1) : dA.matmul(X);
876 // In general,
877 // dX = solve(A, dB - dA_contrib), but this behavior is different for
878 // lu_solve. For refer to lu_solve_jvp for more details on this.
879 return solve(A, dB, dA_contrib);
880 }
881
cumsum_backward(const Tensor & grad,int64_t dim)882 Tensor cumsum_backward(const Tensor& grad, int64_t dim) {
883 // Trivial case
884 if (grad.sym_numel() <= 1 || grad.sym_size(dim) == 1) {
885 return grad;
886 }
887 return grad.flip(dim).cumsum(dim).flip(dim);
888 }
889
logsumexp_backward(Tensor grad,const Tensor & self,Tensor result,IntArrayRef dim,bool keepdim)890 Tensor logsumexp_backward(
891 Tensor grad,
892 const Tensor& self,
893 Tensor result,
894 IntArrayRef dim,
895 bool keepdim) {
896 if (!keepdim && self.dim() != 0) {
897 grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size());
898 result = unsqueeze_multiple(result, dim, self.sym_sizes().size());
899 }
900 return grad * (self - result).exp().conj();
901 }
902
logcumsumexp_backward(Tensor grad,const Tensor & self,const Tensor & result,int64_t dim)903 Tensor logcumsumexp_backward(
904 Tensor grad,
905 const Tensor& self,
906 const Tensor& result,
907 int64_t dim) {
908 if (grad.dim() == 0 || grad.sym_numel() == 0) {
909 return grad;
910 }
911
912 // Reference: https://github.com/tensorflow/tensorflow/blob/
913 // 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
914
915 auto scalar_min = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
916 at::ScalarType::BFloat16,
917 at::typeMetaToScalarType(grad.dtype()),
918 "logcumsumexp_backward",
919 []() { return c10::Scalar(std::numeric_limits<scalar_t>::lowest()); });
920
921 auto reverse_logcumsumexp = [dim](auto x) {
922 return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim});
923 };
924
925 if (!at::is_complex(grad)) {
926 auto grad_min = at::scalar_tensor(scalar_min, grad.options());
927 auto log_abs_grad = grad.abs().log();
928 auto log_grad_positive = at::where(grad > 0, log_abs_grad, grad_min);
929 auto log_grad_negative = at::where(grad < 0, log_abs_grad, grad_min);
930
931 auto output_pos =
932 (reverse_logcumsumexp(log_grad_positive - result) + self).exp();
933 auto output_neg =
934 (reverse_logcumsumexp(log_grad_negative - result) + self).exp();
935
936 return output_pos - output_neg;
937 } else {
938 // no trick separating the positive and negative required
939 auto log_grad = grad.conj().log();
940 auto output = (reverse_logcumsumexp(log_grad - result) + self).exp();
941 return output.conj();
942 }
943 }
944
logcumsumexp_jvp(const Tensor & self_p,const Tensor & self_t,int64_t dim)945 Tensor logcumsumexp_jvp(
946 const Tensor& self_p,
947 const Tensor& self_t,
948 int64_t dim) {
949 // Mostly taken from logsumexp_jvp
950
951 // NB: for simplicity, we recompute some values that can be reused from
952 // forward
953 auto self_p_exp = [&self_p, dim]() {
954 if (!at::is_complex(self_p)) {
955 return (self_p - std::get<0>(at::max(self_p, dim, true)))
956 .exp(); // Use the exp-normalize trick
957 } else {
958 // at::max doesn't support complex128
959 return self_p.exp();
960 }
961 }();
962
963 auto cumsumexp_p = self_p_exp.cumsum(dim);
964
965 TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor())
966
967 constexpr double eps = 1e-13;
968
969 if (areAnyTensorSubclassLike({self_p, self_t})) {
970 auto result = (self_p_exp * self_t).cumsum(dim);
971 result /= cumsumexp_p.add_(eps);
972 return result;
973 } else {
974 self_p_exp *= self_t;
975 auto cumsumexp_t = self_p_exp.cumsum(dim);
976 return cumsumexp_t /= cumsumexp_p.add_(eps);
977 }
978 }
979
unbind_backward(const variable_list & grads,int64_t dim)980 Tensor unbind_backward(const variable_list& grads, int64_t dim) {
981 c10::SymIntArrayRef sizes;
982 at::TensorOptions o;
983 for (const auto& v : grads) {
984 if (v.defined()) {
985 sizes = v.sym_sizes();
986 o = static_cast<Tensor>(v).options();
987 break;
988 }
989 }
990 auto grads_tensors = fmap(grads, [&](const Variable& v) {
991 return (
992 v.defined() ? static_cast<Tensor>(v)
993 : at::zeros({}, o).expand_symint(sizes));
994 });
995 return at::stack(grads_tensors, dim);
996 }
997
unbind_backward_nested(const variable_list & grads,const Tensor & nt_sizes,int64_t dim,const at::TensorOptions & options)998 Tensor unbind_backward_nested(
999 const variable_list& grads,
1000 const Tensor& nt_sizes,
1001 int64_t dim,
1002 const at::TensorOptions& options) {
1003 std::vector<Tensor> grads_tensors;
1004 for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
1005 if (grads[i].defined()) {
1006 grads_tensors.push_back(static_cast<Tensor>(grads[i]));
1007 } else {
1008 const auto component_size = nt_sizes[i].contiguous();
1009 const c10::IntArrayRef grad_size(
1010 component_size.data_ptr<int64_t>(), component_size.size(0));
1011 grads_tensors.push_back(at::zeros(grad_size, options));
1012 }
1013 }
1014
1015 return at::_nested_tensor_from_tensor_list(grads_tensors);
1016 }
1017
unbind_backward_nested_jagged(const variable_list & grads,const Tensor & self,int64_t dim)1018 Tensor unbind_backward_nested_jagged(
1019 const variable_list& grads,
1020 const Tensor& self,
1021 int64_t dim) {
1022 TORCH_INTERNAL_ASSERT(
1023 dim == 0, "unbind_backward_nested_jagged() only supports dim=0")
1024 auto grad_nt = at::zeros_like(self);
1025 auto unbound_grads = grad_nt.unbind();
1026 for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
1027 if (grads[i].defined()) {
1028 unbound_grads[i].copy_(static_cast<Tensor>(grads[i]));
1029 }
1030 }
1031
1032 return grad_nt;
1033 }
1034
unsqueeze_to(const Tensor & self,c10::SymIntArrayRef sym_sizes)1035 Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
1036 auto result = self;
1037
1038 auto nDims = sym_sizes.size();
1039 for (const auto dim : c10::irange(nDims)) {
1040 if (sym_sizes[dim] == 1) {
1041 result = result.unsqueeze(static_cast<int64_t>(dim));
1042 }
1043 }
1044 return result;
1045 }
1046
unsqueeze_to(const Tensor & self,IntArrayRef dims,c10::SymIntArrayRef sym_sizes)1047 Tensor unsqueeze_to(
1048 const Tensor& self,
1049 IntArrayRef dims,
1050 c10::SymIntArrayRef sym_sizes) {
1051 const auto ndim = sym_sizes.size();
1052 auto mask = at::dim_list_to_bitset(dims, ndim);
1053
1054 Tensor result = self;
1055 for (const auto d : c10::irange(ndim)) {
1056 if (mask.test(d) && sym_sizes[d] == 1) {
1057 result = result.unsqueeze(static_cast<int64_t>(d));
1058 }
1059 }
1060 return result;
1061 }
1062
unsqueeze_to(const Tensor & self,int64_t dim,c10::SymIntArrayRef sym_sizes)1063 Tensor unsqueeze_to(
1064 const Tensor& self,
1065 int64_t dim,
1066 c10::SymIntArrayRef sym_sizes) {
1067 return unsqueeze_to(self, IntArrayRef{dim}, sym_sizes);
1068 }
1069
cat_tensors_backward(const Tensor & grad,const std::vector<std::vector<c10::SymInt>> & sizes,const std::vector<ScalarType> & dtypes,int64_t dim)1070 std::vector<Tensor> cat_tensors_backward(
1071 const Tensor& grad,
1072 const std::vector<std::vector<c10::SymInt>>& sizes,
1073 const std::vector<ScalarType>& dtypes,
1074 int64_t dim) {
1075 std::vector<Tensor> grad_inputs(sizes.size());
1076 if (!grad.defined()) {
1077 return grad_inputs;
1078 }
1079 dim = at::legacy_cat_wrap_dim_symint(dim, sizes);
1080 c10::SymInt accumulate = 0;
1081
1082 Tensor grad_;
1083 bool grad_is_complex = grad.is_complex();
1084 if (grad_is_complex) {
1085 grad_ = at::real(grad);
1086 }
1087 for (const auto i : c10::irange(sizes.size())) {
1088 Tensor grad_val;
1089 if (!at::isComplexType(dtypes[i]) && grad_is_complex) {
1090 // R -> C
1091 grad_val = grad_;
1092 } else {
1093 grad_val = grad;
1094 }
1095 auto& shape = sizes[i];
1096 // If input was empty tensor, gradInput should be empty tensor.
1097 if (shape.size() == 1) {
1098 if (TORCH_GUARD_SIZE_OBLIVIOUS(shape[0].sym_eq(0))) {
1099 grad_inputs[i] = at::zeros({0}, grad_val.options());
1100 continue;
1101 }
1102 }
1103 const auto& size = shape[dim];
1104 accumulate += size;
1105 grad_inputs[i] = grad_val.narrow_symint(dim, accumulate - size, size);
1106 }
1107 return grad_inputs;
1108 }
1109
stack_tensors_backward(const Tensor & grad,int64_t dim,const std::vector<ScalarType> & dtypes)1110 std::vector<Tensor> stack_tensors_backward(
1111 const Tensor& grad,
1112 int64_t dim,
1113 const std::vector<ScalarType>& dtypes) {
1114 std::vector<Tensor> grad_inputs(dtypes.size());
1115 if (!grad.defined()) {
1116 return grad_inputs;
1117 }
1118 bool grad_is_complex = grad.is_complex();
1119 for (const auto i : c10::irange(dtypes.size())) {
1120 auto gr = grad.select(dim, static_cast<int64_t>(i));
1121 if (grad_is_complex && !at::isComplexType(dtypes[i])) {
1122 gr = at::real(gr);
1123 }
1124 grad_inputs[i] = gr;
1125 }
1126 return grad_inputs;
1127 }
1128
block_diag_backward(const Tensor & grad,const std::vector<std::vector<int64_t>> & sizes,const std::vector<ScalarType> & dtypes)1129 std::vector<Tensor> block_diag_backward(
1130 const Tensor& grad,
1131 const std::vector<std::vector<int64_t>>& sizes,
1132 const std::vector<ScalarType>& dtypes) {
1133 std::vector<Tensor> grad_inputs(sizes.size());
1134 if (!grad.defined()) {
1135 return grad_inputs;
1136 }
1137 Tensor real_view_of_grad;
1138 bool grad_is_complex = grad.is_complex();
1139 if (grad_is_complex) {
1140 real_view_of_grad = at::real(grad);
1141 }
1142
1143 int64_t cur_dim0 = 0;
1144 int64_t cur_dim1 = 0;
1145
1146 for (const auto i : c10::irange(sizes.size())) {
1147 // R -> C
1148 Tensor grad_val = (!at::isComplexType(dtypes[i]) && grad_is_complex)
1149 ? real_view_of_grad
1150 : grad;
1151
1152 auto& shape = sizes[i];
1153 // If input was empty tensor, gradInput should be empty tensor.
1154 if (shape.size() == 1 && shape[0] == 0) {
1155 grad_inputs[i] = at::zeros({0}, grad_val.options());
1156 continue;
1157 }
1158 // 0d case
1159 int64_t dim0 = 1;
1160 int64_t dim1 = 1;
1161 // 2d case
1162 if (shape.size() == 2) {
1163 dim0 = shape[0];
1164 dim1 = shape[1];
1165 // 1d case
1166 } else if (shape.size() == 1) {
1167 dim1 = shape[0];
1168 }
1169 auto slice = grad_val.slice(0, cur_dim0, cur_dim0 + dim0)
1170 .slice(1, cur_dim1, cur_dim1 + dim1);
1171 if (shape.size() == 1) {
1172 slice = slice.squeeze(-1);
1173 } else if (shape.empty()) {
1174 slice = slice.squeeze(-1).squeeze(-1);
1175 }
1176 grad_inputs[i] = slice;
1177 cur_dim0 += dim0;
1178 cur_dim1 += dim1;
1179 }
1180 return grad_inputs;
1181 }
1182
clamp_backward(const Tensor & grad,const Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max)1183 Tensor clamp_backward(
1184 const Tensor& grad,
1185 const Tensor& self,
1186 const std::optional<Scalar>& min,
1187 const std::optional<Scalar>& max) {
1188 // clamp: gradients not defined on min and max, so we return the subgradient 1
1189 // for these cases.
1190 if (max && min) {
1191 auto zero = at::scalar_tensor(0., grad.options());
1192 return where((self >= *min).logical_and_(self <= *max), grad, zero);
1193 } else if (min) {
1194 auto zero = at::scalar_tensor(0., grad.options());
1195 return where(self >= *min, grad, zero);
1196 } else if (max) {
1197 auto zero = at::scalar_tensor(0., grad.options());
1198 return where(self <= *max, grad, zero);
1199 } else {
1200 return grad;
1201 }
1202 }
1203
clamp_backward(const Tensor & grad,const Tensor & self,const Tensor & min,const Tensor & max)1204 Tensor clamp_backward(
1205 const Tensor& grad,
1206 const Tensor& self,
1207 const Tensor& min,
1208 const Tensor& max) {
1209 // clamp: gradients not defined on min and max, so we return the subgradient 1
1210 // for these cases.
1211 if (max.defined() && min.defined()) {
1212 auto zero = at::scalar_tensor(0., grad.options());
1213 const auto self_ge_min = self >= min;
1214 const auto self_le_max = self <= max;
1215 const auto& pred = areAnyTensorSubclassLike({self, min, max})
1216 ? self_ge_min.logical_and(self_le_max)
1217 : self_ge_min.logical_and_(self_le_max);
1218 return where(pred, grad, zero);
1219 } else if (min.defined()) {
1220 auto zero = at::scalar_tensor(0., grad.options());
1221 return where(self >= min, grad, zero);
1222 } else if (max.defined()) {
1223 auto zero = at::scalar_tensor(0., grad.options());
1224 return where(self <= max, grad, zero);
1225 } else {
1226 return grad;
1227 }
1228 }
1229
clamp_backward_min_max(const Tensor & grad,const Tensor & self,const Tensor & min,const Tensor & max,const std::array<bool,2> & grad_input_mask)1230 std::tuple<at::Tensor, at::Tensor> clamp_backward_min_max(
1231 const Tensor& grad,
1232 const Tensor& self,
1233 const Tensor& min,
1234 const Tensor& max,
1235 const std::array<bool, 2>& grad_input_mask) {
1236 // If min > max, min has no gradient
1237 std::tuple<at::Tensor, at::Tensor> ret;
1238 if (!grad.defined()) {
1239 return ret;
1240 }
1241
1242 auto zero = at::scalar_tensor(0., grad.options());
1243 if (max.defined() && min.defined()) {
1244 if (grad_input_mask[0]) {
1245 const auto self_lt_min = self < min;
1246 const auto min_lt_max = min < max;
1247 const auto& pred = areAnyTensorSubclassLike({self, min, max})
1248 ? self_lt_min.logical_and(min_lt_max)
1249 : self_lt_min.logical_and_(min_lt_max);
1250 std::get<0>(ret) = where(pred, grad, zero);
1251 }
1252 if (grad_input_mask[1]) {
1253 const auto self_gt_max = self > max;
1254 const auto max_lt_min = max < min;
1255 const auto& pred = areAnyTensorSubclassLike({self, min, max})
1256 ? self_gt_max.logical_or(max_lt_min)
1257 : self_gt_max.logical_or_(max_lt_min);
1258 std::get<1>(ret) = where(pred, grad, zero);
1259 }
1260 } else if (min.defined() && grad_input_mask[0]) {
1261 std::get<0>(ret) = where(self < min, grad, zero);
1262 } else if (max.defined() && grad_input_mask[1]) {
1263 std::get<1>(ret) = where(self > max, grad, zero);
1264 }
1265 return ret;
1266 }
1267
clamp_jvp(const Tensor & self_p,const Tensor & self_t,const Tensor & min_p,const Tensor & min_t,const Tensor & max_p,const Tensor & max_t)1268 at::Tensor clamp_jvp(
1269 const Tensor& self_p,
1270 const Tensor& self_t,
1271 const Tensor& min_p,
1272 const Tensor& min_t,
1273 const Tensor& max_p,
1274 const Tensor& max_t) {
1275 if (min_p.defined() && max_p.defined()) {
1276 return where(
1277 min_p > max_p,
1278 max_t,
1279 where(self_p < min_p, min_t, where(self_p > max_p, max_t, self_t)));
1280 } else if (min_p.defined()) {
1281 return where(self_p > min_p, self_t, min_t);
1282 } else if (max_p.defined()) {
1283 return where(self_p < max_p, self_t, max_t);
1284 } else {
1285 return self_t;
1286 }
1287 }
1288
convolution_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_p,const Tensor & bias_t,at::SymIntArrayRef stride,at::SymIntArrayRef padding,at::SymIntArrayRef dilation,bool transposed,at::SymIntArrayRef output_padding,const c10::SymInt & groups)1289 Tensor convolution_jvp(
1290 const Tensor& input_p,
1291 const Tensor& input_t,
1292 const Tensor& weight_p,
1293 const Tensor& weight_t,
1294 const Tensor& bias_p,
1295 const Tensor& bias_t,
1296 at::SymIntArrayRef stride,
1297 at::SymIntArrayRef padding,
1298 at::SymIntArrayRef dilation,
1299 bool transposed,
1300 at::SymIntArrayRef output_padding,
1301 const c10::SymInt& groups) {
1302 auto bias_t_opt =
1303 bias_t.defined() ? std::optional<at::Tensor>(bias_t) : std::nullopt;
1304 return (
1305 at::convolution_symint(
1306 input_t,
1307 weight_p,
1308 std::nullopt,
1309 stride,
1310 padding,
1311 dilation,
1312 transposed,
1313 output_padding,
1314 groups) +
1315 at::convolution_symint(
1316 input_p,
1317 weight_t,
1318 bias_t_opt,
1319 stride,
1320 padding,
1321 dilation,
1322 transposed,
1323 output_padding,
1324 groups));
1325 }
1326
_convolution_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_p,const Tensor & bias_t,at::SymIntArrayRef stride,at::SymIntArrayRef padding,at::SymIntArrayRef dilation,bool transposed,at::SymIntArrayRef output_padding,const c10::SymInt & groups,bool benchmark,bool deterministic,bool cudnn_enabled,bool allow_tf32)1327 Tensor _convolution_jvp(
1328 const Tensor& input_p,
1329 const Tensor& input_t,
1330 const Tensor& weight_p,
1331 const Tensor& weight_t,
1332 const Tensor& bias_p,
1333 const Tensor& bias_t,
1334 at::SymIntArrayRef stride,
1335 at::SymIntArrayRef padding,
1336 at::SymIntArrayRef dilation,
1337 bool transposed,
1338 at::SymIntArrayRef output_padding,
1339 const c10::SymInt& groups,
1340 bool benchmark,
1341 bool deterministic,
1342 bool cudnn_enabled,
1343 bool allow_tf32) {
1344 auto bias_t_opt =
1345 bias_t.defined() ? std::optional<at::Tensor>(bias_t) : std::nullopt;
1346 return (
1347 at::_convolution_symint(
1348 input_t,
1349 weight_p,
1350 std::nullopt,
1351 stride,
1352 padding,
1353 dilation,
1354 transposed,
1355 output_padding,
1356 groups,
1357 benchmark,
1358 deterministic,
1359 cudnn_enabled,
1360 allow_tf32) +
1361 at::_convolution_symint(
1362 input_p,
1363 weight_t,
1364 bias_t_opt,
1365 stride,
1366 padding,
1367 dilation,
1368 transposed,
1369 output_padding,
1370 groups,
1371 benchmark,
1372 deterministic,
1373 cudnn_enabled,
1374 allow_tf32));
1375 }
1376
convolution_backward_jvp_grad_bias(const Tensor & grad_out_t,const Tensor & grad_bias)1377 Tensor convolution_backward_jvp_grad_bias(
1378 const Tensor& grad_out_t,
1379 const Tensor& grad_bias) {
1380 if (!grad_bias.defined()) {
1381 return Tensor();
1382 }
1383 int64_t dim = grad_out_t.dim() - 2;
1384 if (dim == 1) {
1385 // Cannot pass initializer list due to overload ambiguity
1386 auto dimlist = std::vector<int64_t>{0, 2};
1387 return grad_out_t.sum(dimlist);
1388 } else if (dim == 2) {
1389 return grad_out_t.sum({0, 2, 3});
1390 } else if (dim == 3) {
1391 return grad_out_t.sum({0, 2, 3, 4});
1392 } else {
1393 TORCH_INTERNAL_ASSERT(
1394 false,
1395 "convolution_backward_jvp_grad_bias expected dim of grad_out_t to be 3, 4, or 5, but got: ",
1396 grad_out_t.dim());
1397 }
1398 }
1399
1400 // This function is used by load_derivatives.py to replace tensor.strides()
1401 // calls that appear in derivative formulas. If the tensor has requires_grad
1402 // set, this function returns its strides or an empty array if the tensor
1403 // is sparse. If requires_grad is not set, an empty array is returned since
1404 // there will be no backward pass. There has one special case, if input is
1405 // MKLDNN tensor and has requires_grad set, just return an empty array, the
1406 // reason is that MKLDNN tensor is a opaque tensor which has not stride info.
1407 //
1408 // This function only supports the case where `input` is the tensor whose
1409 // single derivative is being calculated.
1410 //
1411 // This function does not support `self` derivatives for inplace functions.
1412 //
1413 // Args:
1414 // input Tensor to call .strides() on
1415 // input_name Name of `input` tensor, from derivative formula
strides_or_error(const Tensor & input,c10::string_view const & input_name)1416 at::SymIntArrayRef strides_or_error(
1417 const Tensor& input,
1418 c10::string_view const& input_name) {
1419 // TODO: Ideally, this function would never be called if requires_grad is
1420 // not set. Once codegen is updated to avoid the call, we can remove this
1421 // check.
1422 if (input.requires_grad()) {
1423 if (input.is_mkldnn())
1424 return {};
1425 if (input.is_sparse() || at::sparse_csr::is_sparse_compressed(input))
1426 return {};
1427 return input.sym_strides();
1428 } else {
1429 return {};
1430 }
1431 }
1432
mm_mat1_backward(const Tensor & grad,const Tensor & mat2,at::SymIntArrayRef mat1_sizes,at::SymIntArrayRef mat1_strides,c10::Layout mat1_layout,const Scalar & alpha)1433 Tensor mm_mat1_backward(
1434 const Tensor& grad,
1435 const Tensor& mat2,
1436 at::SymIntArrayRef mat1_sizes,
1437 at::SymIntArrayRef mat1_strides,
1438 c10::Layout mat1_layout,
1439 const Scalar& alpha) {
1440 if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
1441 mat1_layout == c10::kStrided) {
1442 // if input was column-major, return grad as column-order for efficiency
1443 if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) {
1444 return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha.conj());
1445 }
1446 }
1447
1448 // General fallback, should work for any layout
1449 return maybe_multiply(grad.mm(mat2.t().conj()), alpha.conj());
1450 }
1451
mm_mat2_backward(const Tensor & grad,const Tensor & mat1,at::SymIntArrayRef mat2_sizes,at::SymIntArrayRef mat2_strides,c10::Layout mat2_layout,const Scalar & alpha)1452 Tensor mm_mat2_backward(
1453 const Tensor& grad,
1454 const Tensor& mat1,
1455 at::SymIntArrayRef mat2_sizes,
1456 at::SymIntArrayRef mat2_strides,
1457 c10::Layout mat2_layout,
1458 const Scalar& alpha) {
1459 if (grad.layout() == c10::kStrided && mat1.layout() == c10::kStrided &&
1460 mat2_layout == c10::kStrided) {
1461 // if input was column-major, return grad as column-order for efficiency
1462 if (mat2_strides[0] == 1 && mat2_strides[1] == mat2_sizes[0]) {
1463 return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha.conj());
1464 }
1465 }
1466
1467 // General fallback, should work for any layout
1468 return maybe_multiply(mat1.t().conj().mm(grad), alpha.conj());
1469 }
1470
mm_mat1_sparse_backward(const Tensor & grad,const Tensor & mat1,const Tensor & mat2,const Scalar & alpha)1471 Tensor mm_mat1_sparse_backward(
1472 const Tensor& grad,
1473 const Tensor& mat1,
1474 const Tensor& mat2,
1475 const Scalar& alpha) {
1476 if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
1477 mat1.is_sparse()) {
1478 auto sparse = mat1.coalesce();
1479 Tensor grad_sparse = maybe_multiply(grad.mm(mat2.conj().t()), alpha);
1480 return grad_sparse.sparse_mask(sparse);
1481 } else if (
1482 grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
1483 mat1.is_sparse_csr()) {
1484 // zero must to have mat1 sparsity pattern:
1485 auto zero = mat1.clone();
1486 zero.values().zero_();
1487 return at::sparse_sampled_addmm(zero, grad, mat2.mH(), 1.0, alpha);
1488 } else if (
1489 grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
1490 mat1.layout() == c10::kStrided) {
1491 return maybe_multiply(grad.mm(mat2.mH()), alpha);
1492 }
1493 TORCH_CHECK(
1494 false,
1495 "sparse_addmm_sparse_backward: unsupported combination of layouts",
1496 ", grad: ",
1497 grad.layout(),
1498 ", mat1: ",
1499 mat1.layout(),
1500 ", mat2: ",
1501 mat2.layout());
1502 }
1503
sparse_mask_like_grad(const Tensor & x,const Tensor & gx,bool accumulate_matches)1504 static Tensor sparse_mask_like_grad(
1505 const Tensor& x,
1506 const Tensor& gx,
1507 bool accumulate_matches) {
1508 if (x.is_coalesced() && gx.is_coalesced()) {
1509 if (x._nnz() >= gx._nnz()) {
1510 // search into x is faster
1511 return gx._sparse_mask_projection(x, accumulate_matches);
1512 } else {
1513 // search into gx is faster
1514 return gx.sparse_mask(x);
1515 }
1516 } else if (x.is_coalesced()) {
1517 return gx.sparse_mask(x);
1518 } else if (gx.is_coalesced()) {
1519 return gx._sparse_mask_projection(x, accumulate_matches);
1520 } else {
1521 if (x._nnz() >= gx._nnz()) {
1522 // gx.coalesce() is likely faster
1523 return gx.coalesce()._sparse_mask_projection(x, accumulate_matches);
1524 } else {
1525 // x.coalesce() is likely faster
1526 return gx.sparse_mask(x.coalesce());
1527 }
1528 }
1529 }
1530
sparse_sampled_addmm_backward(const Tensor & grad,const Tensor & self,const std::optional<Tensor> & mat1,const std::optional<Tensor> & mat2,const Scalar & alpha,const Scalar & beta,const std::array<bool,3> & grad_input_mask)1531 std::tuple<Tensor, Tensor, Tensor> sparse_sampled_addmm_backward(
1532 const Tensor& grad,
1533 const Tensor& self,
1534 const std::optional<Tensor>& mat1,
1535 const std::optional<Tensor>& mat2,
1536 const Scalar& alpha,
1537 const Scalar& beta,
1538 const std::array<bool, 3>& grad_input_mask) {
1539 if (!grad.defined()) {
1540 return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
1541 }
1542
1543 const auto grad_projected = grad.sparse_mask(self);
1544 const auto self_requires_grad = grad_input_mask[0];
1545 const auto mat1_requires_grad = grad_input_mask[1];
1546 const auto mat2_requires_grad = grad_input_mask[2];
1547 return std::make_tuple(
1548 self_requires_grad ? maybe_multiply(grad, beta.conj()) : Tensor{},
1549 mat1_requires_grad
1550 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
1551 ? maybe_multiply(grad_projected.mm(mat2->mH()), alpha.conj())
1552 : Tensor{},
1553 mat2_requires_grad
1554 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
1555 ? maybe_multiply(mat1->mH().mm(grad_projected), alpha.conj())
1556 : Tensor{});
1557 }
1558
sparse_mask_backward(const Tensor & grad,const Tensor & mask,const c10::Layout self_layout)1559 Tensor sparse_mask_backward(
1560 const Tensor& grad,
1561 const Tensor& mask,
1562 const c10::Layout self_layout) {
1563 // NOTE: sparse_mask accumulates matches, so the backward step has to
1564 // accumulate as well.
1565 const auto self_grad =
1566 sparse_mask_like_grad(mask, grad, /*accumulate_matches=*/true);
1567 return self_layout == at::kStrided ? self_grad.to_dense() : self_grad;
1568 }
1569
sparse_sparse_matmul_backward(const Tensor & grad,const Tensor & a,const Tensor & b,int64_t grad_order)1570 Tensor sparse_sparse_matmul_backward(
1571 const Tensor& grad,
1572 const Tensor& a,
1573 const Tensor& b,
1574 int64_t grad_order) {
1575 /*
1576 To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we
1577 can start from the following definition for dense tensors:
1578
1579 c = a @ b
1580 then
1581 a_grad = c_grad @ b^H
1582 b_grad = a^H @ c_grad
1583
1584 So for sparse matrices we can use the following definition:
1585
1586 if grad_order == 0:
1587 a_grad = sparse_matrix_mask(c_grad @ b^H, mask=a)
1588 else:
1589 b_grad = sparse_matrix_mask(a^H @ c_grad, mask=b)
1590 */
1591 TORCH_CHECK(
1592 grad_order == 0 || grad_order == 1,
1593 ": grad_order not in [0, 1] at sparse_sparse_matmul_backward function");
1594
1595 // NOTE: _sparse_sparse_matmul returns a coalesced gradient,
1596 // // hence there is no need in accumulating matches.
1597 if (grad_order == 0) {
1598 auto a_grad = _sparse_sparse_matmul(grad, b.conj().t());
1599 return sparse_mask_like_grad(a, a_grad, /*accumulate_matches=*/false);
1600 }
1601 auto b_grad = _sparse_sparse_matmul(a.conj().t(), grad);
1602 return sparse_mask_like_grad(b, b_grad, /*accumulate_matches=*/false);
1603 }
1604
renorm_backward(const Tensor & grad,const Tensor & self,const Scalar & p,int64_t dim,const Scalar & maxnorm)1605 Tensor renorm_backward(
1606 const Tensor& grad,
1607 const Tensor& self,
1608 const Scalar& p,
1609 int64_t dim,
1610 const Scalar& maxnorm) {
1611 auto n = self.dim();
1612 dim = c10::maybe_wrap_dim(dim, n);
1613 auto reduce_dims = at::DimVector(n);
1614 std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
1615 reduce_dims.erase(reduce_dims.begin() + dim);
1616
1617 auto acc_type =
1618 at::toAccumulateType(self.scalar_type(), self.device().type());
1619 auto norm = at::linalg_vector_norm(
1620 self, p, reduce_dims, /*keepdim=*/true, /*dtype=*/acc_type);
1621
1622 const auto real_acc_type = c10::toRealValueType(acc_type);
1623 auto grad_output = (self.conj() * grad);
1624 // vector_norm output is real, so grad_output must also be real
1625 if (real_acc_type != acc_type) {
1626 grad_output = at::real(grad_output);
1627 }
1628 grad_output =
1629 grad_output.sum(reduce_dims, /*keepdim=*/true, /*dtype=*/real_acc_type);
1630 auto nb = norm_backward(
1631 std::move(grad_output), self, p, norm, reduce_dims, /*keepdim=*/true);
1632
1633 auto invnorm = (norm + 1e-7).reciprocal();
1634 auto grad_norm = maxnorm * invnorm * (grad - invnorm * nb);
1635 return at::where(norm > maxnorm, grad_norm.to(grad.scalar_type()), grad);
1636 }
1637
renorm_jvp(const Tensor & self_p,const Tensor & self_t,const Scalar & p,int64_t dim,const Scalar & maxnorm)1638 Tensor renorm_jvp(
1639 const Tensor& self_p,
1640 const Tensor& self_t,
1641 const Scalar& p,
1642 int64_t dim,
1643 const Scalar& maxnorm) {
1644 auto self_sizes = self_p.sizes();
1645 dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(self_sizes.size()));
1646
1647 at::DimVector reduce_dims(self_sizes.size());
1648 std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
1649 reduce_dims.erase(reduce_dims.begin() + dim);
1650
1651 // For cuda half, calculate norm in float precision then cast
1652 // normalization factor to half
1653 auto dtype = self_p.scalar_type();
1654 auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true);
1655 Tensor norm = [&self_p, &p, &reduce_dims, acc_type, dtype]() {
1656 if (acc_type != dtype) {
1657 return at::linalg_vector_norm(
1658 self_p,
1659 p.toDouble(),
1660 reduce_dims,
1661 /*keepdim=*/true,
1662 /*dtype=*/acc_type);
1663 } else {
1664 return at::linalg_vector_norm(
1665 self_p,
1666 p.toDouble(),
1667 reduce_dims,
1668 /*keepdim=*/true);
1669 }
1670 }();
1671
1672 auto double_maxnorm = maxnorm.toDouble();
1673 auto invnorm = (norm + 1e-7).reciprocal();
1674 auto factor = invnorm * double_maxnorm;
1675
1676 return where(
1677 norm > double_maxnorm,
1678 factor *
1679 (self_t -
1680 self_p * invnorm *
1681 norm_jvp(
1682 self_p, self_t, p, norm, reduce_dims, /*keepdim=*/true)),
1683 self_t);
1684 }
1685
repeat_backward(Tensor grad,c10::SymIntArrayRef repeats,c10::SymIntArrayRef input_shape)1686 Tensor repeat_backward(
1687 Tensor grad,
1688 c10::SymIntArrayRef repeats,
1689 c10::SymIntArrayRef input_shape) {
1690 auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0);
1691 if (find_iter != repeats.cend()) {
1692 return at::zeros_symint(input_shape, grad.options());
1693 }
1694 const auto input_dims = input_shape.size();
1695 auto num_unsqueezed = grad.dim() - input_dims;
1696 for (const auto i : c10::irange(num_unsqueezed)) {
1697 (void)i; // Suppress unused variable warning
1698 grad = grad.sum(0, false);
1699 }
1700
1701 at::SymDimVector grad_size;
1702 at::DimVector sum_dims;
1703 for (const auto dim : c10::irange(input_dims)) {
1704 const auto& repeat = repeats[dim + num_unsqueezed];
1705 // Reshape gradient (repeat > 1)
1706 // Index: [..., dim , ...] [..., dim , dim+1 , ...]
1707 // Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...]
1708 // The gradient tensor at 'dim' is reshaped to 'repeat' times of input
1709 // tensor. Then, sum up gradients over repeated tensors along 'dim', and
1710 // reduce shape from 'repeat * dimsize/repeat' to 'dimsize/repeat'
1711 // ('input_dimsize'). Example:
1712 // Size(3, 2) Size(6, 2)
1713 // [[v1_0, v1_1],
1714 // [v1_2, v1_3],
1715 // [[v0, v1], repeat(2, 1) [v1_4, v1_5],
1716 // [v2, v3], -------------> [v2_0, v2_1],
1717 // [v4, v5]] [v2_2, v2_3],
1718 // [v2_4, v2_5]]
1719 //
1720 // input grad (3, 2) reshape (2, 3, 2) output grad (6, 2)
1721 // [[[g1_0, g1_1], [[g1_0, g1_1],
1722 // [g1_2, g1_3], [g1_2, g1_3],
1723 // [[g1_0+g2_0, g1_1+g2_1], [g1_4, g1_5]], [g1_4, g1_5],
1724 // [g1_2+g2_2, g1_3+g2_3], [g2_0, g2_1], [[g2_0, g2_1],
1725 // [g1_4+g2_4, g1_5+g2_5]] [g2_2, g2_3], [g2_2, g2_3],
1726 // [g2_4, g2_5]] [g2_4, g2_5]]]
1727 //
1728 // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and
1729 // then sum over 'dim+1'. The gradient for input is not correctly aligned
1730 // with input. Example:
1731 // input grad (3, 2) reshape (3, 2, 2) output grad (6, 2)
1732 // [[[g1_0, g1_1], [[g1_0, g1_1],
1733 // [g1_2, g1_3]], [g1_2, g1_3],
1734 // [[g1_0+g1_2, g1_1+g1_3], [[g1_4, g1_5], [g1_4, g1_5],
1735 // [g1_4+g2_0, g1_5+g2_1], [g2_0, g2_1]], [g2_0, g2_1],
1736 // [g2_2+g2_4, g2_3+g2_5]] [[g2_2, g2_3], [g2_2, g2_3],
1737 // [g2_4, g2_5]]] [g2_4, g2_5]]
1738 if (repeat != 1) {
1739 grad_size.push_back(repeat);
1740 sum_dims.push_back(static_cast<int64_t>(grad_size.size() - 1));
1741 }
1742 // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat ==
1743 // 1)
1744 grad_size.push_back(input_shape[dim]);
1745 }
1746 // One-time Reshape & Sum
1747 // Reshape gradient to grad_size:
1748 // 1. If repeat equals to 1, append input size at that dimension,
1749 // 2. If repeat is larger than 1, append both repeat and input size at that
1750 // dimension.
1751 // Sum over all "repeat" dimensions from sum_dims:
1752 // Example:
1753 // Input Size (2, 3, 4, 5)
1754 // repeat [4, 1, 9, 3]
1755 // output/grad Size (8, 3, 36, 15)
1756 // grad_size [4, 2, 3, 9, 4, 3, 5]
1757 // sum_dims [0, 3, 5]
1758
1759 // When repeat 1 time over all original dimensions, the empty sum_dims will
1760 // reduce the whole grad tensor into a scalar rather than keeping original
1761 // dimensions.
1762 if (!sum_dims.empty()) {
1763 grad = grad.reshape_symint(grad_size);
1764 grad = grad.sum(sum_dims);
1765 }
1766 return grad;
1767 }
1768
1769 // p1m == 1 - p
_fused_dropout_backward(const Tensor & grad,const Tensor & mask,double p1m)1770 Tensor _fused_dropout_backward(
1771 const Tensor& grad,
1772 const Tensor& mask,
1773 double p1m) {
1774 if (grad.requires_grad()) {
1775 // Use autograd-friendly backward if double backward is required
1776 return grad * (mask.type_as(grad) * (1. / p1m));
1777 } else {
1778 return at::_masked_scale(grad, mask, 1. / p1m);
1779 }
1780 }
1781
1782 // scale == (1 / (1 - prob))
infinitely_differentiable_native_dropout_backward(const Tensor & grad,const Tensor & mask,double scale)1783 Tensor infinitely_differentiable_native_dropout_backward(
1784 const Tensor& grad,
1785 const Tensor& mask,
1786 double scale) {
1787 return grad * (mask.type_as(grad) * scale);
1788 }
1789
native_dropout_double_backward(const Tensor & ggI,const Tensor & grad,const Tensor & mask,double scale)1790 Tensor native_dropout_double_backward(
1791 const Tensor& ggI,
1792 const Tensor& grad,
1793 const Tensor& mask,
1794 double scale) {
1795 return ggI.type_as(grad) * (mask.type_as(grad) * scale);
1796 }
1797
evenly_distribute_backward(const Tensor & grad,const Tensor & input,const Tensor & value)1798 Tensor evenly_distribute_backward(
1799 const Tensor& grad,
1800 const Tensor& input,
1801 const Tensor& value) {
1802 bool any_tensor_subclass_like =
1803 areAnyTensorSubclassLike({grad, input, value});
1804 if (any_tensor_subclass_like || input.is_cuda()) {
1805 const auto input_isnan = input.isnan();
1806 const auto value_isnan = value.isnan();
1807 const auto& input_and_value_isnan = any_tensor_subclass_like
1808 ? input_isnan.logical_and(value_isnan)
1809 : input_isnan.logical_and_(value_isnan);
1810 const auto mask = (input == value).logical_or_(input_and_value_isnan);
1811 return mask * (grad / mask.sum());
1812 } else {
1813 auto mask = value.isnan().item<bool>() ? input.isnan() : input == value;
1814 return grad.new_zeros(input.sizes(), input.options())
1815 .masked_fill_(mask, grad / mask.sum());
1816 }
1817 }
1818
evenly_read_jvp(const Tensor & fw_grad,const Tensor & input,const Tensor & value)1819 Tensor evenly_read_jvp(
1820 const Tensor& fw_grad,
1821 const Tensor& input,
1822 const Tensor& value) {
1823 auto mask = (input == value);
1824 auto count = mask.sum();
1825 auto grad_output = fw_grad / count;
1826 return at::sum(mask * grad_output);
1827 }
1828
var_backward(Tensor grad,const Tensor & self,at::OptionalIntArrayRef dim_opt,const std::optional<at::Scalar> & correction_opt,bool keepdim)1829 Tensor var_backward(
1830 Tensor grad,
1831 const Tensor& self,
1832 at::OptionalIntArrayRef dim_opt,
1833 const std::optional<at::Scalar>& correction_opt,
1834 bool keepdim) {
1835 const auto correction = correction_opt.value_or(1).toSymFloat();
1836 if (self.dim() == 0 || !dim_opt.has_value()) {
1837 const auto dof = c10::SymFloat(self.sym_numel()) - correction;
1838 if (dof <= 0) {
1839 // when n == correction, 2 / (n - correction) is infinity
1840 // when self == self.mean(), we return NaN because infinity * 0 = NaN
1841 // otherwise, we return infinity because infinity * c = infinity, for all
1842 // c > 0
1843 return grad *
1844 at::where(
1845 self == self.mean(),
1846 std::numeric_limits<double>::quiet_NaN(),
1847 std::numeric_limits<double>::infinity());
1848 } else {
1849 return (c10::SymFloat(2.0) / dof) * grad * (self - self.mean());
1850 }
1851 }
1852 auto dim = dim_opt.value();
1853 if (!keepdim && self.dim() > 1) {
1854 grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size());
1855 }
1856 const c10::SymFloat rnumel(_safe_size(self.sym_sizes(), dim));
1857 return (c10::SymFloat(2.0) / (rnumel - correction)) * grad *
1858 (self - self.mean(dim, /*keepdim=*/true));
1859 }
1860
std_backward(const Tensor & result,const Tensor & grad,const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<c10::Scalar> & correction_opt,bool keepdim)1861 Tensor std_backward(
1862 const Tensor& result,
1863 const Tensor& grad,
1864 const Tensor& self,
1865 at::OptionalIntArrayRef dim,
1866 const std::optional<c10::Scalar>& correction_opt,
1867 bool keepdim) {
1868 auto grad_var = (grad / (result * 2)).masked_fill_(result == 0, 0);
1869 return var_backward(std::move(grad_var), self, dim, correction_opt, keepdim);
1870 }
1871
var_mean_backward(const Tensor & gvar,const Tensor & gmean,const Tensor & self,at::OptionalIntArrayRef dim_opt,const std::optional<c10::Scalar> & correction_opt,bool keepdim)1872 Tensor var_mean_backward(
1873 const Tensor& gvar,
1874 const Tensor& gmean,
1875 const Tensor& self,
1876 at::OptionalIntArrayRef dim_opt,
1877 const std::optional<c10::Scalar>& correction_opt,
1878 bool keepdim) {
1879 Tensor gself;
1880 if (gvar.defined()) {
1881 gself = var_backward(gvar, self, dim_opt, correction_opt, keepdim);
1882 }
1883 if (gmean.defined()) {
1884 auto aux = mean_backward(
1885 gmean,
1886 self.sym_sizes(),
1887 dim_opt.value_or(IntArrayRef({})),
1888 self.sym_numel(),
1889 keepdim);
1890 gself = gself.defined() ? gself + aux : std::move(aux);
1891 }
1892 return gself;
1893 }
1894
std_mean_backward(const Tensor & gstd,const Tensor & gmean,const Tensor & self,const Tensor & std,at::OptionalIntArrayRef dim_opt,const std::optional<c10::Scalar> & correction_opt,bool keepdim)1895 Tensor std_mean_backward(
1896 const Tensor& gstd,
1897 const Tensor& gmean,
1898 const Tensor& self,
1899 const Tensor& std,
1900 at::OptionalIntArrayRef dim_opt,
1901 const std::optional<c10::Scalar>& correction_opt,
1902 bool keepdim) {
1903 Tensor gself;
1904 if (gstd.defined()) {
1905 gself = std_backward(std, gstd, self, dim_opt, correction_opt, keepdim);
1906 }
1907 if (gmean.defined()) {
1908 auto aux = mean_backward(
1909 gmean,
1910 self.sym_sizes(),
1911 dim_opt.value_or(IntArrayRef({})),
1912 self.sym_numel(),
1913 keepdim);
1914 gself = gself.defined() ? gself + aux : std::move(aux);
1915 }
1916 return gself;
1917 }
1918
cholesky_jvp(const Tensor & dA,const Tensor & L,bool upper)1919 Tensor cholesky_jvp(const Tensor& dA, const Tensor& L, bool upper) {
1920 at::NoTF32Guard disable_tf32;
1921 // Let A = LL^H
1922 // dA = dLL^H + L(dL)^H
1923 // L^{-1}dA(L^{-H}) = L^{-1}dL + (L^{-1}dL)^H
1924 // = sym(L^{-1}dL)
1925 // where sym(X) = X + X^H
1926 // A short computation gives that the inverse of sym is given by
1927 // \pi(X) = X.tril() - 0.5*diag(X)
1928 // so
1929 // dL = L\pi(L^{-1}dA(L^{-H}))
1930
1931 // Precondition: dA is symmetric/Hermitian
1932 auto L_ = upper ? L.mH() : L;
1933 auto dL = at::linalg_solve_triangular(L_, dA, /*upper=*/false, /*left=*/true);
1934 dL = at::linalg_solve_triangular(L_.mH(), dL, /*upper=*/true, /*left=*/false);
1935 dL = dL.tril() - dL.diagonal(0, -2, -1).mul(0.5).diag_embed();
1936 dL = L_.matmul(dL);
1937 return upper ? dL.mH() : std::move(dL);
1938 }
1939
cholesky_backward(const Tensor & gL,bool upper,const Tensor & L)1940 Tensor cholesky_backward(const Tensor& gL, bool upper, const Tensor& L) {
1941 at::NoTF32Guard disable_tf32;
1942 // From cholesky_jvp we have that
1943 // dL = L\pi(L^{-1}dA(L^-H))
1944 //
1945 // Let gL be the projection into the lower-triangular gradient wrt L. Taking
1946 // adjoints we have gA = L^{-H}\pi^*((L^HgL).tril())L^{-1} where \pi^*(X) =
1947 // 0.5 * (X + X^H - diag(X)) The only non-standard point of this derivation is
1948 // noting that the adjoint to multiplying on the left by a lower triangular
1949 // matrix L is multiplying by L^H and then projecting back to the lower
1950 // triangular matrices (hence the .tril() projection) Note that the gradient
1951 // is symmetric and not triangular.
1952 auto L_ = upper ? L.mH() : L;
1953 auto gL_ = upper ? gL.mH() : gL;
1954
1955 // Nb. We don't need to compute gL_ = gL.tril() as
1956 // tril(L^H gL) = tril(L^H (triu(gL, 1) + tril(gL)))
1957 // = tril(L^H tril(gL)) + tril(L^H triu(gL, 1))
1958 // = tril(L^H tril(gL))
1959 // since tril(L^H triu(gL, 1)) = 0, as L^H triu(gL, 1) is upper triangular
1960 auto gA = L_.mH().matmul(gL_).tril();
1961 // Equivalent to 0.5 * (gA + gA^H - diag(gA))
1962 gA = 0.5 * (gA + gA.tril(-1).mH());
1963 gA = at::linalg_solve_triangular(L_.mH(), gA, /*upper=*/true, /*left=*/true);
1964 gA = at::linalg_solve_triangular(L_, gA, /*upper=*/false, /*left=*/false);
1965 return gA;
1966 }
1967
cholesky_inverse_backward(const Tensor & grad,const Tensor & L,bool upper,const Tensor & inverse)1968 Tensor cholesky_inverse_backward(
1969 const Tensor& grad,
1970 const Tensor& L,
1971 bool upper,
1972 const Tensor& inverse) {
1973 at::NoTF32Guard disable_tf32;
1974 Tensor grad_L;
1975 if (grad.defined()) {
1976 Tensor common_term = grad + grad.mH();
1977 common_term = at::matmul(inverse, at::matmul(common_term, inverse));
1978 if (upper) {
1979 grad_L = -at::matmul(L, common_term);
1980 } else {
1981 grad_L = -at::matmul(common_term, L);
1982 }
1983 }
1984
1985 return grad_L;
1986 }
1987
1988 // If X = (L L^H)^{-1} with L lower-triangular with a real positive diagonal,
1989 // then dX = K^H + K, where
1990 // K = L^{-H} dL^{-1} [dL^{-1} = -L^{-1} dL L^{-1}]
1991 // = -L^{-H} L^{-1} dL L^{-1} [L^{-H} L^{-1} = X]
1992 // = -X dL L^{-1} [X = X^H = L^{-H} L^{-1} = L^{-1} L^{-H}]
1993 // = -X dL X L^{H}.
1994 // If X = (U^H U)^{-1} with U upper-triangular with a real positive diagonal,
1995 // then K becomes
1996 // K = -X dU^H X U
cholesky_inverse_jvp(const Tensor & F,const Tensor & dF,const Tensor & X,bool upper)1997 Tensor cholesky_inverse_jvp(
1998 const Tensor& F,
1999 const Tensor& dF,
2000 const Tensor& X,
2001 bool upper) {
2002 at::NoTF32Guard disable_tf32;
2003 const auto CF = upper ? F : F.mH();
2004 const auto dCF = upper ? dF.mH() : dF;
2005 const auto partial_dX = -X.matmul(dCF).matmul(X).matmul(CF);
2006 return partial_dX + partial_dX.mH();
2007 }
2008
2009 // The formula for forward AD is adapted from
2010 //
2011 // Golub, Gene H., and Victor Pereyra. "The Differentiation of Pseudo-Inverses
2012 // and Nonlinear Least Squares Problems Whose Variables Separate." SIAM Journal
2013 // on Numerical Analysis 10(2). (1973). 413-432. doi: 10.1137/0710036
2014 //
2015 // We present a short derivation below:
2016 // Let Ap := pinv(A), then Ap is the unique matrix such that
2017 //
2018 // Ap A Ap = Ap [1]
2019 // A Ap A = A [2]
2020 //
2021 // By differentiating [1] we get:
2022 //
2023 // dAp = dAp A Ap + Ap dA Ap + Ap A dAp [3]
2024 //
2025 // In the rhs of [3] the products involving dAp could be expressed as products
2026 // of Ap^i, A^j, dA^k with i, j, k in {1, H}, where X^H = X.mH(). To prove that,
2027 // note (A Ap)^H = A Ap and (Ap A)^H = Ap A, which could be shown by taking the
2028 // product between the SVD decompositions of A and Ap. Consider the
2029 // conjugate-transposed [2]: (A Ap A)^H = A^H (A Ap) = A^H. By differentiating
2030 // it we get: dA^H A Ap + A^H dA Ap + A^H A dAp = dA^H. By multiplying from the
2031 // left by Ap^H and using Ap^H A^H = (A Ap)^H = A Ap: Ap^H dA^H A Ap + A Ap dA
2032 // Ap + A Ap A dAp = Ap^H dA^H. By multiplying from the left by Ap and by
2033 // applying [1] and [2] repeatedly until impossible we get: Ap Ap^H dA^H A Ap +
2034 // Ap dA Ap + Ap A dAp = Ap Ap^H dA^H. By rearranging the terms:
2035 //
2036 // Ap A dAp = -Ap dA Ap + Ap Ap^H dA^H (I - A Ap) [4],
2037 // which is one of the summands in [3].
2038 //
2039 // Similar, by differentiating the transpose-conjugated [2] written differently,
2040 // i.e. (A Ap A)^H = Ap A A^H = A^H we will get an expression for dAp A Ap,
2041 // which is
2042 //
2043 // dAp A Ap = -Ap dA Ap + (I - Ap A) dA^H Ap^H Ap [5].
2044 //
2045 // By plugging in [4] and [5] into [3] we get the forward AD formula for pinv:
2046 //
2047 // dAp = -Ap dA Ap + (I - Ap A) dA^H Ap^H Ap + Ap Ap^H dA^H (I - A Ap).
pinv_jvp(const Tensor & A,const Tensor & pinvA,const Tensor & dA)2048 Tensor pinv_jvp(const Tensor& A, const Tensor& pinvA, const Tensor& dA) {
2049 at::NoTF32Guard disable_tf32;
2050 auto m = A.size(-2);
2051 auto n = A.size(-1);
2052 auto dAh = dA.mH();
2053 auto pinvAh = pinvA.mH();
2054 // optimization to produce matrices of the smallest dimension
2055 if (m <= n) {
2056 auto K = pinvAh.matmul(dAh);
2057 return pinvA.matmul(K - K.mH() - K.matmul(A.matmul(pinvA))) +
2058 (dAh - pinvA.matmul(A.matmul(dAh))).matmul(pinvAh.matmul(pinvA));
2059 } else {
2060 auto K = pinvA.matmul(dA);
2061 auto Kh = K.mH();
2062 return (Kh - K - pinvA.matmul(A).matmul(Kh)).matmul(pinvA) +
2063 (pinvA.matmul(pinvAh)).matmul(dAh - (dAh.matmul(A)).matmul(pinvA));
2064 }
2065 }
2066
pinv_backward(const Tensor & grad,const Tensor & pinvA,const Tensor & A)2067 Tensor pinv_backward(const Tensor& grad, const Tensor& pinvA, const Tensor& A) {
2068 at::NoTF32Guard disable_tf32;
2069 auto m = A.sym_size(-2);
2070 auto n = A.sym_size(-1);
2071 auto pinvAh = pinvA.mH();
2072 auto gradh = grad.mH();
2073 // optimization to produce matrices of the smallest dimension
2074 if (m <= n) {
2075 auto K = gradh.matmul(pinvA);
2076 auto KpinvAh = K.matmul(pinvAh);
2077 return -(pinvA.matmul(K)).mH() + KpinvAh -
2078 (A.matmul(pinvA)).matmul(KpinvAh) +
2079 (pinvAh.matmul(pinvA)).matmul(gradh - K.matmul(A));
2080 } else {
2081 auto K = pinvA.matmul(gradh);
2082 auto pinvAhK = pinvAh.matmul(K);
2083 return -(K.matmul(pinvA)).mH() +
2084 (gradh - A.matmul(K)).matmul(pinvA).matmul(pinvAh) + pinvAhK -
2085 pinvAhK.matmul(pinvA).matmul(A);
2086 }
2087 }
2088
chunk_backward_nested(const std::vector<torch::autograd::Variable> & grads,const Tensor & self,int64_t chunks,int64_t dim)2089 Tensor chunk_backward_nested(
2090 const std::vector<torch::autograd::Variable>& grads,
2091 const Tensor& self,
2092 int64_t chunks,
2093 int64_t dim) {
2094 TORCH_INTERNAL_ASSERT(
2095 self.layout() == c10::kJagged,
2096 "Nested Strided Tensor doesn't support chunk backward.")
2097 dim = at::maybe_wrap_dim(dim, self.dim());
2098 TORCH_INTERNAL_ASSERT(
2099 dim != 0, "Nested Tensor doesn't support chunk backward on dim=0 yet.")
2100 Tensor ret = at::zeros_like(self);
2101 std::vector<Tensor> rets = at::chunk(ret, chunks, dim);
2102 for (const auto j : c10::irange(grads.size())) {
2103 if (grads[j].defined()) {
2104 rets[j].copy_(grads[j]);
2105 }
2106 }
2107 return ret;
2108 }
2109
split_with_sizes_backward(const std::vector<torch::autograd::Variable> & grads,c10::SymIntArrayRef split_sizes,int64_t dim,c10::SymIntArrayRef sizes,const at::TensorOptions & options)2110 Tensor split_with_sizes_backward(
2111 const std::vector<torch::autograd::Variable>& grads,
2112 c10::SymIntArrayRef split_sizes,
2113 int64_t dim,
2114 c10::SymIntArrayRef sizes,
2115 const at::TensorOptions& options) {
2116 dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
2117
2118 // it's possible some of the grads are not defined (represents tensors of all
2119 // 0s). Since at::cat can't handle those, let's define them
2120 std::vector<Tensor> grads_all_defined(grads.size());
2121 for (const auto j : c10::irange(grads.size())) {
2122 if (grads[j].defined()) {
2123 grads_all_defined[j] = grads[j];
2124 } else {
2125 const auto& length = split_sizes[j];
2126 auto grad_size = sizes.vec();
2127 grad_size[dim] = length;
2128 grads_all_defined[j] = at::zeros_symint(grad_size, options);
2129 }
2130 }
2131
2132 auto ret = at::cat(grads_all_defined, dim);
2133 return ret;
2134 }
2135
_nested_split_with_sizes_backward(const std::vector<torch::autograd::Variable> & grads,c10::SymIntArrayRef split_sizes,int64_t dim,const Tensor & nt_sizes,const at::TensorOptions & options)2136 Tensor _nested_split_with_sizes_backward(
2137 const std::vector<torch::autograd::Variable>& grads,
2138 c10::SymIntArrayRef split_sizes,
2139 int64_t dim,
2140 const Tensor& nt_sizes,
2141 const at::TensorOptions& options) {
2142 // add 1 to account for batch dim
2143 dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(nt_sizes.size(1)) + 1);
2144 // it's possible some of the grads are not defined (represents tensors of all
2145 // 0s). Since at::cat can't handle those, let's define them
2146 std::vector<Tensor> grads_all_defined;
2147 for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
2148 if (grads[i].defined()) {
2149 grads_all_defined.push_back(static_cast<Tensor>(grads[i]));
2150 } else {
2151 const auto& length = split_sizes[i].guard_int(__FILE__, __LINE__);
2152 auto nt_split_size = nt_sizes.clone();
2153 auto nt_split_size_ptr = nt_split_size.data_ptr<int64_t>();
2154 for (int64_t j : c10::irange(static_cast<int64_t>(nt_sizes.size(0)))) {
2155 // subtract 1 to account for batch dim
2156 nt_split_size_ptr
2157 [j * static_cast<int64_t>(nt_sizes.size(1)) + (dim - 1)] = length;
2158 }
2159 Tensor zeros_buffer = at::zeros(
2160 {at::native::get_numel_from_nested_size_tensor(nt_split_size)},
2161 options);
2162 auto nt_split_grad = at::native::wrap_buffer(zeros_buffer, nt_split_size);
2163 grads_all_defined.push_back(nt_split_grad);
2164 }
2165 }
2166
2167 auto ret = at::cat(grads_all_defined, dim);
2168 return ret;
2169 }
2170
split_backward(const std::vector<torch::autograd::Variable> & grads,const c10::SymInt & split_size,int64_t dim,c10::SymIntArrayRef sym_sizes,const at::TensorOptions & options)2171 Tensor split_backward(
2172 const std::vector<torch::autograd::Variable>& grads,
2173 const c10::SymInt& split_size,
2174 int64_t dim,
2175 c10::SymIntArrayRef sym_sizes,
2176 const at::TensorOptions& options) {
2177 dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(sym_sizes.size()));
2178 const auto& dim_size = sym_sizes[dim];
2179 auto num_splits = grads.size();
2180 std::vector<c10::SymInt> split_sizes(num_splits, split_size);
2181 split_sizes[num_splits - 1] =
2182 split_size - (split_size * num_splits - dim_size);
2183 return split_with_sizes_backward(grads, split_sizes, dim, sym_sizes, options);
2184 }
2185
max_pool_double_backward(const Tensor & grad,const Tensor & indices,int dim)2186 Tensor max_pool_double_backward(
2187 const Tensor& grad,
2188 const Tensor& indices,
2189 int dim) {
2190 AT_ASSERT(indices.dim() >= dim);
2191 // handle non-empty inputs
2192 if (indices.sym_numel() != 0) {
2193 auto size = indices.sym_sizes().slice(0, indices.dim() - dim).vec();
2194 size.emplace_back(-1);
2195 auto indices_view = indices.view_symint(size);
2196 const auto memory_format = indices.suggest_memory_format();
2197 return grad.contiguous(memory_format)
2198 .view_symint(size)
2199 .gather(-1, indices_view)
2200 .view_symint(indices.sym_sizes());
2201 }
2202 // handle empty inputs
2203 else {
2204 return at::empty_like(indices, grad.options());
2205 }
2206 }
2207
error_for_max_pool2d_double_backward()2208 Tensor error_for_max_pool2d_double_backward() { // This is mps-only.
2209 TORCH_CHECK(
2210 false,
2211 "max_pool2d with `return_indices=False` is not infinitely differentiable.",
2212 " If you want to calculate higher order derivatives, e.g. second order,",
2213 " set `return_indices=True`.");
2214 return Tensor();
2215 }
2216
glu_double_backward(const Tensor & grad,const Tensor & grad_output,const Tensor & input,int64_t dim)2217 Tensor glu_double_backward(
2218 const Tensor& grad,
2219 const Tensor& grad_output,
2220 const Tensor& input,
2221 int64_t dim) {
2222 auto& gO = grad_output;
2223 auto input_size = input.size(dim) / 2;
2224 auto first_half = input.narrow(dim, 0, input_size);
2225 auto second_half = input.narrow(dim, input_size, input_size);
2226 auto sig_second_half = second_half.sigmoid();
2227 auto one_sub_sig_second_half = 1 - sig_second_half;
2228 auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half;
2229
2230 auto ggI_first_half = grad.narrow(dim, 0, input_size);
2231 auto ggI_second_half = grad.narrow(dim, input_size, input_size);
2232 auto ggI_second_half_times_first_half = ggI_second_half * first_half;
2233
2234 auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig;
2235 auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half -
2236 sig_second_half * sig_one_sub_sig;
2237 auto gI_second_half =
2238 ggI_second_half_times_first_half * gO * second_order_sh +
2239 ggI_first_half * gO * sig_one_sub_sig;
2240 return at::cat({std::move(gI_first_half), std::move(gI_second_half)}, dim);
2241 }
2242
glu_double_backward_grad_output(const Tensor & grad,const Tensor & input,int64_t dim)2243 Tensor glu_double_backward_grad_output(
2244 const Tensor& grad,
2245 const Tensor& input,
2246 int64_t dim) {
2247 if (dim < 0)
2248 dim += input.dim();
2249 auto sizes = input.sizes().vec();
2250 sizes[dim] /= 2;
2251 auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim);
2252 return tmp.narrow(dim, 0, sizes[dim]) +
2253 tmp.narrow(dim, sizes[dim], sizes[dim]);
2254 }
2255
infinitely_differentiable_silu_backward(const Tensor & grad_output,const Tensor & input)2256 Tensor infinitely_differentiable_silu_backward(
2257 const Tensor& grad_output,
2258 const Tensor& input) {
2259 const Tensor sigmoid = input.sigmoid();
2260 return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid));
2261 }
2262
infinitely_differentiable_mish_backward(const Tensor & grad_output,const Tensor & input)2263 Tensor infinitely_differentiable_mish_backward(
2264 const Tensor& grad_output,
2265 const Tensor& input) {
2266 const Tensor sigmoid = input.sigmoid();
2267 const Tensor softplus = input.exp().log1p();
2268 const Tensor tanh_softplus = softplus.tanh();
2269 return grad_output *
2270 (tanh_softplus + input * sigmoid * (1.0 - tanh_softplus * tanh_softplus));
2271 }
2272
infinitely_differentiable_logit_backward(const Tensor & grad,const Tensor & self,std::optional<double> eps)2273 Tensor infinitely_differentiable_logit_backward(
2274 const Tensor& grad,
2275 const Tensor& self,
2276 std::optional<double> eps) {
2277 if (eps) {
2278 const double lo = eps.value();
2279 const double hi = 1.0 - lo;
2280 return at::where(
2281 at::logical_and(self >= lo, self <= hi),
2282 grad / (self * (1.0 - self)),
2283 at::zeros({}, self.options()));
2284 } else {
2285 return at::where(
2286 at::logical_and(self >= 0.0, self <= 1.0),
2287 grad / (self * (1.0 - self)),
2288 at::empty({}, self.options())
2289 .fill_(std::numeric_limits<double>::quiet_NaN()));
2290 }
2291 }
2292
binary_cross_entropy_target_backward(const Tensor & grad,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction)2293 Tensor binary_cross_entropy_target_backward(
2294 const Tensor& grad,
2295 const Tensor& self,
2296 const Tensor& target,
2297 const std::optional<Tensor>& weight,
2298 int64_t reduction) {
2299 auto grad_target = at::logit(self).neg_();
2300
2301 if (!areAnyTensorSubclassLike({grad})) {
2302 grad_target.mul_(grad);
2303 } else {
2304 grad_target = grad_target * grad;
2305 }
2306
2307 if (isDefined(weight)) {
2308 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2309 if (!isTensorSubclassLike(weight.value())) {
2310 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2311 grad_target.mul_(weight.value());
2312 } else {
2313 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2314 grad_target = grad_target * weight.value();
2315 }
2316 }
2317
2318 if (reduction == at::Reduction::Mean) {
2319 grad_target.div_(target.sym_numel());
2320 }
2321
2322 return grad_target;
2323 }
2324
binary_cross_entropy_double_backward_target(const Tensor & grad,const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction)2325 Tensor binary_cross_entropy_double_backward_target(
2326 const Tensor& grad,
2327 const Tensor& grad_output,
2328 const Tensor& self,
2329 const Tensor& target,
2330 const std::optional<Tensor>& weight,
2331 int64_t reduction) {
2332 auto res = -grad * grad_output;
2333
2334 if (isDefined(weight)) {
2335 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2336 res = isTensorSubclassLike(weight.value())
2337 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2338 ? res.mul(weight.value())
2339 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2340 : res.mul_(weight.value());
2341 }
2342
2343 auto neg_self = 1 - self;
2344 auto denom =
2345 isTensorSubclassLike(self) ? neg_self.mul(self) : neg_self.mul_(self);
2346 {
2347 at::NoGradGuard guard;
2348 // Default eps in binary_cross_entropy for ALL dtypes
2349 // TODO: probably change this to a dtype-dependent value
2350 double eps = 1e-12;
2351 denom.clamp_min_(eps);
2352 }
2353
2354 res = isTensorSubclassLike(denom) ? res.div(denom) : res.div_(denom);
2355
2356 if (reduction == at::Reduction::Mean) {
2357 res.div_(target.sym_numel());
2358 }
2359
2360 return res;
2361 }
2362
binary_cross_entropy_with_logits_backward(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight,const std::optional<Tensor> & pos_weight,int64_t reduction)2363 Tensor binary_cross_entropy_with_logits_backward(
2364 const Tensor& grad,
2365 const Tensor& input,
2366 const Tensor& target,
2367 const std::optional<Tensor>& weight,
2368 const std::optional<Tensor>& pos_weight,
2369 int64_t reduction) {
2370 // Trivial case
2371 if (grad._is_zerotensor()) {
2372 return at::_efficientzerotensor(input.sizes(), input.options());
2373 }
2374
2375 // -w * [ pos * y * (1 -sigmoid(x)) - (1 - y) sigmoid(x)] * grad
2376
2377 // If there are subclassed tensors use the out of place version
2378 Tensor grad_input;
2379 if (isDefined(pos_weight)) {
2380 // pos_weight might need to be broadcasted, thus mul(target) is not inplace.
2381 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2382 auto t = pos_weight->mul(target);
2383 grad_input = at::areAnyTensorSubclassLike({input, target}) ||
2384 at::GradMode::is_enabled()
2385 ? t.add(1).sub(target).mul(input.sigmoid()).sub(t)
2386 : t.add(1).sub_(target).mul_(input.sigmoid()).sub_(t);
2387 } else {
2388 grad_input = at::areAnyTensorSubclassLike({input, target}) ||
2389 at::GradMode::is_enabled()
2390 ? input.sigmoid().sub(target)
2391 : input.sigmoid().sub_(target);
2392 }
2393
2394 if (at::isTensorSubclassLike(grad) || at::GradMode::is_enabled()) {
2395 grad_input = grad_input.mul(grad);
2396 } else {
2397 grad_input.mul_(grad);
2398 }
2399
2400 if (isDefined(weight)) {
2401 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2402 if (at::isTensorSubclassLike(*weight) || at::GradMode::is_enabled()) {
2403 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2404 grad_input = grad_input.mul(*weight);
2405 } else {
2406 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2407 grad_input.mul_(*weight);
2408 }
2409 }
2410
2411 if (reduction == at::Reduction::Mean) {
2412 grad_input.div_(input.sym_numel());
2413 }
2414
2415 return grad_input;
2416 }
2417
binary_cross_entropy_with_logits_target_backward(const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,const std::optional<Tensor> & pos_weight,int64_t reduction)2418 Tensor binary_cross_entropy_with_logits_target_backward(
2419 const Tensor& grad_output,
2420 const Tensor& self,
2421 const Tensor& target,
2422 const std::optional<Tensor>& weight,
2423 const std::optional<Tensor>& pos_weight,
2424 int64_t reduction) {
2425 if (grad_output._is_zerotensor()) {
2426 return at::_efficientzerotensor(target.sizes(), target.options());
2427 }
2428
2429 Tensor grad_target;
2430 if (isDefined(pos_weight)) {
2431 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2432 if (areAnyTensorSubclassLike({*pos_weight, grad_output})) {
2433 grad_target = at::log_sigmoid(-self)
2434 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2435 .sub(at::log_sigmoid(self).mul(*pos_weight))
2436 .mul(grad_output);
2437 } else {
2438 grad_target = at::log_sigmoid(-self)
2439 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2440 .sub_(at::log_sigmoid(self).mul_(*pos_weight))
2441 .mul_(grad_output);
2442 }
2443 } else {
2444 grad_target = -self * grad_output;
2445 }
2446
2447 if (isDefined(weight)) {
2448 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2449 if (at::isTensorSubclassLike(*weight)) {
2450 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2451 grad_target = grad_target.mul(*weight);
2452 } else {
2453 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2454 grad_target.mul_(*weight);
2455 }
2456 }
2457
2458 if (reduction == at::Reduction::Mean) {
2459 grad_target.div_(target.sym_numel());
2460 }
2461
2462 return grad_target;
2463 }
2464
log_sigmoid_double_backward(const Tensor & grad,const Tensor & input)2465 Tensor log_sigmoid_double_backward(const Tensor& grad, const Tensor& input) {
2466 auto z = input.sigmoid();
2467 return grad * (z - 1) * z;
2468 }
2469
softmax_double_backward(const Tensor & grad,const Tensor & grad_output,int dim,const Tensor & output)2470 Tensor softmax_double_backward(
2471 const Tensor& grad,
2472 const Tensor& grad_output,
2473 int dim,
2474 const Tensor& output) {
2475 return grad_output * grad - (output * grad_output).sum(dim, true) * grad -
2476 grad_output * (output * grad).sum(dim, true);
2477 }
2478
2479 // NOTE: [How to write vmap-compatible backward formulas]
2480 //
2481 // See NOTE: [vmap-incompatible in-place operations] for what it means for an
2482 // in-place operation to be incompatible with vmap.
2483 //
2484 // If an in-place operation used in a backward formula is vmap-incompatible,
2485 // then as developers we have the following options:
2486 //
2487 // - If the in-place operation directly followed the creation of a tensor with
2488 // a factory function like at::zeros(...), we should replace the factory with
2489 // a corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call
2490 // propagates the batch dims to the resulting tensor.
2491 // For example:
2492 // Before: at::zeros(input.sizes(), grad.options()).copy_(grad)
2493 // After: grad.new_zeros(input.sizes()).copy_(grad)
2494 //
2495 // - If the in-place operation followed some sequence of operations, if the
2496 // we want to be able to vmap over the backward formula as-is (this is
2497 // usually the case for simple (<15loc) backward formulas), then use
2498 // areAnyTensorSubclassLike to guard the operation. For example:
2499 // c = a * b
2500 // Before: c.mul_(grad)
2501 // After: c = !areAnyTensorSubclassLike({c, grad}) ? c.mul_(grad) : c *
2502 // grad
2503 //
2504 // - If we don't want to vmap directly over the backward formula (e.g., if the
2505 // backward formula is too complicated or has a lot of vmap-incompatible
2506 // operations, then register the backward formula as an operator and
2507 // eventually write a batching rule for it.
2508
binary_cross_entropy_double_backward(const Tensor & grad_output,const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction)2509 Tensor binary_cross_entropy_double_backward(
2510 const Tensor& grad_output,
2511 const Tensor& grad,
2512 const Tensor& input,
2513 const Tensor& target,
2514 const std::optional<Tensor>& weight,
2515 int64_t reduction) {
2516 auto eps = 1e-12;
2517 auto inp_pl_eps = input + eps;
2518 auto one_m_inp_pl_eps = 1 - input + eps;
2519 // gradient wrt input
2520 auto gI = (input * input - 2 * input * target + target) /
2521 (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2));
2522 if (!areAnyTensorSubclassLike({gI, grad})) {
2523 gI *= (grad * grad_output);
2524 } else {
2525 gI = gI * (grad * grad_output);
2526 }
2527
2528 if (isDefined(weight)) {
2529 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2530 if (!isTensorSubclassLike(*weight)) {
2531 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2532 gI *= *weight;
2533 } else {
2534 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2535 gI = gI.mul(*weight);
2536 }
2537 }
2538 if (reduction == at::Reduction::Mean) {
2539 return gI / input.sym_numel();
2540 }
2541
2542 return gI;
2543 }
2544
binary_cross_entropy_double_backward_grad_output(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction)2545 Tensor binary_cross_entropy_double_backward_grad_output(
2546 const Tensor& grad,
2547 const Tensor& input,
2548 const Tensor& target,
2549 const std::optional<Tensor>& weight,
2550 int64_t reduction) {
2551 auto eps = 1e-12;
2552 // gradient wrt grad_output
2553 auto ggO = (input - target) / ((input + eps) * (1 - input + eps));
2554 if (!areAnyTensorSubclassLike({ggO, grad})) {
2555 ggO *= grad;
2556 } else {
2557 ggO = ggO * grad;
2558 }
2559
2560 if (isDefined(weight)) {
2561 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2562 if (!isTensorSubclassLike(*weight)) {
2563 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2564 ggO *= *weight;
2565 } else {
2566 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2567 ggO = ggO.mul(*weight);
2568 }
2569 }
2570 if (reduction == at::Reduction::Mean) {
2571 return ggO / input.sym_numel();
2572 }
2573 return ggO;
2574 }
2575
smooth_l1_loss_double_backward(const Tensor & grad,const Tensor & input,const Tensor & target,int64_t reduction,double beta)2576 Tensor smooth_l1_loss_double_backward(
2577 const Tensor& grad,
2578 const Tensor& input,
2579 const Tensor& target,
2580 int64_t reduction,
2581 double beta) {
2582 // special case to protect against a divide-by-zero.
2583 if (beta == 0) {
2584 return at::zeros(grad.sizes(), grad.options());
2585 }
2586 auto d = (input - target).abs();
2587 auto grad_input = grad * (d < beta).type_as(grad) / beta;
2588 if (reduction == at::Reduction::Mean) {
2589 grad_input /= input.sym_numel();
2590 }
2591 return grad_input;
2592 }
2593
huber_loss_double_backward(const Tensor & grad,const Tensor & input,const Tensor & target,int64_t reduction,double delta)2594 Tensor huber_loss_double_backward(
2595 const Tensor& grad,
2596 const Tensor& input,
2597 const Tensor& target,
2598 int64_t reduction,
2599 double delta) {
2600 auto d = (input - target).abs();
2601 auto grad_input = grad * (d < delta);
2602 if (reduction == at::Reduction::Mean) {
2603 grad_input /= input.sym_numel();
2604 }
2605 return grad_input;
2606 }
2607
huber_loss_double_backward_grad_output(const Tensor & grad,const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,double delta)2608 Tensor huber_loss_double_backward_grad_output(
2609 const Tensor& grad,
2610 const Tensor& grad_output,
2611 const Tensor& input,
2612 const Tensor& target,
2613 int64_t reduction,
2614 double delta) {
2615 if (reduction == at::Reduction::None) {
2616 return huber_loss_backward(grad, input, target, reduction, delta);
2617 }
2618 auto r = huber_loss_backward(
2619 ones_like(grad_output), input, target, reduction, delta);
2620 return (r * grad).sum();
2621 }
2622
mse_loss_double_backward(const Tensor & grad,const Tensor & input,int64_t reduction)2623 Tensor mse_loss_double_backward(
2624 const Tensor& grad,
2625 const Tensor& input,
2626 int64_t reduction) {
2627 auto grad_input = 2 * grad;
2628 if (reduction == at::Reduction::Mean) {
2629 grad_input /= input.sym_numel();
2630 }
2631 return grad_input;
2632 }
2633
soft_margin_loss_double_backward(const Tensor & grad,const Tensor & input,const Tensor & target,int64_t reduction)2634 Tensor soft_margin_loss_double_backward(
2635 const Tensor& grad,
2636 const Tensor& input,
2637 const Tensor& target,
2638 int64_t reduction) {
2639 auto z = (input * -target).exp();
2640 auto zplus1 = z + 1;
2641 auto grad_input = grad * (target * target) * z / (zplus1 * zplus1);
2642 if (reduction == at::Reduction::Mean) {
2643 grad_input /= input.sym_numel();
2644 }
2645 return grad_input;
2646 }
2647
soft_margin_loss_double_backward_grad_output(const Tensor & grad,const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction)2648 Tensor soft_margin_loss_double_backward_grad_output(
2649 const Tensor& grad,
2650 const Tensor& grad_output,
2651 const Tensor& input,
2652 const Tensor& target,
2653 int64_t reduction) {
2654 if (reduction == at::Reduction::None) {
2655 return soft_margin_loss_backward(grad, input, target, reduction);
2656 }
2657 auto r = soft_margin_loss_backward(
2658 ones_like(grad_output), input, target, reduction);
2659 return (r * grad).sum();
2660 }
2661
softplus_double_backward(const Tensor & grad,const Tensor & input,const Scalar & beta,const Scalar & threshold)2662 Tensor softplus_double_backward(
2663 const Tensor& grad,
2664 const Tensor& input,
2665 const Scalar& beta,
2666 const Scalar& threshold) {
2667 auto x = (input * beta);
2668 return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) *
2669 beta;
2670 }
2671
2672 // NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
2673 //
2674 // `storage_offset` is ignored for simplicity in this note. If you just want the
2675 // full algorithm without explanation, scroll down to bottom of this note.
2676 //
2677 // Implementing the backward of as_strided is tricky because you have to deal
2678 // with mappings that map one memory location to multiple indices, i.e., the
2679 // output tensor has multiple indices pointing to **overlapping** memory
2680 // addresses. This can happen in all in all sorts of weird cases. For example,
2681 //
2682 // x = torch.randn(15)
2683 // x.as_strided([3, 3], [1, 0]) # "expand" case
2684 // x.as_strided([3, 3], [2, 1]) # "size too large" case
2685 // x.as_strided([3, 2], [3, 6]) # res[2, 0] points to 2*3 + 0*6 = 6
2686 // # res[0, 1] points to 0*3 + 1*6 = 6
2687 //
2688 // Here is the general strategy we apply in implementing as_strided backward:
2689 // 0. ??? (optimization step. we will talk about this later)
2690 // 1. Create some underlying flattened tensor as if it is the base tensor
2691 // representing the contiguous memory storage for both input and output.
2692 // 2. Use the output geometry to scatter (or index_add) the gradients into
2693 // this storage tensor.
2694 // 3. ??? (fix for input tensor with overlapping memory. we will talk about
2695 // this later)
2696 // 4. Return the as_strided view of the storage tensor using input geometry.
2697 //
2698 // In step (2), if the output tensor doesn't have overlapping memory, we can
2699 // safely scatter (`storage.as_strided(output_geometry).copy_(grad)`);
2700 // otherwise, we must use `index_add` as gradients at different indices may need
2701 // to be summed to a single location.
2702 //
2703 // For example, in this case:
2704 //
2705 // x = torch.randn(3)
2706 // y = x.as_strided([3, 3], [1, 0]) # "expand" case
2707 // # size [ 3, 3]
2708 // # stride [ 1, 0]
2709 // y.backward() # step (1): contiguous storagte tensor `s` of size 3, which
2710 // is large enough to be used as underlying storage
2711 // for `x` and `y`.
2712 // s = [ 0, 0, 0]
2713 // # step (2): since `y` has overlapping memory, index_add grad
2714 // into `s` basing on `y`'s geometry, i.e.,
2715 // s[i * y.stride(0) + j * y.stride(1)] += gy[i, j].
2716 // s = [ 3, 3, 3]
2717 // # step (4): as_strided view `s` using `x`'s geometry
2718 // s = [ 3, 3, 3]
2719 // grad_input = s.as_strided(x.size(), x.stride())
2720 // = s.as_strided([3], [1])
2721 // = [ 3, 3, 3]
2722 //
2723 // This is exactly what we would get if using `expand`. However, here the input
2724 // tensor doesn't have overlapping memory. If it does, we must add an extra step
2725 // before (4). Considering this case:
2726 //
2727 // t = torch.randn(3)
2728 // x = t.expand(3, 3) # input with overlapping memory
2729 // # size [3, 3]
2730 // # stride [0, 1]
2731 // y = x.as_strided([1], [1]) # contiguous output
2732 // # size [1]
2733 // # stride [1]
2734 // y.backward() # step (1): contiguous storage tensor `s` of size 3, which
2735 // is large enough to be used as underlying storage
2736 // for `x` and `y`.
2737 // s = [ 0, 0, 0]
2738 // # step (2): scatter grad into `s` basing on `y`'s geometry
2739 // s = [ 1, 0, 0]
2740 // # step (4): as_strided view `s` using `x`'s geometry
2741 // s = [ 1, 0, 0]
2742 // grad_input = s.as_strided([3, 3], [0, 1])
2743 // = s.as_strided([3, 3], [0, 1])
2744 // = [[ 1, 0, 0],
2745 // [ 1, 0, 0],
2746 // [ 1, 0, 0]]
2747 // Is this result correct?
2748 //
2749 // `x.as_strided([1], [1])` call is obviously equivalent with
2750 // `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second
2751 // gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific
2752 // case, indexing `x` at any index in first column is also equivalent, and
2753 // yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is
2754 // an `x.size(1)`-times difference between these gradients computed from other
2755 // PyTorch ops and the gradient we got from as_strided.
2756 //
2757 // You might conclude that the gradients from as_strided is wrong. However,
2758 // let's first see why they are actually reasonable. Consider the pointwise
2759 // perturbations by `delta` anywhere in the first column of `x`. It will lead to
2760 // a `delta` change in the same memory location, and then `y` will change by
2761 // `delta`. So one can say the gradient should be exactly 1 at the first column,
2762 // as given by our above procedure.
2763 //
2764 // In the above computation of numerical gradients, they only match the
2765 // analytical results because strides and memory locations are considered in the
2766 // forward pass, i.e., this op (including both forward and backward) is
2767 // layout-aware.
2768 //
2769 // However, in PyTorch, most (probably all) other ops (forward and backward) are
2770 // layout-agnostic. E.g.,
2771 //
2772 // t = torch.randn(1)
2773 // x = t.expand(2)
2774 // y = x.sum()
2775 // y.backward()
2776 //
2777 // Layout-agnostic autograd (as it is currently in PyTorch) will give you
2778 //
2779 // gy = 1
2780 // gx = [ 1, 1] # SumBackward: torch.ones_like(x)
2781 // gt = [ 2] # ExpandBackward: gx.sum()
2782 //
2783 // Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta`
2784 // (the other will also change by `delta`), `y` will change by `2 * delta`. So
2785 // the gradients, if strides are taken into consideration, should be 2.
2786 //
2787 // Layout-aware autograd should give you
2788 //
2789 // gy = 1
2790 // gx = [ 2, 2] # Because the backward considers the fact that the input `x`
2791 // # is already expanded.
2792 // gt = [ 2] # Layout-aware backward of expand is just a slicing because
2793 // # the previous backward should have already taken care of
2794 // # strides and made sure that gradients are the same along the
2795 // # expanded dimension.
2796 //
2797 // As shown above, these two types are not compatible. Therefore, we must either
2798 // make as_strided layout-agnostic, or make all other ops layout-aware.
2799 //
2800 // It is difficult to support layout-aware autograd (at least in the current
2801 // codebase structure), because it would mean
2802 // 1. storing tensor geometries of every input tensor for backward
2803 // 2. depending on input geometry, the gradient computed from backward change
2804 // 3. ideally enforcing gradient of T to always have same strides as T
2805 // (although these two methods only differ when it comes to overlapping memory)
2806 //
2807 // Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e.,
2808 // giving the same output regardless of the input layout. We consider
2809 // `input.stride()` as a separate independent fixed argument `input_stride`.
2810 // Then, `as_strided(input, size, stride)` can be thought of as:
2811 // 1. "Scatter" each value of `input` into a "storage" using storage location
2812 // computed from the value's index in `input`, `input.size()` and
2813 // `input_stride`, but if N values end up in the same location, the value
2814 // is average of those N values (they will be the same value anyways).
2815 //
2816 // Formal description:
2817 // Denote the set of all input indices that pointing to the same storage
2818 // location `storage[n]` as `S(n)`, i.e.,
2819 //
2820 // S(n) = { index : <index, input_stride> == n, index is valid given
2821 // input.size() },
2822 //
2823 // where `<x, y>` is the dot product between `x` and `y`.
2824 //
2825 // Then, the process is:
2826 //
2827 // storage[n] = Avg { S(n) }
2828 //
2829 // Note that all values in `S(n)` are the same (they point to the same
2830 // memory location anyways, so this step doesn't change anything, but
2831 // effectively avoids having the dependency on the layout of `input`.
2832 // I.e., the result holds fixed regardless of the layout of `input`, as
2833 // long as `input_stride` is fixed.
2834 //
2835 // NOTE: for forward pass, we can equivalently simply select any one of
2836 // `S(n)` as `storage[n]`. However, considering this as an average
2837 // operation makes backward easier (so all values in set
2838 // `{ grad_input[i] : i in S(n) }` are the same, and it can use the
2839 // same geometry as input).
2840 // 2. As usual, return the as_strided view of `storage` using required output
2841 // `size` and `stride`.
2842 //
2843 // To backward through this layout-agnostic version, we simply add the following
2844 // step:
2845 // .... (scatter gradients into the storage tensor using output geometry)
2846 // 3. For all storage location n, `storage[n] /= |S(n)|`.
2847 // .... (return as_strided view of the storage tensor using input geometry)
2848 //
2849 // Finally, we note that these general operations are expensive, so we apply the
2850 // following optimizations:
2851 // Add step (0): For all output dimension `d` with output stride 0, sum the
2852 // gradients along dimension `d` (don't keepdim), and remove
2853 // dimension `d` from output size and stride.
2854 // (An optimization for "expand" cases so we may avoid step (3))
2855 // Only apply step (3) when input tensor has overlapping memory.
2856 //
2857 // FULL ALGORITHM:
2858 // 0. For all output dimension `d` with output stride 0, sum the gradients
2859 // along dimension `d` (don't keepdim), and remove dimension `d` from
2860 // output size and stride.
2861 // 1. Create some underlying flattened tensor as if it is the base tensor
2862 // representing the contiguous memory storage for both input and output.
2863 // 2. Use the output geometry to scatter (or index_add) the gradients into
2864 // this storage tensor `storage`.
2865 // 3. If input tensor has overlapping memory,
2866 // For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the
2867 // number of indices in input geometry pointing to the same storage
2868 // location `i` (i.e., `|S(i)|` in equations above).
2869 // 4. Return the as_strided view of the storage tensor using input geometry.
2870 //
2871 // See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to
2872 // roughly detech overlapping memory.
2873
2874 // NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
2875 //
2876 // Checking memory overlap within a strided tensor is the special case of
2877 // detecting memory overlap of two strided tensors, where the two tensors start
2878 // at the same memory address. The later is HARD (see #8212).
2879 //
2880 // But even this special case isn't simple. This note describes a check for a
2881 // even more constrained simple case where we can be certain that there is no
2882 // overlap.
2883 //
2884 // The checking algorithm can be described as:
2885 // 0. Return [ pass check ] if any dimension has size 0
2886 // 1. Ignore all dimensions that have size 1
2887 // 2. If no remaining dimensions, return [ pass check ]
2888 // 3. Sort the remaining dimensions according to the strides decreasingly
2889 // 4. Check that for each dimension k,
2890 //
2891 // stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i]
2892 //
2893 // That is equivalent to, after reordering the dimensions so strides are
2894 // in decreasing order, checking that stride of each dimension is larger
2895 // than the maximum memory offset in a slice at that dimension.
2896 //
2897 // Obviously this check passes for contiguous tensors ( the dimensions will be
2898 // already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger
2899 // than RHS ). Similarly, the check passes for tensors contiguous in all but
2900 // the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being
2901 // exactly stride[-1] larger than RHS. (*)
2902 //
2903 // We will show that these view operations, including all our view operations
2904 // *except for* general as_strided and unfold, also preserve this invariant:
2905 //
2906 // alias: Obviously preserves
2907 //
2908 // expand: All changed dimensions are removed in step (1)
2909 //
2910 // view: Consider the input dimensions as grouped into consecutive
2911 // dimension "blocks", where dimensions are contiguous in each one.
2912 // one. view only works when the output dimensions can also be
2913 // grouped into the same consecutive blocks of same ordering.
2914 //
2915 // NB: this means that the number of elements and stride of the
2916 // last dimension in each block is the same in input and
2917 // output. (**)
2918 //
2919 // Notation:
2920 // Consider a single such block B,
2921 // ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [
2922 // B_next[0], ...
2923 // start--^^^^ ^^^^^^^^^^^^--end
2924 // Each B[i] denotes a dimension index such that B[i] = B[0] + i.
2925 //
2926 // We first show that in a tensor (i.e., input) satisfies the
2927 // invariant, after sorting, the dimensions within each block
2928 // still remain consecutive. (***)
2929 //
2930 // After removing dimensions of size 1, the dimensions within a
2931 // block is already sorted by strides in descending order. So
2932 // sorting all dimensions will not change the relative ordering
2933 // among them.
2934 //
2935 // Assume that some block B is not consecutive after sorting,
2936 // i.e., there exists a dimension d between B[0] and B[-1] in
2937 // sorted order.
2938 //
2939 // By (*), we know that
2940 // stride[B[0]]
2941 // = \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] +
2942 // stride[B[-1]] < \sum_{i > 0} (size[B[i]] - 1) *
2943 // stride[B[i]] + stride[d]
2944 // <= \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] +
2945 // (size[d] - 1) * stride[d]
2946 // <= \sum{j > B[0]} (size[j] - 1) * stride[j],
2947 //
2948 // where the first < comes from sorting and
2949 // the second <= comes from the fact that dimension d
2950 // exists after step (1) and
2951 // thus must have size greater
2952 // than 1
2953 // the third <= comes from the fact that each term in
2954 // the sum is non-negative
2955 //
2956 // Then we have a countradiction as the invariant must not be
2957 // satisfied at B[0]. So the original proposition is true.
2958 //
2959 // Now that we established the above claim (***), we consider the
2960 // view operation as first sorting the dimensions (i.e., blocks),
2961 // apply the original view (since it only cares dimensions being
2962 // consecutive and contiguous withtin each block), and then undo
2963 // the sort.
2964 //
2965 // Consider a single block B in the output,
2966 // ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ...
2967 // start--^^^^ ^^^^^^^^^^^^--end
2968 //
2969 // By (*), we know that for all i
2970 // stride[i] = stride[B[-1]] +
2971 // \sum_{j=i+1}^{k} (size[B[j]] - 1) *
2972 // stride[B[j]]
2973 //
2974 // Then the invariant is obviously satisfied at every dimension
2975 // in this block if it is satisfied at dimension B[-1]. It only
2976 // remains to show that it is satisfied at the last dimension in
2977 // each block.
2978 //
2979 // Since the same blocks are present in both input and output
2980 // with the same ordering, we will abuse the notation in the
2981 // following statements.
2982 //
2983 // By (*), we know that the following holds for both input and
2984 // output, for any block B:
2985 // \sum_{i > B[-1]} (size[i] - 1) * stride[i]
2986 // = \sum_{block B' after B} \prod_{j in B'} size[B[j]] *
2987 // stride[B'[-1]] = \sum_{block B' after B} numel(B') *
2988 // stride[B'[-1]].
2989 // ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^
2990 // By (**), we know that, this quantity in the above equation
2991 // remains the same in input and output. So both
2992 // \sum_{i > B[-1]} (size[i] - 1) * stride[i]
2993 // and
2994 // stride[B[-1]]
2995 // are the same in input and output.
2996 //
2997 // These two quantities are exactly the LHS and RHS of the
2998 // invariant inequality. Since by assumption the invariant is
2999 // satisfied in input at B[-1], it is also satisfied in output at
3000 // B[-1]. This concludes the proof.
3001 //
3002 // squeeze: Special case of view
3003 //
3004 // unsqueeze: Special case of view
3005 //
3006 // slice: Consider slicing dimension i with step = k >= 1.
3007 //
3008 // Let stride' and size' be the output strides and sizes. We have
3009 //
3010 // stride'[i] = k * stride[i]
3011 // size'[i] <= floor(size[i] / k)
3012 //
3013 // If size'[i] = 1, invariant is obviously satisfied as we are
3014 // just removing a dimension (afte step (1)).
3015 //
3016 // Assume size'[i] > 1.
3017 //
3018 // By assumption, the invariant is satisfied at every dimension
3019 // in input.
3020 //
3021 // For any dimension j, if stride[j] > stride[i], we have
3022 // stride'[j] = stride[j]
3023 // > (size[i] - 1) * stride[i]
3024 // = (size[i] / k * k - 1) * k * stride[i] / k
3025 // = (size[i] / k - 1 / k) * stride'[i]
3026 // >= (size'[i] - 1 / k) * stride'[i]
3027 // >= stride'[i].
3028 //
3029 // If stride[j] < stride[i], we have
3030 // stride'[j] = stride[j] < stride[i] <= stride'[i].
3031 //
3032 // So the sorting order remains unchanged after slice.
3033 //
3034 // Since
3035 // (size'[i] - 1) * stride'[i]
3036 // = (floor(size[i] / k) - 1) * k * stride[i]
3037 // <= (size[i] / k - 1) * k * stride[i]
3038 // = (size[i] - k) * stride[i]
3039 // <= (size[i] - 1) * * stride[i],
3040 // the term from this dimension i in the invariant inequality at
3041 // other dimensions can only decrease after slice. So the
3042 // invariant is preserved.
3043 //
3044 // narrow: Special case of slice
3045 //
3046 // select: narrow + squeeze
3047 //
3048 // permute: Sorting makes permutation of dimensions irrelevant
3049 //
3050 // transpose: Sorting makes swapping dimensions irrelevant
3051 //
3052 // diagonal: Effectively merging two dimensions i and j into a new
3053 // dimension k s.t.
3054 // stride'[k] = stride[i] + stride[j]
3055 // size'[k] <= min(size[i], size[j]),
3056 // where stride and size are on the input, and stride' and size'
3057 // are on the output.
3058 //
3059 // Assuming that size[i] > 1 and size[j] > 1. If any has size 1,
3060 // then this is unsqueeze on that dimension.
3061 //
3062 // WLOG, say stride[i] >= stride[j].
3063 //
3064 // Each dimension d in input with stride[d] > stride[j] has
3065 // stride'[d] = stride[d]
3066 // > (size[i] - 1) * stride[i] + (size[j] - 1) *
3067 // stride[j]
3068 // >= stride[i] + stride[j]
3069 // = stride[k].
3070 // So, considering the sorted dimensions, this is effectively
3071 // removing i, and replacing j with k.
3072 //
3073 // For dimensions d with stride[i] < stride[d] < stride[j], the
3074 // term from dimension i is removed in the invariant inequality.
3075 // For dimensions d with stride[d] > stride[j], we have
3076 // (size'[k] - 1) * stride'[k]
3077 // <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j])
3078 // <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j],
3079 // so the term from i and j in the invariant can only decrease.
3080 //
3081 // So this is generally relaxing the constraint, and thus it
3082 // preserves it.
3083
3084 // This implements steps (2)~(4) of the algorithm in
3085 // NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
3086 // Helper for as_strided_backward
_maybe_overlapping_memory(c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides)3087 static inline bool _maybe_overlapping_memory(
3088 c10::SymIntArrayRef sizes,
3089 c10::SymIntArrayRef strides) {
3090 if (!sizes.empty()) {
3091 std::vector<std::size_t> argsort(sizes.size());
3092 std::iota(argsort.begin(), argsort.end(), 0);
3093 std::sort(
3094 argsort.begin(), argsort.end(), [&](std::size_t i, std::size_t j) {
3095 return strides[i] < strides[j];
3096 });
3097
3098 c10::SymInt max_index_in_slice = 0;
3099 for (auto i : argsort) {
3100 const auto& stride_ = strides[i];
3101 if (stride_ <= max_index_in_slice) {
3102 return true;
3103 }
3104 max_index_in_slice += stride_ * (sizes[i] - 1);
3105 }
3106 }
3107 return false;
3108 }
3109
3110 // Returns the minimum storage size needed to contain a tensor of sizes,
3111 // strides, and storage_offset Helper for as_strided_backward
_min_storage_size(c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,c10::SymInt storage_offset)3112 static inline c10::SymInt _min_storage_size(
3113 c10::SymIntArrayRef sizes,
3114 c10::SymIntArrayRef strides,
3115 c10::SymInt storage_offset) {
3116 c10::SymInt storage_size = storage_offset + 1;
3117 auto dim = sizes.size();
3118 for (const auto i : c10::irange(dim)) {
3119 const auto& size_i = sizes[i];
3120 if (size_i == 0) {
3121 return storage_offset;
3122 }
3123 storage_size += (size_i - 1) * strides[i];
3124 }
3125 return storage_size;
3126 }
3127
3128 // See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for
3129 // explanation
as_strided_backward(Tensor grad,const TensorGeometry & input_geometry,c10::SymIntArrayRef sym_sizes,c10::SymIntArrayRef sym_strides,const std::optional<c10::SymInt> & sym_storage_offset_)3130 Tensor as_strided_backward(
3131 Tensor grad,
3132 const TensorGeometry& input_geometry,
3133 c10::SymIntArrayRef sym_sizes,
3134 c10::SymIntArrayRef sym_strides,
3135 const std::optional<c10::SymInt>& sym_storage_offset_) {
3136 // For output geometry,
3137 // check for size 0 dimensions,
3138 // skip size 1 dimensions,
3139 // reduce grad on expanded dims (stride=0, size>1)
3140 // Step (0) for the algorithm in NOTE [ as_strided Backward and
3141 // layout-aware/agnostic autograd ] Step (0)~(1) for the algorithm in NOTE [
3142 // Detecting Memory Overlap Within A Strided Tensor ]
3143 // on output geometry
3144 auto sym_storage_offset =
3145 sym_storage_offset_.value_or(input_geometry.sym_storage_offset());
3146 auto odim = grad.dim();
3147 std::vector<c10::SymInt> out_sizes_, out_strides_;
3148 out_sizes_.reserve(odim);
3149 out_strides_.reserve(odim);
3150 for (int64_t i = odim - 1; i >= 0; i--) {
3151 const auto& size_i = sym_sizes[i];
3152 const auto& stride_i = sym_strides[i];
3153 if (size_i == 0) {
3154 return at::zeros_symint(input_geometry.sym_sizes(), grad.options());
3155 } else if (size_i == 1) {
3156 grad = grad.squeeze(i);
3157 } else if (stride_i == 0) {
3158 grad = grad.sum(i, false);
3159 } else {
3160 out_sizes_.insert(out_sizes_.begin(), size_i);
3161 out_strides_.insert(out_strides_.begin(), stride_i);
3162 }
3163 }
3164 // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A
3165 // Strided Tensor ]
3166 // on output geometry
3167 auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_);
3168
3169 // For input geometry,
3170 // check for size 0 dimensions,
3171 // skip size 1 dimensions,
3172 // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A
3173 // Strided Tensor ]
3174 // on input geometry
3175 auto idim = input_geometry.dim();
3176 auto inp_sizes = input_geometry.sym_sizes(),
3177 inp_strides = input_geometry.sym_strides();
3178 std::vector<c10::SymInt> inp_sizes_, inp_strides_;
3179 inp_sizes_.reserve(idim);
3180 inp_strides_.reserve(idim);
3181 for (int64_t i = idim - 1; i >= 0; i--) {
3182 const auto& size_i = inp_sizes[i];
3183 const auto& stride_i = inp_strides[i];
3184 if (size_i == 0) {
3185 return at::zeros_symint(input_geometry.sym_sizes(), grad.options());
3186 } else if (size_i != 1) {
3187 inp_sizes_.insert(inp_sizes_.begin(), size_i);
3188 inp_strides_.insert(inp_strides_.begin(), stride_i);
3189 }
3190 }
3191 // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A
3192 // Strided Tensor ]
3193 // on input geometry
3194 auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_);
3195
3196 // Rest of this function implements
3197 // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and
3198 // layout-aware/agnostic autograd ]
3199 // TODO: Raise if not all output values are visible in input geometry.
3200 // Technically speaking, if you treat those values as constants, not
3201 // raising is fine, and mathematically correct. However, these values
3202 // really are contained in some base tensor, and by treating them as
3203 // constants we are ignoring this tight dependency. Therefore, it is
3204 // more sensible to raise here.
3205
3206 // Step (1): create underlying tensor as "storage"
3207 auto shared_offset =
3208 // TODO: symint-ify. Do we need a min() and max() for SymInts?
3209 input_geometry.sym_storage_offset().min(sym_storage_offset);
3210 auto inp_effective_offset =
3211 input_geometry.sym_storage_offset() - shared_offset;
3212 auto out_effective_offset = sym_storage_offset - shared_offset;
3213 auto base_size1 =
3214 _min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset);
3215 auto base_size2 =
3216 _min_storage_size(out_sizes_, out_strides_, out_effective_offset);
3217 auto base_size = base_size1.max(base_size2);
3218 auto storage = grad.new_zeros_symint(c10::SymIntArrayRef(base_size));
3219
3220 // prepare indices tensor if we will do index_add_ later
3221 std::optional<at::Tensor> flatten_full_indices;
3222 if (inp_maybe_overlap || out_maybe_overlap) {
3223 flatten_full_indices =
3224 // TODO: should we symint-ify arange? Need SymScalar.
3225 at::arange(
3226 0,
3227 base_size.guard_int(__FILE__, __LINE__),
3228 grad.options().dtype(at::kLong));
3229 }
3230
3231 // Step (2): use output geometry to scatter gradients into storage
3232 if (out_maybe_overlap) {
3233 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
3234 auto out_indices = flatten_full_indices->as_strided_symint(
3235 out_sizes_, out_strides_, out_effective_offset);
3236 storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
3237 } else {
3238 // assume that new tensors have 0 storage offset
3239 storage.as_strided_symint(out_sizes_, out_strides_, out_effective_offset)
3240 .copy_(grad);
3241 }
3242
3243 // Step (3): if input tensor has overlapping memory, divide scattered gradient
3244 // at storage[i] by the number of times i shows up in input geometry
3245 if (inp_maybe_overlap) {
3246 auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
3247 auto inp_indices =
3248 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
3249 flatten_full_indices
3250 ->as_strided_symint(inp_sizes_, inp_strides_, inp_effective_offset)
3251 .reshape(-1);
3252 count.index_add_(
3253 0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices));
3254 storage.div_(count); // this will give nan outside visible range
3255 }
3256 // Step (4): return as_strided view of the storage tensor with input geometry
3257 return storage.as_strided_symint(
3258 inp_sizes, inp_strides, inp_effective_offset);
3259 }
3260
as_strided_scatter_backward(const Tensor & grad,const TensorGeometry & input_geometry,const TensorGeometry & src_geometry,c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,std::optional<c10::SymInt> storage_offset)3261 Tensor as_strided_scatter_backward(
3262 const Tensor& grad,
3263 const TensorGeometry& input_geometry,
3264 const TensorGeometry& src_geometry,
3265 c10::SymIntArrayRef sizes,
3266 c10::SymIntArrayRef strides,
3267 std::optional<c10::SymInt> storage_offset) {
3268 // Note [as_strided_scatter backward support]
3269 // as_strided_scatter handling for autograd is a beast, and is non-trivial to
3270 // implement for arbitrarily strided inputs. Most uses for as_strided with
3271 // functionalization only care about the contiguous case anyway, So for now
3272 // this is not implemented. When autograd is being used, we ban non-contiguous
3273 // inputs. We can assume that the input was a contiguous tensor. Also, we'll
3274 // take the perf hit and contiguify grad for now.
3275 auto grad_ = grad.contiguous();
3276 auto grad_slice = grad_.as_strided_symint(sizes, strides, storage_offset);
3277 auto result_buffer = grad_.new_zeros_symint(input_geometry.sym_sizes());
3278 auto result = result_buffer.as_strided_symint(
3279 input_geometry.sym_sizes(), input_geometry.sym_strides());
3280 auto result_slice = result_buffer.as_strided_symint(
3281 sizes, strides, std::move(storage_offset));
3282 result_slice.copy_(grad_slice);
3283 return result;
3284 }
3285
atan2_backward(const Tensor & grad,const Tensor & self,const Tensor & other,std::array<bool,2> output_mask)3286 std::tuple<Tensor, Tensor> atan2_backward(
3287 const Tensor& grad,
3288 const Tensor& self,
3289 const Tensor& other,
3290 std::array<bool, 2> output_mask) {
3291 if (!grad.defined()) {
3292 return std::tuple<Tensor, Tensor>{Tensor(), Tensor()};
3293 }
3294 auto recip = (self * self + other * other).reciprocal();
3295 return std::tuple<Tensor, Tensor>{
3296 output_mask[0] ? grad * other * recip : Tensor(),
3297 output_mask[1] ? grad * -self * recip : Tensor()};
3298 }
3299
gelu_double_backward(const Tensor & ggI,const Tensor & gO,const Tensor & input,c10::string_view approximate)3300 Tensor gelu_double_backward(
3301 const Tensor& ggI,
3302 const Tensor& gO,
3303 const Tensor& input,
3304 c10::string_view approximate) {
3305 // if (at::native::get_gelutype_enum(approximate) ==
3306 // at::native::GeluType::Tanh) {
3307 if (approximate == "tanh") {
3308 constexpr auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
3309 constexpr auto kKappa = 0.044715;
3310
3311 auto inner = kBeta * (input + kKappa * pow(input, 3));
3312 auto tanh_inner = tanh(inner);
3313 auto sech_inner = 1 / cosh(inner);
3314
3315 auto f = 0.5 * input;
3316 auto g = 1 - tanh_inner * tanh_inner;
3317 auto h = kBeta * (1 + 3 * kKappa * input * input);
3318
3319 auto f_prime_gh = 0.5 * g * h;
3320
3321 auto g_prime = (2 * sech_inner) * (-sech_inner * tanh_inner) * h;
3322 auto g_prime_fh = f * h * g_prime;
3323
3324 auto h_prime = 6 * kKappa * input * kBeta;
3325 auto h_prime_fg = f * g * h_prime;
3326
3327 // left_derivative = f_prime_gh
3328 // right_derivative = f_prime_gh + g_prime_fh + h_prime_fg
3329 // dgrad_dX = left_derivative + right_derivative
3330 auto gI = ggI * gO * (2 * f_prime_gh + g_prime_fh + h_prime_fg);
3331 return gI;
3332 } else {
3333 constexpr auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
3334 auto input_sq = input * input;
3335 auto pdf = kBeta * at::exp(-0.5 * input_sq);
3336 auto dgrad_dInput = 2 * pdf - input_sq * pdf;
3337 auto gI = ggI * gO * dgrad_dInput;
3338 return gI;
3339 }
3340 }
3341
elu_double_backward(const Tensor & grad,const Tensor & grad_output,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale,bool is_result,const Tensor & self_or_result)3342 Tensor elu_double_backward(
3343 const Tensor& grad,
3344 const Tensor& grad_output,
3345 const Scalar& alpha,
3346 const Scalar& scale,
3347 const Scalar& input_scale,
3348 bool is_result,
3349 const Tensor& self_or_result) {
3350 if (is_result) {
3351 return grad * grad_output * input_scale *
3352 (self_or_result < 0).type_as(grad);
3353 } else {
3354 return at::elu_backward(
3355 grad * grad_output * input_scale,
3356 alpha,
3357 scale,
3358 input_scale,
3359 is_result,
3360 self_or_result) *
3361 (self_or_result < 0).type_as(grad);
3362 }
3363 }
3364
slice_backward_wrapper(const at::Tensor & grad,const c10::SymIntArrayRef & input_sizes,int64_t dim,std::optional<c10::SymInt> start,std::optional<c10::SymInt> end,c10::SymInt step)3365 Tensor slice_backward_wrapper(
3366 const at::Tensor& grad,
3367 const c10::SymIntArrayRef& input_sizes,
3368 int64_t dim,
3369 std::optional<c10::SymInt> start,
3370 std::optional<c10::SymInt> end,
3371 c10::SymInt step) {
3372 auto start_val = start.has_value() ? start.value() : 0;
3373 auto end_val = end.has_value() ? end.value() : INT64_MAX;
3374
3375 return slice_backward_symint(
3376 grad,
3377 input_sizes,
3378 dim,
3379 std::move(start_val),
3380 std::move(end_val),
3381 std::move(step));
3382 }
3383
linalg_svd_jvp(const Tensor & dA,const Tensor & U_,const Tensor & S,const Tensor & Vh_,const bool full_matrices)3384 std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(
3385 const Tensor& dA,
3386 const Tensor& U_,
3387 const Tensor& S,
3388 const Tensor& Vh_,
3389 const bool full_matrices) {
3390 at::NoTF32Guard disable_tf32;
3391 // See svd_backward for the derivation
3392 // With sym(X) = X + X^H, we implement
3393 // dU = U (sym(dX S) / E + i Im(diag(dX)) / (2S))
3394 // if m > n
3395 // dU = [dU for m == n] + (I_m - UU^H) dA V S^{-1}
3396 // dS = Re(diag(dP))
3397 // dV = V (sym(S dX) / E - i Im(diag(dX)) / (2S))
3398 // if m < n
3399 // dV = [dV for m == n] + (I_n - VV^H) (dA)^H U S^{-1}
3400 // dVh = dV^H
3401 // with dP = U^H dA V
3402 // dX = dP - dS
3403 // E_{jk} = S_k^2 - S_j^2 if j != k
3404 // 1 otherwise
3405
3406 // Checks compute_uv=true
3407 TORCH_INTERNAL_ASSERT(U_.dim() >= 2 && Vh_.dim() >= 2);
3408
3409 const auto is_complex = dA.is_complex();
3410 const auto m = dA.size(-2);
3411 const auto n = dA.size(-1);
3412 const auto k = S.size(-1);
3413
3414 const auto U = full_matrices ? U_.narrow(-1, 0, k) : U_;
3415 const auto Vh = full_matrices ? Vh_.narrow(-2, 0, k) : Vh_;
3416 const auto V = Vh.mH();
3417
3418 // dP = U^H dA V
3419 auto dP = m >= n ? at::matmul(U.mH(), at::matmul(dA, V))
3420 : at::matmul(at::matmul(U.mH(), dA), V);
3421
3422 auto dS =
3423 is_complex ? at::real(dP.diagonal(0, -2, -1)) : dP.diagonal(0, -2, -1);
3424
3425 // dX = dP - dS
3426 dP = dP - dS.diag_embed();
3427
3428 auto E = [&S] {
3429 const auto S2 = S * S;
3430 auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1);
3431 // Any number a != 0 would, as we are just going to use it to compute 0 / a
3432 // later on
3433 ret.diagonal(0, -2, -1).fill_(1);
3434 return ret;
3435 }();
3436
3437 const auto sym = [](const Tensor& X) { return X + X.mH(); };
3438
3439 // diag(dP) / (2S)
3440 auto diagdP2S = is_complex ? dP.diagonal(0, -2, -1).div(2. * S) : Tensor{};
3441
3442 // dU = U (sym(dP S) / E) + i Im(diag(dP)) / (2S)
3443 auto dU = [&] {
3444 auto dUaux = sym(dP * S.unsqueeze(-2)) / E;
3445 if (is_complex) {
3446 dUaux = dUaux + diagdP2S.diag_embed();
3447 }
3448 return at::matmul(U, dUaux);
3449 }();
3450 if (m > n) {
3451 // dU += (I_m - UU^H) dA V S^{-1}
3452 const auto dAVSinv = at::matmul(dA, V / S.unsqueeze(-2));
3453 dU = dU + dAVSinv - at::matmul(U, at::matmul(U.mH(), dAVSinv));
3454
3455 // To "fix" the full_matrices case (the full_matrices case should not be
3456 // differentiable...)
3457 if (full_matrices) {
3458 auto shape = dU.sizes().vec();
3459 shape.end()[-1] = m - n;
3460 dU = at::cat({dU, dU.new_zeros(shape)}, /*dim=*/-1);
3461 }
3462 }
3463
3464 // dVh = -sym(S dP) / E + i Im(diag(dP)) / (2S)
3465 // Perf: We negate the S as it's the smallest tensor in the equation
3466 auto dVh = [&] {
3467 auto dVhaux = sym(dP * (-S).unsqueeze(-1)) / E;
3468 if (is_complex) {
3469 dVhaux = dVhaux + diagdP2S.diag_embed();
3470 }
3471 return at::matmul(dVhaux, Vh);
3472 }();
3473 if (m < n) {
3474 // dVh += S^{-1} U^H dA (I_n - VV^H)
3475 const auto UHdASinv = at::matmul(U.mH() / S.unsqueeze(-1), dA);
3476 dVh = dVh + UHdASinv - at::matmul(at::matmul(UHdASinv, V), Vh);
3477
3478 // To "fix" the full_matrices case (the full_matrices case should not be
3479 // differentiable...)
3480 if (full_matrices) {
3481 auto shape = dVh.sizes().vec();
3482 shape.end()[-2] = n - m;
3483 dVh = at::cat({dVh, dVh.new_zeros(shape)}, /*dim=*/-2);
3484 }
3485 }
3486
3487 return std::make_tuple(std::move(dU), std::move(dS), std::move(dVh));
3488 }
3489
svd_backward(const Tensor & gU,const Tensor & gS,const Tensor & gVh,const Tensor & U,const Tensor & S,const Tensor & Vh)3490 Tensor svd_backward(
3491 const Tensor& gU,
3492 const Tensor& gS,
3493 const Tensor& gVh,
3494 const Tensor& U,
3495 const Tensor& S,
3496 const Tensor& Vh) {
3497 at::NoTF32Guard disable_tf32;
3498 // Throughout both the real and complex case we assume A has distinct singular
3499 // values. Furthermore, if A is rectangular or complex, we assume it's
3500 // full-rank.
3501 //
3502 //
3503 // The real case (A \in R)
3504 // See e.g. https://j-towns.github.io/papers/svd-derivative.pdf
3505 //
3506 // Denote by skew(X) = X - X^T, and by A o B the coordinatewise product, then
3507 // if m == n
3508 // gA = U [(skew(U^T gU) / E)S + S(skew(V^T gV) / E) + I o gS ]V^T
3509 // where E_{jk} = S_k^2 - S_j^2 if j != k and 1 otherwise
3510 //
3511 // if m > n
3512 // gA = [term in m == n] + (I_m - UU^T)gU S^{-1} V^T
3513 // if m < n
3514 // gA = [term in m == n] + U S^{-1} (gV)^T (I_n - VV^T)
3515 //
3516 //
3517 // The complex case (A \in C)
3518 // This one is trickier because the svd is not locally unique.
3519 // Denote L = diag(e^{i\theta_k}), then we have that if A = USV^H, then (UL,
3520 // S, VL) is another valid SVD decomposition of A as A = ULS(VL)^H =
3521 // ULSL^{-1}V^H = USV^H, since L, S and L^{-1} commute, since they are all
3522 // diagonal.
3523 //
3524 // Assume wlog that n >= k in what follows, as otherwise we could reason about
3525 // A^H. Denote by St_k(C^n) = {A \in C^{n,k} | A^H A = I_k} the complex
3526 // Stiefel manifold. What this invariance means is that the svd decomposition
3527 // is not a map svd: C^{n x k} -> St_k(C^n) x R^n x St_k(C^k) (where St_k(C^k)
3528 // is simply the unitary group U(k)) but a map svd: C^{n x k} -> M x R^n where
3529 // M is the manifold given by quotienting St_k(C^n) x U(n) by the action (U,
3530 // V) -> (UL, VL) with L as above. Note that M is a manifold, because the
3531 // action is free and proper (as U(1)^k \iso (S^1)^k is compact). For this
3532 // reason, pi : St_k(C^n) x U(n) -> M forms a principal bundle.
3533 //
3534 // To think about M, consider the case case k = 1. The, we have the bundle
3535 // pi : St_1(C^n) x U(1) -> M
3536 // now, St_1(C^n) are just vectors of norm 1 in C^n. That's exactly the sphere
3537 // of dimension 2n-1 in C^n \iso R^{2n} S^{2n-1} = { z \in C^n | z^H z = 1}.
3538 // Then, in this case, we're quotienting out U(1) completely, so we get that
3539 // pi : S^{2n-1} x U(1) -> CP(n-1)
3540 // where CP(n-1) is the complex projective space of dimension n-1.
3541 // In other words, M is just the complex projective space, and pi is (pretty
3542 // similar to) the usual principal bundle from S^{2n-1} to CP(n-1). The case k
3543 // > 1 is the same, but requiring a linear independence condition between the
3544 // vectors from the different S^{2n-1} or CP(n-1).
3545 //
3546 // Note that this is a U(1)^k-bundle. In plain words, this means that the
3547 // fibres of this bundle, i.e. pi^{-1}(x) for x \in M are isomorphic to U(1) x
3548 // ... x U(1). This is obvious as, if pi(U,V) = x, pi^{-1}(x) = {(U
3549 // diag(e^{i\theta}), V diag(e^{i\theta})) | \theta \in R^k}
3550 // = {(U diag(z), V diag(z)) | z \in U(1)^k}
3551 // since U(1) = {z \in C | |z| = 1}.
3552 //
3553 // The big issue here is that M with its induced metric is not locally
3554 // isometric to St_k(C^n) x U(k). [The why is rather technical, but you can
3555 // see that the horizontal distribution is not involutive, and hence
3556 // integrable due to Frobenius' theorem] What this means in plain words is
3557 // that, no matter how we choose to return the U and V from the SVD, we won't
3558 // be able to simply differentiate wrt. U and V and call it a day. An example
3559 // of a case where we can do this is when performing an eigendecomposition on
3560 // a real matrix that happens to have real eigendecomposition. In this case,
3561 // even though you can rescale the eigenvectors by any real number, you can
3562 // choose them of norm 1 and call it a day. In the eigenvector case, we are
3563 // using that you can isometrically embed S^{n-1} into R^n. In the svd case,
3564 // we need to work with the "quotient manifold" M explicitly, which is
3565 // slightly more technically challenging.
3566 //
3567 // Since the columns of U and V are not uniquely defined, but are
3568 // representatives of certain classes of equivalence which represent elements
3569 // M, the user may not depend on the particular representative that we return
3570 // from the SVD. In particular, if the loss function depends on U or V, it
3571 // must be invariant under the transformation (U, V) -> (UL, VL) with L =
3572 // diag(e^{i\theta})), for every \theta \in R^k. In more geometrical terms,
3573 // this means that the loss function should be constant on the fibres, or, in
3574 // other words, the gradient along the fibres should be zero. We may see this
3575 // by checking that the gradients as element in the tangent space T_{(U,
3576 // V)}(St(n,k) x U(k)) are normal to the fibres. Differentiating the map (U,
3577 // V) -> (UL, VL), we see that the space tangent to the fibres is given by
3578 // Vert_{(U, V)}(St(n,k) x U(k)) = { i[U, V]diag(\theta) | \theta in R^k}
3579 // where [U, V] denotes the vertical concatenation of U and V to form an (n+k,
3580 // k) matrix. Then, solving <i[U,V]diag(\theta), [S, T]> = 0 for two matrices
3581 // S, T \in T_{(U, V)}(St(n,k) x U(k)) where <A, B> = Re tr(A^H B) is the
3582 // canonical (real) inner product in C^{n x k} we get that the function is
3583 // invariant under action of U(1)^k iff Im(diag(U^H gU + V^H gV)) = 0
3584 //
3585 // Using this in the derviaton for the forward AD, one sees that, with the
3586 // notation from those notes Using this and writing sym(X) = X + X^H, we get
3587 // that the forward AD for SVD in the complex case is given by dU = U (sym(dX
3588 // S) / E + i Im(diag(dX)) / (2S)) if m > n
3589 // dU = [dU for m == n] + (I_m - UU^H) dA V S^{-1}
3590 // dS = Re(diag(dP))
3591 // dV = V (sym(S dX) / E - i Im(diag(dX)) / (2S))
3592 // if m < n
3593 // dV = [dV for m == n] + (I_n - VV^H) (dA)^H U S^{-1}
3594 // dVh = dV^H
3595 // with dP = U^H dA V
3596 // dX = dP - dS
3597 // E_{jk} = S_k^2 - S_j^2 if j != k
3598 // 1 otherwise
3599 //
3600 // Similarly, writing skew(X) = X - X^H
3601 // the adjoint wrt. the canonical metric is given by
3602 // if m == n
3603 // gA = U [((skew(U^H gU) / E) S + i Im(diag(U^H gU)) / S + S ((skew(V^H gV)
3604 // / E)) + I o gS] V^H
3605 // if m > n
3606 // gA = [term in m == n] + (I_m - UU^H)gU S^{-1} V^H
3607 // if m < n
3608 // gA = [term in m == n] + U S^{-1} (gV)^H (I_n - VV^H)
3609 // where we have used that Im(diag(U^H gU)) = - Im(diag(V^h gV)) to group the
3610 // diagonal imaginary terms into one that just depends on U^H gU.
3611
3612 // Checks compute_uv=true
3613 TORCH_INTERNAL_ASSERT(U.dim() >= 2 && Vh.dim() >= 2);
3614
3615 // Trivial case
3616 if (!gS.defined() && !gU.defined() && !gVh.defined()) {
3617 return {};
3618 }
3619
3620 const auto m = U.sym_size(-2);
3621 const auto n = Vh.sym_size(-1);
3622
3623 // Optimisation for svdvals: gA = U @ diag(gS) @ Vh
3624 if (!gU.defined() && !gVh.defined()) {
3625 return m >= n ? at::matmul(U, gS.unsqueeze(-1) * Vh)
3626 : at::matmul(U * gS.unsqueeze(-2), Vh);
3627 }
3628 // At this point, at least one of gU, gVh is defined
3629
3630 const bool is_complex = U.is_complex();
3631 const auto skew = [](const Tensor& A) { return A - A.mH(); };
3632 const auto UhgU = gU.defined() ? skew(at::matmul(U.mH(), gU)) : Tensor{};
3633 const auto VhgV = gVh.defined() ? skew(at::matmul(Vh, gVh.mH())) : Tensor{};
3634
3635 // Check for the invariance of the loss function, i.e.
3636 // Im(diag(U^H gU)) + Im(diag(V^H gV)) = 0
3637 if (is_complex) {
3638 const auto imdiag_UhgU =
3639 gU.defined() ? at::imag(UhgU.diagonal(0, -2, -1)) : at::zeros_like(S);
3640 const auto imdiag_VhgV =
3641 gVh.defined() ? at::imag(VhgV.diagonal(0, -2, -1)) : at::zeros_like(S);
3642 // Rather lax atol and rtol, as we don't want false positives
3643 TORCH_CHECK(
3644 at::allclose(imdiag_UhgU, -imdiag_VhgV, /*rtol=*/1e-2, /*atol=*/1e-2),
3645 "svd_backward: The singular vectors in the complex case are specified up to multiplication "
3646 "by e^{i phi}. The specified loss function depends on this phase term, making "
3647 "it ill-defined.");
3648 }
3649
3650 // gA = ((U^H gU) / E) S + S (((V^H gV) / E) + I o (gS + diag(U^H gU) / (2 *
3651 // S))
3652 Tensor gA = [&] {
3653 // ret holds everything but the diagonal of gA
3654 auto ret = [&] {
3655 const auto E = [&S] {
3656 const auto S2 = S * S;
3657 auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1);
3658 // Any number a != 0 would, as we are just going to use it to compute 0
3659 // / a later on
3660 ret.diagonal(0, -2, -1).fill_(1);
3661 return ret;
3662 }();
3663
3664 if (gU.defined()) {
3665 if (gVh.defined()) {
3666 return (UhgU * S.unsqueeze(-2) + S.unsqueeze(-1) * VhgV) / E;
3667 } else {
3668 return (UhgU / E) * S.unsqueeze(-2);
3669 }
3670 } else { // gVh.defined();
3671 return S.unsqueeze(-1) * (VhgV / E);
3672 }
3673 }();
3674 // Fill the diagonal
3675 if (gS.defined()) {
3676 ret = ret + gS.diag_embed();
3677 }
3678 if (is_complex && gU.defined() && gVh.defined()) {
3679 ret = ret + (UhgU.diagonal(0, -2, -1) / (2. * S)).diag_embed();
3680 }
3681 return ret;
3682 }();
3683
3684 if (m > n && gU.defined()) {
3685 // gA = [UgA + (I_m - UU^H)gU S^{-1}]V^H
3686 gA = at::matmul(U, gA);
3687 const auto gUSinv = gU / S.unsqueeze(-2);
3688 gA = gA + gUSinv - at::matmul(U, at::matmul(U.mH(), gUSinv));
3689 gA = at::matmul(gA, Vh);
3690 } else if (m < n && gVh.defined()) {
3691 // gA = U[gA V^H + S^{-1} (gV)^H (I_n - VV^H)]
3692 gA = at::matmul(gA, Vh);
3693 const auto SinvgVh = gVh / S.unsqueeze(-1);
3694 gA = gA + SinvgVh - at::matmul(at::matmul(SinvgVh, Vh.mH()), Vh);
3695 gA = at::matmul(U, gA);
3696 } else {
3697 // gA = U gA V^H
3698 gA = m >= n ? at::matmul(U, at::matmul(gA, Vh))
3699 : at::matmul(at::matmul(U, gA), Vh);
3700 }
3701
3702 return gA;
3703 }
3704
linalg_eig_backward(const Tensor & gL,const Tensor & gV,const Tensor & L,const Tensor & V,const bool is_hermitian,const bool symeig_eigenvectors)3705 Tensor linalg_eig_backward(
3706 const Tensor& gL,
3707 const Tensor& gV,
3708 const Tensor& L,
3709 const Tensor& V,
3710 const bool is_hermitian,
3711 const bool symeig_eigenvectors) {
3712 at::NoTF32Guard disable_tf32;
3713 // https://arxiv.org/pdf/1701.00392.pdf Eq 4.77
3714 // For A = VLV^{-1}, denoting the gradients gA, gV and gL, we have
3715 // gA = V^{-H}(diag_embed(gL) + (V^H gV -V^HV diag(real(V^H gV))) / E*)V^H
3716 // Where:
3717 // - E_{ij} = L_j - L_i if i != j
3718 // 1 otherwise
3719 // - diag_embed takes a vector into a diagonal matrix
3720 // - diag zeroes out elements outside of the diagonal
3721
3722 // Note: the term '-V^HV diag(real(V^H gV))' comes from the fact that the
3723 // eigenvalue decomposition is returned with eigenvectors normalized to have
3724 // norm one.
3725
3726 // Note: The Hermitian case is a simplification of this formula using that
3727 // V^{-1} = V^H and that L is real
3728
3729 // This check just can be triggered in the backwards of torch.symeig
3730 TORCH_CHECK(
3731 symeig_eigenvectors,
3732 "linalg_eig_backward: torch.symeig(A, eigenvectors=False) is not differentiable. ",
3733 "Use torch.linalg.eigvalsh(A) instead.");
3734
3735 // Trivial case
3736 if (!gL.defined() && !gV.defined()) {
3737 return {};
3738 }
3739
3740 // Shortcut for linalg.eigvals/eigvalsh
3741 // Compute V^-H gL V^H
3742 if (!gV.defined()) {
3743 if (is_hermitian) {
3744 return at::matmul(V * gL.unsqueeze(-2), V.mH());
3745 } else {
3746 return at::linalg_solve(V.mH(), gL.unsqueeze(-1) * V.mH());
3747 }
3748 }
3749 auto VhgV = at::matmul(V.mH(), gV);
3750 const auto diag_VhgV = VhgV.diagonal(0, -2, -1);
3751
3752 if (V.is_complex() && !at::isTensorSubclassLike(diag_VhgV)) {
3753 // Check invariance of the loss function wrt the transformation
3754 // V -> V * e^{i\phi} for an arbitrary phi in RR^n
3755 const auto imdiag_VhgV = at::imag(diag_VhgV);
3756 TORCH_CHECK(
3757 at::allclose(
3758 imdiag_VhgV,
3759 at::zeros_like(imdiag_VhgV),
3760 /*rtol=*/1e-2,
3761 /*atol=*/1e-2),
3762 is_hermitian ? "linalg_eigh_backward" : "linalg_eig_backward",
3763 ": The eigenvectors in the complex case are specified up to multiplication ",
3764 "by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined.");
3765 }
3766
3767 if (is_hermitian) {
3768 // Project onto the tangent space at the identity of U(n), that is, the
3769 // skew-Hermitian matrices
3770 VhgV = 0.5 * (VhgV - VhgV.mH());
3771 } else {
3772 // Project onto the tangent space at V^H V of complex matrices with columns
3773 // of norm 1
3774 VhgV = VhgV - at::matmul(V.mH(), V * at::real(diag_VhgV).unsqueeze(-2));
3775 }
3776
3777 auto gA = [&, VhgV = std::move(VhgV)] {
3778 auto Econj = [&L] {
3779 auto Lconj = L.conj();
3780 auto ret = Lconj.unsqueeze(-2) - Lconj.unsqueeze(-1);
3781 ret.diagonal(0, -2, -1).fill_(1.);
3782 return ret;
3783 }();
3784
3785 auto ret = VhgV.div_(Econj);
3786
3787 if (gL.defined()) {
3788 // For CompositeCompliance, if `gL` is subclass but `ret`
3789 // is a regular Tensor, then use out-of-place version of diagonal
3790 // copy aka `diagonal_scatter`.
3791 if (at::isTensorSubclassLike(gL)) {
3792 ret = ret.diagonal_scatter(gL, 0, -2, -1);
3793 } else {
3794 ret.diagonal(0, -2, -1).copy_(gL);
3795 }
3796 }
3797 return ret;
3798 }();
3799
3800 // Conjugate by V^{-H}
3801 if (is_hermitian) {
3802 return at::matmul(V, at::matmul(gA, V.mH()));
3803 } else {
3804 return at::linalg_solve(V.mH(), at::matmul(gA, V.mH()));
3805 }
3806 }
3807
linalg_eig_jvp(const Tensor & dA,const Tensor & L,const Tensor & V,const bool is_hermitian)3808 std::tuple<Tensor, Tensor> linalg_eig_jvp(
3809 const Tensor& dA,
3810 const Tensor& L,
3811 const Tensor& V,
3812 const bool is_hermitian) {
3813 at::NoTF32Guard disable_tf32;
3814 // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
3815 // see also https://arxiv.org/pdf/1701.00392.pdf Eqs. (4.60) and (4.63)
3816 // Note that neither of the formulas in these pdfs are correct, as they do not
3817 // assume that the eigenvectors are of unit norm. As such, they are missing
3818 // the diagonal term in dV dL = diag(dP) dV = dX - V Re(diag V^H dX)) where dP
3819 // = V^{-1} dA V dX = V ((dP - diag(dP)) / E) E_{ij} = L_j - L_i if i != j
3820 // 1 otherwise
3821
3822 // Precondition: if is_hermitian == true, then dA is Hermitian
3823 const auto to_complex = [](const Tensor& A) {
3824 return A.to(c10::toComplexType(A.scalar_type()));
3825 };
3826
3827 const auto dP = is_hermitian
3828 ? at::matmul(at::matmul(V.mH(), dA), V)
3829 : at::linalg_solve(V, at::matmul(to_complex(dA), V));
3830 auto dL = is_hermitian && dA.is_complex() ? at::real(dP.diagonal(0, -2, -1))
3831 : dP.diagonal(0, -2, -1);
3832 auto dV = [&dP, &V, &L, is_hermitian] {
3833 auto dX = [&] {
3834 auto ret = dP / (L.unsqueeze(-2) - L.unsqueeze(-1));
3835 ret.diagonal(0, -2, -1).zero_();
3836 ret = at::matmul(V, ret);
3837 return ret;
3838 }();
3839
3840 if (is_hermitian) {
3841 return dX;
3842 } else {
3843 return dX -
3844 V *
3845 at::real(at::matmul(V.mH(), dX).diagonal(0, -2, -1)).unsqueeze(-2);
3846 }
3847 }();
3848 return std::make_pair(std::move(dL), std::move(dV));
3849 }
3850
linalg_lstsq_jvp(const Tensor & A,const Tensor & B,const Tensor & dA,const Tensor & dB)3851 Tensor linalg_lstsq_jvp(
3852 const Tensor& A,
3853 const Tensor& B,
3854 const Tensor& dA,
3855 const Tensor& dB) {
3856 at::NoTF32Guard disable_tf32;
3857 auto pinvA = at::linalg_pinv(A);
3858 auto dpinvA = pinv_jvp(A, pinvA, dA);
3859 auto dX = dpinvA.matmul(B) + pinvA.matmul(dB);
3860 return dX;
3861 }
3862
linalg_lstsq_backward(const Tensor & gX_,const Tensor & A,const Tensor & B_,const std::array<bool,2> & grad_input_mask)3863 std::tuple<Tensor, Tensor> linalg_lstsq_backward(
3864 const Tensor& gX_,
3865 const Tensor& A,
3866 const Tensor& B_,
3867 const std::array<bool, 2>& grad_input_mask) {
3868 at::NoTF32Guard disable_tf32;
3869 auto A_requires_grad = grad_input_mask[0];
3870 auto B_requires_grad = grad_input_mask[1];
3871 if (!gX_.defined() || (!A_requires_grad && !B_requires_grad)) {
3872 return {};
3873 }
3874
3875 const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B_);
3876 const auto vector_to_matrix = [vector_case](const Tensor& X) {
3877 return vector_case ? X.unsqueeze(-1) : X;
3878 };
3879 const auto matrix_to_vector = [vector_case](const Tensor& X) {
3880 return vector_case ? X.squeeze(-1) : X;
3881 };
3882
3883 auto gX = vector_to_matrix(gX_);
3884 auto B = vector_to_matrix(B_);
3885 Tensor pinvA = at::linalg_pinv(A);
3886 Tensor A_grad, B_grad;
3887 if (A_requires_grad) {
3888 auto pinvA_grad = gX.matmul(B.mH());
3889 A_grad = pinv_backward(pinvA_grad, pinvA, A);
3890 }
3891
3892 if (B_requires_grad) {
3893 // Equivalent to
3894 // B_grad = std::get<0>(at::linalg_lstsq(A.mH(), gX, rcond, driver));
3895 // but we avoid this approach as `gelsy` is non-deterministic
3896 B_grad = matrix_to_vector(pinvA.mH().matmul(gX));
3897 }
3898
3899 return std::make_tuple(A_grad, B_grad);
3900 }
3901
linalg_qr_jvp(const Tensor & dA,const Tensor & Q,const Tensor & R,const c10::string_view mode)3902 std::tuple<Tensor, Tensor> linalg_qr_jvp(
3903 const Tensor& dA,
3904 const Tensor& Q,
3905 const Tensor& R,
3906 const c10::string_view mode) {
3907 // dA = dQR + QdR
3908 //
3909 // Case m >= n
3910 // We can put dQ in terms of dR
3911 // dQ = dAR^{-1} - QdRR^{-1}
3912 // Then we have
3913 // Q^H dA R^{-1} = Q^HdQ + dRR^{-1}
3914 // where Q^HdQ is skew Hermitian and dRR^{-1} is upper triangular
3915 // Define sym(X) = X + X^H
3916 // sym(dRR^{-1}) = sym(Q^H dA R^{-1})
3917 // and define syminv(X) = triu(X) - 0.5 * diag(X) the inverse of
3918 // sym : Triu(k, diag \in \mathbb{R}) -> Her(k) to give
3919 // dR = syminv(sym(Q^H dA R^{-1}))R
3920 //
3921 // Case m < n
3922 // Put dR as a function of dQ
3923 // dR = Q^H dA - Q^H dQ R
3924 // Let X_1 be the main m x m submatrix of a matrix X \in C^{m x n}
3925 // Q^H A_1 R_1^{-1} = Q^H dQ + dR_1 R_1^{-1}
3926 // Define trilIm(X) = X.tril(-1) + i * Im diag(X)
3927 // trilIm(Q^H dQ) = trilIm(Q^H A_1 R_1^{-1})
3928 // and define trilIminv(X) = X - X^H - i*Im diag(X). This is the inverse of
3929 // trilIm : Skew_C(m) -> Tril(m, imaginary diag)
3930 // Note that it is just the inverse when the inputs are skew-Hermitian, not
3931 // necessarily when the inputs are arbitrary matrices. We then get dQ = Q
3932 // trilImInv(trilIm(Q^H A_1 R_1^{-1}))
3933 at::NoTF32Guard disable_tf32;
3934
3935 auto [compute_q, reduced] = at::native::_parse_qr_mode(mode);
3936
3937 TORCH_CHECK(
3938 compute_q,
3939 "The derivative of linalg.qr depends on Q, which is not computed when "
3940 "mode='r'. Please use linalg.qr(A, mode='reduced') if you are "
3941 "going to differentiate through linalg.qr.");
3942 auto m = dA.size(-2);
3943 auto n = dA.size(-1);
3944
3945 TORCH_CHECK(
3946 reduced || m <= n,
3947 "The QR decomposition is not differentiable when "
3948 "mode='complete' and nrows > ncols.");
3949 if (m >= n) {
3950 const auto sym = [](const Tensor& X) { return X + X.mH(); };
3951 const auto syminv = [](const Tensor& X) {
3952 auto ret = X.triu();
3953 ret.diagonal(0, -2, -1).mul_(0.5);
3954 return ret;
3955 };
3956 auto dARinv =
3957 at::linalg_solve_triangular(R, dA, /*upper=*/true, /*left=*/false);
3958 auto dR = syminv(sym(Q.mH().matmul(dARinv)));
3959 auto dQ = dARinv - Q.matmul(dR);
3960 dR = dR.matmul(R);
3961 return std::make_tuple(std::move(dQ), std::move(dR));
3962 } else {
3963 const auto trilim = [](const Tensor& X) {
3964 if (X.is_complex()) {
3965 auto ret = X.tril();
3966 at::real(ret.diagonal(0, -2, -1)).zero_();
3967 return ret;
3968 } else {
3969 return X.tril(-1);
3970 }
3971 };
3972 const auto triliminv = [](const Tensor& X) {
3973 if (X.is_complex()) {
3974 auto ret = X - X.mH();
3975 ret.diagonal(0, -2, -1).mul_(0.5);
3976 return ret;
3977 } else {
3978 return X - X.mT();
3979 }
3980 };
3981
3982 auto QHdA = Q.mH().matmul(dA);
3983 auto QHdA1Rinv = at::linalg_solve_triangular(
3984 R.narrow(-1, 0, m),
3985 QHdA.narrow(-1, 0, m),
3986 /*upper=*/true,
3987 /*left=*/false);
3988 auto dQ = triliminv(trilim(QHdA1Rinv));
3989 auto dR = QHdA - dQ.matmul(R);
3990 dQ = Q.matmul(dQ);
3991 return std::make_tuple(std::move(dQ), std::move(dR));
3992 }
3993 }
3994
linalg_qr_backward(const Tensor & gQ,const Tensor & gR,const Tensor & Q,const Tensor & R,const c10::string_view mode)3995 Tensor linalg_qr_backward(
3996 const Tensor& gQ,
3997 const Tensor& gR,
3998 const Tensor& Q,
3999 const Tensor& R,
4000 const c10::string_view mode) {
4001 // Nb. We won't be too formal below, as writing this proof formally is a pain
4002 // We'll link here a formal writing of all this at some point in the future
4003 //
4004 // Case m >= n
4005 // dQ = dAR^{-1} - Qsyminv(sym(Q^H dA R^{-1}))
4006 // dR = syminv(sym(Q^H dA R^{-1}))R
4007 //
4008 // With the notation from the JVP formula, the only two computations that we
4009 // need are syminv*(R) = 0.5 * (R.triu() + R.triu()^H - Re diag(R)) sym*(X) =
4010 // 2 * X Using these, after a few simplifications we get that gA = (gQ +
4011 // syminvadj(triu(gR R^H - Q^H gQ)))R^{-H}
4012 //
4013 // Case m < n
4014 // dR = Q^H dA - Q^H dQ R
4015 // dQ = Q trilImInv(trilIm(Q^H A_1 R_1^{-1}))
4016 //
4017 // In this case trilIm*(X) = X (it's the trivial embedding)
4018 // while trilImInv*(X) = tril(Y) - 0.5 * diag(Y)
4019 // with Y = X - X^H
4020 //
4021 // We also have that if X \in C^{m, n} an dpi(X) = X_1,
4022 // projects X into its leading m x m submatrix,
4023 // pi*(X) = cat(X, 0_{m,n-m}, dim=-1)
4024 //
4025 // Using this, we get that
4026 // gA = QgR + pi*(Q trilImInv*(Q^H gQ - gR R^H)R_1^{-H})
4027 at::NoTF32Guard disable_tf32;
4028
4029 auto [compute_q, reduced] = at::native::_parse_qr_mode(mode);
4030
4031 TORCH_CHECK(
4032 compute_q,
4033 "The derivative of linalg.qr depends on Q, which is not computed when "
4034 "mode='r'. Please use linalg.qr(A, mode='reduced') if you are "
4035 "going to differentiate through linalg.qr.");
4036
4037 auto m = Q.sym_size(-2);
4038 auto n = R.sym_size(-1);
4039
4040 TORCH_CHECK(
4041 reduced || m <= n,
4042 "The QR decomposition is not differentiable when "
4043 "mode='complete' and nrows > ncols.");
4044
4045 if (!gQ.defined() && !gR.defined()) {
4046 return {};
4047 }
4048
4049 Tensor gA;
4050 if (gQ.defined()) {
4051 if (gR.defined()) {
4052 gA = gR.matmul(R.mH()) - Q.mH().matmul(gQ);
4053 } else {
4054 gA = -Q.mH().matmul(gQ);
4055 }
4056 } else {
4057 gA = gR.matmul(R.mH());
4058 }
4059 if (m >= n) {
4060 const auto syminvadj = [](const Tensor& X) {
4061 auto ret = X + X.mH();
4062 at::real(ret.diagonal(0, -2, -1)).mul_(0.5);
4063 return ret;
4064 };
4065 gA = Q.matmul(syminvadj(gA.triu()));
4066 if (gQ.defined()) {
4067 gA = gA + gQ;
4068 }
4069 gA = at::linalg_solve_triangular(
4070 R.mH(), gA, /*upper*/ false, /*left*/ false);
4071 return gA;
4072 } else {
4073 auto trilImInvAdjSkew = [](const Tensor& X) {
4074 auto ret = (X - X.mH()).tril();
4075 if (X.is_complex()) {
4076 at::imag(ret.diagonal(0, -2, -1)).mul_(0.5);
4077 }
4078 return ret;
4079 };
4080 gA = Q.matmul(trilImInvAdjSkew(-gA));
4081 gA = at::linalg_solve_triangular(
4082 R.narrow_symint(-1, 0, m).mH(), gA, /*upper*/ false, /*left*/ false);
4083 auto shape = R.sym_sizes().vec();
4084 shape.end()[-1] = n - m;
4085 gA = at::cat({gA, gA.new_zeros_symint(shape)}, /*dim=*/-1);
4086 if (gR.defined()) {
4087 gA = gA + Q.matmul(gR);
4088 }
4089 return gA;
4090 }
4091 }
4092
4093 // Based on:
4094 //
4095 // Mathias, Roy.
4096 // A Chain Rule for Matrix Functions and Applications.
4097 // SIAM J. Matrix Anal. Appl. 17 (1996): 610-620.
4098
4099 template <typename func_t>
differential_analytic_matrix_function(const Tensor & self,const Tensor & grad,const func_t & matrix_function,const bool adjoint)4100 Tensor differential_analytic_matrix_function(
4101 const Tensor& self,
4102 const Tensor& grad,
4103 const func_t& matrix_function,
4104 const bool adjoint // Choose between forward (adjoint=false) or backward AD
4105 // (adjoint=true)
4106 ) {
4107 // Given an analytic matrix function, this computes the differential (forward
4108 // AD) or the adjoint of the differential (backward AD)
4109 auto A = adjoint ? self.transpose(-2, -1).conj() : self;
4110 auto meta_grad_sizes = A.sym_sizes().vec();
4111 meta_grad_sizes[A.dim() - 2] *= 2;
4112 meta_grad_sizes[A.dim() - 1] *= 2;
4113
4114 auto n = A.sym_size(-1);
4115 Tensor meta_grad;
4116 // For Composite Compliance, we can't copy a Subclass into a Regular Tensor,
4117 // so we use out-of-place ops with equivalent output.
4118 // NOTE: We can't use `new_zeros` directly as both `A` and `grad` can
4119 // be Tensor Subclass and we don't want to make assumption about which
4120 // one to choose for creating output buffer.
4121 // eg. if both are BatchedTensor at different level.
4122 if (areAnyTensorSubclassLike({A, grad})) {
4123 meta_grad = at::cat(
4124 {at::cat({A, grad}, -1),
4125 at::cat({at::zeros_like(A), std::move(A)}, -1)},
4126 -2);
4127 } else {
4128 meta_grad = at::zeros_symint(meta_grad_sizes, grad.options());
4129 meta_grad.narrow_symint(-2, 0, n).narrow_symint(-1, 0, n).copy_(A);
4130 meta_grad.narrow_symint(-2, n, n).narrow_symint(-1, n, n).copy_(A);
4131 meta_grad.narrow_symint(-2, 0, n).narrow_symint(-1, n, n).copy_(grad);
4132 }
4133
4134 return matrix_function(meta_grad).narrow_symint(-2, 0, n).narrow_symint(
4135 -1, n, n);
4136 }
4137
linalg_matrix_exp_differential(const Tensor & self,const Tensor & grad,bool adjoint)4138 Tensor linalg_matrix_exp_differential(
4139 const Tensor& self,
4140 const Tensor& grad,
4141 bool adjoint) {
4142 at::NoTF32Guard disable_tf32;
4143
4144 return differential_analytic_matrix_function(
4145 self, grad, at::linalg_matrix_exp, /* adjoint */ adjoint);
4146 }
4147
4148 template <typename F1, typename F2, typename... Ts>
masked_fmap(const Tensor & mask,const F1 & f1,const F2 & f2,const Tensor & t,const Ts &...ts)4149 Tensor masked_fmap(
4150 const Tensor& mask,
4151 const F1& f1,
4152 const F2& f2,
4153 const Tensor& t,
4154 const Ts&... ts) {
4155 // This function takes two functions f1 and f2 and a (variadic) list of
4156 // tensors, and creates a new tensor of the same shape as the first element of
4157 // the list of tensors by applying the function f1 to the tensors for which
4158 // the mask is true and f2 to the tensors for which the mask is false This
4159 // function is used when we have a formula that works for, say, all
4160 // non-singular inputs and another one for when the inputs are singular. See
4161 // for example det_backward
4162
4163 // Precondition for the n == 0 case to make sense
4164 TORCH_INTERNAL_ASSERT(t.sym_numel() != 0);
4165 auto t_masked = t.index({mask});
4166 auto n = t_masked.sym_numel();
4167 if (n == t.sym_numel()) {
4168 return f1(t, ts...);
4169 } else if (n == 0) {
4170 return f2(t, ts...);
4171 } else {
4172 // Equivalent to
4173 // ret = torch.empty_like(t)
4174 // ret[mask] = f1(t1[mask], ..., tn[mask])
4175 // ret[~mask] = f2(t1[~mask], ..., tn[~mask])
4176 auto not_mask = mask.logical_not();
4177 return at::empty_like(t)
4178 .index_put_({mask}, f1(t_masked, ts.index({mask})...))
4179 .index_put_(
4180 {not_mask}, f2(t.index({not_mask}), ts.index({not_mask})...));
4181 }
4182 }
4183
linalg_det_jvp(const Tensor & dA,const Tensor & det,const Tensor & LU,const Tensor & pivots,const bool use_A_T)4184 Tensor linalg_det_jvp(
4185 const Tensor& dA,
4186 const Tensor& det,
4187 const Tensor& LU,
4188 const Tensor& pivots,
4189 const bool use_A_T) {
4190 // (d det)_A(E) = tr(A^{-1}E)*det
4191 // We use that the determinant is C^1 to approximate the gradient of singular
4192 // inputs Since we never differentiate over forward AD, we don't need to deal
4193 // with further gradients, as we do in grad_backward
4194 auto eps = at::native::_get_epsilon(c10::toRealValueType(LU.scalar_type()));
4195 auto LU_ =
4196 LU + at::diag_embed(at::where(LU.diagonal(0, -2, -1) == 0., eps, 0.));
4197 auto AinvE =
4198 at::linalg_lu_solve(LU_, pivots, dA, /*left=*/true, /*adjoint=*/use_A_T);
4199 return AinvE.diagonal(0, -2, -1).sum(-1) * det;
4200 }
4201
linalg_det_backward(const Tensor & grad,const Tensor & det,const Tensor & A,const Tensor & LU,const Tensor & pivots)4202 Tensor linalg_det_backward(
4203 const Tensor& grad,
4204 const Tensor& det,
4205 const Tensor& A,
4206 const Tensor& LU,
4207 const Tensor& pivots) {
4208 at::NoTF32Guard disable_tf32;
4209 // A.numel() == 0 necessary for the singular case
4210 if (!grad.defined() || A.sym_numel() == 0) {
4211 return {};
4212 }
4213
4214 // The gradient G is the matrix solving
4215 // A.mH G = det(A).conj() * grad * I
4216 auto d_diag = grad * det.conj();
4217 // Optimisation, Make it F-transposed as it's what lu_solve expects
4218 auto d = at::diag_embed(d_diag.unsqueeze(-1).expand_as(pivots)).mT();
4219 auto eps = at::native::_get_epsilon(c10::toRealValueType(LU.scalar_type()));
4220
4221 // Optimisation if we are not going to compute higher-order gradients
4222 if (!at::GradMode::is_enabled()) {
4223 // The formula is given by the solution of AX = det.conj() * det * I when A
4224 // is invertible det is C^1, so if it's not invertible, we can apply a
4225 // perturbation to the LU decomposition and use the resulting matrix as a
4226 // non-singular approximation
4227 auto LU_ =
4228 LU + at::diag_embed(at::where(LU.diagonal(0, -2, -1) == 0., eps, 0.));
4229 auto use_A_T = A.is_contiguous() && !A.is_complex();
4230 return at::linalg_lu_solve(
4231 LU_, pivots, d, /*left=*/true, /*adjoint=*/!use_A_T);
4232 } else {
4233 // If we want to compute higher-order gradients, we need to recompute the
4234 // LU decomposition so that autograd computes the correct gradients wrt
4235 // to A (cf. solve_backward)
4236 auto non_singular =
4237 [](const Tensor& A, const Tensor& d, const Tensor& /*grad*/) {
4238 return at::linalg_solve(A.mH(), d);
4239 };
4240
4241 // The derivative may be then computed explicitly by noting that the
4242 // gradient of the derivative of the determinant is given in terms of the
4243 // adjugate of a matrix. The adjugate of a singular matrix may be computed
4244 // as per https://nhigham.com/2020/06/16/what-is-the-adjugate-of-a-matrix/
4245 auto singular = [](const Tensor& A,
4246 const Tensor& /*d*/,
4247 const Tensor& grad) {
4248 auto [U, S, Vh] = at::linalg_svd(A);
4249 auto alpha = (at::linalg_det(U) * at::linalg_det(Vh)).conj() * grad;
4250 auto D = prod_safe_zeros_backward(alpha.unsqueeze(-1), S, S.dim() - 1);
4251 return (U * D.unsqueeze(-2)).matmul(Vh);
4252 };
4253
4254 // We could use the singular formula for all inputs but we try to filter out
4255 // some inputs via the masking, as computing an SVD is about 100 times
4256 // slower than computing an lu_solve on GPU
4257 // For tensor subclasses, we can't call masked_fmap as it calls
4258 // index({mask}) which needs to call item to compute the number of elements
4259 // in the result.
4260
4261 if (areAnyTensorSubclassLike({A, d, grad})) {
4262 return singular(A, d, grad);
4263 } else {
4264 return masked_fmap(
4265 det.abs() < 100. * eps, singular, non_singular, A, d, grad);
4266 }
4267 }
4268 }
4269
slogdet_jvp(const Tensor & LU,const Tensor & pivots,const Tensor & dA,const Tensor & sign,const bool use_A_T)4270 std::tuple<Tensor, Tensor> slogdet_jvp(
4271 const Tensor& LU,
4272 const Tensor& pivots,
4273 const Tensor& dA,
4274 const Tensor& sign,
4275 const bool use_A_T) {
4276 // No need to handle the singular case separately as we do in det since
4277 // this function is not differentiable on singular matrices
4278 auto trAinvE = at::linalg_lu_solve(LU, pivots, dA, /*left*/ true, use_A_T)
4279 .diagonal(0, -2, -1)
4280 .sum(-1);
4281 if (LU.is_complex()) {
4282 auto i = c10::complex<double>{0.0, 1.0};
4283 return std::make_tuple(at::imag(trAinvE) * (i * sign), at::real(trAinvE));
4284 } else {
4285 return std::make_tuple(
4286 at::_efficientzerotensor(sign.sizes(), sign.options()), trAinvE);
4287 }
4288 }
4289
slogdet_backward(const Tensor & grad_sign,const Tensor & grad_logabsdet,const Tensor & A,const Tensor & signdet,const Tensor & LU,const Tensor & pivots)4290 Tensor slogdet_backward(
4291 const Tensor& grad_sign,
4292 const Tensor& grad_logabsdet,
4293 const Tensor& A,
4294 const Tensor& signdet,
4295 const Tensor& LU,
4296 const Tensor& pivots) {
4297 // We compute the complex case, as the real case follows from it
4298 // Forward AD
4299 // d (logabsdet)_A(E) = Re(tr(A^{-1}E))
4300 // d (signdet)_A(E) = sgn * Im(tr(A^{-1}E)) * i
4301 // So
4302 // d (logabsdet)*_A(g) = gA^{-H}
4303 // Now, to compute the adjoint of d(signdet), note that
4304 // Re(z * Im(w)) = Re(-Re(z)iw)
4305 // So, let g \in C,
4306 // <g, d(signdet)_A(E)> = Re(g.conj() * sgn * i * Im(A^{-1}E))
4307 // = Re(Re(g.conj() * sgn * i) * -i * A^{-1}E)
4308 // = Re(Im(g.conj() * sgn) * i * A^{-1}E)
4309 // = <Im(g.conj() * sgn) * -i * A^{-H}, E>
4310 // As such,
4311 // (d slogabs)*_A(g_sign, g_abs) = (g_abs - g_sign.conj() * sgn) * A^{-H}
4312
4313 if (!grad_sign.defined() && !grad_logabsdet.defined()) {
4314 return {};
4315 }
4316
4317 auto is_complex = A.is_complex();
4318
4319 // In the real case grad_sign is always zero
4320 if (!is_complex && !grad_logabsdet.defined()) {
4321 return {};
4322 }
4323
4324 auto g = grad_logabsdet;
4325 if (is_complex) {
4326 if (grad_sign.defined()) {
4327 auto i = c10::complex<double>{0.0, 1.0};
4328 if (g.defined()) {
4329 g = g - i * at::imag(grad_sign.conj() * signdet);
4330 } else {
4331 g = -i * at::imag(grad_sign.conj() * signdet);
4332 }
4333 } else {
4334 // Cast to complex explicitly
4335 g = g.to(A.scalar_type());
4336 }
4337 }
4338
4339 // No need to handle the singular case separately here (as we do in det)
4340 // since this function is not differentiable on singular matrices
4341 // Optimisation, Make it F-transposed as it's what lu_solve expects
4342 auto d = at::diag_embed(g.unsqueeze(-1).expand_as(pivots)).mT();
4343 if (!at::GradMode::is_enabled()) {
4344 auto use_A_T = A.is_contiguous() && !A.is_complex();
4345 return at::linalg_lu_solve(
4346 LU, pivots, d, /*left=*/true, /*adjoint=*/!use_A_T);
4347 } else {
4348 // If we want to compute further gradients, we need to recompute the LU
4349 // decomposition so that autograd computes the correct gradients wrt to A
4350 // (cf. solve_backward)
4351 return at::linalg_solve(A.mH(), d);
4352 }
4353 }
4354
4355 // Reference:
4356 // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
4357 // Sec. 2.3.1 Matrix inverse product
triangular_solve_backward(const Tensor & grad_x,const Tensor & grad_m,const Tensor & b,const Tensor & a,const Tensor & x,const bool upper,const bool transpose,const bool unitriangular,std::array<bool,2> output_mask)4358 std::tuple<Tensor, Tensor> triangular_solve_backward(
4359 const Tensor& grad_x,
4360 const Tensor& grad_m,
4361 const Tensor& b,
4362 const Tensor& a,
4363 const Tensor& x,
4364 const bool upper,
4365 const bool transpose,
4366 const bool unitriangular,
4367 std::array<bool, 2> output_mask) {
4368 at::NoTF32Guard disable_tf32;
4369 Tensor grad_b, grad_a;
4370 if (grad_x.defined() || grad_m.defined()) {
4371 if (grad_x.defined()) {
4372 grad_b = std::get<0>(
4373 grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular));
4374 if (output_mask[1]) {
4375 grad_a =
4376 transpose ? -x.conj().matmul(grad_b.mT()) : -grad_b.matmul(x.mH());
4377 if (upper) {
4378 grad_a = grad_a.triu((int)unitriangular);
4379 } else {
4380 grad_a = grad_a.tril(-((int)unitriangular));
4381 }
4382 }
4383 }
4384 if (!grad_a.defined()) {
4385 grad_a = at::zeros({1}, a.options()).expand_as(a);
4386 }
4387 if (!grad_b.defined()) {
4388 grad_b = at::zeros({1}, b.options()).expand_as(b);
4389 }
4390 if (output_mask[1] && grad_m.defined()) {
4391 grad_a = grad_a.add(grad_m);
4392 }
4393 }
4394 return std::tuple<Tensor, Tensor>{grad_b, grad_a};
4395 }
4396
triangular_solve_jvp(const Tensor & X,const Tensor & A,const Tensor & dA,const Tensor & dB,const bool upper,const bool transpose,const bool unitriangular)4397 Tensor triangular_solve_jvp(
4398 const Tensor& X,
4399 const Tensor& A,
4400 const Tensor& dA,
4401 const Tensor& dB,
4402 const bool upper,
4403 const bool transpose,
4404 const bool unitriangular) {
4405 return generic_solve_jvp(
4406 [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
4407 return std::get<0>(at::triangular_solve(
4408 dB - dA_contrib, A, upper, transpose, unitriangular));
4409 },
4410 X,
4411 A,
4412 dA,
4413 dB);
4414 }
4415
linalg_solve_triangular_forward_AD(const Tensor & A_t,const Tensor & B_t,const Tensor & A,const Tensor & X,const bool upper,const bool left,const bool unitriangular)4416 Tensor linalg_solve_triangular_forward_AD(
4417 const Tensor& A_t,
4418 const Tensor& B_t,
4419 const Tensor& A,
4420 const Tensor& X,
4421 const bool upper,
4422 const bool left,
4423 const bool unitriangular) {
4424 at::NoTF32Guard disable_tf32;
4425 // The forward AD formula (for left = true) is A^{-1}(B_t - A_tX)
4426 // For the derivation see:
4427 // [Note: Forward / Backward AD solve_triangular]
4428 const Tensor proj_A_t = upper ? A_t.triu(static_cast<int>(unitriangular))
4429 : A_t.tril(-static_cast<int>(unitriangular));
4430 const Tensor X_t =
4431 B_t - (left ? at::matmul(proj_A_t, X) : at::matmul(X, proj_A_t));
4432 return at::linalg_solve_triangular(A, X_t, upper, left, unitriangular);
4433 }
4434
linalg_solve_triangular_backward(const Tensor & grad,const Tensor & A,const Tensor & X,const bool upper,const bool left,const bool unitriangular,std::array<bool,2> output_mask)4435 std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
4436 const Tensor& grad,
4437 const Tensor& A,
4438 const Tensor& X,
4439 const bool upper,
4440 const bool left,
4441 const bool unitriangular,
4442 std::array<bool, 2> output_mask) {
4443 at::NoTF32Guard disable_tf32;
4444 const bool A_requires_grad = output_mask[0];
4445 const bool B_requires_grad = output_mask[1];
4446 // [Note: Forward / Backward AD solve_triangular]
4447 // Assume left=true for simplicity.
4448 // Remark: A solver computes A^{-1}B
4449 //
4450 // Forward AD:
4451 // If f(A) = A^{-1}, differentiating the equation A^{-1}A = I_n gives
4452 // (df)_A(E) = -A^{-1}EA^{-1}
4453 // As such, if g(A,B) = A^{-1}B,
4454 // (dg)_(A,B)(E_A, E_B) = -A^{-1}E_AA^{-1}B + A^{-1}E_B
4455 // = A^{-1}(E_B - E_AX)
4456
4457 // Backward AD:
4458 // Denoting the gradients by G_A, G_B, we solve above to give
4459 // G_B = A^{-H}G_X
4460 // G_A = -A^{-H}G_XX^H = -G_B X^H
4461 //
4462 // Note that you don't need to store B for forward nor backward
4463 //
4464 // These formulas work for a general solver of linear equations.
4465 // Let's prove now that when A is triangular, G_A is the projection onto the
4466 // triangular matrices of the formula above, i.e. simply taking triu (resp.
4467 // tril) in the formula above. This is because, since the triangular matrices
4468 // form a vector space, the tangent space at any point is itself the space of
4469 // triangular matrices. The result follows from a reasoning as that at the end
4470 // of [Note: eigh backward] Something similar happens for `unitriangular`,
4471 // only that int his case the tangent space is the set of lower-triangular
4472 // matrices with zeros on the diagonal.
4473
4474 if (!grad.defined() || (!A_requires_grad && !B_requires_grad)) {
4475 return std::make_tuple(Tensor{}, Tensor{});
4476 }
4477 // We always need to comput G_B
4478 const Tensor A_H = A.mH();
4479 const Tensor G_B =
4480 at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular);
4481
4482 if (A_requires_grad) {
4483 const Tensor X_H = X.mH();
4484 Tensor G_A = left ? -at::matmul(G_B, X_H) : -at::matmul(X_H, G_B);
4485 G_A = upper ? G_A.triu(static_cast<int>(unitriangular))
4486 : G_A.tril(-static_cast<int>(unitriangular));
4487 return std::make_tuple(G_A, B_requires_grad ? G_B : Tensor{});
4488 } else {
4489 return std::make_tuple(Tensor{}, G_B);
4490 }
4491 }
4492
cholesky_solve_backward(const Tensor & grad_x,const Tensor & self,const Tensor & input2,const Tensor & result,const bool upper,std::array<bool,2> output_mask)4493 std::tuple<Tensor, Tensor> cholesky_solve_backward(
4494 const Tensor& grad_x,
4495 const Tensor& self,
4496 const Tensor& input2,
4497 const Tensor& result,
4498 const bool upper,
4499 std::array<bool, 2> output_mask) {
4500 at::NoTF32Guard disable_tf32;
4501 Tensor grad_self, grad_input2;
4502 if (grad_x.defined()) {
4503 grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper);
4504
4505 if (output_mask[1]) {
4506 Tensor common_term = at::matmul(grad_self, result.mH());
4507 common_term = common_term + common_term.mH();
4508
4509 if (upper) {
4510 grad_input2 = -at::matmul(input2, common_term);
4511 } else {
4512 grad_input2 = -at::matmul(common_term, input2);
4513 }
4514 }
4515 }
4516 return std::tuple<Tensor, Tensor>{grad_self, grad_input2};
4517 }
4518
cholesky_solve_jvp(const Tensor & X,const Tensor & U,const Tensor & dU,const Tensor & dB,const bool upper)4519 Tensor cholesky_solve_jvp(
4520 const Tensor& X,
4521 const Tensor& U,
4522 const Tensor& dU,
4523 const Tensor& dB,
4524 const bool upper) {
4525 at::NoTF32Guard disable_tf32;
4526 auto dK = upper ? dU.mH().matmul(U) : dU.matmul(U.mH());
4527 auto dA = dK + dK.mH();
4528 return generic_solve_jvp(
4529 [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
4530 return at::cholesky_solve(dB - dA_contrib, A, upper);
4531 },
4532 X,
4533 /*A=*/U,
4534 dA,
4535 dB);
4536 }
4537
fft_c2r_backward(const Tensor & grad,IntArrayRef dim,int64_t normalization)4538 Tensor fft_c2r_backward(
4539 const Tensor& grad,
4540 IntArrayRef dim,
4541 int64_t normalization) {
4542 // Forward is C2R (onesided)
4543 // Think of onesided C2R irfft as
4544 // 1. fill the other half by conjugate symmetry
4545 // 2. inverse C2C ifft
4546 // 3. discard the complex dimension
4547 // So backward is
4548 // 1. R2C rfft (essentially add dummy complex dimension, and dft)
4549 // 2. accumulate gradient by conjugate symmetry
4550 // since rfft results follow conjugate symmetry, we only need to
4551 // double some entries from onesided rfft results, i.e., the ones with
4552 // their reflected indices also landing out of the onesided range. So
4553 // consider the index of last dim:
4554 // i. idx = 0.
4555 // Reflected to (N - 0) % N = 0. Not doubled.
4556 // ii 0 < idx < floor(N/2) (last).
4557 // N > N - idx > ceil(N/2)
4558 // Reflected to ()
4559 // iii. idx = floor(N/2) = N/2 (last) when N even.
4560 // Reflected to (N - N/2) % N = N/2. Not doubled.
4561 // iv. idx = floor(N/2) = (N-1)/2 (last) when N odd.
4562 // Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled.
4563 // Therefore, needs to double
4564 // idx = 1, 2, ..., N/2 - 1 when N even
4565 // idx = 1, 2, ..., (N-1)/2 when N odd
4566 // that is
4567 // idx = 1, 2, ..., N - (floor(N/2) + 1)
4568 // = 1, 2, ..., N - onesided_length
4569 auto gI = at::_fft_r2c(grad, dim, normalization, /*onesided=*/true);
4570
4571 auto double_length = grad.sym_size(dim.back()) - gI.sym_size(dim.back());
4572 if (double_length > 0) { // also covers case when signal size is zero
4573 gI.narrow_symint(dim.back(), 1, double_length).mul_(2);
4574 }
4575 return gI;
4576 }
4577
fft_r2c_backward(const Tensor & grad,at::IntArrayRef dim,int64_t normalization,bool onesided,const c10::SymInt & last_dim_size)4578 Tensor fft_r2c_backward(
4579 const Tensor& grad,
4580 at::IntArrayRef dim,
4581 int64_t normalization,
4582 bool onesided,
4583 const c10::SymInt& last_dim_size) {
4584 if (!onesided) {
4585 return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false));
4586 }
4587
4588 // Forward is R2C (onesided)
4589 // Think of onesided R2C rfft as
4590 // 1. view as complex numbers (fill complex dim with zeros)
4591 // 2. C2C fft
4592 // 3. discard half of results
4593 // So backward is
4594 // 1. fill the other half with zeros (with `zero_grad_shape` below)
4595 // (C2C ifft only take twosided inputs so we need to fill here)
4596 // 2. inverse C2C ifft
4597 // 3. discard the complex dim
4598 auto half_sizes = grad.sym_sizes();
4599 std::vector<c10::SymInt> new_grad_shape(half_sizes.begin(), half_sizes.end());
4600 const auto last_dim =
4601 at::maybe_wrap_dim(dim.back(), static_cast<int64_t>(half_sizes.size()));
4602 new_grad_shape[last_dim] = last_dim_size;
4603
4604 const auto zero_length = last_dim_size - grad.sym_size(dim.back());
4605 auto complex_full_grad =
4606 zero_length > 0 ? grad.new_zeros_symint(new_grad_shape) : grad;
4607 if (zero_length > 0) {
4608 complex_full_grad.slice_symint(last_dim, 0, half_sizes[last_dim])
4609 .copy_(grad);
4610 }
4611 return at::real(
4612 at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false));
4613 }
4614
4615 // Helper for batchnorm_double_backward
sum_exclude_dim1(const Tensor & to_sum,bool keepdim=true)4616 static Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim = true) {
4617 auto r = to_sum.sum(0, keepdim);
4618 int64_t start_point_exclusive = keepdim ? 1 : 0;
4619 for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) {
4620 r = r.sum(dim, keepdim);
4621 }
4622 return r;
4623 }
4624
4625 // Helper for batchnorm_double_backward
4626 // similar to expand_as below, but doesn't do the expand_as; operates as if
4627 // reductions were done with keepdim=True
unsqueeze_dim1(const Tensor & src,const Tensor & target)4628 static Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
4629 auto src_expanded = src;
4630 while (src_expanded.sizes().size() < target.sizes().size() - 1) {
4631 src_expanded = src_expanded.unsqueeze(1);
4632 }
4633 if (src_expanded.sizes().size() == target.sizes().size() - 1) {
4634 src_expanded = src_expanded.unsqueeze(0);
4635 }
4636 return src_expanded;
4637 }
4638
4639 // Helper for batchnorm_double_backward
4640 // because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't
4641 // do a straight expansion because it won't follow the broadcasting rules.
expand_as_dim1(const Tensor & src,const Tensor & target)4642 static Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
4643 auto src_expanded = src;
4644 while (src_expanded.sizes().size() < target.sizes().size() - 1) {
4645 src_expanded = src_expanded.unsqueeze(1);
4646 }
4647 return src_expanded.expand_as(target);
4648 }
4649
batchnorm_double_backward(const Tensor & input,const std::optional<Tensor> & gamma,const Tensor & ggI,const Tensor & ggG,const Tensor & ggB,const Tensor & gO,const std::optional<Tensor> & running_mean,const std::optional<Tensor> & running_var,bool training,double eps,const std::optional<Tensor> & save_mean,const std::optional<Tensor> & save_invstd,std::array<bool,3> output_mask)4650 std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
4651 const Tensor& input,
4652 const std::optional<Tensor>& gamma,
4653 const Tensor& ggI,
4654 const Tensor& ggG,
4655 const Tensor& ggB,
4656 const Tensor& gO,
4657 const std::optional<Tensor>& running_mean,
4658 const std::optional<Tensor>& running_var,
4659 bool training,
4660 double eps,
4661 const std::optional<Tensor>& save_mean,
4662 const std::optional<Tensor>& save_invstd,
4663 std::array<bool, 3> output_mask) {
4664 bool affine = isDefined(gamma);
4665 // TODO: Do we have a ScalarOrTensor type? Would such a thing exist?
4666 Tensor gamma_expanded;
4667 Tensor ggG_expanded, ggB_expanded;
4668 if (affine) {
4669 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
4670 gamma_expanded = expand_as_dim1(*gamma, input);
4671 if (ggG.defined()) {
4672 ggG_expanded = expand_as_dim1(ggG, input);
4673 }
4674 if (ggB.defined()) {
4675 ggB_expanded = expand_as_dim1(ggB, input);
4676 }
4677 } else {
4678 gamma_expanded = at::ones({}, input.options());
4679 }
4680
4681 // define some terms we will reuse
4682 auto M = input.size(0);
4683 for (auto s : input.sizes().slice(2)) {
4684 M *= s;
4685 }
4686 // for half inputs, save_mean, save_invstd are float (ideally, we would cast
4687 // everything else, but not now)
4688 auto mu = unsqueeze_dim1(
4689 training ? toNonOptTensor(save_mean).to(input.scalar_type())
4690 : toNonOptTensor(running_mean),
4691 input);
4692 auto input_sub_mu = input - mu;
4693 auto sigma2_eps_neg_1_2 = unsqueeze_dim1(
4694 training ? toNonOptTensor(save_invstd).to(input.scalar_type())
4695 : toNonOptTensor(running_var).add(Scalar(eps)).pow(-0.5),
4696 input);
4697 auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
4698 auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
4699
4700 // calculate gI
4701 auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
4702 auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu);
4703 auto gO_sum = sum_exclude_dim1(gO);
4704
4705 Tensor gI;
4706 if (ggI.defined() && training) {
4707 auto ggI_sum = sum_exclude_dim1(ggI);
4708 auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu);
4709 auto all_sub = ((ggI_sum * gO_sum).div_(M))
4710 .sub_(sum_exclude_dim1(gO * ggI))
4711 .add_((sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum)
4712 .mul_(3. / static_cast<double>(M)));
4713 auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M);
4714 auto gI_1t =
4715 (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO);
4716 auto gI_2t =
4717 (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI);
4718 gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t));
4719 }
4720
4721 // add contribution of gamma term to gI
4722 Tensor gI_G_term;
4723 if (affine && ggG.defined()) {
4724 if (training) {
4725 auto t0 = gO * sigma2_eps_neg_1_2;
4726 auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M);
4727 auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu))
4728 .div_(-M);
4729 gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2));
4730 gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
4731 } else {
4732 gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO;
4733 gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
4734 }
4735 }
4736
4737 // this is the first backward's grad_input
4738 auto first_back_grad_input = [&](const Tensor& gO,
4739 const Tensor& gamma) -> Tensor {
4740 auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M);
4741 auto h1 = (M * gO)
4742 .sub_(sum_exclude_dim1(gO))
4743 .sub_(
4744 input_sub_mu.mul(sigma2_eps_neg_1) *
4745 sum_exclude_dim1(gO * input_sub_mu));
4746 return h0 * h1;
4747 };
4748
4749 // calculate gG
4750 Tensor gG;
4751 if (affine && ggI.defined()) {
4752 if (training) {
4753 // gG is just the first backwards with the gamma term removed (then shaped
4754 // properly)
4755 gG = ggI *
4756 first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options()));
4757 gG = sum_exclude_dim1(gG, false);
4758 } else {
4759 gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false);
4760 }
4761 }
4762
4763 // calculate ggO
4764 Tensor ggO;
4765 // contribution of input term
4766 if (ggI.defined()) {
4767 if (training) {
4768 ggO = first_back_grad_input(ggI, gamma_expanded);
4769 } else {
4770 ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded;
4771 }
4772 }
4773 if (ggG.defined()) {
4774 auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
4775 ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
4776 }
4777 if (ggB.defined()) {
4778 auto ggO_B_term = std::move(ggB_expanded);
4779 ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
4780 }
4781
4782 if (output_mask[1] && !gG.defined()) {
4783 AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
4784 }
4785
4786 return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
4787 }
4788
layer_norm_double_backward(const Tensor & input_t,const std::optional<Tensor> & gamma,const Tensor & ggI,const Tensor & ggG,const Tensor & ggB,const Tensor & gO_t,const Tensor & save_mean_t,const Tensor & save_invstd_t,c10::SymIntArrayRef normalized_shape,std::array<bool,3> output_mask)4789 std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
4790 const Tensor& input_t,
4791 const std::optional<Tensor>& gamma,
4792 const Tensor& ggI,
4793 const Tensor& ggG,
4794 const Tensor& ggB,
4795 const Tensor& gO_t,
4796 const Tensor& save_mean_t,
4797 const Tensor& save_invstd_t,
4798 c10::SymIntArrayRef normalized_shape,
4799 std::array<bool, 3> output_mask) {
4800 const auto normalized_ndim = normalized_shape.size();
4801 const auto input_shape = input_t.sizes();
4802 const auto input_ndim = input_t.dim();
4803 const auto axis = input_ndim - normalized_ndim;
4804 const int64_t M =
4805 c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
4806 const int64_t N =
4807 c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
4808 // printf("M: %ld, N: %ld", M, N);
4809
4810 auto input = input_t.reshape({M, N});
4811 auto gO = gO_t.reshape({M, N});
4812 auto save_mean = save_mean_t.reshape({M, 1});
4813 auto save_invstd = save_invstd_t.reshape({M, 1});
4814
4815 bool affine = isDefined(gamma);
4816 Tensor gamma_expanded;
4817 Tensor ggG_expanded, ggB_expanded;
4818 if (affine) {
4819 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
4820 gamma_expanded = gamma->reshape({1, N});
4821 if (ggG.defined()) {
4822 ggG_expanded = ggG.reshape({1, N});
4823 }
4824 if (ggB.defined()) {
4825 ggB_expanded = ggB.reshape({1, N});
4826 }
4827 } else {
4828 gamma_expanded = at::ones({1}, input.options());
4829 }
4830
4831 Tensor ggI_expanded;
4832 if (ggI.defined()) {
4833 ggI_expanded = ggI.reshape({M, N});
4834 }
4835
4836 // for half inputs, save_mean, save_invstd are float
4837 // (ideally, we would cast everything else, but not now)
4838 auto mu = save_mean.to(input.scalar_type());
4839 auto input_sub_mu = input - mu;
4840 auto sigma2_eps_neg_1_2 = save_invstd.to(input.scalar_type());
4841 auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
4842 auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
4843
4844 Tensor gI;
4845 // calculate gI
4846 auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
4847
4848 if (ggI.defined()) {
4849 auto gxhat = gO * gamma_expanded;
4850 auto gxhat_mu_sum = (gxhat * input_sub_mu).sum(1, true);
4851 auto gxhat_sum = gxhat.sum(1, true);
4852
4853 auto ggI_sum = ggI_expanded.sum(1, true);
4854 auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true);
4855
4856 auto all_sub = ((ggI_sum * gxhat_sum).div_(N))
4857 .sub_((ggI_expanded * gxhat).sum(1, true))
4858 .add_((sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum)
4859 .mul_(3. / static_cast<double>(N)));
4860 auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N);
4861 auto gI_1t =
4862 (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat);
4863 auto gI_2t = (gxhat_mu_sum * sigma2_eps_neg_3_2).div_(N) *
4864 (ggI_sum.div(N) - ggI_expanded);
4865
4866 gI = (gI_0t.add_(gI_1t).add_(gI_2t));
4867 }
4868
4869 // add contribution of gamma term to gI
4870 if (affine && ggG.defined()) {
4871 auto t0 = gO * ggG_expanded * sigma2_eps_neg_1_2;
4872 auto t1 = (sigma2_eps_neg_1_2 * (gO * ggG_expanded).sum(1, true)).div_(-N);
4873 auto t2 = (input_mu_sigma2_neg_3_2 *
4874 (gO * ggG_expanded * input_sub_mu).sum(1, true))
4875 .div_(-N);
4876 auto gI_G_term = t0.add_(t1).add_(t2);
4877 gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
4878 }
4879
4880 if (gI.defined()) {
4881 // printf("=== computing gI\n");
4882 gI = gI.reshape_as(input_t);
4883 }
4884
4885 // this is the grad_input for the first backward function
4886 auto first_bwd_fn_grad_input = [&](const Tensor& gO_local,
4887 const Tensor& gamma_local) -> Tensor {
4888 auto h0 = (gamma_local * sigma2_eps_neg_1_2).div_(N);
4889 auto h1 = (N * gO_local)
4890 .sub_(gO_local.sum(1, true))
4891 .sub_(
4892 input_sub_mu.mul(sigma2_eps_neg_1) *
4893 (gO_local * input_sub_mu).sum(1, true));
4894 return h0 * h1;
4895 };
4896
4897 // calculate gG
4898 Tensor gG;
4899 if (affine && ggI.defined()) {
4900 gG = first_bwd_fn_grad_input(
4901 ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options()));
4902 gG = (gO * gG).sum(0);
4903 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
4904 gG = gG.reshape_as(*gamma);
4905 }
4906
4907 // calculate ggO
4908 Tensor ggO;
4909 // contribution of input term
4910 if (ggI.defined()) {
4911 ggO = first_bwd_fn_grad_input(ggI_expanded, gamma_expanded);
4912 }
4913 if (ggG.defined()) {
4914 auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
4915 ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
4916 }
4917 if (ggB.defined()) {
4918 auto ggO_B_term = std::move(ggB_expanded);
4919 ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
4920 }
4921 if (ggO.defined()) {
4922 ggO = ggO.expand({M, N}).reshape_as(input_t);
4923 }
4924
4925 if (output_mask[1] && !gG.defined()) {
4926 AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
4927 }
4928
4929 return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
4930 }
4931
4932 std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_group_norm_backward(const Tensor & dY,const Tensor & dmean,const Tensor & drstd,const Tensor & X,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & gamma,c10::SymInt N,const c10::SymInt & C,c10::SymInt HxW,int64_t group,double eps,std::array<bool,3> grad_input_mask)4933 infinitely_differentiable_native_group_norm_backward(
4934 const Tensor& dY,
4935 const Tensor& dmean,
4936 const Tensor& drstd,
4937 const Tensor& X,
4938 const Tensor& mean,
4939 const Tensor& rstd,
4940 const std::optional<Tensor>& gamma,
4941 c10::SymInt N,
4942 const c10::SymInt& C,
4943 c10::SymInt HxW,
4944 int64_t group,
4945 double eps,
4946 std::array<bool, 3> grad_input_mask) {
4947 const int64_t G = group;
4948 const auto D = C / G;
4949 c10::SymFloat s = c10::SymFloat(1.0) / c10::SymFloat(D * HxW);
4950 Tensor dX;
4951 Tensor dgamma;
4952 Tensor dbeta;
4953 const Tensor X_tensor = X.reshape_symint({N, G, D, HxW});
4954 const Tensor mean_tensor = mean.reshape_symint({N, G, 1, 1});
4955 const Tensor rstd_tensor = rstd.reshape_symint({N, G, 1, 1});
4956 Tensor dY_tensor;
4957 Tensor ds;
4958 Tensor db;
4959 if (dY.defined()) {
4960 dY_tensor = dY.reshape_symint({N, G, D, std::move(HxW)});
4961 ds = (dY_tensor * X_tensor).sum(3).unsqueeze_(-1);
4962 db = dY_tensor.sum(3).unsqueeze_(-1);
4963 }
4964 if (grad_input_mask[0]) {
4965 Tensor gamma_tensor;
4966 if (isDefined(gamma)) {
4967 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
4968 gamma_tensor = gamma->reshape_symint({1, G, D, 1});
4969 }
4970 const Tensor var =
4971 ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
4972 const Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
4973 Tensor dvar;
4974 if (drstd.defined()) {
4975 dvar = -0.5 * rstd_cube * drstd.view_symint({N, G, 1, 1});
4976 }
4977 if (dY.defined()) {
4978 const Tensor a =
4979 isDefined(gamma) ? rstd_tensor * gamma_tensor : rstd_tensor;
4980 Tensor b = (isDefined(gamma) ? (ds * gamma_tensor).sum(2) : ds.sum(2))
4981 .unsqueeze_(-2);
4982 Tensor c = (isDefined(gamma) ? (db * gamma_tensor).sum(2) : db.sum(2))
4983 .unsqueeze_(-2);
4984 b = (c * mean_tensor - b) * rstd_cube * s;
4985 c = -b * mean_tensor - c * rstd_tensor * std::move(s);
4986 dX = a * dY_tensor + b * X_tensor + c;
4987 if (dmean.defined() && drstd.defined()) {
4988 dX += var_mean_backward(
4989 dvar,
4990 dmean.view_symint({std::move(N), G, 1, 1}),
4991 X_tensor,
4992 IntArrayRef{2, 3},
4993 0,
4994 true);
4995 }
4996 dX = dX.reshape_as(X);
4997 } else if (dmean.defined() && drstd.defined()) {
4998 dX = var_mean_backward(
4999 dvar,
5000 dmean.view_symint({std::move(N), G, 1, 1}),
5001 X_tensor,
5002 IntArrayRef{2, 3},
5003 0,
5004 true)
5005 .reshape_as(X);
5006 }
5007 }
5008 if (grad_input_mask[1] && dY.defined()) {
5009 dgamma = ((ds - db * mean_tensor) * rstd_tensor)
5010 .sum(0)
5011 .reshape_as(toNonOptTensor(gamma));
5012 }
5013 if (grad_input_mask[2] && dY.defined()) {
5014 dbeta = db.sum(0).reshape_as(toNonOptTensor(gamma));
5015 }
5016
5017 return std::make_tuple(dX, dgamma, dbeta);
5018 }
5019
_trilinear_backward(const Tensor & grad_out,const std::optional<Tensor> & i1,const std::optional<Tensor> & i2,const std::optional<Tensor> & i3,IntArrayRef expand1,IntArrayRef expand2,IntArrayRef expand3,IntArrayRef sumdim,std::array<bool,3> grad_mask)5020 std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(
5021 const Tensor& grad_out,
5022 const std::optional<Tensor>& i1,
5023 const std::optional<Tensor>& i2,
5024 const std::optional<Tensor>& i3,
5025 IntArrayRef expand1,
5026 IntArrayRef expand2,
5027 IntArrayRef expand3,
5028 IntArrayRef sumdim,
5029 std::array<bool, 3> grad_mask) {
5030 Tensor grad_i1, grad_i2, grad_i3;
5031 if (grad_out.defined()) {
5032 if (grad_mask[0])
5033 grad_i1 =
5034 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
5035 at::_trilinear(grad_out, *i2, *i3, sumdim, expand2, expand3, expand1);
5036 if (grad_mask[1])
5037 grad_i2 =
5038 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
5039 at::_trilinear(*i1, grad_out, *i3, expand1, sumdim, expand3, expand2);
5040 if (grad_mask[2])
5041 grad_i3 =
5042 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
5043 at::_trilinear(*i1, *i2, grad_out, expand1, expand2, sumdim, expand3);
5044 }
5045 return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
5046 }
5047
log1p_backward(const Tensor & grad,const Tensor & self)5048 Tensor log1p_backward(const Tensor& grad, const Tensor& self) {
5049 // We must conditionally initialize this using to_dense if sparse, sparse
5050 // addition is not supported without exact shape match
5051 Tensor self_p1_conj;
5052 if (self.layout() == c10::kSparse || self.layout() == c10::kSparseCsr ||
5053 self.layout() == c10::kSparseCsc || self.layout() == c10::kSparseBsr ||
5054 self.layout() == c10::kSparseBsc) {
5055 // The warning only applies to the sparsity of self, dense grad is never
5056 // materialized so if self is strided and grad is sparse nothing unexpected
5057 // happens memory wise
5058 TORCH_WARN(
5059 "log1p_backward: received self with sparse layout, but backward requires materialization of a dense tensor with this shape");
5060 self_p1_conj = (self.to_dense() + 1).conj();
5061 } else {
5062 // Although calling self.to_dense() would just return self when it has
5063 // strided layout, that would breaks functorch tests.
5064 self_p1_conj = (self + 1).conj();
5065 }
5066 if (grad.layout() == c10::kSparse || grad.layout() == c10::kSparseCsr ||
5067 grad.layout() == c10::kSparseCsc || grad.layout() == c10::kSparseBsr ||
5068 grad.layout() == c10::kSparseBsc) {
5069 // If grad is sparse we can't divide by the n-d (self + 1).conj(), so we
5070 // must multiply by the recipricol, layout of grad is preserved which is
5071 // important to gradcheck
5072 return grad * self_p1_conj.reciprocal_();
5073 }
5074 return grad / self_p1_conj;
5075 }
5076
sinc_backward(const Tensor & grad,const Tensor & self)5077 Tensor sinc_backward(const Tensor& grad, const Tensor& self) {
5078 auto self_pi = self * M_PI;
5079 auto self_squared_pi = self * self * M_PI;
5080 auto out = grad *
5081 ((self_pi * self_pi.cos() - self_pi.sin()) / self_squared_pi).conj();
5082 return at::where(self_squared_pi == 0.0, at::zeros({}, grad.options()), out);
5083 }
5084
5085 // Because the backward of pad(input, pads) is just pad(grad_output, [-p for p
5086 // in pads])
constant_pad_nd_backward(const Tensor & grad,c10::SymIntArrayRef pad)5087 Tensor constant_pad_nd_backward(const Tensor& grad, c10::SymIntArrayRef pad) {
5088 auto negated_pad = pad.vec();
5089 std::transform(
5090 negated_pad.cbegin(),
5091 negated_pad.cend(),
5092 negated_pad.begin(),
5093 // NOLINTNEXTLINE(modernize-use-transparent-functors)
5094 std::negate<c10::SymInt>());
5095 return at::constant_pad_nd_symint(grad, negated_pad, 0);
5096 }
5097
embedding_dense_double_backward_symint(const Tensor & grad,const Tensor & indices,const c10::SymInt & padding_idx)5098 Tensor embedding_dense_double_backward_symint(
5099 const Tensor& grad,
5100 const Tensor& indices,
5101 const c10::SymInt& padding_idx) {
5102 // since first backward takes care of scaling by frequency,
5103 // we don't need to worry about it here.
5104 auto gg_weight = grad.index_select(0, indices.reshape(-1));
5105
5106 // reshape gradient as per the shape of indices
5107 auto size = indices.sizes().vec();
5108 size.push_back(-1);
5109
5110 if (padding_idx >= 0) {
5111 gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
5112 }
5113 return gg_weight.view(size);
5114 }
5115
index_backward(Tensor zeros_like_self,const torch::List<std::optional<Tensor>> & indices,const Tensor & grad)5116 Tensor index_backward(
5117 Tensor zeros_like_self,
5118 const torch::List<std::optional<Tensor>>& indices,
5119 const Tensor& grad) {
5120 return (areAnyTensorSubclassLike({zeros_like_self, grad}) ||
5121 areAnyOptionalTensorSubclassLike(indices))
5122 ? zeros_like_self.index_put(indices, grad, true)
5123 : at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
5124 }
5125
_cudnn_ctc_loss_backward(const Tensor & grad_out,const Tensor & loss,const Tensor & raw_grad,bool zero_infinity)5126 Tensor _cudnn_ctc_loss_backward(
5127 const Tensor& grad_out,
5128 const Tensor& loss,
5129 const Tensor& raw_grad,
5130 bool zero_infinity) {
5131 if (zero_infinity) {
5132 return at::where(
5133 loss.unsqueeze(0).unsqueeze(2) == 0,
5134 at::zeros({}, raw_grad.options()),
5135 raw_grad * grad_out.unsqueeze(0).unsqueeze(2));
5136 } else {
5137 return raw_grad * grad_out.unsqueeze(0).unsqueeze(2);
5138 }
5139 }
5140
any_variable_defined(const variable_list & variables)5141 bool any_variable_defined(const variable_list& variables) {
5142 for (const auto& variable : variables) {
5143 if (variable.defined()) {
5144 return true;
5145 }
5146 }
5147 return false;
5148 }
5149
5150 // Derivations for the householder_product.backward method.
5151 //
5152 // Given a sequence of vectors v_1, ..., v_n and a sequence of scalars tau_1,
5153 // ..., tau_k, the torch.linalg.householder_product computes the firt n columns
5154 // of the following product: Q = (I - tau_1 v_1 v_1^H) ... (I - tau_k v_k
5155 // v_k^H). Let
5156 // H_i(sigma) := I - sigma v_i v_i^H, so Q = (H_1(sigma_1) ...
5157 // H_k(sigma_k)[:, :k]; H_i_minus = H_1(tau_1) ... H_{i - 1}(tau_{i - 1}),
5158 // with H_1_minus := I; H_i_plus = H_{i + 1}(tau_{i + 1}) ... H_k(tau_k)
5159 // with H_k_plus := I;
5160 //
5161 // Forward AD:
5162 // dQ = sum_{i = 1}^k H_i_minus (-dtau_i v_i v_i^H - tau_i dv_i v_i^H - tau_i
5163 // v_i dv_i^H) H_i_plus.
5164 //
5165 // Backward AD:
5166 // Tr(Q_grad^H dQ) = sum_{i = 1}^k Tr(H_i_plus Q_grad^H H_i_minus (-dtau_i v_i
5167 // v_i^H - tau_i dv_i v_i^H - tau_i v_i dv_i^H)). Let K_i := H_i_plus Q_grad^H
5168 // H_i_minus, then the gradients are v_i_grad = (-tau_i v_i^H K_i)^H - tau_i K_i
5169 // v_i, tau_i_grad = Tr(-v_i^H K_i v_i).conj(). NOTE: the algorithms ignores
5170 // that only n columns of Q are observed, so there is no need in recomputing Q
5171 // to full completion.
5172 //
5173 // Note that K_{i + 1} = H_{i + 1}^{-1} K_i H_i, so we can compute v_i_grad,
5174 // tau_i_grad one by one by just efficiently updating K_i if that is possible.
5175 // Multiplying with H_i from the right could be done with matrix-vector
5176 // products, but what about the inverse H_{i + 1}^{-1} and does it even exist?
5177 // Luckily, under some assumptions, H_{i + 1}^{-1} exists and admits a
5178 // representation as H_i(sigma_i) for some sigma_i, so the left update is also
5179 // could be done with matrix-vector and not matrix-matrix products.
5180 //
5181 // Let H(tau) := I - tau v v^H.
5182 // H(tau) has eigenvalues 1 with multiplicity (m - 1) with eigenvectors
5183 // orthogonal to v, and an eigenvalue (1 - tau ||v||^2) with the corresponding
5184 // eigenvector v / ||v||. If (1 - tau ||v||^2) != 0, H(tau) is invertible. If (1
5185 // - tau ||v||^2) != 0, then with sigma defined as sigma := tau / (||v||^2 tau -
5186 // 1) we get that H(tau) H(sigma) = H(sigma) H(tau) = I, so H(sigma) is the
5187 // inverse of H(tau).
5188 //
5189 // WARNING: the algorithm below assumes that H_i(tau_i) are all invertible, so
5190 // it expects that (1 - tau_i ||v_i||^2) != 0 for all i.
5191 // We would like to point out that if there is H_i(tau_i) which is not
5192 // invertible, the householder_product is still differentiable! We will not be
5193 // able to compute K_i efficiently in such cases, however, as evaluating of each
5194 // K_i will amount to calls to ORGQR to be able to compute H_i_plus.
5195
5196 // This function computes either the product between
5197 // (I - tau u v^H) and K (in-place or not) with `condition_with_I = true`, or
5198 // between
5199 // (-tau u v^H) and K (out-of-place only) with `condition_with_I = false`.
5200 // Parameter `left` controls whether the matrix K is multiplied from the left or
5201 // from the right.
5202 // Additionally, when the computation is done in-place, we exploit that the
5203 // first `k` coordinates of `u_full/v_full` are zeros.
apply_simple_transformation(const c10::SymInt & m,const c10::SymInt & k,const Tensor & u_full,const Tensor & v_full,const Tensor & t,Tensor & K,bool modify_K_in_place=true,bool condition_with_I=true,bool left=true)5204 static Tensor apply_simple_transformation(
5205 const c10::SymInt& m,
5206 const c10::SymInt& k,
5207 const Tensor& u_full,
5208 const Tensor& v_full,
5209 const Tensor& t,
5210 Tensor& K,
5211 bool modify_K_in_place = true,
5212 bool condition_with_I = true,
5213 bool left = true) {
5214 // we assume u_full is a vector of dimension (..., m, 1), t is a scalar of
5215 // dimension (..., 1)
5216
5217 // TODO: matrix-vector products in the code below are dispatched to
5218 // matrix-matrix products. We either need to extend matmul to support batched
5219 // matrix-vector products, or implement a batched variant of mv. We could
5220 // enable mv for inputs which are not batched, but it is not done to eliminate
5221 // the code duplication.
5222
5223 // returns (I - t u v^H) K or -t u v^H K
5224 if (left) {
5225 if (modify_K_in_place) {
5226 auto v = u_full.narrow_symint(-2, k, m - k);
5227 auto u = v_full.narrow_symint(-2, k, m - k)
5228 .mH()
5229 .matmul(K.narrow_symint(-2, k, m - k));
5230 K.narrow_symint(-2, k, m - k).sub_((t.unsqueeze(-1) * v) * u);
5231 return K;
5232 } else {
5233 auto transformation = (t.unsqueeze(-1) * u_full) * v_full.mH().matmul(K);
5234 return condition_with_I ? K - transformation : -transformation;
5235 }
5236 }
5237 // returns K (I - t u v^H) or -K t u v^H
5238 else {
5239 if (modify_K_in_place) {
5240 auto v = u_full.narrow_symint(-2, k, m - k);
5241 auto u =
5242 K.narrow_symint(-1, k, m - k)
5243 .matmul(t.unsqueeze(-1) * v_full.narrow_symint(-2, k, m - k));
5244 K.narrow_symint(-1, k, m - k).sub_(u * v.mH());
5245 return K;
5246 } else {
5247 auto transformation = K.matmul(t.unsqueeze(-1) * u_full) * v_full.mH();
5248 return condition_with_I ? K - transformation : -transformation;
5249 }
5250 }
5251 };
5252
householder_product_backward(const Tensor & grad,const Tensor & result,const Tensor & input_,const Tensor & tau,const bool flip_order)5253 std::tuple<Tensor, Tensor> householder_product_backward(
5254 const Tensor& grad,
5255 const Tensor& result,
5256 const Tensor& input_,
5257 const Tensor& tau,
5258 const bool flip_order) {
5259 // NOTE on `flip_order`: when flip_order is true,
5260 // the algorithm below reverses the processing direction from
5261 // range(k) to range(k - 1, -1, -1) in the main loop, and left/right
5262 // Householder projection applications get flipped.
5263 // The comments below about the algorithmic details assume flip_order = false.
5264 if (!grad.defined() || input_.sym_numel() == 0 || tau.sym_numel() == 0) {
5265 return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
5266 }
5267 auto m = input_.sym_size(-2);
5268 // guard_int is due to irange calls below
5269 auto k = tau.sym_size(-1).guard_int(__FILE__, __LINE__);
5270
5271 // forward operates only over the lower triangular part with the assumption
5272 // that the diagonal of input is filled with 1s.
5273 auto input = input_.tril(-1);
5274 input.diagonal(0, -2, -1).fill_(1.0);
5275
5276 // compute sigma such that
5277 // H(sigma_i) == H(tau_i)^{-1}.
5278 // If the input to householder_product comes from GEQRF,
5279 // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be
5280 // invertible. This follows from the documentation
5281 // https://www.netlib.org/lapack/lug/node128.html, and tau always satisfying
5282 // the condition |tau|^2 ||v||^2 == 2 * Re(tau).
5283 auto input_first_k_cols = input.narrow(-1, 0, k);
5284 auto input_first_k_cols_norm_squared =
5285 (input_first_k_cols * input_first_k_cols.conj()).sum(-2);
5286 auto sigma = tau / (tau * input_first_k_cols_norm_squared - 1.0);
5287
5288 auto K = result.matmul(grad.mH());
5289
5290 // The algorithm updates K by multiplying it from the left/right with
5291 // Householder reflectors. If only single backward is run, we modify K
5292 // in-place and exploit triangularity of the input. With higher order
5293 // derivatives we cannot rewrite the storage of K, hence we use much less
5294 // efficient out-of-place methods.
5295 //
5296 // if only first-order derivative is expected, we can modify K in-place for
5297 // better performance
5298 bool modify_K_in_place = !at::GradMode::is_enabled();
5299
5300 // This method exploits that at k-th iteration vector v_k has only elements
5301 // v_k[k:] which are non-zero.
5302 auto update_grad = [&m](
5303 int64_t k,
5304 const Tensor& v_full,
5305 const Tensor& t,
5306 const Tensor& K) -> std::tuple<Tensor, Tensor> {
5307 // v_full is a vector of dimension (..., m, 1), t is a scalar of dimension
5308 // (..., 1)
5309 auto v = v_full.narrow_symint(-2, k, m - k);
5310 auto vHK = v.mH().matmul(K.narrow_symint(-2, k, m - k));
5311 auto Kv = K.narrow_symint(-1, k, m - k).matmul(v);
5312 auto t_unsqueezed = t.unsqueeze(-1);
5313 auto v_grad = (-t_unsqueezed * vHK).conj().squeeze(-2) -
5314 (t_unsqueezed * Kv).squeeze(-1);
5315 auto tau_grad = -(vHK.narrow_symint(-1, k, m - k).matmul(v)).conj();
5316 return std::make_tuple(v_grad.unsqueeze(-1), tau_grad.squeeze(-1));
5317 };
5318
5319 auto apply_householder_reflector = [m, modify_K_in_place](
5320 int64_t k,
5321 const Tensor& v_full,
5322 const Tensor& t,
5323 Tensor& K,
5324 bool left = true) -> Tensor {
5325 return apply_simple_transformation(
5326 m,
5327 k,
5328 v_full,
5329 v_full,
5330 t,
5331 K,
5332 modify_K_in_place,
5333 /*condition_with_I=*/true,
5334 left);
5335 };
5336
5337 const auto flip_i = [flip_order, k](int64_t i) -> int64_t {
5338 return !flip_order ? i : k - i - 1;
5339 };
5340 const auto next_i = [flip_order](int64_t i) -> int64_t {
5341 return !flip_order ? ++i : --i;
5342 };
5343 const auto apply_left = !flip_order;
5344
5345 // K <- H_0^{-1} @ K
5346 const auto zero_idx = flip_i(0);
5347 K = apply_householder_reflector(
5348 zero_idx,
5349 input.narrow(-1, zero_idx, 1),
5350 sigma.narrow(-1, zero_idx, 1),
5351 K,
5352 /*left=*/apply_left);
5353
5354 Tensor input_grad, tau_grad;
5355 // For Composite Compliance, we can't copy a Subclass into a Regular Tensor,
5356 // so we use out-of-place ops with equivalent output.
5357 // NOTE: We can't use `new_zeros` directly as `input`, 'tau' or `grad` can
5358 // be Tensor Subclass and we don't want to make assumption about which
5359 // one to choose for creating output buffer.
5360 // eg. if both are BatchedTensor at different level.
5361 if (areAnyTensorSubclassLike({input, tau, K})) {
5362 // k + 1 if input_grads hold a matrix of zeros for inactive parts of input.
5363 auto input_grads = std::vector<Tensor>(k < input.sym_size(-1) ? k + 1 : k);
5364 auto tau_grads = std::vector<Tensor>(k);
5365
5366 for (const auto i_idx : c10::irange(k)) {
5367 auto i = flip_i(i_idx);
5368 // NOTE: narrow will unsqueeze(-1)
5369 auto v_i = input.narrow(-1, i, 1);
5370 auto t_i = tau.narrow(-1, i, 1);
5371
5372 std::tie(input_grads[i], tau_grads[i]) = update_grad(i, v_i, t_i, K);
5373
5374 // K <- H_{i + 1}^{-1} @ K @ H_i
5375 if (i != flip_i(k - 1)) {
5376 auto i_next = next_i(i);
5377 auto v_i_next = input.narrow(-1, i_next, 1);
5378 auto s_i_next = sigma.narrow(-1, i_next, 1);
5379 K = apply_householder_reflector(
5380 i_next, v_i_next, s_i_next, K, /*left=*/apply_left);
5381 K = apply_householder_reflector(i, v_i, t_i, K, /*left=*/!apply_left);
5382 }
5383 }
5384
5385 // Only first k columns are active in forward.
5386 // zero gradients for the inactive input.
5387 if (k < input.sym_size(-1)) {
5388 auto zero_grad_shape =
5389 at::SymDimVector(input_.sym_sizes().slice(0, input_.dim() - 1));
5390 zero_grad_shape.push_back(input.sym_size(-1) - k);
5391 auto zero_grad = at::zeros_symint(zero_grad_shape, input_.options());
5392 input_grads[k] = zero_grad;
5393 }
5394
5395 input_grad = at::cat(input_grads, -1);
5396 tau_grad = at::cat(tau_grads, -1);
5397 } else {
5398 input_grad = at::zeros_like(input_);
5399 tau_grad = at::zeros_like(tau);
5400 for (const auto i_idx : c10::irange(k)) {
5401 auto i = flip_i(i_idx);
5402 // NOTE: narrow will unsqueeze(-1)
5403 auto v_i = input.narrow(-1, i, 1);
5404 auto t_i = tau.narrow(-1, i, 1);
5405
5406 auto [v_i_grad, tau_i_grad] = update_grad(i, v_i, t_i, K);
5407 input_grad.select(-1, i).copy_(v_i_grad.squeeze(-1));
5408 tau_grad.select(-1, i).copy_(tau_i_grad.squeeze(-1));
5409
5410 // K <- H_{i + 1}^{-1} @ K @ H_i
5411 if (i != flip_i(k - 1)) {
5412 auto i_next = next_i(i);
5413 auto v_i_next = input.narrow(-1, i_next, 1);
5414 auto s_i_next = sigma.narrow(-1, i_next, 1);
5415 K = apply_householder_reflector(
5416 i_next, v_i_next, s_i_next, K, /*left=*/apply_left);
5417 K = apply_householder_reflector(i, v_i, t_i, K, /*left=*/!apply_left);
5418 }
5419 }
5420 }
5421
5422 // forward operates only over the lower-triangular part of the input
5423 // excluding the main diagonal, hence the gradient is also lower-triangular.
5424 input_grad.tril_(-1);
5425
5426 return std::make_tuple(input_grad, tau_grad);
5427 }
5428
5429 // We refer to the derivations described above the method
5430 // `apply_simple_transformation`
householder_product_jvp(const Tensor & dV_,const Tensor & dtau,const Tensor & prod,const Tensor & V_,const Tensor & tau)5431 Tensor householder_product_jvp(
5432 const Tensor& dV_,
5433 const Tensor& dtau,
5434 const Tensor& prod,
5435 const Tensor& V_,
5436 const Tensor& tau) {
5437 auto m = V_.sym_size(-2);
5438 auto k = tau.size(-1);
5439
5440 // forward operates only over the lower triangular part with the assumption
5441 // that the diagonal of input is filled with 1s.
5442 auto V = V_.tril(-1);
5443 V.diagonal(0, -2, -1).fill_(1.0);
5444 auto dV = dV_.tril(-1);
5445
5446 // compute sigma such that
5447 // H(sigma_i) == H(tau_i)^{-1}.
5448 // If the input to householder_product comes from GEQRF,
5449 // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be
5450 // invertible. This follows from the documentation
5451 // https://www.netlib.org/lapack/lug/node128.html, and tau always satisfying
5452 // the condition |tau|^2 ||v||^2 == 2 * Re(tau).
5453 auto V_first_k_cols = V.narrow(-1, 0, k);
5454 auto V_first_k_cols_norm_squared =
5455 (V_first_k_cols * V_first_k_cols.conj()).sum(-2);
5456 auto sigma = tau / (tau * V_first_k_cols_norm_squared - 1.0);
5457
5458 auto apply_householder_reflector = [m](const Tensor& v_full,
5459 const Tensor& t,
5460 Tensor& K,
5461 bool left = true) -> Tensor {
5462 return apply_simple_transformation(
5463 // setting `modify_K_in_place = true` causes CUDA memory leaks in OpInfo
5464 // tests of forward AD for that reason we ignore `k` by passing zero
5465 m,
5466 /*k=*/0,
5467 v_full,
5468 v_full,
5469 t,
5470 K,
5471 /*modify_K_in_place=*/false,
5472 /*condition_with_I=*/true,
5473 left);
5474 };
5475
5476 // computes (-t u v^H) K
5477 auto apply_simple_product = [m](const Tensor& u_full,
5478 const Tensor& v_full,
5479 const Tensor& t,
5480 Tensor& K) -> Tensor {
5481 return apply_simple_transformation(
5482 // since ``modify_K_in_place = false`, we can ignore `k` and pass
5483 // arbitrary value
5484 m,
5485 /*k=*/0,
5486 u_full,
5487 v_full,
5488 t,
5489 K,
5490 /*modify_K_in_place=*/false,
5491 /*condition_with_I=*/false,
5492 /*left=*/true);
5493 };
5494
5495 auto H_plus = prod.detach().clone();
5496 IntArrayRef batch_vector_shape(V.sizes().data(), V.dim() - 1);
5497 auto H_minus =
5498 at::diag_embed(at::ones({1}, V.options()).expand(batch_vector_shape));
5499
5500 auto dprod = at::zeros_like(prod);
5501 for (const auto i : c10::irange(k)) {
5502 auto v_i = V.narrow(-1, i, 1);
5503 auto dv_i = dV.narrow(-1, i, 1);
5504 auto tau_i = tau.narrow(-1, i, 1);
5505 auto dtau_i = dtau.narrow(-1, i, 1);
5506 auto sigma_i = sigma.narrow(-1, i, 1);
5507
5508 H_plus = apply_householder_reflector(v_i, sigma_i, H_plus, /*left=*/true);
5509
5510 // `H_minus_dH_i_H_plus` = H_1 * ... * H_{i-1} dH_i * H_{i+1} * ...
5511 auto H_minus_dH_i_H_plus = H_minus.matmul(
5512 apply_simple_product(v_i, v_i, dtau_i, H_plus) +
5513 apply_simple_product(dv_i, v_i, tau_i, H_plus) +
5514 apply_simple_product(v_i, dv_i, tau_i, H_plus));
5515 // For Composite Compliance, if `intermediate` is a Tensor-Subclass,
5516 // we use out-of-place variant of add.
5517 if (at::isTensorSubclassLike(H_minus_dH_i_H_plus)) {
5518 dprod = dprod.add(H_minus_dH_i_H_plus);
5519 } else {
5520 dprod.add_(H_minus_dH_i_H_plus);
5521 }
5522
5523 H_minus = apply_householder_reflector(v_i, tau_i, H_minus, /*left=*/false);
5524 }
5525
5526 return dprod;
5527 }
5528
ormqr_backward(const Tensor & grad,const Tensor & result,const Tensor & self,const Tensor & tau,const Tensor & other,bool left,bool transpose,std::array<bool,3> grad_output_mask)5529 std::tuple<Tensor, Tensor, Tensor> ormqr_backward(
5530 const Tensor& grad,
5531 const Tensor& result,
5532 const Tensor& self,
5533 const Tensor& tau,
5534 const Tensor& other,
5535 bool left,
5536 bool transpose,
5537 std::array<bool, 3> grad_output_mask) {
5538 Tensor self_grad, tau_grad, other_grad;
5539
5540 if (!grad.defined()) {
5541 return std::make_tuple(self_grad, tau_grad, other_grad);
5542 }
5543
5544 const auto self_requires_grad = grad_output_mask[0];
5545 const auto tau_requires_grad = grad_output_mask[1];
5546 const auto other_requires_grad = grad_output_mask[2];
5547
5548 if (other_requires_grad) {
5549 other_grad = at::ormqr(self, tau, grad, left, !transpose);
5550 }
5551 if (self_requires_grad || tau_requires_grad) {
5552 if (left ^ transpose) {
5553 // Assume left = true and transpose = false. The case with
5554 // left = false and transpose = true is very much similar with just
5555 // transposed arguments passed into householder_product_backward.
5556 // Ormqr computes B = H_1 * ... * H_k * A.
5557 // The sensivity wrt H_i is given by (see notes in
5558 // householder_product_backward) Tr(H_i_plus B B_grad^H H_i_minus dH_i),
5559 // so, since householder_product_backward respects `for i in range(k)`, we
5560 // could reuse householder_product_backward with
5561 // householder_product_backward.grad = grad and
5562 // householder_product_backward.result = result.
5563 const auto hpb_grad = !transpose ? grad : grad.mH();
5564 const auto hpb_result = !transpose ? result : result.mH();
5565 std::tie(self_grad, tau_grad) =
5566 householder_product_backward(hpb_grad, hpb_result, self, tau);
5567 } else {
5568 // Assuming left = false and transpose = false. The case with
5569 // left = true and transpose = true is very much similar with just
5570 // transposed arguments passed into householder_product_backward.
5571 // In this case Ormqr computes B = H_1 * ... * H_k * A and the sensitivity
5572 // wrt H_i becomes Tr(H_i_plus B_grad^H B H_i_minus dH_k).
5573 // We could see that the role of `grad` and `result` in
5574 // householder_product_backward gets "swapped" and "transposed" and that
5575 // in order to compute H_k_grad efficiently we would need to compute grads
5576 // in reversed order (`for i in range(k - 1, -1, -1)`). Hence we reuse
5577 // householder_product_backward with householder_product_backward.grad =
5578 // result.mH, householder_product_backward.result = grad.mH,
5579 // householder_product_backward.flip_order = true.
5580 const auto hpb_grad = !transpose ? result.mH() : result;
5581 const auto hpb_result = !transpose ? grad.mH() : grad;
5582 std::tie(self_grad, tau_grad) = householder_product_backward(
5583 hpb_grad, hpb_result, self, tau, /*flip_order=*/true);
5584 }
5585 }
5586
5587 return std::make_tuple(self_grad, tau_grad, other_grad);
5588 }
5589
polar_backward(const Tensor & grad,const Tensor & result)5590 std::tuple<Tensor, Tensor> polar_backward(
5591 const Tensor& grad,
5592 const Tensor& result) {
5593 Tensor grad_abs, grad_angle;
5594 if (grad.defined()) {
5595 auto grad_conj = grad.conj();
5596 grad_abs = at::real(grad_conj * at::sgn(result));
5597 auto result_mul_1_j = result * Scalar(c10::complex<double>{0.0, 1.0});
5598 grad_angle = at::real(grad_conj * result_mul_1_j);
5599 }
5600 return std::make_tuple(grad_abs, grad_angle);
5601 }
5602
i1_backward(const Tensor & grad,const Tensor & self,const Tensor & result)5603 Tensor i1_backward(
5604 const Tensor& grad,
5605 const Tensor& self,
5606 const Tensor& result) {
5607 return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "i1_backward", [&]() {
5608 // For x = 0, the correct gradient is 0.5,
5609 // however due to floating point computation we get NaN.
5610 // So we manually update gradient for x=0
5611 auto eps = std::numeric_limits<scalar_t>::epsilon();
5612 auto self_is_not_tiny = self.abs() > eps;
5613
5614 // Following `where` is needed as `where` computes gradients,
5615 // even for the part which didn't affect the output.
5616 // Look at https://github.com/pytorch/pytorch/issues/52248
5617 // Update if and when this is fixed.
5618 auto safe_self =
5619 at::where(self_is_not_tiny, self, at::full({}, eps, self.options()));
5620 auto gradx = (safe_self.i0() - (result * safe_self.reciprocal()));
5621 return grad *
5622 at::where(self_is_not_tiny, gradx, at::full({}, 0.5, self.options()));
5623 });
5624 }
5625
i1e_backward(const Tensor & grad,const Tensor & self,const Tensor & result)5626 Tensor i1e_backward(
5627 const Tensor& grad,
5628 const Tensor& self,
5629 const Tensor& result) {
5630 return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "i1e_backward", [&]() {
5631 // For x = 0, the correct gradient is 0.5,
5632 // however due to floating point computation we get NaN.
5633 // So we manually update gradient for x=0
5634 auto eps = std::numeric_limits<scalar_t>::epsilon();
5635 auto self_is_not_tiny = self.abs() > eps;
5636
5637 // Following `where` is needed as `where` computes gradients,
5638 // even for the part which didn't affect the output.
5639 // Look at https://github.com/pytorch/pytorch/issues/52248
5640 // Update if and when this is fixed.
5641 auto safe_self =
5642 at::where(self_is_not_tiny, self, at::full({}, eps, self.options()));
5643 auto gradx =
5644 (at::special_i0e(safe_self) -
5645 result * (safe_self.sgn() + safe_self.reciprocal()));
5646 return grad *
5647 at::where(self_is_not_tiny, gradx, at::full({}, 0.5, self.options()));
5648 });
5649 }
5650
5651 // lu_solve is a map (LU, P, B) -> (PLU)^{-1} B,
5652 // where LU = L + U - I and P is a permutation matrix, and is fixed.
5653 //
5654 // Let 1 = ones_like(LU),
5655 // 1_U = 1.triu(),
5656 // 1_L = 1.tril(-1)
5657 // * := the Hadamard (element-wise) product
5658 //
5659 // Forward AD:
5660 //
5661 // Let X := U^{-1} L^{-1} P^T B be the output of the function.
5662 // Also, the LU input of the function could be represented as
5663 // LU = (L - I) + U.
5664 //
5665 // Differentiating LU = L + U - I produces:
5666 // dLU = dL + dU.
5667 // Noting that dL and dU are lower- and upper-triangular, respectively,
5668 // and that the diagonal of L is never explicitly exposed, so
5669 // diag(dL) = 0, it follows
5670 // dL = dLU * 1_L,
5671 // dU = dLU * 1_U.
5672 //
5673 // Differentiating X = U^{-1} L^{-1} P^T B produces:
5674 // dX = dU^{-1} L^{-1} P^T B + U^{-1} dL^{-1} P^T B + U^{-1} L^{-1} P^T dB
5675 // Note that for any invertible matrix A we have A A^{-1} = I, hence
5676 // dA A^{-1} + A dA^{-1} = 0 => dA^{-1} = -A^{-1} dA A^{-1}.
5677 // Inserting it back into the definition of dX gives:
5678 // dX = -U^{-1} dU U^{-1} L^{-1} P^T B - U^{-1} L^{-1} dL L^{-1} P^T B + U^{-1}
5679 // L^{-1} P^T dB dX = -U^{-1} dU X - U^{-1} L^{-1} dL U X + U^{-1} L^{-1} P^T dB
5680 //
5681 // Backward AD:
5682 //
5683 // Using the definition of dL, dU from above:
5684 // Tr(L_grad^H dL) + Tr(U_grad^H dU) = Tr(L_grad^H (dLU * 1_L)) + Tr(U_grad^H
5685 // (dLU * 1_U))
5686 // = [using Tr(A (B * C)) = Tr((A * B^T) C)
5687 // = Tr((L_grad^H * 1_L^T) dLU) + Tr((U_grad^H
5688 // * 1_U^T) dLU),
5689 // hence
5690 // LU_grad = L_grad * 1_L + U_grad * 1_U (!!!)
5691 //
5692 // Then, transposing the formula for dX above we get:
5693 // B_grad = P L^{-H} U^{-H} X_grad = lu_solve(X_grad, LU_data, LU_pivots,
5694 // /*adjoint=*/true) U_grad = -U^{-H} X_grad X^H L_grad = L^{-H} U_grad U^H
5695 // After inserting U_grad and L_grad into (!!!) we get the value for LU_grad.
5696
linalg_lu_solve_LU(const Tensor & gX,const Tensor & LU,const Tensor & pivots,const Tensor & X,const bool left,const bool adjoint)5697 Tensor linalg_lu_solve_LU(
5698 const Tensor& gX,
5699 const Tensor& LU,
5700 const Tensor& pivots,
5701 const Tensor& X,
5702 const bool left,
5703 const bool adjoint) {
5704 // From linalg_lu_solve_jvp we have that:
5705 // left = True, adjoint = True: A^HX = B
5706 // left = True, adjoint = False: AX = B
5707 // left = False, adjoint = True: AX^H = B^H
5708 // left = False, adjoint = False: A^HX^H = B^H
5709 // let op_1(A) = A^H or op_1(A) = A according to the list above
5710 // same with op_2(X) and op_3(B)
5711 // We have that letting S = lu_solve(LU, pivots, dB, left, adjoint)
5712 // the JVP formula reads
5713 // if left != adjoint:
5714 // dX = op_2(-U^{-1}(dU + L^{-1}dL U)op_2(X)) + S
5715 // else:
5716 // dX = op_2(op_1(-op_3(X)^H P(LdUU^{-1} + dL)L^{-1} P^T)) + S
5717 // So computing the adjoint of this operation we get that, using an auxiliary
5718 // variable gR if left != adjoint:
5719 // gR = U^{-H}op_2(-gX)op_2(X)^H
5720 // gU = gR.triu()
5721 // gL = (L^{-H} gR U^H).tril(-1)
5722 // else:
5723 // gR = -P^T op_3(X)op_1(op_2(gX))PL^{-H}
5724 // gL = gR.tril(-1)
5725 // gU = (L^H gR U^{-H}).triu()
5726 // gLU = gL + gU
5727
5728 at::NoTF32Guard disable_tf32;
5729 auto [P, L, U] = at::lu_unpack(
5730 LU, pivots, /*unpack_data=*/true, /*unpack_pivots=*/left == adjoint);
5731 // TODO Optimise the order of the operations to avoid operating on large
5732 // tensors unnecessarily
5733 // The logic should be: if n < k == left then multiply the gX and X first
5734 // (as it's done now) Otherwise multiply them last
5735 if (left != adjoint) {
5736 // gR = U^{-H}op_2(-gX)op_2(X)^H
5737 auto gR = at::linalg_solve_triangular(
5738 U.mH(),
5739 -(left ? gX : gX.mH()).matmul(left ? X.mH() : X),
5740 /*upper*/ false);
5741 // gL = (L^{-H} gR U^H).tril(-1)
5742 auto gL = at::linalg_solve_triangular(
5743 L.mH(),
5744 gR.matmul(U.mH()),
5745 /*upper*/ true,
5746 /*left*/ true,
5747 /*unitriangular*/ true)
5748 .tril(-1);
5749 ;
5750 return gL + gR.triu();
5751 } else {
5752 // gR = -P^T op_3(X)op_1(op_2(gX))P
5753 auto gR =
5754 -P.mT().matmul(left ? X : X.mH()).matmul(left ? gX.mH() : gX).matmul(P);
5755 // gR = gR L^{-H}
5756 gR = at::linalg_solve_triangular(
5757 L.mH(), gR, /*upper*/ true, /*left*/ false, /*unitriangular*/ true);
5758 // gU = (L^H gR U^{-H}).triu()
5759 auto gU = at::linalg_solve_triangular(
5760 U.mH(), L.mH().matmul(gR), /*upper*/ false, /*left*/ false)
5761 .triu();
5762 return gR.tril(-1) + gU;
5763 }
5764 }
5765
linalg_lu_solve_jvp(const Tensor & X,const Tensor & LU,const Tensor & pivots,const Tensor & dLU,const Tensor & dB,const bool left,const bool adjoint)5766 Tensor linalg_lu_solve_jvp(
5767 const Tensor& X,
5768 const Tensor& LU,
5769 const Tensor& pivots,
5770 const Tensor& dLU,
5771 const Tensor& dB,
5772 const bool left,
5773 const bool adjoint) {
5774 // We write the derivation in terms of some adjoint operations, as otherwise
5775 // we would need to write down 4 different proofs with 4 different
5776 // implementations and it'd be painful to derive and maintain Below, we just
5777 // use that X -> X^H is linear, so it commutes with the derivative The
5778 // derivation follows by differentiating op_1(PLU)op_2(X) = op_3(B)
5779
5780 // left = True, adjoint = True: A^HX = B
5781 // left = True, adjoint = False: AX = B
5782 // left = False, adjoint = True: AX^H = B^H
5783 // left = False, adjoint = False: A^HX^H = B^H
5784 // let op_1(A) = A^H or op_1(A) = A according to the list above
5785 // same with op_2(X) and op_3(B)
5786 // We have that letting S = lu_solve(LU, pivots, dB, left, adjoint)
5787 // the JVP formula reads
5788 // dX = op_2(op_1(-U^{-1}(dUU^{-1} + L^{-1}dL)L^{-1} P^T)op_3(B)) + S
5789
5790 at::NoTF32Guard disable_tf32;
5791 auto S = at::linalg_lu_solve(LU, pivots, dB, left, adjoint);
5792 if (left != adjoint) {
5793 // We see that when left != adjoint, op_1(A) = A, and we can substitute
5794 // A^{-1}op_3(B) by op_2(X) dX = op_2(-U^{-1}(dU + L^{-1}dL U)op_2(X)) + S
5795 // Let R = -U^{-1}(dU + L^{-1}dL U)
5796 auto R = at::linalg_solve_triangular(
5797 LU,
5798 dLU.tril(-1),
5799 /*upper*/ false,
5800 /*left*/ true,
5801 /*unitriangular*/ true);
5802 auto U = LU.triu();
5803 R = -at::linalg_solve_triangular(
5804 U, dLU.triu() + R.matmul(U), /*upper*/ true);
5805 // dX = op_2(R op_2(X)) + S
5806 return (left ? R.matmul(X) : X.matmul(R.mH())) + S;
5807 } else {
5808 // We see that when left == adjoint, op_1(A) = A^H
5809 // dX = op_2(op_1(-op_3(B)^H U^{-1}(dUU^{-1} + L^{-1}dL)L^{-1} P^T)) + S
5810 // Now, note that whenever adjoint == left, we have that
5811 // op_3(B)^H A^{-1} = op_3(X)^H
5812 // We can then rewrite the formula above in terms of X as
5813 // dX = op_2(op_1(-op_3(X)^H P(LdUU^{-1} + dL)L^{-1} P^T)) + S
5814 auto [P, L, U] = at::lu_unpack(LU, pivots);
5815 // Compute V = op_3(X)^H
5816 auto V = left ? X.mH() : X;
5817 // Compute the inner parens LdUU^{-1} + dL
5818 auto R = at::linalg_solve_triangular(
5819 U, L.matmul(dLU.triu()), /*upper*/ true, /*left*/ false) +
5820 dLU.tril(-1);
5821 // dX = op_2(op_1(-op_3(X)^H PRL^{-1} P^T)) + S
5822 R = at::linalg_solve_triangular(
5823 L,
5824 -V.matmul(P).matmul(R),
5825 /*upper*/ false,
5826 /*left*/ false,
5827 /*unitriangular*/ true)
5828 .matmul(P.mT());
5829 // dX = op_2(R^H) + S
5830 return (left ? R.mH() : std::move(R)) + S;
5831 }
5832 }
5833
linalg_solve_jvp(const Tensor & dA,const Tensor & dB,const Tensor & X,const Tensor & LU,const Tensor & pivots,const bool left,const bool use_A_T)5834 Tensor linalg_solve_jvp(
5835 const Tensor& dA,
5836 const Tensor& dB,
5837 const Tensor& X,
5838 const Tensor& LU,
5839 const Tensor& pivots,
5840 const bool left,
5841 const bool use_A_T) {
5842 at::NoTF32Guard disable_tf32;
5843 // For left=True (left=False is analogous)
5844 // dX = A^{-1}(dB - dAX)
5845
5846 // [NumPy compat] Case where the rhs is a vector.
5847 // We denote with an underscore vectors that have been converted to matrices
5848 // by `unsqueeze(-1)`
5849 const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X);
5850 const auto vector_to_matrix = [vector_case](const Tensor& X) {
5851 return vector_case ? X.unsqueeze(-1) : X;
5852 };
5853 const auto matrix_to_vector = [vector_case](const Tensor& X) {
5854 return vector_case ? X.squeeze(-1) : X;
5855 };
5856
5857 // This case is disallowed in the primal operation as A.shape = (*, 1, 1)
5858 TORCH_INTERNAL_ASSERT(left || !vector_case);
5859
5860 auto X_ = vector_to_matrix(X);
5861 auto dB_ = vector_to_matrix(dB);
5862 auto R_ = left ? dA.matmul(X_) : X_.matmul(dA);
5863 auto dX_ =
5864 at::linalg_lu_solve(LU, pivots, dB_ - R_, left, /*adjoint*/ use_A_T);
5865 return matrix_to_vector(dX_);
5866 }
5867
linalg_solve_backward(const Tensor & gX,const Tensor & X,const Tensor & A,const Tensor & LU,const Tensor & pivots,const bool left,const bool B_requires_grad)5868 std::tuple<Tensor, Tensor> linalg_solve_backward(
5869 const Tensor& gX,
5870 const Tensor& X,
5871 const Tensor& A,
5872 const Tensor& LU,
5873 const Tensor& pivots,
5874 const bool left,
5875 const bool B_requires_grad) {
5876 // for X = A^{-1}B
5877 // gB = A^{-H}gX
5878 // gA = -gB X^H
5879 at::NoTF32Guard disable_tf32;
5880 const auto A_requires_grad = A.requires_grad();
5881 if (!gX.defined() || (!A_requires_grad && !B_requires_grad)) {
5882 return {};
5883 }
5884
5885 // [NumPy compat] Case where the rhs is a vector.
5886 // We denote with an underscore vectors that have been converted to matrices
5887 // by `unsqueeze(-1)`
5888 const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X);
5889 const auto vector_to_matrix = [vector_case](const Tensor& X) {
5890 return vector_case ? X.unsqueeze(-1) : X;
5891 };
5892 const auto matrix_to_vector = [vector_case](const Tensor& X) {
5893 return vector_case ? X.squeeze(-1) : X;
5894 };
5895
5896 // If the user is going to compute higher order gradients, then we need to
5897 // recompute the LU and the pivots
5898 Tensor gB_;
5899 if (at::GradMode::is_enabled()) {
5900 gB_ = at::linalg_solve(A.mH(), vector_to_matrix(gX), left);
5901 } else {
5902 const auto use_A_T = A.is_contiguous() && !A.is_complex();
5903 gB_ = at::linalg_lu_solve(
5904 LU, pivots, vector_to_matrix(gX), left, /*adjoint*/ !use_A_T);
5905 }
5906
5907 Tensor gA_;
5908 if (A_requires_grad) {
5909 auto X_ = vector_to_matrix(X);
5910 gA_ = left ? -gB_.matmul(X_.mH()) : -X_.mH().matmul(gB_);
5911 }
5912 return std::make_tuple(
5913 A_requires_grad ? std::move(gA_) : Tensor{},
5914 B_requires_grad ? matrix_to_vector(gB_) : Tensor{});
5915 }
5916
solve_jvp(const Tensor & X,const Tensor & A,const Tensor & dA,const Tensor & dB)5917 Tensor solve_jvp(
5918 const Tensor& X,
5919 const Tensor& A,
5920 const Tensor& dA,
5921 const Tensor& dB) {
5922 return generic_solve_jvp(
5923 [](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
5924 return at::linalg_solve(A, dB - dA_contrib);
5925 },
5926 X,
5927 A,
5928 dA,
5929 dB);
5930 }
5931
lu_unpack_backward(const Tensor & L_grad,const Tensor & U_grad,const c10::SymInt & m,const c10::SymInt & n)5932 Tensor lu_unpack_backward(
5933 const Tensor& L_grad,
5934 const Tensor& U_grad,
5935 const c10::SymInt& m,
5936 const c10::SymInt& n) {
5937 if (!L_grad.defined() && !U_grad.defined()) {
5938 return {};
5939 }
5940 const auto k = std::min(m, n);
5941
5942 // Getters for the principal and complementary part of the matrices
5943 const auto get_L1 = [m, k](const Tensor& L) {
5944 return m == k ? L.tril(-1) : L.narrow_symint(-2, 0, k).tril(-1);
5945 };
5946 const auto get_L2 = [m, k](const Tensor& L) {
5947 return L.narrow_symint(-2, k, m - k);
5948 };
5949 const auto get_U1 = [n, k](const Tensor& U) {
5950 return n == k ? U.triu() : U.narrow_symint(-1, 0, k).triu();
5951 };
5952 const auto get_U2 = [n, k](const Tensor& U) {
5953 return U.narrow_symint(-1, k, n - k);
5954 };
5955
5956 if (L_grad.defined()) {
5957 if (U_grad.defined()) {
5958 if (m == n) {
5959 return L_grad.tril(-1) + U_grad.triu();
5960 } else {
5961 auto A1_grad = get_L1(L_grad) + get_U1(U_grad);
5962 auto A2_grad = m > n ? get_L2(L_grad) : get_U2(U_grad);
5963 const auto dim = m > n ? -2 : -1;
5964 return at::cat({std::move(A1_grad), std::move(A2_grad)}, /*dim=*/dim);
5965 }
5966 } else {
5967 if (m >= n) {
5968 return L_grad.tril(-1);
5969 } else {
5970 auto size = L_grad.sym_sizes().vec();
5971 size.end()[-1] = n - m;
5972 return at::cat(
5973 {L_grad.tril(-1), at::zeros_symint(size, L_grad.options())},
5974 /*dim=*/-1);
5975 }
5976 }
5977 } else {
5978 if (n >= m) {
5979 return U_grad.triu();
5980 } else {
5981 auto size = U_grad.sym_sizes().vec();
5982 size.end()[-2] = m - n;
5983 return at::cat(
5984 {U_grad.triu(), at::zeros_symint(size, U_grad.options())},
5985 /*dim=*/-2);
5986 }
5987 }
5988 }
5989
cat_jvp(const at::ITensorListRef & tensors,int64_t dim)5990 Tensor cat_jvp(const at::ITensorListRef& tensors, int64_t dim) {
5991 Tensor out_fw_grad;
5992
5993 auto materialized = tensors.materialize();
5994 auto any_defined = false;
5995 for (const Tensor& t : materialized) {
5996 any_defined |= isFwGradDefined(t);
5997 }
5998
5999 if (any_defined) {
6000 std::vector<Tensor> fw_grads;
6001
6002 for (const Tensor& t : materialized) {
6003 fw_grads.push_back(
6004 isFwGradDefined(t)
6005 ? t._fw_grad(/*level*/ 0)
6006 : at::_efficientzerotensor(t.sizes(), t.options()));
6007 }
6008
6009 out_fw_grad = at::cat(fw_grads, dim);
6010 }
6011
6012 return out_fw_grad;
6013 }
6014
block_diag_jvp(at::TensorList tensors)6015 Tensor block_diag_jvp(at::TensorList tensors) {
6016 Tensor out_fw_grad;
6017
6018 auto any_defined = false;
6019 for (const auto& t : tensors) {
6020 any_defined |= isFwGradDefined(t);
6021 }
6022
6023 if (any_defined) {
6024 std::vector<Tensor> fw_grads;
6025 fw_grads.reserve(tensors.size());
6026
6027 for (const auto& t : tensors) {
6028 fw_grads.push_back(
6029 isFwGradDefined(t)
6030 ? t._fw_grad(/*level*/ 0)
6031 : at::_efficientzerotensor(t.sizes(), t.options()));
6032 }
6033
6034 out_fw_grad = at::block_diag(fw_grads);
6035 }
6036
6037 return out_fw_grad;
6038 }
6039
stack_jvp(at::TensorList tensors,int64_t dim)6040 Tensor stack_jvp(at::TensorList tensors, int64_t dim) {
6041 // Basically copy of cat_jvp above
6042 // TODO: consolidate with the logic of cat_jvp
6043 Tensor out_fw_grad;
6044
6045 auto any_defined = false;
6046 for (const auto& t : tensors) {
6047 any_defined |= isFwGradDefined(t);
6048 }
6049
6050 if (any_defined) {
6051 std::vector<Tensor> fw_grads;
6052
6053 for (auto& t : tensors) {
6054 fw_grads.push_back(
6055 isFwGradDefined(t)
6056 ? t._fw_grad(/*level*/ 0)
6057 : at::_efficientzerotensor(t.sizes(), t.options()));
6058 }
6059 out_fw_grad = at::stack(fw_grads, dim);
6060 }
6061 return out_fw_grad;
6062 }
6063
cumprod_jvp(const Tensor & self_t,const Tensor & self_p,const Tensor & result,int dim)6064 Tensor cumprod_jvp(
6065 const Tensor& self_t,
6066 const Tensor& self_p,
6067 const Tensor& result,
6068 int dim) {
6069 // Generic formula when no 0. is involved
6070 Tensor gradient = (self_t / self_p).cumsum(dim) * result;
6071
6072 // Note that we have to use at::where below as we are removing nans
6073
6074 if (self_p.dim() == 0) {
6075 gradient.masked_fill_(self_p.eq(0), self_t);
6076 return gradient;
6077 } else {
6078 // For input (a, 0, b, 0, c) with gradients (t0, t1, t2, t3, t4)
6079 // The output of cumprod is (a, 0, 0, 0, 0)
6080 // The gradient we want to compute is (t0, a*t1, a*b*t1, 0, 0)
6081 // We do this by:
6082 // Get a mask of all zeros (0, 1, 0, 1, 0)
6083 auto mask_zeros = self_p.eq(0);
6084 // Get a mask of the first zero for each dim (0, 1, 0, 0, 0)
6085 auto mask_first_zero = mask_zeros.logical_and(mask_zeros.cumsum(dim).eq(1));
6086
6087 // Get the new grad value that should be used after any zero happened:
6088 // (X, a*t1, a*b*t1, 0, 0) = cumprod((a, t1, b, 0, c))
6089 auto new_grad = at::where(mask_first_zero, self_t, self_p).cumprod(dim);
6090
6091 // Get a mask of everything after the first zero: (0, 1, 1, 1, 1)
6092 auto mask_after_first_zero = mask_first_zero.cumsum(dim);
6093
6094 // Do the final replacement
6095 return at::where(
6096 mask_after_first_zero.to(ScalarType::Bool), new_grad, gradient);
6097 }
6098 }
6099
6100 // Helper for {batch,layer,group}_norms below
6101 // Computes the jvp for `1 / input.std(dims, keepdim)`
_invstd_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & mean_p,const Tensor & invstd_p,IntArrayRef dims,int64_t numel,bool keepdim)6102 static Tensor _invstd_jvp(
6103 const Tensor& input_p,
6104 const Tensor& input_t,
6105 const Tensor& mean_p,
6106 const Tensor& invstd_p,
6107 IntArrayRef dims,
6108 int64_t numel,
6109 bool keepdim) {
6110 Tensor invstd_t;
6111 if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) ||
6112 input_t._is_zerotensor()) {
6113 invstd_t = -invstd_p.pow(3) * (input_t - input_t.mean(dims, true)) *
6114 (input_p - mean_p);
6115 } else {
6116 invstd_t = input_t - input_t.mean(dims, true);
6117 invstd_t *= input_p - mean_p;
6118 invstd_t *= -invstd_p.pow(3);
6119 }
6120 invstd_t = invstd_t.sum(dims, keepdim);
6121 invstd_t /= numel;
6122 return invstd_t;
6123 }
6124
6125 // Helper for {batch,layer,group}_norms below only
6126 // Computes the jvp for `(input - input.mean(dims)) * input.invstd(dims)`
_norm_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & mean_p,const Tensor & invstd_p,IntArrayRef dims,int64_t numel)6127 static Tensor _norm_jvp(
6128 const Tensor& input_p,
6129 const Tensor& input_t,
6130 const Tensor& mean_p,
6131 const Tensor& invstd_p,
6132 IntArrayRef dims,
6133 int64_t numel) {
6134 auto invstd_t =
6135 _invstd_jvp(input_p, input_t, mean_p, invstd_p, dims, numel, true);
6136 Tensor result_t;
6137 if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) ||
6138 input_t._is_zerotensor()) {
6139 result_t = (input_t - input_t.mean(dims, true)) * invstd_p +
6140 (input_p - mean_p) * invstd_t;
6141 } else {
6142 result_t = input_t - input_t.mean(dims, true);
6143 result_t *= invstd_p;
6144 auto temp = input_p - mean_p;
6145 temp *= invstd_t;
6146 result_t += temp;
6147 }
6148 return result_t;
6149 }
6150
6151 // Helper for {batch,layer,group}_norms below only
6152 // Computes the jvp for `input * weight + bias` where weight and bias may be
6153 // undefined Possibly modifies the input inplace
_affine_jvp(const std::optional<Tensor> & input_p,Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_t)6154 static Tensor _affine_jvp(
6155 const std::optional<Tensor>& input_p,
6156 Tensor& input_t,
6157 const Tensor& weight_p,
6158 const Tensor& weight_t,
6159 const Tensor& bias_t) {
6160 // We allow input_p to be optional because if weight_p isn't defined,
6161 // it may be possible to avoid computing input_p
6162 TORCH_INTERNAL_ASSERT(input_p.has_value() == weight_p.defined());
6163 if (weight_p.defined()) {
6164 if (areAnyTensorSubclassLike(
6165 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
6166 {input_p.value(), input_t, weight_p, weight_t}) ||
6167 input_t._is_zerotensor() || weight_t._is_zerotensor()) {
6168 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
6169 input_t = input_t * weight_p + input_p.value() * weight_t;
6170 } else {
6171 input_t *= weight_p;
6172 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
6173 auto temp = input_p.value();
6174 temp *= weight_t;
6175 input_t += temp;
6176 }
6177 }
6178 if (bias_t.defined()) {
6179 if (areAnyTensorSubclassLike({input_t, bias_t}) ||
6180 input_t._is_zerotensor()) {
6181 input_t = input_t + bias_t;
6182 } else {
6183 input_t += bias_t;
6184 }
6185 }
6186 return input_t;
6187 }
6188
batch_norm_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_p,const Tensor & bias_t,const std::optional<Tensor> & running_mean,const std::optional<Tensor> & running_var,const Tensor & saved_mean,const Tensor & saved_invstd,bool train,double eps)6189 Tensor batch_norm_jvp(
6190 const Tensor& input_p,
6191 const Tensor& input_t,
6192 const Tensor& weight_p,
6193 const Tensor& weight_t,
6194 const Tensor& bias_p,
6195 const Tensor& bias_t,
6196 const std::optional<Tensor>& running_mean,
6197 const std::optional<Tensor>& running_var,
6198 const Tensor& saved_mean,
6199 const Tensor& saved_invstd,
6200 bool train,
6201 double eps) {
6202 auto dims = std::vector<int64_t>{};
6203 auto view_size = input_t.sizes().vec();
6204 int64_t numel = 1;
6205 for (const auto dim : c10::irange(view_size.size())) {
6206 if (dim != 1) {
6207 numel *= input_t.size(static_cast<int64_t>(dim));
6208 view_size[dim] = 1;
6209 dims.push_back(static_cast<int64_t>(dim));
6210 }
6211 }
6212 Tensor mean_p;
6213 Tensor invstd_p;
6214 Tensor result_t;
6215 if (train) {
6216 mean_p = saved_mean.view(view_size);
6217 invstd_p = saved_invstd.view(view_size);
6218 result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);
6219 } else {
6220 TORCH_INTERNAL_ASSERT(
6221 running_mean.has_value() && running_var.has_value(),
6222 "Expect running_mean and running_var to have value when train=false");
6223 TORCH_CHECK(
6224 !running_mean.value()._fw_grad(/*level=*/0).defined() &&
6225 !running_var.value()._fw_grad(/*level=*/0).defined(),
6226 "batch_norm is not differentiable wrt running_mean and running_var, they cannot have forward grad defined");
6227 mean_p = running_mean.value().view(view_size);
6228 invstd_p =
6229 (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size);
6230 result_t = input_t * invstd_p;
6231 }
6232
6233 std::optional<Tensor> result_p = weight_p.defined()
6234 ? std::optional<Tensor>((input_p - mean_p) * invstd_p)
6235 : std::nullopt;
6236 return _affine_jvp(
6237 result_p,
6238 result_t,
6239 weight_p.defined() ? weight_p.view(view_size) : weight_p,
6240 weight_t.defined() ? weight_t.view(view_size) : weight_t,
6241 bias_t.defined() ? bias_t.view(view_size) : bias_t);
6242 }
6243
layer_norm_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_p,const Tensor & bias_t,const Tensor & saved_mean,const Tensor & saved_invstd,c10::SymIntArrayRef normalized_shape)6244 Tensor layer_norm_jvp(
6245 const Tensor& input_p,
6246 const Tensor& input_t,
6247 const Tensor& weight_p,
6248 const Tensor& weight_t,
6249 const Tensor& bias_p,
6250 const Tensor& bias_t,
6251 const Tensor& saved_mean,
6252 const Tensor& saved_invstd,
6253 c10::SymIntArrayRef normalized_shape) {
6254 auto dims = std::vector<int64_t>{};
6255 auto view_size = input_t.sizes().vec();
6256 auto view_size_affine = input_t.sizes().vec();
6257
6258 int64_t numel = 1;
6259 for (const auto i : c10::irange(view_size.size())) {
6260 if (i < view_size.size() - normalized_shape.size()) {
6261 view_size_affine[i] = 1;
6262 } else {
6263 numel *= input_t.size(static_cast<int64_t>(i));
6264 view_size[i] = 1;
6265 dims.push_back(static_cast<int64_t>(i));
6266 }
6267 }
6268 auto mean_p = saved_mean.view(view_size);
6269 auto invstd_p = saved_invstd.view(view_size);
6270 auto result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);
6271
6272 std::optional<Tensor> result_p = weight_p.defined()
6273 ? std::optional<Tensor>((input_p - mean_p) * invstd_p)
6274 : std::nullopt;
6275 return _affine_jvp(
6276 result_p,
6277 result_t,
6278 weight_p.defined() ? weight_p.view(view_size_affine) : weight_p,
6279 weight_t.defined() ? weight_t.view(view_size_affine) : weight_t,
6280 bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
6281 }
6282
group_norm_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_p,const Tensor & bias_t,const Tensor & saved_mean,const Tensor & saved_invstd,int64_t groups)6283 Tensor group_norm_jvp(
6284 const Tensor& input_p,
6285 const Tensor& input_t,
6286 const Tensor& weight_p,
6287 const Tensor& weight_t,
6288 const Tensor& bias_p,
6289 const Tensor& bias_t,
6290 const Tensor& saved_mean,
6291 const Tensor& saved_invstd,
6292 int64_t groups) {
6293 auto input_shape = input_p.sizes();
6294 int64_t N = input_p.size(0);
6295 int64_t C = input_p.size(1);
6296
6297 auto input_t_reshaped = input_t.view({1, N * groups, N ? -1 : 1});
6298 auto input_p_reshaped = input_p.view({1, N * groups, N ? -1 : 1});
6299
6300 auto result_t = batch_norm_jvp(
6301 input_p_reshaped,
6302 input_t_reshaped,
6303 /*weight_p=*/{},
6304 /*weight_t=*/{},
6305 /*bias_p=*/{},
6306 /*bias_t=*/{},
6307 /*running_mean=*/{},
6308 /*running_var=*/{},
6309 saved_mean,
6310 saved_invstd,
6311 /*train=*/true,
6312 /*eps=*/0)
6313 .view(input_shape);
6314
6315 std::optional<Tensor> result_p = std::nullopt;
6316 if (weight_p.defined()) {
6317 std::vector<int64_t> view_size(input_t_reshaped.dim(), 1);
6318 view_size[1] = input_t_reshaped.size(1);
6319 result_p = ((input_p_reshaped - saved_mean.view(view_size)) *
6320 saved_invstd.view(view_size))
6321 .view(input_shape);
6322 }
6323 std::vector<int64_t> affine_param_shape(input_p.dim(), 1);
6324 affine_param_shape[1] = C;
6325
6326 return _affine_jvp(
6327 result_p,
6328 result_t,
6329 weight_p.defined() ? weight_p.view(affine_param_shape) : weight_p,
6330 weight_t.defined() ? weight_t.view(affine_param_shape) : weight_t,
6331 bias_t.defined() ? bias_t.view(affine_param_shape) : bias_t);
6332 }
6333
group_norm_mean_jvp(const Tensor & input_t,const Tensor & mean_p,int64_t groups)6334 Tensor group_norm_mean_jvp(
6335 const Tensor& input_t,
6336 const Tensor& mean_p,
6337 int64_t groups) {
6338 int64_t N = input_t.size(0);
6339 std::array<int64_t, 3> view_shape = {1, N * groups, N ? -1 : 1};
6340 auto input_t_reshaped = input_t.view(view_shape);
6341 return input_t_reshaped.mean({2}, false).view_as(mean_p);
6342 }
6343
group_norm_invstd_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & mean_p,const Tensor & invstd_p,int64_t groups)6344 Tensor group_norm_invstd_jvp(
6345 const Tensor& input_p,
6346 const Tensor& input_t,
6347 const Tensor& mean_p,
6348 const Tensor& invstd_p,
6349 int64_t groups) {
6350 int64_t N = input_p.size(0);
6351
6352 std::vector<int64_t> view_shape = {1, N * groups, N ? -1 : 1};
6353
6354 auto input_t_reshaped = input_t.view(view_shape);
6355 auto input_p_reshaped = input_p.view(view_shape);
6356
6357 return _invstd_jvp(
6358 input_t_reshaped,
6359 input_p_reshaped,
6360 mean_p.view(view_shape),
6361 invstd_p.view(view_shape),
6362 /*dims=*/{2},
6363 /*numel=*/input_t_reshaped.size(2),
6364 /*keepdim=*/false)
6365 .view_as(invstd_p);
6366 }
6367
gather_with_keepdimed_indices(const Tensor & input,int64_t dim,const Tensor & indices,bool keepdim)6368 Tensor gather_with_keepdimed_indices(
6369 const Tensor& input,
6370 int64_t dim,
6371 const Tensor& indices,
6372 bool keepdim) {
6373 auto full_indices = indices;
6374 if (!keepdim) {
6375 full_indices = indices.unsqueeze(dim);
6376 }
6377 auto out_fw_grad = at::gather(input, dim, full_indices);
6378 if (!keepdim) {
6379 out_fw_grad = out_fw_grad.squeeze(dim);
6380 }
6381
6382 return out_fw_grad;
6383 }
6384
6385 // Let A in \C^{m \times n}, then its pivoted LU decomposition is
6386 // A = P L U, where P is a permutation matrix.
6387 //
6388 // Useful notation:
6389 // Let o denote the elementwise, or Hadamard, product.
6390 // k := min(m, n)
6391 // 1 := ones(k, k),
6392 // 1_U = 1.tril();
6393 // 1_L = 1 - 1_U (note the diagonal is zero)
6394 // For a matrix A, A^H := A.mH()
6395 //
6396 // Below we derive the backward algorithm for the case when m <= n.
6397 // The case m > n could be obtained using the same idea.
6398 // Since we assume m <= n, the LU decomposition of A could be written as
6399 // A = (A1 | A2) = P L (U1 | U2) where A1, U1 in \C^{m \times m}, A2, U2 in
6400 // \C^{m, n - m}
6401 //
6402 // Forward AD:
6403 //
6404 // dA = P dL U + P L dU => [left-multiply P^T]
6405 // (P^T dA1 | P^T dA2) = (dL U1 + L dU1 | dL U2 + L dU2) (*)
6406 // From (*):
6407 // P^T dA1 = dL U1 + L dU1 => [left-multiply by L^{-1}, right-multiply by
6408 // U1^{-1}] L^{-1} P^T dA1 U1^{-1} = L^{-1} dL + dU1 U1^{-1} (**). Note, L is
6409 // lower-triangular, and so is its inverse, hence L^{-1} dL is lower-triangular.
6410 // Also, since the diagonal of L (all ones) is never exposed explicitly (packed
6411 // representation), the diagonal of dL is zero, and hence diag(L^{-1} dL) = 0.
6412 // Assuming that U1 is full-rank, similarly, dU1 U1^{-1} is upper-triangular.
6413 // Combining these observations we conclude:
6414 //
6415 // L^{-1} dL = (L^{-1} P^T dA1 U1^{-1}) o 1_L,
6416 // dU1 U1^{-1} = (L^{-1} P^T dA1 U1^{-1}) o 1_U.
6417 //
6418 // Hence,
6419 // dL = L [(L^{-1} P^T dA1 U1^{-1}) o 1_L],
6420 // dU1 = [(L^{-1} P^T dA1 U1^{-1}) o 1_U] U1.
6421 // As for dU2, from (*) it follows
6422 // P^T dA2 = dL U2 + L dU2 =>
6423 // dU2 = L^{-1} (P^T dA2 - dL U2).
6424 //
6425 // Backward AD:
6426 //
6427 // The following equality comes very handy:
6428 // Tr(A (B o C)) = Tr((A o B^T) C) (!)
6429 // or in other words, given that X -> B o X is a pointwise operation
6430 // its Jacobian is diagonal, so its differential is self-adjoint
6431 // <A, B o C> = <A o B, C>
6432 //
6433 // Tr(A_grad^H dA) = Tr(L_grad^H dL) + Tr(U_grad^H dU), then
6434 //
6435 // Tr(L_grad^H dL) = Tr(L_grad^H L [(L^{-1} P^T dA1 U1^{-1}) o 1_L] = [using
6436 // (!)]
6437 // = Tr((L_grad^H L o 1_L^T) L^{-1} P^T dA1 U1^{-1}) = [using
6438 // the cyclic property of Tr] = Tr(U1^{-1} (L_grad^H L o 1_L^T)
6439 // L^{-1} P^T dA1)
6440 //
6441 // Similar, using (!) and the cyclic property of the trace operator:
6442 // Tr(U_grad^H dU) = Tr(U1_grad^H dU1) + Tr(U2_grad^H dU2)
6443 // = Tr(U1^{-1} (U1 U1_grad^H o 1_U^T) L^{-1} P^T dA1)
6444 // + Tr(U2_grad^H L^{-1} P^T dA2)
6445 // - Tr(U1^{-1} (U2 U2_grad^H o 1_L^T) L^{-1} P^T dA1)
6446 //
6447 // By combining the matrices to the left from dA1 and dA2 and then applying
6448 // conjugate transposition, we finally arrive at:
6449 //
6450 // A1_grad = P L^{-H} [L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o
6451 // 1_L] U1^{-H}, A2_grad = P L^{-H} U2_grad
linalg_lu_backward(const Tensor & L_grad,const Tensor & U_grad,const Tensor & P,const Tensor & L,const Tensor & U,const bool pivot)6452 Tensor linalg_lu_backward(
6453 const Tensor& L_grad,
6454 const Tensor& U_grad,
6455 const Tensor& P,
6456 const Tensor& L,
6457 const Tensor& U,
6458 const bool pivot) {
6459 at::NoTF32Guard disable_tf32;
6460 // Return early if there's nothing to do
6461 if (!L_grad.defined() && !U_grad.defined()) {
6462 return {};
6463 }
6464
6465 // L.shape == (..., m, k)
6466 // U.shape == (..., k, n)
6467 auto m = L.sym_size(-2);
6468 auto n = U.sym_size(-1);
6469 auto k = std::min(m, n);
6470
6471 if (m == n) {
6472 // A_grad = P L^{-H} [L^H L_grad o 1_L + U_grad U^H o 1_U] U^{-H},
6473 auto A_grad = L_grad.defined() ? L.mH().matmul(L_grad).tril(-1) : Tensor{};
6474 if (U_grad.defined()) {
6475 A_grad = A_grad.defined() ? A_grad + U_grad.matmul(U.mH()).triu()
6476 : U_grad.matmul(U.mH()).triu();
6477 }
6478 A_grad = at::linalg_solve_triangular(
6479 U.mH(),
6480 A_grad,
6481 /*upper=*/false,
6482 /*left=*/false);
6483 A_grad = at::linalg_solve_triangular(
6484 L.mH(),
6485 A_grad,
6486 /*upper=*/true,
6487 /*left=*/true,
6488 /*unitriangular=*/true);
6489
6490 return pivot ? P.matmul(A_grad) : std::move(A_grad);
6491 } else if (m < n) {
6492 // Wide case
6493 // A1_grad = P L^{-H} [U1_grad + (L^H L_grad o 1_L - U_grad U^H o 1_U)
6494 // U1^{-H}) U^{-H}] A2_grad = P L^{-H} U2_grad
6495 const auto get_U1 = [n, k](const Tensor& U) {
6496 return n == k ? U : U.narrow_symint(-1, 0, k);
6497 };
6498 const auto get_U2 = [n, k](const Tensor& U) {
6499 return U.narrow_symint(-1, k, n - k);
6500 };
6501
6502 auto A_grad = L_grad.defined() ? L.mH().matmul(L_grad) : Tensor{};
6503 if (U_grad.defined()) {
6504 A_grad = A_grad.defined() ? A_grad - U_grad.triu().matmul(U.mH())
6505 : -U_grad.triu().matmul(U.mH());
6506 }
6507 A_grad = at::linalg_solve_triangular(
6508 get_U1(U).mH(),
6509 A_grad.tril(-1),
6510 /*upper=*/false,
6511 /*left=*/false);
6512
6513 if (U_grad.defined()) {
6514 A_grad =
6515 at::cat({A_grad + get_U1(U_grad).triu(), get_U2(U_grad)}, /*dim=*/-1);
6516 }
6517
6518 A_grad = at::linalg_solve_triangular(
6519 L.mH(),
6520 A_grad,
6521 /*upper=*/true,
6522 /*left=*/true,
6523 /*unitriangular=*/true);
6524
6525 if (!U_grad.defined()) {
6526 A_grad = at::cat({A_grad, at::zeros_like(get_U2(U))}, /*dim=*/-1);
6527 }
6528 if (pivot) {
6529 A_grad = P.matmul(A_grad);
6530 }
6531 return A_grad;
6532 } else {
6533 // Tall case
6534 // A1_grad = P [L1_grad + L^{-H} (U_grad U^H o 1_U - L^H L_grad o
6535 // 1_L)]U^{-H} A2_grad = P L2_grad U^{-H}
6536
6537 const auto get_L1 = [m, k](const Tensor& L) {
6538 return m == k ? L : L.narrow_symint(-2, 0, k);
6539 };
6540 const auto get_L2 = [m, k](const Tensor& L) {
6541 return L.narrow_symint(-2, k, m - k);
6542 };
6543
6544 auto A_grad = U_grad.defined() ? U_grad.matmul(U.mH()) : Tensor{};
6545 if (L_grad.defined()) {
6546 A_grad = A_grad.defined() ? A_grad - L.mH().matmul(L_grad.tril(-1))
6547 : -L.mH().matmul(L_grad.tril(-1));
6548 }
6549 A_grad = at::linalg_solve_triangular(
6550 get_L1(L).mH(),
6551 A_grad.triu(),
6552 /*upper=*/true,
6553 /*left=*/true,
6554 /*unitriangular=*/true);
6555
6556 if (L_grad.defined()) {
6557 A_grad = at::cat(
6558 {A_grad + get_L1(L_grad).tril(-1), get_L2(L_grad)}, /*dim=*/-2);
6559 }
6560
6561 A_grad = at::linalg_solve_triangular(
6562 U.mH(),
6563 A_grad,
6564 /*upper=*/false,
6565 /*left=*/false);
6566
6567 if (!L_grad.defined()) {
6568 A_grad = at::cat({A_grad, at::zeros_like(get_L2(L))}, /*dim=*/-2);
6569 }
6570 if (pivot) {
6571 A_grad = P.matmul(A_grad);
6572 }
6573 return A_grad;
6574 }
6575 }
6576
lu_factor_ex_backward(const Tensor & grad,const Tensor & LU,const Tensor & pivs,const bool pivot)6577 Tensor lu_factor_ex_backward(
6578 const Tensor& grad,
6579 const Tensor& LU,
6580 const Tensor& pivs,
6581 const bool pivot) {
6582 auto [P, L, U] =
6583 at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots*/ pivot);
6584
6585 // L.shape == (..., m, k)
6586 // U.shape == (..., k, n)
6587 const auto m = LU.sym_size(-2);
6588 const auto n = LU.sym_size(-1);
6589 const auto k = std::min(m, n);
6590 const auto L_grad = grad.narrow_symint(-1, 0, k);
6591 const auto U_grad = grad.narrow_symint(-2, 0, k);
6592 return linalg_lu_backward(
6593 /*L_grad=*/L_grad, /*U_grad=*/U_grad, P, L, U, pivot);
6594 }
6595
6596 // This function is based on the forward AD derivations outlined
6597 // in the description to the linalg_lu_backward function.
linalg_lu_jvp(const Tensor & dA,const Tensor & P,const Tensor & L,const Tensor & U,const bool pivot)6598 std::tuple<Tensor, Tensor> linalg_lu_jvp(
6599 const Tensor& dA,
6600 const Tensor& P,
6601 const Tensor& L,
6602 const Tensor& U,
6603 const bool pivot) {
6604 at::NoTF32Guard disable_tf32;
6605
6606 auto m = dA.size(-2);
6607 auto n = dA.size(-1);
6608 auto k = std::min(m, n);
6609
6610 auto PdA = pivot ? P.mT().matmul(dA) : dA;
6611
6612 // similar to the backward implementation, we also consider block structures
6613 // such as: for a matrix A of size m x n we decompose it as A = (A1 | A2) with
6614 // A1 of size m x m if m <= n and A = (A1^T | A2^T)^T with A1 of size n x n if
6615 // m > n.
6616 auto PdA1 = PdA.narrow(-2, 0, k).narrow(-1, 0, k);
6617 auto L1 = L.narrow(-2, 0, k).narrow(-1, 0, k);
6618 auto U1 = U.narrow(-2, 0, k).narrow(-1, 0, k);
6619
6620 // We form using two triangular_solve the matrix, the second one in place
6621 // dK = L1^{-1} PdA1 U2^{-1}
6622 auto dK = at::linalg_solve_triangular(
6623 L1, PdA1, /*upper=*/false, /*left=*/true, /*unitriangular*/ true);
6624
6625 // TODO We should be able to do this in-place. At the moment it raises:
6626 // RuntimeError: linalg_solve_triangular(): functions with out=...
6627 // arguments don't support automatic differentiation, but one of the
6628 // arguments requires grad.
6629
6630 // at::linalg_solve_triangular_out(dK, U1, dK, /*upper=*/true,
6631 // /*left=*/false);
6632 dK = at::linalg_solve_triangular(U1, dK, /*upper=*/true, /*left=*/false);
6633
6634 auto dL1 = L1.matmul(dK.tril(-1));
6635 auto dU1 = dK.triu().matmul(U1);
6636
6637 if (m == n) {
6638 return std::make_tuple(std::move(dL1), std::move(dU1));
6639 } else if (m < n) {
6640 // we only need to update dU2 defined as
6641 // dU2 := L1^{-1} PdA2 - dK.tril(-1) U2)
6642 const auto PdA2 = PdA.narrow(-1, k, n - k);
6643 const auto U2 = U.narrow(-1, k, n - k);
6644 auto dU2 =
6645 at::linalg_solve_triangular(
6646 L1, PdA2, /*upper=*/false, /*left=*/true, /*unitriangular*/ true) -
6647 dK.tril(-1).matmul(U2);
6648 return std::make_tuple(
6649 std::move(dL1), at::cat({std::move(dU1), std::move(dU2)}, /*dim=*/-1));
6650 } else {
6651 // we only need to update dL2 defined as
6652 // dL2 := PdA2 U^{-1} - L2 dK.triu()
6653 const auto PdA2 = PdA.narrow(-2, k, m - k);
6654 const auto L2 = L.narrow(-2, k, m - k);
6655 auto dL2 =
6656 at::linalg_solve_triangular(U1, PdA2, /*upper=*/true, /*left=*/false) -
6657 L2.matmul(dK.triu());
6658 return std::make_tuple(
6659 at::cat({std::move(dL1), std::move(dL2)}, /*dim=*/-2), std::move(dU1));
6660 }
6661 }
6662
lu_factor_ex_jvp(const Tensor & dA,const Tensor & LU,const Tensor & pivs,const bool pivot)6663 Tensor lu_factor_ex_jvp(
6664 const Tensor& dA,
6665 const Tensor& LU,
6666 const Tensor& pivs,
6667 const bool pivot) {
6668 auto [P, L, U] =
6669 at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots=*/pivot);
6670 auto [dL, dU] = linalg_lu_jvp(dA, P, L, U, pivot);
6671
6672 auto m = dA.size(-2);
6673 auto n = dA.size(-1);
6674 if (m >= n) {
6675 dL.narrow(-2, 0, n).add_(dU);
6676 return dL;
6677 } else {
6678 dU.narrow(-1, 0, m).add_(dL);
6679 return dU;
6680 }
6681 }
6682
logsumexp_jvp(const Tensor & self_p,const Tensor & self_t,IntArrayRef dim,bool keepdim)6683 Tensor logsumexp_jvp(
6684 const Tensor& self_p,
6685 const Tensor& self_t,
6686 IntArrayRef dim,
6687 bool keepdim) {
6688 // NB: for simplicity, we recompute some values that can be reused from
6689 // forward
6690 auto self_p_exp = [&self_p, &dim]() {
6691 if (self_p.sym_numel() > 0) {
6692 // Use only the real part for complex tensors
6693 return (self_p - at::amax(at::real(self_p), dim, true))
6694 .exp(); // Use the exp-normalize trick
6695 } else {
6696 // amax fails if numel() == 0, in which case it doesn't matter anyway
6697 return self_p.exp();
6698 }
6699 }();
6700
6701 auto sumexp_p = self_p_exp.sum(dim, keepdim);
6702
6703 // NB: it's OK for logsumexp_jvp to be reused for formulas like
6704 // softmax/log_softmax
6705 // that only have one differentiable input, because that means self_t are
6706 // never zerotensors
6707 TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor())
6708 if (areAnyTensorSubclassLike({self_p, self_t})) {
6709 auto result = (self_p_exp * self_t).sum(dim, keepdim);
6710 result /= sumexp_p;
6711 return result;
6712 } else {
6713 self_p_exp *= self_t;
6714 auto sumexp_t = self_p_exp.sum(dim, keepdim);
6715 return sumexp_t /= sumexp_p;
6716 }
6717 }
6718
safe_logsumexp_jvp(const Tensor & self_p,const Tensor & self_t,IntArrayRef dim,bool keepdim)6719 Tensor safe_logsumexp_jvp(
6720 const Tensor& self_p,
6721 const Tensor& self_t,
6722 IntArrayRef dim,
6723 bool keepdim) {
6724 auto lse_jvp = logsumexp_jvp(self_p, self_t, dim, keepdim);
6725 const auto neg_inf = at::scalar_tensor(
6726 -std::numeric_limits<float>::infinity(),
6727 at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device()));
6728 const auto masked = self_p.eq(neg_inf);
6729 const auto masked_rows = all(masked, dim, true);
6730 const auto zero = at::scalar_tensor(
6731 0.0, at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device()));
6732 return at::where(masked_rows, zero, lse_jvp);
6733 }
6734
warn_backwards(const Tensor & grad_output)6735 Tensor warn_backwards(const Tensor& grad_output) {
6736 TORCH_WARN("Warn from backward");
6737 return grad_output;
6738 }
6739
6740 // This function only exists because cuDNN does not support bias gradient
6741 // computation and it's not easy to slice a std::tuple to return only grad_input
6742 // / grad_weight from convolution_backward. It will be removed when the
6743 // cudnn_convolution and cudnn_convolution_transpose go away.
_cudnn_convolution_backward(const at::Tensor & self,const at::Tensor & grad_output,const at::Tensor & weight,at::SymIntArrayRef padding,at::SymIntArrayRef output_padding,at::SymIntArrayRef stride,at::SymIntArrayRef dilation,bool transposed,c10::SymInt groups,::std::array<bool,2> output_mask)6744 std::tuple<Tensor, Tensor> _cudnn_convolution_backward(
6745 const at::Tensor& self,
6746 const at::Tensor& grad_output,
6747 const at::Tensor& weight,
6748 at::SymIntArrayRef padding,
6749 at::SymIntArrayRef output_padding,
6750 at::SymIntArrayRef stride,
6751 at::SymIntArrayRef dilation,
6752 bool transposed,
6753 c10::SymInt groups,
6754 ::std::array<bool, 2> output_mask) {
6755 if (!grad_output.defined()) {
6756 return std::tuple<Tensor, Tensor>();
6757 }
6758
6759 // Just call the general backward and ignore the bias gradient part.
6760 std::tuple<Tensor, Tensor, Tensor> grad_inputs =
6761 at::convolution_backward_symint(
6762 grad_output,
6763 self,
6764 weight,
6765 std::nullopt,
6766 stride,
6767 padding,
6768 dilation,
6769 transposed,
6770 output_padding,
6771 std::move(groups),
6772 {output_mask[0], output_mask[1], false});
6773 std::tuple<Tensor, Tensor> result =
6774 std::make_tuple(std::get<0>(grad_inputs), std::get<1>(grad_inputs));
6775 return result;
6776 }
6777
scatter_reduce_jvp(const Tensor & self_p,const Tensor & self_t,int dim,const Tensor & index,const Tensor & src_p,const Tensor & src_t,c10::string_view reduce,bool include_self,const Tensor & result)6778 Tensor scatter_reduce_jvp(
6779 const Tensor& self_p,
6780 const Tensor& self_t,
6781 int dim,
6782 const Tensor& index,
6783 const Tensor& src_p,
6784 const Tensor& src_t,
6785 c10::string_view reduce,
6786 bool include_self,
6787 const Tensor& result) {
6788 if (reduce == "sum" || reduce == "mean") {
6789 // The function is linear
6790 return at::scatter_reduce(self_t, dim, index, src_t, reduce, include_self);
6791 // auto mask = x == restore_reduced_dims(result, dim, keepdim);
6792 // return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim,
6793 // keepdim);
6794 } else if (reduce == "amin" || reduce == "amax") {
6795 auto gather_result = at::gather(result, dim, index);
6796 auto mask_self = self_p == result;
6797 auto mask_src = src_p == gather_result;
6798 auto masked_src_t = at::where(mask_src, src_t, 0.);
6799 auto div =
6800 mask_self.to(self_t.dtype())
6801 .scatter_reduce(
6802 dim, index, mask_src.to(self_t.dtype()), "sum", include_self);
6803 return at::where(mask_self, self_t, 0.)
6804 .scatter_reduce(dim, index, masked_src_t, "sum", include_self)
6805 .div(div);
6806 } else {
6807 // Not implemented
6808 return Tensor{};
6809 }
6810 }
6811
scatter_reduce_backward(const Tensor & grad,const Tensor & self,int dim,const Tensor & index,const Tensor & src,c10::string_view reduce,bool include_self,const Tensor & result)6812 std::tuple<Tensor, Tensor> scatter_reduce_backward(
6813 const Tensor& grad,
6814 const Tensor& self,
6815 int dim,
6816 const Tensor& index,
6817 const Tensor& src,
6818 c10::string_view reduce,
6819 bool include_self,
6820 const Tensor& result) {
6821 Tensor grad_self, grad_src;
6822
6823 // FIXME: complex gradients not handled correctly
6824 // For now this is ok as scatter_reduce isn't added to the whitelist
6825 // in tools/autograd/gen_variable_type.py
6826
6827 if (!grad.defined()) {
6828 return std::make_tuple(grad_self, grad_src);
6829 }
6830
6831 if (reduce == "sum") {
6832 grad_self = grad;
6833 grad_src = grad.gather(dim, index);
6834 } else if (reduce == "prod") {
6835 // Explicitly compute exclusive prod for elements in self/src that are 0
6836 Tensor masked_self = self.masked_fill(self == 0, 1);
6837 Tensor masked_self_result =
6838 masked_self.scatter_reduce(dim, index, src, reduce, include_self);
6839 grad_self = grad * masked_self_result / masked_self;
6840 Tensor src_zero = src == 0;
6841 Tensor src_num_zeros =
6842 zeros_like(self)
6843 .scatter_add(dim, index, src_zero.to(self.dtype()))
6844 .gather(dim, index);
6845 Tensor src_single_zero = bitwise_and(src_zero, src_num_zeros == 1);
6846 // For src positions with src_single_zero, grad * result.gather(dim,index) /
6847 // src.masked_fill(src_zero, 1) would incorrectly propagate zeros as the
6848 // gradient
6849 Tensor masked_src = src.masked_fill(src_single_zero, 1);
6850 Tensor masked_src_result =
6851 self.scatter_reduce(dim, index, masked_src, reduce, include_self);
6852 Tensor grad_src1 = where(
6853 src_single_zero,
6854 (grad * masked_src_result).gather(dim, index),
6855 (grad * result).gather(dim, index) / src.masked_fill(src_zero, 1));
6856 // GradMode::is_enabled() - adding the autograd Node is a no-op if autograd
6857 // is disabled; this also avoids having the item() call in the usual case.
6858 if (GradMode::is_enabled() && (src_num_zeros > 1).any().item<bool>()) {
6859 auto node = std::make_shared<DelayedError>(
6860 "scatter_reduce(): Double backward is unsupported for src when >1 zeros in src are scattered to the same position in self",
6861 /* num inputs */ 1);
6862 auto result = node->apply({std::move(grad_src1)});
6863 grad_src = result[0];
6864 } else {
6865 grad_src = grad_src1;
6866 }
6867 } else if (reduce == "mean") {
6868 Tensor N = include_self ? ones_like(grad) : zeros_like(grad);
6869 N = N.scatter_add(dim, index, ones_like(src));
6870 N.masked_fill_(N == 0, 1);
6871 grad_self = grad / N;
6872 Tensor N_src = N.gather(dim, index);
6873 grad_src = grad.gather(dim, index) / N_src;
6874 } else if (reduce == "amax" || reduce == "amin") {
6875 // Evenly distribute gradient when there are multiple max/mins
6876 Tensor value = result.gather(dim, index);
6877 Tensor self_is_result = (self == result).to(self.scalar_type());
6878 Tensor src_is_result = (src == value).to(self.scalar_type());
6879 Tensor N_to_distribute =
6880 self_is_result.scatter_add(dim, index, src_is_result);
6881 Tensor grad_distributed = grad / N_to_distribute;
6882 grad_self = (self == result) * grad_distributed;
6883 grad_src = (src == value) * grad_distributed.gather(dim, index);
6884 } else {
6885 AT_ERROR(
6886 "Expected 'reduce' to be one of 'sum', 'prod', 'mean', 'amax', 'amin' but got ",
6887 reduce,
6888 ".");
6889 }
6890
6891 if (!include_self) {
6892 grad_self = grad_self.scatter(dim, index, 0);
6893 }
6894
6895 return std::make_tuple(grad_self, grad_src);
6896 }
6897
_to_copy_backward(const Tensor & grad_,const c10::TensorOptions & self_options)6898 Tensor _to_copy_backward(
6899 const Tensor& grad_,
6900 const c10::TensorOptions& self_options) {
6901 // Handle R->C copies without raising a warning
6902 const auto self_type = self_options.dtype().toScalarType();
6903 auto grad = c10::MaybeOwned<at::Tensor>::borrowed(grad_);
6904 if (!c10::isComplexType(self_type) && grad->is_complex()) {
6905 grad = c10::MaybeOwned<at::Tensor>::owned(at::real(grad_));
6906 }
6907
6908 return grad->to(self_options, /*non_blocking=*/false, /*copy=*/false);
6909 }
6910
index_reduce_backward(const Tensor & grad,const Tensor & self,int dim,const Tensor & index,const Tensor & source,c10::string_view reduce,bool include_self,const Tensor & result)6911 std::tuple<Tensor, Tensor> index_reduce_backward(
6912 const Tensor& grad,
6913 const Tensor& self,
6914 int dim,
6915 const Tensor& index,
6916 const Tensor& source,
6917 c10::string_view reduce,
6918 bool include_self,
6919 const Tensor& result) {
6920 Tensor grad_self, grad_src;
6921
6922 // FIXME: index_add's backward formula has a special case for source.dim == 0
6923 // but this case seems to throw the error "IndexError: dimension specified as
6924 // 0 but tensor has no dimensions" look into whether this case is reachable
6925 // and should be covered here
6926
6927 if (!grad.defined()) {
6928 return std::make_tuple(grad_self, grad_src);
6929 }
6930
6931 if (reduce == "prod") {
6932 Tensor masked_self = self.masked_fill(self == 0, 1);
6933 Tensor masked_self_result =
6934 masked_self.index_reduce(dim, index, source, reduce, include_self);
6935 grad_self = grad * masked_self_result / masked_self;
6936 Tensor src_zero = source == 0;
6937 Tensor src_num_zeros = zeros_like(self)
6938 .index_add(dim, index, src_zero.to(self.dtype()))
6939 .index_select(dim, index);
6940 Tensor src_single_zero = bitwise_and(src_zero, src_num_zeros == 1);
6941 // For src positions with src_single_zero, (grad *
6942 // result).index_select(dim,index) / source.masked_fill(src_zero, 1) would
6943 // incorrectly propagate zeros as the gradient
6944 Tensor masked_src = source.masked_fill(src_single_zero, 1);
6945 Tensor masked_src_result =
6946 self.index_reduce(dim, index, masked_src, reduce, include_self);
6947 Tensor grad_src1 = where(
6948 src_single_zero,
6949 (grad * masked_src_result).index_select(dim, index),
6950 (grad * result).index_select(dim, index) /
6951 source.masked_fill(src_zero, 1));
6952 // GradMode::is_enabled() - adding the autograd Node is a no-op if autograd
6953 // is disabled this also avoids having the item() call in the usual case
6954 if (GradMode::is_enabled() && (src_num_zeros > 1).any().item<bool>()) {
6955 auto node = std::make_shared<DelayedError>(
6956 "index_reduce(): Double backward is unsupported for source when >1 zeros in source are scattered to the same position in self",
6957 /* num inputs */ 1);
6958 auto result = node->apply({std::move(grad_src1)});
6959 grad_src = result[0];
6960 } else {
6961 grad_src = grad_src1;
6962 }
6963 } else if (reduce == "mean") {
6964 Tensor N = include_self ? ones_like(grad) : zeros_like(grad);
6965 N = N.index_add(dim, index, ones_like(source));
6966 N.masked_fill_(N == 0, 1);
6967 grad_self = grad / N;
6968 Tensor N_src = N.index_select(dim, index);
6969 grad_src = grad.index_select(dim, index) / N_src;
6970 } else if (reduce == "amax" || reduce == "amin") {
6971 Tensor value = result.index_select(dim, index);
6972 Tensor self_is_result = (self == result).to(self.scalar_type());
6973 Tensor source_is_result = (source == value).to(self.scalar_type());
6974 Tensor N_to_distribute =
6975 self_is_result.index_add(dim, index, source_is_result);
6976 Tensor grad_distributed = grad / N_to_distribute;
6977 grad_self = self_is_result * grad_distributed;
6978 grad_src = source_is_result * grad_distributed.index_select(dim, index);
6979 } else {
6980 AT_ERROR(
6981 "Expected 'reduce' to be one of 'prod', 'amax', 'amin' or 'mean' but got ",
6982 reduce,
6983 ".");
6984 }
6985
6986 if (!include_self) {
6987 grad_self = grad_self.index_fill(dim, index, 0);
6988 }
6989
6990 return std::make_tuple(grad_self, grad_src);
6991 }
6992
take_backward(const Tensor & grad,const Tensor & self,const Tensor & indices)6993 Tensor take_backward(
6994 const Tensor& grad,
6995 const Tensor& self,
6996 const Tensor& indices) {
6997 Tensor grad_self = at::zeros_like(self);
6998 // For Composite Compliance,
6999 // if `grad` and `indices` are CCT but `grad_self` is not
7000 // then we use the out-of-place variant of `put`.
7001 if (areAnyTensorSubclassLike({grad, indices})) {
7002 return grad_self.put(indices, grad, true);
7003 }
7004 return grad_self.put_(indices, grad, true);
7005 }
7006
to_sparse_backward(const Tensor & grad,const c10::Layout self_layout,const c10::OptionalArrayRef<c10::SymInt> & self_blocksize)7007 Tensor to_sparse_backward(
7008 const Tensor& grad,
7009 const c10::Layout self_layout,
7010 const c10::OptionalArrayRef<c10::SymInt>& self_blocksize) {
7011 // Path for strided and nested
7012 if (self_layout == c10::kStrided) {
7013 return grad.to_dense();
7014 } else {
7015 OptionalIntArrayRef blocksize = std::nullopt;
7016 if (self_blocksize.has_value()) {
7017 blocksize = c10::asIntArrayRefSlowOpt(*self_blocksize);
7018 }
7019 return grad.to_sparse(self_layout, blocksize);
7020 }
7021 }
7022
7023 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>
mkldnn_rnn_layer_differentiable_backward(const Tensor & input,const Tensor & weight0,const Tensor & weight1,const Tensor & weight2,const Tensor & weight3,const Tensor & hx_,const Tensor & cx_tmp,const Tensor & output,const Tensor & hy_,const Tensor & cy_,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,bool reverse,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool train,bool bidirectional,at::IntArrayRef batch_sizes,bool batch_first,const at::Tensor & workspace)7024 mkldnn_rnn_layer_differentiable_backward(
7025 const Tensor& input,
7026 const Tensor& weight0,
7027 const Tensor& weight1,
7028 const Tensor& weight2,
7029 const Tensor& weight3,
7030 const Tensor& hx_,
7031 const Tensor& cx_tmp,
7032 const Tensor& output,
7033 const Tensor& hy_,
7034 const Tensor& cy_,
7035 const std::optional<Tensor>& grad_output_r_opt,
7036 const std::optional<Tensor>& grad_hy_r_opt,
7037 const std::optional<Tensor>& grad_cy_r_opt,
7038 bool reverse,
7039 int64_t mode,
7040 int64_t hidden_size,
7041 int64_t num_layers,
7042 bool has_biases,
7043 bool train,
7044 bool bidirectional,
7045 at::IntArrayRef batch_sizes,
7046 bool batch_first,
7047 const at::Tensor& workspace) {
7048 const Tensor& grad_output_r =
7049 c10::value_or_else(grad_output_r_opt, [] { return Tensor(); });
7050 const Tensor& grad_hy_r =
7051 c10::value_or_else(grad_hy_r_opt, [] { return Tensor(); });
7052 const Tensor& grad_cy_r =
7053 c10::value_or_else(grad_cy_r_opt, [] { return Tensor(); });
7054 if (!grad_output_r.defined() && !grad_hy_r.defined() &&
7055 !grad_cy_r.defined()) {
7056 return std::make_tuple(
7057 Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor());
7058 }
7059 auto grad_output = grad_output_r.defined()
7060 ? grad_output_r.contiguous()
7061 : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
7062 auto grad_hy = grad_hy_r.defined()
7063 ? grad_hy_r.contiguous()
7064 : at::zeros_like(hx_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
7065 auto grad_cy = cx_tmp.defined()
7066 ? (grad_cy_r.defined()
7067 ? grad_cy_r.contiguous()
7068 : at::zeros_like(cx_tmp, LEGACY_CONTIGUOUS_MEMORY_FORMAT))
7069 : grad_cy_r.contiguous();
7070 Tensor bias_ih, bias_hh;
7071 if (has_biases) {
7072 bias_ih = weight2;
7073 bias_hh = weight3;
7074 } else {
7075 bias_ih = at::zeros(
7076 {4 /* num_bias_gates of LSTM */ * hidden_size}, weight0.options());
7077 bias_hh = at::zeros(
7078 {4 /* num_bias_gates of LSTM */ * hidden_size}, weight0.options());
7079 }
7080 const auto& input_ = input;
7081 auto hx_prev = hx_;
7082 auto cx_prev = cx_tmp;
7083
7084 // Re-calculate gates and hidden states during one layer, which will be used
7085 // in backward.
7086 int64_t seq_length = input.size(0);
7087 std::vector<std::tuple<Tensor, Tensor, Tensor, Tensor>> layer_gates(
7088 seq_length);
7089 std::vector<std::tuple<Tensor, Tensor>> layer_states(seq_length + 1);
7090 layer_states[0] = std::make_tuple(hx_, cx_tmp);
7091 for (int64_t seq = 1; seq < seq_length + 1; seq++) {
7092 auto hx = hx_prev;
7093 auto cx = cx_prev;
7094 auto x_index = reverse ? seq_length - seq : seq - 1;
7095 auto gate = at::linear(input_[x_index], weight0, bias_ih)
7096 .add_(at::linear(hx, weight1, bias_hh));
7097 auto chunked_gates = gate.unsafe_chunk(4, 1);
7098 auto i = chunked_gates[0].sigmoid_();
7099 auto f = chunked_gates[1].sigmoid_();
7100 auto g = chunked_gates[2].tanh_();
7101 auto o = chunked_gates[3].sigmoid_();
7102 layer_gates[x_index] = std::make_tuple(i, f, g, o);
7103 auto cy = (f * cx).add(i * g);
7104 auto hy = o * cy.tanh();
7105 layer_states[seq] = std::make_tuple(hy, cy);
7106 hx_prev = hy;
7107 cx_prev = cy;
7108 }
7109
7110 Tensor dx, dWx, dWh, db, db_, dprev_h, dprev_c, dWh_, dWx_;
7111 Tensor new_grad_hy, d1, dgp, dip, dfp, dop, do_, dg, df, di, da;
7112 std::vector<at::Tensor> layer_dx(seq_length);
7113 for (int64_t seq = seq_length - 1; seq >= 0; seq--) {
7114 int64_t x_index = reverse ? seq_length - seq - 1 : seq;
7115 auto i = std::get<0>(layer_gates[x_index]);
7116 auto f = std::get<1>(layer_gates[x_index]);
7117 auto g = std::get<2>(layer_gates[x_index]);
7118 auto o = std::get<3>(layer_gates[x_index]);
7119 auto hy = std::get<0>(layer_states[seq + 1]);
7120 auto cy = std::get<1>(layer_states[seq + 1]);
7121 auto hx = std::get<0>(layer_states[seq]);
7122 auto cx = std::get<1>(layer_states[seq]);
7123 new_grad_hy = grad_output[x_index].add(grad_hy);
7124 d1 = grad_cy.add(new_grad_hy * o * (1 - cy.tanh() * cy.tanh()));
7125 dgp = d1 * i;
7126 dip = d1 * g;
7127 dprev_c = d1 * f;
7128 dfp = d1 * cx;
7129 dop = new_grad_hy * cy.tanh();
7130 do_ = dop * o * (1 - o);
7131 dg = dgp * (1 - g * g);
7132 df = dfp * f * (1 - f);
7133 di = dip * i * (1 - i);
7134 da = at::cat({di, df, dg, do_}, 1);
7135 db_ = at::sum(da, 0);
7136 dx = at::matmul(da, weight0);
7137 dx = at::unsqueeze(dx, 0);
7138 dprev_h = at::matmul(da, weight1);
7139 dWx_ = at::matmul(da.transpose(0, 1), input_[x_index]);
7140 dWh_ = at::matmul(da.transpose(0, 1), hx);
7141 if (seq == seq_length - 1) {
7142 db = db_;
7143 dWx = dWx_;
7144 dWh = dWh_;
7145 } else {
7146 db += db_;
7147 dWx += dWx_;
7148 dWh += dWh_;
7149 }
7150 layer_dx[x_index] = dx;
7151 grad_hy = dprev_h;
7152 grad_cy = dprev_c;
7153 }
7154
7155 auto cat_layer_dx = at::cat(layer_dx, 0);
7156 return std::make_tuple(cat_layer_dx, dWx, dWh, db, db, dprev_h, dprev_c);
7157 }
7158
values_backward(const Tensor & grad,const Tensor & self)7159 Tensor values_backward(const Tensor& grad, const Tensor& self) {
7160 Tensor grad_self;
7161 if (grad.defined()) {
7162 if (self.layout() == c10::kSparse) {
7163 return at::_sparse_coo_tensor_unsafe_symint(
7164 self.indices(),
7165 grad,
7166 self.sym_sizes(),
7167 self.options(),
7168 /*is_coalesced=*/true);
7169 } else if (at::sparse_csr::is_sparse_compressed(self)) {
7170 auto [compressed_indices, plain_indices] =
7171 at::sparse_csr::getCompressedPlainIndices(self);
7172 return at::_sparse_compressed_tensor_unsafe_symint(
7173 compressed_indices,
7174 plain_indices,
7175 grad,
7176 self.sym_sizes(),
7177 self.options());
7178 } else {
7179 TORCH_CHECK_NOT_IMPLEMENTED(
7180 false,
7181 "values backward with respect to self with layout ",
7182 self.layout());
7183 }
7184 }
7185 return grad_self;
7186 }
7187
7188 } // namespace torch::autograd::generated::details
7189