xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/linalg.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 
5 namespace torch {
6 namespace linalg {
7 
8 #ifndef DOXYGEN_SHOULD_SKIP_THIS
9 namespace detail {
10 
cholesky(const Tensor & self)11 inline Tensor cholesky(const Tensor& self) {
12   return torch::linalg_cholesky(self);
13 }
14 
cholesky_out(Tensor & result,const Tensor & self)15 inline Tensor cholesky_out(Tensor& result, const Tensor& self) {
16   return torch::linalg_cholesky_out(result, self);
17 }
18 
det(const Tensor & self)19 inline Tensor det(const Tensor& self) {
20   return torch::linalg_det(self);
21 }
22 
slogdet(const Tensor & input)23 inline std::tuple<Tensor, Tensor> slogdet(const Tensor& input) {
24   return torch::linalg_slogdet(input);
25 }
26 
slogdet_out(Tensor & sign,Tensor & logabsdet,const Tensor & input)27 inline std::tuple<Tensor&, Tensor&> slogdet_out(
28     Tensor& sign,
29     Tensor& logabsdet,
30     const Tensor& input) {
31   return torch::linalg_slogdet_out(sign, logabsdet, input);
32 }
33 
eig(const Tensor & self)34 inline std::tuple<Tensor, Tensor> eig(const Tensor& self) {
35   return torch::linalg_eig(self);
36 }
37 
eig_out(Tensor & eigvals,Tensor & eigvecs,const Tensor & self)38 inline std::tuple<Tensor&, Tensor&> eig_out(
39     Tensor& eigvals,
40     Tensor& eigvecs,
41     const Tensor& self) {
42   return torch::linalg_eig_out(eigvals, eigvecs, self);
43 }
44 
eigvals(const Tensor & self)45 inline Tensor eigvals(const Tensor& self) {
46   return torch::linalg_eigvals(self);
47 }
48 
eigvals_out(Tensor & result,const Tensor & self)49 inline Tensor& eigvals_out(Tensor& result, const Tensor& self) {
50   return torch::linalg_eigvals_out(result, self);
51 }
52 
eigh(const Tensor & self,c10::string_view uplo)53 inline std::tuple<Tensor, Tensor> eigh(
54     const Tensor& self,
55     c10::string_view uplo) {
56   return torch::linalg_eigh(self, uplo);
57 }
58 
eigh_out(Tensor & eigvals,Tensor & eigvecs,const Tensor & self,c10::string_view uplo)59 inline std::tuple<Tensor&, Tensor&> eigh_out(
60     Tensor& eigvals,
61     Tensor& eigvecs,
62     const Tensor& self,
63     c10::string_view uplo) {
64   return torch::linalg_eigh_out(eigvals, eigvecs, self, uplo);
65 }
66 
eigvalsh(const Tensor & self,c10::string_view uplo)67 inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) {
68   return torch::linalg_eigvalsh(self, uplo);
69 }
70 
eigvalsh_out(Tensor & result,const Tensor & self,c10::string_view uplo)71 inline Tensor& eigvalsh_out(
72     Tensor& result,
73     const Tensor& self,
74     c10::string_view uplo) {
75   return torch::linalg_eigvalsh_out(result, self, uplo);
76 }
77 
householder_product(const Tensor & input,const Tensor & tau)78 inline Tensor householder_product(const Tensor& input, const Tensor& tau) {
79   return torch::linalg_householder_product(input, tau);
80 }
81 
householder_product_out(Tensor & result,const Tensor & input,const Tensor & tau)82 inline Tensor& householder_product_out(
83     Tensor& result,
84     const Tensor& input,
85     const Tensor& tau) {
86   return torch::linalg_householder_product_out(result, input, tau);
87 }
88 
lu_factor(const Tensor & self,const bool pivot)89 inline std::tuple<Tensor, Tensor> lu_factor(
90     const Tensor& self,
91     const bool pivot) {
92   return torch::linalg_lu_factor(self, pivot);
93 }
94 
lu_factor_out(Tensor & LU,Tensor & pivots,const Tensor & self,const bool pivot)95 inline std::tuple<Tensor&, Tensor&> lu_factor_out(
96     Tensor& LU,
97     Tensor& pivots,
98     const Tensor& self,
99     const bool pivot) {
100   return torch::linalg_lu_factor_out(LU, pivots, self, pivot);
101 }
102 
lu(const Tensor & self,const bool pivot)103 inline std::tuple<Tensor, Tensor, Tensor> lu(
104     const Tensor& self,
105     const bool pivot) {
106   return torch::linalg_lu(self, pivot);
107 }
108 
lu_out(Tensor & P,Tensor & L,Tensor & U,const Tensor & self,const bool pivot)109 inline std::tuple<Tensor&, Tensor&, Tensor&> lu_out(
110     Tensor& P,
111     Tensor& L,
112     Tensor& U,
113     const Tensor& self,
114     const bool pivot) {
115   return torch::linalg_lu_out(P, L, U, self, pivot);
116 }
117 
lstsq(const Tensor & self,const Tensor & b,std::optional<double> cond,std::optional<c10::string_view> driver)118 inline std::tuple<Tensor, Tensor, Tensor, Tensor> lstsq(
119     const Tensor& self,
120     const Tensor& b,
121     std::optional<double> cond,
122     std::optional<c10::string_view> driver) {
123   return torch::linalg_lstsq(self, b, cond, driver);
124 }
125 
matrix_exp(const Tensor & self)126 inline Tensor matrix_exp(const Tensor& self) {
127   return torch::linalg_matrix_exp(self);
128 }
129 
norm(const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)130 inline Tensor norm(
131     const Tensor& self,
132     const std::optional<Scalar>& opt_ord,
133     OptionalIntArrayRef opt_dim,
134     bool keepdim,
135     std::optional<ScalarType> opt_dtype) {
136   return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
137 }
138 
norm(const Tensor & self,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)139 inline Tensor norm(
140     const Tensor& self,
141     c10::string_view ord,
142     OptionalIntArrayRef opt_dim,
143     bool keepdim,
144     std::optional<ScalarType> opt_dtype) {
145   return torch::linalg_norm(self, ord, opt_dim, keepdim, opt_dtype);
146 }
147 
norm_out(Tensor & result,const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)148 inline Tensor& norm_out(
149     Tensor& result,
150     const Tensor& self,
151     const std::optional<Scalar>& opt_ord,
152     OptionalIntArrayRef opt_dim,
153     bool keepdim,
154     std::optional<ScalarType> opt_dtype) {
155   return torch::linalg_norm_out(
156       result, self, opt_ord, opt_dim, keepdim, opt_dtype);
157 }
158 
norm_out(Tensor & result,const Tensor & self,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)159 inline Tensor& norm_out(
160     Tensor& result,
161     const Tensor& self,
162     c10::string_view ord,
163     OptionalIntArrayRef opt_dim,
164     bool keepdim,
165     std::optional<ScalarType> opt_dtype) {
166   return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
167 }
168 
vector_norm(const Tensor & self,Scalar ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)169 inline Tensor vector_norm(
170     const Tensor& self,
171     Scalar ord,
172     OptionalIntArrayRef opt_dim,
173     bool keepdim,
174     std::optional<ScalarType> opt_dtype) {
175   return torch::linalg_vector_norm(self, ord, opt_dim, keepdim, opt_dtype);
176 }
177 
vector_norm_out(Tensor & result,const Tensor & self,Scalar ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)178 inline Tensor& vector_norm_out(
179     Tensor& result,
180     const Tensor& self,
181     Scalar ord,
182     OptionalIntArrayRef opt_dim,
183     bool keepdim,
184     std::optional<ScalarType> opt_dtype) {
185   return torch::linalg_vector_norm_out(
186       result, self, ord, opt_dim, keepdim, opt_dtype);
187 }
188 
matrix_norm(const Tensor & self,const Scalar & ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)189 inline Tensor matrix_norm(
190     const Tensor& self,
191     const Scalar& ord,
192     IntArrayRef dim,
193     bool keepdim,
194     std::optional<ScalarType> dtype) {
195   return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype);
196 }
197 
matrix_norm_out(const Tensor & self,const Scalar & ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype,Tensor & result)198 inline Tensor& matrix_norm_out(
199     const Tensor& self,
200     const Scalar& ord,
201     IntArrayRef dim,
202     bool keepdim,
203     std::optional<ScalarType> dtype,
204     Tensor& result) {
205   return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype);
206 }
207 
matrix_norm(const Tensor & self,std::string ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)208 inline Tensor matrix_norm(
209     const Tensor& self,
210     std::string ord,
211     IntArrayRef dim,
212     bool keepdim,
213     std::optional<ScalarType> dtype) {
214   return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype);
215 }
216 
matrix_norm_out(const Tensor & self,std::string ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype,Tensor & result)217 inline Tensor& matrix_norm_out(
218     const Tensor& self,
219     std::string ord,
220     IntArrayRef dim,
221     bool keepdim,
222     std::optional<ScalarType> dtype,
223     Tensor& result) {
224   return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype);
225 }
226 
matrix_power(const Tensor & self,int64_t n)227 inline Tensor matrix_power(const Tensor& self, int64_t n) {
228   return torch::linalg_matrix_power(self, n);
229 }
230 
matrix_power_out(const Tensor & self,int64_t n,Tensor & result)231 inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
232   return torch::linalg_matrix_power_out(result, self, n);
233 }
234 
matrix_rank(const Tensor & input,double tol,bool hermitian)235 inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) {
236   return torch::linalg_matrix_rank(input, tol, hermitian);
237 }
238 
matrix_rank(const Tensor & input,const Tensor & tol,bool hermitian)239 inline Tensor matrix_rank(
240     const Tensor& input,
241     const Tensor& tol,
242     bool hermitian) {
243   return torch::linalg_matrix_rank(input, tol, hermitian);
244 }
245 
matrix_rank(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)246 inline Tensor matrix_rank(
247     const Tensor& input,
248     std::optional<double> atol,
249     std::optional<double> rtol,
250     bool hermitian) {
251   return torch::linalg_matrix_rank(input, atol, rtol, hermitian);
252 }
253 
matrix_rank(const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian)254 inline Tensor matrix_rank(
255     const Tensor& input,
256     const std::optional<Tensor>& atol,
257     const std::optional<Tensor>& rtol,
258     bool hermitian) {
259   return torch::linalg_matrix_rank(input, atol, rtol, hermitian);
260 }
261 
matrix_rank_out(Tensor & result,const Tensor & input,double tol,bool hermitian)262 inline Tensor& matrix_rank_out(
263     Tensor& result,
264     const Tensor& input,
265     double tol,
266     bool hermitian) {
267   return torch::linalg_matrix_rank_out(result, input, tol, hermitian);
268 }
269 
matrix_rank_out(Tensor & result,const Tensor & input,const Tensor & tol,bool hermitian)270 inline Tensor& matrix_rank_out(
271     Tensor& result,
272     const Tensor& input,
273     const Tensor& tol,
274     bool hermitian) {
275   return torch::linalg_matrix_rank_out(result, input, tol, hermitian);
276 }
277 
matrix_rank_out(Tensor & result,const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)278 inline Tensor& matrix_rank_out(
279     Tensor& result,
280     const Tensor& input,
281     std::optional<double> atol,
282     std::optional<double> rtol,
283     bool hermitian) {
284   return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian);
285 }
286 
matrix_rank_out(Tensor & result,const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian)287 inline Tensor& matrix_rank_out(
288     Tensor& result,
289     const Tensor& input,
290     const std::optional<Tensor>& atol,
291     const std::optional<Tensor>& rtol,
292     bool hermitian) {
293   return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian);
294 }
295 
multi_dot(TensorList tensors)296 inline Tensor multi_dot(TensorList tensors) {
297   return torch::linalg_multi_dot(tensors);
298 }
299 
multi_dot_out(TensorList tensors,Tensor & result)300 inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) {
301   return torch::linalg_multi_dot_out(result, tensors);
302 }
303 
pinv(const Tensor & input,double rcond,bool hermitian)304 inline Tensor pinv(const Tensor& input, double rcond, bool hermitian) {
305   return torch::linalg_pinv(input, rcond, hermitian);
306 }
307 
pinv_out(Tensor & result,const Tensor & input,double rcond,bool hermitian)308 inline Tensor& pinv_out(
309     Tensor& result,
310     const Tensor& input,
311     double rcond,
312     bool hermitian) {
313   return torch::linalg_pinv_out(result, input, rcond, hermitian);
314 }
315 
qr(const Tensor & input,c10::string_view mode)316 inline std::tuple<Tensor, Tensor> qr(
317     const Tensor& input,
318     c10::string_view mode) {
319   return torch::linalg_qr(input, mode);
320 }
321 
qr_out(Tensor & Q,Tensor & R,const Tensor & input,c10::string_view mode)322 inline std::tuple<Tensor&, Tensor&> qr_out(
323     Tensor& Q,
324     Tensor& R,
325     const Tensor& input,
326     c10::string_view mode) {
327   return torch::linalg_qr_out(Q, R, input, mode);
328 }
329 
solve_ex(const Tensor & input,const Tensor & other,bool left,bool check_errors)330 inline std::tuple<Tensor, Tensor> solve_ex(
331     const Tensor& input,
332     const Tensor& other,
333     bool left,
334     bool check_errors) {
335   return torch::linalg_solve_ex(input, other, left, check_errors);
336 }
337 
solve_ex_out(Tensor & result,Tensor & info,const Tensor & input,const Tensor & other,bool left,bool check_errors)338 inline std::tuple<Tensor&, Tensor&> solve_ex_out(
339     Tensor& result,
340     Tensor& info,
341     const Tensor& input,
342     const Tensor& other,
343     bool left,
344     bool check_errors) {
345   return torch::linalg_solve_ex_out(
346       result, info, input, other, left, check_errors);
347 }
348 
solve(const Tensor & input,const Tensor & other,bool left)349 inline Tensor solve(const Tensor& input, const Tensor& other, bool left) {
350   return torch::linalg_solve(input, other, left);
351 }
352 
solve_out(Tensor & result,const Tensor & input,const Tensor & other,bool left)353 inline Tensor& solve_out(
354     Tensor& result,
355     const Tensor& input,
356     const Tensor& other,
357     bool left) {
358   return torch::linalg_solve_out(result, input, other, left);
359 }
360 
solve_triangular(const Tensor & input,const Tensor & other,bool upper,bool left,bool unitriangular)361 inline Tensor solve_triangular(
362     const Tensor& input,
363     const Tensor& other,
364     bool upper,
365     bool left,
366     bool unitriangular) {
367   return torch::linalg_solve_triangular(
368       input, other, upper, left, unitriangular);
369 }
370 
solve_triangular_out(Tensor & result,const Tensor & input,const Tensor & other,bool upper,bool left,bool unitriangular)371 inline Tensor& solve_triangular_out(
372     Tensor& result,
373     const Tensor& input,
374     const Tensor& other,
375     bool upper,
376     bool left,
377     bool unitriangular) {
378   return torch::linalg_solve_triangular_out(
379       result, input, other, upper, left, unitriangular);
380 }
381 
svd(const Tensor & input,bool full_matrices,std::optional<c10::string_view> driver)382 inline std::tuple<Tensor, Tensor, Tensor> svd(
383     const Tensor& input,
384     bool full_matrices,
385     std::optional<c10::string_view> driver) {
386   return torch::linalg_svd(input, full_matrices, driver);
387 }
388 
svd_out(Tensor & U,Tensor & S,Tensor & Vh,const Tensor & input,bool full_matrices,std::optional<c10::string_view> driver)389 inline std::tuple<Tensor&, Tensor&, Tensor&> svd_out(
390     Tensor& U,
391     Tensor& S,
392     Tensor& Vh,
393     const Tensor& input,
394     bool full_matrices,
395     std::optional<c10::string_view> driver) {
396   return torch::linalg_svd_out(U, S, Vh, input, full_matrices, driver);
397 }
398 
svdvals(const Tensor & input,std::optional<c10::string_view> driver)399 inline Tensor svdvals(
400     const Tensor& input,
401     std::optional<c10::string_view> driver) {
402   return torch::linalg_svdvals(input, driver);
403 }
404 
svdvals_out(Tensor & result,const Tensor & input,std::optional<c10::string_view> driver)405 inline Tensor& svdvals_out(
406     Tensor& result,
407     const Tensor& input,
408     std::optional<c10::string_view> driver) {
409   return torch::linalg_svdvals_out(result, input, driver);
410 }
411 
tensorinv(const Tensor & self,int64_t ind)412 inline Tensor tensorinv(const Tensor& self, int64_t ind) {
413   return torch::linalg_tensorinv(self, ind);
414 }
415 
tensorinv_out(Tensor & result,const Tensor & self,int64_t ind)416 inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) {
417   return torch::linalg_tensorinv_out(result, self, ind);
418 }
419 
tensorsolve(const Tensor & self,const Tensor & other,OptionalIntArrayRef dims)420 inline Tensor tensorsolve(
421     const Tensor& self,
422     const Tensor& other,
423     OptionalIntArrayRef dims) {
424   return torch::linalg_tensorsolve(self, other, dims);
425 }
426 
tensorsolve_out(Tensor & result,const Tensor & self,const Tensor & other,OptionalIntArrayRef dims)427 inline Tensor& tensorsolve_out(
428     Tensor& result,
429     const Tensor& self,
430     const Tensor& other,
431     OptionalIntArrayRef dims) {
432   return torch::linalg_tensorsolve_out(result, self, other, dims);
433 }
434 
inv(const Tensor & input)435 inline Tensor inv(const Tensor& input) {
436   return torch::linalg_inv(input);
437 }
438 
inv_out(Tensor & result,const Tensor & input)439 inline Tensor& inv_out(Tensor& result, const Tensor& input) {
440   return torch::linalg_inv_out(result, input);
441 }
442 
443 } // namespace detail
444 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
445 
446 /// Cholesky decomposition
447 ///
448 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.cholesky
449 ///
450 /// Example:
451 /// ```
452 /// auto A = torch::randn({4, 4});
453 /// auto A = torch::matmul(A, A.t());
454 /// auto L = torch::linalg::cholesky(A);
455 /// assert(torch::allclose(torch::matmul(L, L.t()), A));
456 /// ```
cholesky(const Tensor & self)457 inline Tensor cholesky(const Tensor& self) {
458   return detail::cholesky(self);
459 }
460 
cholesky_out(Tensor & result,const Tensor & self)461 inline Tensor cholesky_out(Tensor& result, const Tensor& self) {
462   return detail::cholesky_out(result, self);
463 }
464 
465 // C10_DEPRECATED_MESSAGE("linalg_det is deprecated, use det instead.")
linalg_det(const Tensor & self)466 inline Tensor linalg_det(const Tensor& self) {
467   return detail::det(self);
468 }
469 
470 /// See the documentation of torch.linalg.det
det(const Tensor & self)471 inline Tensor det(const Tensor& self) {
472   return detail::det(self);
473 }
474 
475 /// Computes the sign and (natural) logarithm of the determinant
476 ///
477 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.slogdet
slogdet(const Tensor & input)478 inline std::tuple<Tensor, Tensor> slogdet(const Tensor& input) {
479   return detail::slogdet(input);
480 }
481 
slogdet_out(Tensor & sign,Tensor & logabsdet,const Tensor & input)482 inline std::tuple<Tensor&, Tensor&> slogdet_out(
483     Tensor& sign,
484     Tensor& logabsdet,
485     const Tensor& input) {
486   return detail::slogdet_out(sign, logabsdet, input);
487 }
488 
489 /// Computes eigenvalues and eigenvectors of non-symmetric/non-hermitian
490 /// matrices
491 ///
492 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eig
eig(const Tensor & self)493 inline std::tuple<Tensor, Tensor> eig(const Tensor& self) {
494   return detail::eig(self);
495 }
496 
eig_out(Tensor & eigvals,Tensor & eigvecs,const Tensor & self)497 inline std::tuple<Tensor&, Tensor&> eig_out(
498     Tensor& eigvals,
499     Tensor& eigvecs,
500     const Tensor& self) {
501   return detail::eig_out(eigvals, eigvecs, self);
502 }
503 
504 /// Computes eigenvalues of non-symmetric/non-hermitian matrices
505 ///
506 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eigvals
eigvals(const Tensor & self)507 inline Tensor eigvals(const Tensor& self) {
508   return detail::eigvals(self);
509 }
510 
eigvals_out(Tensor & result,const Tensor & self)511 inline Tensor& eigvals_out(Tensor& result, const Tensor& self) {
512   return detail::eigvals_out(result, self);
513 }
514 
515 /// Computes eigenvalues and eigenvectors
516 ///
517 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eigh
eigh(const Tensor & self,c10::string_view uplo)518 inline std::tuple<Tensor, Tensor> eigh(
519     const Tensor& self,
520     c10::string_view uplo) {
521   return detail::eigh(self, uplo);
522 }
523 
eigh_out(Tensor & eigvals,Tensor & eigvecs,const Tensor & self,c10::string_view uplo)524 inline std::tuple<Tensor&, Tensor&> eigh_out(
525     Tensor& eigvals,
526     Tensor& eigvecs,
527     const Tensor& self,
528     c10::string_view uplo) {
529   return detail::eigh_out(eigvals, eigvecs, self, uplo);
530 }
531 
532 /// Computes eigenvalues
533 ///
534 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eigvalsh
eigvalsh(const Tensor & self,c10::string_view uplo)535 inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) {
536   return detail::eigvalsh(self, uplo);
537 }
538 
eigvalsh_out(Tensor & result,const Tensor & self,c10::string_view uplo)539 inline Tensor& eigvalsh_out(
540     Tensor& result,
541     const Tensor& self,
542     c10::string_view uplo) {
543   return detail::eigvalsh_out(result, self, uplo);
544 }
545 
546 /// Computes the product of Householder matrices
547 ///
548 /// See
549 /// https://pytorch.org/docs/main/linalg.html#torch.linalg.householder_product
householder_product(const Tensor & input,const Tensor & tau)550 inline Tensor householder_product(const Tensor& input, const Tensor& tau) {
551   return detail::householder_product(input, tau);
552 }
553 
householder_product_out(Tensor & result,const Tensor & input,const Tensor & tau)554 inline Tensor& householder_product_out(
555     Tensor& result,
556     const Tensor& input,
557     const Tensor& tau) {
558   return detail::householder_product_out(result, input, tau);
559 }
560 
lstsq(const Tensor & self,const Tensor & b,std::optional<double> cond,std::optional<c10::string_view> driver)561 inline std::tuple<Tensor, Tensor, Tensor, Tensor> lstsq(
562     const Tensor& self,
563     const Tensor& b,
564     std::optional<double> cond,
565     std::optional<c10::string_view> driver) {
566   return detail::lstsq(self, b, cond, driver);
567 }
568 
569 /// Computes the matrix exponential
570 ///
571 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_exp
matrix_exp(const Tensor & input)572 inline Tensor matrix_exp(const Tensor& input) {
573   return detail::matrix_exp(input);
574 }
575 
576 // C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.")
linalg_norm(const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)577 inline Tensor linalg_norm(
578     const Tensor& self,
579     const std::optional<Scalar>& opt_ord,
580     OptionalIntArrayRef opt_dim,
581     bool keepdim,
582     std::optional<ScalarType> opt_dtype) {
583   return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
584 }
585 
586 // C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.")
linalg_norm(const Tensor & self,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)587 inline Tensor linalg_norm(
588     const Tensor& self,
589     c10::string_view ord,
590     OptionalIntArrayRef opt_dim,
591     bool keepdim,
592     std::optional<ScalarType> opt_dtype) {
593   return detail::norm(self, ord, opt_dim, keepdim, opt_dtype);
594 }
595 
596 // C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out
597 // instead.")
linalg_norm_out(Tensor & result,const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)598 inline Tensor& linalg_norm_out(
599     Tensor& result,
600     const Tensor& self,
601     const std::optional<Scalar>& opt_ord,
602     OptionalIntArrayRef opt_dim,
603     bool keepdim,
604     std::optional<ScalarType> opt_dtype) {
605   return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
606 }
607 
608 // C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out
609 // instead.")
linalg_norm_out(Tensor & result,const Tensor & self,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)610 inline Tensor& linalg_norm_out(
611     Tensor& result,
612     const Tensor& self,
613     c10::string_view ord,
614     OptionalIntArrayRef opt_dim,
615     bool keepdim,
616     std::optional<ScalarType> opt_dtype) {
617   return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
618 }
619 
620 /// Computes the LU factorization with partial pivoting
621 ///
622 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.lu_factor
623 inline std::tuple<Tensor, Tensor> lu_factor(
624     const Tensor& input,
625     const bool pivot = true) {
626   return detail::lu_factor(input, pivot);
627 }
628 
629 inline std::tuple<Tensor&, Tensor&> lu_factor_out(
630     Tensor& LU,
631     Tensor& pivots,
632     const Tensor& self,
633     const bool pivot = true) {
634   return detail::lu_factor_out(LU, pivots, self, pivot);
635 }
636 
637 /// Computes the LU factorization with partial pivoting
638 ///
639 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.lu
640 inline std::tuple<Tensor, Tensor, Tensor> lu(
641     const Tensor& input,
642     const bool pivot = true) {
643   return detail::lu(input, pivot);
644 }
645 
646 inline std::tuple<Tensor&, Tensor&, Tensor&> lu_out(
647     Tensor& P,
648     Tensor& L,
649     Tensor& U,
650     const Tensor& self,
651     const bool pivot = true) {
652   return detail::lu_out(P, L, U, self, pivot);
653 }
654 
norm(const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)655 inline Tensor norm(
656     const Tensor& self,
657     const std::optional<Scalar>& opt_ord,
658     OptionalIntArrayRef opt_dim,
659     bool keepdim,
660     std::optional<ScalarType> opt_dtype) {
661   return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
662 }
663 
norm(const Tensor & self,std::string ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)664 inline Tensor norm(
665     const Tensor& self,
666     std::string ord,
667     OptionalIntArrayRef opt_dim,
668     bool keepdim,
669     std::optional<ScalarType> opt_dtype) {
670   return detail::norm(self, ord, opt_dim, keepdim, opt_dtype);
671 }
672 
norm_out(Tensor & result,const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)673 inline Tensor& norm_out(
674     Tensor& result,
675     const Tensor& self,
676     const std::optional<Scalar>& opt_ord,
677     OptionalIntArrayRef opt_dim,
678     bool keepdim,
679     std::optional<ScalarType> opt_dtype) {
680   return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
681 }
682 
norm_out(Tensor & result,const Tensor & self,std::string ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)683 inline Tensor& norm_out(
684     Tensor& result,
685     const Tensor& self,
686     std::string ord,
687     OptionalIntArrayRef opt_dim,
688     bool keepdim,
689     std::optional<ScalarType> opt_dtype) {
690   return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
691 }
692 
693 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.vector_norm
vector_norm(const Tensor & self,Scalar ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)694 inline Tensor vector_norm(
695     const Tensor& self,
696     Scalar ord,
697     OptionalIntArrayRef opt_dim,
698     bool keepdim,
699     std::optional<ScalarType> opt_dtype) {
700   return detail::vector_norm(self, ord, opt_dim, keepdim, opt_dtype);
701 }
702 
vector_norm_out(Tensor & result,const Tensor & self,Scalar ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)703 inline Tensor& vector_norm_out(
704     Tensor& result,
705     const Tensor& self,
706     Scalar ord,
707     OptionalIntArrayRef opt_dim,
708     bool keepdim,
709     std::optional<ScalarType> opt_dtype) {
710   return detail::vector_norm_out(
711       result, self, ord, opt_dim, keepdim, opt_dtype);
712 }
713 
714 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_norm
matrix_norm(const Tensor & self,const Scalar & ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)715 inline Tensor matrix_norm(
716     const Tensor& self,
717     const Scalar& ord,
718     IntArrayRef dim,
719     bool keepdim,
720     std::optional<ScalarType> dtype) {
721   return detail::matrix_norm(self, ord, dim, keepdim, dtype);
722 }
723 
matrix_norm_out(const Tensor & self,const Scalar & ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype,Tensor & result)724 inline Tensor& matrix_norm_out(
725     const Tensor& self,
726     const Scalar& ord,
727     IntArrayRef dim,
728     bool keepdim,
729     std::optional<ScalarType> dtype,
730     Tensor& result) {
731   return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result);
732 }
733 
matrix_norm(const Tensor & self,std::string ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)734 inline Tensor matrix_norm(
735     const Tensor& self,
736     std::string ord,
737     IntArrayRef dim,
738     bool keepdim,
739     std::optional<ScalarType> dtype) {
740   return detail::matrix_norm(self, ord, dim, keepdim, dtype);
741 }
742 
matrix_norm_out(const Tensor & self,std::string ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype,Tensor & result)743 inline Tensor& matrix_norm_out(
744     const Tensor& self,
745     std::string ord,
746     IntArrayRef dim,
747     bool keepdim,
748     std::optional<ScalarType> dtype,
749     Tensor& result) {
750   return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result);
751 }
752 
753 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_power
matrix_power(const Tensor & self,int64_t n)754 inline Tensor matrix_power(const Tensor& self, int64_t n) {
755   return detail::matrix_power(self, n);
756 }
757 
matrix_power_out(const Tensor & self,int64_t n,Tensor & result)758 inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
759   return detail::matrix_power_out(self, n, result);
760 }
761 
762 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_rank
matrix_rank(const Tensor & input,double tol,bool hermitian)763 inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) {
764   return detail::matrix_rank(input, tol, hermitian);
765 }
766 
matrix_rank(const Tensor & input,const Tensor & tol,bool hermitian)767 inline Tensor matrix_rank(
768     const Tensor& input,
769     const Tensor& tol,
770     bool hermitian) {
771   return detail::matrix_rank(input, tol, hermitian);
772 }
773 
matrix_rank(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)774 inline Tensor matrix_rank(
775     const Tensor& input,
776     std::optional<double> atol,
777     std::optional<double> rtol,
778     bool hermitian) {
779   return detail::matrix_rank(input, atol, rtol, hermitian);
780 }
781 
matrix_rank(const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian)782 inline Tensor matrix_rank(
783     const Tensor& input,
784     const std::optional<Tensor>& atol,
785     const std::optional<Tensor>& rtol,
786     bool hermitian) {
787   return detail::matrix_rank(input, atol, rtol, hermitian);
788 }
789 
matrix_rank_out(Tensor & result,const Tensor & input,double tol,bool hermitian)790 inline Tensor& matrix_rank_out(
791     Tensor& result,
792     const Tensor& input,
793     double tol,
794     bool hermitian) {
795   return detail::matrix_rank_out(result, input, tol, hermitian);
796 }
797 
matrix_rank_out(Tensor & result,const Tensor & input,const Tensor & tol,bool hermitian)798 inline Tensor& matrix_rank_out(
799     Tensor& result,
800     const Tensor& input,
801     const Tensor& tol,
802     bool hermitian) {
803   return detail::matrix_rank_out(result, input, tol, hermitian);
804 }
805 
matrix_rank_out(Tensor & result,const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)806 inline Tensor& matrix_rank_out(
807     Tensor& result,
808     const Tensor& input,
809     std::optional<double> atol,
810     std::optional<double> rtol,
811     bool hermitian) {
812   return detail::matrix_rank_out(result, input, atol, rtol, hermitian);
813 }
814 
matrix_rank_out(Tensor & result,const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian)815 inline Tensor& matrix_rank_out(
816     Tensor& result,
817     const Tensor& input,
818     const std::optional<Tensor>& atol,
819     const std::optional<Tensor>& rtol,
820     bool hermitian) {
821   return detail::matrix_rank_out(result, input, atol, rtol, hermitian);
822 }
823 
824 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.multi_dot
multi_dot(TensorList tensors)825 inline Tensor multi_dot(TensorList tensors) {
826   return detail::multi_dot(tensors);
827 }
828 
multi_dot_out(TensorList tensors,Tensor & result)829 inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) {
830   return detail::multi_dot_out(tensors, result);
831 }
832 
833 /// Computes the pseudo-inverse
834 ///
835 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.pinv
836 inline Tensor pinv(
837     const Tensor& input,
838     double rcond = 1e-15,
839     bool hermitian = false) {
840   return detail::pinv(input, rcond, hermitian);
841 }
842 
843 inline Tensor& pinv_out(
844     Tensor& result,
845     const Tensor& input,
846     double rcond = 1e-15,
847     bool hermitian = false) {
848   return detail::pinv_out(result, input, rcond, hermitian);
849 }
850 
851 /// Computes the QR decomposition
852 ///
853 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.qr
854 inline std::tuple<Tensor, Tensor> qr(
855     const Tensor& input,
856     c10::string_view mode = "reduced") {
857   // C++17 Change the initialisation to "reduced"sv
858   //       Same for qr_out
859   return detail::qr(input, mode);
860 }
861 
862 inline std::tuple<Tensor&, Tensor&> qr_out(
863     Tensor& Q,
864     Tensor& R,
865     const Tensor& input,
866     c10::string_view mode = "reduced") {
867   return detail::qr_out(Q, R, input, mode);
868 }
869 
870 /// Computes the LDL decomposition
871 ///
872 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.ldl_factor_ex
ldl_factor_ex(const Tensor & input,bool hermitian,bool check_errors)873 inline std::tuple<Tensor, Tensor, Tensor> ldl_factor_ex(
874     const Tensor& input,
875     bool hermitian,
876     bool check_errors) {
877   return torch::linalg_ldl_factor_ex(input, hermitian, check_errors);
878 }
879 
ldl_factor_ex_out(Tensor & LD,Tensor & pivots,Tensor & info,const Tensor & input,bool hermitian,bool check_errors)880 inline std::tuple<Tensor&, Tensor&, Tensor&> ldl_factor_ex_out(
881     Tensor& LD,
882     Tensor& pivots,
883     Tensor& info,
884     const Tensor& input,
885     bool hermitian,
886     bool check_errors) {
887   return torch::linalg_ldl_factor_ex_out(
888       LD, pivots, info, input, hermitian, check_errors);
889 }
890 
891 /// Solve a system of linear equations using the LDL decomposition
892 ///
893 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.ldl_solve
ldl_solve(const Tensor & LD,const Tensor & pivots,const Tensor & B,bool hermitian)894 inline Tensor ldl_solve(
895     const Tensor& LD,
896     const Tensor& pivots,
897     const Tensor& B,
898     bool hermitian) {
899   return torch::linalg_ldl_solve(LD, pivots, B, hermitian);
900 }
901 
ldl_solve_out(Tensor & result,const Tensor & LD,const Tensor & pivots,const Tensor & B,bool hermitian)902 inline Tensor& ldl_solve_out(
903     Tensor& result,
904     const Tensor& LD,
905     const Tensor& pivots,
906     const Tensor& B,
907     bool hermitian) {
908   return torch::linalg_ldl_solve_out(result, LD, pivots, B, hermitian);
909 }
910 
911 /// Solves a system linear system AX = B
912 ///
913 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.solve_ex
solve_ex(const Tensor & input,const Tensor & other,bool left,bool check_errors)914 inline std::tuple<Tensor, Tensor> solve_ex(
915     const Tensor& input,
916     const Tensor& other,
917     bool left,
918     bool check_errors) {
919   return detail::solve_ex(input, other, left, check_errors);
920 }
921 
solve_ex_out(Tensor & result,Tensor & info,const Tensor & input,const Tensor & other,bool left,bool check_errors)922 inline std::tuple<Tensor&, Tensor&> solve_ex_out(
923     Tensor& result,
924     Tensor& info,
925     const Tensor& input,
926     const Tensor& other,
927     bool left,
928     bool check_errors) {
929   return detail::solve_ex_out(result, info, input, other, left, check_errors);
930 }
931 
932 /// Computes a tensor `x` such that `matmul(input, x) = other`.
933 ///
934 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.solve
solve(const Tensor & input,const Tensor & other,bool left)935 inline Tensor solve(const Tensor& input, const Tensor& other, bool left) {
936   return detail::solve(input, other, left);
937 }
938 
solve_out(Tensor & result,const Tensor & input,const Tensor & other,bool left)939 inline Tensor& solve_out(
940     Tensor& result,
941     const Tensor& input,
942     const Tensor& other,
943     bool left) {
944   return detail::solve_out(result, input, other, left);
945 }
946 
947 /// Computes a solution of a linear system AX = B for input = A and other = B
948 /// whenever A is square upper or lower triangular and does not have zeros in
949 /// the diagonal
950 ///
951 /// See
952 /// https://pytorch.org/docs/main/linalg.html#torch.linalg.solve_triangular
solve_triangular(const Tensor & input,const Tensor & other,bool upper,bool left,bool unitriangular)953 inline Tensor solve_triangular(
954     const Tensor& input,
955     const Tensor& other,
956     bool upper,
957     bool left,
958     bool unitriangular) {
959   return detail::solve_triangular(input, other, upper, left, unitriangular);
960 }
961 
solve_triangular_out(Tensor & result,const Tensor & input,const Tensor & other,bool upper,bool left,bool unitriangular)962 inline Tensor& solve_triangular_out(
963     Tensor& result,
964     const Tensor& input,
965     const Tensor& other,
966     bool upper,
967     bool left,
968     bool unitriangular) {
969   return detail::solve_triangular_out(
970       result, input, other, upper, left, unitriangular);
971 }
972 
973 /// Computes the singular values and singular vectors
974 ///
975 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.svd
svd(const Tensor & input,bool full_matrices,std::optional<c10::string_view> driver)976 inline std::tuple<Tensor, Tensor, Tensor> svd(
977     const Tensor& input,
978     bool full_matrices,
979     std::optional<c10::string_view> driver) {
980   return detail::svd(input, full_matrices, driver);
981 }
982 
svd_out(Tensor & U,Tensor & S,Tensor & Vh,const Tensor & input,bool full_matrices,std::optional<c10::string_view> driver)983 inline std::tuple<Tensor&, Tensor&, Tensor&> svd_out(
984     Tensor& U,
985     Tensor& S,
986     Tensor& Vh,
987     const Tensor& input,
988     bool full_matrices,
989     std::optional<c10::string_view> driver) {
990   return detail::svd_out(U, S, Vh, input, full_matrices, driver);
991 }
992 
993 /// Computes the singular values
994 ///
995 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.svdvals
svdvals(const Tensor & input,std::optional<c10::string_view> driver)996 inline Tensor svdvals(
997     const Tensor& input,
998     std::optional<c10::string_view> driver) {
999   return detail::svdvals(input, driver);
1000 }
1001 
svdvals_out(Tensor & result,const Tensor & input,std::optional<c10::string_view> driver)1002 inline Tensor& svdvals_out(
1003     Tensor& result,
1004     const Tensor& input,
1005     std::optional<c10::string_view> driver) {
1006   return detail::svdvals_out(result, input, driver);
1007 }
1008 
1009 /// Computes the inverse of a tensor
1010 ///
1011 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.tensorinv
1012 ///
1013 /// Example:
1014 /// ```
1015 /// auto a = torch::eye(4*6).reshape({4, 6, 8, 3});
1016 /// int64_t ind = 2;
1017 /// auto ainv = torch::linalg::tensorinv(a, ind);
1018 /// ```
tensorinv(const Tensor & self,int64_t ind)1019 inline Tensor tensorinv(const Tensor& self, int64_t ind) {
1020   return detail::tensorinv(self, ind);
1021 }
1022 
tensorinv_out(Tensor & result,const Tensor & self,int64_t ind)1023 inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) {
1024   return detail::tensorinv_out(result, self, ind);
1025 }
1026 
1027 /// Computes a tensor `x` such that `tensordot(input, x, dims=x.dim()) = other`.
1028 ///
1029 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.tensorsolve
1030 ///
1031 /// Example:
1032 /// ```
1033 /// auto a = torch::eye(2*3*4).reshape({2*3, 4, 2, 3, 4});
1034 /// auto b = torch::randn(2*3, 4);
1035 /// auto x = torch::linalg::tensorsolve(a, b);
1036 /// ```
tensorsolve(const Tensor & input,const Tensor & other,OptionalIntArrayRef dims)1037 inline Tensor tensorsolve(
1038     const Tensor& input,
1039     const Tensor& other,
1040     OptionalIntArrayRef dims) {
1041   return detail::tensorsolve(input, other, dims);
1042 }
1043 
tensorsolve_out(Tensor & result,const Tensor & input,const Tensor & other,OptionalIntArrayRef dims)1044 inline Tensor& tensorsolve_out(
1045     Tensor& result,
1046     const Tensor& input,
1047     const Tensor& other,
1048     OptionalIntArrayRef dims) {
1049   return detail::tensorsolve_out(result, input, other, dims);
1050 }
1051 
1052 /// Computes a tensor `inverse_input` such that `dot(input, inverse_input) =
1053 /// eye(input.size(0))`.
1054 ///
1055 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.inv
inv(const Tensor & input)1056 inline Tensor inv(const Tensor& input) {
1057   return detail::inv(input);
1058 }
1059 
inv_out(Tensor & result,const Tensor & input)1060 inline Tensor& inv_out(Tensor& result, const Tensor& input) {
1061   return detail::inv_out(result, input);
1062 }
1063 
1064 } // namespace linalg
1065 } // namespace torch
1066