xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/FunctionsManual.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // NB: Must be at the top of file to avoid including the deprecated "math.h".
4 // https://stackoverflow.com/questions/6563810/m-pi-works-with-math-h-but-not-with-cmath-in-visual-studio
5 #ifdef _MSC_VER
6 #ifndef _USE_MATH_DEFINES
7 #define _USE_MATH_DEFINES
8 #endif
9 #include <cmath>
10 #endif
11 
12 #include <ATen/ATen.h>
13 #include <torch/csrc/autograd/generated/Functions.h>
14 
15 namespace torch::autograd::generated::details {
16 
17 extern const char* kCudnnDoubleBackwardMsg;
18 
19 // A simple way to imperatively compute index ranges for slots
20 // that have been flattened
21 struct TORCH_API IndexRangeGenerator {
rangeIndexRangeGenerator22   IndexRange range(size_t range_size) {
23     i += range_size;
24     return {i - range_size, i};
25   }
sizeIndexRangeGenerator26   size_t size() {
27     return i;
28   }
29 
30  private:
31   size_t i = 0;
32 };
33 
34 TORCH_API Tensor toNonOptFwGrad(const std::optional<Tensor>& t);
35 TORCH_API Tensor toNonOptPrimal(const std::optional<Tensor>& t);
36 TORCH_API Tensor toNonOptTensor(const std::optional<Tensor>& t);
37 
wrap_opt_if(const Tensor & t,const bool cond)38 TORCH_API inline std::optional<Tensor> wrap_opt_if(
39     const Tensor& t,
40     const bool cond) {
41   using OptTensor = std::optional<Tensor>;
42   return cond ? OptTensor(t) : static_cast<OptTensor>(std::nullopt);
43 }
44 
45 TORCH_API Tensor
46 apply_loss_reduction(const Tensor& unreduced, int64_t reduction);
47 TORCH_API bool any_variable_defined(const variable_list& variables);
48 TORCH_API void copy_range(
49     variable_list& out,
50     IndexRange range,
51     const at::Tensor& t);
52 TORCH_API void copy_range(
53     variable_list& out,
54     IndexRange range,
55     at::ArrayRef<at::Tensor> t);
56 TORCH_API at::Tensor copysign_tensor_self_backward(
57     const Tensor& grad,
58     const Tensor& self,
59     const Tensor& result);
60 TORCH_API at::Tensor not_implemented(const char* name, const char* reason = "");
61 TORCH_API std::vector<Tensor> not_implemented_list(
62     const char* name,
63     const char* reason = "");
64 at::Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result);
65 at::Tensor maybe_multiply(const at::Tensor& t, const at::Scalar& s);
66 int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim);
67 Tensor restore_reduced_dims(
68     const Tensor& output,
69     IntArrayRef dims,
70     bool keepdim);
71 Tensor scale_grad_by_count(
72     const Tensor& grad,
73     const Tensor& mask,
74     IntArrayRef dims);
75 at::Tensor norm_backward(
76     const at::Tensor& grad,
77     const at::Tensor& self,
78     const std::optional<at::Scalar>& p_,
79     const at::Tensor& norm);
80 at::Tensor norm_backward(
81     at::Tensor grad,
82     const at::Tensor& self,
83     const std::optional<at::Scalar>& p_,
84     at::Tensor norm,
85     at::IntArrayRef dim,
86     bool keepdim);
87 Tensor norm_jvp(
88     const Tensor& self_p,
89     const Tensor& self_t,
90     const std::optional<Scalar>& p_,
91     Tensor norm,
92     IntArrayRef dim,
93     bool keepdim);
94 Tensor norm_jvp(
95     const Tensor& grad,
96     const Tensor& self,
97     const std::optional<Scalar>& p_,
98     Tensor norm);
99 Tensor _nested_from_padded_backward(
100     const Tensor& grad,
101     const Tensor& input,
102     const bool do_transform_0213);
103 std::tuple<Tensor, Tensor, Tensor> linear_double_backward(
104     const variable_list& grads,
105     const Tensor& self,
106     const Tensor& grad_output,
107     const Tensor& weight);
108 Tensor linalg_vector_norm_jvp(
109     const Tensor& self_p,
110     const Tensor& self_t,
111     const Scalar& scalar_ord,
112     Tensor norm,
113     const at::OptionalIntArrayRef& opt_dim,
114     bool keepdim);
115 at::Tensor linalg_vector_norm_backward(
116     at::Tensor grad,
117     const at::Tensor& self,
118     const at::Scalar& ord,
119     at::Tensor norm,
120     const at::OptionalIntArrayRef& opt_dim,
121     bool keepdim);
122 at::Tensor pow_backward(
123     at::Tensor grad,
124     const at::Tensor& self,
125     const at::Scalar& exponent_);
126 at::Tensor pow_backward_self(
127     const at::Tensor& grad,
128     const at::Tensor& self,
129     const at::Tensor& exponent);
130 at::Tensor pow_backward_exponent(
131     const at::Tensor& grad,
132     const at::Tensor& self,
133     const at::Tensor& exponent,
134     const at::Tensor& result);
135 at::Tensor pow_backward_exponent(
136     const at::Tensor& grad,
137     const at::Scalar& base,
138     const at::Tensor& exponent,
139     const at::Tensor& result);
140 at::Tensor angle_backward(const at::Tensor& grad, const at::Tensor& self);
141 template <typename T>
142 at::Tensor mul_tensor_backward(const Tensor& grad, T other, ScalarType self_st);
143 template <typename T>
144 at::Tensor div_tensor_self_backward(
145     const Tensor& grad,
146     T other,
147     ScalarType self_st);
148 at::Tensor div_tensor_other_backward(
149     const Tensor& grad,
150     const Tensor& self,
151     const Tensor& other);
152 template <typename T>
153 at::Tensor div_tensor_self_backward(
154     const Tensor& grad,
155     T other,
156     ScalarType self_st,
157     const std::optional<c10::string_view>& rounding_mode);
158 at::Tensor div_tensor_other_backward(
159     const Tensor& grad,
160     const Tensor& self,
161     const Tensor& other,
162     const std::optional<c10::string_view>& rounding_mode);
163 at::Tensor mvlgamma_backward(
164     const at::Tensor& grad,
165     const at::Tensor& self,
166     int64_t p);
167 at::Tensor permute_backwards(const at::Tensor& grad, at::IntArrayRef fwd_dims);
168 at::Tensor rad2deg_backward(const at::Tensor& grad);
169 at::Tensor deg2rad_backward(const at::Tensor& grad);
170 at::Tensor unsqueeze_multiple(
171     const at::Tensor& t,
172     at::OptionalIntArrayRef opt_dim,
173     size_t n_dims);
174 at::Tensor sum_backward(
175     const at::Tensor& grad,
176     at::SymIntArrayRef sizes,
177     at::OptionalIntArrayRef opt_dims,
178     bool keepdim);
179 at::Tensor sum_backward(
180     const at::Tensor& grad,
181     c10::SymIntArrayRef sizes,
182     c10::IntArrayRef dims,
183     bool keepdim);
184 at::Tensor nansum_backward(
185     const at::Tensor& grad,
186     const at::Tensor& self,
187     at::OptionalIntArrayRef dims,
188     bool keepdim);
189 std::vector<int64_t> reverse_list(const at::IntArrayRef list);
190 std::vector<c10::SymInt> reverse_list_symint(const c10::SymIntArrayRef list);
191 at::Tensor reverse_dim(const at::Tensor& t, int64_t dim);
192 at::Tensor prod_safe_zeros_backward(
193     const at::Tensor& grad,
194     const at::Tensor& inp,
195     int64_t dim);
196 at::Tensor prod_backward(
197     const at::Tensor& grad,
198     const at::Tensor& input,
199     const at::Tensor& result);
200 at::Tensor prod_backward(
201     at::Tensor grad,
202     const at::Tensor& input,
203     at::Tensor result,
204     int64_t dim,
205     bool keepdim);
206 at::Tensor solve_jvp(
207     const Tensor& X,
208     const Tensor& A,
209     const Tensor& dA,
210     const Tensor& dB);
211 at::Tensor solve_backward_self(
212     const at::Tensor& grad,
213     const at::Tensor& self,
214     const at::Tensor& A);
215 at::Tensor solve_backward_A(
216     const at::Tensor& grad,
217     const at::Tensor& self,
218     const at::Tensor& A,
219     const at::Tensor& solution);
220 at::Tensor cumsum_backward(const at::Tensor& grad, int64_t dim);
221 at::Tensor logsumexp_backward(
222     at::Tensor grad,
223     const at::Tensor& self,
224     at::Tensor result,
225     at::IntArrayRef dim,
226     bool keepdim);
227 at::Tensor logsumexp_jvp(
228     const at::Tensor& self_p,
229     const at::Tensor& self_t,
230     IntArrayRef dim,
231     bool keepdim);
232 at::Tensor safe_logsumexp_jvp(
233     const at::Tensor& self_p,
234     const at::Tensor& self_t,
235     IntArrayRef dim,
236     bool keepdim);
237 at::Tensor logcumsumexp_backward(
238     at::Tensor grad,
239     const at::Tensor& self,
240     const at::Tensor& result,
241     int64_t dim);
242 at::Tensor logcumsumexp_jvp(
243     const at::Tensor& self_p,
244     const at::Tensor& self_t,
245     int64_t dim);
246 at::Tensor unbind_backward(const variable_list& grads, int64_t dim);
247 at::Tensor unbind_backward_nested(
248     const variable_list& grads,
249     const Tensor& nt_sizes,
250     int64_t dim,
251     const at::TensorOptions& options);
252 at::Tensor unbind_backward_nested_jagged(
253     const variable_list& grads,
254     const Tensor& self,
255     int64_t dim);
256 at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes);
257 at::Tensor unsqueeze_to(
258     const at::Tensor& self,
259     int64_t dim,
260     c10::SymIntArrayRef sym_sizes);
261 at::Tensor unsqueeze_to(
262     const at::Tensor& self,
263     IntArrayRef dim,
264     c10::SymIntArrayRef sym_sizes);
265 std::vector<at::Tensor> cat_tensors_backward(
266     const at::Tensor& grad,
267     const std::vector<std::vector<c10::SymInt>>& sizes,
268     const std::vector<ScalarType>& dtypes,
269     int64_t dim);
270 std::vector<at::Tensor> stack_tensors_backward(
271     const at::Tensor& grad,
272     int64_t dim,
273     const std::vector<ScalarType>& dtypes);
274 std::vector<at::Tensor> block_diag_backward(
275     const at::Tensor& grad,
276     const std::vector<std::vector<int64_t>>& sizes,
277     const std::vector<ScalarType>& dtypes);
278 at::Tensor clamp_backward(
279     const at::Tensor& grad,
280     const at::Tensor& self,
281     const std::optional<at::Scalar>& min,
282     const std::optional<at::Scalar>& max);
283 at::Tensor clamp_backward(
284     const at::Tensor& grad,
285     const at::Tensor& self,
286     const at::Tensor& min,
287     const at::Tensor& max);
288 std::tuple<at::Tensor, at::Tensor> clamp_backward_min_max(
289     const at::Tensor& grad,
290     const at::Tensor& self,
291     const at::Tensor& min,
292     const at::Tensor& max,
293     const std::array<bool, 2>&);
294 at::Tensor clamp_jvp(
295     const Tensor& self_p,
296     const Tensor& self_t,
297     const Tensor& min_p,
298     const Tensor& min_t,
299     const Tensor& max_p,
300     const Tensor& max_t);
301 at::SymIntArrayRef strides_or_error(
302     const Tensor& input,
303     c10::string_view const& input_name);
304 at::Tensor mm_mat1_backward(
305     const Tensor& grad,
306     const Tensor& mat2,
307     at::SymIntArrayRef mat1_sizes,
308     at::SymIntArrayRef mat1_strides,
309     c10::Layout mat1_layout,
310     const Scalar& alpha);
311 at::Tensor mm_mat2_backward(
312     const at::Tensor& grad,
313     const at::Tensor& mat1,
314     at::SymIntArrayRef sizes,
315     at::SymIntArrayRef strides,
316     c10::Layout layout,
317     const at::Scalar& alpha);
318 at::Tensor mm_mat1_sparse_backward(
319     const at::Tensor& grad,
320     const at::Tensor& mat1,
321     const at::Tensor& mat2,
322     const at::Scalar& alpha);
323 std::tuple<Tensor, Tensor, Tensor> sparse_sampled_addmm_backward(
324     const Tensor& grad,
325     const Tensor& self,
326     const std::optional<Tensor>& mat1,
327     const std::optional<Tensor>& mat2,
328     const Scalar& alpha,
329     const Scalar& beta,
330     const std::array<bool, 3>& grad_input_mask);
331 at::Tensor sparse_mask_backward(
332     const at::Tensor& grad,
333     const at::Tensor& mask,
334     c10::Layout self_layout);
335 at::Tensor sparse_sparse_matmul_backward(
336     const at::Tensor& grad,
337     const at::Tensor& mat1,
338     const at::Tensor& mat2,
339     int64_t grad_order);
340 at::Tensor renorm_backward(
341     const at::Tensor& grad,
342     const at::Tensor& self,
343     const at::Scalar& p,
344     int64_t dim,
345     const at::Scalar& maxnorm);
346 at::Tensor renorm_jvp(
347     const at::Tensor& self_p,
348     const at::Tensor& self_t,
349     const at::Scalar& p,
350     int64_t dim,
351     const at::Scalar& maxnorm);
352 at::Tensor repeat_backward(
353     at::Tensor grad,
354     at::SymIntArrayRef repeats,
355     at::SymIntArrayRef input_shape);
356 at::Tensor _fused_dropout_backward(
357     const at::Tensor& grad,
358     const at::Tensor& mask,
359     double p1m);
360 at::Tensor infinitely_differentiable_native_dropout_backward(
361     const at::Tensor& grad,
362     const at::Tensor& mask,
363     double scale);
364 at::Tensor native_dropout_double_backward(
365     const at::Tensor& ggI,
366     const at::Tensor& grad,
367     const at::Tensor& mask,
368     double scale);
369 at::Tensor evenly_distribute_backward(
370     const at::Tensor& grad,
371     const at::Tensor& input,
372     const at::Tensor& value);
373 Tensor sgn_backward(const Tensor& x, const Tensor& gx, const Tensor& sgn);
374 Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask);
375 at::Tensor var_backward(
376     at::Tensor grad,
377     const at::Tensor& self,
378     at::OptionalIntArrayRef dim,
379     const std::optional<c10::Scalar>& correction,
380     bool keepdim);
381 at::Tensor var_jvp(
382     const at::Tensor& self_t,
383     const at::Tensor& self_p,
384     const at::Tensor& result,
385     at::OptionalIntArrayRef dim_opt,
386     const std::optional<c10::Scalar>& correction,
387     bool keepdim);
388 at::Tensor std_backward(
389     const at::Tensor& result,
390     const at::Tensor& grad,
391     const at::Tensor& self,
392     at::OptionalIntArrayRef dim,
393     const std::optional<c10::Scalar>& correction,
394     bool keepdim);
395 Tensor mean_backward(
396     const Tensor& grad,
397     c10::SymIntArrayRef shape,
398     at::OptionalIntArrayRef opt_dim,
399     c10::SymInt numel,
400     bool keepdim);
401 Tensor var_mean_backward(
402     const Tensor& gvar,
403     const Tensor& gmean,
404     const Tensor& self,
405     at::OptionalIntArrayRef dim_opt,
406     const std::optional<c10::Scalar>& correction,
407     bool keepdim);
408 Tensor std_mean_backward(
409     const Tensor& gstd,
410     const Tensor& gmean,
411     const Tensor& self,
412     const Tensor& std,
413     at::OptionalIntArrayRef dim_opt,
414     const std::optional<c10::Scalar>& correction,
415     bool keepdim);
416 at::Tensor cholesky_backward(
417     const at::Tensor& grad,
418     bool upper,
419     const at::Tensor& L);
420 at::Tensor cholesky_jvp(
421     const at::Tensor& input_tangent,
422     const at::Tensor& L,
423     bool upper);
424 at::Tensor cholesky_inverse_backward(
425     const at::Tensor& grad,
426     const at::Tensor& L,
427     bool upper,
428     const at::Tensor& inverse);
429 at::Tensor cholesky_inverse_jvp(
430     const at::Tensor& F,
431     const at::Tensor& dF,
432     const at::Tensor& X,
433     bool upper);
434 Tensor pinv_jvp(const Tensor& A, const Tensor& pinvA, const Tensor& dA);
435 Tensor pinv_backward(const Tensor& grad, const Tensor& pinvA, const Tensor& A);
436 Tensor chunk_backward_nested(
437     const std::vector<torch::autograd::Variable>& grads,
438     const Tensor& self,
439     int64_t chunks,
440     int64_t dim);
441 at::Tensor split_with_sizes_backward(
442     const std::vector<torch::autograd::Variable>& grads,
443     c10::SymIntArrayRef split_sizes,
444     int64_t dim,
445     c10::SymIntArrayRef sizes,
446     const at::TensorOptions& options);
447 at::Tensor _nested_split_with_sizes_backward(
448     const std::vector<torch::autograd::Variable>& grads,
449     c10::SymIntArrayRef split_sizes,
450     int64_t dim,
451     const Tensor& nt_sizes,
452     const at::TensorOptions& options);
453 at::Tensor split_backward(
454     const std::vector<torch::autograd::Variable>& grads,
455     const c10::SymInt& split_size,
456     int64_t dim,
457     c10::SymIntArrayRef sizes,
458     const at::TensorOptions& options);
459 at::Tensor max_pool_double_backward(
460     const at::Tensor& grad,
461     const at::Tensor& indices,
462     int dim);
463 at::Tensor error_for_max_pool2d_double_backward();
464 at::Tensor glu_double_backward(
465     const at::Tensor& grad,
466     const at::Tensor& grad_output,
467     const at::Tensor& input,
468     int64_t dim);
469 at::Tensor glu_double_backward_grad_output(
470     const at::Tensor& grad,
471     const at::Tensor& input,
472     int64_t dim);
473 at::Tensor infinitely_differentiable_silu_backward(
474     const at::Tensor& grad_output,
475     const at::Tensor& input);
476 at::Tensor infinitely_differentiable_mish_backward(
477     const at::Tensor& grad_output,
478     const at::Tensor& input);
479 Tensor infinitely_differentiable_logit_backward(
480     const Tensor& grad,
481     const Tensor& self,
482     std::optional<double> eps);
483 Tensor binary_cross_entropy_target_backward(
484     const Tensor& grad,
485     const Tensor& self,
486     const Tensor& target,
487     const std::optional<Tensor>& weight,
488     int64_t reduction);
489 Tensor binary_cross_entropy_double_backward_target(
490     const Tensor& grad,
491     const Tensor& grad_output,
492     const Tensor& self,
493     const Tensor& target,
494     const std::optional<Tensor>& weight,
495     int64_t reduction);
496 Tensor binary_cross_entropy_with_logits_backward(
497     const Tensor& grad,
498     const Tensor& input,
499     const Tensor& target,
500     const std::optional<Tensor>& weight_opt,
501     const std::optional<Tensor>& pos_weight_opt,
502     int64_t reduction);
503 at::Tensor binary_cross_entropy_with_logits_target_backward(
504     const at::Tensor& grad_output,
505     const at::Tensor& self,
506     const at::Tensor& target,
507     const std::optional<at::Tensor>& weight,
508     const std::optional<at::Tensor>& pos_weight,
509     int64_t reduction);
510 at::Tensor log_sigmoid_double_backward(
511     const at::Tensor& grad,
512     const at::Tensor& input);
513 at::Tensor softmax_double_backward(
514     const at::Tensor& grad,
515     const at::Tensor& grad_output,
516     int dim,
517     const at::Tensor& output);
518 at::Tensor binary_cross_entropy_double_backward(
519     const at::Tensor& grad_output,
520     const at::Tensor& grad,
521     const at::Tensor& input,
522     const at::Tensor& target,
523     const std::optional<at::Tensor>& weight,
524     int64_t reduction);
525 at::Tensor binary_cross_entropy_double_backward_grad_output(
526     const at::Tensor& grad,
527     const at::Tensor& input,
528     const at::Tensor& target,
529     const std::optional<at::Tensor>& weight,
530     int64_t reduction);
531 at::Tensor smooth_l1_loss_double_backward(
532     const at::Tensor& grad,
533     const at::Tensor& input,
534     const at::Tensor& target,
535     int64_t reduction,
536     double beta);
537 at::Tensor huber_loss_double_backward(
538     const at::Tensor& grad,
539     const at::Tensor& input,
540     const at::Tensor& target,
541     int64_t reduction,
542     double delta);
543 at::Tensor huber_loss_double_backward_grad_output(
544     const at::Tensor& grad,
545     const at::Tensor& grad_output,
546     const at::Tensor& input,
547     const at::Tensor& target,
548     int64_t reduction,
549     double delta);
550 at::Tensor mse_loss_double_backward(
551     const at::Tensor& grad,
552     const at::Tensor& input,
553     int64_t reduction);
554 at::Tensor soft_margin_loss_double_backward(
555     const at::Tensor& grad,
556     const at::Tensor& input,
557     const at::Tensor& target,
558     int64_t reduction);
559 at::Tensor soft_margin_loss_double_backward_grad_output(
560     const at::Tensor& grad,
561     const at::Tensor& grad_output,
562     const at::Tensor& input,
563     const at::Tensor& target,
564     int64_t reduction);
565 at::Tensor softplus_double_backward(
566     const at::Tensor& grad,
567     const at::Tensor& input,
568     const at::Scalar& beta,
569     const at::Scalar& threshold);
570 std::tuple<at::Tensor, at::Tensor> slogdet_jvp(
571     const at::Tensor& LU,
572     const at::Tensor& pivots,
573     const at::Tensor& dA,
574     const at::Tensor& sign,
575     const bool use_A_T);
576 at::Tensor slogdet_backward(
577     const at::Tensor& grad_sign,
578     const at::Tensor& grad_logabsdet,
579     const at::Tensor& A,
580     const at::Tensor& signdet,
581     const at::Tensor& LU,
582     const at::Tensor& pivots);
583 at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self);
584 at::Tensor sinc_backward(const at::Tensor& grad, const at::Tensor& self);
585 at::Tensor sparse_constructor_values_backward(
586     const at::Tensor& sparse_grad_out,
587     const at::Tensor& indices);
588 at::Tensor embedding_dense_double_backward_symint(
589     const at::Tensor& grad,
590     const at::Tensor& indices,
591     const c10::SymInt& padding_idx);
592 at::Tensor index_backward(
593     at::Tensor zeros_like_self,
594     const torch::List<std::optional<Tensor>>& indices,
595     const at::Tensor& grad);
596 at::Tensor _cudnn_ctc_loss_backward(
597     const at::Tensor& grad_out,
598     const at::Tensor& loss,
599     const at::Tensor& raw_grad,
600     bool zero_infinity);
601 at::Tensor elu_double_backward(
602     const Tensor& grad,
603     const Tensor& grad_output,
604     const Scalar& alpha,
605     const Scalar& scale,
606     const Scalar& input_scale,
607     bool is_result,
608     const Tensor& self_or_result);
609 
610 Tensor svd_backward(
611     const Tensor& gU,
612     const Tensor& gS,
613     const Tensor& gVh,
614     const Tensor& U,
615     const Tensor& S,
616     const Tensor& Vh);
617 
618 std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(
619     const Tensor& dA,
620     const Tensor& U,
621     const Tensor& S,
622     const Tensor& Vh,
623     const bool full_matrices);
624 Tensor slice_backward_wrapper(
625     const at::Tensor& grad,
626     const c10::SymIntArrayRef& input_sizes,
627     int64_t dim,
628     std::optional<c10::SymInt> start,
629     std::optional<c10::SymInt> end,
630     c10::SymInt step);
631 std::tuple<Tensor, Tensor> linalg_eig_jvp(
632     const Tensor& dA,
633     const Tensor& L,
634     const Tensor& V,
635     const bool is_hermitian);
636 Tensor linalg_eig_backward(
637     const Tensor& gL,
638     const Tensor& gV,
639     const Tensor& L,
640     const Tensor& V,
641     const bool is_hermitian,
642     const bool symeig_eigenvectors = true);
643 Tensor linalg_lstsq_jvp(
644     const Tensor& A,
645     const Tensor& B,
646     const Tensor& dA,
647     const Tensor& dB);
648 std::tuple<Tensor, Tensor> triangular_solve_backward(
649     const Tensor& grad_x,
650     const Tensor& grad_m,
651     const Tensor& b,
652     const Tensor& a,
653     const Tensor& x,
654     const bool upper,
655     const bool transpose,
656     const bool unitriangular,
657     std::array<bool, 2> output_mask);
658 Tensor triangular_solve_jvp(
659     const Tensor& X,
660     const Tensor& A,
661     const Tensor& dA,
662     const Tensor& dB,
663     const bool upper,
664     const bool transpose,
665     const bool unitriangular);
666 Tensor linalg_solve_triangular_forward_AD(
667     const Tensor& A_t,
668     const Tensor& B_t,
669     const Tensor& A,
670     const Tensor& X,
671     const bool upper,
672     const bool left,
673     const bool unitriangular);
674 std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
675     const Tensor& grad,
676     const Tensor& A,
677     const Tensor& X,
678     const bool upper,
679     const bool left,
680     const bool unitriangular,
681     std::array<bool, 2> output_mask);
682 std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(
683     const Tensor& grad_out,
684     const std::optional<Tensor>& i1,
685     const std::optional<Tensor>& i2,
686     const std::optional<Tensor>& i3,
687     IntArrayRef expand1,
688     IntArrayRef expand2,
689     IntArrayRef expand3,
690     IntArrayRef sumdim,
691     std::array<bool, 3> grad_mask);
692 std::tuple<Tensor, Tensor> linalg_qr_jvp(
693     const Tensor& dA,
694     const Tensor& Q,
695     const Tensor& R,
696     const c10::string_view mode);
697 Tensor linalg_qr_backward(
698     const Tensor& gQ,
699     const Tensor& gR,
700     const Tensor& Q,
701     const Tensor& R,
702     const c10::string_view mode);
703 Tensor linalg_matrix_exp_differential(
704     const Tensor& self,
705     const Tensor& grad,
706     bool adjoint);
707 std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
708     const Tensor& input,
709     const std::optional<Tensor>& gamma,
710     const Tensor& ggI,
711     const Tensor& ggG,
712     const Tensor& ggB,
713     const Tensor& gO,
714     const std::optional<Tensor>& running_mean,
715     const std::optional<Tensor>& running_var,
716     bool training,
717     double eps,
718     const std::optional<Tensor>& save_mean,
719     const std::optional<Tensor>& save_invstd,
720     std::array<bool, 3> output_mask);
721 std::tuple<Tensor, Tensor> _euclidean_dist_backward(
722     const Tensor& grad,
723     const Tensor& x1,
724     const Tensor& x2,
725     const Tensor& res);
726 Tensor fft_backward(
727     const Tensor& self,
728     const Tensor& grad,
729     int64_t signal_ndim,
730     bool complex_input,
731     bool complex_output,
732     bool inverse,
733     IntArrayRef checked_signal_sizes,
734     int64_t normalization,
735     bool onesided,
736     IntArrayRef output_sizes);
737 Tensor fft_r2c_backward(
738     const Tensor& grad,
739     at::IntArrayRef dim,
740     int64_t normalization,
741     bool onesided,
742     const c10::SymInt& last_dim_size);
743 Tensor fft_c2r_backward(
744     const Tensor& grad,
745     IntArrayRef dim,
746     int64_t normalization);
747 Tensor constant_pad_nd_backward(const Tensor& grad, c10::SymIntArrayRef pad);
748 std::tuple<Tensor, Tensor> cholesky_solve_backward(
749     const Tensor& grad_x,
750     const Tensor& self,
751     const Tensor& input2,
752     const Tensor& result,
753     const bool upper,
754     std::array<bool, 2> output_mask);
755 Tensor cholesky_solve_jvp(
756     const Tensor& X,
757     const Tensor& U,
758     const Tensor& dU,
759     const Tensor& dB,
760     const bool upper);
761 std::tuple<Tensor, Tensor, Tensor>
762 infinitely_differentiable_native_group_norm_backward(
763     const Tensor& dY,
764     const Tensor& dmean,
765     const Tensor& drstd,
766     const Tensor& X,
767     const Tensor& mean,
768     const Tensor& rstd,
769     const std::optional<Tensor>& gamma,
770     c10::SymInt N,
771     const c10::SymInt& C,
772     c10::SymInt HxW,
773     int64_t group,
774     double eps,
775     std::array<bool, 3> grad_input_mask);
776 Tensor gelu_double_backward(
777     const Tensor& ggI,
778     const Tensor& gO,
779     const Tensor& input,
780     c10::string_view approximate);
781 Tensor as_strided_backward(
782     Tensor grad,
783     const TensorGeometry& input_geometry,
784     c10::SymIntArrayRef sizes,
785     c10::SymIntArrayRef strides,
786     const std::optional<c10::SymInt>& storage_offset_);
787 Tensor as_strided_scatter_backward(
788     const Tensor& grad,
789     const TensorGeometry& input_geometry,
790     const TensorGeometry& src_geometry,
791     c10::SymIntArrayRef sizes,
792     c10::SymIntArrayRef strides,
793     std::optional<c10::SymInt> storage_offset);
794 std::tuple<Tensor, Tensor> atan2_backward(
795     const Tensor& grad,
796     const Tensor& self,
797     const Tensor& other,
798     std::array<bool, 2> output_mask);
799 Tensor amaxamin_jvp(
800     const Tensor& x,
801     const Tensor& dx,
802     const Tensor& result,
803     IntArrayRef dim,
804     bool keepdim);
805 std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
806     const Tensor& input,
807     const std::optional<Tensor>& gamma,
808     const Tensor& ggI,
809     const Tensor& ggG,
810     const Tensor& ggB,
811     const Tensor& gO,
812     const Tensor& save_mean,
813     const Tensor& save_invstd,
814     c10::SymIntArrayRef normalized_shape,
815     std::array<bool, 3> output_mask);
816 
817 std::tuple<Tensor, Tensor> householder_product_backward(
818     const Tensor& grad,
819     const Tensor& result,
820     const Tensor& input,
821     const Tensor& tau,
822     const bool flip_order = false);
823 Tensor householder_product_jvp(
824     const Tensor& dV,
825     const Tensor& dtau,
826     const Tensor& prod,
827     const Tensor& V,
828     const Tensor& tau);
829 std::tuple<Tensor, Tensor, Tensor> ormqr_backward(
830     const Tensor& grad,
831     const Tensor& result,
832     const Tensor& self,
833     const Tensor& tau,
834     const Tensor& other,
835     bool left,
836     bool transpose,
837     std::array<bool, 3> grad_output_mask);
838 std::tuple<Tensor, Tensor> polar_backward(
839     const Tensor& grad,
840     const Tensor& result);
841 Tensor i1_backward(
842     const Tensor& grad,
843     const Tensor& self,
844     const Tensor& result);
845 Tensor i1e_backward(
846     const Tensor& grad,
847     const Tensor& self,
848     const Tensor& result);
849 Tensor linalg_lu_solve_LU(
850     const Tensor& grad,
851     const Tensor& LU,
852     const Tensor& pivots,
853     const Tensor& X,
854     const bool left,
855     const bool adjoint);
856 Tensor linalg_lu_solve_jvp(
857     const Tensor& X,
858     const Tensor& LU,
859     const Tensor& pivots,
860     const Tensor& dLU,
861     const Tensor& dB,
862     const bool left,
863     const bool adjoint);
864 std::tuple<Tensor, Tensor> linalg_solve_backward(
865     const Tensor& gX,
866     const Tensor& X,
867     const Tensor& A,
868     const Tensor& LU,
869     const Tensor& pivots,
870     const bool left,
871     const bool B_requires_grad);
872 Tensor linalg_solve_jvp(
873     const Tensor& dA,
874     const Tensor& dB,
875     const Tensor& X,
876     const Tensor& LU,
877     const Tensor& pivots,
878     const bool left,
879     const bool use_A_T);
880 Tensor lu_unpack_backward(
881     const Tensor& L_grad,
882     const Tensor& U_grad,
883     const c10::SymInt& m,
884     const c10::SymInt& n);
885 
886 Tensor linalg_det_backward(
887     const Tensor& grad,
888     const Tensor& det,
889     const Tensor& A,
890     const Tensor& LU,
891     const Tensor& pivots);
892 Tensor linalg_det_jvp(
893     const Tensor& dA,
894     const Tensor& det,
895     const Tensor& LU,
896     const Tensor& pivots,
897     const bool use_A_T);
898 std::tuple<Tensor, Tensor> linalg_lstsq_backward(
899     const Tensor& grad,
900     const Tensor& A,
901     const Tensor& B_,
902     const std::array<bool, 2>& grad_input_mask);
903 Tensor linalg_lu_backward(
904     const Tensor& L_grad,
905     const Tensor& U_grad,
906     const Tensor& P,
907     const Tensor& L,
908     const Tensor& U,
909     const bool pivot);
910 
911 std::tuple<Tensor, Tensor> linalg_lu_jvp(
912     const Tensor& dA,
913     const Tensor& P,
914     const Tensor& L,
915     const Tensor& U,
916     const bool pivot);
917 
918 Tensor lu_factor_ex_backward(
919     const Tensor& grad,
920     const Tensor& LU,
921     const Tensor& pivs,
922     const bool pivot);
923 Tensor lu_factor_ex_jvp(
924     const Tensor& dX,
925     const Tensor& LU,
926     const Tensor& pivs,
927     const bool pivot);
928 
929 Tensor batch_norm_jvp(
930     const Tensor& input_p,
931     const Tensor& input_t,
932     const Tensor& weight_p,
933     const Tensor& weight_t,
934     const Tensor& bias_p,
935     const Tensor& bias_t,
936     const std::optional<Tensor>& running_mean,
937     const std::optional<Tensor>& running_var,
938     const Tensor& saved_mean,
939     const Tensor& saved_invstd,
940     bool train,
941     double eps);
942 
943 Tensor layer_norm_jvp(
944     const Tensor& input_p,
945     const Tensor& input_t,
946     const Tensor& weight_p,
947     const Tensor& weight_t,
948     const Tensor& bias_p,
949     const Tensor& bias_t,
950     const Tensor& saved_mean,
951     const Tensor& saved_invstd,
952     c10::SymIntArrayRef normalized_shape);
953 
954 Tensor group_norm_jvp(
955     const Tensor& input_p,
956     const Tensor& input_t,
957     const Tensor& weight_p,
958     const Tensor& weight_t,
959     const Tensor& bias_p,
960     const Tensor& bias_t,
961     const Tensor& saved_mean,
962     const Tensor& saved_invstd,
963     int64_t groups);
964 Tensor group_norm_mean_jvp(
965     const Tensor& input_t,
966     const Tensor& mean_p,
967     int64_t groups);
968 Tensor group_norm_invstd_jvp(
969     const Tensor& input_p,
970     const Tensor& input_t,
971     const Tensor& mean_p,
972     const Tensor& invstd_p,
973     int64_t groups);
974 
975 Tensor convolution_jvp(
976     const Tensor& input_p,
977     const Tensor& input_t,
978     const Tensor& weight_p,
979     const Tensor& weight_t,
980     const Tensor& bias_p,
981     const Tensor& bias_t,
982     at::SymIntArrayRef stride,
983     at::SymIntArrayRef padding,
984     at::SymIntArrayRef dilation,
985     bool transposed,
986     at::SymIntArrayRef output_padding,
987     const c10::SymInt& groups);
988 
989 Tensor _convolution_jvp(
990     const Tensor& input_p,
991     const Tensor& input_t,
992     const Tensor& weight_p,
993     const Tensor& weight_t,
994     const Tensor& bias_p,
995     const Tensor& bias_t,
996     at::SymIntArrayRef stride,
997     at::SymIntArrayRef padding,
998     at::SymIntArrayRef dilation,
999     bool transposed,
1000     at::SymIntArrayRef output_padding,
1001     const c10::SymInt& groups,
1002     bool benchmark,
1003     bool deterministic,
1004     bool cudnn_enabled,
1005     bool allow_tf32);
1006 
1007 Tensor convolution_backward_jvp_grad_bias(
1008     const Tensor& grad_out_t,
1009     const Tensor& grad_bias);
1010 
1011 Tensor cat_jvp(const at::ITensorListRef& tensors, int64_t dim);
1012 Tensor block_diag_jvp(at::TensorList tensors);
1013 Tensor stack_jvp(at::TensorList tensors, int64_t dim);
1014 Tensor cumprod_jvp(
1015     const Tensor& self_t,
1016     const Tensor& self_p,
1017     const Tensor& result,
1018     int dim);
1019 Tensor gather_with_keepdimed_indices(
1020     const Tensor& input,
1021     int64_t dim,
1022     const Tensor& indices,
1023     bool keepdim);
1024 Tensor evenly_read_jvp(
1025     const Tensor& fw_grad,
1026     const Tensor& input,
1027     const Tensor& value);
1028 Tensor warn_backwards(const Tensor& grad_output);
1029 
1030 std::tuple<Tensor, Tensor> _cudnn_convolution_backward(
1031     const at::Tensor& self,
1032     const at::Tensor& grad_output,
1033     const at::Tensor& weight,
1034     at::SymIntArrayRef padding,
1035     at::SymIntArrayRef output_padding,
1036     at::SymIntArrayRef stride,
1037     at::SymIntArrayRef dilation,
1038     bool transposed,
1039     c10::SymInt groups,
1040     ::std::array<bool, 2> output_mask);
1041 
1042 Tensor scatter_reduce_jvp(
1043     const Tensor& self_p,
1044     const Tensor& self_t,
1045     int dim,
1046     const Tensor& index,
1047     const Tensor& src_p,
1048     const Tensor& src_t,
1049     c10::string_view reduce,
1050     bool include_self,
1051     const Tensor& result);
1052 
1053 std::tuple<Tensor, Tensor> scatter_reduce_backward(
1054     const Tensor& grad,
1055     const Tensor& self,
1056     int dim,
1057     const Tensor& index,
1058     const Tensor& src,
1059     c10::string_view reduce,
1060     bool include_self,
1061     const Tensor& result);
1062 
1063 Tensor _to_copy_backward(
1064     const Tensor& grad,
1065     const c10::TensorOptions& self_options);
1066 
1067 std::tuple<Tensor, Tensor> index_reduce_backward(
1068     const Tensor& grad,
1069     const Tensor& self,
1070     int dim,
1071     const Tensor& index,
1072     const Tensor& source,
1073     c10::string_view reduce,
1074     bool include_self,
1075     const Tensor& result);
1076 
1077 Tensor take_backward(
1078     const Tensor& grad,
1079     const Tensor& self,
1080     const Tensor& indices);
1081 
1082 Tensor to_sparse_backward(
1083     const Tensor& grad,
1084     const c10::Layout self_layout,
1085     const c10::OptionalArrayRef<c10::SymInt>& self_blocksize);
1086 
1087 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>
1088 mkldnn_rnn_layer_differentiable_backward(
1089     const Tensor& input,
1090     const Tensor& weight0,
1091     const Tensor& weight1,
1092     const Tensor& weight2,
1093     const Tensor& weight3,
1094     const Tensor& hx_,
1095     const Tensor& cx_tmp,
1096     const Tensor& output,
1097     const Tensor& hy_,
1098     const Tensor& cy_,
1099     const std::optional<Tensor>& grad_output_r_opt,
1100     const std::optional<Tensor>& grad_hy_r_opt,
1101     const std::optional<Tensor>& grad_cy_r_opt,
1102     bool reverse,
1103     int64_t mode,
1104     int64_t hidden_size,
1105     int64_t num_layers,
1106     bool has_biases,
1107     bool train,
1108     bool bidirectional,
1109     at::IntArrayRef batch_sizes,
1110     bool batch_first,
1111     const at::Tensor& workspace);
1112 
1113 Tensor values_backward(const Tensor& grad, const Tensor& self);
1114 
1115 } // namespace torch::autograd::generated::details
1116