1 #pragma once
2
3 // NB: Must be at the top of file to avoid including the deprecated "math.h".
4 // https://stackoverflow.com/questions/6563810/m-pi-works-with-math-h-but-not-with-cmath-in-visual-studio
5 #ifdef _MSC_VER
6 #ifndef _USE_MATH_DEFINES
7 #define _USE_MATH_DEFINES
8 #endif
9 #include <cmath>
10 #endif
11
12 #include <ATen/ATen.h>
13 #include <torch/csrc/autograd/generated/Functions.h>
14
15 namespace torch::autograd::generated::details {
16
17 extern const char* kCudnnDoubleBackwardMsg;
18
19 // A simple way to imperatively compute index ranges for slots
20 // that have been flattened
21 struct TORCH_API IndexRangeGenerator {
rangeIndexRangeGenerator22 IndexRange range(size_t range_size) {
23 i += range_size;
24 return {i - range_size, i};
25 }
sizeIndexRangeGenerator26 size_t size() {
27 return i;
28 }
29
30 private:
31 size_t i = 0;
32 };
33
34 TORCH_API Tensor toNonOptFwGrad(const std::optional<Tensor>& t);
35 TORCH_API Tensor toNonOptPrimal(const std::optional<Tensor>& t);
36 TORCH_API Tensor toNonOptTensor(const std::optional<Tensor>& t);
37
wrap_opt_if(const Tensor & t,const bool cond)38 TORCH_API inline std::optional<Tensor> wrap_opt_if(
39 const Tensor& t,
40 const bool cond) {
41 using OptTensor = std::optional<Tensor>;
42 return cond ? OptTensor(t) : static_cast<OptTensor>(std::nullopt);
43 }
44
45 TORCH_API Tensor
46 apply_loss_reduction(const Tensor& unreduced, int64_t reduction);
47 TORCH_API bool any_variable_defined(const variable_list& variables);
48 TORCH_API void copy_range(
49 variable_list& out,
50 IndexRange range,
51 const at::Tensor& t);
52 TORCH_API void copy_range(
53 variable_list& out,
54 IndexRange range,
55 at::ArrayRef<at::Tensor> t);
56 TORCH_API at::Tensor copysign_tensor_self_backward(
57 const Tensor& grad,
58 const Tensor& self,
59 const Tensor& result);
60 TORCH_API at::Tensor not_implemented(const char* name, const char* reason = "");
61 TORCH_API std::vector<Tensor> not_implemented_list(
62 const char* name,
63 const char* reason = "");
64 at::Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result);
65 at::Tensor maybe_multiply(const at::Tensor& t, const at::Scalar& s);
66 int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim);
67 Tensor restore_reduced_dims(
68 const Tensor& output,
69 IntArrayRef dims,
70 bool keepdim);
71 Tensor scale_grad_by_count(
72 const Tensor& grad,
73 const Tensor& mask,
74 IntArrayRef dims);
75 at::Tensor norm_backward(
76 const at::Tensor& grad,
77 const at::Tensor& self,
78 const std::optional<at::Scalar>& p_,
79 const at::Tensor& norm);
80 at::Tensor norm_backward(
81 at::Tensor grad,
82 const at::Tensor& self,
83 const std::optional<at::Scalar>& p_,
84 at::Tensor norm,
85 at::IntArrayRef dim,
86 bool keepdim);
87 Tensor norm_jvp(
88 const Tensor& self_p,
89 const Tensor& self_t,
90 const std::optional<Scalar>& p_,
91 Tensor norm,
92 IntArrayRef dim,
93 bool keepdim);
94 Tensor norm_jvp(
95 const Tensor& grad,
96 const Tensor& self,
97 const std::optional<Scalar>& p_,
98 Tensor norm);
99 Tensor _nested_from_padded_backward(
100 const Tensor& grad,
101 const Tensor& input,
102 const bool do_transform_0213);
103 std::tuple<Tensor, Tensor, Tensor> linear_double_backward(
104 const variable_list& grads,
105 const Tensor& self,
106 const Tensor& grad_output,
107 const Tensor& weight);
108 Tensor linalg_vector_norm_jvp(
109 const Tensor& self_p,
110 const Tensor& self_t,
111 const Scalar& scalar_ord,
112 Tensor norm,
113 const at::OptionalIntArrayRef& opt_dim,
114 bool keepdim);
115 at::Tensor linalg_vector_norm_backward(
116 at::Tensor grad,
117 const at::Tensor& self,
118 const at::Scalar& ord,
119 at::Tensor norm,
120 const at::OptionalIntArrayRef& opt_dim,
121 bool keepdim);
122 at::Tensor pow_backward(
123 at::Tensor grad,
124 const at::Tensor& self,
125 const at::Scalar& exponent_);
126 at::Tensor pow_backward_self(
127 const at::Tensor& grad,
128 const at::Tensor& self,
129 const at::Tensor& exponent);
130 at::Tensor pow_backward_exponent(
131 const at::Tensor& grad,
132 const at::Tensor& self,
133 const at::Tensor& exponent,
134 const at::Tensor& result);
135 at::Tensor pow_backward_exponent(
136 const at::Tensor& grad,
137 const at::Scalar& base,
138 const at::Tensor& exponent,
139 const at::Tensor& result);
140 at::Tensor angle_backward(const at::Tensor& grad, const at::Tensor& self);
141 template <typename T>
142 at::Tensor mul_tensor_backward(const Tensor& grad, T other, ScalarType self_st);
143 template <typename T>
144 at::Tensor div_tensor_self_backward(
145 const Tensor& grad,
146 T other,
147 ScalarType self_st);
148 at::Tensor div_tensor_other_backward(
149 const Tensor& grad,
150 const Tensor& self,
151 const Tensor& other);
152 template <typename T>
153 at::Tensor div_tensor_self_backward(
154 const Tensor& grad,
155 T other,
156 ScalarType self_st,
157 const std::optional<c10::string_view>& rounding_mode);
158 at::Tensor div_tensor_other_backward(
159 const Tensor& grad,
160 const Tensor& self,
161 const Tensor& other,
162 const std::optional<c10::string_view>& rounding_mode);
163 at::Tensor mvlgamma_backward(
164 const at::Tensor& grad,
165 const at::Tensor& self,
166 int64_t p);
167 at::Tensor permute_backwards(const at::Tensor& grad, at::IntArrayRef fwd_dims);
168 at::Tensor rad2deg_backward(const at::Tensor& grad);
169 at::Tensor deg2rad_backward(const at::Tensor& grad);
170 at::Tensor unsqueeze_multiple(
171 const at::Tensor& t,
172 at::OptionalIntArrayRef opt_dim,
173 size_t n_dims);
174 at::Tensor sum_backward(
175 const at::Tensor& grad,
176 at::SymIntArrayRef sizes,
177 at::OptionalIntArrayRef opt_dims,
178 bool keepdim);
179 at::Tensor sum_backward(
180 const at::Tensor& grad,
181 c10::SymIntArrayRef sizes,
182 c10::IntArrayRef dims,
183 bool keepdim);
184 at::Tensor nansum_backward(
185 const at::Tensor& grad,
186 const at::Tensor& self,
187 at::OptionalIntArrayRef dims,
188 bool keepdim);
189 std::vector<int64_t> reverse_list(const at::IntArrayRef list);
190 std::vector<c10::SymInt> reverse_list_symint(const c10::SymIntArrayRef list);
191 at::Tensor reverse_dim(const at::Tensor& t, int64_t dim);
192 at::Tensor prod_safe_zeros_backward(
193 const at::Tensor& grad,
194 const at::Tensor& inp,
195 int64_t dim);
196 at::Tensor prod_backward(
197 const at::Tensor& grad,
198 const at::Tensor& input,
199 const at::Tensor& result);
200 at::Tensor prod_backward(
201 at::Tensor grad,
202 const at::Tensor& input,
203 at::Tensor result,
204 int64_t dim,
205 bool keepdim);
206 at::Tensor solve_jvp(
207 const Tensor& X,
208 const Tensor& A,
209 const Tensor& dA,
210 const Tensor& dB);
211 at::Tensor solve_backward_self(
212 const at::Tensor& grad,
213 const at::Tensor& self,
214 const at::Tensor& A);
215 at::Tensor solve_backward_A(
216 const at::Tensor& grad,
217 const at::Tensor& self,
218 const at::Tensor& A,
219 const at::Tensor& solution);
220 at::Tensor cumsum_backward(const at::Tensor& grad, int64_t dim);
221 at::Tensor logsumexp_backward(
222 at::Tensor grad,
223 const at::Tensor& self,
224 at::Tensor result,
225 at::IntArrayRef dim,
226 bool keepdim);
227 at::Tensor logsumexp_jvp(
228 const at::Tensor& self_p,
229 const at::Tensor& self_t,
230 IntArrayRef dim,
231 bool keepdim);
232 at::Tensor safe_logsumexp_jvp(
233 const at::Tensor& self_p,
234 const at::Tensor& self_t,
235 IntArrayRef dim,
236 bool keepdim);
237 at::Tensor logcumsumexp_backward(
238 at::Tensor grad,
239 const at::Tensor& self,
240 const at::Tensor& result,
241 int64_t dim);
242 at::Tensor logcumsumexp_jvp(
243 const at::Tensor& self_p,
244 const at::Tensor& self_t,
245 int64_t dim);
246 at::Tensor unbind_backward(const variable_list& grads, int64_t dim);
247 at::Tensor unbind_backward_nested(
248 const variable_list& grads,
249 const Tensor& nt_sizes,
250 int64_t dim,
251 const at::TensorOptions& options);
252 at::Tensor unbind_backward_nested_jagged(
253 const variable_list& grads,
254 const Tensor& self,
255 int64_t dim);
256 at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes);
257 at::Tensor unsqueeze_to(
258 const at::Tensor& self,
259 int64_t dim,
260 c10::SymIntArrayRef sym_sizes);
261 at::Tensor unsqueeze_to(
262 const at::Tensor& self,
263 IntArrayRef dim,
264 c10::SymIntArrayRef sym_sizes);
265 std::vector<at::Tensor> cat_tensors_backward(
266 const at::Tensor& grad,
267 const std::vector<std::vector<c10::SymInt>>& sizes,
268 const std::vector<ScalarType>& dtypes,
269 int64_t dim);
270 std::vector<at::Tensor> stack_tensors_backward(
271 const at::Tensor& grad,
272 int64_t dim,
273 const std::vector<ScalarType>& dtypes);
274 std::vector<at::Tensor> block_diag_backward(
275 const at::Tensor& grad,
276 const std::vector<std::vector<int64_t>>& sizes,
277 const std::vector<ScalarType>& dtypes);
278 at::Tensor clamp_backward(
279 const at::Tensor& grad,
280 const at::Tensor& self,
281 const std::optional<at::Scalar>& min,
282 const std::optional<at::Scalar>& max);
283 at::Tensor clamp_backward(
284 const at::Tensor& grad,
285 const at::Tensor& self,
286 const at::Tensor& min,
287 const at::Tensor& max);
288 std::tuple<at::Tensor, at::Tensor> clamp_backward_min_max(
289 const at::Tensor& grad,
290 const at::Tensor& self,
291 const at::Tensor& min,
292 const at::Tensor& max,
293 const std::array<bool, 2>&);
294 at::Tensor clamp_jvp(
295 const Tensor& self_p,
296 const Tensor& self_t,
297 const Tensor& min_p,
298 const Tensor& min_t,
299 const Tensor& max_p,
300 const Tensor& max_t);
301 at::SymIntArrayRef strides_or_error(
302 const Tensor& input,
303 c10::string_view const& input_name);
304 at::Tensor mm_mat1_backward(
305 const Tensor& grad,
306 const Tensor& mat2,
307 at::SymIntArrayRef mat1_sizes,
308 at::SymIntArrayRef mat1_strides,
309 c10::Layout mat1_layout,
310 const Scalar& alpha);
311 at::Tensor mm_mat2_backward(
312 const at::Tensor& grad,
313 const at::Tensor& mat1,
314 at::SymIntArrayRef sizes,
315 at::SymIntArrayRef strides,
316 c10::Layout layout,
317 const at::Scalar& alpha);
318 at::Tensor mm_mat1_sparse_backward(
319 const at::Tensor& grad,
320 const at::Tensor& mat1,
321 const at::Tensor& mat2,
322 const at::Scalar& alpha);
323 std::tuple<Tensor, Tensor, Tensor> sparse_sampled_addmm_backward(
324 const Tensor& grad,
325 const Tensor& self,
326 const std::optional<Tensor>& mat1,
327 const std::optional<Tensor>& mat2,
328 const Scalar& alpha,
329 const Scalar& beta,
330 const std::array<bool, 3>& grad_input_mask);
331 at::Tensor sparse_mask_backward(
332 const at::Tensor& grad,
333 const at::Tensor& mask,
334 c10::Layout self_layout);
335 at::Tensor sparse_sparse_matmul_backward(
336 const at::Tensor& grad,
337 const at::Tensor& mat1,
338 const at::Tensor& mat2,
339 int64_t grad_order);
340 at::Tensor renorm_backward(
341 const at::Tensor& grad,
342 const at::Tensor& self,
343 const at::Scalar& p,
344 int64_t dim,
345 const at::Scalar& maxnorm);
346 at::Tensor renorm_jvp(
347 const at::Tensor& self_p,
348 const at::Tensor& self_t,
349 const at::Scalar& p,
350 int64_t dim,
351 const at::Scalar& maxnorm);
352 at::Tensor repeat_backward(
353 at::Tensor grad,
354 at::SymIntArrayRef repeats,
355 at::SymIntArrayRef input_shape);
356 at::Tensor _fused_dropout_backward(
357 const at::Tensor& grad,
358 const at::Tensor& mask,
359 double p1m);
360 at::Tensor infinitely_differentiable_native_dropout_backward(
361 const at::Tensor& grad,
362 const at::Tensor& mask,
363 double scale);
364 at::Tensor native_dropout_double_backward(
365 const at::Tensor& ggI,
366 const at::Tensor& grad,
367 const at::Tensor& mask,
368 double scale);
369 at::Tensor evenly_distribute_backward(
370 const at::Tensor& grad,
371 const at::Tensor& input,
372 const at::Tensor& value);
373 Tensor sgn_backward(const Tensor& x, const Tensor& gx, const Tensor& sgn);
374 Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask);
375 at::Tensor var_backward(
376 at::Tensor grad,
377 const at::Tensor& self,
378 at::OptionalIntArrayRef dim,
379 const std::optional<c10::Scalar>& correction,
380 bool keepdim);
381 at::Tensor var_jvp(
382 const at::Tensor& self_t,
383 const at::Tensor& self_p,
384 const at::Tensor& result,
385 at::OptionalIntArrayRef dim_opt,
386 const std::optional<c10::Scalar>& correction,
387 bool keepdim);
388 at::Tensor std_backward(
389 const at::Tensor& result,
390 const at::Tensor& grad,
391 const at::Tensor& self,
392 at::OptionalIntArrayRef dim,
393 const std::optional<c10::Scalar>& correction,
394 bool keepdim);
395 Tensor mean_backward(
396 const Tensor& grad,
397 c10::SymIntArrayRef shape,
398 at::OptionalIntArrayRef opt_dim,
399 c10::SymInt numel,
400 bool keepdim);
401 Tensor var_mean_backward(
402 const Tensor& gvar,
403 const Tensor& gmean,
404 const Tensor& self,
405 at::OptionalIntArrayRef dim_opt,
406 const std::optional<c10::Scalar>& correction,
407 bool keepdim);
408 Tensor std_mean_backward(
409 const Tensor& gstd,
410 const Tensor& gmean,
411 const Tensor& self,
412 const Tensor& std,
413 at::OptionalIntArrayRef dim_opt,
414 const std::optional<c10::Scalar>& correction,
415 bool keepdim);
416 at::Tensor cholesky_backward(
417 const at::Tensor& grad,
418 bool upper,
419 const at::Tensor& L);
420 at::Tensor cholesky_jvp(
421 const at::Tensor& input_tangent,
422 const at::Tensor& L,
423 bool upper);
424 at::Tensor cholesky_inverse_backward(
425 const at::Tensor& grad,
426 const at::Tensor& L,
427 bool upper,
428 const at::Tensor& inverse);
429 at::Tensor cholesky_inverse_jvp(
430 const at::Tensor& F,
431 const at::Tensor& dF,
432 const at::Tensor& X,
433 bool upper);
434 Tensor pinv_jvp(const Tensor& A, const Tensor& pinvA, const Tensor& dA);
435 Tensor pinv_backward(const Tensor& grad, const Tensor& pinvA, const Tensor& A);
436 Tensor chunk_backward_nested(
437 const std::vector<torch::autograd::Variable>& grads,
438 const Tensor& self,
439 int64_t chunks,
440 int64_t dim);
441 at::Tensor split_with_sizes_backward(
442 const std::vector<torch::autograd::Variable>& grads,
443 c10::SymIntArrayRef split_sizes,
444 int64_t dim,
445 c10::SymIntArrayRef sizes,
446 const at::TensorOptions& options);
447 at::Tensor _nested_split_with_sizes_backward(
448 const std::vector<torch::autograd::Variable>& grads,
449 c10::SymIntArrayRef split_sizes,
450 int64_t dim,
451 const Tensor& nt_sizes,
452 const at::TensorOptions& options);
453 at::Tensor split_backward(
454 const std::vector<torch::autograd::Variable>& grads,
455 const c10::SymInt& split_size,
456 int64_t dim,
457 c10::SymIntArrayRef sizes,
458 const at::TensorOptions& options);
459 at::Tensor max_pool_double_backward(
460 const at::Tensor& grad,
461 const at::Tensor& indices,
462 int dim);
463 at::Tensor error_for_max_pool2d_double_backward();
464 at::Tensor glu_double_backward(
465 const at::Tensor& grad,
466 const at::Tensor& grad_output,
467 const at::Tensor& input,
468 int64_t dim);
469 at::Tensor glu_double_backward_grad_output(
470 const at::Tensor& grad,
471 const at::Tensor& input,
472 int64_t dim);
473 at::Tensor infinitely_differentiable_silu_backward(
474 const at::Tensor& grad_output,
475 const at::Tensor& input);
476 at::Tensor infinitely_differentiable_mish_backward(
477 const at::Tensor& grad_output,
478 const at::Tensor& input);
479 Tensor infinitely_differentiable_logit_backward(
480 const Tensor& grad,
481 const Tensor& self,
482 std::optional<double> eps);
483 Tensor binary_cross_entropy_target_backward(
484 const Tensor& grad,
485 const Tensor& self,
486 const Tensor& target,
487 const std::optional<Tensor>& weight,
488 int64_t reduction);
489 Tensor binary_cross_entropy_double_backward_target(
490 const Tensor& grad,
491 const Tensor& grad_output,
492 const Tensor& self,
493 const Tensor& target,
494 const std::optional<Tensor>& weight,
495 int64_t reduction);
496 Tensor binary_cross_entropy_with_logits_backward(
497 const Tensor& grad,
498 const Tensor& input,
499 const Tensor& target,
500 const std::optional<Tensor>& weight_opt,
501 const std::optional<Tensor>& pos_weight_opt,
502 int64_t reduction);
503 at::Tensor binary_cross_entropy_with_logits_target_backward(
504 const at::Tensor& grad_output,
505 const at::Tensor& self,
506 const at::Tensor& target,
507 const std::optional<at::Tensor>& weight,
508 const std::optional<at::Tensor>& pos_weight,
509 int64_t reduction);
510 at::Tensor log_sigmoid_double_backward(
511 const at::Tensor& grad,
512 const at::Tensor& input);
513 at::Tensor softmax_double_backward(
514 const at::Tensor& grad,
515 const at::Tensor& grad_output,
516 int dim,
517 const at::Tensor& output);
518 at::Tensor binary_cross_entropy_double_backward(
519 const at::Tensor& grad_output,
520 const at::Tensor& grad,
521 const at::Tensor& input,
522 const at::Tensor& target,
523 const std::optional<at::Tensor>& weight,
524 int64_t reduction);
525 at::Tensor binary_cross_entropy_double_backward_grad_output(
526 const at::Tensor& grad,
527 const at::Tensor& input,
528 const at::Tensor& target,
529 const std::optional<at::Tensor>& weight,
530 int64_t reduction);
531 at::Tensor smooth_l1_loss_double_backward(
532 const at::Tensor& grad,
533 const at::Tensor& input,
534 const at::Tensor& target,
535 int64_t reduction,
536 double beta);
537 at::Tensor huber_loss_double_backward(
538 const at::Tensor& grad,
539 const at::Tensor& input,
540 const at::Tensor& target,
541 int64_t reduction,
542 double delta);
543 at::Tensor huber_loss_double_backward_grad_output(
544 const at::Tensor& grad,
545 const at::Tensor& grad_output,
546 const at::Tensor& input,
547 const at::Tensor& target,
548 int64_t reduction,
549 double delta);
550 at::Tensor mse_loss_double_backward(
551 const at::Tensor& grad,
552 const at::Tensor& input,
553 int64_t reduction);
554 at::Tensor soft_margin_loss_double_backward(
555 const at::Tensor& grad,
556 const at::Tensor& input,
557 const at::Tensor& target,
558 int64_t reduction);
559 at::Tensor soft_margin_loss_double_backward_grad_output(
560 const at::Tensor& grad,
561 const at::Tensor& grad_output,
562 const at::Tensor& input,
563 const at::Tensor& target,
564 int64_t reduction);
565 at::Tensor softplus_double_backward(
566 const at::Tensor& grad,
567 const at::Tensor& input,
568 const at::Scalar& beta,
569 const at::Scalar& threshold);
570 std::tuple<at::Tensor, at::Tensor> slogdet_jvp(
571 const at::Tensor& LU,
572 const at::Tensor& pivots,
573 const at::Tensor& dA,
574 const at::Tensor& sign,
575 const bool use_A_T);
576 at::Tensor slogdet_backward(
577 const at::Tensor& grad_sign,
578 const at::Tensor& grad_logabsdet,
579 const at::Tensor& A,
580 const at::Tensor& signdet,
581 const at::Tensor& LU,
582 const at::Tensor& pivots);
583 at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self);
584 at::Tensor sinc_backward(const at::Tensor& grad, const at::Tensor& self);
585 at::Tensor sparse_constructor_values_backward(
586 const at::Tensor& sparse_grad_out,
587 const at::Tensor& indices);
588 at::Tensor embedding_dense_double_backward_symint(
589 const at::Tensor& grad,
590 const at::Tensor& indices,
591 const c10::SymInt& padding_idx);
592 at::Tensor index_backward(
593 at::Tensor zeros_like_self,
594 const torch::List<std::optional<Tensor>>& indices,
595 const at::Tensor& grad);
596 at::Tensor _cudnn_ctc_loss_backward(
597 const at::Tensor& grad_out,
598 const at::Tensor& loss,
599 const at::Tensor& raw_grad,
600 bool zero_infinity);
601 at::Tensor elu_double_backward(
602 const Tensor& grad,
603 const Tensor& grad_output,
604 const Scalar& alpha,
605 const Scalar& scale,
606 const Scalar& input_scale,
607 bool is_result,
608 const Tensor& self_or_result);
609
610 Tensor svd_backward(
611 const Tensor& gU,
612 const Tensor& gS,
613 const Tensor& gVh,
614 const Tensor& U,
615 const Tensor& S,
616 const Tensor& Vh);
617
618 std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(
619 const Tensor& dA,
620 const Tensor& U,
621 const Tensor& S,
622 const Tensor& Vh,
623 const bool full_matrices);
624 Tensor slice_backward_wrapper(
625 const at::Tensor& grad,
626 const c10::SymIntArrayRef& input_sizes,
627 int64_t dim,
628 std::optional<c10::SymInt> start,
629 std::optional<c10::SymInt> end,
630 c10::SymInt step);
631 std::tuple<Tensor, Tensor> linalg_eig_jvp(
632 const Tensor& dA,
633 const Tensor& L,
634 const Tensor& V,
635 const bool is_hermitian);
636 Tensor linalg_eig_backward(
637 const Tensor& gL,
638 const Tensor& gV,
639 const Tensor& L,
640 const Tensor& V,
641 const bool is_hermitian,
642 const bool symeig_eigenvectors = true);
643 Tensor linalg_lstsq_jvp(
644 const Tensor& A,
645 const Tensor& B,
646 const Tensor& dA,
647 const Tensor& dB);
648 std::tuple<Tensor, Tensor> triangular_solve_backward(
649 const Tensor& grad_x,
650 const Tensor& grad_m,
651 const Tensor& b,
652 const Tensor& a,
653 const Tensor& x,
654 const bool upper,
655 const bool transpose,
656 const bool unitriangular,
657 std::array<bool, 2> output_mask);
658 Tensor triangular_solve_jvp(
659 const Tensor& X,
660 const Tensor& A,
661 const Tensor& dA,
662 const Tensor& dB,
663 const bool upper,
664 const bool transpose,
665 const bool unitriangular);
666 Tensor linalg_solve_triangular_forward_AD(
667 const Tensor& A_t,
668 const Tensor& B_t,
669 const Tensor& A,
670 const Tensor& X,
671 const bool upper,
672 const bool left,
673 const bool unitriangular);
674 std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
675 const Tensor& grad,
676 const Tensor& A,
677 const Tensor& X,
678 const bool upper,
679 const bool left,
680 const bool unitriangular,
681 std::array<bool, 2> output_mask);
682 std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(
683 const Tensor& grad_out,
684 const std::optional<Tensor>& i1,
685 const std::optional<Tensor>& i2,
686 const std::optional<Tensor>& i3,
687 IntArrayRef expand1,
688 IntArrayRef expand2,
689 IntArrayRef expand3,
690 IntArrayRef sumdim,
691 std::array<bool, 3> grad_mask);
692 std::tuple<Tensor, Tensor> linalg_qr_jvp(
693 const Tensor& dA,
694 const Tensor& Q,
695 const Tensor& R,
696 const c10::string_view mode);
697 Tensor linalg_qr_backward(
698 const Tensor& gQ,
699 const Tensor& gR,
700 const Tensor& Q,
701 const Tensor& R,
702 const c10::string_view mode);
703 Tensor linalg_matrix_exp_differential(
704 const Tensor& self,
705 const Tensor& grad,
706 bool adjoint);
707 std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
708 const Tensor& input,
709 const std::optional<Tensor>& gamma,
710 const Tensor& ggI,
711 const Tensor& ggG,
712 const Tensor& ggB,
713 const Tensor& gO,
714 const std::optional<Tensor>& running_mean,
715 const std::optional<Tensor>& running_var,
716 bool training,
717 double eps,
718 const std::optional<Tensor>& save_mean,
719 const std::optional<Tensor>& save_invstd,
720 std::array<bool, 3> output_mask);
721 std::tuple<Tensor, Tensor> _euclidean_dist_backward(
722 const Tensor& grad,
723 const Tensor& x1,
724 const Tensor& x2,
725 const Tensor& res);
726 Tensor fft_backward(
727 const Tensor& self,
728 const Tensor& grad,
729 int64_t signal_ndim,
730 bool complex_input,
731 bool complex_output,
732 bool inverse,
733 IntArrayRef checked_signal_sizes,
734 int64_t normalization,
735 bool onesided,
736 IntArrayRef output_sizes);
737 Tensor fft_r2c_backward(
738 const Tensor& grad,
739 at::IntArrayRef dim,
740 int64_t normalization,
741 bool onesided,
742 const c10::SymInt& last_dim_size);
743 Tensor fft_c2r_backward(
744 const Tensor& grad,
745 IntArrayRef dim,
746 int64_t normalization);
747 Tensor constant_pad_nd_backward(const Tensor& grad, c10::SymIntArrayRef pad);
748 std::tuple<Tensor, Tensor> cholesky_solve_backward(
749 const Tensor& grad_x,
750 const Tensor& self,
751 const Tensor& input2,
752 const Tensor& result,
753 const bool upper,
754 std::array<bool, 2> output_mask);
755 Tensor cholesky_solve_jvp(
756 const Tensor& X,
757 const Tensor& U,
758 const Tensor& dU,
759 const Tensor& dB,
760 const bool upper);
761 std::tuple<Tensor, Tensor, Tensor>
762 infinitely_differentiable_native_group_norm_backward(
763 const Tensor& dY,
764 const Tensor& dmean,
765 const Tensor& drstd,
766 const Tensor& X,
767 const Tensor& mean,
768 const Tensor& rstd,
769 const std::optional<Tensor>& gamma,
770 c10::SymInt N,
771 const c10::SymInt& C,
772 c10::SymInt HxW,
773 int64_t group,
774 double eps,
775 std::array<bool, 3> grad_input_mask);
776 Tensor gelu_double_backward(
777 const Tensor& ggI,
778 const Tensor& gO,
779 const Tensor& input,
780 c10::string_view approximate);
781 Tensor as_strided_backward(
782 Tensor grad,
783 const TensorGeometry& input_geometry,
784 c10::SymIntArrayRef sizes,
785 c10::SymIntArrayRef strides,
786 const std::optional<c10::SymInt>& storage_offset_);
787 Tensor as_strided_scatter_backward(
788 const Tensor& grad,
789 const TensorGeometry& input_geometry,
790 const TensorGeometry& src_geometry,
791 c10::SymIntArrayRef sizes,
792 c10::SymIntArrayRef strides,
793 std::optional<c10::SymInt> storage_offset);
794 std::tuple<Tensor, Tensor> atan2_backward(
795 const Tensor& grad,
796 const Tensor& self,
797 const Tensor& other,
798 std::array<bool, 2> output_mask);
799 Tensor amaxamin_jvp(
800 const Tensor& x,
801 const Tensor& dx,
802 const Tensor& result,
803 IntArrayRef dim,
804 bool keepdim);
805 std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
806 const Tensor& input,
807 const std::optional<Tensor>& gamma,
808 const Tensor& ggI,
809 const Tensor& ggG,
810 const Tensor& ggB,
811 const Tensor& gO,
812 const Tensor& save_mean,
813 const Tensor& save_invstd,
814 c10::SymIntArrayRef normalized_shape,
815 std::array<bool, 3> output_mask);
816
817 std::tuple<Tensor, Tensor> householder_product_backward(
818 const Tensor& grad,
819 const Tensor& result,
820 const Tensor& input,
821 const Tensor& tau,
822 const bool flip_order = false);
823 Tensor householder_product_jvp(
824 const Tensor& dV,
825 const Tensor& dtau,
826 const Tensor& prod,
827 const Tensor& V,
828 const Tensor& tau);
829 std::tuple<Tensor, Tensor, Tensor> ormqr_backward(
830 const Tensor& grad,
831 const Tensor& result,
832 const Tensor& self,
833 const Tensor& tau,
834 const Tensor& other,
835 bool left,
836 bool transpose,
837 std::array<bool, 3> grad_output_mask);
838 std::tuple<Tensor, Tensor> polar_backward(
839 const Tensor& grad,
840 const Tensor& result);
841 Tensor i1_backward(
842 const Tensor& grad,
843 const Tensor& self,
844 const Tensor& result);
845 Tensor i1e_backward(
846 const Tensor& grad,
847 const Tensor& self,
848 const Tensor& result);
849 Tensor linalg_lu_solve_LU(
850 const Tensor& grad,
851 const Tensor& LU,
852 const Tensor& pivots,
853 const Tensor& X,
854 const bool left,
855 const bool adjoint);
856 Tensor linalg_lu_solve_jvp(
857 const Tensor& X,
858 const Tensor& LU,
859 const Tensor& pivots,
860 const Tensor& dLU,
861 const Tensor& dB,
862 const bool left,
863 const bool adjoint);
864 std::tuple<Tensor, Tensor> linalg_solve_backward(
865 const Tensor& gX,
866 const Tensor& X,
867 const Tensor& A,
868 const Tensor& LU,
869 const Tensor& pivots,
870 const bool left,
871 const bool B_requires_grad);
872 Tensor linalg_solve_jvp(
873 const Tensor& dA,
874 const Tensor& dB,
875 const Tensor& X,
876 const Tensor& LU,
877 const Tensor& pivots,
878 const bool left,
879 const bool use_A_T);
880 Tensor lu_unpack_backward(
881 const Tensor& L_grad,
882 const Tensor& U_grad,
883 const c10::SymInt& m,
884 const c10::SymInt& n);
885
886 Tensor linalg_det_backward(
887 const Tensor& grad,
888 const Tensor& det,
889 const Tensor& A,
890 const Tensor& LU,
891 const Tensor& pivots);
892 Tensor linalg_det_jvp(
893 const Tensor& dA,
894 const Tensor& det,
895 const Tensor& LU,
896 const Tensor& pivots,
897 const bool use_A_T);
898 std::tuple<Tensor, Tensor> linalg_lstsq_backward(
899 const Tensor& grad,
900 const Tensor& A,
901 const Tensor& B_,
902 const std::array<bool, 2>& grad_input_mask);
903 Tensor linalg_lu_backward(
904 const Tensor& L_grad,
905 const Tensor& U_grad,
906 const Tensor& P,
907 const Tensor& L,
908 const Tensor& U,
909 const bool pivot);
910
911 std::tuple<Tensor, Tensor> linalg_lu_jvp(
912 const Tensor& dA,
913 const Tensor& P,
914 const Tensor& L,
915 const Tensor& U,
916 const bool pivot);
917
918 Tensor lu_factor_ex_backward(
919 const Tensor& grad,
920 const Tensor& LU,
921 const Tensor& pivs,
922 const bool pivot);
923 Tensor lu_factor_ex_jvp(
924 const Tensor& dX,
925 const Tensor& LU,
926 const Tensor& pivs,
927 const bool pivot);
928
929 Tensor batch_norm_jvp(
930 const Tensor& input_p,
931 const Tensor& input_t,
932 const Tensor& weight_p,
933 const Tensor& weight_t,
934 const Tensor& bias_p,
935 const Tensor& bias_t,
936 const std::optional<Tensor>& running_mean,
937 const std::optional<Tensor>& running_var,
938 const Tensor& saved_mean,
939 const Tensor& saved_invstd,
940 bool train,
941 double eps);
942
943 Tensor layer_norm_jvp(
944 const Tensor& input_p,
945 const Tensor& input_t,
946 const Tensor& weight_p,
947 const Tensor& weight_t,
948 const Tensor& bias_p,
949 const Tensor& bias_t,
950 const Tensor& saved_mean,
951 const Tensor& saved_invstd,
952 c10::SymIntArrayRef normalized_shape);
953
954 Tensor group_norm_jvp(
955 const Tensor& input_p,
956 const Tensor& input_t,
957 const Tensor& weight_p,
958 const Tensor& weight_t,
959 const Tensor& bias_p,
960 const Tensor& bias_t,
961 const Tensor& saved_mean,
962 const Tensor& saved_invstd,
963 int64_t groups);
964 Tensor group_norm_mean_jvp(
965 const Tensor& input_t,
966 const Tensor& mean_p,
967 int64_t groups);
968 Tensor group_norm_invstd_jvp(
969 const Tensor& input_p,
970 const Tensor& input_t,
971 const Tensor& mean_p,
972 const Tensor& invstd_p,
973 int64_t groups);
974
975 Tensor convolution_jvp(
976 const Tensor& input_p,
977 const Tensor& input_t,
978 const Tensor& weight_p,
979 const Tensor& weight_t,
980 const Tensor& bias_p,
981 const Tensor& bias_t,
982 at::SymIntArrayRef stride,
983 at::SymIntArrayRef padding,
984 at::SymIntArrayRef dilation,
985 bool transposed,
986 at::SymIntArrayRef output_padding,
987 const c10::SymInt& groups);
988
989 Tensor _convolution_jvp(
990 const Tensor& input_p,
991 const Tensor& input_t,
992 const Tensor& weight_p,
993 const Tensor& weight_t,
994 const Tensor& bias_p,
995 const Tensor& bias_t,
996 at::SymIntArrayRef stride,
997 at::SymIntArrayRef padding,
998 at::SymIntArrayRef dilation,
999 bool transposed,
1000 at::SymIntArrayRef output_padding,
1001 const c10::SymInt& groups,
1002 bool benchmark,
1003 bool deterministic,
1004 bool cudnn_enabled,
1005 bool allow_tf32);
1006
1007 Tensor convolution_backward_jvp_grad_bias(
1008 const Tensor& grad_out_t,
1009 const Tensor& grad_bias);
1010
1011 Tensor cat_jvp(const at::ITensorListRef& tensors, int64_t dim);
1012 Tensor block_diag_jvp(at::TensorList tensors);
1013 Tensor stack_jvp(at::TensorList tensors, int64_t dim);
1014 Tensor cumprod_jvp(
1015 const Tensor& self_t,
1016 const Tensor& self_p,
1017 const Tensor& result,
1018 int dim);
1019 Tensor gather_with_keepdimed_indices(
1020 const Tensor& input,
1021 int64_t dim,
1022 const Tensor& indices,
1023 bool keepdim);
1024 Tensor evenly_read_jvp(
1025 const Tensor& fw_grad,
1026 const Tensor& input,
1027 const Tensor& value);
1028 Tensor warn_backwards(const Tensor& grad_output);
1029
1030 std::tuple<Tensor, Tensor> _cudnn_convolution_backward(
1031 const at::Tensor& self,
1032 const at::Tensor& grad_output,
1033 const at::Tensor& weight,
1034 at::SymIntArrayRef padding,
1035 at::SymIntArrayRef output_padding,
1036 at::SymIntArrayRef stride,
1037 at::SymIntArrayRef dilation,
1038 bool transposed,
1039 c10::SymInt groups,
1040 ::std::array<bool, 2> output_mask);
1041
1042 Tensor scatter_reduce_jvp(
1043 const Tensor& self_p,
1044 const Tensor& self_t,
1045 int dim,
1046 const Tensor& index,
1047 const Tensor& src_p,
1048 const Tensor& src_t,
1049 c10::string_view reduce,
1050 bool include_self,
1051 const Tensor& result);
1052
1053 std::tuple<Tensor, Tensor> scatter_reduce_backward(
1054 const Tensor& grad,
1055 const Tensor& self,
1056 int dim,
1057 const Tensor& index,
1058 const Tensor& src,
1059 c10::string_view reduce,
1060 bool include_self,
1061 const Tensor& result);
1062
1063 Tensor _to_copy_backward(
1064 const Tensor& grad,
1065 const c10::TensorOptions& self_options);
1066
1067 std::tuple<Tensor, Tensor> index_reduce_backward(
1068 const Tensor& grad,
1069 const Tensor& self,
1070 int dim,
1071 const Tensor& index,
1072 const Tensor& source,
1073 c10::string_view reduce,
1074 bool include_self,
1075 const Tensor& result);
1076
1077 Tensor take_backward(
1078 const Tensor& grad,
1079 const Tensor& self,
1080 const Tensor& indices);
1081
1082 Tensor to_sparse_backward(
1083 const Tensor& grad,
1084 const c10::Layout self_layout,
1085 const c10::OptionalArrayRef<c10::SymInt>& self_blocksize);
1086
1087 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>
1088 mkldnn_rnn_layer_differentiable_backward(
1089 const Tensor& input,
1090 const Tensor& weight0,
1091 const Tensor& weight1,
1092 const Tensor& weight2,
1093 const Tensor& weight3,
1094 const Tensor& hx_,
1095 const Tensor& cx_tmp,
1096 const Tensor& output,
1097 const Tensor& hy_,
1098 const Tensor& cy_,
1099 const std::optional<Tensor>& grad_output_r_opt,
1100 const std::optional<Tensor>& grad_hy_r_opt,
1101 const std::optional<Tensor>& grad_cy_r_opt,
1102 bool reverse,
1103 int64_t mode,
1104 int64_t hidden_size,
1105 int64_t num_layers,
1106 bool has_biases,
1107 bool train,
1108 bool bidirectional,
1109 at::IntArrayRef batch_sizes,
1110 bool batch_first,
1111 const at::Tensor& workspace);
1112
1113 Tensor values_backward(const Tensor& grad, const Tensor& self);
1114
1115 } // namespace torch::autograd::generated::details
1116