xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/FunctionsManual.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/FunctionsManual.h>
2 #include <torch/csrc/autograd/functions/basic_ops.h>
3 #include <torch/csrc/autograd/functions/utils.h>
4 #include <torch/csrc/autograd/variable.h>
5 
6 #include <ATen/ATen.h>
7 #include <ATen/AccumulateType.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/ExpandUtils.h>
10 #include <ATen/LegacyBatchedTensorImpl.h>
11 #include <ATen/ScalarOps.h>
12 #include <ATen/SparseCsrTensorUtils.h>
13 #include <ATen/TensorSubclassLikeUtils.h>
14 #include <ATen/Utils.h>
15 #include <ATen/WrapDimUtils.h>
16 #include <ATen/WrapDimUtilsMulti.h>
17 #include <ATen/core/Reduction.h>
18 #include <ATen/core/grad_mode.h>
19 #include <ATen/native/Activation.h>
20 #include <ATen/native/IndexingUtils.h>
21 #include <ATen/native/LinearAlgebraUtils.h>
22 #include <ATen/native/SparseTensorUtils.h>
23 #include <ATen/native/nested/NestedTensorUtils.h>
24 #include <c10/core/TensorOptions.h>
25 #include <c10/util/OptionalArrayRef.h>
26 #include <c10/util/SmallBuffer.h>
27 #include <c10/util/accumulate.h>
28 #include <c10/util/irange.h>
29 
30 #include <algorithm>
31 #include <ciso646>
32 #include <functional>
33 #include <numeric>
34 #include <utility>
35 
36 // Helper functions for autogenerated code
37 // These used to be inlined into the codegened Functions.cpp
38 
39 namespace torch::autograd::generated::details {
40 
41 using at::areAnyTensorSubclassLike;
42 using at::IntArrayRef;
43 using at::OptionalIntArrayRef;
44 using at::Scalar;
45 using at::Tensor;
46 using at::TensorList;
47 
48 const char* kCudnnDoubleBackwardMsg =
49     "Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n    output = model(inputs)";
50 
apply_loss_reduction(const Tensor & unreduced,int64_t reduction)51 Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction) {
52   if (reduction == at::Reduction::Mean) {
53     return unreduced.mean();
54   } else if (reduction == at::Reduction::Sum) {
55     return unreduced.sum();
56   }
57   return unreduced;
58 }
59 
isDefined(const std::optional<Tensor> & t)60 static bool isDefined(const std::optional<Tensor>& t) {
61   return t.has_value() && t->defined();
62 }
63 
toNonOptTensor(const std::optional<Tensor> & t)64 Tensor toNonOptTensor(const std::optional<Tensor>& t) {
65   return t.has_value() ? *t : Tensor();
66 }
67 
toNonOptFwGrad(const std::optional<Tensor> & t)68 Tensor toNonOptFwGrad(const std::optional<Tensor>& t) {
69   return (t.has_value() && t->defined()) ? t->_fw_grad(/*level */ 0) : Tensor();
70 }
71 
toNonOptPrimal(const std::optional<Tensor> & t)72 Tensor toNonOptPrimal(const std::optional<Tensor>& t) {
73   if (t.has_value() && t->defined()) {
74     if (t->unsafeGetTensorImpl()->is_wrapped_number()) {
75       return *t;
76     }
77     return t->_fw_primal(/* level */ 0);
78   }
79   return Tensor();
80 }
81 
copy_range(variable_list & out,IndexRange range,const Tensor & t)82 void copy_range(variable_list& out, IndexRange range, const Tensor& t) {
83   TORCH_CHECK(range.second <= out.size());
84   TORCH_CHECK(
85       range.second - range.first == 1, "inconsistent range for Tensor output");
86   out[range.first] = t;
87 }
88 
copy_range(variable_list & out,IndexRange range,at::ArrayRef<Tensor> t)89 void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
90   TORCH_CHECK(range.second <= out.size());
91   TORCH_CHECK(
92       range.second - range.first == t.size(),
93       "inconsistent range for TensorList output");
94   std::copy(
95       t.begin(), t.end(), out.begin() + static_cast<int64_t>(range.first));
96 }
97 
copysign_tensor_self_backward(const Tensor & grad,const Tensor & self,const Tensor & result)98 Tensor copysign_tensor_self_backward(
99     const Tensor& grad,
100     const Tensor& self,
101     const Tensor& result) {
102   auto ratio = result / self;
103   ratio.masked_fill_(self == 0, 0);
104   return grad * ratio;
105 }
106 
107 template <typename T>
not_implemented_base(const char * name,const char * reason)108 T not_implemented_base(const char* name, const char* reason) {
109   std::string msg =
110       c10::str("the derivative for '", name, "' is not implemented.");
111   if (reason[0] != '\0') {
112     msg = c10::str(msg, " ", reason);
113   };
114   TORCH_CHECK_NOT_IMPLEMENTED(false, msg);
115 }
116 
not_implemented(const char * name,const char * reason)117 Tensor not_implemented(const char* name, const char* reason) {
118   return not_implemented_base<Tensor>(name, reason);
119 }
120 
not_implemented_list(const char * name,const char * reason)121 std::vector<Tensor> not_implemented_list(const char* name, const char* reason) {
122   return not_implemented_base<std::vector<Tensor>>(name, reason);
123 }
124 
maybe_multiply(const Tensor & t,const Scalar & s)125 Tensor maybe_multiply(const Tensor& t, const Scalar& s) {
126   bool is_one = false;
127   if (s.isFloatingPoint()) {
128     is_one = s.toSymFloat() == 1;
129   } else if (s.isIntegral(true)) {
130     is_one = s.toSymInt() == 1;
131   }
132 
133   if (is_one) {
134     return t;
135   } else {
136     return t * s;
137   }
138 }
139 
_safe_size(IntArrayRef sizes,IntArrayRef dim)140 int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
141   int64_t size = 1;
142   if (sizes.empty()) {
143     return 1;
144   }
145   for (auto d : dim) {
146     d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size()));
147     size *= sizes[d];
148   }
149   return size;
150 }
151 
_safe_size(c10::SymIntArrayRef sizes,c10::IntArrayRef dim)152 static c10::SymInt _safe_size(c10::SymIntArrayRef sizes, c10::IntArrayRef dim) {
153   c10::SymInt size = 1;
154   if (sizes.empty()) {
155     return 1;
156   }
157   for (auto d : dim) {
158     d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size()));
159     size *= sizes[d];
160   }
161   return size;
162 }
163 
handle_r_to_c(ScalarType self_st,Tensor gradient_result)164 Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) {
165   if (!at::isComplexType(self_st) && gradient_result.is_complex()) {
166     // R -> C
167     return at::real(gradient_result);
168   }
169   return gradient_result;
170 }
171 
handle_r_to_c(const Tensor & self,Tensor gradient_result)172 static Tensor handle_r_to_c(const Tensor& self, Tensor gradient_result) {
173   if (!self.is_complex() && gradient_result.is_complex()) {
174     // R -> C
175     return at::real(gradient_result);
176   }
177   return gradient_result;
178 }
179 
restore_reduced_dims(const Tensor & output,IntArrayRef dims,bool keepdim)180 Tensor restore_reduced_dims(
181     const Tensor& output,
182     IntArrayRef dims,
183     bool keepdim) {
184   if (keepdim) {
185     return output;
186   }
187   auto total_dims = output.dim() + dims.size();
188   std::vector<c10::SymInt> target_shape(total_dims, 0);
189   for (int64_t i : dims) {
190     if (i < 0) {
191       i = static_cast<int64_t>(total_dims) + i;
192     }
193     target_shape[i] = 1;
194   }
195   int64_t j = 0;
196   for (const c10::SymInt& i : output.sym_sizes()) {
197     while (target_shape[j] > 0)
198       j++;
199     target_shape[j++] = i;
200   }
201   return output.reshape_symint(target_shape);
202 }
203 
scale_grad_by_count(const Tensor & grad,const Tensor & mask,IntArrayRef dims)204 Tensor scale_grad_by_count(
205     const Tensor& grad,
206     const Tensor& mask,
207     IntArrayRef dims) {
208   return (grad / mask.sum(dims, true)) * mask;
209 }
210 
amaxamin_jvp(const Tensor & x,const Tensor & dx,const Tensor & result,IntArrayRef dim,bool keepdim)211 Tensor amaxamin_jvp(
212     const Tensor& x,
213     const Tensor& dx,
214     const Tensor& result,
215     IntArrayRef dim,
216     bool keepdim) {
217   auto mask = x == restore_reduced_dims(result, dim, keepdim);
218   return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim, keepdim);
219 }
220 
_euclidean_dist_backward(const Tensor & grad,const Tensor & x1,const Tensor & x2,const Tensor & res)221 std::tuple<Tensor, Tensor> _euclidean_dist_backward(
222     const Tensor& grad,
223     const Tensor& x1,
224     const Tensor& x2,
225     const Tensor& res) {
226   if (!grad.defined()) {
227     return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
228   }
229   // handle case at 0 where we return a subgradient containing 0
230   Tensor ratio = grad / res;
231   ratio.masked_fill_(res == 0, 0);
232   return std::tuple<Tensor, Tensor>{
233       x1 * ratio.sum(-1, true) - ratio.matmul(x2),
234       x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.mT().matmul(x1)};
235 }
236 
norm_backward(const Tensor & grad,const Tensor & self,const std::optional<Scalar> & p_,const Tensor & norm)237 Tensor norm_backward(
238     const Tensor& grad,
239     const Tensor& self,
240     const std::optional<Scalar>& p_,
241     const Tensor& norm) {
242   return norm_backward(grad, self, p_, norm, {}, true);
243 }
244 
norm_backward(Tensor grad,const Tensor & self,const std::optional<Scalar> & p_,Tensor norm,IntArrayRef dim,bool keepdim)245 Tensor norm_backward(
246     Tensor grad,
247     const Tensor& self,
248     const std::optional<Scalar>& p_,
249     Tensor norm,
250     IntArrayRef dim,
251     bool keepdim) {
252   // NB: We mask fill the NaNs in the output to be zero but still do float
253   // division
254   //     by zero, which ASAN complains about. One way to appease ASAN is to fill
255   //     the problematic values with something arbitrary before the division,
256   //     but we decide not to due to the perf hit. Instead we just silence ASAN
257   //     where necessary
258   size_t ndim = self.dim();
259   double p = p_.value_or(2.0).toDouble();
260   Tensor self_scaled;
261   Tensor scale_v;
262 
263   if (!keepdim && self.dim() != 0) {
264     grad = unsqueeze_multiple(grad, dim, ndim);
265     norm = unsqueeze_multiple(norm, dim, ndim);
266   }
267 
268   if (p == 0.0) {
269     return {};
270   } else if (p == 1.0) {
271     return self.sgn() * grad;
272   } else if (p == 2.0) {
273     return grad * (self / norm).masked_fill_(norm == 0, 0);
274   } else if (std::isinf(p)) {
275     // Derivative of amax(abs(self), dim, keepdim) but respecting nans
276     // We create a mask of `argmax`: it's argmax if self.abs() == norm or it's
277     // NaN
278     auto self_abs = self.abs();
279     auto mask = self_abs.eq(norm).logical_or(self_abs.isnan());
280     return self.sgn() * ((grad / mask.sum(dim, true)) * mask);
281   } else if (p < 1.0) {
282     self_scaled =
283         self.sgn() * self.abs().pow_(p - 1).masked_fill_(self == 0, 0);
284     return self_scaled * grad * norm.pow(1 - p);
285   } else if (p < 2.0) {
286     self_scaled = self.sgn() * self.abs().pow_(p - 1);
287     scale_v = grad / norm.pow(p - 1);
288     scale_v.masked_fill_(norm == 0, 0);
289     return self_scaled * scale_v;
290   } else {
291     self_scaled = self * self.abs().pow_(p - 2);
292     scale_v = grad / norm.pow(p - 1);
293     scale_v.masked_fill_(norm == 0, 0);
294     return self_scaled * scale_v;
295   }
296 }
297 
298 // See norm_backward above for a note on ignoring the sanitizer
norm_jvp(const Tensor & self_p,const Tensor & self_t,const std::optional<Scalar> & p_,Tensor norm,IntArrayRef dim,bool keepdim)299 Tensor norm_jvp(
300     const Tensor& self_p,
301     const Tensor& self_t,
302     const std::optional<Scalar>& p_,
303     Tensor norm,
304     IntArrayRef dim,
305     bool keepdim) {
306   // NB: currently norm_jvp is also reused for dist's jvp (which haas two
307   // differentiable inputs)
308   //     but self_t still cannot be a ZT because that would require both self_t
309   //     and other_t to be ZT
310   TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor());
311   size_t ndim = self_p.dim(); // composite compliance?
312   double p = p_.value_or(2.0).toDouble();
313 
314   if (p == 0.0) {
315     return at::zeros_like(norm);
316   } else if (p == 1.0) {
317     auto result = self_p.sgn();
318     result = areAnyTensorSubclassLike({self_t}) ? result.mul(self_t.conj())
319                                                 : result.mul_(self_t.conj());
320     result = at::real(result);
321     return result.sum(dim, keepdim);
322   } else if (p == 2.0) {
323     auto result = self_p.mul(self_t.conj());
324     result = at::real(result);
325     result = result.sum(dim, keepdim);
326     return result.div_(norm).masked_fill_(norm == 0, 0);
327   } else if (std::isinf(p)) {
328     if (!keepdim && self_p.dim() != 0) {
329       norm = unsqueeze_multiple(norm, dim, ndim);
330     }
331     const auto self_isnan = self_p.isnan();
332     const auto norm_isnan = norm.isnan();
333     const auto& self_and_norm_isnan = areAnyTensorSubclassLike({norm})
334         ? self_isnan.logical_and(norm_isnan)
335         : self_isnan.logical_and_(norm_isnan);
336     const auto is_eq_max =
337         (self_p.abs() == norm).logical_or_(self_and_norm_isnan).type_as(norm);
338     auto nb_max = is_eq_max.count_nonzero(dim);
339     if (self_p.dim() != 0) {
340       nb_max = unsqueeze_multiple(nb_max, dim, ndim);
341     }
342     return (at::real(self_p.sgn() * self_t.conj()) * is_eq_max / nb_max)
343         .sum(dim, keepdim);
344   } else if (p < 1.0) {
345     auto sumpow_t = (self_p.abs().pow_(p - 1).masked_fill_(self_p == 0, 0) *
346                      at::real(self_p.sgn() * self_t.conj()))
347                         .sum(dim, keepdim);
348     return sumpow_t * norm.pow(1 - p);
349   } else if (p < 2.0) {
350     auto sumpow_t =
351         (self_p.abs().pow_(p - 1) * at::real(self_p.sgn() * self_t.conj()))
352             .sum(dim, keepdim);
353     auto out = sumpow_t / norm.pow(p - 1);
354     return out.masked_fill_(norm == 0, 0);
355   } else {
356     auto sumpow_t =
357         (self_p.abs().pow_(p - 2) * at::real(self_p * self_t.conj()))
358             .sum(dim, keepdim);
359     auto out = sumpow_t / norm.pow(p - 1);
360     return out.masked_fill_(norm == 0, 0);
361   }
362 }
363 
norm_jvp(const Tensor & self_p,const Tensor & self_t,const std::optional<Scalar> & p_,Tensor norm)364 Tensor norm_jvp(
365     const Tensor& self_p,
366     const Tensor& self_t,
367     const std::optional<Scalar>& p_,
368     Tensor norm) {
369   return norm_jvp(self_p, self_t, p_, std::move(norm), {}, true);
370 }
371 
_nested_from_padded_backward(const Tensor & grad,const Tensor & input,bool do_transform_0213)372 Tensor _nested_from_padded_backward(
373     const Tensor& grad,
374     const Tensor& input,
375     bool do_transform_0213) {
376   if (do_transform_0213) {
377     auto new_sizes = {
378         input.size(0), input.size(2), (input.size(1) * input.size(3))};
379     auto out = grad.to_padded_tensor(0, new_sizes);
380     auto expand_last_dim_size = {
381         input.size(0), input.size(2), input.size(1), input.size(3)};
382     return out.view(expand_last_dim_size).permute({0, 2, 1, 3});
383   }
384   return grad.to_padded_tensor(0, input.sizes());
385 }
386 
linear_double_backward(const variable_list & grads,const Tensor & self,const Tensor & grad_output,const Tensor & weight)387 std::tuple<Tensor, Tensor, Tensor> linear_double_backward(
388     const variable_list& grads,
389     const Tensor& self,
390     const Tensor& grad_output,
391     const Tensor& weight) {
392   if (!grad_output.defined()) {
393     return std::make_tuple(Tensor(), Tensor(), Tensor());
394   }
395 
396   Tensor grad_self, grad_grad_output, grad_weight;
397 
398   if (grads[1].defined()) {
399     grad_self =
400         (grad_output.dim() == 1 ? grad_output.unsqueeze(0) : grad_output)
401             .matmul(grads[1]);
402     if (grad_output.dim() == 1) {
403       grad_self = grad_self.squeeze(0);
404     }
405   }
406   if (grads[0].defined()) {
407     grad_weight =
408         (grad_output.dim() == 1 ? grad_output.unsqueeze(1) : grad_output.mT())
409             .matmul(grads[0].dim() == 1 ? grads[0].unsqueeze(0) : grads[0]);
410   }
411 
412   if (grads[0].defined() || grads[1].defined() || grads[2].defined()) {
413     grad_grad_output = at::zeros_like(grad_output);
414     if (grad_output.dim() == 1) {
415       grad_grad_output = grad_grad_output.unsqueeze(0);
416     }
417   }
418 
419   if (grads[0].defined()) {
420     grad_grad_output = grad_grad_output +
421         (grads[0].dim() == 1 ? grads[0].unsqueeze(0) : grads[0])
422             .matmul(weight.mT());
423   }
424   if (grads[1].defined()) {
425     grad_grad_output = grad_grad_output +
426         (self.dim() == 1 ? self.unsqueeze(0) : self).matmul(grads[1].mT());
427   }
428   if (grads[2].defined()) {
429     grad_grad_output = grad_grad_output + grads[2];
430   }
431   if (grad_grad_output.defined() && grad_output.dim() == 1) {
432     grad_grad_output = grad_grad_output.squeeze(0);
433   }
434 
435   return std::make_tuple(
436       std::move(grad_self),
437       std::move(grad_grad_output),
438       std::move(grad_weight));
439 }
440 
linalg_vector_norm_jvp(const Tensor & self_p,const Tensor & self_t,const Scalar & scalar_ord,Tensor norm,const at::OptionalIntArrayRef & opt_dim,bool keepdim)441 Tensor linalg_vector_norm_jvp(
442     const Tensor& self_p,
443     const Tensor& self_t,
444     const Scalar& scalar_ord,
445     Tensor norm,
446     const at::OptionalIntArrayRef& opt_dim,
447     bool keepdim) {
448   // No need to handle the dtype arg as it's handled via broadcasting in the
449   // function
450   auto dim = opt_dim.value_or(IntArrayRef({}));
451   return norm_jvp(self_p, self_t, scalar_ord, std::move(norm), dim, keepdim);
452 }
453 
linalg_vector_norm_backward(Tensor grad,const Tensor & self,const Scalar & scalar_ord,Tensor norm,const at::OptionalIntArrayRef & opt_dim,bool keepdim)454 Tensor linalg_vector_norm_backward(
455     Tensor grad,
456     const Tensor& self,
457     const Scalar& scalar_ord,
458     Tensor norm,
459     const at::OptionalIntArrayRef& opt_dim,
460     bool keepdim) {
461   // No need to handle the dtype arg as it's handled via broadcasting in the
462   // function
463   auto dim = opt_dim.value_or(IntArrayRef({}));
464   return norm_backward(
465       std::move(grad), self, scalar_ord, std::move(norm), dim, keepdim);
466 }
467 
pow_backward(Tensor grad,const Tensor & self,const Scalar & exponent)468 Tensor pow_backward(Tensor grad, const Tensor& self, const Scalar& exponent) {
469   if (exponent.equal(0.0)) {
470     return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
471   } else {
472     auto grad_lambda = [&](auto exp) {
473       return grad * (exp * self.pow(exp - 1)).conj();
474     };
475     Tensor out = (exponent.isComplex())
476         ? grad_lambda(exponent.toComplexDouble())
477         : grad_lambda(exponent.toDouble());
478     return handle_r_to_c(self, std::move(out));
479   }
480 }
481 
pow_backward_self(const Tensor & grad,const Tensor & self,const Tensor & exponent)482 Tensor pow_backward_self(
483     const Tensor& grad,
484     const Tensor& self,
485     const Tensor& exponent) {
486   auto out = at::where(
487       exponent == 0.0,
488       at::zeros({}, grad.options()),
489       grad * (exponent * self.pow(exponent - 1)).conj());
490   return handle_r_to_c(self, std::move(out));
491 }
492 
493 // Caveats:
494 // We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to
495 // d(a^b)/db -> -inf for a fixed b as a -> +0
496 // Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0.
497 //
498 // We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as
499 // d(a^b)/db = 0 for a > 0 and b -> +0.
500 // Currently, tensorflow agrees with us.
pow_backward_exponent(const Tensor & grad,const Tensor & self,const Tensor & exponent,const Tensor & result)501 Tensor pow_backward_exponent(
502     const Tensor& grad,
503     const Tensor& self,
504     const Tensor& exponent,
505     const Tensor& result) {
506   Tensor cond;
507   if (exponent.is_complex()) {
508     auto is_real_exp =
509         at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0);
510     cond = at::logical_and(self == 0, is_real_exp);
511   } else {
512     cond = at::logical_and(self == 0, exponent >= 0);
513   }
514   auto promoted_dtype = at::result_type(self, exponent);
515   // `.to()` is no-op if dtype is same.
516   auto self_ = self.to(promoted_dtype);
517 
518   auto out =
519       grad *
520       at::where(
521           cond, at::zeros({}, grad.options()), (result * self_.log()).conj());
522   return handle_r_to_c(exponent, std::move(out));
523 }
524 
pow_backward_exponent(const Tensor & grad,const Scalar & base,const Tensor & exponent,const Tensor & result)525 Tensor pow_backward_exponent(
526     const Tensor& grad,
527     const Scalar& base,
528     const Tensor& exponent,
529     const Tensor& result) {
530   auto grad_lambda = [](const Tensor& a, const Scalar& b) {
531     return (a * b.log()).conj();
532   };
533   auto base_ = exponent.is_complex() && !base.isComplex()
534       ? base.toComplexDouble()
535       : base;
536   if (base.equal(0.0)) {
537     auto cond = [](auto exp) {
538       if (exp.is_complex()) {
539         return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
540       } else {
541         return exp >= 0;
542       }
543     };
544     auto out = grad *
545         at::where(cond(exponent),
546                   at::zeros({}, grad.options()),
547                   grad_lambda(result, base_));
548     return handle_r_to_c(exponent, std::move(out));
549   } else {
550     auto out = grad * grad_lambda(result, base_);
551     return handle_r_to_c(exponent, std::move(out));
552   }
553 }
554 
angle_backward(const Tensor & grad,const Tensor & self)555 Tensor angle_backward(const Tensor& grad, const Tensor& self) {
556   if (self.is_complex()) {
557     return at::where(
558         self == 0.0,
559         at::zeros({}, self.options()),
560         grad * self / self.abs().pow(2) *
561             Scalar(c10::complex<double>{0.0, 1.0}));
562   } else {
563     return at::zeros_like(self, at::MemoryFormat::Preserve);
564   }
565 }
566 
mvlgamma_backward(const Tensor & grad,const Tensor & self,int64_t p)567 Tensor mvlgamma_backward(const Tensor& grad, const Tensor& self, int64_t p) {
568   Tensor args = at::arange(
569       -static_cast<double>(p) / 2. + 0.5,
570       0.5,
571       0.5,
572       // use strided here regardless of self's layout; useful for e.g. NJT
573       self.options().layout(c10::kStrided));
574   args = args.add(self.unsqueeze(-1));
575   return grad * args.digamma_().sum(-1);
576 }
577 
sgn_backward(const Tensor & x,const Tensor & gx,const Tensor & sgn)578 Tensor sgn_backward(const Tensor& x, const Tensor& gx, const Tensor& sgn) {
579   if (x.is_complex()) {
580     auto abs = x.abs();
581     return ((gx - (sgn * sgn) * gx.conj()) / (2. * abs))
582         .masked_fill_(abs == 0., 0.);
583   } else {
584     return at::_efficientzerotensor(sgn.sizes(), sgn.options());
585   }
586 }
587 
masked_fill_backward(const Tensor & grad,const Tensor & mask)588 Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask) {
589   // masked_select does not work well with functorch, as its shape is
590   // data-dependent
591   return areAnyTensorSubclassLike({grad, mask})
592       ? at::where(mask, grad, 0).sum()
593       : grad.masked_select(mask).sum();
594 }
595 
596 template <typename T>
mul_tensor_backward(const Tensor & grad,T other,ScalarType self_st)597 Tensor mul_tensor_backward(const Tensor& grad, T other, ScalarType self_st) {
598   auto out = grad * other.conj();
599   return handle_r_to_c(self_st, std::move(out));
600 }
601 template Tensor mul_tensor_backward(const Tensor&, Tensor, ScalarType);
602 template Tensor mul_tensor_backward(const Tensor&, Scalar, ScalarType);
603 
604 template <typename T>
div_tensor_self_backward(const Tensor & grad,T other,ScalarType self_st,const std::optional<c10::string_view> & rounding_mode)605 Tensor div_tensor_self_backward(
606     const Tensor& grad,
607     T other,
608     ScalarType self_st,
609     const std::optional<c10::string_view>& rounding_mode) {
610   if (rounding_mode.has_value()) {
611     return at::zeros_like(grad, grad.options().dtype(self_st));
612   }
613 
614   auto result = grad / other.conj();
615   return handle_r_to_c(self_st, std::move(result));
616 }
617 template Tensor div_tensor_self_backward(
618     const Tensor&,
619     Tensor,
620     ScalarType,
621     const std::optional<c10::string_view>&);
622 template Tensor div_tensor_self_backward(
623     const Tensor&,
624     Scalar,
625     ScalarType,
626     const std::optional<c10::string_view>&);
627 
628 template <typename T>
div_tensor_self_backward(const Tensor & grad,T other,ScalarType self_st)629 Tensor div_tensor_self_backward(
630     const Tensor& grad,
631     T other,
632     ScalarType self_st) {
633   return div_tensor_self_backward(
634       grad, std::move(other), self_st, std::nullopt);
635 }
636 template Tensor div_tensor_self_backward(const Tensor&, Tensor, ScalarType);
637 template Tensor div_tensor_self_backward(const Tensor&, Scalar, ScalarType);
638 
div_tensor_other_backward(const Tensor & grad,const Tensor & self,const Tensor & other,const std::optional<c10::string_view> & rounding_mode)639 Tensor div_tensor_other_backward(
640     const Tensor& grad,
641     const Tensor& self,
642     const Tensor& other,
643     const std::optional<c10::string_view>& rounding_mode) {
644   if (rounding_mode.has_value()) {
645     return at::zeros_like(grad, grad.options().dtype(other.scalar_type()));
646   }
647 
648   auto result = -grad * ((self / other) / other).conj();
649   return handle_r_to_c(other, std::move(result));
650 }
651 
div_tensor_other_backward(const Tensor & grad,const Tensor & self,const Tensor & other)652 Tensor div_tensor_other_backward(
653     const Tensor& grad,
654     const Tensor& self,
655     const Tensor& other) {
656   return div_tensor_other_backward(grad, self, other, std::nullopt);
657 }
658 
permute_backwards(const Tensor & grad,IntArrayRef fwd_dims)659 Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) {
660   // invert the permutation
661   auto ndims = fwd_dims.size();
662   std::vector<int64_t> dims(ndims);
663   for (const auto i : c10::irange(ndims)) {
664     dims[at::maybe_wrap_dim(fwd_dims[i], static_cast<int64_t>(ndims))] =
665         static_cast<int64_t>(i);
666   }
667   return grad.permute(dims);
668 }
669 
rad2deg_backward(const Tensor & grad)670 Tensor rad2deg_backward(const Tensor& grad) {
671   constexpr double M_180_PI =
672       57.295779513082320876798154814105170332405472466564;
673   return at::mul(grad, Scalar(M_180_PI));
674 }
675 
deg2rad_backward(const Tensor & grad)676 Tensor deg2rad_backward(const Tensor& grad) {
677   constexpr double M_PI_180 =
678       0.017453292519943295769236907684886127134428718885417;
679   return at::mul(grad, Scalar(M_PI_180));
680 }
681 
unsqueeze_multiple(const Tensor & t,OptionalIntArrayRef opt_dim,size_t n_dims)682 Tensor unsqueeze_multiple(
683     const Tensor& t,
684     OptionalIntArrayRef opt_dim,
685     size_t n_dims) {
686   if (opt_dim.has_value()) {
687     IntArrayRef dim = opt_dim.value();
688     auto dim_size = dim.size();
689     // Optimisation for two common cases
690     if (dim_size == 0) {
691       return t;
692     } else if (dim_size == 1) {
693       return t.unsqueeze(dim[0]);
694     }
695   }
696   auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims);
697   Tensor res = t;
698   for (const auto i : c10::irange(n_dims)) {
699     if (dims_to_unsqueeze[i]) {
700       res = res.unsqueeze(static_cast<int64_t>(i));
701     }
702   }
703   return res;
704 }
705 
sum_backward(const Tensor & grad,c10::SymIntArrayRef sizes,OptionalIntArrayRef opt_dims,bool keepdim)706 Tensor sum_backward(
707     const Tensor& grad,
708     c10::SymIntArrayRef sizes,
709     OptionalIntArrayRef opt_dims,
710     bool keepdim) {
711   if (!keepdim && !sizes.empty()) {
712     if (opt_dims.has_value() && !opt_dims.value().empty()) {
713       return unsqueeze_multiple(grad, opt_dims, sizes.size())
714           .expand_symint(sizes);
715     }
716   }
717   return grad.expand_symint(sizes);
718 }
719 
sum_backward(const Tensor & grad,c10::SymIntArrayRef sizes,c10::IntArrayRef dims,bool keepdim)720 Tensor sum_backward(
721     const Tensor& grad,
722     c10::SymIntArrayRef sizes,
723     c10::IntArrayRef dims,
724     bool keepdim) {
725   if (!keepdim && !sizes.empty() && !dims.empty()) {
726     // we are only using `keepdim=true` path for SymInts for now
727     TORCH_CHECK_NOT_IMPLEMENTED(
728         false,
729         "Only the keepdim=true path is implemented to support symints in autograd");
730   } else {
731     return grad.expand_symint(sizes);
732   }
733 }
734 
nansum_backward(const Tensor & grad,const Tensor & self,at::OptionalIntArrayRef dims,bool keepdim)735 Tensor nansum_backward(
736     const Tensor& grad,
737     const Tensor& self,
738     at::OptionalIntArrayRef dims,
739     bool keepdim) {
740   return sum_backward(grad, self.sym_sizes(), dims, keepdim) *
741       self.isnan().logical_not();
742 }
743 
mean_backward(const Tensor & grad,c10::SymIntArrayRef shape,OptionalIntArrayRef opt_dim,c10::SymInt numel,bool keepdim)744 Tensor mean_backward(
745     const Tensor& grad,
746     c10::SymIntArrayRef shape,
747     OptionalIntArrayRef opt_dim,
748     c10::SymInt numel,
749     bool keepdim) {
750   bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
751   auto n =
752       is_all_reduce ? std::move(numel) : _safe_size(shape, opt_dim.value());
753   return sum_backward(grad, shape, opt_dim, keepdim) / std::move(n);
754 }
755 
reverse_list_symint(const c10::SymIntArrayRef list)756 std::vector<c10::SymInt> reverse_list_symint(const c10::SymIntArrayRef list) {
757   auto result = std::vector<c10::SymInt>();
758   result.reserve(list.size());
759   for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
760     result.push_back(*iter);
761   }
762   return result;
763 }
764 
reverse_list(const IntArrayRef list)765 std::vector<int64_t> reverse_list(const IntArrayRef list) {
766   auto result = std::vector<int64_t>();
767   result.reserve(list.size());
768   for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
769     result.push_back(*iter);
770   }
771   return result;
772 }
773 
prod_safe_zeros_backward(const Tensor & grad,const Tensor & inp,int64_t dim)774 Tensor prod_safe_zeros_backward(
775     const Tensor& grad,
776     const Tensor& inp,
777     int64_t dim) {
778   if (inp.sym_numel() == 0) {
779     // When input has a zero sized dimension (empty tensor),
780     // we don't need to actually compute the grads.
781     // So we just reshape `grad` as `input`.
782     return grad.expand_as(inp);
783   }
784 
785   if (inp.sym_size(dim) == 1) {
786     return grad;
787   }
788 
789   auto ones_size = inp.sym_sizes().vec();
790   ones_size[dim] = 1;
791   Tensor ones = at::ones_symint(ones_size, grad.options());
792   Tensor exclusive_normal_nocp =
793       at::cat({ones, inp.narrow_symint(dim, 0, inp.sym_size(dim) - 1)}, dim);
794   Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);
795 
796   Tensor narrow_reverse =
797       inp.narrow_symint(dim, 1, inp.sym_size(dim) - 1).flip(dim);
798   Tensor exclusive_reverse_nocp =
799       at::cat({std::move(ones), std::move(narrow_reverse)}, dim);
800   Tensor exclusive_reverse = exclusive_reverse_nocp.cumprod(dim).flip(dim);
801 
802   return grad * (exclusive_normal * exclusive_reverse).conj();
803 }
804 
805 // note that the gradient for prod is equivalent to:
806 // cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
807 // input:                        [    a,     b,     c]
808 // cumprod(exclusive, normal):   [1    ,     a, a * b]
809 // cumprod(exclusive, reverse):  [b * c,     c,     1]
810 // product:                      [b * c, a * c, a * b]
811 // and this is safe under input with 0s.
prod_backward(const Tensor & grad,const Tensor & input,const Tensor & result)812 Tensor prod_backward(
813     const Tensor& grad,
814     const Tensor& input,
815     const Tensor& result) {
816   if (input.dim() == 0) {
817     return grad;
818   }
819   if (input.is_meta() || isTensorSubclassLike(input)) {
820     // For Composite Compliance, always take the safer (and slower) path
821     return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0)
822         .view_as(input);
823   }
824   Tensor zero_idx = (input == 0).nonzero();
825   if (zero_idx.sym_numel() == 0) {
826     return grad * (result / input).conj();
827   } else if (!at::GradMode::is_enabled() && zero_idx.sym_size(0) > 1) {
828     return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
829   } else {
830     return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0)
831         .view_as(input);
832   }
833 }
834 
prod_backward(Tensor grad,const Tensor & input,Tensor result,int64_t dim,bool keepdim)835 Tensor prod_backward(
836     Tensor grad,
837     const Tensor& input,
838     Tensor result,
839     int64_t dim,
840     bool keepdim) {
841   if (input.dim() == 0) {
842     return grad;
843   }
844   dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(input.sym_sizes().size()));
845   if (!keepdim) {
846     // `prod` reduces the dimension at `dim`,
847     // so, unsqueeze `grad` and `result` at dim.
848     grad = grad.unsqueeze(dim);
849     result = result.unsqueeze(dim);
850   }
851   if (input.is_meta() || isTensorSubclassLike(input)) {
852     // For Composite Compliance, always take the safer (and slower) path
853     return prod_safe_zeros_backward(grad, input, dim);
854   }
855 
856   Tensor zero_mask = (input == 0);
857   Tensor slice_zero_count = zero_mask.sum(dim, true);
858   int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
859   if (total_zeros == 0) {
860     return grad * (result / input).conj();
861   } else {
862     return prod_safe_zeros_backward(grad, input, dim);
863   }
864 }
865 
866 template <typename solve_f>
generic_solve_jvp(solve_f solve,const Tensor & X,const Tensor & A,const Tensor & dA,const Tensor & dB)867 static Tensor generic_solve_jvp(
868     solve_f solve,
869     const Tensor& X,
870     const Tensor& A,
871     const Tensor& dA,
872     const Tensor& dB) {
873   auto is_vector_case = at::native::linalg_solve_is_vector_rhs(dA, dB);
874   auto dA_contrib =
875       is_vector_case ? dA.matmul(X.unsqueeze(-1)).squeeze(-1) : dA.matmul(X);
876   // In general,
877   // dX = solve(A, dB - dA_contrib), but this behavior is different for
878   // lu_solve. For refer to lu_solve_jvp for more details on this.
879   return solve(A, dB, dA_contrib);
880 }
881 
cumsum_backward(const Tensor & grad,int64_t dim)882 Tensor cumsum_backward(const Tensor& grad, int64_t dim) {
883   // Trivial case
884   if (grad.sym_numel() <= 1 || grad.sym_size(dim) == 1) {
885     return grad;
886   }
887   return grad.flip(dim).cumsum(dim).flip(dim);
888 }
889 
logsumexp_backward(Tensor grad,const Tensor & self,Tensor result,IntArrayRef dim,bool keepdim)890 Tensor logsumexp_backward(
891     Tensor grad,
892     const Tensor& self,
893     Tensor result,
894     IntArrayRef dim,
895     bool keepdim) {
896   if (!keepdim && self.dim() != 0) {
897     grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size());
898     result = unsqueeze_multiple(result, dim, self.sym_sizes().size());
899   }
900   return grad * (self - result).exp().conj();
901 }
902 
logcumsumexp_backward(Tensor grad,const Tensor & self,const Tensor & result,int64_t dim)903 Tensor logcumsumexp_backward(
904     Tensor grad,
905     const Tensor& self,
906     const Tensor& result,
907     int64_t dim) {
908   if (grad.dim() == 0 || grad.sym_numel() == 0) {
909     return grad;
910   }
911 
912   // Reference: https://github.com/tensorflow/tensorflow/blob/
913   // 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
914 
915   auto scalar_min = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
916       at::ScalarType::BFloat16,
917       at::typeMetaToScalarType(grad.dtype()),
918       "logcumsumexp_backward",
919       []() { return c10::Scalar(std::numeric_limits<scalar_t>::lowest()); });
920 
921   auto reverse_logcumsumexp = [dim](auto x) {
922     return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim});
923   };
924 
925   if (!at::is_complex(grad)) {
926     auto grad_min = at::scalar_tensor(scalar_min, grad.options());
927     auto log_abs_grad = grad.abs().log();
928     auto log_grad_positive = at::where(grad > 0, log_abs_grad, grad_min);
929     auto log_grad_negative = at::where(grad < 0, log_abs_grad, grad_min);
930 
931     auto output_pos =
932         (reverse_logcumsumexp(log_grad_positive - result) + self).exp();
933     auto output_neg =
934         (reverse_logcumsumexp(log_grad_negative - result) + self).exp();
935 
936     return output_pos - output_neg;
937   } else {
938     // no trick separating the positive and negative required
939     auto log_grad = grad.conj().log();
940     auto output = (reverse_logcumsumexp(log_grad - result) + self).exp();
941     return output.conj();
942   }
943 }
944 
logcumsumexp_jvp(const Tensor & self_p,const Tensor & self_t,int64_t dim)945 Tensor logcumsumexp_jvp(
946     const Tensor& self_p,
947     const Tensor& self_t,
948     int64_t dim) {
949   // Mostly taken from logsumexp_jvp
950 
951   // NB: for simplicity, we recompute some values that can be reused from
952   // forward
953   auto self_p_exp = [&self_p, dim]() {
954     if (!at::is_complex(self_p)) {
955       return (self_p - std::get<0>(at::max(self_p, dim, true)))
956           .exp(); // Use the exp-normalize trick
957     } else {
958       // at::max doesn't support complex128
959       return self_p.exp();
960     }
961   }();
962 
963   auto cumsumexp_p = self_p_exp.cumsum(dim);
964 
965   TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor())
966 
967   constexpr double eps = 1e-13;
968 
969   if (areAnyTensorSubclassLike({self_p, self_t})) {
970     auto result = (self_p_exp * self_t).cumsum(dim);
971     result /= cumsumexp_p.add_(eps);
972     return result;
973   } else {
974     self_p_exp *= self_t;
975     auto cumsumexp_t = self_p_exp.cumsum(dim);
976     return cumsumexp_t /= cumsumexp_p.add_(eps);
977   }
978 }
979 
unbind_backward(const variable_list & grads,int64_t dim)980 Tensor unbind_backward(const variable_list& grads, int64_t dim) {
981   c10::SymIntArrayRef sizes;
982   at::TensorOptions o;
983   for (const auto& v : grads) {
984     if (v.defined()) {
985       sizes = v.sym_sizes();
986       o = static_cast<Tensor>(v).options();
987       break;
988     }
989   }
990   auto grads_tensors = fmap(grads, [&](const Variable& v) {
991     return (
992         v.defined() ? static_cast<Tensor>(v)
993                     : at::zeros({}, o).expand_symint(sizes));
994   });
995   return at::stack(grads_tensors, dim);
996 }
997 
unbind_backward_nested(const variable_list & grads,const Tensor & nt_sizes,int64_t dim,const at::TensorOptions & options)998 Tensor unbind_backward_nested(
999     const variable_list& grads,
1000     const Tensor& nt_sizes,
1001     int64_t dim,
1002     const at::TensorOptions& options) {
1003   std::vector<Tensor> grads_tensors;
1004   for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
1005     if (grads[i].defined()) {
1006       grads_tensors.push_back(static_cast<Tensor>(grads[i]));
1007     } else {
1008       const auto component_size = nt_sizes[i].contiguous();
1009       const c10::IntArrayRef grad_size(
1010           component_size.data_ptr<int64_t>(), component_size.size(0));
1011       grads_tensors.push_back(at::zeros(grad_size, options));
1012     }
1013   }
1014 
1015   return at::_nested_tensor_from_tensor_list(grads_tensors);
1016 }
1017 
unbind_backward_nested_jagged(const variable_list & grads,const Tensor & self,int64_t dim)1018 Tensor unbind_backward_nested_jagged(
1019     const variable_list& grads,
1020     const Tensor& self,
1021     int64_t dim) {
1022   TORCH_INTERNAL_ASSERT(
1023       dim == 0, "unbind_backward_nested_jagged() only supports dim=0")
1024   auto grad_nt = at::zeros_like(self);
1025   auto unbound_grads = grad_nt.unbind();
1026   for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
1027     if (grads[i].defined()) {
1028       unbound_grads[i].copy_(static_cast<Tensor>(grads[i]));
1029     }
1030   }
1031 
1032   return grad_nt;
1033 }
1034 
unsqueeze_to(const Tensor & self,c10::SymIntArrayRef sym_sizes)1035 Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
1036   auto result = self;
1037 
1038   auto nDims = sym_sizes.size();
1039   for (const auto dim : c10::irange(nDims)) {
1040     if (sym_sizes[dim] == 1) {
1041       result = result.unsqueeze(static_cast<int64_t>(dim));
1042     }
1043   }
1044   return result;
1045 }
1046 
unsqueeze_to(const Tensor & self,IntArrayRef dims,c10::SymIntArrayRef sym_sizes)1047 Tensor unsqueeze_to(
1048     const Tensor& self,
1049     IntArrayRef dims,
1050     c10::SymIntArrayRef sym_sizes) {
1051   const auto ndim = sym_sizes.size();
1052   auto mask = at::dim_list_to_bitset(dims, ndim);
1053 
1054   Tensor result = self;
1055   for (const auto d : c10::irange(ndim)) {
1056     if (mask.test(d) && sym_sizes[d] == 1) {
1057       result = result.unsqueeze(static_cast<int64_t>(d));
1058     }
1059   }
1060   return result;
1061 }
1062 
unsqueeze_to(const Tensor & self,int64_t dim,c10::SymIntArrayRef sym_sizes)1063 Tensor unsqueeze_to(
1064     const Tensor& self,
1065     int64_t dim,
1066     c10::SymIntArrayRef sym_sizes) {
1067   return unsqueeze_to(self, IntArrayRef{dim}, sym_sizes);
1068 }
1069 
cat_tensors_backward(const Tensor & grad,const std::vector<std::vector<c10::SymInt>> & sizes,const std::vector<ScalarType> & dtypes,int64_t dim)1070 std::vector<Tensor> cat_tensors_backward(
1071     const Tensor& grad,
1072     const std::vector<std::vector<c10::SymInt>>& sizes,
1073     const std::vector<ScalarType>& dtypes,
1074     int64_t dim) {
1075   std::vector<Tensor> grad_inputs(sizes.size());
1076   if (!grad.defined()) {
1077     return grad_inputs;
1078   }
1079   dim = at::legacy_cat_wrap_dim_symint(dim, sizes);
1080   c10::SymInt accumulate = 0;
1081 
1082   Tensor grad_;
1083   bool grad_is_complex = grad.is_complex();
1084   if (grad_is_complex) {
1085     grad_ = at::real(grad);
1086   }
1087   for (const auto i : c10::irange(sizes.size())) {
1088     Tensor grad_val;
1089     if (!at::isComplexType(dtypes[i]) && grad_is_complex) {
1090       // R -> C
1091       grad_val = grad_;
1092     } else {
1093       grad_val = grad;
1094     }
1095     auto& shape = sizes[i];
1096     // If input was empty tensor, gradInput should be empty tensor.
1097     if (shape.size() == 1) {
1098       if (TORCH_GUARD_SIZE_OBLIVIOUS(shape[0].sym_eq(0))) {
1099         grad_inputs[i] = at::zeros({0}, grad_val.options());
1100         continue;
1101       }
1102     }
1103     const auto& size = shape[dim];
1104     accumulate += size;
1105     grad_inputs[i] = grad_val.narrow_symint(dim, accumulate - size, size);
1106   }
1107   return grad_inputs;
1108 }
1109 
stack_tensors_backward(const Tensor & grad,int64_t dim,const std::vector<ScalarType> & dtypes)1110 std::vector<Tensor> stack_tensors_backward(
1111     const Tensor& grad,
1112     int64_t dim,
1113     const std::vector<ScalarType>& dtypes) {
1114   std::vector<Tensor> grad_inputs(dtypes.size());
1115   if (!grad.defined()) {
1116     return grad_inputs;
1117   }
1118   bool grad_is_complex = grad.is_complex();
1119   for (const auto i : c10::irange(dtypes.size())) {
1120     auto gr = grad.select(dim, static_cast<int64_t>(i));
1121     if (grad_is_complex && !at::isComplexType(dtypes[i])) {
1122       gr = at::real(gr);
1123     }
1124     grad_inputs[i] = gr;
1125   }
1126   return grad_inputs;
1127 }
1128 
block_diag_backward(const Tensor & grad,const std::vector<std::vector<int64_t>> & sizes,const std::vector<ScalarType> & dtypes)1129 std::vector<Tensor> block_diag_backward(
1130     const Tensor& grad,
1131     const std::vector<std::vector<int64_t>>& sizes,
1132     const std::vector<ScalarType>& dtypes) {
1133   std::vector<Tensor> grad_inputs(sizes.size());
1134   if (!grad.defined()) {
1135     return grad_inputs;
1136   }
1137   Tensor real_view_of_grad;
1138   bool grad_is_complex = grad.is_complex();
1139   if (grad_is_complex) {
1140     real_view_of_grad = at::real(grad);
1141   }
1142 
1143   int64_t cur_dim0 = 0;
1144   int64_t cur_dim1 = 0;
1145 
1146   for (const auto i : c10::irange(sizes.size())) {
1147     // R -> C
1148     Tensor grad_val = (!at::isComplexType(dtypes[i]) && grad_is_complex)
1149         ? real_view_of_grad
1150         : grad;
1151 
1152     auto& shape = sizes[i];
1153     // If input was empty tensor, gradInput should be empty tensor.
1154     if (shape.size() == 1 && shape[0] == 0) {
1155       grad_inputs[i] = at::zeros({0}, grad_val.options());
1156       continue;
1157     }
1158     // 0d case
1159     int64_t dim0 = 1;
1160     int64_t dim1 = 1;
1161     // 2d case
1162     if (shape.size() == 2) {
1163       dim0 = shape[0];
1164       dim1 = shape[1];
1165       // 1d case
1166     } else if (shape.size() == 1) {
1167       dim1 = shape[0];
1168     }
1169     auto slice = grad_val.slice(0, cur_dim0, cur_dim0 + dim0)
1170                      .slice(1, cur_dim1, cur_dim1 + dim1);
1171     if (shape.size() == 1) {
1172       slice = slice.squeeze(-1);
1173     } else if (shape.empty()) {
1174       slice = slice.squeeze(-1).squeeze(-1);
1175     }
1176     grad_inputs[i] = slice;
1177     cur_dim0 += dim0;
1178     cur_dim1 += dim1;
1179   }
1180   return grad_inputs;
1181 }
1182 
clamp_backward(const Tensor & grad,const Tensor & self,const std::optional<Scalar> & min,const std::optional<Scalar> & max)1183 Tensor clamp_backward(
1184     const Tensor& grad,
1185     const Tensor& self,
1186     const std::optional<Scalar>& min,
1187     const std::optional<Scalar>& max) {
1188   // clamp: gradients not defined on min and max, so we return the subgradient 1
1189   // for these cases.
1190   if (max && min) {
1191     auto zero = at::scalar_tensor(0., grad.options());
1192     return where((self >= *min).logical_and_(self <= *max), grad, zero);
1193   } else if (min) {
1194     auto zero = at::scalar_tensor(0., grad.options());
1195     return where(self >= *min, grad, zero);
1196   } else if (max) {
1197     auto zero = at::scalar_tensor(0., grad.options());
1198     return where(self <= *max, grad, zero);
1199   } else {
1200     return grad;
1201   }
1202 }
1203 
clamp_backward(const Tensor & grad,const Tensor & self,const Tensor & min,const Tensor & max)1204 Tensor clamp_backward(
1205     const Tensor& grad,
1206     const Tensor& self,
1207     const Tensor& min,
1208     const Tensor& max) {
1209   // clamp: gradients not defined on min and max, so we return the subgradient 1
1210   // for these cases.
1211   if (max.defined() && min.defined()) {
1212     auto zero = at::scalar_tensor(0., grad.options());
1213     const auto self_ge_min = self >= min;
1214     const auto self_le_max = self <= max;
1215     const auto& pred = areAnyTensorSubclassLike({self, min, max})
1216         ? self_ge_min.logical_and(self_le_max)
1217         : self_ge_min.logical_and_(self_le_max);
1218     return where(pred, grad, zero);
1219   } else if (min.defined()) {
1220     auto zero = at::scalar_tensor(0., grad.options());
1221     return where(self >= min, grad, zero);
1222   } else if (max.defined()) {
1223     auto zero = at::scalar_tensor(0., grad.options());
1224     return where(self <= max, grad, zero);
1225   } else {
1226     return grad;
1227   }
1228 }
1229 
clamp_backward_min_max(const Tensor & grad,const Tensor & self,const Tensor & min,const Tensor & max,const std::array<bool,2> & grad_input_mask)1230 std::tuple<at::Tensor, at::Tensor> clamp_backward_min_max(
1231     const Tensor& grad,
1232     const Tensor& self,
1233     const Tensor& min,
1234     const Tensor& max,
1235     const std::array<bool, 2>& grad_input_mask) {
1236   // If min > max, min has no gradient
1237   std::tuple<at::Tensor, at::Tensor> ret;
1238   if (!grad.defined()) {
1239     return ret;
1240   }
1241 
1242   auto zero = at::scalar_tensor(0., grad.options());
1243   if (max.defined() && min.defined()) {
1244     if (grad_input_mask[0]) {
1245       const auto self_lt_min = self < min;
1246       const auto min_lt_max = min < max;
1247       const auto& pred = areAnyTensorSubclassLike({self, min, max})
1248           ? self_lt_min.logical_and(min_lt_max)
1249           : self_lt_min.logical_and_(min_lt_max);
1250       std::get<0>(ret) = where(pred, grad, zero);
1251     }
1252     if (grad_input_mask[1]) {
1253       const auto self_gt_max = self > max;
1254       const auto max_lt_min = max < min;
1255       const auto& pred = areAnyTensorSubclassLike({self, min, max})
1256           ? self_gt_max.logical_or(max_lt_min)
1257           : self_gt_max.logical_or_(max_lt_min);
1258       std::get<1>(ret) = where(pred, grad, zero);
1259     }
1260   } else if (min.defined() && grad_input_mask[0]) {
1261     std::get<0>(ret) = where(self < min, grad, zero);
1262   } else if (max.defined() && grad_input_mask[1]) {
1263     std::get<1>(ret) = where(self > max, grad, zero);
1264   }
1265   return ret;
1266 }
1267 
clamp_jvp(const Tensor & self_p,const Tensor & self_t,const Tensor & min_p,const Tensor & min_t,const Tensor & max_p,const Tensor & max_t)1268 at::Tensor clamp_jvp(
1269     const Tensor& self_p,
1270     const Tensor& self_t,
1271     const Tensor& min_p,
1272     const Tensor& min_t,
1273     const Tensor& max_p,
1274     const Tensor& max_t) {
1275   if (min_p.defined() && max_p.defined()) {
1276     return where(
1277         min_p > max_p,
1278         max_t,
1279         where(self_p < min_p, min_t, where(self_p > max_p, max_t, self_t)));
1280   } else if (min_p.defined()) {
1281     return where(self_p > min_p, self_t, min_t);
1282   } else if (max_p.defined()) {
1283     return where(self_p < max_p, self_t, max_t);
1284   } else {
1285     return self_t;
1286   }
1287 }
1288 
convolution_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_p,const Tensor & bias_t,at::SymIntArrayRef stride,at::SymIntArrayRef padding,at::SymIntArrayRef dilation,bool transposed,at::SymIntArrayRef output_padding,const c10::SymInt & groups)1289 Tensor convolution_jvp(
1290     const Tensor& input_p,
1291     const Tensor& input_t,
1292     const Tensor& weight_p,
1293     const Tensor& weight_t,
1294     const Tensor& bias_p,
1295     const Tensor& bias_t,
1296     at::SymIntArrayRef stride,
1297     at::SymIntArrayRef padding,
1298     at::SymIntArrayRef dilation,
1299     bool transposed,
1300     at::SymIntArrayRef output_padding,
1301     const c10::SymInt& groups) {
1302   auto bias_t_opt =
1303       bias_t.defined() ? std::optional<at::Tensor>(bias_t) : std::nullopt;
1304   return (
1305       at::convolution_symint(
1306           input_t,
1307           weight_p,
1308           std::nullopt,
1309           stride,
1310           padding,
1311           dilation,
1312           transposed,
1313           output_padding,
1314           groups) +
1315       at::convolution_symint(
1316           input_p,
1317           weight_t,
1318           bias_t_opt,
1319           stride,
1320           padding,
1321           dilation,
1322           transposed,
1323           output_padding,
1324           groups));
1325 }
1326 
_convolution_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_p,const Tensor & bias_t,at::SymIntArrayRef stride,at::SymIntArrayRef padding,at::SymIntArrayRef dilation,bool transposed,at::SymIntArrayRef output_padding,const c10::SymInt & groups,bool benchmark,bool deterministic,bool cudnn_enabled,bool allow_tf32)1327 Tensor _convolution_jvp(
1328     const Tensor& input_p,
1329     const Tensor& input_t,
1330     const Tensor& weight_p,
1331     const Tensor& weight_t,
1332     const Tensor& bias_p,
1333     const Tensor& bias_t,
1334     at::SymIntArrayRef stride,
1335     at::SymIntArrayRef padding,
1336     at::SymIntArrayRef dilation,
1337     bool transposed,
1338     at::SymIntArrayRef output_padding,
1339     const c10::SymInt& groups,
1340     bool benchmark,
1341     bool deterministic,
1342     bool cudnn_enabled,
1343     bool allow_tf32) {
1344   auto bias_t_opt =
1345       bias_t.defined() ? std::optional<at::Tensor>(bias_t) : std::nullopt;
1346   return (
1347       at::_convolution_symint(
1348           input_t,
1349           weight_p,
1350           std::nullopt,
1351           stride,
1352           padding,
1353           dilation,
1354           transposed,
1355           output_padding,
1356           groups,
1357           benchmark,
1358           deterministic,
1359           cudnn_enabled,
1360           allow_tf32) +
1361       at::_convolution_symint(
1362           input_p,
1363           weight_t,
1364           bias_t_opt,
1365           stride,
1366           padding,
1367           dilation,
1368           transposed,
1369           output_padding,
1370           groups,
1371           benchmark,
1372           deterministic,
1373           cudnn_enabled,
1374           allow_tf32));
1375 }
1376 
convolution_backward_jvp_grad_bias(const Tensor & grad_out_t,const Tensor & grad_bias)1377 Tensor convolution_backward_jvp_grad_bias(
1378     const Tensor& grad_out_t,
1379     const Tensor& grad_bias) {
1380   if (!grad_bias.defined()) {
1381     return Tensor();
1382   }
1383   int64_t dim = grad_out_t.dim() - 2;
1384   if (dim == 1) {
1385     // Cannot pass initializer list due to overload ambiguity
1386     auto dimlist = std::vector<int64_t>{0, 2};
1387     return grad_out_t.sum(dimlist);
1388   } else if (dim == 2) {
1389     return grad_out_t.sum({0, 2, 3});
1390   } else if (dim == 3) {
1391     return grad_out_t.sum({0, 2, 3, 4});
1392   } else {
1393     TORCH_INTERNAL_ASSERT(
1394         false,
1395         "convolution_backward_jvp_grad_bias expected dim of grad_out_t to be 3, 4, or 5, but got: ",
1396         grad_out_t.dim());
1397   }
1398 }
1399 
1400 // This function is used by load_derivatives.py to replace tensor.strides()
1401 // calls that appear in derivative formulas. If the tensor has requires_grad
1402 // set, this function returns its strides or an empty array if the tensor
1403 // is sparse. If requires_grad is not set, an empty array is returned since
1404 // there will be no backward pass. There has one special case, if input is
1405 // MKLDNN tensor and has requires_grad set, just return an empty array, the
1406 // reason is that MKLDNN tensor is a opaque tensor which has not stride info.
1407 //
1408 // This function only supports the case where `input` is the tensor whose
1409 // single derivative is being calculated.
1410 //
1411 // This function does not support `self` derivatives for inplace functions.
1412 //
1413 // Args:
1414 //  input              Tensor to call .strides() on
1415 //  input_name         Name of `input` tensor, from derivative formula
strides_or_error(const Tensor & input,c10::string_view const & input_name)1416 at::SymIntArrayRef strides_or_error(
1417     const Tensor& input,
1418     c10::string_view const& input_name) {
1419   // TODO: Ideally, this function would never be called if requires_grad is
1420   // not set. Once codegen is updated to avoid the call, we can remove this
1421   // check.
1422   if (input.requires_grad()) {
1423     if (input.is_mkldnn())
1424       return {};
1425     if (input.is_sparse() || at::sparse_csr::is_sparse_compressed(input))
1426       return {};
1427     return input.sym_strides();
1428   } else {
1429     return {};
1430   }
1431 }
1432 
mm_mat1_backward(const Tensor & grad,const Tensor & mat2,at::SymIntArrayRef mat1_sizes,at::SymIntArrayRef mat1_strides,c10::Layout mat1_layout,const Scalar & alpha)1433 Tensor mm_mat1_backward(
1434     const Tensor& grad,
1435     const Tensor& mat2,
1436     at::SymIntArrayRef mat1_sizes,
1437     at::SymIntArrayRef mat1_strides,
1438     c10::Layout mat1_layout,
1439     const Scalar& alpha) {
1440   if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
1441       mat1_layout == c10::kStrided) {
1442     // if input was column-major, return grad as column-order for efficiency
1443     if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) {
1444       return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha.conj());
1445     }
1446   }
1447 
1448   // General fallback, should work for any layout
1449   return maybe_multiply(grad.mm(mat2.t().conj()), alpha.conj());
1450 }
1451 
mm_mat2_backward(const Tensor & grad,const Tensor & mat1,at::SymIntArrayRef mat2_sizes,at::SymIntArrayRef mat2_strides,c10::Layout mat2_layout,const Scalar & alpha)1452 Tensor mm_mat2_backward(
1453     const Tensor& grad,
1454     const Tensor& mat1,
1455     at::SymIntArrayRef mat2_sizes,
1456     at::SymIntArrayRef mat2_strides,
1457     c10::Layout mat2_layout,
1458     const Scalar& alpha) {
1459   if (grad.layout() == c10::kStrided && mat1.layout() == c10::kStrided &&
1460       mat2_layout == c10::kStrided) {
1461     // if input was column-major, return grad as column-order for efficiency
1462     if (mat2_strides[0] == 1 && mat2_strides[1] == mat2_sizes[0]) {
1463       return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha.conj());
1464     }
1465   }
1466 
1467   // General fallback, should work for any layout
1468   return maybe_multiply(mat1.t().conj().mm(grad), alpha.conj());
1469 }
1470 
mm_mat1_sparse_backward(const Tensor & grad,const Tensor & mat1,const Tensor & mat2,const Scalar & alpha)1471 Tensor mm_mat1_sparse_backward(
1472     const Tensor& grad,
1473     const Tensor& mat1,
1474     const Tensor& mat2,
1475     const Scalar& alpha) {
1476   if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
1477       mat1.is_sparse()) {
1478     auto sparse = mat1.coalesce();
1479     Tensor grad_sparse = maybe_multiply(grad.mm(mat2.conj().t()), alpha);
1480     return grad_sparse.sparse_mask(sparse);
1481   } else if (
1482       grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
1483       mat1.is_sparse_csr()) {
1484     // zero must to have mat1 sparsity pattern:
1485     auto zero = mat1.clone();
1486     zero.values().zero_();
1487     return at::sparse_sampled_addmm(zero, grad, mat2.mH(), 1.0, alpha);
1488   } else if (
1489       grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
1490       mat1.layout() == c10::kStrided) {
1491     return maybe_multiply(grad.mm(mat2.mH()), alpha);
1492   }
1493   TORCH_CHECK(
1494       false,
1495       "sparse_addmm_sparse_backward: unsupported combination of layouts",
1496       ", grad: ",
1497       grad.layout(),
1498       ", mat1: ",
1499       mat1.layout(),
1500       ", mat2: ",
1501       mat2.layout());
1502 }
1503 
sparse_mask_like_grad(const Tensor & x,const Tensor & gx,bool accumulate_matches)1504 static Tensor sparse_mask_like_grad(
1505     const Tensor& x,
1506     const Tensor& gx,
1507     bool accumulate_matches) {
1508   if (x.is_coalesced() && gx.is_coalesced()) {
1509     if (x._nnz() >= gx._nnz()) {
1510       // search into x is faster
1511       return gx._sparse_mask_projection(x, accumulate_matches);
1512     } else {
1513       // search into gx is faster
1514       return gx.sparse_mask(x);
1515     }
1516   } else if (x.is_coalesced()) {
1517     return gx.sparse_mask(x);
1518   } else if (gx.is_coalesced()) {
1519     return gx._sparse_mask_projection(x, accumulate_matches);
1520   } else {
1521     if (x._nnz() >= gx._nnz()) {
1522       // gx.coalesce() is likely faster
1523       return gx.coalesce()._sparse_mask_projection(x, accumulate_matches);
1524     } else {
1525       // x.coalesce() is likely faster
1526       return gx.sparse_mask(x.coalesce());
1527     }
1528   }
1529 }
1530 
sparse_sampled_addmm_backward(const Tensor & grad,const Tensor & self,const std::optional<Tensor> & mat1,const std::optional<Tensor> & mat2,const Scalar & alpha,const Scalar & beta,const std::array<bool,3> & grad_input_mask)1531 std::tuple<Tensor, Tensor, Tensor> sparse_sampled_addmm_backward(
1532     const Tensor& grad,
1533     const Tensor& self,
1534     const std::optional<Tensor>& mat1,
1535     const std::optional<Tensor>& mat2,
1536     const Scalar& alpha,
1537     const Scalar& beta,
1538     const std::array<bool, 3>& grad_input_mask) {
1539   if (!grad.defined()) {
1540     return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
1541   }
1542 
1543   const auto grad_projected = grad.sparse_mask(self);
1544   const auto self_requires_grad = grad_input_mask[0];
1545   const auto mat1_requires_grad = grad_input_mask[1];
1546   const auto mat2_requires_grad = grad_input_mask[2];
1547   return std::make_tuple(
1548       self_requires_grad ? maybe_multiply(grad, beta.conj()) : Tensor{},
1549       mat1_requires_grad
1550           // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
1551           ? maybe_multiply(grad_projected.mm(mat2->mH()), alpha.conj())
1552           : Tensor{},
1553       mat2_requires_grad
1554           // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
1555           ? maybe_multiply(mat1->mH().mm(grad_projected), alpha.conj())
1556           : Tensor{});
1557 }
1558 
sparse_mask_backward(const Tensor & grad,const Tensor & mask,const c10::Layout self_layout)1559 Tensor sparse_mask_backward(
1560     const Tensor& grad,
1561     const Tensor& mask,
1562     const c10::Layout self_layout) {
1563   // NOTE: sparse_mask accumulates matches, so the backward step has to
1564   // accumulate as well.
1565   const auto self_grad =
1566       sparse_mask_like_grad(mask, grad, /*accumulate_matches=*/true);
1567   return self_layout == at::kStrided ? self_grad.to_dense() : self_grad;
1568 }
1569 
sparse_sparse_matmul_backward(const Tensor & grad,const Tensor & a,const Tensor & b,int64_t grad_order)1570 Tensor sparse_sparse_matmul_backward(
1571     const Tensor& grad,
1572     const Tensor& a,
1573     const Tensor& b,
1574     int64_t grad_order) {
1575   /*
1576   To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we
1577   can start from the following definition for dense tensors:
1578 
1579   c = a @ b
1580       then
1581   a_grad = c_grad @ b^H
1582   b_grad = a^H @ c_grad
1583 
1584   So for sparse matrices we can use the following definition:
1585 
1586   if grad_order == 0:
1587       a_grad = sparse_matrix_mask(c_grad @ b^H, mask=a)
1588   else:
1589       b_grad = sparse_matrix_mask(a^H @ c_grad, mask=b)
1590   */
1591   TORCH_CHECK(
1592       grad_order == 0 || grad_order == 1,
1593       ": grad_order not in [0, 1] at sparse_sparse_matmul_backward function");
1594 
1595   // NOTE: _sparse_sparse_matmul returns a coalesced gradient,
1596   //   // hence there is no need in accumulating matches.
1597   if (grad_order == 0) {
1598     auto a_grad = _sparse_sparse_matmul(grad, b.conj().t());
1599     return sparse_mask_like_grad(a, a_grad, /*accumulate_matches=*/false);
1600   }
1601   auto b_grad = _sparse_sparse_matmul(a.conj().t(), grad);
1602   return sparse_mask_like_grad(b, b_grad, /*accumulate_matches=*/false);
1603 }
1604 
renorm_backward(const Tensor & grad,const Tensor & self,const Scalar & p,int64_t dim,const Scalar & maxnorm)1605 Tensor renorm_backward(
1606     const Tensor& grad,
1607     const Tensor& self,
1608     const Scalar& p,
1609     int64_t dim,
1610     const Scalar& maxnorm) {
1611   auto n = self.dim();
1612   dim = c10::maybe_wrap_dim(dim, n);
1613   auto reduce_dims = at::DimVector(n);
1614   std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
1615   reduce_dims.erase(reduce_dims.begin() + dim);
1616 
1617   auto acc_type =
1618       at::toAccumulateType(self.scalar_type(), self.device().type());
1619   auto norm = at::linalg_vector_norm(
1620       self, p, reduce_dims, /*keepdim=*/true, /*dtype=*/acc_type);
1621 
1622   const auto real_acc_type = c10::toRealValueType(acc_type);
1623   auto grad_output = (self.conj() * grad);
1624   // vector_norm output is real, so grad_output must also be real
1625   if (real_acc_type != acc_type) {
1626     grad_output = at::real(grad_output);
1627   }
1628   grad_output =
1629       grad_output.sum(reduce_dims, /*keepdim=*/true, /*dtype=*/real_acc_type);
1630   auto nb = norm_backward(
1631       std::move(grad_output), self, p, norm, reduce_dims, /*keepdim=*/true);
1632 
1633   auto invnorm = (norm + 1e-7).reciprocal();
1634   auto grad_norm = maxnorm * invnorm * (grad - invnorm * nb);
1635   return at::where(norm > maxnorm, grad_norm.to(grad.scalar_type()), grad);
1636 }
1637 
renorm_jvp(const Tensor & self_p,const Tensor & self_t,const Scalar & p,int64_t dim,const Scalar & maxnorm)1638 Tensor renorm_jvp(
1639     const Tensor& self_p,
1640     const Tensor& self_t,
1641     const Scalar& p,
1642     int64_t dim,
1643     const Scalar& maxnorm) {
1644   auto self_sizes = self_p.sizes();
1645   dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(self_sizes.size()));
1646 
1647   at::DimVector reduce_dims(self_sizes.size());
1648   std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
1649   reduce_dims.erase(reduce_dims.begin() + dim);
1650 
1651   // For cuda half, calculate norm in float precision then cast
1652   // normalization factor to half
1653   auto dtype = self_p.scalar_type();
1654   auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true);
1655   Tensor norm = [&self_p, &p, &reduce_dims, acc_type, dtype]() {
1656     if (acc_type != dtype) {
1657       return at::linalg_vector_norm(
1658           self_p,
1659           p.toDouble(),
1660           reduce_dims,
1661           /*keepdim=*/true,
1662           /*dtype=*/acc_type);
1663     } else {
1664       return at::linalg_vector_norm(
1665           self_p,
1666           p.toDouble(),
1667           reduce_dims,
1668           /*keepdim=*/true);
1669     }
1670   }();
1671 
1672   auto double_maxnorm = maxnorm.toDouble();
1673   auto invnorm = (norm + 1e-7).reciprocal();
1674   auto factor = invnorm * double_maxnorm;
1675 
1676   return where(
1677       norm > double_maxnorm,
1678       factor *
1679           (self_t -
1680            self_p * invnorm *
1681                norm_jvp(
1682                    self_p, self_t, p, norm, reduce_dims, /*keepdim=*/true)),
1683       self_t);
1684 }
1685 
repeat_backward(Tensor grad,c10::SymIntArrayRef repeats,c10::SymIntArrayRef input_shape)1686 Tensor repeat_backward(
1687     Tensor grad,
1688     c10::SymIntArrayRef repeats,
1689     c10::SymIntArrayRef input_shape) {
1690   auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0);
1691   if (find_iter != repeats.cend()) {
1692     return at::zeros_symint(input_shape, grad.options());
1693   }
1694   const auto input_dims = input_shape.size();
1695   auto num_unsqueezed = grad.dim() - input_dims;
1696   for (const auto i : c10::irange(num_unsqueezed)) {
1697     (void)i; // Suppress unused variable warning
1698     grad = grad.sum(0, false);
1699   }
1700 
1701   at::SymDimVector grad_size;
1702   at::DimVector sum_dims;
1703   for (const auto dim : c10::irange(input_dims)) {
1704     const auto& repeat = repeats[dim + num_unsqueezed];
1705     // Reshape gradient (repeat > 1)
1706     // Index:      [..., dim    , ...]    [..., dim   ,  dim+1        , ...]
1707     // Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...]
1708     // The gradient tensor at 'dim' is reshaped to 'repeat' times of input
1709     // tensor. Then, sum up gradients over repeated tensors along 'dim', and
1710     // reduce shape from 'repeat * dimsize/repeat' to 'dimsize/repeat'
1711     // ('input_dimsize'). Example:
1712     //        Size(3, 2)                                      Size(6, 2)
1713     //                                                      [[v1_0, v1_1],
1714     //                                                       [v1_2, v1_3],
1715     //        [[v0, v1],               repeat(2, 1)          [v1_4, v1_5],
1716     //         [v2, v3],              ------------->         [v2_0, v2_1],
1717     //         [v4, v5]]                                     [v2_2, v2_3],
1718     //                                                       [v2_4, v2_5]]
1719     //
1720     //    input grad (3, 2)      reshape (2, 3, 2)         output grad (6, 2)
1721     //                            [[[g1_0, g1_1],            [[g1_0, g1_1],
1722     //                              [g1_2, g1_3],             [g1_2, g1_3],
1723     // [[g1_0+g2_0, g1_1+g2_1],     [g1_4, g1_5]],            [g1_4, g1_5],
1724     //  [g1_2+g2_2, g1_3+g2_3],     [g2_0, g2_1],            [[g2_0, g2_1],
1725     //  [g1_4+g2_4, g1_5+g2_5]]     [g2_2, g2_3],             [g2_2, g2_3],
1726     //                              [g2_4, g2_5]]             [g2_4, g2_5]]]
1727     //
1728     // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and
1729     // then sum over 'dim+1'. The gradient for input is not correctly aligned
1730     // with input. Example:
1731     //  input grad (3, 2)        reshape (3, 2, 2)        output grad (6, 2)
1732     //                           [[[g1_0, g1_1],           [[g1_0, g1_1],
1733     //                             [g1_2, g1_3]],           [g1_2, g1_3],
1734     // [[g1_0+g1_2, g1_1+g1_3],   [[g1_4, g1_5],            [g1_4, g1_5],
1735     //  [g1_4+g2_0, g1_5+g2_1],    [g2_0, g2_1]],           [g2_0, g2_1],
1736     //  [g2_2+g2_4, g2_3+g2_5]]   [[g2_2, g2_3],            [g2_2, g2_3],
1737     //                             [g2_4, g2_5]]]           [g2_4, g2_5]]
1738     if (repeat != 1) {
1739       grad_size.push_back(repeat);
1740       sum_dims.push_back(static_cast<int64_t>(grad_size.size() - 1));
1741     }
1742     // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat ==
1743     // 1)
1744     grad_size.push_back(input_shape[dim]);
1745   }
1746   // One-time Reshape & Sum
1747   // Reshape gradient to grad_size:
1748   //   1. If repeat equals to 1, append input size at that dimension,
1749   //   2. If repeat is larger than 1, append both repeat and input size at that
1750   //   dimension.
1751   // Sum over all "repeat" dimensions from sum_dims:
1752   // Example:
1753   // Input Size         (2,    3,    4,    5)
1754   // repeat             [4,    1,    9,    3]
1755   // output/grad Size   (8,    3,    36,   15)
1756   // grad_size          [4, 2,    3, 9, 4, 3, 5]
1757   // sum_dims           [0,          3,    5]
1758 
1759   // When repeat 1 time over all original dimensions, the empty sum_dims will
1760   // reduce the whole grad tensor into a scalar rather than keeping original
1761   // dimensions.
1762   if (!sum_dims.empty()) {
1763     grad = grad.reshape_symint(grad_size);
1764     grad = grad.sum(sum_dims);
1765   }
1766   return grad;
1767 }
1768 
1769 // p1m == 1 - p
_fused_dropout_backward(const Tensor & grad,const Tensor & mask,double p1m)1770 Tensor _fused_dropout_backward(
1771     const Tensor& grad,
1772     const Tensor& mask,
1773     double p1m) {
1774   if (grad.requires_grad()) {
1775     // Use autograd-friendly backward if double backward is required
1776     return grad * (mask.type_as(grad) * (1. / p1m));
1777   } else {
1778     return at::_masked_scale(grad, mask, 1. / p1m);
1779   }
1780 }
1781 
1782 // scale == (1 / (1 - prob))
infinitely_differentiable_native_dropout_backward(const Tensor & grad,const Tensor & mask,double scale)1783 Tensor infinitely_differentiable_native_dropout_backward(
1784     const Tensor& grad,
1785     const Tensor& mask,
1786     double scale) {
1787   return grad * (mask.type_as(grad) * scale);
1788 }
1789 
native_dropout_double_backward(const Tensor & ggI,const Tensor & grad,const Tensor & mask,double scale)1790 Tensor native_dropout_double_backward(
1791     const Tensor& ggI,
1792     const Tensor& grad,
1793     const Tensor& mask,
1794     double scale) {
1795   return ggI.type_as(grad) * (mask.type_as(grad) * scale);
1796 }
1797 
evenly_distribute_backward(const Tensor & grad,const Tensor & input,const Tensor & value)1798 Tensor evenly_distribute_backward(
1799     const Tensor& grad,
1800     const Tensor& input,
1801     const Tensor& value) {
1802   bool any_tensor_subclass_like =
1803       areAnyTensorSubclassLike({grad, input, value});
1804   if (any_tensor_subclass_like || input.is_cuda()) {
1805     const auto input_isnan = input.isnan();
1806     const auto value_isnan = value.isnan();
1807     const auto& input_and_value_isnan = any_tensor_subclass_like
1808         ? input_isnan.logical_and(value_isnan)
1809         : input_isnan.logical_and_(value_isnan);
1810     const auto mask = (input == value).logical_or_(input_and_value_isnan);
1811     return mask * (grad / mask.sum());
1812   } else {
1813     auto mask = value.isnan().item<bool>() ? input.isnan() : input == value;
1814     return grad.new_zeros(input.sizes(), input.options())
1815         .masked_fill_(mask, grad / mask.sum());
1816   }
1817 }
1818 
evenly_read_jvp(const Tensor & fw_grad,const Tensor & input,const Tensor & value)1819 Tensor evenly_read_jvp(
1820     const Tensor& fw_grad,
1821     const Tensor& input,
1822     const Tensor& value) {
1823   auto mask = (input == value);
1824   auto count = mask.sum();
1825   auto grad_output = fw_grad / count;
1826   return at::sum(mask * grad_output);
1827 }
1828 
var_backward(Tensor grad,const Tensor & self,at::OptionalIntArrayRef dim_opt,const std::optional<at::Scalar> & correction_opt,bool keepdim)1829 Tensor var_backward(
1830     Tensor grad,
1831     const Tensor& self,
1832     at::OptionalIntArrayRef dim_opt,
1833     const std::optional<at::Scalar>& correction_opt,
1834     bool keepdim) {
1835   const auto correction = correction_opt.value_or(1).toSymFloat();
1836   if (self.dim() == 0 || !dim_opt.has_value()) {
1837     const auto dof = c10::SymFloat(self.sym_numel()) - correction;
1838     if (dof <= 0) {
1839       // when n == correction, 2 / (n - correction) is infinity
1840       // when self == self.mean(), we return NaN because infinity * 0 = NaN
1841       // otherwise, we return infinity because infinity * c = infinity, for all
1842       // c > 0
1843       return grad *
1844           at::where(
1845                  self == self.mean(),
1846                  std::numeric_limits<double>::quiet_NaN(),
1847                  std::numeric_limits<double>::infinity());
1848     } else {
1849       return (c10::SymFloat(2.0) / dof) * grad * (self - self.mean());
1850     }
1851   }
1852   auto dim = dim_opt.value();
1853   if (!keepdim && self.dim() > 1) {
1854     grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size());
1855   }
1856   const c10::SymFloat rnumel(_safe_size(self.sym_sizes(), dim));
1857   return (c10::SymFloat(2.0) / (rnumel - correction)) * grad *
1858       (self - self.mean(dim, /*keepdim=*/true));
1859 }
1860 
std_backward(const Tensor & result,const Tensor & grad,const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<c10::Scalar> & correction_opt,bool keepdim)1861 Tensor std_backward(
1862     const Tensor& result,
1863     const Tensor& grad,
1864     const Tensor& self,
1865     at::OptionalIntArrayRef dim,
1866     const std::optional<c10::Scalar>& correction_opt,
1867     bool keepdim) {
1868   auto grad_var = (grad / (result * 2)).masked_fill_(result == 0, 0);
1869   return var_backward(std::move(grad_var), self, dim, correction_opt, keepdim);
1870 }
1871 
var_mean_backward(const Tensor & gvar,const Tensor & gmean,const Tensor & self,at::OptionalIntArrayRef dim_opt,const std::optional<c10::Scalar> & correction_opt,bool keepdim)1872 Tensor var_mean_backward(
1873     const Tensor& gvar,
1874     const Tensor& gmean,
1875     const Tensor& self,
1876     at::OptionalIntArrayRef dim_opt,
1877     const std::optional<c10::Scalar>& correction_opt,
1878     bool keepdim) {
1879   Tensor gself;
1880   if (gvar.defined()) {
1881     gself = var_backward(gvar, self, dim_opt, correction_opt, keepdim);
1882   }
1883   if (gmean.defined()) {
1884     auto aux = mean_backward(
1885         gmean,
1886         self.sym_sizes(),
1887         dim_opt.value_or(IntArrayRef({})),
1888         self.sym_numel(),
1889         keepdim);
1890     gself = gself.defined() ? gself + aux : std::move(aux);
1891   }
1892   return gself;
1893 }
1894 
std_mean_backward(const Tensor & gstd,const Tensor & gmean,const Tensor & self,const Tensor & std,at::OptionalIntArrayRef dim_opt,const std::optional<c10::Scalar> & correction_opt,bool keepdim)1895 Tensor std_mean_backward(
1896     const Tensor& gstd,
1897     const Tensor& gmean,
1898     const Tensor& self,
1899     const Tensor& std,
1900     at::OptionalIntArrayRef dim_opt,
1901     const std::optional<c10::Scalar>& correction_opt,
1902     bool keepdim) {
1903   Tensor gself;
1904   if (gstd.defined()) {
1905     gself = std_backward(std, gstd, self, dim_opt, correction_opt, keepdim);
1906   }
1907   if (gmean.defined()) {
1908     auto aux = mean_backward(
1909         gmean,
1910         self.sym_sizes(),
1911         dim_opt.value_or(IntArrayRef({})),
1912         self.sym_numel(),
1913         keepdim);
1914     gself = gself.defined() ? gself + aux : std::move(aux);
1915   }
1916   return gself;
1917 }
1918 
cholesky_jvp(const Tensor & dA,const Tensor & L,bool upper)1919 Tensor cholesky_jvp(const Tensor& dA, const Tensor& L, bool upper) {
1920   at::NoTF32Guard disable_tf32;
1921   // Let A = LL^H
1922   // dA = dLL^H + L(dL)^H
1923   // L^{-1}dA(L^{-H}) = L^{-1}dL + (L^{-1}dL)^H
1924   //               = sym(L^{-1}dL)
1925   // where sym(X) = X + X^H
1926   // A short computation gives that the inverse of sym is given by
1927   // \pi(X) = X.tril() - 0.5*diag(X)
1928   // so
1929   // dL = L\pi(L^{-1}dA(L^{-H}))
1930 
1931   // Precondition: dA is symmetric/Hermitian
1932   auto L_ = upper ? L.mH() : L;
1933   auto dL = at::linalg_solve_triangular(L_, dA, /*upper=*/false, /*left=*/true);
1934   dL = at::linalg_solve_triangular(L_.mH(), dL, /*upper=*/true, /*left=*/false);
1935   dL = dL.tril() - dL.diagonal(0, -2, -1).mul(0.5).diag_embed();
1936   dL = L_.matmul(dL);
1937   return upper ? dL.mH() : std::move(dL);
1938 }
1939 
cholesky_backward(const Tensor & gL,bool upper,const Tensor & L)1940 Tensor cholesky_backward(const Tensor& gL, bool upper, const Tensor& L) {
1941   at::NoTF32Guard disable_tf32;
1942   // From cholesky_jvp we have that
1943   // dL = L\pi(L^{-1}dA(L^-H))
1944   //
1945   // Let gL be the projection into the lower-triangular gradient wrt L. Taking
1946   // adjoints we have gA = L^{-H}\pi^*((L^HgL).tril())L^{-1} where \pi^*(X) =
1947   // 0.5 * (X + X^H - diag(X)) The only non-standard point of this derivation is
1948   // noting that the adjoint to multiplying on the left by a lower triangular
1949   // matrix L is multiplying by L^H and then projecting back to the lower
1950   // triangular matrices (hence the .tril() projection) Note that the gradient
1951   // is symmetric and not triangular.
1952   auto L_ = upper ? L.mH() : L;
1953   auto gL_ = upper ? gL.mH() : gL;
1954 
1955   // Nb. We don't need to compute gL_ = gL.tril() as
1956   // tril(L^H gL) = tril(L^H (triu(gL, 1) + tril(gL)))
1957   //              = tril(L^H tril(gL)) + tril(L^H triu(gL, 1))
1958   //              = tril(L^H tril(gL))
1959   // since tril(L^H triu(gL, 1)) = 0, as L^H triu(gL, 1) is upper triangular
1960   auto gA = L_.mH().matmul(gL_).tril();
1961   // Equivalent to 0.5 * (gA + gA^H - diag(gA))
1962   gA = 0.5 * (gA + gA.tril(-1).mH());
1963   gA = at::linalg_solve_triangular(L_.mH(), gA, /*upper=*/true, /*left=*/true);
1964   gA = at::linalg_solve_triangular(L_, gA, /*upper=*/false, /*left=*/false);
1965   return gA;
1966 }
1967 
cholesky_inverse_backward(const Tensor & grad,const Tensor & L,bool upper,const Tensor & inverse)1968 Tensor cholesky_inverse_backward(
1969     const Tensor& grad,
1970     const Tensor& L,
1971     bool upper,
1972     const Tensor& inverse) {
1973   at::NoTF32Guard disable_tf32;
1974   Tensor grad_L;
1975   if (grad.defined()) {
1976     Tensor common_term = grad + grad.mH();
1977     common_term = at::matmul(inverse, at::matmul(common_term, inverse));
1978     if (upper) {
1979       grad_L = -at::matmul(L, common_term);
1980     } else {
1981       grad_L = -at::matmul(common_term, L);
1982     }
1983   }
1984 
1985   return grad_L;
1986 }
1987 
1988 // If X = (L L^H)^{-1} with L lower-triangular with a real positive diagonal,
1989 // then dX = K^H + K, where
1990 // K =  L^{-H} dL^{-1} [dL^{-1} = -L^{-1} dL L^{-1}]
1991 //   = -L^{-H} L^{-1} dL L^{-1} [L^{-H} L^{-1} = X]
1992 //   = -X dL L^{-1} [X = X^H = L^{-H} L^{-1} = L^{-1} L^{-H}]
1993 //   = -X dL X L^{H}.
1994 // If X = (U^H U)^{-1} with U upper-triangular with a real positive diagonal,
1995 // then K becomes
1996 // K = -X dU^H X U
cholesky_inverse_jvp(const Tensor & F,const Tensor & dF,const Tensor & X,bool upper)1997 Tensor cholesky_inverse_jvp(
1998     const Tensor& F,
1999     const Tensor& dF,
2000     const Tensor& X,
2001     bool upper) {
2002   at::NoTF32Guard disable_tf32;
2003   const auto CF = upper ? F : F.mH();
2004   const auto dCF = upper ? dF.mH() : dF;
2005   const auto partial_dX = -X.matmul(dCF).matmul(X).matmul(CF);
2006   return partial_dX + partial_dX.mH();
2007 }
2008 
2009 // The formula for forward AD is adapted from
2010 //
2011 // Golub, Gene H., and Victor Pereyra. "The Differentiation of Pseudo-Inverses
2012 // and Nonlinear Least Squares Problems Whose Variables Separate." SIAM Journal
2013 // on Numerical Analysis 10(2). (1973). 413-432. doi: 10.1137/0710036
2014 //
2015 // We present a short derivation below:
2016 // Let Ap := pinv(A), then Ap is the unique matrix such that
2017 //
2018 // Ap A Ap = Ap [1]
2019 // A Ap A = A   [2]
2020 //
2021 // By differentiating [1] we get:
2022 //
2023 // dAp = dAp A Ap + Ap dA Ap + Ap A dAp [3]
2024 //
2025 // In the rhs of [3] the products involving dAp could be expressed as products
2026 // of Ap^i, A^j, dA^k with i, j, k in {1, H}, where X^H = X.mH(). To prove that,
2027 // note (A Ap)^H = A Ap and (Ap A)^H = Ap A, which could be shown by taking the
2028 // product between the SVD decompositions of A and Ap. Consider the
2029 // conjugate-transposed [2]: (A Ap A)^H = A^H (A Ap) = A^H. By differentiating
2030 // it we get: dA^H A Ap + A^H dA Ap + A^H A dAp = dA^H. By multiplying from the
2031 // left by Ap^H and using Ap^H A^H = (A Ap)^H = A Ap: Ap^H dA^H A Ap + A Ap dA
2032 // Ap + A Ap A dAp = Ap^H dA^H. By multiplying from the left by Ap and by
2033 // applying [1] and [2] repeatedly until impossible we get: Ap Ap^H dA^H A Ap +
2034 // Ap dA Ap + Ap A dAp = Ap Ap^H dA^H. By rearranging the terms:
2035 //
2036 // Ap A dAp = -Ap dA Ap + Ap Ap^H dA^H (I - A Ap) [4],
2037 // which is one of the summands in [3].
2038 //
2039 // Similar, by differentiating the transpose-conjugated [2] written differently,
2040 // i.e. (A Ap A)^H = Ap A A^H = A^H we will get an expression for dAp A Ap,
2041 // which is
2042 //
2043 // dAp A Ap = -Ap dA Ap + (I - Ap A) dA^H Ap^H Ap [5].
2044 //
2045 // By plugging in [4] and [5] into [3] we get the forward AD formula for pinv:
2046 //
2047 // dAp = -Ap dA Ap + (I - Ap A) dA^H Ap^H Ap + Ap Ap^H dA^H (I - A Ap).
pinv_jvp(const Tensor & A,const Tensor & pinvA,const Tensor & dA)2048 Tensor pinv_jvp(const Tensor& A, const Tensor& pinvA, const Tensor& dA) {
2049   at::NoTF32Guard disable_tf32;
2050   auto m = A.size(-2);
2051   auto n = A.size(-1);
2052   auto dAh = dA.mH();
2053   auto pinvAh = pinvA.mH();
2054   // optimization to produce matrices of the smallest dimension
2055   if (m <= n) {
2056     auto K = pinvAh.matmul(dAh);
2057     return pinvA.matmul(K - K.mH() - K.matmul(A.matmul(pinvA))) +
2058         (dAh - pinvA.matmul(A.matmul(dAh))).matmul(pinvAh.matmul(pinvA));
2059   } else {
2060     auto K = pinvA.matmul(dA);
2061     auto Kh = K.mH();
2062     return (Kh - K - pinvA.matmul(A).matmul(Kh)).matmul(pinvA) +
2063         (pinvA.matmul(pinvAh)).matmul(dAh - (dAh.matmul(A)).matmul(pinvA));
2064   }
2065 }
2066 
pinv_backward(const Tensor & grad,const Tensor & pinvA,const Tensor & A)2067 Tensor pinv_backward(const Tensor& grad, const Tensor& pinvA, const Tensor& A) {
2068   at::NoTF32Guard disable_tf32;
2069   auto m = A.sym_size(-2);
2070   auto n = A.sym_size(-1);
2071   auto pinvAh = pinvA.mH();
2072   auto gradh = grad.mH();
2073   // optimization to produce matrices of the smallest dimension
2074   if (m <= n) {
2075     auto K = gradh.matmul(pinvA);
2076     auto KpinvAh = K.matmul(pinvAh);
2077     return -(pinvA.matmul(K)).mH() + KpinvAh -
2078         (A.matmul(pinvA)).matmul(KpinvAh) +
2079         (pinvAh.matmul(pinvA)).matmul(gradh - K.matmul(A));
2080   } else {
2081     auto K = pinvA.matmul(gradh);
2082     auto pinvAhK = pinvAh.matmul(K);
2083     return -(K.matmul(pinvA)).mH() +
2084         (gradh - A.matmul(K)).matmul(pinvA).matmul(pinvAh) + pinvAhK -
2085         pinvAhK.matmul(pinvA).matmul(A);
2086   }
2087 }
2088 
chunk_backward_nested(const std::vector<torch::autograd::Variable> & grads,const Tensor & self,int64_t chunks,int64_t dim)2089 Tensor chunk_backward_nested(
2090     const std::vector<torch::autograd::Variable>& grads,
2091     const Tensor& self,
2092     int64_t chunks,
2093     int64_t dim) {
2094   TORCH_INTERNAL_ASSERT(
2095       self.layout() == c10::kJagged,
2096       "Nested Strided Tensor doesn't support chunk backward.")
2097   dim = at::maybe_wrap_dim(dim, self.dim());
2098   TORCH_INTERNAL_ASSERT(
2099       dim != 0, "Nested Tensor doesn't support chunk backward on dim=0 yet.")
2100   Tensor ret = at::zeros_like(self);
2101   std::vector<Tensor> rets = at::chunk(ret, chunks, dim);
2102   for (const auto j : c10::irange(grads.size())) {
2103     if (grads[j].defined()) {
2104       rets[j].copy_(grads[j]);
2105     }
2106   }
2107   return ret;
2108 }
2109 
split_with_sizes_backward(const std::vector<torch::autograd::Variable> & grads,c10::SymIntArrayRef split_sizes,int64_t dim,c10::SymIntArrayRef sizes,const at::TensorOptions & options)2110 Tensor split_with_sizes_backward(
2111     const std::vector<torch::autograd::Variable>& grads,
2112     c10::SymIntArrayRef split_sizes,
2113     int64_t dim,
2114     c10::SymIntArrayRef sizes,
2115     const at::TensorOptions& options) {
2116   dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
2117 
2118   // it's possible some of the grads are not defined (represents tensors of all
2119   // 0s). Since at::cat can't handle those, let's define them
2120   std::vector<Tensor> grads_all_defined(grads.size());
2121   for (const auto j : c10::irange(grads.size())) {
2122     if (grads[j].defined()) {
2123       grads_all_defined[j] = grads[j];
2124     } else {
2125       const auto& length = split_sizes[j];
2126       auto grad_size = sizes.vec();
2127       grad_size[dim] = length;
2128       grads_all_defined[j] = at::zeros_symint(grad_size, options);
2129     }
2130   }
2131 
2132   auto ret = at::cat(grads_all_defined, dim);
2133   return ret;
2134 }
2135 
_nested_split_with_sizes_backward(const std::vector<torch::autograd::Variable> & grads,c10::SymIntArrayRef split_sizes,int64_t dim,const Tensor & nt_sizes,const at::TensorOptions & options)2136 Tensor _nested_split_with_sizes_backward(
2137     const std::vector<torch::autograd::Variable>& grads,
2138     c10::SymIntArrayRef split_sizes,
2139     int64_t dim,
2140     const Tensor& nt_sizes,
2141     const at::TensorOptions& options) {
2142   // add 1 to account for batch dim
2143   dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(nt_sizes.size(1)) + 1);
2144   // it's possible some of the grads are not defined (represents tensors of all
2145   // 0s). Since at::cat can't handle those, let's define them
2146   std::vector<Tensor> grads_all_defined;
2147   for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
2148     if (grads[i].defined()) {
2149       grads_all_defined.push_back(static_cast<Tensor>(grads[i]));
2150     } else {
2151       const auto& length = split_sizes[i].guard_int(__FILE__, __LINE__);
2152       auto nt_split_size = nt_sizes.clone();
2153       auto nt_split_size_ptr = nt_split_size.data_ptr<int64_t>();
2154       for (int64_t j : c10::irange(static_cast<int64_t>(nt_sizes.size(0)))) {
2155         // subtract 1 to account for batch dim
2156         nt_split_size_ptr
2157             [j * static_cast<int64_t>(nt_sizes.size(1)) + (dim - 1)] = length;
2158       }
2159       Tensor zeros_buffer = at::zeros(
2160           {at::native::get_numel_from_nested_size_tensor(nt_split_size)},
2161           options);
2162       auto nt_split_grad = at::native::wrap_buffer(zeros_buffer, nt_split_size);
2163       grads_all_defined.push_back(nt_split_grad);
2164     }
2165   }
2166 
2167   auto ret = at::cat(grads_all_defined, dim);
2168   return ret;
2169 }
2170 
split_backward(const std::vector<torch::autograd::Variable> & grads,const c10::SymInt & split_size,int64_t dim,c10::SymIntArrayRef sym_sizes,const at::TensorOptions & options)2171 Tensor split_backward(
2172     const std::vector<torch::autograd::Variable>& grads,
2173     const c10::SymInt& split_size,
2174     int64_t dim,
2175     c10::SymIntArrayRef sym_sizes,
2176     const at::TensorOptions& options) {
2177   dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(sym_sizes.size()));
2178   const auto& dim_size = sym_sizes[dim];
2179   auto num_splits = grads.size();
2180   std::vector<c10::SymInt> split_sizes(num_splits, split_size);
2181   split_sizes[num_splits - 1] =
2182       split_size - (split_size * num_splits - dim_size);
2183   return split_with_sizes_backward(grads, split_sizes, dim, sym_sizes, options);
2184 }
2185 
max_pool_double_backward(const Tensor & grad,const Tensor & indices,int dim)2186 Tensor max_pool_double_backward(
2187     const Tensor& grad,
2188     const Tensor& indices,
2189     int dim) {
2190   AT_ASSERT(indices.dim() >= dim);
2191   // handle non-empty inputs
2192   if (indices.sym_numel() != 0) {
2193     auto size = indices.sym_sizes().slice(0, indices.dim() - dim).vec();
2194     size.emplace_back(-1);
2195     auto indices_view = indices.view_symint(size);
2196     const auto memory_format = indices.suggest_memory_format();
2197     return grad.contiguous(memory_format)
2198         .view_symint(size)
2199         .gather(-1, indices_view)
2200         .view_symint(indices.sym_sizes());
2201   }
2202   // handle empty inputs
2203   else {
2204     return at::empty_like(indices, grad.options());
2205   }
2206 }
2207 
error_for_max_pool2d_double_backward()2208 Tensor error_for_max_pool2d_double_backward() { // This is mps-only.
2209   TORCH_CHECK(
2210       false,
2211       "max_pool2d with `return_indices=False` is not infinitely differentiable.",
2212       " If you want to calculate higher order derivatives, e.g. second order,",
2213       " set `return_indices=True`.");
2214   return Tensor();
2215 }
2216 
glu_double_backward(const Tensor & grad,const Tensor & grad_output,const Tensor & input,int64_t dim)2217 Tensor glu_double_backward(
2218     const Tensor& grad,
2219     const Tensor& grad_output,
2220     const Tensor& input,
2221     int64_t dim) {
2222   auto& gO = grad_output;
2223   auto input_size = input.size(dim) / 2;
2224   auto first_half = input.narrow(dim, 0, input_size);
2225   auto second_half = input.narrow(dim, input_size, input_size);
2226   auto sig_second_half = second_half.sigmoid();
2227   auto one_sub_sig_second_half = 1 - sig_second_half;
2228   auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half;
2229 
2230   auto ggI_first_half = grad.narrow(dim, 0, input_size);
2231   auto ggI_second_half = grad.narrow(dim, input_size, input_size);
2232   auto ggI_second_half_times_first_half = ggI_second_half * first_half;
2233 
2234   auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig;
2235   auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half -
2236       sig_second_half * sig_one_sub_sig;
2237   auto gI_second_half =
2238       ggI_second_half_times_first_half * gO * second_order_sh +
2239       ggI_first_half * gO * sig_one_sub_sig;
2240   return at::cat({std::move(gI_first_half), std::move(gI_second_half)}, dim);
2241 }
2242 
glu_double_backward_grad_output(const Tensor & grad,const Tensor & input,int64_t dim)2243 Tensor glu_double_backward_grad_output(
2244     const Tensor& grad,
2245     const Tensor& input,
2246     int64_t dim) {
2247   if (dim < 0)
2248     dim += input.dim();
2249   auto sizes = input.sizes().vec();
2250   sizes[dim] /= 2;
2251   auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim);
2252   return tmp.narrow(dim, 0, sizes[dim]) +
2253       tmp.narrow(dim, sizes[dim], sizes[dim]);
2254 }
2255 
infinitely_differentiable_silu_backward(const Tensor & grad_output,const Tensor & input)2256 Tensor infinitely_differentiable_silu_backward(
2257     const Tensor& grad_output,
2258     const Tensor& input) {
2259   const Tensor sigmoid = input.sigmoid();
2260   return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid));
2261 }
2262 
infinitely_differentiable_mish_backward(const Tensor & grad_output,const Tensor & input)2263 Tensor infinitely_differentiable_mish_backward(
2264     const Tensor& grad_output,
2265     const Tensor& input) {
2266   const Tensor sigmoid = input.sigmoid();
2267   const Tensor softplus = input.exp().log1p();
2268   const Tensor tanh_softplus = softplus.tanh();
2269   return grad_output *
2270       (tanh_softplus + input * sigmoid * (1.0 - tanh_softplus * tanh_softplus));
2271 }
2272 
infinitely_differentiable_logit_backward(const Tensor & grad,const Tensor & self,std::optional<double> eps)2273 Tensor infinitely_differentiable_logit_backward(
2274     const Tensor& grad,
2275     const Tensor& self,
2276     std::optional<double> eps) {
2277   if (eps) {
2278     const double lo = eps.value();
2279     const double hi = 1.0 - lo;
2280     return at::where(
2281         at::logical_and(self >= lo, self <= hi),
2282         grad / (self * (1.0 - self)),
2283         at::zeros({}, self.options()));
2284   } else {
2285     return at::where(
2286         at::logical_and(self >= 0.0, self <= 1.0),
2287         grad / (self * (1.0 - self)),
2288         at::empty({}, self.options())
2289             .fill_(std::numeric_limits<double>::quiet_NaN()));
2290   }
2291 }
2292 
binary_cross_entropy_target_backward(const Tensor & grad,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction)2293 Tensor binary_cross_entropy_target_backward(
2294     const Tensor& grad,
2295     const Tensor& self,
2296     const Tensor& target,
2297     const std::optional<Tensor>& weight,
2298     int64_t reduction) {
2299   auto grad_target = at::logit(self).neg_();
2300 
2301   if (!areAnyTensorSubclassLike({grad})) {
2302     grad_target.mul_(grad);
2303   } else {
2304     grad_target = grad_target * grad;
2305   }
2306 
2307   if (isDefined(weight)) {
2308     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2309     if (!isTensorSubclassLike(weight.value())) {
2310       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2311       grad_target.mul_(weight.value());
2312     } else {
2313       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2314       grad_target = grad_target * weight.value();
2315     }
2316   }
2317 
2318   if (reduction == at::Reduction::Mean) {
2319     grad_target.div_(target.sym_numel());
2320   }
2321 
2322   return grad_target;
2323 }
2324 
binary_cross_entropy_double_backward_target(const Tensor & grad,const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction)2325 Tensor binary_cross_entropy_double_backward_target(
2326     const Tensor& grad,
2327     const Tensor& grad_output,
2328     const Tensor& self,
2329     const Tensor& target,
2330     const std::optional<Tensor>& weight,
2331     int64_t reduction) {
2332   auto res = -grad * grad_output;
2333 
2334   if (isDefined(weight)) {
2335     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2336     res = isTensorSubclassLike(weight.value())
2337         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2338         ? res.mul(weight.value())
2339         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2340         : res.mul_(weight.value());
2341   }
2342 
2343   auto neg_self = 1 - self;
2344   auto denom =
2345       isTensorSubclassLike(self) ? neg_self.mul(self) : neg_self.mul_(self);
2346   {
2347     at::NoGradGuard guard;
2348     // Default eps in binary_cross_entropy for ALL dtypes
2349     // TODO: probably change this to a dtype-dependent value
2350     double eps = 1e-12;
2351     denom.clamp_min_(eps);
2352   }
2353 
2354   res = isTensorSubclassLike(denom) ? res.div(denom) : res.div_(denom);
2355 
2356   if (reduction == at::Reduction::Mean) {
2357     res.div_(target.sym_numel());
2358   }
2359 
2360   return res;
2361 }
2362 
binary_cross_entropy_with_logits_backward(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight,const std::optional<Tensor> & pos_weight,int64_t reduction)2363 Tensor binary_cross_entropy_with_logits_backward(
2364     const Tensor& grad,
2365     const Tensor& input,
2366     const Tensor& target,
2367     const std::optional<Tensor>& weight,
2368     const std::optional<Tensor>& pos_weight,
2369     int64_t reduction) {
2370   // Trivial case
2371   if (grad._is_zerotensor()) {
2372     return at::_efficientzerotensor(input.sizes(), input.options());
2373   }
2374 
2375   // -w * [ pos * y * (1 -sigmoid(x)) - (1 - y) sigmoid(x)] * grad
2376 
2377   // If there are subclassed tensors use the out of place version
2378   Tensor grad_input;
2379   if (isDefined(pos_weight)) {
2380     // pos_weight might need to be broadcasted, thus mul(target) is not inplace.
2381     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2382     auto t = pos_weight->mul(target);
2383     grad_input = at::areAnyTensorSubclassLike({input, target}) ||
2384             at::GradMode::is_enabled()
2385         ? t.add(1).sub(target).mul(input.sigmoid()).sub(t)
2386         : t.add(1).sub_(target).mul_(input.sigmoid()).sub_(t);
2387   } else {
2388     grad_input = at::areAnyTensorSubclassLike({input, target}) ||
2389             at::GradMode::is_enabled()
2390         ? input.sigmoid().sub(target)
2391         : input.sigmoid().sub_(target);
2392   }
2393 
2394   if (at::isTensorSubclassLike(grad) || at::GradMode::is_enabled()) {
2395     grad_input = grad_input.mul(grad);
2396   } else {
2397     grad_input.mul_(grad);
2398   }
2399 
2400   if (isDefined(weight)) {
2401     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2402     if (at::isTensorSubclassLike(*weight) || at::GradMode::is_enabled()) {
2403       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2404       grad_input = grad_input.mul(*weight);
2405     } else {
2406       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2407       grad_input.mul_(*weight);
2408     }
2409   }
2410 
2411   if (reduction == at::Reduction::Mean) {
2412     grad_input.div_(input.sym_numel());
2413   }
2414 
2415   return grad_input;
2416 }
2417 
binary_cross_entropy_with_logits_target_backward(const Tensor & grad_output,const Tensor & self,const Tensor & target,const std::optional<Tensor> & weight,const std::optional<Tensor> & pos_weight,int64_t reduction)2418 Tensor binary_cross_entropy_with_logits_target_backward(
2419     const Tensor& grad_output,
2420     const Tensor& self,
2421     const Tensor& target,
2422     const std::optional<Tensor>& weight,
2423     const std::optional<Tensor>& pos_weight,
2424     int64_t reduction) {
2425   if (grad_output._is_zerotensor()) {
2426     return at::_efficientzerotensor(target.sizes(), target.options());
2427   }
2428 
2429   Tensor grad_target;
2430   if (isDefined(pos_weight)) {
2431     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2432     if (areAnyTensorSubclassLike({*pos_weight, grad_output})) {
2433       grad_target = at::log_sigmoid(-self)
2434                         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2435                         .sub(at::log_sigmoid(self).mul(*pos_weight))
2436                         .mul(grad_output);
2437     } else {
2438       grad_target = at::log_sigmoid(-self)
2439                         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2440                         .sub_(at::log_sigmoid(self).mul_(*pos_weight))
2441                         .mul_(grad_output);
2442     }
2443   } else {
2444     grad_target = -self * grad_output;
2445   }
2446 
2447   if (isDefined(weight)) {
2448     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2449     if (at::isTensorSubclassLike(*weight)) {
2450       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2451       grad_target = grad_target.mul(*weight);
2452     } else {
2453       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2454       grad_target.mul_(*weight);
2455     }
2456   }
2457 
2458   if (reduction == at::Reduction::Mean) {
2459     grad_target.div_(target.sym_numel());
2460   }
2461 
2462   return grad_target;
2463 }
2464 
log_sigmoid_double_backward(const Tensor & grad,const Tensor & input)2465 Tensor log_sigmoid_double_backward(const Tensor& grad, const Tensor& input) {
2466   auto z = input.sigmoid();
2467   return grad * (z - 1) * z;
2468 }
2469 
softmax_double_backward(const Tensor & grad,const Tensor & grad_output,int dim,const Tensor & output)2470 Tensor softmax_double_backward(
2471     const Tensor& grad,
2472     const Tensor& grad_output,
2473     int dim,
2474     const Tensor& output) {
2475   return grad_output * grad - (output * grad_output).sum(dim, true) * grad -
2476       grad_output * (output * grad).sum(dim, true);
2477 }
2478 
2479 // NOTE: [How to write vmap-compatible backward formulas]
2480 //
2481 // See NOTE: [vmap-incompatible in-place operations] for what it means for an
2482 // in-place operation to be incompatible with vmap.
2483 //
2484 // If an in-place operation used in a backward formula is vmap-incompatible,
2485 // then as developers we have the following options:
2486 //
2487 // - If the in-place operation directly followed the creation of a tensor with
2488 //   a factory function like at::zeros(...), we should replace the factory with
2489 //   a corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call
2490 //   propagates the batch dims to the resulting tensor.
2491 //   For example:
2492 //     Before: at::zeros(input.sizes(), grad.options()).copy_(grad)
2493 //     After:  grad.new_zeros(input.sizes()).copy_(grad)
2494 //
2495 // - If the in-place operation followed some sequence of operations, if the
2496 //   we want to be able to vmap over the backward formula as-is (this is
2497 //   usually the case for simple (<15loc) backward formulas), then use
2498 //   areAnyTensorSubclassLike  to guard the operation. For example:
2499 //             c = a * b
2500 //     Before: c.mul_(grad)
2501 //     After:  c = !areAnyTensorSubclassLike({c, grad}) ? c.mul_(grad) : c *
2502 //     grad
2503 //
2504 // - If we don't want to vmap directly over the backward formula (e.g., if the
2505 //   backward formula is too complicated or has a lot of vmap-incompatible
2506 //   operations, then register the backward formula as an operator and
2507 //   eventually write a batching rule for it.
2508 
binary_cross_entropy_double_backward(const Tensor & grad_output,const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction)2509 Tensor binary_cross_entropy_double_backward(
2510     const Tensor& grad_output,
2511     const Tensor& grad,
2512     const Tensor& input,
2513     const Tensor& target,
2514     const std::optional<Tensor>& weight,
2515     int64_t reduction) {
2516   auto eps = 1e-12;
2517   auto inp_pl_eps = input + eps;
2518   auto one_m_inp_pl_eps = 1 - input + eps;
2519   // gradient wrt input
2520   auto gI = (input * input - 2 * input * target + target) /
2521       (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2));
2522   if (!areAnyTensorSubclassLike({gI, grad})) {
2523     gI *= (grad * grad_output);
2524   } else {
2525     gI = gI * (grad * grad_output);
2526   }
2527 
2528   if (isDefined(weight)) {
2529     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2530     if (!isTensorSubclassLike(*weight)) {
2531       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2532       gI *= *weight;
2533     } else {
2534       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2535       gI = gI.mul(*weight);
2536     }
2537   }
2538   if (reduction == at::Reduction::Mean) {
2539     return gI / input.sym_numel();
2540   }
2541 
2542   return gI;
2543 }
2544 
binary_cross_entropy_double_backward_grad_output(const Tensor & grad,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight,int64_t reduction)2545 Tensor binary_cross_entropy_double_backward_grad_output(
2546     const Tensor& grad,
2547     const Tensor& input,
2548     const Tensor& target,
2549     const std::optional<Tensor>& weight,
2550     int64_t reduction) {
2551   auto eps = 1e-12;
2552   // gradient wrt grad_output
2553   auto ggO = (input - target) / ((input + eps) * (1 - input + eps));
2554   if (!areAnyTensorSubclassLike({ggO, grad})) {
2555     ggO *= grad;
2556   } else {
2557     ggO = ggO * grad;
2558   }
2559 
2560   if (isDefined(weight)) {
2561     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2562     if (!isTensorSubclassLike(*weight)) {
2563       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2564       ggO *= *weight;
2565     } else {
2566       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
2567       ggO = ggO.mul(*weight);
2568     }
2569   }
2570   if (reduction == at::Reduction::Mean) {
2571     return ggO / input.sym_numel();
2572   }
2573   return ggO;
2574 }
2575 
smooth_l1_loss_double_backward(const Tensor & grad,const Tensor & input,const Tensor & target,int64_t reduction,double beta)2576 Tensor smooth_l1_loss_double_backward(
2577     const Tensor& grad,
2578     const Tensor& input,
2579     const Tensor& target,
2580     int64_t reduction,
2581     double beta) {
2582   // special case to protect against a divide-by-zero.
2583   if (beta == 0) {
2584     return at::zeros(grad.sizes(), grad.options());
2585   }
2586   auto d = (input - target).abs();
2587   auto grad_input = grad * (d < beta).type_as(grad) / beta;
2588   if (reduction == at::Reduction::Mean) {
2589     grad_input /= input.sym_numel();
2590   }
2591   return grad_input;
2592 }
2593 
huber_loss_double_backward(const Tensor & grad,const Tensor & input,const Tensor & target,int64_t reduction,double delta)2594 Tensor huber_loss_double_backward(
2595     const Tensor& grad,
2596     const Tensor& input,
2597     const Tensor& target,
2598     int64_t reduction,
2599     double delta) {
2600   auto d = (input - target).abs();
2601   auto grad_input = grad * (d < delta);
2602   if (reduction == at::Reduction::Mean) {
2603     grad_input /= input.sym_numel();
2604   }
2605   return grad_input;
2606 }
2607 
huber_loss_double_backward_grad_output(const Tensor & grad,const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction,double delta)2608 Tensor huber_loss_double_backward_grad_output(
2609     const Tensor& grad,
2610     const Tensor& grad_output,
2611     const Tensor& input,
2612     const Tensor& target,
2613     int64_t reduction,
2614     double delta) {
2615   if (reduction == at::Reduction::None) {
2616     return huber_loss_backward(grad, input, target, reduction, delta);
2617   }
2618   auto r = huber_loss_backward(
2619       ones_like(grad_output), input, target, reduction, delta);
2620   return (r * grad).sum();
2621 }
2622 
mse_loss_double_backward(const Tensor & grad,const Tensor & input,int64_t reduction)2623 Tensor mse_loss_double_backward(
2624     const Tensor& grad,
2625     const Tensor& input,
2626     int64_t reduction) {
2627   auto grad_input = 2 * grad;
2628   if (reduction == at::Reduction::Mean) {
2629     grad_input /= input.sym_numel();
2630   }
2631   return grad_input;
2632 }
2633 
soft_margin_loss_double_backward(const Tensor & grad,const Tensor & input,const Tensor & target,int64_t reduction)2634 Tensor soft_margin_loss_double_backward(
2635     const Tensor& grad,
2636     const Tensor& input,
2637     const Tensor& target,
2638     int64_t reduction) {
2639   auto z = (input * -target).exp();
2640   auto zplus1 = z + 1;
2641   auto grad_input = grad * (target * target) * z / (zplus1 * zplus1);
2642   if (reduction == at::Reduction::Mean) {
2643     grad_input /= input.sym_numel();
2644   }
2645   return grad_input;
2646 }
2647 
soft_margin_loss_double_backward_grad_output(const Tensor & grad,const Tensor & grad_output,const Tensor & input,const Tensor & target,int64_t reduction)2648 Tensor soft_margin_loss_double_backward_grad_output(
2649     const Tensor& grad,
2650     const Tensor& grad_output,
2651     const Tensor& input,
2652     const Tensor& target,
2653     int64_t reduction) {
2654   if (reduction == at::Reduction::None) {
2655     return soft_margin_loss_backward(grad, input, target, reduction);
2656   }
2657   auto r = soft_margin_loss_backward(
2658       ones_like(grad_output), input, target, reduction);
2659   return (r * grad).sum();
2660 }
2661 
softplus_double_backward(const Tensor & grad,const Tensor & input,const Scalar & beta,const Scalar & threshold)2662 Tensor softplus_double_backward(
2663     const Tensor& grad,
2664     const Tensor& input,
2665     const Scalar& beta,
2666     const Scalar& threshold) {
2667   auto x = (input * beta);
2668   return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) *
2669       beta;
2670 }
2671 
2672 // NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
2673 //
2674 // `storage_offset` is ignored for simplicity in this note. If you just want the
2675 // full algorithm without explanation, scroll down to bottom of this note.
2676 //
2677 // Implementing the backward of as_strided is tricky because you have to deal
2678 // with mappings that map one memory location to multiple indices, i.e., the
2679 // output tensor has multiple indices pointing to **overlapping** memory
2680 // addresses. This can happen in all in all sorts of weird cases. For example,
2681 //
2682 //   x = torch.randn(15)
2683 //   x.as_strided([3, 3], [1, 0])  # "expand" case
2684 //   x.as_strided([3, 3], [2, 1])  # "size too large" case
2685 //   x.as_strided([3, 2], [3, 6])  # res[2, 0] points to 2*3 + 0*6 = 6
2686 //                                 # res[0, 1] points to 0*3 + 1*6 = 6
2687 //
2688 // Here is the general strategy we apply in implementing as_strided backward:
2689 //   0. ??? (optimization step. we will talk about this later)
2690 //   1. Create some underlying flattened tensor as if it is the base tensor
2691 //      representing the contiguous memory storage for both input and output.
2692 //   2. Use the output geometry to scatter (or index_add) the gradients into
2693 //      this storage tensor.
2694 //   3. ??? (fix for input tensor with overlapping memory. we will talk about
2695 //           this later)
2696 //   4. Return the as_strided view of the storage tensor using input geometry.
2697 //
2698 // In step (2), if the output tensor doesn't have overlapping memory, we can
2699 // safely scatter (`storage.as_strided(output_geometry).copy_(grad)`);
2700 // otherwise, we must use `index_add` as gradients at different indices may need
2701 // to be summed to a single location.
2702 //
2703 // For example, in this case:
2704 //
2705 //   x = torch.randn(3)
2706 //   y = x.as_strided([3, 3], [1, 0])  # "expand" case
2707 //                                     # size   [ 3, 3]
2708 //                                     # stride [ 1, 0]
2709 //   y.backward()  # step (1): contiguous storagte tensor `s` of size 3, which
2710 //                             is large enough to be used as underlying storage
2711 //                             for `x` and `y`.
2712 //                               s = [ 0, 0, 0]
2713 //                 # step (2): since `y` has overlapping memory, index_add grad
2714 //                             into `s` basing on `y`'s geometry, i.e.,
2715 //                             s[i * y.stride(0) + j * y.stride(1)] += gy[i, j].
2716 //                               s = [ 3, 3, 3]
2717 //                 # step (4): as_strided view `s` using `x`'s geometry
2718 //                               s = [ 3, 3, 3]
2719 //                               grad_input = s.as_strided(x.size(), x.stride())
2720 //                                          = s.as_strided([3], [1])
2721 //                                          = [ 3, 3, 3]
2722 //
2723 // This is exactly what we would get if using `expand`. However, here the input
2724 // tensor doesn't have overlapping memory. If it does, we must add an extra step
2725 // before (4). Considering this case:
2726 //
2727 //   t = torch.randn(3)
2728 //   x = t.expand(3, 3)            # input with overlapping memory
2729 //                                 # size   [3, 3]
2730 //                                 # stride [0, 1]
2731 //   y = x.as_strided([1], [1])    # contiguous output
2732 //                                 # size   [1]
2733 //                                 # stride [1]
2734 //   y.backward()  # step (1): contiguous storage tensor `s` of size 3, which
2735 //                             is large enough to be used as underlying storage
2736 //                             for `x` and `y`.
2737 //                               s = [ 0, 0, 0]
2738 //                 # step (2): scatter grad into `s` basing on `y`'s geometry
2739 //                               s = [ 1, 0, 0]
2740 //                 # step (4): as_strided view `s` using `x`'s geometry
2741 //                               s = [ 1, 0, 0]
2742 //                               grad_input = s.as_strided([3, 3], [0, 1])
2743 //                                          = s.as_strided([3, 3], [0, 1])
2744 //                                          = [[ 1, 0, 0],
2745 //                                             [ 1, 0, 0],
2746 //                                             [ 1, 0, 0]]
2747 // Is this result correct?
2748 //
2749 // `x.as_strided([1], [1])` call is obviously equivalent with
2750 // `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second
2751 // gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific
2752 // case, indexing `x` at any index in first column is also equivalent, and
2753 // yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is
2754 // an `x.size(1)`-times difference between these gradients computed from other
2755 // PyTorch ops and the gradient we got from as_strided.
2756 //
2757 // You might conclude that the gradients from as_strided is wrong. However,
2758 // let's first see why they are actually reasonable. Consider the pointwise
2759 // perturbations by `delta` anywhere in the first column of `x`. It will lead to
2760 // a `delta` change in the same memory location, and then `y` will change by
2761 // `delta`. So one can say the gradient should be exactly 1 at the first column,
2762 // as given by our above procedure.
2763 //
2764 // In the above computation of numerical gradients, they only match the
2765 // analytical results because strides and memory locations are considered in the
2766 // forward pass, i.e., this op (including both forward and backward) is
2767 // layout-aware.
2768 //
2769 // However, in PyTorch, most (probably all) other ops (forward and backward) are
2770 // layout-agnostic. E.g.,
2771 //
2772 //   t = torch.randn(1)
2773 //   x = t.expand(2)
2774 //   y = x.sum()
2775 //   y.backward()
2776 //
2777 // Layout-agnostic autograd (as it is currently in PyTorch) will give you
2778 //
2779 //   gy = 1
2780 //   gx = [ 1, 1]  # SumBackward:    torch.ones_like(x)
2781 //   gt = [ 2]     # ExpandBackward: gx.sum()
2782 //
2783 // Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta`
2784 // (the other will also change by `delta`), `y` will change by `2 * delta`. So
2785 // the gradients, if strides are taken into consideration, should be 2.
2786 //
2787 // Layout-aware autograd should give you
2788 //
2789 //   gy = 1
2790 //   gx = [ 2, 2]  # Because the backward considers the fact that the input `x`
2791 //                 # is already expanded.
2792 //   gt = [ 2]     # Layout-aware backward of expand is just a slicing because
2793 //                 # the previous backward should have already taken care of
2794 //                 # strides and made sure that gradients are the same along the
2795 //                 # expanded dimension.
2796 //
2797 // As shown above, these two types are not compatible. Therefore, we must either
2798 // make as_strided layout-agnostic, or make all other ops layout-aware.
2799 //
2800 // It is difficult to support layout-aware autograd (at least in the current
2801 // codebase structure), because it would mean
2802 //   1. storing tensor geometries of every input tensor for backward
2803 //   2. depending on input geometry, the gradient computed from backward change
2804 //   3. ideally enforcing gradient of T to always have same strides as T
2805 // (although these two methods only differ when it comes to overlapping memory)
2806 //
2807 // Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e.,
2808 // giving the same output regardless of the input layout. We consider
2809 // `input.stride()` as a separate independent fixed argument `input_stride`.
2810 // Then, `as_strided(input, size, stride)` can be thought of as:
2811 //   1. "Scatter" each value of `input` into a "storage" using storage location
2812 //      computed from the value's index in `input`, `input.size()` and
2813 //      `input_stride`, but if N values end up in the same location, the value
2814 //      is average of those N values (they will be the same value anyways).
2815 //
2816 //      Formal description:
2817 //        Denote the set of all input indices that pointing to the same storage
2818 //        location `storage[n]` as `S(n)`, i.e.,
2819 //
2820 //            S(n) = { index : <index, input_stride> == n, index is valid given
2821 //            input.size() },
2822 //
2823 //        where `<x, y>` is the dot product between `x` and `y`.
2824 //
2825 //        Then, the process is:
2826 //
2827 //            storage[n] = Avg { S(n) }
2828 //
2829 //        Note that all values in `S(n)` are the same (they point to the same
2830 //        memory location anyways, so this step doesn't change anything, but
2831 //        effectively avoids having the dependency on the layout of `input`.
2832 //        I.e., the result holds fixed regardless of the layout of `input`, as
2833 //        long as `input_stride` is fixed.
2834 //
2835 //      NOTE: for forward pass, we can equivalently simply select any one of
2836 //            `S(n)` as `storage[n]`. However, considering this as an average
2837 //            operation makes backward easier (so all values in set
2838 //            `{ grad_input[i] : i in S(n) }` are the same, and it can use the
2839 //            same geometry as input).
2840 //   2. As usual, return the as_strided view of `storage` using required output
2841 //      `size` and `stride`.
2842 //
2843 // To backward through this layout-agnostic version, we simply add the following
2844 // step:
2845 //   .... (scatter gradients into the storage tensor using output geometry)
2846 //   3. For all storage location n, `storage[n] /= |S(n)|`.
2847 //   .... (return as_strided view of the storage tensor using input geometry)
2848 //
2849 // Finally, we note that these general operations are expensive, so we apply the
2850 // following optimizations:
2851 //   Add step (0): For all output dimension `d` with output stride 0, sum the
2852 //                 gradients along dimension `d` (don't keepdim), and remove
2853 //                 dimension `d` from output size and stride.
2854 //                 (An optimization for "expand" cases so we may avoid step (3))
2855 //  Only apply step (3) when input tensor has overlapping memory.
2856 //
2857 // FULL ALGORITHM:
2858 //   0. For all output dimension `d` with output stride 0, sum the gradients
2859 //       along dimension `d` (don't keepdim), and remove dimension `d` from
2860 //       output size and stride.
2861 //   1. Create some underlying flattened tensor as if it is the base tensor
2862 //      representing the contiguous memory storage for both input and output.
2863 //   2. Use the output geometry to scatter (or index_add) the gradients into
2864 //      this storage tensor `storage`.
2865 //   3. If input tensor has overlapping memory,
2866 //      For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the
2867 //      number of indices in input geometry pointing to the same storage
2868 //      location `i` (i.e., `|S(i)|` in equations above).
2869 //   4. Return the as_strided view of the storage tensor using input geometry.
2870 //
2871 // See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to
2872 // roughly detech overlapping memory.
2873 
2874 // NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
2875 //
2876 // Checking memory overlap within a strided tensor is the special case of
2877 // detecting memory overlap of two strided tensors, where the two tensors start
2878 // at the same memory address. The later is HARD (see #8212).
2879 //
2880 // But even this special case isn't simple. This note describes a check for a
2881 // even more constrained simple case where we can be certain that there is no
2882 // overlap.
2883 //
2884 // The checking algorithm can be described as:
2885 //   0. Return [ pass check ] if any dimension has size 0
2886 //   1. Ignore all dimensions that have size 1
2887 //   2. If no remaining dimensions, return [ pass check ]
2888 //   3. Sort the remaining dimensions according to the strides decreasingly
2889 //   4. Check that for each dimension k,
2890 //
2891 //           stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i]
2892 //
2893 //      That is equivalent to, after reordering the dimensions so strides are
2894 //      in decreasing order, checking that stride of each dimension is larger
2895 //      than the maximum memory offset in a slice at that dimension.
2896 //
2897 // Obviously this check passes for contiguous tensors ( the dimensions will be
2898 // already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger
2899 // than RHS ). Similarly, the check passes for tensors contiguous in all but
2900 // the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being
2901 // exactly stride[-1] larger than RHS. (*)
2902 //
2903 // We will show that these view operations, including all our view operations
2904 // *except for* general as_strided and unfold, also preserve this invariant:
2905 //
2906 //  alias:      Obviously preserves
2907 //
2908 //  expand:     All changed dimensions are removed in step (1)
2909 //
2910 //  view:       Consider the input dimensions as grouped into consecutive
2911 //              dimension "blocks", where dimensions are contiguous in each one.
2912 //              one. view only works when the output dimensions can also be
2913 //              grouped into the same consecutive blocks of same ordering.
2914 //
2915 //              NB: this means that the number of elements and stride of the
2916 //                  last dimension in each block is the same in input and
2917 //                  output. (**)
2918 //
2919 //              Notation:
2920 //                Consider a single such block B,
2921 //                    ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [
2922 //                    B_next[0], ...
2923 //                                start--^^^^                  ^^^^^^^^^^^^--end
2924 //                Each B[i] denotes a dimension index such that B[i] = B[0] + i.
2925 //
2926 //              We first show that in a tensor (i.e., input) satisfies the
2927 //              invariant, after sorting, the dimensions within each block
2928 //              still remain consecutive. (***)
2929 //
2930 //                After removing dimensions of size 1, the dimensions within a
2931 //                block is already sorted by strides in descending order. So
2932 //                sorting all dimensions will not change the relative ordering
2933 //                among them.
2934 //
2935 //                Assume that some block B is not consecutive after sorting,
2936 //                i.e., there exists a dimension d between B[0] and B[-1] in
2937 //                sorted order.
2938 //
2939 //                By (*), we know that
2940 //                       stride[B[0]]
2941 //                    =  \sum_{i > 0}   (size[B[i]] - 1) * stride[B[i]] +
2942 //                    stride[B[-1]] <  \sum_{i > 0}   (size[B[i]] - 1) *
2943 //                    stride[B[i]] + stride[d]
2944 //                    <= \sum_{i > 0}   (size[B[i]] - 1) * stride[B[i]] +
2945 //                    (size[d] - 1) * stride[d]
2946 //                    <= \sum{j > B[0]} (size[j]    - 1) * stride[j],
2947 //
2948 //                where the first <   comes from sorting and
2949 //                      the second <= comes from the fact that dimension d
2950 //                                               exists after step (1) and
2951 //                                               thus must have size greater
2952 //                                               than 1
2953 //                      the third  <= comes from the fact that each term in
2954 //                                               the sum is non-negative
2955 //
2956 //                Then we have a countradiction as the invariant must not be
2957 //                satisfied at B[0]. So the original proposition is true.
2958 //
2959 //              Now that we established the above claim (***), we consider the
2960 //              view operation as first sorting the dimensions (i.e., blocks),
2961 //              apply the original view (since it only cares dimensions being
2962 //              consecutive and contiguous withtin each block), and then undo
2963 //              the sort.
2964 //
2965 //              Consider a single block B in the output,
2966 //                  ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ...
2967 //                    start--^^^^                  ^^^^^^^^^^^^--end
2968 //
2969 //              By (*), we know that for all i
2970 //                  stride[i] = stride[B[-1]] +
2971 //                                \sum_{j=i+1}^{k} (size[B[j]] - 1) *
2972 //                                stride[B[j]]
2973 //
2974 //              Then the invariant is obviously satisfied at every dimension
2975 //              in this block if it is satisfied at dimension B[-1]. It only
2976 //              remains to show that it is satisfied at the last dimension in
2977 //              each block.
2978 //
2979 //              Since the same blocks are present in both input and output
2980 //              with the same ordering, we will abuse the notation in the
2981 //              following statements.
2982 //
2983 //              By (*), we know that the following holds for both input and
2984 //              output, for any block B:
2985 //                    \sum_{i > B[-1]} (size[i] - 1) * stride[i]
2986 //                  = \sum_{block B' after B} \prod_{j in B'} size[B[j]] *
2987 //                  stride[B'[-1]] = \sum_{block B' after B} numel(B') *
2988 //                  stride[B'[-1]].
2989 //                    ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^
2990 //              By (**), we know that, this quantity in the above equation
2991 //              remains the same in input and output. So both
2992 //                  \sum_{i > B[-1]} (size[i] - 1) * stride[i]
2993 //              and
2994 //                  stride[B[-1]]
2995 //              are the same in input and output.
2996 //
2997 //              These two quantities are exactly the LHS and RHS of the
2998 //              invariant inequality. Since by assumption the invariant is
2999 //              satisfied in input at B[-1], it is also satisfied in output at
3000 //              B[-1]. This concludes the proof.
3001 //
3002 //  squeeze:    Special case of view
3003 //
3004 //  unsqueeze:  Special case of view
3005 //
3006 //  slice:      Consider slicing dimension i with step = k >= 1.
3007 //
3008 //              Let stride' and size' be the output strides and sizes. We have
3009 //
3010 //                  stride'[i] = k * stride[i]
3011 //                  size'[i] <= floor(size[i] / k)
3012 //
3013 //              If size'[i] = 1, invariant is obviously satisfied as we are
3014 //              just removing a dimension (afte step (1)).
3015 //
3016 //              Assume size'[i] > 1.
3017 //
3018 //              By assumption, the invariant is satisfied at every dimension
3019 //              in input.
3020 //
3021 //              For any dimension j, if stride[j] > stride[i], we have
3022 //                  stride'[j] =  stride[j]
3023 //                             >  (size[i] - 1) * stride[i]
3024 //                             =  (size[i] / k * k - 1) * k * stride[i] / k
3025 //                             =  (size[i] / k - 1 / k) * stride'[i]
3026 //                             >= (size'[i]    - 1 / k) * stride'[i]
3027 //                             >= stride'[i].
3028 //
3029 //              If stride[j] < stride[i], we have
3030 //                  stride'[j] = stride[j] < stride[i] <= stride'[i].
3031 //
3032 //              So the sorting order remains unchanged after slice.
3033 //
3034 //              Since
3035 //                     (size'[i] - 1) * stride'[i]
3036 //                  =  (floor(size[i] / k) - 1) * k * stride[i]
3037 //                  <= (size[i] / k - 1) * k * stride[i]
3038 //                  =  (size[i] - k) * stride[i]
3039 //                  <= (size[i] - 1) * * stride[i],
3040 //              the term from this dimension i in the invariant inequality at
3041 //              other dimensions can only decrease after slice. So the
3042 //              invariant is preserved.
3043 //
3044 //  narrow:     Special case of slice
3045 //
3046 //  select:     narrow + squeeze
3047 //
3048 //  permute:    Sorting makes permutation of dimensions irrelevant
3049 //
3050 //  transpose:  Sorting makes swapping dimensions irrelevant
3051 //
3052 //  diagonal:   Effectively merging two dimensions i and j into a new
3053 //              dimension k s.t.
3054 //                  stride'[k] =  stride[i] + stride[j]
3055 //                  size'[k]   <= min(size[i], size[j]),
3056 //              where stride and size are on the input, and stride' and size'
3057 //              are on the output.
3058 //
3059 //              Assuming that size[i] > 1 and size[j] > 1. If any has size 1,
3060 //              then this is unsqueeze on that dimension.
3061 //
3062 //              WLOG, say stride[i] >= stride[j].
3063 //
3064 //              Each dimension d in input with stride[d] > stride[j] has
3065 //                  stride'[d] =  stride[d]
3066 //                             >  (size[i] - 1) * stride[i] + (size[j] - 1) *
3067 //                             stride[j]
3068 //                             >= stride[i] + stride[j]
3069 //                             =  stride[k].
3070 //              So, considering the sorted dimensions, this is effectively
3071 //              removing i, and replacing j with k.
3072 //
3073 //              For dimensions d with stride[i] < stride[d] < stride[j], the
3074 //              term from dimension i is removed in the invariant inequality.
3075 //              For dimensions d with stride[d] > stride[j], we have
3076 //                     (size'[k] - 1) * stride'[k]
3077 //                  <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j])
3078 //                  <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j],
3079 //              so the term from i and j in the invariant can only decrease.
3080 //
3081 //              So this is generally relaxing the constraint, and thus it
3082 //              preserves it.
3083 
3084 // This implements steps (2)~(4) of the algorithm in
3085 // NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
3086 // Helper for as_strided_backward
_maybe_overlapping_memory(c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides)3087 static inline bool _maybe_overlapping_memory(
3088     c10::SymIntArrayRef sizes,
3089     c10::SymIntArrayRef strides) {
3090   if (!sizes.empty()) {
3091     std::vector<std::size_t> argsort(sizes.size());
3092     std::iota(argsort.begin(), argsort.end(), 0);
3093     std::sort(
3094         argsort.begin(), argsort.end(), [&](std::size_t i, std::size_t j) {
3095           return strides[i] < strides[j];
3096         });
3097 
3098     c10::SymInt max_index_in_slice = 0;
3099     for (auto i : argsort) {
3100       const auto& stride_ = strides[i];
3101       if (stride_ <= max_index_in_slice) {
3102         return true;
3103       }
3104       max_index_in_slice += stride_ * (sizes[i] - 1);
3105     }
3106   }
3107   return false;
3108 }
3109 
3110 // Returns the minimum storage size needed to contain a tensor of sizes,
3111 // strides, and storage_offset Helper for as_strided_backward
_min_storage_size(c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,c10::SymInt storage_offset)3112 static inline c10::SymInt _min_storage_size(
3113     c10::SymIntArrayRef sizes,
3114     c10::SymIntArrayRef strides,
3115     c10::SymInt storage_offset) {
3116   c10::SymInt storage_size = storage_offset + 1;
3117   auto dim = sizes.size();
3118   for (const auto i : c10::irange(dim)) {
3119     const auto& size_i = sizes[i];
3120     if (size_i == 0) {
3121       return storage_offset;
3122     }
3123     storage_size += (size_i - 1) * strides[i];
3124   }
3125   return storage_size;
3126 }
3127 
3128 // See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for
3129 // explanation
as_strided_backward(Tensor grad,const TensorGeometry & input_geometry,c10::SymIntArrayRef sym_sizes,c10::SymIntArrayRef sym_strides,const std::optional<c10::SymInt> & sym_storage_offset_)3130 Tensor as_strided_backward(
3131     Tensor grad,
3132     const TensorGeometry& input_geometry,
3133     c10::SymIntArrayRef sym_sizes,
3134     c10::SymIntArrayRef sym_strides,
3135     const std::optional<c10::SymInt>& sym_storage_offset_) {
3136   // For output geometry,
3137   //   check for size 0 dimensions,
3138   //   skip size 1 dimensions,
3139   //   reduce grad on expanded dims (stride=0, size>1)
3140   // Step (0)     for the algorithm in NOTE [ as_strided Backward and
3141   // layout-aware/agnostic autograd ] Step (0)~(1) for the algorithm in NOTE [
3142   // Detecting Memory Overlap Within A Strided Tensor ]
3143   //              on output geometry
3144   auto sym_storage_offset =
3145       sym_storage_offset_.value_or(input_geometry.sym_storage_offset());
3146   auto odim = grad.dim();
3147   std::vector<c10::SymInt> out_sizes_, out_strides_;
3148   out_sizes_.reserve(odim);
3149   out_strides_.reserve(odim);
3150   for (int64_t i = odim - 1; i >= 0; i--) {
3151     const auto& size_i = sym_sizes[i];
3152     const auto& stride_i = sym_strides[i];
3153     if (size_i == 0) {
3154       return at::zeros_symint(input_geometry.sym_sizes(), grad.options());
3155     } else if (size_i == 1) {
3156       grad = grad.squeeze(i);
3157     } else if (stride_i == 0) {
3158       grad = grad.sum(i, false);
3159     } else {
3160       out_sizes_.insert(out_sizes_.begin(), size_i);
3161       out_strides_.insert(out_strides_.begin(), stride_i);
3162     }
3163   }
3164   // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A
3165   // Strided Tensor ]
3166   //              on output geometry
3167   auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_);
3168 
3169   // For input geometry,
3170   //   check for size 0 dimensions,
3171   //   skip size 1 dimensions,
3172   // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A
3173   // Strided Tensor ]
3174   //              on input geometry
3175   auto idim = input_geometry.dim();
3176   auto inp_sizes = input_geometry.sym_sizes(),
3177        inp_strides = input_geometry.sym_strides();
3178   std::vector<c10::SymInt> inp_sizes_, inp_strides_;
3179   inp_sizes_.reserve(idim);
3180   inp_strides_.reserve(idim);
3181   for (int64_t i = idim - 1; i >= 0; i--) {
3182     const auto& size_i = inp_sizes[i];
3183     const auto& stride_i = inp_strides[i];
3184     if (size_i == 0) {
3185       return at::zeros_symint(input_geometry.sym_sizes(), grad.options());
3186     } else if (size_i != 1) {
3187       inp_sizes_.insert(inp_sizes_.begin(), size_i);
3188       inp_strides_.insert(inp_strides_.begin(), stride_i);
3189     }
3190   }
3191   // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A
3192   // Strided Tensor ]
3193   //              on input geometry
3194   auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_);
3195 
3196   // Rest of this function implements
3197   // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and
3198   // layout-aware/agnostic autograd ]
3199   // TODO: Raise if not all output values are visible in input geometry.
3200   //       Technically speaking, if you treat those values as constants, not
3201   //       raising is fine, and mathematically correct. However, these values
3202   //       really are contained in some base tensor, and by treating them as
3203   //       constants we are ignoring this tight dependency. Therefore, it is
3204   //       more sensible to raise here.
3205 
3206   // Step (1): create underlying tensor as "storage"
3207   auto shared_offset =
3208       // TODO: symint-ify. Do we need a min() and max() for SymInts?
3209       input_geometry.sym_storage_offset().min(sym_storage_offset);
3210   auto inp_effective_offset =
3211       input_geometry.sym_storage_offset() - shared_offset;
3212   auto out_effective_offset = sym_storage_offset - shared_offset;
3213   auto base_size1 =
3214       _min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset);
3215   auto base_size2 =
3216       _min_storage_size(out_sizes_, out_strides_, out_effective_offset);
3217   auto base_size = base_size1.max(base_size2);
3218   auto storage = grad.new_zeros_symint(c10::SymIntArrayRef(base_size));
3219 
3220   // prepare indices tensor if we will do index_add_ later
3221   std::optional<at::Tensor> flatten_full_indices;
3222   if (inp_maybe_overlap || out_maybe_overlap) {
3223     flatten_full_indices =
3224         // TODO: should we symint-ify arange? Need SymScalar.
3225         at::arange(
3226             0,
3227             base_size.guard_int(__FILE__, __LINE__),
3228             grad.options().dtype(at::kLong));
3229   }
3230 
3231   // Step (2): use output geometry to scatter gradients into storage
3232   if (out_maybe_overlap) {
3233     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
3234     auto out_indices = flatten_full_indices->as_strided_symint(
3235         out_sizes_, out_strides_, out_effective_offset);
3236     storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
3237   } else {
3238     // assume that new tensors have 0 storage offset
3239     storage.as_strided_symint(out_sizes_, out_strides_, out_effective_offset)
3240         .copy_(grad);
3241   }
3242 
3243   // Step (3): if input tensor has overlapping memory, divide scattered gradient
3244   //           at storage[i] by the number of times i shows up in input geometry
3245   if (inp_maybe_overlap) {
3246     auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
3247     auto inp_indices =
3248         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
3249         flatten_full_indices
3250             ->as_strided_symint(inp_sizes_, inp_strides_, inp_effective_offset)
3251             .reshape(-1);
3252     count.index_add_(
3253         0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices));
3254     storage.div_(count); // this will give nan outside visible range
3255   }
3256   // Step (4): return as_strided view of the storage tensor with input geometry
3257   return storage.as_strided_symint(
3258       inp_sizes, inp_strides, inp_effective_offset);
3259 }
3260 
as_strided_scatter_backward(const Tensor & grad,const TensorGeometry & input_geometry,const TensorGeometry & src_geometry,c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,std::optional<c10::SymInt> storage_offset)3261 Tensor as_strided_scatter_backward(
3262     const Tensor& grad,
3263     const TensorGeometry& input_geometry,
3264     const TensorGeometry& src_geometry,
3265     c10::SymIntArrayRef sizes,
3266     c10::SymIntArrayRef strides,
3267     std::optional<c10::SymInt> storage_offset) {
3268   // Note [as_strided_scatter backward support]
3269   // as_strided_scatter handling for autograd is a beast, and is non-trivial to
3270   // implement for arbitrarily strided inputs. Most uses for as_strided with
3271   // functionalization only care about the contiguous case anyway, So for now
3272   // this is not implemented. When autograd is being used, we ban non-contiguous
3273   // inputs. We can assume that the input was a contiguous tensor. Also, we'll
3274   // take the perf hit and contiguify grad for now.
3275   auto grad_ = grad.contiguous();
3276   auto grad_slice = grad_.as_strided_symint(sizes, strides, storage_offset);
3277   auto result_buffer = grad_.new_zeros_symint(input_geometry.sym_sizes());
3278   auto result = result_buffer.as_strided_symint(
3279       input_geometry.sym_sizes(), input_geometry.sym_strides());
3280   auto result_slice = result_buffer.as_strided_symint(
3281       sizes, strides, std::move(storage_offset));
3282   result_slice.copy_(grad_slice);
3283   return result;
3284 }
3285 
atan2_backward(const Tensor & grad,const Tensor & self,const Tensor & other,std::array<bool,2> output_mask)3286 std::tuple<Tensor, Tensor> atan2_backward(
3287     const Tensor& grad,
3288     const Tensor& self,
3289     const Tensor& other,
3290     std::array<bool, 2> output_mask) {
3291   if (!grad.defined()) {
3292     return std::tuple<Tensor, Tensor>{Tensor(), Tensor()};
3293   }
3294   auto recip = (self * self + other * other).reciprocal();
3295   return std::tuple<Tensor, Tensor>{
3296       output_mask[0] ? grad * other * recip : Tensor(),
3297       output_mask[1] ? grad * -self * recip : Tensor()};
3298 }
3299 
gelu_double_backward(const Tensor & ggI,const Tensor & gO,const Tensor & input,c10::string_view approximate)3300 Tensor gelu_double_backward(
3301     const Tensor& ggI,
3302     const Tensor& gO,
3303     const Tensor& input,
3304     c10::string_view approximate) {
3305   // if (at::native::get_gelutype_enum(approximate) ==
3306   // at::native::GeluType::Tanh) {
3307   if (approximate == "tanh") {
3308     constexpr auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
3309     constexpr auto kKappa = 0.044715;
3310 
3311     auto inner = kBeta * (input + kKappa * pow(input, 3));
3312     auto tanh_inner = tanh(inner);
3313     auto sech_inner = 1 / cosh(inner);
3314 
3315     auto f = 0.5 * input;
3316     auto g = 1 - tanh_inner * tanh_inner;
3317     auto h = kBeta * (1 + 3 * kKappa * input * input);
3318 
3319     auto f_prime_gh = 0.5 * g * h;
3320 
3321     auto g_prime = (2 * sech_inner) * (-sech_inner * tanh_inner) * h;
3322     auto g_prime_fh = f * h * g_prime;
3323 
3324     auto h_prime = 6 * kKappa * input * kBeta;
3325     auto h_prime_fg = f * g * h_prime;
3326 
3327     // left_derivative = f_prime_gh
3328     // right_derivative = f_prime_gh + g_prime_fh + h_prime_fg
3329     // dgrad_dX = left_derivative + right_derivative
3330     auto gI = ggI * gO * (2 * f_prime_gh + g_prime_fh + h_prime_fg);
3331     return gI;
3332   } else {
3333     constexpr auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
3334     auto input_sq = input * input;
3335     auto pdf = kBeta * at::exp(-0.5 * input_sq);
3336     auto dgrad_dInput = 2 * pdf - input_sq * pdf;
3337     auto gI = ggI * gO * dgrad_dInput;
3338     return gI;
3339   }
3340 }
3341 
elu_double_backward(const Tensor & grad,const Tensor & grad_output,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale,bool is_result,const Tensor & self_or_result)3342 Tensor elu_double_backward(
3343     const Tensor& grad,
3344     const Tensor& grad_output,
3345     const Scalar& alpha,
3346     const Scalar& scale,
3347     const Scalar& input_scale,
3348     bool is_result,
3349     const Tensor& self_or_result) {
3350   if (is_result) {
3351     return grad * grad_output * input_scale *
3352         (self_or_result < 0).type_as(grad);
3353   } else {
3354     return at::elu_backward(
3355                grad * grad_output * input_scale,
3356                alpha,
3357                scale,
3358                input_scale,
3359                is_result,
3360                self_or_result) *
3361         (self_or_result < 0).type_as(grad);
3362   }
3363 }
3364 
slice_backward_wrapper(const at::Tensor & grad,const c10::SymIntArrayRef & input_sizes,int64_t dim,std::optional<c10::SymInt> start,std::optional<c10::SymInt> end,c10::SymInt step)3365 Tensor slice_backward_wrapper(
3366     const at::Tensor& grad,
3367     const c10::SymIntArrayRef& input_sizes,
3368     int64_t dim,
3369     std::optional<c10::SymInt> start,
3370     std::optional<c10::SymInt> end,
3371     c10::SymInt step) {
3372   auto start_val = start.has_value() ? start.value() : 0;
3373   auto end_val = end.has_value() ? end.value() : INT64_MAX;
3374 
3375   return slice_backward_symint(
3376       grad,
3377       input_sizes,
3378       dim,
3379       std::move(start_val),
3380       std::move(end_val),
3381       std::move(step));
3382 }
3383 
linalg_svd_jvp(const Tensor & dA,const Tensor & U_,const Tensor & S,const Tensor & Vh_,const bool full_matrices)3384 std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(
3385     const Tensor& dA,
3386     const Tensor& U_,
3387     const Tensor& S,
3388     const Tensor& Vh_,
3389     const bool full_matrices) {
3390   at::NoTF32Guard disable_tf32;
3391   // See svd_backward for the derivation
3392   // With sym(X) = X + X^H, we implement
3393   // dU = U (sym(dX S) / E + i Im(diag(dX)) / (2S))
3394   // if m > n
3395   //   dU = [dU for m == n] + (I_m - UU^H) dA V S^{-1}
3396   // dS = Re(diag(dP))
3397   // dV = V (sym(S dX) / E - i Im(diag(dX)) / (2S))
3398   // if m < n
3399   //   dV = [dV for m == n] + (I_n - VV^H) (dA)^H U S^{-1}
3400   // dVh = dV^H
3401   // with dP = U^H dA V
3402   //      dX = dP - dS
3403   //      E_{jk} = S_k^2 - S_j^2 if j != k
3404   //               1             otherwise
3405 
3406   // Checks compute_uv=true
3407   TORCH_INTERNAL_ASSERT(U_.dim() >= 2 && Vh_.dim() >= 2);
3408 
3409   const auto is_complex = dA.is_complex();
3410   const auto m = dA.size(-2);
3411   const auto n = dA.size(-1);
3412   const auto k = S.size(-1);
3413 
3414   const auto U = full_matrices ? U_.narrow(-1, 0, k) : U_;
3415   const auto Vh = full_matrices ? Vh_.narrow(-2, 0, k) : Vh_;
3416   const auto V = Vh.mH();
3417 
3418   // dP = U^H dA V
3419   auto dP = m >= n ? at::matmul(U.mH(), at::matmul(dA, V))
3420                    : at::matmul(at::matmul(U.mH(), dA), V);
3421 
3422   auto dS =
3423       is_complex ? at::real(dP.diagonal(0, -2, -1)) : dP.diagonal(0, -2, -1);
3424 
3425   // dX = dP - dS
3426   dP = dP - dS.diag_embed();
3427 
3428   auto E = [&S] {
3429     const auto S2 = S * S;
3430     auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1);
3431     // Any number a != 0 would, as we are just going to use it to compute 0 / a
3432     // later on
3433     ret.diagonal(0, -2, -1).fill_(1);
3434     return ret;
3435   }();
3436 
3437   const auto sym = [](const Tensor& X) { return X + X.mH(); };
3438 
3439   // diag(dP) / (2S)
3440   auto diagdP2S = is_complex ? dP.diagonal(0, -2, -1).div(2. * S) : Tensor{};
3441 
3442   // dU = U (sym(dP S) / E) + i Im(diag(dP)) / (2S)
3443   auto dU = [&] {
3444     auto dUaux = sym(dP * S.unsqueeze(-2)) / E;
3445     if (is_complex) {
3446       dUaux = dUaux + diagdP2S.diag_embed();
3447     }
3448     return at::matmul(U, dUaux);
3449   }();
3450   if (m > n) {
3451     // dU += (I_m - UU^H) dA V S^{-1}
3452     const auto dAVSinv = at::matmul(dA, V / S.unsqueeze(-2));
3453     dU = dU + dAVSinv - at::matmul(U, at::matmul(U.mH(), dAVSinv));
3454 
3455     // To "fix" the full_matrices case (the full_matrices case should not be
3456     // differentiable...)
3457     if (full_matrices) {
3458       auto shape = dU.sizes().vec();
3459       shape.end()[-1] = m - n;
3460       dU = at::cat({dU, dU.new_zeros(shape)}, /*dim=*/-1);
3461     }
3462   }
3463 
3464   // dVh = -sym(S dP) / E + i Im(diag(dP)) / (2S)
3465   // Perf: We negate the S as it's the smallest tensor in the equation
3466   auto dVh = [&] {
3467     auto dVhaux = sym(dP * (-S).unsqueeze(-1)) / E;
3468     if (is_complex) {
3469       dVhaux = dVhaux + diagdP2S.diag_embed();
3470     }
3471     return at::matmul(dVhaux, Vh);
3472   }();
3473   if (m < n) {
3474     // dVh += S^{-1} U^H dA (I_n - VV^H)
3475     const auto UHdASinv = at::matmul(U.mH() / S.unsqueeze(-1), dA);
3476     dVh = dVh + UHdASinv - at::matmul(at::matmul(UHdASinv, V), Vh);
3477 
3478     // To "fix" the full_matrices case (the full_matrices case should not be
3479     // differentiable...)
3480     if (full_matrices) {
3481       auto shape = dVh.sizes().vec();
3482       shape.end()[-2] = n - m;
3483       dVh = at::cat({dVh, dVh.new_zeros(shape)}, /*dim=*/-2);
3484     }
3485   }
3486 
3487   return std::make_tuple(std::move(dU), std::move(dS), std::move(dVh));
3488 }
3489 
svd_backward(const Tensor & gU,const Tensor & gS,const Tensor & gVh,const Tensor & U,const Tensor & S,const Tensor & Vh)3490 Tensor svd_backward(
3491     const Tensor& gU,
3492     const Tensor& gS,
3493     const Tensor& gVh,
3494     const Tensor& U,
3495     const Tensor& S,
3496     const Tensor& Vh) {
3497   at::NoTF32Guard disable_tf32;
3498   // Throughout both the real and complex case we assume A has distinct singular
3499   // values. Furthermore, if A is rectangular or complex, we assume it's
3500   // full-rank.
3501   //
3502   //
3503   // The real case (A \in R)
3504   // See e.g. https://j-towns.github.io/papers/svd-derivative.pdf
3505   //
3506   // Denote by skew(X) = X - X^T, and by A o B the coordinatewise product, then
3507   // if m == n
3508   //   gA = U [(skew(U^T gU) / E)S + S(skew(V^T gV) / E) + I o gS ]V^T
3509   // where E_{jk} = S_k^2 - S_j^2 if j != k and 1 otherwise
3510   //
3511   // if m > n
3512   //   gA = [term in m == n] + (I_m - UU^T)gU S^{-1} V^T
3513   // if m < n
3514   //   gA = [term in m == n] + U S^{-1} (gV)^T (I_n - VV^T)
3515   //
3516   //
3517   // The complex case (A \in C)
3518   // This one is trickier because the svd is not locally unique.
3519   // Denote L = diag(e^{i\theta_k}), then we have that if A = USV^H, then (UL,
3520   // S, VL) is another valid SVD decomposition of A as A = ULS(VL)^H =
3521   // ULSL^{-1}V^H = USV^H, since L, S and L^{-1} commute, since they are all
3522   // diagonal.
3523   //
3524   // Assume wlog that n >= k in what follows, as otherwise we could reason about
3525   // A^H. Denote by St_k(C^n) = {A \in C^{n,k} | A^H A = I_k} the complex
3526   // Stiefel manifold. What this invariance means is that the svd decomposition
3527   // is not a map svd: C^{n x k} -> St_k(C^n) x R^n x St_k(C^k) (where St_k(C^k)
3528   // is simply the unitary group U(k)) but a map svd: C^{n x k} -> M x R^n where
3529   // M is the manifold given by quotienting St_k(C^n) x U(n) by the action (U,
3530   // V) -> (UL, VL) with L as above. Note that M is a manifold, because the
3531   // action is free and proper (as U(1)^k \iso (S^1)^k is compact). For this
3532   // reason, pi : St_k(C^n) x U(n) -> M forms a principal bundle.
3533   //
3534   // To think about M, consider the case case k = 1. The, we have the bundle
3535   // pi : St_1(C^n) x U(1) -> M
3536   // now, St_1(C^n) are just vectors of norm 1 in C^n. That's exactly the sphere
3537   // of dimension 2n-1 in C^n \iso R^{2n} S^{2n-1} = { z \in C^n | z^H z = 1}.
3538   // Then, in this case, we're quotienting out U(1) completely, so we get that
3539   // pi : S^{2n-1} x U(1) -> CP(n-1)
3540   // where CP(n-1) is the complex projective space of dimension n-1.
3541   // In other words, M is just the complex projective space, and pi is (pretty
3542   // similar to) the usual principal bundle from S^{2n-1} to CP(n-1). The case k
3543   // > 1 is the same, but requiring a linear independence condition between the
3544   // vectors from the different S^{2n-1} or CP(n-1).
3545   //
3546   // Note that this is a U(1)^k-bundle. In plain words, this means that the
3547   // fibres of this bundle, i.e. pi^{-1}(x) for x \in M are isomorphic to U(1) x
3548   // ... x U(1). This is obvious as, if pi(U,V) = x, pi^{-1}(x) = {(U
3549   // diag(e^{i\theta}), V diag(e^{i\theta})) | \theta \in R^k}
3550   //            = {(U diag(z), V diag(z)) | z \in U(1)^k}
3551   // since U(1) = {z \in C | |z| = 1}.
3552   //
3553   // The big issue here is that M with its induced metric is not locally
3554   // isometric to St_k(C^n) x U(k). [The why is rather technical, but you can
3555   // see that the horizontal distribution is not involutive, and hence
3556   // integrable due to Frobenius' theorem] What this means in plain words is
3557   // that, no matter how we choose to return the U and V from the SVD, we won't
3558   // be able to simply differentiate wrt. U and V and call it a day. An example
3559   // of a case where we can do this is when performing an eigendecomposition on
3560   // a real matrix that happens to have real eigendecomposition. In this case,
3561   // even though you can rescale the eigenvectors by any real number, you can
3562   // choose them of norm 1 and call it a day. In the eigenvector case, we are
3563   // using that you can isometrically embed S^{n-1} into R^n. In the svd case,
3564   // we need to work with the "quotient manifold" M explicitly, which is
3565   // slightly more technically challenging.
3566   //
3567   // Since the columns of U and V are not uniquely defined, but are
3568   // representatives of certain classes of equivalence which represent elements
3569   // M, the user may not depend on the particular representative that we return
3570   // from the SVD. In particular, if the loss function depends on U or V, it
3571   // must be invariant under the transformation (U, V) -> (UL, VL) with L =
3572   // diag(e^{i\theta})), for every \theta \in R^k. In more geometrical terms,
3573   // this means that the loss function should be constant on the fibres, or, in
3574   // other words, the gradient along the fibres should be zero. We may see this
3575   // by checking that the gradients as element in the tangent space T_{(U,
3576   // V)}(St(n,k) x U(k)) are normal to the fibres. Differentiating the map (U,
3577   // V) -> (UL, VL), we see that the space tangent to the fibres is given by
3578   // Vert_{(U, V)}(St(n,k) x U(k)) = { i[U, V]diag(\theta) | \theta in R^k}
3579   // where [U, V] denotes the vertical concatenation of U and V to form an (n+k,
3580   // k) matrix. Then, solving <i[U,V]diag(\theta), [S, T]> = 0 for two matrices
3581   // S, T \in T_{(U, V)}(St(n,k) x U(k)) where <A, B> = Re tr(A^H B) is the
3582   // canonical (real) inner product in C^{n x k} we get that the function is
3583   // invariant under action of U(1)^k iff Im(diag(U^H gU + V^H gV)) = 0
3584   //
3585   // Using this in the derviaton for the forward AD, one sees that, with the
3586   // notation from those notes Using this and writing sym(X) = X + X^H, we get
3587   // that the forward AD for SVD in the complex case is given by dU = U (sym(dX
3588   // S) / E + i Im(diag(dX)) / (2S)) if m > n
3589   //   dU = [dU for m == n] + (I_m - UU^H) dA V S^{-1}
3590   // dS = Re(diag(dP))
3591   // dV = V (sym(S dX) / E - i Im(diag(dX)) / (2S))
3592   // if m < n
3593   //   dV = [dV for m == n] + (I_n - VV^H) (dA)^H U S^{-1}
3594   // dVh = dV^H
3595   // with dP = U^H dA V
3596   //      dX = dP - dS
3597   //      E_{jk} = S_k^2 - S_j^2 if j != k
3598   //               1             otherwise
3599   //
3600   // Similarly, writing skew(X) = X - X^H
3601   // the adjoint wrt. the canonical metric is given by
3602   // if m == n
3603   //   gA = U [((skew(U^H gU) / E) S + i Im(diag(U^H gU)) / S + S ((skew(V^H gV)
3604   //   / E)) + I o gS] V^H
3605   // if m > n
3606   //   gA = [term in m == n] + (I_m - UU^H)gU S^{-1} V^H
3607   // if m < n
3608   //   gA = [term in m == n] + U S^{-1} (gV)^H (I_n - VV^H)
3609   // where we have used that Im(diag(U^H gU)) = - Im(diag(V^h gV)) to group the
3610   // diagonal imaginary terms into one that just depends on U^H gU.
3611 
3612   // Checks compute_uv=true
3613   TORCH_INTERNAL_ASSERT(U.dim() >= 2 && Vh.dim() >= 2);
3614 
3615   // Trivial case
3616   if (!gS.defined() && !gU.defined() && !gVh.defined()) {
3617     return {};
3618   }
3619 
3620   const auto m = U.sym_size(-2);
3621   const auto n = Vh.sym_size(-1);
3622 
3623   // Optimisation for svdvals: gA = U @ diag(gS) @ Vh
3624   if (!gU.defined() && !gVh.defined()) {
3625     return m >= n ? at::matmul(U, gS.unsqueeze(-1) * Vh)
3626                   : at::matmul(U * gS.unsqueeze(-2), Vh);
3627   }
3628   // At this point, at least one of gU, gVh is defined
3629 
3630   const bool is_complex = U.is_complex();
3631   const auto skew = [](const Tensor& A) { return A - A.mH(); };
3632   const auto UhgU = gU.defined() ? skew(at::matmul(U.mH(), gU)) : Tensor{};
3633   const auto VhgV = gVh.defined() ? skew(at::matmul(Vh, gVh.mH())) : Tensor{};
3634 
3635   // Check for the invariance of the loss function, i.e.
3636   // Im(diag(U^H gU)) + Im(diag(V^H gV)) = 0
3637   if (is_complex) {
3638     const auto imdiag_UhgU =
3639         gU.defined() ? at::imag(UhgU.diagonal(0, -2, -1)) : at::zeros_like(S);
3640     const auto imdiag_VhgV =
3641         gVh.defined() ? at::imag(VhgV.diagonal(0, -2, -1)) : at::zeros_like(S);
3642     // Rather lax atol and rtol, as we don't want false positives
3643     TORCH_CHECK(
3644         at::allclose(imdiag_UhgU, -imdiag_VhgV, /*rtol=*/1e-2, /*atol=*/1e-2),
3645         "svd_backward: The singular vectors in the complex case are specified up to multiplication "
3646         "by e^{i phi}. The specified loss function depends on this phase term, making "
3647         "it ill-defined.");
3648   }
3649 
3650   // gA = ((U^H gU) / E) S +  S (((V^H gV) / E) + I o (gS + diag(U^H gU) / (2 *
3651   // S))
3652   Tensor gA = [&] {
3653     // ret holds everything but the diagonal of gA
3654     auto ret = [&] {
3655       const auto E = [&S] {
3656         const auto S2 = S * S;
3657         auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1);
3658         // Any number a != 0 would, as we are just going to use it to compute 0
3659         // / a later on
3660         ret.diagonal(0, -2, -1).fill_(1);
3661         return ret;
3662       }();
3663 
3664       if (gU.defined()) {
3665         if (gVh.defined()) {
3666           return (UhgU * S.unsqueeze(-2) + S.unsqueeze(-1) * VhgV) / E;
3667         } else {
3668           return (UhgU / E) * S.unsqueeze(-2);
3669         }
3670       } else { // gVh.defined();
3671         return S.unsqueeze(-1) * (VhgV / E);
3672       }
3673     }();
3674     // Fill the diagonal
3675     if (gS.defined()) {
3676       ret = ret + gS.diag_embed();
3677     }
3678     if (is_complex && gU.defined() && gVh.defined()) {
3679       ret = ret + (UhgU.diagonal(0, -2, -1) / (2. * S)).diag_embed();
3680     }
3681     return ret;
3682   }();
3683 
3684   if (m > n && gU.defined()) {
3685     // gA = [UgA + (I_m - UU^H)gU S^{-1}]V^H
3686     gA = at::matmul(U, gA);
3687     const auto gUSinv = gU / S.unsqueeze(-2);
3688     gA = gA + gUSinv - at::matmul(U, at::matmul(U.mH(), gUSinv));
3689     gA = at::matmul(gA, Vh);
3690   } else if (m < n && gVh.defined()) {
3691     //   gA = U[gA V^H + S^{-1} (gV)^H (I_n - VV^H)]
3692     gA = at::matmul(gA, Vh);
3693     const auto SinvgVh = gVh / S.unsqueeze(-1);
3694     gA = gA + SinvgVh - at::matmul(at::matmul(SinvgVh, Vh.mH()), Vh);
3695     gA = at::matmul(U, gA);
3696   } else {
3697     // gA = U gA V^H
3698     gA = m >= n ? at::matmul(U, at::matmul(gA, Vh))
3699                 : at::matmul(at::matmul(U, gA), Vh);
3700   }
3701 
3702   return gA;
3703 }
3704 
linalg_eig_backward(const Tensor & gL,const Tensor & gV,const Tensor & L,const Tensor & V,const bool is_hermitian,const bool symeig_eigenvectors)3705 Tensor linalg_eig_backward(
3706     const Tensor& gL,
3707     const Tensor& gV,
3708     const Tensor& L,
3709     const Tensor& V,
3710     const bool is_hermitian,
3711     const bool symeig_eigenvectors) {
3712   at::NoTF32Guard disable_tf32;
3713   // https://arxiv.org/pdf/1701.00392.pdf Eq 4.77
3714   // For A = VLV^{-1}, denoting the gradients gA, gV and gL, we have
3715   // gA = V^{-H}(diag_embed(gL) + (V^H gV -V^HV diag(real(V^H gV))) / E*)V^H
3716   // Where:
3717   //   - E_{ij} = L_j - L_i if i != j
3718   //              1         otherwise
3719   //   - diag_embed takes a vector into a diagonal matrix
3720   //   - diag zeroes out elements outside of the diagonal
3721 
3722   // Note: the term '-V^HV diag(real(V^H gV))' comes from the fact that the
3723   // eigenvalue decomposition is returned with eigenvectors normalized to have
3724   // norm one.
3725 
3726   // Note: The Hermitian case is a simplification of this formula using that
3727   // V^{-1} = V^H and that L is real
3728 
3729   // This check just can be triggered in the backwards of torch.symeig
3730   TORCH_CHECK(
3731       symeig_eigenvectors,
3732       "linalg_eig_backward: torch.symeig(A, eigenvectors=False) is not differentiable. ",
3733       "Use torch.linalg.eigvalsh(A) instead.");
3734 
3735   // Trivial case
3736   if (!gL.defined() && !gV.defined()) {
3737     return {};
3738   }
3739 
3740   // Shortcut for linalg.eigvals/eigvalsh
3741   // Compute V^-H gL V^H
3742   if (!gV.defined()) {
3743     if (is_hermitian) {
3744       return at::matmul(V * gL.unsqueeze(-2), V.mH());
3745     } else {
3746       return at::linalg_solve(V.mH(), gL.unsqueeze(-1) * V.mH());
3747     }
3748   }
3749   auto VhgV = at::matmul(V.mH(), gV);
3750   const auto diag_VhgV = VhgV.diagonal(0, -2, -1);
3751 
3752   if (V.is_complex() && !at::isTensorSubclassLike(diag_VhgV)) {
3753     // Check invariance of the loss function wrt the transformation
3754     // V -> V * e^{i\phi} for an arbitrary phi in RR^n
3755     const auto imdiag_VhgV = at::imag(diag_VhgV);
3756     TORCH_CHECK(
3757         at::allclose(
3758             imdiag_VhgV,
3759             at::zeros_like(imdiag_VhgV),
3760             /*rtol=*/1e-2,
3761             /*atol=*/1e-2),
3762         is_hermitian ? "linalg_eigh_backward" : "linalg_eig_backward",
3763         ": The eigenvectors in the complex case are specified up to multiplication ",
3764         "by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined.");
3765   }
3766 
3767   if (is_hermitian) {
3768     // Project onto the tangent space at the identity of U(n), that is, the
3769     // skew-Hermitian matrices
3770     VhgV = 0.5 * (VhgV - VhgV.mH());
3771   } else {
3772     // Project onto the tangent space at V^H V of complex matrices with columns
3773     // of norm 1
3774     VhgV = VhgV - at::matmul(V.mH(), V * at::real(diag_VhgV).unsqueeze(-2));
3775   }
3776 
3777   auto gA = [&, VhgV = std::move(VhgV)] {
3778     auto Econj = [&L] {
3779       auto Lconj = L.conj();
3780       auto ret = Lconj.unsqueeze(-2) - Lconj.unsqueeze(-1);
3781       ret.diagonal(0, -2, -1).fill_(1.);
3782       return ret;
3783     }();
3784 
3785     auto ret = VhgV.div_(Econj);
3786 
3787     if (gL.defined()) {
3788       // For CompositeCompliance, if `gL` is subclass but `ret`
3789       // is a regular Tensor, then use out-of-place version of diagonal
3790       // copy aka `diagonal_scatter`.
3791       if (at::isTensorSubclassLike(gL)) {
3792         ret = ret.diagonal_scatter(gL, 0, -2, -1);
3793       } else {
3794         ret.diagonal(0, -2, -1).copy_(gL);
3795       }
3796     }
3797     return ret;
3798   }();
3799 
3800   // Conjugate by V^{-H}
3801   if (is_hermitian) {
3802     return at::matmul(V, at::matmul(gA, V.mH()));
3803   } else {
3804     return at::linalg_solve(V.mH(), at::matmul(gA, V.mH()));
3805   }
3806 }
3807 
linalg_eig_jvp(const Tensor & dA,const Tensor & L,const Tensor & V,const bool is_hermitian)3808 std::tuple<Tensor, Tensor> linalg_eig_jvp(
3809     const Tensor& dA,
3810     const Tensor& L,
3811     const Tensor& V,
3812     const bool is_hermitian) {
3813   at::NoTF32Guard disable_tf32;
3814   // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
3815   // see also https://arxiv.org/pdf/1701.00392.pdf Eqs. (4.60) and (4.63)
3816   // Note that neither of the formulas in these pdfs are correct, as they do not
3817   // assume that the eigenvectors are of unit norm. As such, they are missing
3818   // the diagonal term in dV dL = diag(dP) dV = dX - V Re(diag V^H dX)) where dP
3819   // = V^{-1} dA V dX = V ((dP - diag(dP)) / E) E_{ij} = L_j - L_i if i != j
3820   //          1         otherwise
3821 
3822   // Precondition: if is_hermitian == true, then dA is Hermitian
3823   const auto to_complex = [](const Tensor& A) {
3824     return A.to(c10::toComplexType(A.scalar_type()));
3825   };
3826 
3827   const auto dP = is_hermitian
3828       ? at::matmul(at::matmul(V.mH(), dA), V)
3829       : at::linalg_solve(V, at::matmul(to_complex(dA), V));
3830   auto dL = is_hermitian && dA.is_complex() ? at::real(dP.diagonal(0, -2, -1))
3831                                             : dP.diagonal(0, -2, -1);
3832   auto dV = [&dP, &V, &L, is_hermitian] {
3833     auto dX = [&] {
3834       auto ret = dP / (L.unsqueeze(-2) - L.unsqueeze(-1));
3835       ret.diagonal(0, -2, -1).zero_();
3836       ret = at::matmul(V, ret);
3837       return ret;
3838     }();
3839 
3840     if (is_hermitian) {
3841       return dX;
3842     } else {
3843       return dX -
3844           V *
3845           at::real(at::matmul(V.mH(), dX).diagonal(0, -2, -1)).unsqueeze(-2);
3846     }
3847   }();
3848   return std::make_pair(std::move(dL), std::move(dV));
3849 }
3850 
linalg_lstsq_jvp(const Tensor & A,const Tensor & B,const Tensor & dA,const Tensor & dB)3851 Tensor linalg_lstsq_jvp(
3852     const Tensor& A,
3853     const Tensor& B,
3854     const Tensor& dA,
3855     const Tensor& dB) {
3856   at::NoTF32Guard disable_tf32;
3857   auto pinvA = at::linalg_pinv(A);
3858   auto dpinvA = pinv_jvp(A, pinvA, dA);
3859   auto dX = dpinvA.matmul(B) + pinvA.matmul(dB);
3860   return dX;
3861 }
3862 
linalg_lstsq_backward(const Tensor & gX_,const Tensor & A,const Tensor & B_,const std::array<bool,2> & grad_input_mask)3863 std::tuple<Tensor, Tensor> linalg_lstsq_backward(
3864     const Tensor& gX_,
3865     const Tensor& A,
3866     const Tensor& B_,
3867     const std::array<bool, 2>& grad_input_mask) {
3868   at::NoTF32Guard disable_tf32;
3869   auto A_requires_grad = grad_input_mask[0];
3870   auto B_requires_grad = grad_input_mask[1];
3871   if (!gX_.defined() || (!A_requires_grad && !B_requires_grad)) {
3872     return {};
3873   }
3874 
3875   const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B_);
3876   const auto vector_to_matrix = [vector_case](const Tensor& X) {
3877     return vector_case ? X.unsqueeze(-1) : X;
3878   };
3879   const auto matrix_to_vector = [vector_case](const Tensor& X) {
3880     return vector_case ? X.squeeze(-1) : X;
3881   };
3882 
3883   auto gX = vector_to_matrix(gX_);
3884   auto B = vector_to_matrix(B_);
3885   Tensor pinvA = at::linalg_pinv(A);
3886   Tensor A_grad, B_grad;
3887   if (A_requires_grad) {
3888     auto pinvA_grad = gX.matmul(B.mH());
3889     A_grad = pinv_backward(pinvA_grad, pinvA, A);
3890   }
3891 
3892   if (B_requires_grad) {
3893     // Equivalent to
3894     // B_grad = std::get<0>(at::linalg_lstsq(A.mH(), gX, rcond, driver));
3895     // but we avoid this approach as `gelsy` is non-deterministic
3896     B_grad = matrix_to_vector(pinvA.mH().matmul(gX));
3897   }
3898 
3899   return std::make_tuple(A_grad, B_grad);
3900 }
3901 
linalg_qr_jvp(const Tensor & dA,const Tensor & Q,const Tensor & R,const c10::string_view mode)3902 std::tuple<Tensor, Tensor> linalg_qr_jvp(
3903     const Tensor& dA,
3904     const Tensor& Q,
3905     const Tensor& R,
3906     const c10::string_view mode) {
3907   // dA = dQR + QdR
3908   //
3909   // Case m >= n
3910   // We can put dQ in terms of dR
3911   // dQ = dAR^{-1} - QdRR^{-1}
3912   // Then we have
3913   // Q^H dA R^{-1} = Q^HdQ + dRR^{-1}
3914   // where Q^HdQ is skew Hermitian and dRR^{-1} is upper triangular
3915   // Define sym(X) = X + X^H
3916   // sym(dRR^{-1}) = sym(Q^H dA R^{-1})
3917   // and define syminv(X) = triu(X) - 0.5 * diag(X) the inverse of
3918   // sym : Triu(k, diag \in \mathbb{R}) -> Her(k) to give
3919   // dR = syminv(sym(Q^H dA R^{-1}))R
3920   //
3921   // Case m < n
3922   // Put dR as a function of dQ
3923   // dR = Q^H dA - Q^H dQ R
3924   // Let X_1 be the main m x m submatrix of a matrix X \in C^{m x n}
3925   // Q^H A_1 R_1^{-1} = Q^H dQ + dR_1 R_1^{-1}
3926   // Define trilIm(X) = X.tril(-1) + i * Im diag(X)
3927   // trilIm(Q^H dQ) = trilIm(Q^H A_1 R_1^{-1})
3928   // and define trilIminv(X) = X - X^H - i*Im diag(X). This is the inverse of
3929   // trilIm : Skew_C(m) -> Tril(m, imaginary diag)
3930   // Note that it is just the inverse when the inputs are skew-Hermitian, not
3931   // necessarily when the inputs are arbitrary matrices. We then get dQ = Q
3932   // trilImInv(trilIm(Q^H A_1 R_1^{-1}))
3933   at::NoTF32Guard disable_tf32;
3934 
3935   auto [compute_q, reduced] = at::native::_parse_qr_mode(mode);
3936 
3937   TORCH_CHECK(
3938       compute_q,
3939       "The derivative of linalg.qr depends on Q, which is not computed when "
3940       "mode='r'. Please use linalg.qr(A, mode='reduced') if you are "
3941       "going to differentiate through linalg.qr.");
3942   auto m = dA.size(-2);
3943   auto n = dA.size(-1);
3944 
3945   TORCH_CHECK(
3946       reduced || m <= n,
3947       "The QR decomposition is not differentiable when "
3948       "mode='complete' and nrows > ncols.");
3949   if (m >= n) {
3950     const auto sym = [](const Tensor& X) { return X + X.mH(); };
3951     const auto syminv = [](const Tensor& X) {
3952       auto ret = X.triu();
3953       ret.diagonal(0, -2, -1).mul_(0.5);
3954       return ret;
3955     };
3956     auto dARinv =
3957         at::linalg_solve_triangular(R, dA, /*upper=*/true, /*left=*/false);
3958     auto dR = syminv(sym(Q.mH().matmul(dARinv)));
3959     auto dQ = dARinv - Q.matmul(dR);
3960     dR = dR.matmul(R);
3961     return std::make_tuple(std::move(dQ), std::move(dR));
3962   } else {
3963     const auto trilim = [](const Tensor& X) {
3964       if (X.is_complex()) {
3965         auto ret = X.tril();
3966         at::real(ret.diagonal(0, -2, -1)).zero_();
3967         return ret;
3968       } else {
3969         return X.tril(-1);
3970       }
3971     };
3972     const auto triliminv = [](const Tensor& X) {
3973       if (X.is_complex()) {
3974         auto ret = X - X.mH();
3975         ret.diagonal(0, -2, -1).mul_(0.5);
3976         return ret;
3977       } else {
3978         return X - X.mT();
3979       }
3980     };
3981 
3982     auto QHdA = Q.mH().matmul(dA);
3983     auto QHdA1Rinv = at::linalg_solve_triangular(
3984         R.narrow(-1, 0, m),
3985         QHdA.narrow(-1, 0, m),
3986         /*upper=*/true,
3987         /*left=*/false);
3988     auto dQ = triliminv(trilim(QHdA1Rinv));
3989     auto dR = QHdA - dQ.matmul(R);
3990     dQ = Q.matmul(dQ);
3991     return std::make_tuple(std::move(dQ), std::move(dR));
3992   }
3993 }
3994 
linalg_qr_backward(const Tensor & gQ,const Tensor & gR,const Tensor & Q,const Tensor & R,const c10::string_view mode)3995 Tensor linalg_qr_backward(
3996     const Tensor& gQ,
3997     const Tensor& gR,
3998     const Tensor& Q,
3999     const Tensor& R,
4000     const c10::string_view mode) {
4001   // Nb. We won't be too formal below, as writing this proof formally is a pain
4002   // We'll link here a formal writing of all this at some point in the future
4003   //
4004   // Case m >= n
4005   // dQ = dAR^{-1} - Qsyminv(sym(Q^H dA R^{-1}))
4006   // dR = syminv(sym(Q^H dA R^{-1}))R
4007   //
4008   // With the notation from the JVP formula, the only two computations that we
4009   // need are syminv*(R) = 0.5 * (R.triu() + R.triu()^H - Re diag(R)) sym*(X) =
4010   // 2 * X Using these, after a few simplifications we get that gA = (gQ +
4011   // syminvadj(triu(gR R^H - Q^H gQ)))R^{-H}
4012   //
4013   // Case m < n
4014   // dR = Q^H dA - Q^H dQ R
4015   // dQ = Q trilImInv(trilIm(Q^H A_1 R_1^{-1}))
4016   //
4017   // In this case trilIm*(X) = X (it's the trivial embedding)
4018   // while trilImInv*(X) = tril(Y) - 0.5 * diag(Y)
4019   // with Y = X - X^H
4020   //
4021   // We also have that if X \in C^{m, n} an dpi(X) = X_1,
4022   // projects X into its leading m x m submatrix,
4023   // pi*(X) = cat(X, 0_{m,n-m}, dim=-1)
4024   //
4025   // Using this, we get that
4026   // gA = QgR + pi*(Q trilImInv*(Q^H gQ - gR R^H)R_1^{-H})
4027   at::NoTF32Guard disable_tf32;
4028 
4029   auto [compute_q, reduced] = at::native::_parse_qr_mode(mode);
4030 
4031   TORCH_CHECK(
4032       compute_q,
4033       "The derivative of linalg.qr depends on Q, which is not computed when "
4034       "mode='r'. Please use linalg.qr(A, mode='reduced') if you are "
4035       "going to differentiate through linalg.qr.");
4036 
4037   auto m = Q.sym_size(-2);
4038   auto n = R.sym_size(-1);
4039 
4040   TORCH_CHECK(
4041       reduced || m <= n,
4042       "The QR decomposition is not differentiable when "
4043       "mode='complete' and nrows > ncols.");
4044 
4045   if (!gQ.defined() && !gR.defined()) {
4046     return {};
4047   }
4048 
4049   Tensor gA;
4050   if (gQ.defined()) {
4051     if (gR.defined()) {
4052       gA = gR.matmul(R.mH()) - Q.mH().matmul(gQ);
4053     } else {
4054       gA = -Q.mH().matmul(gQ);
4055     }
4056   } else {
4057     gA = gR.matmul(R.mH());
4058   }
4059   if (m >= n) {
4060     const auto syminvadj = [](const Tensor& X) {
4061       auto ret = X + X.mH();
4062       at::real(ret.diagonal(0, -2, -1)).mul_(0.5);
4063       return ret;
4064     };
4065     gA = Q.matmul(syminvadj(gA.triu()));
4066     if (gQ.defined()) {
4067       gA = gA + gQ;
4068     }
4069     gA = at::linalg_solve_triangular(
4070         R.mH(), gA, /*upper*/ false, /*left*/ false);
4071     return gA;
4072   } else {
4073     auto trilImInvAdjSkew = [](const Tensor& X) {
4074       auto ret = (X - X.mH()).tril();
4075       if (X.is_complex()) {
4076         at::imag(ret.diagonal(0, -2, -1)).mul_(0.5);
4077       }
4078       return ret;
4079     };
4080     gA = Q.matmul(trilImInvAdjSkew(-gA));
4081     gA = at::linalg_solve_triangular(
4082         R.narrow_symint(-1, 0, m).mH(), gA, /*upper*/ false, /*left*/ false);
4083     auto shape = R.sym_sizes().vec();
4084     shape.end()[-1] = n - m;
4085     gA = at::cat({gA, gA.new_zeros_symint(shape)}, /*dim=*/-1);
4086     if (gR.defined()) {
4087       gA = gA + Q.matmul(gR);
4088     }
4089     return gA;
4090   }
4091 }
4092 
4093 // Based on:
4094 //
4095 // Mathias, Roy.
4096 // A Chain Rule for Matrix Functions and Applications.
4097 // SIAM J. Matrix Anal. Appl. 17 (1996): 610-620.
4098 
4099 template <typename func_t>
differential_analytic_matrix_function(const Tensor & self,const Tensor & grad,const func_t & matrix_function,const bool adjoint)4100 Tensor differential_analytic_matrix_function(
4101     const Tensor& self,
4102     const Tensor& grad,
4103     const func_t& matrix_function,
4104     const bool adjoint // Choose between forward (adjoint=false) or backward AD
4105                        // (adjoint=true)
4106 ) {
4107   // Given an analytic matrix function, this computes the differential (forward
4108   // AD) or the adjoint of the differential (backward AD)
4109   auto A = adjoint ? self.transpose(-2, -1).conj() : self;
4110   auto meta_grad_sizes = A.sym_sizes().vec();
4111   meta_grad_sizes[A.dim() - 2] *= 2;
4112   meta_grad_sizes[A.dim() - 1] *= 2;
4113 
4114   auto n = A.sym_size(-1);
4115   Tensor meta_grad;
4116   // For Composite Compliance, we can't copy a Subclass into a Regular Tensor,
4117   // so we use out-of-place ops with equivalent output.
4118   // NOTE: We can't use `new_zeros` directly as both `A` and `grad` can
4119   // be Tensor Subclass and we don't want to make assumption about which
4120   // one to choose for creating output buffer.
4121   // eg. if both are BatchedTensor at different level.
4122   if (areAnyTensorSubclassLike({A, grad})) {
4123     meta_grad = at::cat(
4124         {at::cat({A, grad}, -1),
4125          at::cat({at::zeros_like(A), std::move(A)}, -1)},
4126         -2);
4127   } else {
4128     meta_grad = at::zeros_symint(meta_grad_sizes, grad.options());
4129     meta_grad.narrow_symint(-2, 0, n).narrow_symint(-1, 0, n).copy_(A);
4130     meta_grad.narrow_symint(-2, n, n).narrow_symint(-1, n, n).copy_(A);
4131     meta_grad.narrow_symint(-2, 0, n).narrow_symint(-1, n, n).copy_(grad);
4132   }
4133 
4134   return matrix_function(meta_grad).narrow_symint(-2, 0, n).narrow_symint(
4135       -1, n, n);
4136 }
4137 
linalg_matrix_exp_differential(const Tensor & self,const Tensor & grad,bool adjoint)4138 Tensor linalg_matrix_exp_differential(
4139     const Tensor& self,
4140     const Tensor& grad,
4141     bool adjoint) {
4142   at::NoTF32Guard disable_tf32;
4143 
4144   return differential_analytic_matrix_function(
4145       self, grad, at::linalg_matrix_exp, /* adjoint */ adjoint);
4146 }
4147 
4148 template <typename F1, typename F2, typename... Ts>
masked_fmap(const Tensor & mask,const F1 & f1,const F2 & f2,const Tensor & t,const Ts &...ts)4149 Tensor masked_fmap(
4150     const Tensor& mask,
4151     const F1& f1,
4152     const F2& f2,
4153     const Tensor& t,
4154     const Ts&... ts) {
4155   // This function takes two functions f1 and f2 and a (variadic) list of
4156   // tensors, and creates a new tensor of the same shape as the first element of
4157   // the list of tensors by applying the function f1 to the tensors for which
4158   // the mask is true and f2 to the tensors for which the mask is false This
4159   // function is used when we have a formula that works for, say, all
4160   // non-singular inputs and another one for when the inputs are singular. See
4161   // for example det_backward
4162 
4163   // Precondition for the n == 0 case to make sense
4164   TORCH_INTERNAL_ASSERT(t.sym_numel() != 0);
4165   auto t_masked = t.index({mask});
4166   auto n = t_masked.sym_numel();
4167   if (n == t.sym_numel()) {
4168     return f1(t, ts...);
4169   } else if (n == 0) {
4170     return f2(t, ts...);
4171   } else {
4172     // Equivalent to
4173     // ret = torch.empty_like(t)
4174     // ret[mask] = f1(t1[mask], ..., tn[mask])
4175     // ret[~mask] = f2(t1[~mask], ..., tn[~mask])
4176     auto not_mask = mask.logical_not();
4177     return at::empty_like(t)
4178         .index_put_({mask}, f1(t_masked, ts.index({mask})...))
4179         .index_put_(
4180             {not_mask}, f2(t.index({not_mask}), ts.index({not_mask})...));
4181   }
4182 }
4183 
linalg_det_jvp(const Tensor & dA,const Tensor & det,const Tensor & LU,const Tensor & pivots,const bool use_A_T)4184 Tensor linalg_det_jvp(
4185     const Tensor& dA,
4186     const Tensor& det,
4187     const Tensor& LU,
4188     const Tensor& pivots,
4189     const bool use_A_T) {
4190   // (d det)_A(E) = tr(A^{-1}E)*det
4191   // We use that the determinant is C^1 to approximate the gradient of singular
4192   // inputs Since we never differentiate over forward AD, we don't need to deal
4193   // with further gradients, as we do in grad_backward
4194   auto eps = at::native::_get_epsilon(c10::toRealValueType(LU.scalar_type()));
4195   auto LU_ =
4196       LU + at::diag_embed(at::where(LU.diagonal(0, -2, -1) == 0., eps, 0.));
4197   auto AinvE =
4198       at::linalg_lu_solve(LU_, pivots, dA, /*left=*/true, /*adjoint=*/use_A_T);
4199   return AinvE.diagonal(0, -2, -1).sum(-1) * det;
4200 }
4201 
linalg_det_backward(const Tensor & grad,const Tensor & det,const Tensor & A,const Tensor & LU,const Tensor & pivots)4202 Tensor linalg_det_backward(
4203     const Tensor& grad,
4204     const Tensor& det,
4205     const Tensor& A,
4206     const Tensor& LU,
4207     const Tensor& pivots) {
4208   at::NoTF32Guard disable_tf32;
4209   // A.numel() == 0 necessary for the singular case
4210   if (!grad.defined() || A.sym_numel() == 0) {
4211     return {};
4212   }
4213 
4214   // The gradient G is the matrix solving
4215   // A.mH G = det(A).conj() * grad * I
4216   auto d_diag = grad * det.conj();
4217   // Optimisation, Make it F-transposed as it's what lu_solve expects
4218   auto d = at::diag_embed(d_diag.unsqueeze(-1).expand_as(pivots)).mT();
4219   auto eps = at::native::_get_epsilon(c10::toRealValueType(LU.scalar_type()));
4220 
4221   // Optimisation if we are not going to compute higher-order gradients
4222   if (!at::GradMode::is_enabled()) {
4223     // The formula is given by the solution of AX = det.conj() * det * I when A
4224     // is invertible det is C^1, so if it's not invertible, we can apply a
4225     // perturbation to the LU decomposition and use the resulting matrix as a
4226     // non-singular approximation
4227     auto LU_ =
4228         LU + at::diag_embed(at::where(LU.diagonal(0, -2, -1) == 0., eps, 0.));
4229     auto use_A_T = A.is_contiguous() && !A.is_complex();
4230     return at::linalg_lu_solve(
4231         LU_, pivots, d, /*left=*/true, /*adjoint=*/!use_A_T);
4232   } else {
4233     // If we want to compute higher-order gradients, we need to recompute the
4234     // LU decomposition so that autograd computes the correct gradients wrt
4235     // to A (cf. solve_backward)
4236     auto non_singular =
4237         [](const Tensor& A, const Tensor& d, const Tensor& /*grad*/) {
4238           return at::linalg_solve(A.mH(), d);
4239         };
4240 
4241     // The derivative may be then computed explicitly by noting that the
4242     // gradient of the derivative of the determinant is given in terms of the
4243     // adjugate of a matrix. The adjugate of a singular matrix may be computed
4244     // as per https://nhigham.com/2020/06/16/what-is-the-adjugate-of-a-matrix/
4245     auto singular = [](const Tensor& A,
4246                        const Tensor& /*d*/,
4247                        const Tensor& grad) {
4248       auto [U, S, Vh] = at::linalg_svd(A);
4249       auto alpha = (at::linalg_det(U) * at::linalg_det(Vh)).conj() * grad;
4250       auto D = prod_safe_zeros_backward(alpha.unsqueeze(-1), S, S.dim() - 1);
4251       return (U * D.unsqueeze(-2)).matmul(Vh);
4252     };
4253 
4254     // We could use the singular formula for all inputs but we try to filter out
4255     // some inputs via the masking, as computing an SVD is about 100 times
4256     // slower than computing an lu_solve on GPU
4257     // For tensor subclasses, we can't call masked_fmap as it calls
4258     // index({mask}) which needs to call item to compute the number of elements
4259     // in the result.
4260 
4261     if (areAnyTensorSubclassLike({A, d, grad})) {
4262       return singular(A, d, grad);
4263     } else {
4264       return masked_fmap(
4265           det.abs() < 100. * eps, singular, non_singular, A, d, grad);
4266     }
4267   }
4268 }
4269 
slogdet_jvp(const Tensor & LU,const Tensor & pivots,const Tensor & dA,const Tensor & sign,const bool use_A_T)4270 std::tuple<Tensor, Tensor> slogdet_jvp(
4271     const Tensor& LU,
4272     const Tensor& pivots,
4273     const Tensor& dA,
4274     const Tensor& sign,
4275     const bool use_A_T) {
4276   // No need to handle the singular case separately as we do in det since
4277   // this function is not differentiable on singular matrices
4278   auto trAinvE = at::linalg_lu_solve(LU, pivots, dA, /*left*/ true, use_A_T)
4279                      .diagonal(0, -2, -1)
4280                      .sum(-1);
4281   if (LU.is_complex()) {
4282     auto i = c10::complex<double>{0.0, 1.0};
4283     return std::make_tuple(at::imag(trAinvE) * (i * sign), at::real(trAinvE));
4284   } else {
4285     return std::make_tuple(
4286         at::_efficientzerotensor(sign.sizes(), sign.options()), trAinvE);
4287   }
4288 }
4289 
slogdet_backward(const Tensor & grad_sign,const Tensor & grad_logabsdet,const Tensor & A,const Tensor & signdet,const Tensor & LU,const Tensor & pivots)4290 Tensor slogdet_backward(
4291     const Tensor& grad_sign,
4292     const Tensor& grad_logabsdet,
4293     const Tensor& A,
4294     const Tensor& signdet,
4295     const Tensor& LU,
4296     const Tensor& pivots) {
4297   // We compute the complex case, as the real case follows from it
4298   // Forward AD
4299   // d (logabsdet)_A(E) = Re(tr(A^{-1}E))
4300   // d (signdet)_A(E) = sgn * Im(tr(A^{-1}E)) * i
4301   // So
4302   // d (logabsdet)*_A(g) = gA^{-H}
4303   // Now, to compute the adjoint of d(signdet), note that
4304   // Re(z * Im(w)) = Re(-Re(z)iw)
4305   // So, let g \in C,
4306   // <g, d(signdet)_A(E)> = Re(g.conj() * sgn * i * Im(A^{-1}E))
4307   //                      = Re(Re(g.conj() * sgn * i) * -i * A^{-1}E)
4308   //                      = Re(Im(g.conj() * sgn) * i * A^{-1}E)
4309   //                      = <Im(g.conj() * sgn) * -i * A^{-H}, E>
4310   // As such,
4311   // (d slogabs)*_A(g_sign, g_abs) = (g_abs - g_sign.conj() * sgn) * A^{-H}
4312 
4313   if (!grad_sign.defined() && !grad_logabsdet.defined()) {
4314     return {};
4315   }
4316 
4317   auto is_complex = A.is_complex();
4318 
4319   // In the real case grad_sign is always zero
4320   if (!is_complex && !grad_logabsdet.defined()) {
4321     return {};
4322   }
4323 
4324   auto g = grad_logabsdet;
4325   if (is_complex) {
4326     if (grad_sign.defined()) {
4327       auto i = c10::complex<double>{0.0, 1.0};
4328       if (g.defined()) {
4329         g = g - i * at::imag(grad_sign.conj() * signdet);
4330       } else {
4331         g = -i * at::imag(grad_sign.conj() * signdet);
4332       }
4333     } else {
4334       // Cast to complex explicitly
4335       g = g.to(A.scalar_type());
4336     }
4337   }
4338 
4339   // No need to handle the singular case separately here (as we do in det)
4340   // since this function is not differentiable on singular matrices
4341   // Optimisation, Make it F-transposed as it's what lu_solve expects
4342   auto d = at::diag_embed(g.unsqueeze(-1).expand_as(pivots)).mT();
4343   if (!at::GradMode::is_enabled()) {
4344     auto use_A_T = A.is_contiguous() && !A.is_complex();
4345     return at::linalg_lu_solve(
4346         LU, pivots, d, /*left=*/true, /*adjoint=*/!use_A_T);
4347   } else {
4348     // If we want to compute further gradients, we need to recompute the LU
4349     // decomposition so that autograd computes the correct gradients wrt to A
4350     // (cf. solve_backward)
4351     return at::linalg_solve(A.mH(), d);
4352   }
4353 }
4354 
4355 // Reference:
4356 // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
4357 // Sec. 2.3.1 Matrix inverse product
triangular_solve_backward(const Tensor & grad_x,const Tensor & grad_m,const Tensor & b,const Tensor & a,const Tensor & x,const bool upper,const bool transpose,const bool unitriangular,std::array<bool,2> output_mask)4358 std::tuple<Tensor, Tensor> triangular_solve_backward(
4359     const Tensor& grad_x,
4360     const Tensor& grad_m,
4361     const Tensor& b,
4362     const Tensor& a,
4363     const Tensor& x,
4364     const bool upper,
4365     const bool transpose,
4366     const bool unitriangular,
4367     std::array<bool, 2> output_mask) {
4368   at::NoTF32Guard disable_tf32;
4369   Tensor grad_b, grad_a;
4370   if (grad_x.defined() || grad_m.defined()) {
4371     if (grad_x.defined()) {
4372       grad_b = std::get<0>(
4373           grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular));
4374       if (output_mask[1]) {
4375         grad_a =
4376             transpose ? -x.conj().matmul(grad_b.mT()) : -grad_b.matmul(x.mH());
4377         if (upper) {
4378           grad_a = grad_a.triu((int)unitriangular);
4379         } else {
4380           grad_a = grad_a.tril(-((int)unitriangular));
4381         }
4382       }
4383     }
4384     if (!grad_a.defined()) {
4385       grad_a = at::zeros({1}, a.options()).expand_as(a);
4386     }
4387     if (!grad_b.defined()) {
4388       grad_b = at::zeros({1}, b.options()).expand_as(b);
4389     }
4390     if (output_mask[1] && grad_m.defined()) {
4391       grad_a = grad_a.add(grad_m);
4392     }
4393   }
4394   return std::tuple<Tensor, Tensor>{grad_b, grad_a};
4395 }
4396 
triangular_solve_jvp(const Tensor & X,const Tensor & A,const Tensor & dA,const Tensor & dB,const bool upper,const bool transpose,const bool unitriangular)4397 Tensor triangular_solve_jvp(
4398     const Tensor& X,
4399     const Tensor& A,
4400     const Tensor& dA,
4401     const Tensor& dB,
4402     const bool upper,
4403     const bool transpose,
4404     const bool unitriangular) {
4405   return generic_solve_jvp(
4406       [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
4407         return std::get<0>(at::triangular_solve(
4408             dB - dA_contrib, A, upper, transpose, unitriangular));
4409       },
4410       X,
4411       A,
4412       dA,
4413       dB);
4414 }
4415 
linalg_solve_triangular_forward_AD(const Tensor & A_t,const Tensor & B_t,const Tensor & A,const Tensor & X,const bool upper,const bool left,const bool unitriangular)4416 Tensor linalg_solve_triangular_forward_AD(
4417     const Tensor& A_t,
4418     const Tensor& B_t,
4419     const Tensor& A,
4420     const Tensor& X,
4421     const bool upper,
4422     const bool left,
4423     const bool unitriangular) {
4424   at::NoTF32Guard disable_tf32;
4425   // The forward AD formula (for left = true) is A^{-1}(B_t - A_tX)
4426   // For the derivation see:
4427   // [Note: Forward / Backward AD solve_triangular]
4428   const Tensor proj_A_t = upper ? A_t.triu(static_cast<int>(unitriangular))
4429                                 : A_t.tril(-static_cast<int>(unitriangular));
4430   const Tensor X_t =
4431       B_t - (left ? at::matmul(proj_A_t, X) : at::matmul(X, proj_A_t));
4432   return at::linalg_solve_triangular(A, X_t, upper, left, unitriangular);
4433 }
4434 
linalg_solve_triangular_backward(const Tensor & grad,const Tensor & A,const Tensor & X,const bool upper,const bool left,const bool unitriangular,std::array<bool,2> output_mask)4435 std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
4436     const Tensor& grad,
4437     const Tensor& A,
4438     const Tensor& X,
4439     const bool upper,
4440     const bool left,
4441     const bool unitriangular,
4442     std::array<bool, 2> output_mask) {
4443   at::NoTF32Guard disable_tf32;
4444   const bool A_requires_grad = output_mask[0];
4445   const bool B_requires_grad = output_mask[1];
4446   // [Note: Forward / Backward AD solve_triangular]
4447   // Assume left=true for simplicity.
4448   // Remark: A solver computes A^{-1}B
4449   //
4450   // Forward AD:
4451   // If f(A) = A^{-1}, differentiating the equation A^{-1}A = I_n gives
4452   // (df)_A(E) = -A^{-1}EA^{-1}
4453   // As such, if g(A,B) = A^{-1}B,
4454   // (dg)_(A,B)(E_A, E_B) = -A^{-1}E_AA^{-1}B + A^{-1}E_B
4455   //                      = A^{-1}(E_B - E_AX)
4456 
4457   // Backward AD:
4458   // Denoting the gradients by G_A, G_B, we solve above to give
4459   // G_B = A^{-H}G_X
4460   // G_A = -A^{-H}G_XX^H = -G_B X^H
4461   //
4462   // Note that you don't need to store B for forward nor backward
4463   //
4464   // These formulas work for a general solver of linear equations.
4465   // Let's prove now that when A is triangular, G_A is the projection onto the
4466   // triangular matrices of the formula above, i.e. simply taking triu (resp.
4467   // tril) in the formula above. This is because, since the triangular matrices
4468   // form a vector space, the tangent space at any point is itself the space of
4469   // triangular matrices. The result follows from a reasoning as that at the end
4470   // of [Note: eigh backward] Something similar happens for `unitriangular`,
4471   // only that int his case the tangent space is the set of lower-triangular
4472   // matrices with zeros on the diagonal.
4473 
4474   if (!grad.defined() || (!A_requires_grad && !B_requires_grad)) {
4475     return std::make_tuple(Tensor{}, Tensor{});
4476   }
4477   // We always need to comput G_B
4478   const Tensor A_H = A.mH();
4479   const Tensor G_B =
4480       at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular);
4481 
4482   if (A_requires_grad) {
4483     const Tensor X_H = X.mH();
4484     Tensor G_A = left ? -at::matmul(G_B, X_H) : -at::matmul(X_H, G_B);
4485     G_A = upper ? G_A.triu(static_cast<int>(unitriangular))
4486                 : G_A.tril(-static_cast<int>(unitriangular));
4487     return std::make_tuple(G_A, B_requires_grad ? G_B : Tensor{});
4488   } else {
4489     return std::make_tuple(Tensor{}, G_B);
4490   }
4491 }
4492 
cholesky_solve_backward(const Tensor & grad_x,const Tensor & self,const Tensor & input2,const Tensor & result,const bool upper,std::array<bool,2> output_mask)4493 std::tuple<Tensor, Tensor> cholesky_solve_backward(
4494     const Tensor& grad_x,
4495     const Tensor& self,
4496     const Tensor& input2,
4497     const Tensor& result,
4498     const bool upper,
4499     std::array<bool, 2> output_mask) {
4500   at::NoTF32Guard disable_tf32;
4501   Tensor grad_self, grad_input2;
4502   if (grad_x.defined()) {
4503     grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper);
4504 
4505     if (output_mask[1]) {
4506       Tensor common_term = at::matmul(grad_self, result.mH());
4507       common_term = common_term + common_term.mH();
4508 
4509       if (upper) {
4510         grad_input2 = -at::matmul(input2, common_term);
4511       } else {
4512         grad_input2 = -at::matmul(common_term, input2);
4513       }
4514     }
4515   }
4516   return std::tuple<Tensor, Tensor>{grad_self, grad_input2};
4517 }
4518 
cholesky_solve_jvp(const Tensor & X,const Tensor & U,const Tensor & dU,const Tensor & dB,const bool upper)4519 Tensor cholesky_solve_jvp(
4520     const Tensor& X,
4521     const Tensor& U,
4522     const Tensor& dU,
4523     const Tensor& dB,
4524     const bool upper) {
4525   at::NoTF32Guard disable_tf32;
4526   auto dK = upper ? dU.mH().matmul(U) : dU.matmul(U.mH());
4527   auto dA = dK + dK.mH();
4528   return generic_solve_jvp(
4529       [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
4530         return at::cholesky_solve(dB - dA_contrib, A, upper);
4531       },
4532       X,
4533       /*A=*/U,
4534       dA,
4535       dB);
4536 }
4537 
fft_c2r_backward(const Tensor & grad,IntArrayRef dim,int64_t normalization)4538 Tensor fft_c2r_backward(
4539     const Tensor& grad,
4540     IntArrayRef dim,
4541     int64_t normalization) {
4542   // Forward is C2R (onesided)
4543   // Think of onesided C2R irfft as
4544   //    1. fill the other half by conjugate symmetry
4545   //    2. inverse C2C ifft
4546   //    3. discard the complex dimension
4547   // So backward is
4548   //    1. R2C rfft (essentially add dummy complex dimension, and dft)
4549   //    2. accumulate gradient by conjugate symmetry
4550   //       since rfft results follow conjugate symmetry, we only need to
4551   //       double some entries from onesided rfft results, i.e., the ones with
4552   //       their reflected indices also landing out of the onesided range. So
4553   //       consider the index of last dim:
4554   //           i.   idx = 0.
4555   //                Reflected to (N - 0) % N = 0. Not doubled.
4556   //           ii   0 < idx < floor(N/2) (last).
4557   //                N > N - idx > ceil(N/2)
4558   //                Reflected to ()
4559   //           iii. idx = floor(N/2) = N/2 (last) when N even.
4560   //                Reflected to (N - N/2) % N = N/2. Not doubled.
4561   //           iv.  idx = floor(N/2) = (N-1)/2 (last) when N odd.
4562   //                Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled.
4563   //       Therefore, needs to double
4564   //           idx = 1, 2, ..., N/2 - 1     when N even
4565   //           idx = 1, 2, ..., (N-1)/2     when N odd
4566   //       that is
4567   //           idx = 1, 2, ..., N - (floor(N/2) + 1)
4568   //               = 1, 2, ..., N - onesided_length
4569   auto gI = at::_fft_r2c(grad, dim, normalization, /*onesided=*/true);
4570 
4571   auto double_length = grad.sym_size(dim.back()) - gI.sym_size(dim.back());
4572   if (double_length > 0) { // also covers case when signal size is zero
4573     gI.narrow_symint(dim.back(), 1, double_length).mul_(2);
4574   }
4575   return gI;
4576 }
4577 
fft_r2c_backward(const Tensor & grad,at::IntArrayRef dim,int64_t normalization,bool onesided,const c10::SymInt & last_dim_size)4578 Tensor fft_r2c_backward(
4579     const Tensor& grad,
4580     at::IntArrayRef dim,
4581     int64_t normalization,
4582     bool onesided,
4583     const c10::SymInt& last_dim_size) {
4584   if (!onesided) {
4585     return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false));
4586   }
4587 
4588   // Forward is R2C (onesided)
4589   // Think of onesided R2C rfft as
4590   //     1. view as complex numbers (fill complex dim with zeros)
4591   //     2. C2C fft
4592   //     3. discard half of results
4593   // So backward is
4594   //     1. fill the other half with zeros (with `zero_grad_shape` below)
4595   //        (C2C ifft only take twosided inputs so we need to fill here)
4596   //     2. inverse C2C ifft
4597   //     3. discard the complex dim
4598   auto half_sizes = grad.sym_sizes();
4599   std::vector<c10::SymInt> new_grad_shape(half_sizes.begin(), half_sizes.end());
4600   const auto last_dim =
4601       at::maybe_wrap_dim(dim.back(), static_cast<int64_t>(half_sizes.size()));
4602   new_grad_shape[last_dim] = last_dim_size;
4603 
4604   const auto zero_length = last_dim_size - grad.sym_size(dim.back());
4605   auto complex_full_grad =
4606       zero_length > 0 ? grad.new_zeros_symint(new_grad_shape) : grad;
4607   if (zero_length > 0) {
4608     complex_full_grad.slice_symint(last_dim, 0, half_sizes[last_dim])
4609         .copy_(grad);
4610   }
4611   return at::real(
4612       at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false));
4613 }
4614 
4615 // Helper for batchnorm_double_backward
sum_exclude_dim1(const Tensor & to_sum,bool keepdim=true)4616 static Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim = true) {
4617   auto r = to_sum.sum(0, keepdim);
4618   int64_t start_point_exclusive = keepdim ? 1 : 0;
4619   for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) {
4620     r = r.sum(dim, keepdim);
4621   }
4622   return r;
4623 }
4624 
4625 // Helper for batchnorm_double_backward
4626 // similar to expand_as below, but doesn't do the expand_as; operates as if
4627 // reductions were done with keepdim=True
unsqueeze_dim1(const Tensor & src,const Tensor & target)4628 static Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
4629   auto src_expanded = src;
4630   while (src_expanded.sizes().size() < target.sizes().size() - 1) {
4631     src_expanded = src_expanded.unsqueeze(1);
4632   }
4633   if (src_expanded.sizes().size() == target.sizes().size() - 1) {
4634     src_expanded = src_expanded.unsqueeze(0);
4635   }
4636   return src_expanded;
4637 }
4638 
4639 // Helper for batchnorm_double_backward
4640 // because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't
4641 // do a straight expansion because it won't follow the broadcasting rules.
expand_as_dim1(const Tensor & src,const Tensor & target)4642 static Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
4643   auto src_expanded = src;
4644   while (src_expanded.sizes().size() < target.sizes().size() - 1) {
4645     src_expanded = src_expanded.unsqueeze(1);
4646   }
4647   return src_expanded.expand_as(target);
4648 }
4649 
batchnorm_double_backward(const Tensor & input,const std::optional<Tensor> & gamma,const Tensor & ggI,const Tensor & ggG,const Tensor & ggB,const Tensor & gO,const std::optional<Tensor> & running_mean,const std::optional<Tensor> & running_var,bool training,double eps,const std::optional<Tensor> & save_mean,const std::optional<Tensor> & save_invstd,std::array<bool,3> output_mask)4650 std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
4651     const Tensor& input,
4652     const std::optional<Tensor>& gamma,
4653     const Tensor& ggI,
4654     const Tensor& ggG,
4655     const Tensor& ggB,
4656     const Tensor& gO,
4657     const std::optional<Tensor>& running_mean,
4658     const std::optional<Tensor>& running_var,
4659     bool training,
4660     double eps,
4661     const std::optional<Tensor>& save_mean,
4662     const std::optional<Tensor>& save_invstd,
4663     std::array<bool, 3> output_mask) {
4664   bool affine = isDefined(gamma);
4665   // TODO: Do we have a ScalarOrTensor type?  Would such a thing exist?
4666   Tensor gamma_expanded;
4667   Tensor ggG_expanded, ggB_expanded;
4668   if (affine) {
4669     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
4670     gamma_expanded = expand_as_dim1(*gamma, input);
4671     if (ggG.defined()) {
4672       ggG_expanded = expand_as_dim1(ggG, input);
4673     }
4674     if (ggB.defined()) {
4675       ggB_expanded = expand_as_dim1(ggB, input);
4676     }
4677   } else {
4678     gamma_expanded = at::ones({}, input.options());
4679   }
4680 
4681   // define some terms we will reuse
4682   auto M = input.size(0);
4683   for (auto s : input.sizes().slice(2)) {
4684     M *= s;
4685   }
4686   // for half inputs, save_mean, save_invstd are float (ideally, we would cast
4687   // everything else, but not now)
4688   auto mu = unsqueeze_dim1(
4689       training ? toNonOptTensor(save_mean).to(input.scalar_type())
4690                : toNonOptTensor(running_mean),
4691       input);
4692   auto input_sub_mu = input - mu;
4693   auto sigma2_eps_neg_1_2 = unsqueeze_dim1(
4694       training ? toNonOptTensor(save_invstd).to(input.scalar_type())
4695                : toNonOptTensor(running_var).add(Scalar(eps)).pow(-0.5),
4696       input);
4697   auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
4698   auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
4699 
4700   // calculate gI
4701   auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
4702   auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu);
4703   auto gO_sum = sum_exclude_dim1(gO);
4704 
4705   Tensor gI;
4706   if (ggI.defined() && training) {
4707     auto ggI_sum = sum_exclude_dim1(ggI);
4708     auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu);
4709     auto all_sub = ((ggI_sum * gO_sum).div_(M))
4710                        .sub_(sum_exclude_dim1(gO * ggI))
4711                        .add_((sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum)
4712                                  .mul_(3. / static_cast<double>(M)));
4713     auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M);
4714     auto gI_1t =
4715         (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO);
4716     auto gI_2t =
4717         (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI);
4718     gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t));
4719   }
4720 
4721   // add contribution of gamma term to gI
4722   Tensor gI_G_term;
4723   if (affine && ggG.defined()) {
4724     if (training) {
4725       auto t0 = gO * sigma2_eps_neg_1_2;
4726       auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M);
4727       auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu))
4728                     .div_(-M);
4729       gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2));
4730       gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
4731     } else {
4732       gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO;
4733       gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
4734     }
4735   }
4736 
4737   // this is the first backward's grad_input
4738   auto first_back_grad_input = [&](const Tensor& gO,
4739                                    const Tensor& gamma) -> Tensor {
4740     auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M);
4741     auto h1 = (M * gO)
4742                   .sub_(sum_exclude_dim1(gO))
4743                   .sub_(
4744                       input_sub_mu.mul(sigma2_eps_neg_1) *
4745                       sum_exclude_dim1(gO * input_sub_mu));
4746     return h0 * h1;
4747   };
4748 
4749   // calculate gG
4750   Tensor gG;
4751   if (affine && ggI.defined()) {
4752     if (training) {
4753       // gG is just the first backwards with the gamma term removed (then shaped
4754       // properly)
4755       gG = ggI *
4756           first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options()));
4757       gG = sum_exclude_dim1(gG, false);
4758     } else {
4759       gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false);
4760     }
4761   }
4762 
4763   // calculate ggO
4764   Tensor ggO;
4765   // contribution of input term
4766   if (ggI.defined()) {
4767     if (training) {
4768       ggO = first_back_grad_input(ggI, gamma_expanded);
4769     } else {
4770       ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded;
4771     }
4772   }
4773   if (ggG.defined()) {
4774     auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
4775     ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
4776   }
4777   if (ggB.defined()) {
4778     auto ggO_B_term = std::move(ggB_expanded);
4779     ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
4780   }
4781 
4782   if (output_mask[1] && !gG.defined()) {
4783     AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
4784   }
4785 
4786   return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
4787 }
4788 
layer_norm_double_backward(const Tensor & input_t,const std::optional<Tensor> & gamma,const Tensor & ggI,const Tensor & ggG,const Tensor & ggB,const Tensor & gO_t,const Tensor & save_mean_t,const Tensor & save_invstd_t,c10::SymIntArrayRef normalized_shape,std::array<bool,3> output_mask)4789 std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
4790     const Tensor& input_t,
4791     const std::optional<Tensor>& gamma,
4792     const Tensor& ggI,
4793     const Tensor& ggG,
4794     const Tensor& ggB,
4795     const Tensor& gO_t,
4796     const Tensor& save_mean_t,
4797     const Tensor& save_invstd_t,
4798     c10::SymIntArrayRef normalized_shape,
4799     std::array<bool, 3> output_mask) {
4800   const auto normalized_ndim = normalized_shape.size();
4801   const auto input_shape = input_t.sizes();
4802   const auto input_ndim = input_t.dim();
4803   const auto axis = input_ndim - normalized_ndim;
4804   const int64_t M =
4805       c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
4806   const int64_t N =
4807       c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
4808   // printf("M: %ld, N: %ld", M, N);
4809 
4810   auto input = input_t.reshape({M, N});
4811   auto gO = gO_t.reshape({M, N});
4812   auto save_mean = save_mean_t.reshape({M, 1});
4813   auto save_invstd = save_invstd_t.reshape({M, 1});
4814 
4815   bool affine = isDefined(gamma);
4816   Tensor gamma_expanded;
4817   Tensor ggG_expanded, ggB_expanded;
4818   if (affine) {
4819     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
4820     gamma_expanded = gamma->reshape({1, N});
4821     if (ggG.defined()) {
4822       ggG_expanded = ggG.reshape({1, N});
4823     }
4824     if (ggB.defined()) {
4825       ggB_expanded = ggB.reshape({1, N});
4826     }
4827   } else {
4828     gamma_expanded = at::ones({1}, input.options());
4829   }
4830 
4831   Tensor ggI_expanded;
4832   if (ggI.defined()) {
4833     ggI_expanded = ggI.reshape({M, N});
4834   }
4835 
4836   // for half inputs, save_mean, save_invstd are float
4837   // (ideally, we would cast everything else, but not now)
4838   auto mu = save_mean.to(input.scalar_type());
4839   auto input_sub_mu = input - mu;
4840   auto sigma2_eps_neg_1_2 = save_invstd.to(input.scalar_type());
4841   auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
4842   auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
4843 
4844   Tensor gI;
4845   // calculate gI
4846   auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
4847 
4848   if (ggI.defined()) {
4849     auto gxhat = gO * gamma_expanded;
4850     auto gxhat_mu_sum = (gxhat * input_sub_mu).sum(1, true);
4851     auto gxhat_sum = gxhat.sum(1, true);
4852 
4853     auto ggI_sum = ggI_expanded.sum(1, true);
4854     auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true);
4855 
4856     auto all_sub = ((ggI_sum * gxhat_sum).div_(N))
4857                        .sub_((ggI_expanded * gxhat).sum(1, true))
4858                        .add_((sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum)
4859                                  .mul_(3. / static_cast<double>(N)));
4860     auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N);
4861     auto gI_1t =
4862         (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat);
4863     auto gI_2t = (gxhat_mu_sum * sigma2_eps_neg_3_2).div_(N) *
4864         (ggI_sum.div(N) - ggI_expanded);
4865 
4866     gI = (gI_0t.add_(gI_1t).add_(gI_2t));
4867   }
4868 
4869   // add contribution of gamma term to gI
4870   if (affine && ggG.defined()) {
4871     auto t0 = gO * ggG_expanded * sigma2_eps_neg_1_2;
4872     auto t1 = (sigma2_eps_neg_1_2 * (gO * ggG_expanded).sum(1, true)).div_(-N);
4873     auto t2 = (input_mu_sigma2_neg_3_2 *
4874                (gO * ggG_expanded * input_sub_mu).sum(1, true))
4875                   .div_(-N);
4876     auto gI_G_term = t0.add_(t1).add_(t2);
4877     gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
4878   }
4879 
4880   if (gI.defined()) {
4881     // printf("=== computing gI\n");
4882     gI = gI.reshape_as(input_t);
4883   }
4884 
4885   // this is the grad_input for the first backward function
4886   auto first_bwd_fn_grad_input = [&](const Tensor& gO_local,
4887                                      const Tensor& gamma_local) -> Tensor {
4888     auto h0 = (gamma_local * sigma2_eps_neg_1_2).div_(N);
4889     auto h1 = (N * gO_local)
4890                   .sub_(gO_local.sum(1, true))
4891                   .sub_(
4892                       input_sub_mu.mul(sigma2_eps_neg_1) *
4893                       (gO_local * input_sub_mu).sum(1, true));
4894     return h0 * h1;
4895   };
4896 
4897   // calculate gG
4898   Tensor gG;
4899   if (affine && ggI.defined()) {
4900     gG = first_bwd_fn_grad_input(
4901         ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options()));
4902     gG = (gO * gG).sum(0);
4903     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
4904     gG = gG.reshape_as(*gamma);
4905   }
4906 
4907   // calculate ggO
4908   Tensor ggO;
4909   // contribution of input term
4910   if (ggI.defined()) {
4911     ggO = first_bwd_fn_grad_input(ggI_expanded, gamma_expanded);
4912   }
4913   if (ggG.defined()) {
4914     auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
4915     ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
4916   }
4917   if (ggB.defined()) {
4918     auto ggO_B_term = std::move(ggB_expanded);
4919     ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
4920   }
4921   if (ggO.defined()) {
4922     ggO = ggO.expand({M, N}).reshape_as(input_t);
4923   }
4924 
4925   if (output_mask[1] && !gG.defined()) {
4926     AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
4927   }
4928 
4929   return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
4930 }
4931 
4932 std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_group_norm_backward(const Tensor & dY,const Tensor & dmean,const Tensor & drstd,const Tensor & X,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & gamma,c10::SymInt N,const c10::SymInt & C,c10::SymInt HxW,int64_t group,double eps,std::array<bool,3> grad_input_mask)4933 infinitely_differentiable_native_group_norm_backward(
4934     const Tensor& dY,
4935     const Tensor& dmean,
4936     const Tensor& drstd,
4937     const Tensor& X,
4938     const Tensor& mean,
4939     const Tensor& rstd,
4940     const std::optional<Tensor>& gamma,
4941     c10::SymInt N,
4942     const c10::SymInt& C,
4943     c10::SymInt HxW,
4944     int64_t group,
4945     double eps,
4946     std::array<bool, 3> grad_input_mask) {
4947   const int64_t G = group;
4948   const auto D = C / G;
4949   c10::SymFloat s = c10::SymFloat(1.0) / c10::SymFloat(D * HxW);
4950   Tensor dX;
4951   Tensor dgamma;
4952   Tensor dbeta;
4953   const Tensor X_tensor = X.reshape_symint({N, G, D, HxW});
4954   const Tensor mean_tensor = mean.reshape_symint({N, G, 1, 1});
4955   const Tensor rstd_tensor = rstd.reshape_symint({N, G, 1, 1});
4956   Tensor dY_tensor;
4957   Tensor ds;
4958   Tensor db;
4959   if (dY.defined()) {
4960     dY_tensor = dY.reshape_symint({N, G, D, std::move(HxW)});
4961     ds = (dY_tensor * X_tensor).sum(3).unsqueeze_(-1);
4962     db = dY_tensor.sum(3).unsqueeze_(-1);
4963   }
4964   if (grad_input_mask[0]) {
4965     Tensor gamma_tensor;
4966     if (isDefined(gamma)) {
4967       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
4968       gamma_tensor = gamma->reshape_symint({1, G, D, 1});
4969     }
4970     const Tensor var =
4971         ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
4972     const Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
4973     Tensor dvar;
4974     if (drstd.defined()) {
4975       dvar = -0.5 * rstd_cube * drstd.view_symint({N, G, 1, 1});
4976     }
4977     if (dY.defined()) {
4978       const Tensor a =
4979           isDefined(gamma) ? rstd_tensor * gamma_tensor : rstd_tensor;
4980       Tensor b = (isDefined(gamma) ? (ds * gamma_tensor).sum(2) : ds.sum(2))
4981                      .unsqueeze_(-2);
4982       Tensor c = (isDefined(gamma) ? (db * gamma_tensor).sum(2) : db.sum(2))
4983                      .unsqueeze_(-2);
4984       b = (c * mean_tensor - b) * rstd_cube * s;
4985       c = -b * mean_tensor - c * rstd_tensor * std::move(s);
4986       dX = a * dY_tensor + b * X_tensor + c;
4987       if (dmean.defined() && drstd.defined()) {
4988         dX += var_mean_backward(
4989             dvar,
4990             dmean.view_symint({std::move(N), G, 1, 1}),
4991             X_tensor,
4992             IntArrayRef{2, 3},
4993             0,
4994             true);
4995       }
4996       dX = dX.reshape_as(X);
4997     } else if (dmean.defined() && drstd.defined()) {
4998       dX = var_mean_backward(
4999                dvar,
5000                dmean.view_symint({std::move(N), G, 1, 1}),
5001                X_tensor,
5002                IntArrayRef{2, 3},
5003                0,
5004                true)
5005                .reshape_as(X);
5006     }
5007   }
5008   if (grad_input_mask[1] && dY.defined()) {
5009     dgamma = ((ds - db * mean_tensor) * rstd_tensor)
5010                  .sum(0)
5011                  .reshape_as(toNonOptTensor(gamma));
5012   }
5013   if (grad_input_mask[2] && dY.defined()) {
5014     dbeta = db.sum(0).reshape_as(toNonOptTensor(gamma));
5015   }
5016 
5017   return std::make_tuple(dX, dgamma, dbeta);
5018 }
5019 
_trilinear_backward(const Tensor & grad_out,const std::optional<Tensor> & i1,const std::optional<Tensor> & i2,const std::optional<Tensor> & i3,IntArrayRef expand1,IntArrayRef expand2,IntArrayRef expand3,IntArrayRef sumdim,std::array<bool,3> grad_mask)5020 std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(
5021     const Tensor& grad_out,
5022     const std::optional<Tensor>& i1,
5023     const std::optional<Tensor>& i2,
5024     const std::optional<Tensor>& i3,
5025     IntArrayRef expand1,
5026     IntArrayRef expand2,
5027     IntArrayRef expand3,
5028     IntArrayRef sumdim,
5029     std::array<bool, 3> grad_mask) {
5030   Tensor grad_i1, grad_i2, grad_i3;
5031   if (grad_out.defined()) {
5032     if (grad_mask[0])
5033       grad_i1 =
5034           // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
5035           at::_trilinear(grad_out, *i2, *i3, sumdim, expand2, expand3, expand1);
5036     if (grad_mask[1])
5037       grad_i2 =
5038           // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
5039           at::_trilinear(*i1, grad_out, *i3, expand1, sumdim, expand3, expand2);
5040     if (grad_mask[2])
5041       grad_i3 =
5042           // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
5043           at::_trilinear(*i1, *i2, grad_out, expand1, expand2, sumdim, expand3);
5044   }
5045   return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
5046 }
5047 
log1p_backward(const Tensor & grad,const Tensor & self)5048 Tensor log1p_backward(const Tensor& grad, const Tensor& self) {
5049   // We must conditionally initialize this using to_dense if sparse, sparse
5050   // addition is not supported without exact shape match
5051   Tensor self_p1_conj;
5052   if (self.layout() == c10::kSparse || self.layout() == c10::kSparseCsr ||
5053       self.layout() == c10::kSparseCsc || self.layout() == c10::kSparseBsr ||
5054       self.layout() == c10::kSparseBsc) {
5055     // The warning only applies to the sparsity of self, dense grad is never
5056     // materialized so if self is strided and grad is sparse nothing unexpected
5057     // happens memory wise
5058     TORCH_WARN(
5059         "log1p_backward: received self with sparse layout, but backward requires materialization of a dense tensor with this shape");
5060     self_p1_conj = (self.to_dense() + 1).conj();
5061   } else {
5062     // Although calling self.to_dense() would just return self when it has
5063     // strided layout, that would breaks functorch tests.
5064     self_p1_conj = (self + 1).conj();
5065   }
5066   if (grad.layout() == c10::kSparse || grad.layout() == c10::kSparseCsr ||
5067       grad.layout() == c10::kSparseCsc || grad.layout() == c10::kSparseBsr ||
5068       grad.layout() == c10::kSparseBsc) {
5069     // If grad is sparse we can't divide by the n-d (self + 1).conj(), so we
5070     // must multiply by the recipricol, layout of grad is preserved which is
5071     // important to gradcheck
5072     return grad * self_p1_conj.reciprocal_();
5073   }
5074   return grad / self_p1_conj;
5075 }
5076 
sinc_backward(const Tensor & grad,const Tensor & self)5077 Tensor sinc_backward(const Tensor& grad, const Tensor& self) {
5078   auto self_pi = self * M_PI;
5079   auto self_squared_pi = self * self * M_PI;
5080   auto out = grad *
5081       ((self_pi * self_pi.cos() - self_pi.sin()) / self_squared_pi).conj();
5082   return at::where(self_squared_pi == 0.0, at::zeros({}, grad.options()), out);
5083 }
5084 
5085 // Because the backward of pad(input, pads) is just pad(grad_output, [-p for p
5086 // in pads])
constant_pad_nd_backward(const Tensor & grad,c10::SymIntArrayRef pad)5087 Tensor constant_pad_nd_backward(const Tensor& grad, c10::SymIntArrayRef pad) {
5088   auto negated_pad = pad.vec();
5089   std::transform(
5090       negated_pad.cbegin(),
5091       negated_pad.cend(),
5092       negated_pad.begin(),
5093       // NOLINTNEXTLINE(modernize-use-transparent-functors)
5094       std::negate<c10::SymInt>());
5095   return at::constant_pad_nd_symint(grad, negated_pad, 0);
5096 }
5097 
embedding_dense_double_backward_symint(const Tensor & grad,const Tensor & indices,const c10::SymInt & padding_idx)5098 Tensor embedding_dense_double_backward_symint(
5099     const Tensor& grad,
5100     const Tensor& indices,
5101     const c10::SymInt& padding_idx) {
5102   // since first backward takes care of scaling by frequency,
5103   // we don't need to worry about it here.
5104   auto gg_weight = grad.index_select(0, indices.reshape(-1));
5105 
5106   // reshape gradient as per the shape of indices
5107   auto size = indices.sizes().vec();
5108   size.push_back(-1);
5109 
5110   if (padding_idx >= 0) {
5111     gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
5112   }
5113   return gg_weight.view(size);
5114 }
5115 
index_backward(Tensor zeros_like_self,const torch::List<std::optional<Tensor>> & indices,const Tensor & grad)5116 Tensor index_backward(
5117     Tensor zeros_like_self,
5118     const torch::List<std::optional<Tensor>>& indices,
5119     const Tensor& grad) {
5120   return (areAnyTensorSubclassLike({zeros_like_self, grad}) ||
5121           areAnyOptionalTensorSubclassLike(indices))
5122       ? zeros_like_self.index_put(indices, grad, true)
5123       : at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
5124 }
5125 
_cudnn_ctc_loss_backward(const Tensor & grad_out,const Tensor & loss,const Tensor & raw_grad,bool zero_infinity)5126 Tensor _cudnn_ctc_loss_backward(
5127     const Tensor& grad_out,
5128     const Tensor& loss,
5129     const Tensor& raw_grad,
5130     bool zero_infinity) {
5131   if (zero_infinity) {
5132     return at::where(
5133         loss.unsqueeze(0).unsqueeze(2) == 0,
5134         at::zeros({}, raw_grad.options()),
5135         raw_grad * grad_out.unsqueeze(0).unsqueeze(2));
5136   } else {
5137     return raw_grad * grad_out.unsqueeze(0).unsqueeze(2);
5138   }
5139 }
5140 
any_variable_defined(const variable_list & variables)5141 bool any_variable_defined(const variable_list& variables) {
5142   for (const auto& variable : variables) {
5143     if (variable.defined()) {
5144       return true;
5145     }
5146   }
5147   return false;
5148 }
5149 
5150 // Derivations for the householder_product.backward method.
5151 //
5152 // Given a sequence of vectors v_1, ..., v_n and a sequence of scalars tau_1,
5153 // ..., tau_k, the torch.linalg.householder_product computes the firt n columns
5154 // of the following product: Q = (I - tau_1 v_1 v_1^H) ... (I - tau_k v_k
5155 // v_k^H). Let
5156 //     H_i(sigma) := I - sigma v_i v_i^H, so Q = (H_1(sigma_1) ...
5157 //     H_k(sigma_k)[:, :k]; H_i_minus = H_1(tau_1) ... H_{i - 1}(tau_{i - 1}),
5158 //     with H_1_minus := I; H_i_plus = H_{i + 1}(tau_{i + 1}) ... H_k(tau_k)
5159 //     with H_k_plus := I;
5160 //
5161 // Forward AD:
5162 // dQ = sum_{i = 1}^k H_i_minus (-dtau_i v_i v_i^H - tau_i dv_i v_i^H - tau_i
5163 // v_i dv_i^H) H_i_plus.
5164 //
5165 // Backward AD:
5166 // Tr(Q_grad^H dQ) = sum_{i = 1}^k Tr(H_i_plus Q_grad^H H_i_minus (-dtau_i v_i
5167 // v_i^H - tau_i dv_i v_i^H - tau_i v_i dv_i^H)). Let K_i := H_i_plus Q_grad^H
5168 // H_i_minus, then the gradients are v_i_grad = (-tau_i v_i^H K_i)^H - tau_i K_i
5169 // v_i, tau_i_grad = Tr(-v_i^H K_i v_i).conj(). NOTE: the algorithms ignores
5170 // that only n columns of Q are observed, so there is no need in recomputing Q
5171 // to full completion.
5172 //
5173 // Note that K_{i + 1} = H_{i + 1}^{-1} K_i H_i, so we can compute v_i_grad,
5174 // tau_i_grad one by one by just efficiently updating K_i if that is possible.
5175 // Multiplying with H_i from the right could be done with matrix-vector
5176 // products, but what about the inverse H_{i + 1}^{-1} and does it even exist?
5177 // Luckily, under some assumptions, H_{i + 1}^{-1} exists and admits a
5178 // representation as H_i(sigma_i) for some sigma_i, so the left update is also
5179 // could be done with matrix-vector and not matrix-matrix products.
5180 //
5181 // Let H(tau) := I - tau v v^H.
5182 // H(tau) has eigenvalues 1 with multiplicity (m - 1) with eigenvectors
5183 // orthogonal to v, and an eigenvalue (1 - tau ||v||^2) with the corresponding
5184 // eigenvector v / ||v||. If (1 - tau ||v||^2) != 0, H(tau) is invertible. If (1
5185 // - tau ||v||^2) != 0, then with sigma defined as sigma := tau / (||v||^2 tau -
5186 // 1) we get that H(tau) H(sigma) = H(sigma) H(tau) = I, so H(sigma) is the
5187 // inverse of H(tau).
5188 //
5189 // WARNING: the algorithm below assumes that H_i(tau_i) are all invertible, so
5190 // it expects that (1 - tau_i ||v_i||^2) != 0 for all i.
5191 // We would like to point out that if there is H_i(tau_i) which is not
5192 // invertible, the householder_product is still differentiable! We will not be
5193 // able to compute K_i efficiently in such cases, however, as evaluating of each
5194 // K_i will amount to calls to ORGQR to be able to compute H_i_plus.
5195 
5196 // This function computes either the product between
5197 // (I - tau u v^H) and K (in-place or not) with `condition_with_I = true`, or
5198 // between
5199 // (-tau u v^H) and K (out-of-place only) with `condition_with_I = false`.
5200 // Parameter `left` controls whether the matrix K is multiplied from the left or
5201 // from the right.
5202 // Additionally, when the computation is done in-place, we exploit that the
5203 // first `k` coordinates of `u_full/v_full` are zeros.
apply_simple_transformation(const c10::SymInt & m,const c10::SymInt & k,const Tensor & u_full,const Tensor & v_full,const Tensor & t,Tensor & K,bool modify_K_in_place=true,bool condition_with_I=true,bool left=true)5204 static Tensor apply_simple_transformation(
5205     const c10::SymInt& m,
5206     const c10::SymInt& k,
5207     const Tensor& u_full,
5208     const Tensor& v_full,
5209     const Tensor& t,
5210     Tensor& K,
5211     bool modify_K_in_place = true,
5212     bool condition_with_I = true,
5213     bool left = true) {
5214   // we assume u_full is a vector of dimension (..., m, 1), t is a scalar of
5215   // dimension (..., 1)
5216 
5217   // TODO: matrix-vector products in the code below are dispatched to
5218   // matrix-matrix products. We either need to extend matmul to support batched
5219   // matrix-vector products, or implement a batched variant of mv. We could
5220   // enable mv for inputs which are not batched, but it is not done to eliminate
5221   // the code duplication.
5222 
5223   // returns (I - t u v^H) K or -t u v^H K
5224   if (left) {
5225     if (modify_K_in_place) {
5226       auto v = u_full.narrow_symint(-2, k, m - k);
5227       auto u = v_full.narrow_symint(-2, k, m - k)
5228                    .mH()
5229                    .matmul(K.narrow_symint(-2, k, m - k));
5230       K.narrow_symint(-2, k, m - k).sub_((t.unsqueeze(-1) * v) * u);
5231       return K;
5232     } else {
5233       auto transformation = (t.unsqueeze(-1) * u_full) * v_full.mH().matmul(K);
5234       return condition_with_I ? K - transformation : -transformation;
5235     }
5236   }
5237   // returns K (I - t u v^H) or -K t u v^H
5238   else {
5239     if (modify_K_in_place) {
5240       auto v = u_full.narrow_symint(-2, k, m - k);
5241       auto u =
5242           K.narrow_symint(-1, k, m - k)
5243               .matmul(t.unsqueeze(-1) * v_full.narrow_symint(-2, k, m - k));
5244       K.narrow_symint(-1, k, m - k).sub_(u * v.mH());
5245       return K;
5246     } else {
5247       auto transformation = K.matmul(t.unsqueeze(-1) * u_full) * v_full.mH();
5248       return condition_with_I ? K - transformation : -transformation;
5249     }
5250   }
5251 };
5252 
householder_product_backward(const Tensor & grad,const Tensor & result,const Tensor & input_,const Tensor & tau,const bool flip_order)5253 std::tuple<Tensor, Tensor> householder_product_backward(
5254     const Tensor& grad,
5255     const Tensor& result,
5256     const Tensor& input_,
5257     const Tensor& tau,
5258     const bool flip_order) {
5259   // NOTE on `flip_order`: when flip_order is true,
5260   // the algorithm below reverses the processing direction from
5261   // range(k) to range(k - 1, -1, -1) in the main loop, and left/right
5262   // Householder projection applications get flipped.
5263   // The comments below about the algorithmic details assume flip_order = false.
5264   if (!grad.defined() || input_.sym_numel() == 0 || tau.sym_numel() == 0) {
5265     return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
5266   }
5267   auto m = input_.sym_size(-2);
5268   // guard_int is due to irange calls below
5269   auto k = tau.sym_size(-1).guard_int(__FILE__, __LINE__);
5270 
5271   // forward operates only over the lower triangular part with the assumption
5272   // that the diagonal of input is filled with 1s.
5273   auto input = input_.tril(-1);
5274   input.diagonal(0, -2, -1).fill_(1.0);
5275 
5276   // compute sigma such that
5277   // H(sigma_i) == H(tau_i)^{-1}.
5278   // If the input to householder_product comes from GEQRF,
5279   // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be
5280   // invertible. This follows from the documentation
5281   // https://www.netlib.org/lapack/lug/node128.html, and tau always satisfying
5282   // the condition |tau|^2 ||v||^2 == 2 * Re(tau).
5283   auto input_first_k_cols = input.narrow(-1, 0, k);
5284   auto input_first_k_cols_norm_squared =
5285       (input_first_k_cols * input_first_k_cols.conj()).sum(-2);
5286   auto sigma = tau / (tau * input_first_k_cols_norm_squared - 1.0);
5287 
5288   auto K = result.matmul(grad.mH());
5289 
5290   // The algorithm updates K by multiplying it from the left/right with
5291   // Householder reflectors. If only single backward is run, we modify K
5292   // in-place and exploit triangularity of the input. With higher order
5293   // derivatives we cannot rewrite the storage of K, hence we use much less
5294   // efficient out-of-place methods.
5295   //
5296   // if only first-order derivative is expected, we can modify K in-place for
5297   // better performance
5298   bool modify_K_in_place = !at::GradMode::is_enabled();
5299 
5300   // This method exploits that at k-th iteration vector v_k has only elements
5301   // v_k[k:] which are non-zero.
5302   auto update_grad = [&m](
5303                          int64_t k,
5304                          const Tensor& v_full,
5305                          const Tensor& t,
5306                          const Tensor& K) -> std::tuple<Tensor, Tensor> {
5307     // v_full is a vector of dimension (..., m, 1), t is a scalar of dimension
5308     // (..., 1)
5309     auto v = v_full.narrow_symint(-2, k, m - k);
5310     auto vHK = v.mH().matmul(K.narrow_symint(-2, k, m - k));
5311     auto Kv = K.narrow_symint(-1, k, m - k).matmul(v);
5312     auto t_unsqueezed = t.unsqueeze(-1);
5313     auto v_grad = (-t_unsqueezed * vHK).conj().squeeze(-2) -
5314         (t_unsqueezed * Kv).squeeze(-1);
5315     auto tau_grad = -(vHK.narrow_symint(-1, k, m - k).matmul(v)).conj();
5316     return std::make_tuple(v_grad.unsqueeze(-1), tau_grad.squeeze(-1));
5317   };
5318 
5319   auto apply_householder_reflector = [m, modify_K_in_place](
5320                                          int64_t k,
5321                                          const Tensor& v_full,
5322                                          const Tensor& t,
5323                                          Tensor& K,
5324                                          bool left = true) -> Tensor {
5325     return apply_simple_transformation(
5326         m,
5327         k,
5328         v_full,
5329         v_full,
5330         t,
5331         K,
5332         modify_K_in_place,
5333         /*condition_with_I=*/true,
5334         left);
5335   };
5336 
5337   const auto flip_i = [flip_order, k](int64_t i) -> int64_t {
5338     return !flip_order ? i : k - i - 1;
5339   };
5340   const auto next_i = [flip_order](int64_t i) -> int64_t {
5341     return !flip_order ? ++i : --i;
5342   };
5343   const auto apply_left = !flip_order;
5344 
5345   // K <- H_0^{-1} @ K
5346   const auto zero_idx = flip_i(0);
5347   K = apply_householder_reflector(
5348       zero_idx,
5349       input.narrow(-1, zero_idx, 1),
5350       sigma.narrow(-1, zero_idx, 1),
5351       K,
5352       /*left=*/apply_left);
5353 
5354   Tensor input_grad, tau_grad;
5355   // For Composite Compliance, we can't copy a Subclass into a Regular Tensor,
5356   // so we use out-of-place ops with equivalent output.
5357   // NOTE: We can't use `new_zeros` directly as `input`, 'tau' or `grad` can
5358   // be Tensor Subclass and we don't want to make assumption about which
5359   // one to choose for creating output buffer.
5360   // eg. if both are BatchedTensor at different level.
5361   if (areAnyTensorSubclassLike({input, tau, K})) {
5362     // k + 1 if input_grads hold a matrix of zeros for inactive parts of input.
5363     auto input_grads = std::vector<Tensor>(k < input.sym_size(-1) ? k + 1 : k);
5364     auto tau_grads = std::vector<Tensor>(k);
5365 
5366     for (const auto i_idx : c10::irange(k)) {
5367       auto i = flip_i(i_idx);
5368       // NOTE: narrow will unsqueeze(-1)
5369       auto v_i = input.narrow(-1, i, 1);
5370       auto t_i = tau.narrow(-1, i, 1);
5371 
5372       std::tie(input_grads[i], tau_grads[i]) = update_grad(i, v_i, t_i, K);
5373 
5374       // K <- H_{i + 1}^{-1} @ K @ H_i
5375       if (i != flip_i(k - 1)) {
5376         auto i_next = next_i(i);
5377         auto v_i_next = input.narrow(-1, i_next, 1);
5378         auto s_i_next = sigma.narrow(-1, i_next, 1);
5379         K = apply_householder_reflector(
5380             i_next, v_i_next, s_i_next, K, /*left=*/apply_left);
5381         K = apply_householder_reflector(i, v_i, t_i, K, /*left=*/!apply_left);
5382       }
5383     }
5384 
5385     // Only first k columns are active in forward.
5386     // zero gradients for the inactive input.
5387     if (k < input.sym_size(-1)) {
5388       auto zero_grad_shape =
5389           at::SymDimVector(input_.sym_sizes().slice(0, input_.dim() - 1));
5390       zero_grad_shape.push_back(input.sym_size(-1) - k);
5391       auto zero_grad = at::zeros_symint(zero_grad_shape, input_.options());
5392       input_grads[k] = zero_grad;
5393     }
5394 
5395     input_grad = at::cat(input_grads, -1);
5396     tau_grad = at::cat(tau_grads, -1);
5397   } else {
5398     input_grad = at::zeros_like(input_);
5399     tau_grad = at::zeros_like(tau);
5400     for (const auto i_idx : c10::irange(k)) {
5401       auto i = flip_i(i_idx);
5402       // NOTE: narrow will unsqueeze(-1)
5403       auto v_i = input.narrow(-1, i, 1);
5404       auto t_i = tau.narrow(-1, i, 1);
5405 
5406       auto [v_i_grad, tau_i_grad] = update_grad(i, v_i, t_i, K);
5407       input_grad.select(-1, i).copy_(v_i_grad.squeeze(-1));
5408       tau_grad.select(-1, i).copy_(tau_i_grad.squeeze(-1));
5409 
5410       // K <- H_{i + 1}^{-1} @ K @ H_i
5411       if (i != flip_i(k - 1)) {
5412         auto i_next = next_i(i);
5413         auto v_i_next = input.narrow(-1, i_next, 1);
5414         auto s_i_next = sigma.narrow(-1, i_next, 1);
5415         K = apply_householder_reflector(
5416             i_next, v_i_next, s_i_next, K, /*left=*/apply_left);
5417         K = apply_householder_reflector(i, v_i, t_i, K, /*left=*/!apply_left);
5418       }
5419     }
5420   }
5421 
5422   // forward operates only over the lower-triangular part of the input
5423   // excluding the main diagonal, hence the gradient is also lower-triangular.
5424   input_grad.tril_(-1);
5425 
5426   return std::make_tuple(input_grad, tau_grad);
5427 }
5428 
5429 // We refer to the derivations described above the method
5430 // `apply_simple_transformation`
householder_product_jvp(const Tensor & dV_,const Tensor & dtau,const Tensor & prod,const Tensor & V_,const Tensor & tau)5431 Tensor householder_product_jvp(
5432     const Tensor& dV_,
5433     const Tensor& dtau,
5434     const Tensor& prod,
5435     const Tensor& V_,
5436     const Tensor& tau) {
5437   auto m = V_.sym_size(-2);
5438   auto k = tau.size(-1);
5439 
5440   // forward operates only over the lower triangular part with the assumption
5441   // that the diagonal of input is filled with 1s.
5442   auto V = V_.tril(-1);
5443   V.diagonal(0, -2, -1).fill_(1.0);
5444   auto dV = dV_.tril(-1);
5445 
5446   // compute sigma such that
5447   // H(sigma_i) == H(tau_i)^{-1}.
5448   // If the input to householder_product comes from GEQRF,
5449   // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be
5450   // invertible. This follows from the documentation
5451   // https://www.netlib.org/lapack/lug/node128.html, and tau always satisfying
5452   // the condition |tau|^2 ||v||^2 == 2 * Re(tau).
5453   auto V_first_k_cols = V.narrow(-1, 0, k);
5454   auto V_first_k_cols_norm_squared =
5455       (V_first_k_cols * V_first_k_cols.conj()).sum(-2);
5456   auto sigma = tau / (tau * V_first_k_cols_norm_squared - 1.0);
5457 
5458   auto apply_householder_reflector = [m](const Tensor& v_full,
5459                                          const Tensor& t,
5460                                          Tensor& K,
5461                                          bool left = true) -> Tensor {
5462     return apply_simple_transformation(
5463         // setting `modify_K_in_place = true` causes CUDA memory leaks in OpInfo
5464         // tests of forward AD for that reason we ignore `k` by passing zero
5465         m,
5466         /*k=*/0,
5467         v_full,
5468         v_full,
5469         t,
5470         K,
5471         /*modify_K_in_place=*/false,
5472         /*condition_with_I=*/true,
5473         left);
5474   };
5475 
5476   // computes (-t u v^H) K
5477   auto apply_simple_product = [m](const Tensor& u_full,
5478                                   const Tensor& v_full,
5479                                   const Tensor& t,
5480                                   Tensor& K) -> Tensor {
5481     return apply_simple_transformation(
5482         // since ``modify_K_in_place = false`, we can ignore `k` and pass
5483         // arbitrary value
5484         m,
5485         /*k=*/0,
5486         u_full,
5487         v_full,
5488         t,
5489         K,
5490         /*modify_K_in_place=*/false,
5491         /*condition_with_I=*/false,
5492         /*left=*/true);
5493   };
5494 
5495   auto H_plus = prod.detach().clone();
5496   IntArrayRef batch_vector_shape(V.sizes().data(), V.dim() - 1);
5497   auto H_minus =
5498       at::diag_embed(at::ones({1}, V.options()).expand(batch_vector_shape));
5499 
5500   auto dprod = at::zeros_like(prod);
5501   for (const auto i : c10::irange(k)) {
5502     auto v_i = V.narrow(-1, i, 1);
5503     auto dv_i = dV.narrow(-1, i, 1);
5504     auto tau_i = tau.narrow(-1, i, 1);
5505     auto dtau_i = dtau.narrow(-1, i, 1);
5506     auto sigma_i = sigma.narrow(-1, i, 1);
5507 
5508     H_plus = apply_householder_reflector(v_i, sigma_i, H_plus, /*left=*/true);
5509 
5510     // `H_minus_dH_i_H_plus` = H_1 * ... * H_{i-1} dH_i * H_{i+1} * ...
5511     auto H_minus_dH_i_H_plus = H_minus.matmul(
5512         apply_simple_product(v_i, v_i, dtau_i, H_plus) +
5513         apply_simple_product(dv_i, v_i, tau_i, H_plus) +
5514         apply_simple_product(v_i, dv_i, tau_i, H_plus));
5515     // For Composite Compliance, if `intermediate` is a Tensor-Subclass,
5516     // we use out-of-place variant of add.
5517     if (at::isTensorSubclassLike(H_minus_dH_i_H_plus)) {
5518       dprod = dprod.add(H_minus_dH_i_H_plus);
5519     } else {
5520       dprod.add_(H_minus_dH_i_H_plus);
5521     }
5522 
5523     H_minus = apply_householder_reflector(v_i, tau_i, H_minus, /*left=*/false);
5524   }
5525 
5526   return dprod;
5527 }
5528 
ormqr_backward(const Tensor & grad,const Tensor & result,const Tensor & self,const Tensor & tau,const Tensor & other,bool left,bool transpose,std::array<bool,3> grad_output_mask)5529 std::tuple<Tensor, Tensor, Tensor> ormqr_backward(
5530     const Tensor& grad,
5531     const Tensor& result,
5532     const Tensor& self,
5533     const Tensor& tau,
5534     const Tensor& other,
5535     bool left,
5536     bool transpose,
5537     std::array<bool, 3> grad_output_mask) {
5538   Tensor self_grad, tau_grad, other_grad;
5539 
5540   if (!grad.defined()) {
5541     return std::make_tuple(self_grad, tau_grad, other_grad);
5542   }
5543 
5544   const auto self_requires_grad = grad_output_mask[0];
5545   const auto tau_requires_grad = grad_output_mask[1];
5546   const auto other_requires_grad = grad_output_mask[2];
5547 
5548   if (other_requires_grad) {
5549     other_grad = at::ormqr(self, tau, grad, left, !transpose);
5550   }
5551   if (self_requires_grad || tau_requires_grad) {
5552     if (left ^ transpose) {
5553       // Assume left = true and transpose = false. The case with
5554       // left = false and transpose = true is very much similar with just
5555       // transposed arguments passed into householder_product_backward.
5556       // Ormqr computes B = H_1 * ... * H_k * A.
5557       // The sensivity wrt H_i is given by (see notes in
5558       // householder_product_backward) Tr(H_i_plus B B_grad^H H_i_minus dH_i),
5559       // so, since householder_product_backward respects `for i in range(k)`, we
5560       // could reuse householder_product_backward with
5561       // householder_product_backward.grad = grad and
5562       // householder_product_backward.result = result.
5563       const auto hpb_grad = !transpose ? grad : grad.mH();
5564       const auto hpb_result = !transpose ? result : result.mH();
5565       std::tie(self_grad, tau_grad) =
5566           householder_product_backward(hpb_grad, hpb_result, self, tau);
5567     } else {
5568       // Assuming left = false and transpose = false. The case with
5569       // left = true and transpose = true is very much similar with just
5570       // transposed arguments passed into householder_product_backward.
5571       // In this case Ormqr computes B = H_1 * ... * H_k * A and the sensitivity
5572       // wrt H_i becomes Tr(H_i_plus B_grad^H B H_i_minus dH_k).
5573       // We could see that the role of `grad` and `result` in
5574       // householder_product_backward gets "swapped" and "transposed" and that
5575       // in order to compute H_k_grad efficiently we would need to compute grads
5576       // in reversed order (`for i in range(k - 1, -1, -1)`). Hence we reuse
5577       // householder_product_backward with householder_product_backward.grad =
5578       // result.mH, householder_product_backward.result = grad.mH,
5579       // householder_product_backward.flip_order = true.
5580       const auto hpb_grad = !transpose ? result.mH() : result;
5581       const auto hpb_result = !transpose ? grad.mH() : grad;
5582       std::tie(self_grad, tau_grad) = householder_product_backward(
5583           hpb_grad, hpb_result, self, tau, /*flip_order=*/true);
5584     }
5585   }
5586 
5587   return std::make_tuple(self_grad, tau_grad, other_grad);
5588 }
5589 
polar_backward(const Tensor & grad,const Tensor & result)5590 std::tuple<Tensor, Tensor> polar_backward(
5591     const Tensor& grad,
5592     const Tensor& result) {
5593   Tensor grad_abs, grad_angle;
5594   if (grad.defined()) {
5595     auto grad_conj = grad.conj();
5596     grad_abs = at::real(grad_conj * at::sgn(result));
5597     auto result_mul_1_j = result * Scalar(c10::complex<double>{0.0, 1.0});
5598     grad_angle = at::real(grad_conj * result_mul_1_j);
5599   }
5600   return std::make_tuple(grad_abs, grad_angle);
5601 }
5602 
i1_backward(const Tensor & grad,const Tensor & self,const Tensor & result)5603 Tensor i1_backward(
5604     const Tensor& grad,
5605     const Tensor& self,
5606     const Tensor& result) {
5607   return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "i1_backward", [&]() {
5608     // For x = 0, the correct gradient is 0.5,
5609     // however due to floating point computation we get NaN.
5610     // So we manually update gradient for x=0
5611     auto eps = std::numeric_limits<scalar_t>::epsilon();
5612     auto self_is_not_tiny = self.abs() > eps;
5613 
5614     // Following `where` is needed as `where` computes gradients,
5615     // even for the part which didn't affect the output.
5616     // Look at https://github.com/pytorch/pytorch/issues/52248
5617     // Update if and when this is fixed.
5618     auto safe_self =
5619         at::where(self_is_not_tiny, self, at::full({}, eps, self.options()));
5620     auto gradx = (safe_self.i0() - (result * safe_self.reciprocal()));
5621     return grad *
5622         at::where(self_is_not_tiny, gradx, at::full({}, 0.5, self.options()));
5623   });
5624 }
5625 
i1e_backward(const Tensor & grad,const Tensor & self,const Tensor & result)5626 Tensor i1e_backward(
5627     const Tensor& grad,
5628     const Tensor& self,
5629     const Tensor& result) {
5630   return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "i1e_backward", [&]() {
5631     // For x = 0, the correct gradient is 0.5,
5632     // however due to floating point computation we get NaN.
5633     // So we manually update gradient for x=0
5634     auto eps = std::numeric_limits<scalar_t>::epsilon();
5635     auto self_is_not_tiny = self.abs() > eps;
5636 
5637     // Following `where` is needed as `where` computes gradients,
5638     // even for the part which didn't affect the output.
5639     // Look at https://github.com/pytorch/pytorch/issues/52248
5640     // Update if and when this is fixed.
5641     auto safe_self =
5642         at::where(self_is_not_tiny, self, at::full({}, eps, self.options()));
5643     auto gradx =
5644         (at::special_i0e(safe_self) -
5645          result * (safe_self.sgn() + safe_self.reciprocal()));
5646     return grad *
5647         at::where(self_is_not_tiny, gradx, at::full({}, 0.5, self.options()));
5648   });
5649 }
5650 
5651 // lu_solve is a map (LU, P, B) -> (PLU)^{-1} B,
5652 // where LU = L + U - I and P is a permutation matrix, and is fixed.
5653 //
5654 // Let 1 = ones_like(LU),
5655 // 1_U = 1.triu(),
5656 // 1_L = 1.tril(-1)
5657 // * := the Hadamard (element-wise) product
5658 //
5659 // Forward AD:
5660 //
5661 // Let X := U^{-1} L^{-1} P^T B be the output of the function.
5662 // Also, the LU input of the function could be represented as
5663 // LU = (L - I) + U.
5664 //
5665 // Differentiating LU = L + U - I produces:
5666 // dLU = dL + dU.
5667 // Noting that dL and dU are lower- and upper-triangular, respectively,
5668 // and that the diagonal of L is never explicitly exposed, so
5669 // diag(dL) = 0, it follows
5670 // dL = dLU * 1_L,
5671 // dU = dLU * 1_U.
5672 //
5673 // Differentiating X = U^{-1} L^{-1} P^T B produces:
5674 // dX = dU^{-1} L^{-1} P^T B + U^{-1} dL^{-1} P^T B + U^{-1} L^{-1} P^T dB
5675 // Note that for any invertible matrix A we have A A^{-1} = I, hence
5676 // dA A^{-1} + A dA^{-1} = 0 => dA^{-1} = -A^{-1} dA A^{-1}.
5677 // Inserting it back into the definition of dX gives:
5678 // dX = -U^{-1} dU U^{-1} L^{-1} P^T B - U^{-1} L^{-1} dL L^{-1} P^T B + U^{-1}
5679 // L^{-1} P^T dB dX = -U^{-1} dU X - U^{-1} L^{-1} dL U X + U^{-1} L^{-1} P^T dB
5680 //
5681 // Backward AD:
5682 //
5683 // Using the definition of dL, dU from above:
5684 // Tr(L_grad^H dL) + Tr(U_grad^H dU) = Tr(L_grad^H (dLU * 1_L)) + Tr(U_grad^H
5685 // (dLU * 1_U))
5686 //                                   = [using Tr(A (B * C)) = Tr((A * B^T) C)
5687 //                                   = Tr((L_grad^H * 1_L^T) dLU) + Tr((U_grad^H
5688 //                                   * 1_U^T) dLU),
5689 // hence
5690 // LU_grad = L_grad * 1_L + U_grad * 1_U (!!!)
5691 //
5692 // Then, transposing the formula for dX above we get:
5693 // B_grad = P L^{-H} U^{-H} X_grad = lu_solve(X_grad, LU_data, LU_pivots,
5694 // /*adjoint=*/true) U_grad = -U^{-H} X_grad X^H L_grad = L^{-H} U_grad U^H
5695 // After inserting U_grad and L_grad into (!!!) we get the value for LU_grad.
5696 
linalg_lu_solve_LU(const Tensor & gX,const Tensor & LU,const Tensor & pivots,const Tensor & X,const bool left,const bool adjoint)5697 Tensor linalg_lu_solve_LU(
5698     const Tensor& gX,
5699     const Tensor& LU,
5700     const Tensor& pivots,
5701     const Tensor& X,
5702     const bool left,
5703     const bool adjoint) {
5704   // From linalg_lu_solve_jvp we have that:
5705   // left = True, adjoint = True: A^HX = B
5706   // left = True, adjoint = False: AX = B
5707   // left = False, adjoint = True: AX^H = B^H
5708   // left = False, adjoint = False: A^HX^H = B^H
5709   // let op_1(A) = A^H or op_1(A) = A according to the list above
5710   // same with op_2(X) and op_3(B)
5711   // We have that letting S = lu_solve(LU, pivots, dB, left, adjoint)
5712   // the JVP formula reads
5713   // if left != adjoint:
5714   //   dX = op_2(-U^{-1}(dU + L^{-1}dL U)op_2(X)) + S
5715   // else:
5716   //   dX = op_2(op_1(-op_3(X)^H P(LdUU^{-1} + dL)L^{-1} P^T)) + S
5717   // So computing the adjoint of this operation we get that, using an auxiliary
5718   // variable gR if left != adjoint:
5719   //   gR = U^{-H}op_2(-gX)op_2(X)^H
5720   //   gU = gR.triu()
5721   //   gL = (L^{-H} gR U^H).tril(-1)
5722   // else:
5723   //   gR = -P^T op_3(X)op_1(op_2(gX))PL^{-H}
5724   //   gL = gR.tril(-1)
5725   //   gU = (L^H gR U^{-H}).triu()
5726   // gLU = gL + gU
5727 
5728   at::NoTF32Guard disable_tf32;
5729   auto [P, L, U] = at::lu_unpack(
5730       LU, pivots, /*unpack_data=*/true, /*unpack_pivots=*/left == adjoint);
5731   // TODO Optimise the order of the operations to avoid operating on large
5732   // tensors unnecessarily
5733   //      The logic should be: if n < k == left then multiply the gX and X first
5734   //      (as it's done now) Otherwise multiply them last
5735   if (left != adjoint) {
5736     // gR = U^{-H}op_2(-gX)op_2(X)^H
5737     auto gR = at::linalg_solve_triangular(
5738         U.mH(),
5739         -(left ? gX : gX.mH()).matmul(left ? X.mH() : X),
5740         /*upper*/ false);
5741     // gL = (L^{-H} gR U^H).tril(-1)
5742     auto gL = at::linalg_solve_triangular(
5743                   L.mH(),
5744                   gR.matmul(U.mH()),
5745                   /*upper*/ true,
5746                   /*left*/ true,
5747                   /*unitriangular*/ true)
5748                   .tril(-1);
5749     ;
5750     return gL + gR.triu();
5751   } else {
5752     // gR = -P^T op_3(X)op_1(op_2(gX))P
5753     auto gR =
5754         -P.mT().matmul(left ? X : X.mH()).matmul(left ? gX.mH() : gX).matmul(P);
5755     // gR = gR L^{-H}
5756     gR = at::linalg_solve_triangular(
5757         L.mH(), gR, /*upper*/ true, /*left*/ false, /*unitriangular*/ true);
5758     // gU = (L^H gR U^{-H}).triu()
5759     auto gU = at::linalg_solve_triangular(
5760                   U.mH(), L.mH().matmul(gR), /*upper*/ false, /*left*/ false)
5761                   .triu();
5762     return gR.tril(-1) + gU;
5763   }
5764 }
5765 
linalg_lu_solve_jvp(const Tensor & X,const Tensor & LU,const Tensor & pivots,const Tensor & dLU,const Tensor & dB,const bool left,const bool adjoint)5766 Tensor linalg_lu_solve_jvp(
5767     const Tensor& X,
5768     const Tensor& LU,
5769     const Tensor& pivots,
5770     const Tensor& dLU,
5771     const Tensor& dB,
5772     const bool left,
5773     const bool adjoint) {
5774   // We write the derivation in terms of some adjoint operations, as otherwise
5775   // we would need to write down 4 different proofs with 4 different
5776   // implementations and it'd be painful to derive and maintain Below, we just
5777   // use that X -> X^H is linear, so it commutes with the derivative The
5778   // derivation follows by differentiating op_1(PLU)op_2(X) = op_3(B)
5779 
5780   // left = True, adjoint = True: A^HX = B
5781   // left = True, adjoint = False: AX = B
5782   // left = False, adjoint = True: AX^H = B^H
5783   // left = False, adjoint = False: A^HX^H = B^H
5784   // let op_1(A) = A^H or op_1(A) = A according to the list above
5785   // same with op_2(X) and op_3(B)
5786   // We have that letting S = lu_solve(LU, pivots, dB, left, adjoint)
5787   // the JVP formula reads
5788   // dX = op_2(op_1(-U^{-1}(dUU^{-1} + L^{-1}dL)L^{-1} P^T)op_3(B)) + S
5789 
5790   at::NoTF32Guard disable_tf32;
5791   auto S = at::linalg_lu_solve(LU, pivots, dB, left, adjoint);
5792   if (left != adjoint) {
5793     // We see that when left != adjoint, op_1(A) = A, and we can substitute
5794     // A^{-1}op_3(B) by op_2(X) dX = op_2(-U^{-1}(dU + L^{-1}dL U)op_2(X)) + S
5795     // Let R = -U^{-1}(dU + L^{-1}dL U)
5796     auto R = at::linalg_solve_triangular(
5797         LU,
5798         dLU.tril(-1),
5799         /*upper*/ false,
5800         /*left*/ true,
5801         /*unitriangular*/ true);
5802     auto U = LU.triu();
5803     R = -at::linalg_solve_triangular(
5804         U, dLU.triu() + R.matmul(U), /*upper*/ true);
5805     // dX = op_2(R op_2(X)) + S
5806     return (left ? R.matmul(X) : X.matmul(R.mH())) + S;
5807   } else {
5808     // We see that when left == adjoint, op_1(A) = A^H
5809     // dX = op_2(op_1(-op_3(B)^H U^{-1}(dUU^{-1} + L^{-1}dL)L^{-1} P^T)) + S
5810     // Now, note that whenever adjoint == left, we have that
5811     // op_3(B)^H A^{-1} = op_3(X)^H
5812     // We can then rewrite the formula above in terms of X as
5813     // dX = op_2(op_1(-op_3(X)^H P(LdUU^{-1} + dL)L^{-1} P^T)) + S
5814     auto [P, L, U] = at::lu_unpack(LU, pivots);
5815     // Compute V = op_3(X)^H
5816     auto V = left ? X.mH() : X;
5817     // Compute the inner parens LdUU^{-1} + dL
5818     auto R = at::linalg_solve_triangular(
5819                  U, L.matmul(dLU.triu()), /*upper*/ true, /*left*/ false) +
5820         dLU.tril(-1);
5821     // dX = op_2(op_1(-op_3(X)^H PRL^{-1} P^T)) + S
5822     R = at::linalg_solve_triangular(
5823             L,
5824             -V.matmul(P).matmul(R),
5825             /*upper*/ false,
5826             /*left*/ false,
5827             /*unitriangular*/ true)
5828             .matmul(P.mT());
5829     // dX = op_2(R^H) + S
5830     return (left ? R.mH() : std::move(R)) + S;
5831   }
5832 }
5833 
linalg_solve_jvp(const Tensor & dA,const Tensor & dB,const Tensor & X,const Tensor & LU,const Tensor & pivots,const bool left,const bool use_A_T)5834 Tensor linalg_solve_jvp(
5835     const Tensor& dA,
5836     const Tensor& dB,
5837     const Tensor& X,
5838     const Tensor& LU,
5839     const Tensor& pivots,
5840     const bool left,
5841     const bool use_A_T) {
5842   at::NoTF32Guard disable_tf32;
5843   // For left=True (left=False is analogous)
5844   // dX = A^{-1}(dB - dAX)
5845 
5846   // [NumPy compat] Case where the rhs is a vector.
5847   // We denote with an underscore vectors that have been converted to matrices
5848   // by `unsqueeze(-1)`
5849   const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X);
5850   const auto vector_to_matrix = [vector_case](const Tensor& X) {
5851     return vector_case ? X.unsqueeze(-1) : X;
5852   };
5853   const auto matrix_to_vector = [vector_case](const Tensor& X) {
5854     return vector_case ? X.squeeze(-1) : X;
5855   };
5856 
5857   // This case is disallowed in the primal operation as A.shape = (*, 1, 1)
5858   TORCH_INTERNAL_ASSERT(left || !vector_case);
5859 
5860   auto X_ = vector_to_matrix(X);
5861   auto dB_ = vector_to_matrix(dB);
5862   auto R_ = left ? dA.matmul(X_) : X_.matmul(dA);
5863   auto dX_ =
5864       at::linalg_lu_solve(LU, pivots, dB_ - R_, left, /*adjoint*/ use_A_T);
5865   return matrix_to_vector(dX_);
5866 }
5867 
linalg_solve_backward(const Tensor & gX,const Tensor & X,const Tensor & A,const Tensor & LU,const Tensor & pivots,const bool left,const bool B_requires_grad)5868 std::tuple<Tensor, Tensor> linalg_solve_backward(
5869     const Tensor& gX,
5870     const Tensor& X,
5871     const Tensor& A,
5872     const Tensor& LU,
5873     const Tensor& pivots,
5874     const bool left,
5875     const bool B_requires_grad) {
5876   // for X = A^{-1}B
5877   // gB = A^{-H}gX
5878   // gA = -gB X^H
5879   at::NoTF32Guard disable_tf32;
5880   const auto A_requires_grad = A.requires_grad();
5881   if (!gX.defined() || (!A_requires_grad && !B_requires_grad)) {
5882     return {};
5883   }
5884 
5885   // [NumPy compat] Case where the rhs is a vector.
5886   // We denote with an underscore vectors that have been converted to matrices
5887   // by `unsqueeze(-1)`
5888   const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X);
5889   const auto vector_to_matrix = [vector_case](const Tensor& X) {
5890     return vector_case ? X.unsqueeze(-1) : X;
5891   };
5892   const auto matrix_to_vector = [vector_case](const Tensor& X) {
5893     return vector_case ? X.squeeze(-1) : X;
5894   };
5895 
5896   // If the user is going to compute higher order gradients, then we need to
5897   // recompute the LU and the pivots
5898   Tensor gB_;
5899   if (at::GradMode::is_enabled()) {
5900     gB_ = at::linalg_solve(A.mH(), vector_to_matrix(gX), left);
5901   } else {
5902     const auto use_A_T = A.is_contiguous() && !A.is_complex();
5903     gB_ = at::linalg_lu_solve(
5904         LU, pivots, vector_to_matrix(gX), left, /*adjoint*/ !use_A_T);
5905   }
5906 
5907   Tensor gA_;
5908   if (A_requires_grad) {
5909     auto X_ = vector_to_matrix(X);
5910     gA_ = left ? -gB_.matmul(X_.mH()) : -X_.mH().matmul(gB_);
5911   }
5912   return std::make_tuple(
5913       A_requires_grad ? std::move(gA_) : Tensor{},
5914       B_requires_grad ? matrix_to_vector(gB_) : Tensor{});
5915 }
5916 
solve_jvp(const Tensor & X,const Tensor & A,const Tensor & dA,const Tensor & dB)5917 Tensor solve_jvp(
5918     const Tensor& X,
5919     const Tensor& A,
5920     const Tensor& dA,
5921     const Tensor& dB) {
5922   return generic_solve_jvp(
5923       [](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
5924         return at::linalg_solve(A, dB - dA_contrib);
5925       },
5926       X,
5927       A,
5928       dA,
5929       dB);
5930 }
5931 
lu_unpack_backward(const Tensor & L_grad,const Tensor & U_grad,const c10::SymInt & m,const c10::SymInt & n)5932 Tensor lu_unpack_backward(
5933     const Tensor& L_grad,
5934     const Tensor& U_grad,
5935     const c10::SymInt& m,
5936     const c10::SymInt& n) {
5937   if (!L_grad.defined() && !U_grad.defined()) {
5938     return {};
5939   }
5940   const auto k = std::min(m, n);
5941 
5942   // Getters for the principal and complementary part of the matrices
5943   const auto get_L1 = [m, k](const Tensor& L) {
5944     return m == k ? L.tril(-1) : L.narrow_symint(-2, 0, k).tril(-1);
5945   };
5946   const auto get_L2 = [m, k](const Tensor& L) {
5947     return L.narrow_symint(-2, k, m - k);
5948   };
5949   const auto get_U1 = [n, k](const Tensor& U) {
5950     return n == k ? U.triu() : U.narrow_symint(-1, 0, k).triu();
5951   };
5952   const auto get_U2 = [n, k](const Tensor& U) {
5953     return U.narrow_symint(-1, k, n - k);
5954   };
5955 
5956   if (L_grad.defined()) {
5957     if (U_grad.defined()) {
5958       if (m == n) {
5959         return L_grad.tril(-1) + U_grad.triu();
5960       } else {
5961         auto A1_grad = get_L1(L_grad) + get_U1(U_grad);
5962         auto A2_grad = m > n ? get_L2(L_grad) : get_U2(U_grad);
5963         const auto dim = m > n ? -2 : -1;
5964         return at::cat({std::move(A1_grad), std::move(A2_grad)}, /*dim=*/dim);
5965       }
5966     } else {
5967       if (m >= n) {
5968         return L_grad.tril(-1);
5969       } else {
5970         auto size = L_grad.sym_sizes().vec();
5971         size.end()[-1] = n - m;
5972         return at::cat(
5973             {L_grad.tril(-1), at::zeros_symint(size, L_grad.options())},
5974             /*dim=*/-1);
5975       }
5976     }
5977   } else {
5978     if (n >= m) {
5979       return U_grad.triu();
5980     } else {
5981       auto size = U_grad.sym_sizes().vec();
5982       size.end()[-2] = m - n;
5983       return at::cat(
5984           {U_grad.triu(), at::zeros_symint(size, U_grad.options())},
5985           /*dim=*/-2);
5986     }
5987   }
5988 }
5989 
cat_jvp(const at::ITensorListRef & tensors,int64_t dim)5990 Tensor cat_jvp(const at::ITensorListRef& tensors, int64_t dim) {
5991   Tensor out_fw_grad;
5992 
5993   auto materialized = tensors.materialize();
5994   auto any_defined = false;
5995   for (const Tensor& t : materialized) {
5996     any_defined |= isFwGradDefined(t);
5997   }
5998 
5999   if (any_defined) {
6000     std::vector<Tensor> fw_grads;
6001 
6002     for (const Tensor& t : materialized) {
6003       fw_grads.push_back(
6004           isFwGradDefined(t)
6005               ? t._fw_grad(/*level*/ 0)
6006               : at::_efficientzerotensor(t.sizes(), t.options()));
6007     }
6008 
6009     out_fw_grad = at::cat(fw_grads, dim);
6010   }
6011 
6012   return out_fw_grad;
6013 }
6014 
block_diag_jvp(at::TensorList tensors)6015 Tensor block_diag_jvp(at::TensorList tensors) {
6016   Tensor out_fw_grad;
6017 
6018   auto any_defined = false;
6019   for (const auto& t : tensors) {
6020     any_defined |= isFwGradDefined(t);
6021   }
6022 
6023   if (any_defined) {
6024     std::vector<Tensor> fw_grads;
6025     fw_grads.reserve(tensors.size());
6026 
6027     for (const auto& t : tensors) {
6028       fw_grads.push_back(
6029           isFwGradDefined(t)
6030               ? t._fw_grad(/*level*/ 0)
6031               : at::_efficientzerotensor(t.sizes(), t.options()));
6032     }
6033 
6034     out_fw_grad = at::block_diag(fw_grads);
6035   }
6036 
6037   return out_fw_grad;
6038 }
6039 
stack_jvp(at::TensorList tensors,int64_t dim)6040 Tensor stack_jvp(at::TensorList tensors, int64_t dim) {
6041   // Basically copy of cat_jvp above
6042   // TODO: consolidate with the logic of cat_jvp
6043   Tensor out_fw_grad;
6044 
6045   auto any_defined = false;
6046   for (const auto& t : tensors) {
6047     any_defined |= isFwGradDefined(t);
6048   }
6049 
6050   if (any_defined) {
6051     std::vector<Tensor> fw_grads;
6052 
6053     for (auto& t : tensors) {
6054       fw_grads.push_back(
6055           isFwGradDefined(t)
6056               ? t._fw_grad(/*level*/ 0)
6057               : at::_efficientzerotensor(t.sizes(), t.options()));
6058     }
6059     out_fw_grad = at::stack(fw_grads, dim);
6060   }
6061   return out_fw_grad;
6062 }
6063 
cumprod_jvp(const Tensor & self_t,const Tensor & self_p,const Tensor & result,int dim)6064 Tensor cumprod_jvp(
6065     const Tensor& self_t,
6066     const Tensor& self_p,
6067     const Tensor& result,
6068     int dim) {
6069   // Generic formula when no 0. is involved
6070   Tensor gradient = (self_t / self_p).cumsum(dim) * result;
6071 
6072   // Note that we have to use at::where below as we are removing nans
6073 
6074   if (self_p.dim() == 0) {
6075     gradient.masked_fill_(self_p.eq(0), self_t);
6076     return gradient;
6077   } else {
6078     // For input (a, 0, b, 0, c) with gradients (t0, t1, t2, t3, t4)
6079     // The output of cumprod is (a, 0, 0, 0, 0)
6080     // The gradient we want to compute is (t0, a*t1, a*b*t1, 0, 0)
6081     // We do this by:
6082     // Get a mask of all zeros (0, 1, 0, 1, 0)
6083     auto mask_zeros = self_p.eq(0);
6084     // Get a mask of the first zero for each dim (0, 1, 0, 0, 0)
6085     auto mask_first_zero = mask_zeros.logical_and(mask_zeros.cumsum(dim).eq(1));
6086 
6087     // Get the new grad value that should be used after any zero happened:
6088     // (X, a*t1, a*b*t1, 0, 0) = cumprod((a, t1, b, 0, c))
6089     auto new_grad = at::where(mask_first_zero, self_t, self_p).cumprod(dim);
6090 
6091     // Get a mask of everything after the first zero: (0, 1, 1, 1, 1)
6092     auto mask_after_first_zero = mask_first_zero.cumsum(dim);
6093 
6094     // Do the final replacement
6095     return at::where(
6096         mask_after_first_zero.to(ScalarType::Bool), new_grad, gradient);
6097   }
6098 }
6099 
6100 // Helper for {batch,layer,group}_norms below
6101 // Computes the jvp for `1 / input.std(dims, keepdim)`
_invstd_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & mean_p,const Tensor & invstd_p,IntArrayRef dims,int64_t numel,bool keepdim)6102 static Tensor _invstd_jvp(
6103     const Tensor& input_p,
6104     const Tensor& input_t,
6105     const Tensor& mean_p,
6106     const Tensor& invstd_p,
6107     IntArrayRef dims,
6108     int64_t numel,
6109     bool keepdim) {
6110   Tensor invstd_t;
6111   if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) ||
6112       input_t._is_zerotensor()) {
6113     invstd_t = -invstd_p.pow(3) * (input_t - input_t.mean(dims, true)) *
6114         (input_p - mean_p);
6115   } else {
6116     invstd_t = input_t - input_t.mean(dims, true);
6117     invstd_t *= input_p - mean_p;
6118     invstd_t *= -invstd_p.pow(3);
6119   }
6120   invstd_t = invstd_t.sum(dims, keepdim);
6121   invstd_t /= numel;
6122   return invstd_t;
6123 }
6124 
6125 // Helper for {batch,layer,group}_norms below only
6126 // Computes the jvp for `(input - input.mean(dims)) * input.invstd(dims)`
_norm_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & mean_p,const Tensor & invstd_p,IntArrayRef dims,int64_t numel)6127 static Tensor _norm_jvp(
6128     const Tensor& input_p,
6129     const Tensor& input_t,
6130     const Tensor& mean_p,
6131     const Tensor& invstd_p,
6132     IntArrayRef dims,
6133     int64_t numel) {
6134   auto invstd_t =
6135       _invstd_jvp(input_p, input_t, mean_p, invstd_p, dims, numel, true);
6136   Tensor result_t;
6137   if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) ||
6138       input_t._is_zerotensor()) {
6139     result_t = (input_t - input_t.mean(dims, true)) * invstd_p +
6140         (input_p - mean_p) * invstd_t;
6141   } else {
6142     result_t = input_t - input_t.mean(dims, true);
6143     result_t *= invstd_p;
6144     auto temp = input_p - mean_p;
6145     temp *= invstd_t;
6146     result_t += temp;
6147   }
6148   return result_t;
6149 }
6150 
6151 // Helper for {batch,layer,group}_norms below only
6152 // Computes the jvp for `input * weight + bias` where weight and bias may be
6153 // undefined Possibly modifies the input inplace
_affine_jvp(const std::optional<Tensor> & input_p,Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_t)6154 static Tensor _affine_jvp(
6155     const std::optional<Tensor>& input_p,
6156     Tensor& input_t,
6157     const Tensor& weight_p,
6158     const Tensor& weight_t,
6159     const Tensor& bias_t) {
6160   // We allow input_p to be optional because if weight_p isn't defined,
6161   // it may be possible to avoid computing input_p
6162   TORCH_INTERNAL_ASSERT(input_p.has_value() == weight_p.defined());
6163   if (weight_p.defined()) {
6164     if (areAnyTensorSubclassLike(
6165             // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
6166             {input_p.value(), input_t, weight_p, weight_t}) ||
6167         input_t._is_zerotensor() || weight_t._is_zerotensor()) {
6168       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
6169       input_t = input_t * weight_p + input_p.value() * weight_t;
6170     } else {
6171       input_t *= weight_p;
6172       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
6173       auto temp = input_p.value();
6174       temp *= weight_t;
6175       input_t += temp;
6176     }
6177   }
6178   if (bias_t.defined()) {
6179     if (areAnyTensorSubclassLike({input_t, bias_t}) ||
6180         input_t._is_zerotensor()) {
6181       input_t = input_t + bias_t;
6182     } else {
6183       input_t += bias_t;
6184     }
6185   }
6186   return input_t;
6187 }
6188 
batch_norm_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_p,const Tensor & bias_t,const std::optional<Tensor> & running_mean,const std::optional<Tensor> & running_var,const Tensor & saved_mean,const Tensor & saved_invstd,bool train,double eps)6189 Tensor batch_norm_jvp(
6190     const Tensor& input_p,
6191     const Tensor& input_t,
6192     const Tensor& weight_p,
6193     const Tensor& weight_t,
6194     const Tensor& bias_p,
6195     const Tensor& bias_t,
6196     const std::optional<Tensor>& running_mean,
6197     const std::optional<Tensor>& running_var,
6198     const Tensor& saved_mean,
6199     const Tensor& saved_invstd,
6200     bool train,
6201     double eps) {
6202   auto dims = std::vector<int64_t>{};
6203   auto view_size = input_t.sizes().vec();
6204   int64_t numel = 1;
6205   for (const auto dim : c10::irange(view_size.size())) {
6206     if (dim != 1) {
6207       numel *= input_t.size(static_cast<int64_t>(dim));
6208       view_size[dim] = 1;
6209       dims.push_back(static_cast<int64_t>(dim));
6210     }
6211   }
6212   Tensor mean_p;
6213   Tensor invstd_p;
6214   Tensor result_t;
6215   if (train) {
6216     mean_p = saved_mean.view(view_size);
6217     invstd_p = saved_invstd.view(view_size);
6218     result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);
6219   } else {
6220     TORCH_INTERNAL_ASSERT(
6221         running_mean.has_value() && running_var.has_value(),
6222         "Expect running_mean and running_var to have value when train=false");
6223     TORCH_CHECK(
6224         !running_mean.value()._fw_grad(/*level=*/0).defined() &&
6225             !running_var.value()._fw_grad(/*level=*/0).defined(),
6226         "batch_norm is not differentiable wrt running_mean and running_var, they cannot have forward grad defined");
6227     mean_p = running_mean.value().view(view_size);
6228     invstd_p =
6229         (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size);
6230     result_t = input_t * invstd_p;
6231   }
6232 
6233   std::optional<Tensor> result_p = weight_p.defined()
6234       ? std::optional<Tensor>((input_p - mean_p) * invstd_p)
6235       : std::nullopt;
6236   return _affine_jvp(
6237       result_p,
6238       result_t,
6239       weight_p.defined() ? weight_p.view(view_size) : weight_p,
6240       weight_t.defined() ? weight_t.view(view_size) : weight_t,
6241       bias_t.defined() ? bias_t.view(view_size) : bias_t);
6242 }
6243 
layer_norm_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_p,const Tensor & bias_t,const Tensor & saved_mean,const Tensor & saved_invstd,c10::SymIntArrayRef normalized_shape)6244 Tensor layer_norm_jvp(
6245     const Tensor& input_p,
6246     const Tensor& input_t,
6247     const Tensor& weight_p,
6248     const Tensor& weight_t,
6249     const Tensor& bias_p,
6250     const Tensor& bias_t,
6251     const Tensor& saved_mean,
6252     const Tensor& saved_invstd,
6253     c10::SymIntArrayRef normalized_shape) {
6254   auto dims = std::vector<int64_t>{};
6255   auto view_size = input_t.sizes().vec();
6256   auto view_size_affine = input_t.sizes().vec();
6257 
6258   int64_t numel = 1;
6259   for (const auto i : c10::irange(view_size.size())) {
6260     if (i < view_size.size() - normalized_shape.size()) {
6261       view_size_affine[i] = 1;
6262     } else {
6263       numel *= input_t.size(static_cast<int64_t>(i));
6264       view_size[i] = 1;
6265       dims.push_back(static_cast<int64_t>(i));
6266     }
6267   }
6268   auto mean_p = saved_mean.view(view_size);
6269   auto invstd_p = saved_invstd.view(view_size);
6270   auto result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);
6271 
6272   std::optional<Tensor> result_p = weight_p.defined()
6273       ? std::optional<Tensor>((input_p - mean_p) * invstd_p)
6274       : std::nullopt;
6275   return _affine_jvp(
6276       result_p,
6277       result_t,
6278       weight_p.defined() ? weight_p.view(view_size_affine) : weight_p,
6279       weight_t.defined() ? weight_t.view(view_size_affine) : weight_t,
6280       bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
6281 }
6282 
group_norm_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & weight_p,const Tensor & weight_t,const Tensor & bias_p,const Tensor & bias_t,const Tensor & saved_mean,const Tensor & saved_invstd,int64_t groups)6283 Tensor group_norm_jvp(
6284     const Tensor& input_p,
6285     const Tensor& input_t,
6286     const Tensor& weight_p,
6287     const Tensor& weight_t,
6288     const Tensor& bias_p,
6289     const Tensor& bias_t,
6290     const Tensor& saved_mean,
6291     const Tensor& saved_invstd,
6292     int64_t groups) {
6293   auto input_shape = input_p.sizes();
6294   int64_t N = input_p.size(0);
6295   int64_t C = input_p.size(1);
6296 
6297   auto input_t_reshaped = input_t.view({1, N * groups, N ? -1 : 1});
6298   auto input_p_reshaped = input_p.view({1, N * groups, N ? -1 : 1});
6299 
6300   auto result_t = batch_norm_jvp(
6301                       input_p_reshaped,
6302                       input_t_reshaped,
6303                       /*weight_p=*/{},
6304                       /*weight_t=*/{},
6305                       /*bias_p=*/{},
6306                       /*bias_t=*/{},
6307                       /*running_mean=*/{},
6308                       /*running_var=*/{},
6309                       saved_mean,
6310                       saved_invstd,
6311                       /*train=*/true,
6312                       /*eps=*/0)
6313                       .view(input_shape);
6314 
6315   std::optional<Tensor> result_p = std::nullopt;
6316   if (weight_p.defined()) {
6317     std::vector<int64_t> view_size(input_t_reshaped.dim(), 1);
6318     view_size[1] = input_t_reshaped.size(1);
6319     result_p = ((input_p_reshaped - saved_mean.view(view_size)) *
6320                 saved_invstd.view(view_size))
6321                    .view(input_shape);
6322   }
6323   std::vector<int64_t> affine_param_shape(input_p.dim(), 1);
6324   affine_param_shape[1] = C;
6325 
6326   return _affine_jvp(
6327       result_p,
6328       result_t,
6329       weight_p.defined() ? weight_p.view(affine_param_shape) : weight_p,
6330       weight_t.defined() ? weight_t.view(affine_param_shape) : weight_t,
6331       bias_t.defined() ? bias_t.view(affine_param_shape) : bias_t);
6332 }
6333 
group_norm_mean_jvp(const Tensor & input_t,const Tensor & mean_p,int64_t groups)6334 Tensor group_norm_mean_jvp(
6335     const Tensor& input_t,
6336     const Tensor& mean_p,
6337     int64_t groups) {
6338   int64_t N = input_t.size(0);
6339   std::array<int64_t, 3> view_shape = {1, N * groups, N ? -1 : 1};
6340   auto input_t_reshaped = input_t.view(view_shape);
6341   return input_t_reshaped.mean({2}, false).view_as(mean_p);
6342 }
6343 
group_norm_invstd_jvp(const Tensor & input_p,const Tensor & input_t,const Tensor & mean_p,const Tensor & invstd_p,int64_t groups)6344 Tensor group_norm_invstd_jvp(
6345     const Tensor& input_p,
6346     const Tensor& input_t,
6347     const Tensor& mean_p,
6348     const Tensor& invstd_p,
6349     int64_t groups) {
6350   int64_t N = input_p.size(0);
6351 
6352   std::vector<int64_t> view_shape = {1, N * groups, N ? -1 : 1};
6353 
6354   auto input_t_reshaped = input_t.view(view_shape);
6355   auto input_p_reshaped = input_p.view(view_shape);
6356 
6357   return _invstd_jvp(
6358              input_t_reshaped,
6359              input_p_reshaped,
6360              mean_p.view(view_shape),
6361              invstd_p.view(view_shape),
6362              /*dims=*/{2},
6363              /*numel=*/input_t_reshaped.size(2),
6364              /*keepdim=*/false)
6365       .view_as(invstd_p);
6366 }
6367 
gather_with_keepdimed_indices(const Tensor & input,int64_t dim,const Tensor & indices,bool keepdim)6368 Tensor gather_with_keepdimed_indices(
6369     const Tensor& input,
6370     int64_t dim,
6371     const Tensor& indices,
6372     bool keepdim) {
6373   auto full_indices = indices;
6374   if (!keepdim) {
6375     full_indices = indices.unsqueeze(dim);
6376   }
6377   auto out_fw_grad = at::gather(input, dim, full_indices);
6378   if (!keepdim) {
6379     out_fw_grad = out_fw_grad.squeeze(dim);
6380   }
6381 
6382   return out_fw_grad;
6383 }
6384 
6385 // Let A in \C^{m \times n}, then its pivoted LU decomposition is
6386 // A = P L U, where P is a permutation matrix.
6387 //
6388 // Useful notation:
6389 // Let o denote the elementwise, or Hadamard, product.
6390 // k := min(m, n)
6391 // 1 := ones(k, k),
6392 // 1_U = 1.tril();
6393 // 1_L = 1 - 1_U (note the diagonal is zero)
6394 // For a matrix A, A^H := A.mH()
6395 //
6396 // Below we derive the backward algorithm for the case when m <= n.
6397 // The case m > n could be obtained using the same idea.
6398 // Since we assume m <= n, the LU decomposition of A could be written as
6399 // A = (A1 | A2) = P L (U1 | U2) where A1, U1 in \C^{m \times m}, A2, U2 in
6400 // \C^{m, n - m}
6401 //
6402 // Forward AD:
6403 //
6404 // dA = P dL U + P L dU => [left-multiply P^T]
6405 // (P^T dA1 | P^T dA2) = (dL U1 + L dU1 | dL U2 + L dU2) (*)
6406 // From (*):
6407 // P^T dA1 = dL U1 + L dU1 => [left-multiply by L^{-1}, right-multiply by
6408 // U1^{-1}] L^{-1} P^T dA1 U1^{-1} = L^{-1} dL + dU1 U1^{-1} (**). Note, L is
6409 // lower-triangular, and so is its inverse, hence L^{-1} dL is lower-triangular.
6410 // Also, since the diagonal of L (all ones) is never exposed explicitly (packed
6411 // representation), the diagonal of dL is zero, and hence diag(L^{-1} dL) = 0.
6412 // Assuming that U1 is full-rank, similarly, dU1 U1^{-1} is upper-triangular.
6413 // Combining these observations we conclude:
6414 //
6415 // L^{-1} dL = (L^{-1} P^T dA1 U1^{-1}) o 1_L,
6416 // dU1 U1^{-1} = (L^{-1} P^T dA1 U1^{-1}) o 1_U.
6417 //
6418 // Hence,
6419 // dL = L [(L^{-1} P^T dA1 U1^{-1}) o 1_L],
6420 // dU1 = [(L^{-1} P^T dA1 U1^{-1}) o 1_U] U1.
6421 // As for dU2, from (*) it follows
6422 // P^T dA2 = dL U2 + L dU2 =>
6423 // dU2 = L^{-1} (P^T dA2 - dL U2).
6424 //
6425 // Backward AD:
6426 //
6427 // The following equality comes very handy:
6428 // Tr(A (B o C)) = Tr((A o B^T) C) (!)
6429 // or in other words, given that X -> B o X is a pointwise operation
6430 // its Jacobian is diagonal, so its differential is self-adjoint
6431 // <A, B o C> = <A o B, C>
6432 //
6433 // Tr(A_grad^H dA) = Tr(L_grad^H dL) + Tr(U_grad^H dU), then
6434 //
6435 // Tr(L_grad^H dL) = Tr(L_grad^H L [(L^{-1} P^T dA1 U1^{-1}) o 1_L] = [using
6436 // (!)]
6437 //                 = Tr((L_grad^H L o 1_L^T) L^{-1} P^T dA1 U1^{-1}) = [using
6438 //                 the cyclic property of Tr] = Tr(U1^{-1} (L_grad^H L o 1_L^T)
6439 //                 L^{-1} P^T dA1)
6440 //
6441 // Similar, using (!) and the cyclic property of the trace operator:
6442 // Tr(U_grad^H dU) = Tr(U1_grad^H dU1) + Tr(U2_grad^H dU2)
6443 //                 = Tr(U1^{-1} (U1 U1_grad^H o 1_U^T) L^{-1} P^T dA1)
6444 //                   + Tr(U2_grad^H L^{-1} P^T dA2)
6445 //                   - Tr(U1^{-1} (U2 U2_grad^H o 1_L^T) L^{-1} P^T dA1)
6446 //
6447 // By combining the matrices to the left from dA1 and dA2 and then applying
6448 // conjugate transposition, we finally arrive at:
6449 //
6450 // A1_grad = P L^{-H} [L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o
6451 // 1_L] U1^{-H}, A2_grad = P L^{-H} U2_grad
linalg_lu_backward(const Tensor & L_grad,const Tensor & U_grad,const Tensor & P,const Tensor & L,const Tensor & U,const bool pivot)6452 Tensor linalg_lu_backward(
6453     const Tensor& L_grad,
6454     const Tensor& U_grad,
6455     const Tensor& P,
6456     const Tensor& L,
6457     const Tensor& U,
6458     const bool pivot) {
6459   at::NoTF32Guard disable_tf32;
6460   // Return early if there's nothing to do
6461   if (!L_grad.defined() && !U_grad.defined()) {
6462     return {};
6463   }
6464 
6465   // L.shape == (..., m, k)
6466   // U.shape == (..., k, n)
6467   auto m = L.sym_size(-2);
6468   auto n = U.sym_size(-1);
6469   auto k = std::min(m, n);
6470 
6471   if (m == n) {
6472     // A_grad = P L^{-H} [L^H L_grad o 1_L + U_grad U^H o 1_U] U^{-H},
6473     auto A_grad = L_grad.defined() ? L.mH().matmul(L_grad).tril(-1) : Tensor{};
6474     if (U_grad.defined()) {
6475       A_grad = A_grad.defined() ? A_grad + U_grad.matmul(U.mH()).triu()
6476                                 : U_grad.matmul(U.mH()).triu();
6477     }
6478     A_grad = at::linalg_solve_triangular(
6479         U.mH(),
6480         A_grad,
6481         /*upper=*/false,
6482         /*left=*/false);
6483     A_grad = at::linalg_solve_triangular(
6484         L.mH(),
6485         A_grad,
6486         /*upper=*/true,
6487         /*left=*/true,
6488         /*unitriangular=*/true);
6489 
6490     return pivot ? P.matmul(A_grad) : std::move(A_grad);
6491   } else if (m < n) {
6492     // Wide case
6493     // A1_grad = P L^{-H} [U1_grad + (L^H L_grad o 1_L - U_grad U^H o 1_U)
6494     // U1^{-H}) U^{-H}] A2_grad = P L^{-H}  U2_grad
6495     const auto get_U1 = [n, k](const Tensor& U) {
6496       return n == k ? U : U.narrow_symint(-1, 0, k);
6497     };
6498     const auto get_U2 = [n, k](const Tensor& U) {
6499       return U.narrow_symint(-1, k, n - k);
6500     };
6501 
6502     auto A_grad = L_grad.defined() ? L.mH().matmul(L_grad) : Tensor{};
6503     if (U_grad.defined()) {
6504       A_grad = A_grad.defined() ? A_grad - U_grad.triu().matmul(U.mH())
6505                                 : -U_grad.triu().matmul(U.mH());
6506     }
6507     A_grad = at::linalg_solve_triangular(
6508         get_U1(U).mH(),
6509         A_grad.tril(-1),
6510         /*upper=*/false,
6511         /*left=*/false);
6512 
6513     if (U_grad.defined()) {
6514       A_grad =
6515           at::cat({A_grad + get_U1(U_grad).triu(), get_U2(U_grad)}, /*dim=*/-1);
6516     }
6517 
6518     A_grad = at::linalg_solve_triangular(
6519         L.mH(),
6520         A_grad,
6521         /*upper=*/true,
6522         /*left=*/true,
6523         /*unitriangular=*/true);
6524 
6525     if (!U_grad.defined()) {
6526       A_grad = at::cat({A_grad, at::zeros_like(get_U2(U))}, /*dim=*/-1);
6527     }
6528     if (pivot) {
6529       A_grad = P.matmul(A_grad);
6530     }
6531     return A_grad;
6532   } else {
6533     // Tall case
6534     // A1_grad = P [L1_grad + L^{-H} (U_grad U^H o 1_U - L^H L_grad o
6535     // 1_L)]U^{-H} A2_grad = P  L2_grad U^{-H}
6536 
6537     const auto get_L1 = [m, k](const Tensor& L) {
6538       return m == k ? L : L.narrow_symint(-2, 0, k);
6539     };
6540     const auto get_L2 = [m, k](const Tensor& L) {
6541       return L.narrow_symint(-2, k, m - k);
6542     };
6543 
6544     auto A_grad = U_grad.defined() ? U_grad.matmul(U.mH()) : Tensor{};
6545     if (L_grad.defined()) {
6546       A_grad = A_grad.defined() ? A_grad - L.mH().matmul(L_grad.tril(-1))
6547                                 : -L.mH().matmul(L_grad.tril(-1));
6548     }
6549     A_grad = at::linalg_solve_triangular(
6550         get_L1(L).mH(),
6551         A_grad.triu(),
6552         /*upper=*/true,
6553         /*left=*/true,
6554         /*unitriangular=*/true);
6555 
6556     if (L_grad.defined()) {
6557       A_grad = at::cat(
6558           {A_grad + get_L1(L_grad).tril(-1), get_L2(L_grad)}, /*dim=*/-2);
6559     }
6560 
6561     A_grad = at::linalg_solve_triangular(
6562         U.mH(),
6563         A_grad,
6564         /*upper=*/false,
6565         /*left=*/false);
6566 
6567     if (!L_grad.defined()) {
6568       A_grad = at::cat({A_grad, at::zeros_like(get_L2(L))}, /*dim=*/-2);
6569     }
6570     if (pivot) {
6571       A_grad = P.matmul(A_grad);
6572     }
6573     return A_grad;
6574   }
6575 }
6576 
lu_factor_ex_backward(const Tensor & grad,const Tensor & LU,const Tensor & pivs,const bool pivot)6577 Tensor lu_factor_ex_backward(
6578     const Tensor& grad,
6579     const Tensor& LU,
6580     const Tensor& pivs,
6581     const bool pivot) {
6582   auto [P, L, U] =
6583       at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots*/ pivot);
6584 
6585   // L.shape == (..., m, k)
6586   // U.shape == (..., k, n)
6587   const auto m = LU.sym_size(-2);
6588   const auto n = LU.sym_size(-1);
6589   const auto k = std::min(m, n);
6590   const auto L_grad = grad.narrow_symint(-1, 0, k);
6591   const auto U_grad = grad.narrow_symint(-2, 0, k);
6592   return linalg_lu_backward(
6593       /*L_grad=*/L_grad, /*U_grad=*/U_grad, P, L, U, pivot);
6594 }
6595 
6596 // This function is based on the forward AD derivations outlined
6597 // in the description to the linalg_lu_backward function.
linalg_lu_jvp(const Tensor & dA,const Tensor & P,const Tensor & L,const Tensor & U,const bool pivot)6598 std::tuple<Tensor, Tensor> linalg_lu_jvp(
6599     const Tensor& dA,
6600     const Tensor& P,
6601     const Tensor& L,
6602     const Tensor& U,
6603     const bool pivot) {
6604   at::NoTF32Guard disable_tf32;
6605 
6606   auto m = dA.size(-2);
6607   auto n = dA.size(-1);
6608   auto k = std::min(m, n);
6609 
6610   auto PdA = pivot ? P.mT().matmul(dA) : dA;
6611 
6612   // similar to the backward implementation, we also consider block structures
6613   // such as: for a matrix A of size m x n we decompose it as A = (A1 | A2) with
6614   // A1 of size m x m if m <= n and A = (A1^T | A2^T)^T with A1 of size n x n if
6615   // m > n.
6616   auto PdA1 = PdA.narrow(-2, 0, k).narrow(-1, 0, k);
6617   auto L1 = L.narrow(-2, 0, k).narrow(-1, 0, k);
6618   auto U1 = U.narrow(-2, 0, k).narrow(-1, 0, k);
6619 
6620   // We form using two triangular_solve the matrix, the second one in place
6621   // dK = L1^{-1} PdA1 U2^{-1}
6622   auto dK = at::linalg_solve_triangular(
6623       L1, PdA1, /*upper=*/false, /*left=*/true, /*unitriangular*/ true);
6624 
6625   // TODO We should be able to do this in-place. At the moment it raises:
6626   //  RuntimeError: linalg_solve_triangular(): functions with out=...
6627   //  arguments don't support automatic differentiation, but one of the
6628   //  arguments requires grad.
6629 
6630   //  at::linalg_solve_triangular_out(dK, U1, dK, /*upper=*/true,
6631   //  /*left=*/false);
6632   dK = at::linalg_solve_triangular(U1, dK, /*upper=*/true, /*left=*/false);
6633 
6634   auto dL1 = L1.matmul(dK.tril(-1));
6635   auto dU1 = dK.triu().matmul(U1);
6636 
6637   if (m == n) {
6638     return std::make_tuple(std::move(dL1), std::move(dU1));
6639   } else if (m < n) {
6640     // we only need to update dU2 defined as
6641     // dU2 := L1^{-1} PdA2 - dK.tril(-1) U2)
6642     const auto PdA2 = PdA.narrow(-1, k, n - k);
6643     const auto U2 = U.narrow(-1, k, n - k);
6644     auto dU2 =
6645         at::linalg_solve_triangular(
6646             L1, PdA2, /*upper=*/false, /*left=*/true, /*unitriangular*/ true) -
6647         dK.tril(-1).matmul(U2);
6648     return std::make_tuple(
6649         std::move(dL1), at::cat({std::move(dU1), std::move(dU2)}, /*dim=*/-1));
6650   } else {
6651     // we only need to update dL2 defined as
6652     // dL2 := PdA2 U^{-1} - L2 dK.triu()
6653     const auto PdA2 = PdA.narrow(-2, k, m - k);
6654     const auto L2 = L.narrow(-2, k, m - k);
6655     auto dL2 =
6656         at::linalg_solve_triangular(U1, PdA2, /*upper=*/true, /*left=*/false) -
6657         L2.matmul(dK.triu());
6658     return std::make_tuple(
6659         at::cat({std::move(dL1), std::move(dL2)}, /*dim=*/-2), std::move(dU1));
6660   }
6661 }
6662 
lu_factor_ex_jvp(const Tensor & dA,const Tensor & LU,const Tensor & pivs,const bool pivot)6663 Tensor lu_factor_ex_jvp(
6664     const Tensor& dA,
6665     const Tensor& LU,
6666     const Tensor& pivs,
6667     const bool pivot) {
6668   auto [P, L, U] =
6669       at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots=*/pivot);
6670   auto [dL, dU] = linalg_lu_jvp(dA, P, L, U, pivot);
6671 
6672   auto m = dA.size(-2);
6673   auto n = dA.size(-1);
6674   if (m >= n) {
6675     dL.narrow(-2, 0, n).add_(dU);
6676     return dL;
6677   } else {
6678     dU.narrow(-1, 0, m).add_(dL);
6679     return dU;
6680   }
6681 }
6682 
logsumexp_jvp(const Tensor & self_p,const Tensor & self_t,IntArrayRef dim,bool keepdim)6683 Tensor logsumexp_jvp(
6684     const Tensor& self_p,
6685     const Tensor& self_t,
6686     IntArrayRef dim,
6687     bool keepdim) {
6688   // NB: for simplicity, we recompute some values that can be reused from
6689   // forward
6690   auto self_p_exp = [&self_p, &dim]() {
6691     if (self_p.sym_numel() > 0) {
6692       // Use only the real part for complex tensors
6693       return (self_p - at::amax(at::real(self_p), dim, true))
6694           .exp(); // Use the exp-normalize trick
6695     } else {
6696       // amax fails if numel() == 0, in which case it doesn't matter anyway
6697       return self_p.exp();
6698     }
6699   }();
6700 
6701   auto sumexp_p = self_p_exp.sum(dim, keepdim);
6702 
6703   // NB: it's OK for logsumexp_jvp to be reused for formulas like
6704   // softmax/log_softmax
6705   //     that only have one differentiable input, because that means self_t are
6706   //     never zerotensors
6707   TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor())
6708   if (areAnyTensorSubclassLike({self_p, self_t})) {
6709     auto result = (self_p_exp * self_t).sum(dim, keepdim);
6710     result /= sumexp_p;
6711     return result;
6712   } else {
6713     self_p_exp *= self_t;
6714     auto sumexp_t = self_p_exp.sum(dim, keepdim);
6715     return sumexp_t /= sumexp_p;
6716   }
6717 }
6718 
safe_logsumexp_jvp(const Tensor & self_p,const Tensor & self_t,IntArrayRef dim,bool keepdim)6719 Tensor safe_logsumexp_jvp(
6720     const Tensor& self_p,
6721     const Tensor& self_t,
6722     IntArrayRef dim,
6723     bool keepdim) {
6724   auto lse_jvp = logsumexp_jvp(self_p, self_t, dim, keepdim);
6725   const auto neg_inf = at::scalar_tensor(
6726       -std::numeric_limits<float>::infinity(),
6727       at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device()));
6728   const auto masked = self_p.eq(neg_inf);
6729   const auto masked_rows = all(masked, dim, true);
6730   const auto zero = at::scalar_tensor(
6731       0.0, at::TensorOptions().dtype(lse_jvp.dtype()).device(lse_jvp.device()));
6732   return at::where(masked_rows, zero, lse_jvp);
6733 }
6734 
warn_backwards(const Tensor & grad_output)6735 Tensor warn_backwards(const Tensor& grad_output) {
6736   TORCH_WARN("Warn from backward");
6737   return grad_output;
6738 }
6739 
6740 // This function only exists because cuDNN does not support bias gradient
6741 // computation and it's not easy to slice a std::tuple to return only grad_input
6742 // / grad_weight from convolution_backward. It will be removed when the
6743 // cudnn_convolution and cudnn_convolution_transpose go away.
_cudnn_convolution_backward(const at::Tensor & self,const at::Tensor & grad_output,const at::Tensor & weight,at::SymIntArrayRef padding,at::SymIntArrayRef output_padding,at::SymIntArrayRef stride,at::SymIntArrayRef dilation,bool transposed,c10::SymInt groups,::std::array<bool,2> output_mask)6744 std::tuple<Tensor, Tensor> _cudnn_convolution_backward(
6745     const at::Tensor& self,
6746     const at::Tensor& grad_output,
6747     const at::Tensor& weight,
6748     at::SymIntArrayRef padding,
6749     at::SymIntArrayRef output_padding,
6750     at::SymIntArrayRef stride,
6751     at::SymIntArrayRef dilation,
6752     bool transposed,
6753     c10::SymInt groups,
6754     ::std::array<bool, 2> output_mask) {
6755   if (!grad_output.defined()) {
6756     return std::tuple<Tensor, Tensor>();
6757   }
6758 
6759   // Just call the general backward and ignore the bias gradient part.
6760   std::tuple<Tensor, Tensor, Tensor> grad_inputs =
6761       at::convolution_backward_symint(
6762           grad_output,
6763           self,
6764           weight,
6765           std::nullopt,
6766           stride,
6767           padding,
6768           dilation,
6769           transposed,
6770           output_padding,
6771           std::move(groups),
6772           {output_mask[0], output_mask[1], false});
6773   std::tuple<Tensor, Tensor> result =
6774       std::make_tuple(std::get<0>(grad_inputs), std::get<1>(grad_inputs));
6775   return result;
6776 }
6777 
scatter_reduce_jvp(const Tensor & self_p,const Tensor & self_t,int dim,const Tensor & index,const Tensor & src_p,const Tensor & src_t,c10::string_view reduce,bool include_self,const Tensor & result)6778 Tensor scatter_reduce_jvp(
6779     const Tensor& self_p,
6780     const Tensor& self_t,
6781     int dim,
6782     const Tensor& index,
6783     const Tensor& src_p,
6784     const Tensor& src_t,
6785     c10::string_view reduce,
6786     bool include_self,
6787     const Tensor& result) {
6788   if (reduce == "sum" || reduce == "mean") {
6789     // The function is linear
6790     return at::scatter_reduce(self_t, dim, index, src_t, reduce, include_self);
6791     //  auto mask = x == restore_reduced_dims(result, dim, keepdim);
6792     //  return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim,
6793     //  keepdim);
6794   } else if (reduce == "amin" || reduce == "amax") {
6795     auto gather_result = at::gather(result, dim, index);
6796     auto mask_self = self_p == result;
6797     auto mask_src = src_p == gather_result;
6798     auto masked_src_t = at::where(mask_src, src_t, 0.);
6799     auto div =
6800         mask_self.to(self_t.dtype())
6801             .scatter_reduce(
6802                 dim, index, mask_src.to(self_t.dtype()), "sum", include_self);
6803     return at::where(mask_self, self_t, 0.)
6804         .scatter_reduce(dim, index, masked_src_t, "sum", include_self)
6805         .div(div);
6806   } else {
6807     // Not implemented
6808     return Tensor{};
6809   }
6810 }
6811 
scatter_reduce_backward(const Tensor & grad,const Tensor & self,int dim,const Tensor & index,const Tensor & src,c10::string_view reduce,bool include_self,const Tensor & result)6812 std::tuple<Tensor, Tensor> scatter_reduce_backward(
6813     const Tensor& grad,
6814     const Tensor& self,
6815     int dim,
6816     const Tensor& index,
6817     const Tensor& src,
6818     c10::string_view reduce,
6819     bool include_self,
6820     const Tensor& result) {
6821   Tensor grad_self, grad_src;
6822 
6823   // FIXME: complex gradients not handled correctly
6824   // For now this is ok as scatter_reduce isn't added to the whitelist
6825   // in tools/autograd/gen_variable_type.py
6826 
6827   if (!grad.defined()) {
6828     return std::make_tuple(grad_self, grad_src);
6829   }
6830 
6831   if (reduce == "sum") {
6832     grad_self = grad;
6833     grad_src = grad.gather(dim, index);
6834   } else if (reduce == "prod") {
6835     // Explicitly compute exclusive prod for elements in self/src that are 0
6836     Tensor masked_self = self.masked_fill(self == 0, 1);
6837     Tensor masked_self_result =
6838         masked_self.scatter_reduce(dim, index, src, reduce, include_self);
6839     grad_self = grad * masked_self_result / masked_self;
6840     Tensor src_zero = src == 0;
6841     Tensor src_num_zeros =
6842         zeros_like(self)
6843             .scatter_add(dim, index, src_zero.to(self.dtype()))
6844             .gather(dim, index);
6845     Tensor src_single_zero = bitwise_and(src_zero, src_num_zeros == 1);
6846     // For src positions with src_single_zero, grad * result.gather(dim,index) /
6847     // src.masked_fill(src_zero, 1) would incorrectly propagate zeros as the
6848     // gradient
6849     Tensor masked_src = src.masked_fill(src_single_zero, 1);
6850     Tensor masked_src_result =
6851         self.scatter_reduce(dim, index, masked_src, reduce, include_self);
6852     Tensor grad_src1 = where(
6853         src_single_zero,
6854         (grad * masked_src_result).gather(dim, index),
6855         (grad * result).gather(dim, index) / src.masked_fill(src_zero, 1));
6856     // GradMode::is_enabled() - adding the autograd Node is a no-op if autograd
6857     // is disabled; this also avoids having the item() call in the usual case.
6858     if (GradMode::is_enabled() && (src_num_zeros > 1).any().item<bool>()) {
6859       auto node = std::make_shared<DelayedError>(
6860           "scatter_reduce(): Double backward is unsupported for src when >1 zeros in src are scattered to the same position in self",
6861           /* num inputs */ 1);
6862       auto result = node->apply({std::move(grad_src1)});
6863       grad_src = result[0];
6864     } else {
6865       grad_src = grad_src1;
6866     }
6867   } else if (reduce == "mean") {
6868     Tensor N = include_self ? ones_like(grad) : zeros_like(grad);
6869     N = N.scatter_add(dim, index, ones_like(src));
6870     N.masked_fill_(N == 0, 1);
6871     grad_self = grad / N;
6872     Tensor N_src = N.gather(dim, index);
6873     grad_src = grad.gather(dim, index) / N_src;
6874   } else if (reduce == "amax" || reduce == "amin") {
6875     // Evenly distribute gradient when there are multiple max/mins
6876     Tensor value = result.gather(dim, index);
6877     Tensor self_is_result = (self == result).to(self.scalar_type());
6878     Tensor src_is_result = (src == value).to(self.scalar_type());
6879     Tensor N_to_distribute =
6880         self_is_result.scatter_add(dim, index, src_is_result);
6881     Tensor grad_distributed = grad / N_to_distribute;
6882     grad_self = (self == result) * grad_distributed;
6883     grad_src = (src == value) * grad_distributed.gather(dim, index);
6884   } else {
6885     AT_ERROR(
6886         "Expected 'reduce' to be one of 'sum', 'prod', 'mean', 'amax', 'amin' but got ",
6887         reduce,
6888         ".");
6889   }
6890 
6891   if (!include_self) {
6892     grad_self = grad_self.scatter(dim, index, 0);
6893   }
6894 
6895   return std::make_tuple(grad_self, grad_src);
6896 }
6897 
_to_copy_backward(const Tensor & grad_,const c10::TensorOptions & self_options)6898 Tensor _to_copy_backward(
6899     const Tensor& grad_,
6900     const c10::TensorOptions& self_options) {
6901   // Handle R->C copies without raising a warning
6902   const auto self_type = self_options.dtype().toScalarType();
6903   auto grad = c10::MaybeOwned<at::Tensor>::borrowed(grad_);
6904   if (!c10::isComplexType(self_type) && grad->is_complex()) {
6905     grad = c10::MaybeOwned<at::Tensor>::owned(at::real(grad_));
6906   }
6907 
6908   return grad->to(self_options, /*non_blocking=*/false, /*copy=*/false);
6909 }
6910 
index_reduce_backward(const Tensor & grad,const Tensor & self,int dim,const Tensor & index,const Tensor & source,c10::string_view reduce,bool include_self,const Tensor & result)6911 std::tuple<Tensor, Tensor> index_reduce_backward(
6912     const Tensor& grad,
6913     const Tensor& self,
6914     int dim,
6915     const Tensor& index,
6916     const Tensor& source,
6917     c10::string_view reduce,
6918     bool include_self,
6919     const Tensor& result) {
6920   Tensor grad_self, grad_src;
6921 
6922   // FIXME: index_add's backward formula has a special case for source.dim == 0
6923   // but this case seems to throw the error "IndexError: dimension specified as
6924   // 0 but tensor has no dimensions" look into whether this case is reachable
6925   // and should be covered here
6926 
6927   if (!grad.defined()) {
6928     return std::make_tuple(grad_self, grad_src);
6929   }
6930 
6931   if (reduce == "prod") {
6932     Tensor masked_self = self.masked_fill(self == 0, 1);
6933     Tensor masked_self_result =
6934         masked_self.index_reduce(dim, index, source, reduce, include_self);
6935     grad_self = grad * masked_self_result / masked_self;
6936     Tensor src_zero = source == 0;
6937     Tensor src_num_zeros = zeros_like(self)
6938                                .index_add(dim, index, src_zero.to(self.dtype()))
6939                                .index_select(dim, index);
6940     Tensor src_single_zero = bitwise_and(src_zero, src_num_zeros == 1);
6941     // For src positions with src_single_zero, (grad *
6942     // result).index_select(dim,index) / source.masked_fill(src_zero, 1) would
6943     // incorrectly propagate zeros as the gradient
6944     Tensor masked_src = source.masked_fill(src_single_zero, 1);
6945     Tensor masked_src_result =
6946         self.index_reduce(dim, index, masked_src, reduce, include_self);
6947     Tensor grad_src1 = where(
6948         src_single_zero,
6949         (grad * masked_src_result).index_select(dim, index),
6950         (grad * result).index_select(dim, index) /
6951             source.masked_fill(src_zero, 1));
6952     // GradMode::is_enabled() - adding the autograd Node is a no-op if autograd
6953     // is disabled this also avoids having the item() call in the usual case
6954     if (GradMode::is_enabled() && (src_num_zeros > 1).any().item<bool>()) {
6955       auto node = std::make_shared<DelayedError>(
6956           "index_reduce(): Double backward is unsupported for source when >1 zeros in source are scattered to the same position in self",
6957           /* num inputs */ 1);
6958       auto result = node->apply({std::move(grad_src1)});
6959       grad_src = result[0];
6960     } else {
6961       grad_src = grad_src1;
6962     }
6963   } else if (reduce == "mean") {
6964     Tensor N = include_self ? ones_like(grad) : zeros_like(grad);
6965     N = N.index_add(dim, index, ones_like(source));
6966     N.masked_fill_(N == 0, 1);
6967     grad_self = grad / N;
6968     Tensor N_src = N.index_select(dim, index);
6969     grad_src = grad.index_select(dim, index) / N_src;
6970   } else if (reduce == "amax" || reduce == "amin") {
6971     Tensor value = result.index_select(dim, index);
6972     Tensor self_is_result = (self == result).to(self.scalar_type());
6973     Tensor source_is_result = (source == value).to(self.scalar_type());
6974     Tensor N_to_distribute =
6975         self_is_result.index_add(dim, index, source_is_result);
6976     Tensor grad_distributed = grad / N_to_distribute;
6977     grad_self = self_is_result * grad_distributed;
6978     grad_src = source_is_result * grad_distributed.index_select(dim, index);
6979   } else {
6980     AT_ERROR(
6981         "Expected 'reduce' to be one of 'prod', 'amax', 'amin' or 'mean' but got ",
6982         reduce,
6983         ".");
6984   }
6985 
6986   if (!include_self) {
6987     grad_self = grad_self.index_fill(dim, index, 0);
6988   }
6989 
6990   return std::make_tuple(grad_self, grad_src);
6991 }
6992 
take_backward(const Tensor & grad,const Tensor & self,const Tensor & indices)6993 Tensor take_backward(
6994     const Tensor& grad,
6995     const Tensor& self,
6996     const Tensor& indices) {
6997   Tensor grad_self = at::zeros_like(self);
6998   // For Composite Compliance,
6999   // if `grad` and `indices` are CCT but `grad_self` is not
7000   // then we use the out-of-place variant of `put`.
7001   if (areAnyTensorSubclassLike({grad, indices})) {
7002     return grad_self.put(indices, grad, true);
7003   }
7004   return grad_self.put_(indices, grad, true);
7005 }
7006 
to_sparse_backward(const Tensor & grad,const c10::Layout self_layout,const c10::OptionalArrayRef<c10::SymInt> & self_blocksize)7007 Tensor to_sparse_backward(
7008     const Tensor& grad,
7009     const c10::Layout self_layout,
7010     const c10::OptionalArrayRef<c10::SymInt>& self_blocksize) {
7011   // Path for strided and nested
7012   if (self_layout == c10::kStrided) {
7013     return grad.to_dense();
7014   } else {
7015     OptionalIntArrayRef blocksize = std::nullopt;
7016     if (self_blocksize.has_value()) {
7017       blocksize = c10::asIntArrayRefSlowOpt(*self_blocksize);
7018     }
7019     return grad.to_sparse(self_layout, blocksize);
7020   }
7021 }
7022 
7023 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>
mkldnn_rnn_layer_differentiable_backward(const Tensor & input,const Tensor & weight0,const Tensor & weight1,const Tensor & weight2,const Tensor & weight3,const Tensor & hx_,const Tensor & cx_tmp,const Tensor & output,const Tensor & hy_,const Tensor & cy_,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,bool reverse,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool train,bool bidirectional,at::IntArrayRef batch_sizes,bool batch_first,const at::Tensor & workspace)7024 mkldnn_rnn_layer_differentiable_backward(
7025     const Tensor& input,
7026     const Tensor& weight0,
7027     const Tensor& weight1,
7028     const Tensor& weight2,
7029     const Tensor& weight3,
7030     const Tensor& hx_,
7031     const Tensor& cx_tmp,
7032     const Tensor& output,
7033     const Tensor& hy_,
7034     const Tensor& cy_,
7035     const std::optional<Tensor>& grad_output_r_opt,
7036     const std::optional<Tensor>& grad_hy_r_opt,
7037     const std::optional<Tensor>& grad_cy_r_opt,
7038     bool reverse,
7039     int64_t mode,
7040     int64_t hidden_size,
7041     int64_t num_layers,
7042     bool has_biases,
7043     bool train,
7044     bool bidirectional,
7045     at::IntArrayRef batch_sizes,
7046     bool batch_first,
7047     const at::Tensor& workspace) {
7048   const Tensor& grad_output_r =
7049       c10::value_or_else(grad_output_r_opt, [] { return Tensor(); });
7050   const Tensor& grad_hy_r =
7051       c10::value_or_else(grad_hy_r_opt, [] { return Tensor(); });
7052   const Tensor& grad_cy_r =
7053       c10::value_or_else(grad_cy_r_opt, [] { return Tensor(); });
7054   if (!grad_output_r.defined() && !grad_hy_r.defined() &&
7055       !grad_cy_r.defined()) {
7056     return std::make_tuple(
7057         Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor());
7058   }
7059   auto grad_output = grad_output_r.defined()
7060       ? grad_output_r.contiguous()
7061       : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
7062   auto grad_hy = grad_hy_r.defined()
7063       ? grad_hy_r.contiguous()
7064       : at::zeros_like(hx_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
7065   auto grad_cy = cx_tmp.defined()
7066       ? (grad_cy_r.defined()
7067              ? grad_cy_r.contiguous()
7068              : at::zeros_like(cx_tmp, LEGACY_CONTIGUOUS_MEMORY_FORMAT))
7069       : grad_cy_r.contiguous();
7070   Tensor bias_ih, bias_hh;
7071   if (has_biases) {
7072     bias_ih = weight2;
7073     bias_hh = weight3;
7074   } else {
7075     bias_ih = at::zeros(
7076         {4 /* num_bias_gates of LSTM */ * hidden_size}, weight0.options());
7077     bias_hh = at::zeros(
7078         {4 /* num_bias_gates of LSTM */ * hidden_size}, weight0.options());
7079   }
7080   const auto& input_ = input;
7081   auto hx_prev = hx_;
7082   auto cx_prev = cx_tmp;
7083 
7084   // Re-calculate gates and hidden states during one layer, which will be used
7085   // in backward.
7086   int64_t seq_length = input.size(0);
7087   std::vector<std::tuple<Tensor, Tensor, Tensor, Tensor>> layer_gates(
7088       seq_length);
7089   std::vector<std::tuple<Tensor, Tensor>> layer_states(seq_length + 1);
7090   layer_states[0] = std::make_tuple(hx_, cx_tmp);
7091   for (int64_t seq = 1; seq < seq_length + 1; seq++) {
7092     auto hx = hx_prev;
7093     auto cx = cx_prev;
7094     auto x_index = reverse ? seq_length - seq : seq - 1;
7095     auto gate = at::linear(input_[x_index], weight0, bias_ih)
7096                     .add_(at::linear(hx, weight1, bias_hh));
7097     auto chunked_gates = gate.unsafe_chunk(4, 1);
7098     auto i = chunked_gates[0].sigmoid_();
7099     auto f = chunked_gates[1].sigmoid_();
7100     auto g = chunked_gates[2].tanh_();
7101     auto o = chunked_gates[3].sigmoid_();
7102     layer_gates[x_index] = std::make_tuple(i, f, g, o);
7103     auto cy = (f * cx).add(i * g);
7104     auto hy = o * cy.tanh();
7105     layer_states[seq] = std::make_tuple(hy, cy);
7106     hx_prev = hy;
7107     cx_prev = cy;
7108   }
7109 
7110   Tensor dx, dWx, dWh, db, db_, dprev_h, dprev_c, dWh_, dWx_;
7111   Tensor new_grad_hy, d1, dgp, dip, dfp, dop, do_, dg, df, di, da;
7112   std::vector<at::Tensor> layer_dx(seq_length);
7113   for (int64_t seq = seq_length - 1; seq >= 0; seq--) {
7114     int64_t x_index = reverse ? seq_length - seq - 1 : seq;
7115     auto i = std::get<0>(layer_gates[x_index]);
7116     auto f = std::get<1>(layer_gates[x_index]);
7117     auto g = std::get<2>(layer_gates[x_index]);
7118     auto o = std::get<3>(layer_gates[x_index]);
7119     auto hy = std::get<0>(layer_states[seq + 1]);
7120     auto cy = std::get<1>(layer_states[seq + 1]);
7121     auto hx = std::get<0>(layer_states[seq]);
7122     auto cx = std::get<1>(layer_states[seq]);
7123     new_grad_hy = grad_output[x_index].add(grad_hy);
7124     d1 = grad_cy.add(new_grad_hy * o * (1 - cy.tanh() * cy.tanh()));
7125     dgp = d1 * i;
7126     dip = d1 * g;
7127     dprev_c = d1 * f;
7128     dfp = d1 * cx;
7129     dop = new_grad_hy * cy.tanh();
7130     do_ = dop * o * (1 - o);
7131     dg = dgp * (1 - g * g);
7132     df = dfp * f * (1 - f);
7133     di = dip * i * (1 - i);
7134     da = at::cat({di, df, dg, do_}, 1);
7135     db_ = at::sum(da, 0);
7136     dx = at::matmul(da, weight0);
7137     dx = at::unsqueeze(dx, 0);
7138     dprev_h = at::matmul(da, weight1);
7139     dWx_ = at::matmul(da.transpose(0, 1), input_[x_index]);
7140     dWh_ = at::matmul(da.transpose(0, 1), hx);
7141     if (seq == seq_length - 1) {
7142       db = db_;
7143       dWx = dWx_;
7144       dWh = dWh_;
7145     } else {
7146       db += db_;
7147       dWx += dWx_;
7148       dWh += dWh_;
7149     }
7150     layer_dx[x_index] = dx;
7151     grad_hy = dprev_h;
7152     grad_cy = dprev_c;
7153   }
7154 
7155   auto cat_layer_dx = at::cat(layer_dx, 0);
7156   return std::make_tuple(cat_layer_dx, dWx, dWh, db, db, dprev_h, dprev_c);
7157 }
7158 
values_backward(const Tensor & grad,const Tensor & self)7159 Tensor values_backward(const Tensor& grad, const Tensor& self) {
7160   Tensor grad_self;
7161   if (grad.defined()) {
7162     if (self.layout() == c10::kSparse) {
7163       return at::_sparse_coo_tensor_unsafe_symint(
7164           self.indices(),
7165           grad,
7166           self.sym_sizes(),
7167           self.options(),
7168           /*is_coalesced=*/true);
7169     } else if (at::sparse_csr::is_sparse_compressed(self)) {
7170       auto [compressed_indices, plain_indices] =
7171           at::sparse_csr::getCompressedPlainIndices(self);
7172       return at::_sparse_compressed_tensor_unsafe_symint(
7173           compressed_indices,
7174           plain_indices,
7175           grad,
7176           self.sym_sizes(),
7177           self.options());
7178     } else {
7179       TORCH_CHECK_NOT_IMPLEMENTED(
7180           false,
7181           "values backward with respect to self with layout ",
7182           self.layout());
7183     }
7184   }
7185   return grad_self;
7186 }
7187 
7188 } // namespace torch::autograd::generated::details
7189