1 #pragma once
2
3 #include <ATen/ATen.h>
4
5 namespace torch {
6 namespace linalg {
7
8 #ifndef DOXYGEN_SHOULD_SKIP_THIS
9 namespace detail {
10
cholesky(const Tensor & self)11 inline Tensor cholesky(const Tensor& self) {
12 return torch::linalg_cholesky(self);
13 }
14
cholesky_out(Tensor & result,const Tensor & self)15 inline Tensor cholesky_out(Tensor& result, const Tensor& self) {
16 return torch::linalg_cholesky_out(result, self);
17 }
18
det(const Tensor & self)19 inline Tensor det(const Tensor& self) {
20 return torch::linalg_det(self);
21 }
22
slogdet(const Tensor & input)23 inline std::tuple<Tensor, Tensor> slogdet(const Tensor& input) {
24 return torch::linalg_slogdet(input);
25 }
26
slogdet_out(Tensor & sign,Tensor & logabsdet,const Tensor & input)27 inline std::tuple<Tensor&, Tensor&> slogdet_out(
28 Tensor& sign,
29 Tensor& logabsdet,
30 const Tensor& input) {
31 return torch::linalg_slogdet_out(sign, logabsdet, input);
32 }
33
eig(const Tensor & self)34 inline std::tuple<Tensor, Tensor> eig(const Tensor& self) {
35 return torch::linalg_eig(self);
36 }
37
eig_out(Tensor & eigvals,Tensor & eigvecs,const Tensor & self)38 inline std::tuple<Tensor&, Tensor&> eig_out(
39 Tensor& eigvals,
40 Tensor& eigvecs,
41 const Tensor& self) {
42 return torch::linalg_eig_out(eigvals, eigvecs, self);
43 }
44
eigvals(const Tensor & self)45 inline Tensor eigvals(const Tensor& self) {
46 return torch::linalg_eigvals(self);
47 }
48
eigvals_out(Tensor & result,const Tensor & self)49 inline Tensor& eigvals_out(Tensor& result, const Tensor& self) {
50 return torch::linalg_eigvals_out(result, self);
51 }
52
eigh(const Tensor & self,c10::string_view uplo)53 inline std::tuple<Tensor, Tensor> eigh(
54 const Tensor& self,
55 c10::string_view uplo) {
56 return torch::linalg_eigh(self, uplo);
57 }
58
eigh_out(Tensor & eigvals,Tensor & eigvecs,const Tensor & self,c10::string_view uplo)59 inline std::tuple<Tensor&, Tensor&> eigh_out(
60 Tensor& eigvals,
61 Tensor& eigvecs,
62 const Tensor& self,
63 c10::string_view uplo) {
64 return torch::linalg_eigh_out(eigvals, eigvecs, self, uplo);
65 }
66
eigvalsh(const Tensor & self,c10::string_view uplo)67 inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) {
68 return torch::linalg_eigvalsh(self, uplo);
69 }
70
eigvalsh_out(Tensor & result,const Tensor & self,c10::string_view uplo)71 inline Tensor& eigvalsh_out(
72 Tensor& result,
73 const Tensor& self,
74 c10::string_view uplo) {
75 return torch::linalg_eigvalsh_out(result, self, uplo);
76 }
77
householder_product(const Tensor & input,const Tensor & tau)78 inline Tensor householder_product(const Tensor& input, const Tensor& tau) {
79 return torch::linalg_householder_product(input, tau);
80 }
81
householder_product_out(Tensor & result,const Tensor & input,const Tensor & tau)82 inline Tensor& householder_product_out(
83 Tensor& result,
84 const Tensor& input,
85 const Tensor& tau) {
86 return torch::linalg_householder_product_out(result, input, tau);
87 }
88
lu_factor(const Tensor & self,const bool pivot)89 inline std::tuple<Tensor, Tensor> lu_factor(
90 const Tensor& self,
91 const bool pivot) {
92 return torch::linalg_lu_factor(self, pivot);
93 }
94
lu_factor_out(Tensor & LU,Tensor & pivots,const Tensor & self,const bool pivot)95 inline std::tuple<Tensor&, Tensor&> lu_factor_out(
96 Tensor& LU,
97 Tensor& pivots,
98 const Tensor& self,
99 const bool pivot) {
100 return torch::linalg_lu_factor_out(LU, pivots, self, pivot);
101 }
102
lu(const Tensor & self,const bool pivot)103 inline std::tuple<Tensor, Tensor, Tensor> lu(
104 const Tensor& self,
105 const bool pivot) {
106 return torch::linalg_lu(self, pivot);
107 }
108
lu_out(Tensor & P,Tensor & L,Tensor & U,const Tensor & self,const bool pivot)109 inline std::tuple<Tensor&, Tensor&, Tensor&> lu_out(
110 Tensor& P,
111 Tensor& L,
112 Tensor& U,
113 const Tensor& self,
114 const bool pivot) {
115 return torch::linalg_lu_out(P, L, U, self, pivot);
116 }
117
lstsq(const Tensor & self,const Tensor & b,std::optional<double> cond,std::optional<c10::string_view> driver)118 inline std::tuple<Tensor, Tensor, Tensor, Tensor> lstsq(
119 const Tensor& self,
120 const Tensor& b,
121 std::optional<double> cond,
122 std::optional<c10::string_view> driver) {
123 return torch::linalg_lstsq(self, b, cond, driver);
124 }
125
matrix_exp(const Tensor & self)126 inline Tensor matrix_exp(const Tensor& self) {
127 return torch::linalg_matrix_exp(self);
128 }
129
norm(const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)130 inline Tensor norm(
131 const Tensor& self,
132 const std::optional<Scalar>& opt_ord,
133 OptionalIntArrayRef opt_dim,
134 bool keepdim,
135 std::optional<ScalarType> opt_dtype) {
136 return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
137 }
138
norm(const Tensor & self,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)139 inline Tensor norm(
140 const Tensor& self,
141 c10::string_view ord,
142 OptionalIntArrayRef opt_dim,
143 bool keepdim,
144 std::optional<ScalarType> opt_dtype) {
145 return torch::linalg_norm(self, ord, opt_dim, keepdim, opt_dtype);
146 }
147
norm_out(Tensor & result,const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)148 inline Tensor& norm_out(
149 Tensor& result,
150 const Tensor& self,
151 const std::optional<Scalar>& opt_ord,
152 OptionalIntArrayRef opt_dim,
153 bool keepdim,
154 std::optional<ScalarType> opt_dtype) {
155 return torch::linalg_norm_out(
156 result, self, opt_ord, opt_dim, keepdim, opt_dtype);
157 }
158
norm_out(Tensor & result,const Tensor & self,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)159 inline Tensor& norm_out(
160 Tensor& result,
161 const Tensor& self,
162 c10::string_view ord,
163 OptionalIntArrayRef opt_dim,
164 bool keepdim,
165 std::optional<ScalarType> opt_dtype) {
166 return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
167 }
168
vector_norm(const Tensor & self,Scalar ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)169 inline Tensor vector_norm(
170 const Tensor& self,
171 Scalar ord,
172 OptionalIntArrayRef opt_dim,
173 bool keepdim,
174 std::optional<ScalarType> opt_dtype) {
175 return torch::linalg_vector_norm(self, ord, opt_dim, keepdim, opt_dtype);
176 }
177
vector_norm_out(Tensor & result,const Tensor & self,Scalar ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)178 inline Tensor& vector_norm_out(
179 Tensor& result,
180 const Tensor& self,
181 Scalar ord,
182 OptionalIntArrayRef opt_dim,
183 bool keepdim,
184 std::optional<ScalarType> opt_dtype) {
185 return torch::linalg_vector_norm_out(
186 result, self, ord, opt_dim, keepdim, opt_dtype);
187 }
188
matrix_norm(const Tensor & self,const Scalar & ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)189 inline Tensor matrix_norm(
190 const Tensor& self,
191 const Scalar& ord,
192 IntArrayRef dim,
193 bool keepdim,
194 std::optional<ScalarType> dtype) {
195 return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype);
196 }
197
matrix_norm_out(const Tensor & self,const Scalar & ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype,Tensor & result)198 inline Tensor& matrix_norm_out(
199 const Tensor& self,
200 const Scalar& ord,
201 IntArrayRef dim,
202 bool keepdim,
203 std::optional<ScalarType> dtype,
204 Tensor& result) {
205 return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype);
206 }
207
matrix_norm(const Tensor & self,std::string ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)208 inline Tensor matrix_norm(
209 const Tensor& self,
210 std::string ord,
211 IntArrayRef dim,
212 bool keepdim,
213 std::optional<ScalarType> dtype) {
214 return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype);
215 }
216
matrix_norm_out(const Tensor & self,std::string ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype,Tensor & result)217 inline Tensor& matrix_norm_out(
218 const Tensor& self,
219 std::string ord,
220 IntArrayRef dim,
221 bool keepdim,
222 std::optional<ScalarType> dtype,
223 Tensor& result) {
224 return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype);
225 }
226
matrix_power(const Tensor & self,int64_t n)227 inline Tensor matrix_power(const Tensor& self, int64_t n) {
228 return torch::linalg_matrix_power(self, n);
229 }
230
matrix_power_out(const Tensor & self,int64_t n,Tensor & result)231 inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
232 return torch::linalg_matrix_power_out(result, self, n);
233 }
234
matrix_rank(const Tensor & input,double tol,bool hermitian)235 inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) {
236 return torch::linalg_matrix_rank(input, tol, hermitian);
237 }
238
matrix_rank(const Tensor & input,const Tensor & tol,bool hermitian)239 inline Tensor matrix_rank(
240 const Tensor& input,
241 const Tensor& tol,
242 bool hermitian) {
243 return torch::linalg_matrix_rank(input, tol, hermitian);
244 }
245
matrix_rank(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)246 inline Tensor matrix_rank(
247 const Tensor& input,
248 std::optional<double> atol,
249 std::optional<double> rtol,
250 bool hermitian) {
251 return torch::linalg_matrix_rank(input, atol, rtol, hermitian);
252 }
253
matrix_rank(const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian)254 inline Tensor matrix_rank(
255 const Tensor& input,
256 const std::optional<Tensor>& atol,
257 const std::optional<Tensor>& rtol,
258 bool hermitian) {
259 return torch::linalg_matrix_rank(input, atol, rtol, hermitian);
260 }
261
matrix_rank_out(Tensor & result,const Tensor & input,double tol,bool hermitian)262 inline Tensor& matrix_rank_out(
263 Tensor& result,
264 const Tensor& input,
265 double tol,
266 bool hermitian) {
267 return torch::linalg_matrix_rank_out(result, input, tol, hermitian);
268 }
269
matrix_rank_out(Tensor & result,const Tensor & input,const Tensor & tol,bool hermitian)270 inline Tensor& matrix_rank_out(
271 Tensor& result,
272 const Tensor& input,
273 const Tensor& tol,
274 bool hermitian) {
275 return torch::linalg_matrix_rank_out(result, input, tol, hermitian);
276 }
277
matrix_rank_out(Tensor & result,const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)278 inline Tensor& matrix_rank_out(
279 Tensor& result,
280 const Tensor& input,
281 std::optional<double> atol,
282 std::optional<double> rtol,
283 bool hermitian) {
284 return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian);
285 }
286
matrix_rank_out(Tensor & result,const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian)287 inline Tensor& matrix_rank_out(
288 Tensor& result,
289 const Tensor& input,
290 const std::optional<Tensor>& atol,
291 const std::optional<Tensor>& rtol,
292 bool hermitian) {
293 return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian);
294 }
295
multi_dot(TensorList tensors)296 inline Tensor multi_dot(TensorList tensors) {
297 return torch::linalg_multi_dot(tensors);
298 }
299
multi_dot_out(TensorList tensors,Tensor & result)300 inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) {
301 return torch::linalg_multi_dot_out(result, tensors);
302 }
303
pinv(const Tensor & input,double rcond,bool hermitian)304 inline Tensor pinv(const Tensor& input, double rcond, bool hermitian) {
305 return torch::linalg_pinv(input, rcond, hermitian);
306 }
307
pinv_out(Tensor & result,const Tensor & input,double rcond,bool hermitian)308 inline Tensor& pinv_out(
309 Tensor& result,
310 const Tensor& input,
311 double rcond,
312 bool hermitian) {
313 return torch::linalg_pinv_out(result, input, rcond, hermitian);
314 }
315
qr(const Tensor & input,c10::string_view mode)316 inline std::tuple<Tensor, Tensor> qr(
317 const Tensor& input,
318 c10::string_view mode) {
319 return torch::linalg_qr(input, mode);
320 }
321
qr_out(Tensor & Q,Tensor & R,const Tensor & input,c10::string_view mode)322 inline std::tuple<Tensor&, Tensor&> qr_out(
323 Tensor& Q,
324 Tensor& R,
325 const Tensor& input,
326 c10::string_view mode) {
327 return torch::linalg_qr_out(Q, R, input, mode);
328 }
329
solve_ex(const Tensor & input,const Tensor & other,bool left,bool check_errors)330 inline std::tuple<Tensor, Tensor> solve_ex(
331 const Tensor& input,
332 const Tensor& other,
333 bool left,
334 bool check_errors) {
335 return torch::linalg_solve_ex(input, other, left, check_errors);
336 }
337
solve_ex_out(Tensor & result,Tensor & info,const Tensor & input,const Tensor & other,bool left,bool check_errors)338 inline std::tuple<Tensor&, Tensor&> solve_ex_out(
339 Tensor& result,
340 Tensor& info,
341 const Tensor& input,
342 const Tensor& other,
343 bool left,
344 bool check_errors) {
345 return torch::linalg_solve_ex_out(
346 result, info, input, other, left, check_errors);
347 }
348
solve(const Tensor & input,const Tensor & other,bool left)349 inline Tensor solve(const Tensor& input, const Tensor& other, bool left) {
350 return torch::linalg_solve(input, other, left);
351 }
352
solve_out(Tensor & result,const Tensor & input,const Tensor & other,bool left)353 inline Tensor& solve_out(
354 Tensor& result,
355 const Tensor& input,
356 const Tensor& other,
357 bool left) {
358 return torch::linalg_solve_out(result, input, other, left);
359 }
360
solve_triangular(const Tensor & input,const Tensor & other,bool upper,bool left,bool unitriangular)361 inline Tensor solve_triangular(
362 const Tensor& input,
363 const Tensor& other,
364 bool upper,
365 bool left,
366 bool unitriangular) {
367 return torch::linalg_solve_triangular(
368 input, other, upper, left, unitriangular);
369 }
370
solve_triangular_out(Tensor & result,const Tensor & input,const Tensor & other,bool upper,bool left,bool unitriangular)371 inline Tensor& solve_triangular_out(
372 Tensor& result,
373 const Tensor& input,
374 const Tensor& other,
375 bool upper,
376 bool left,
377 bool unitriangular) {
378 return torch::linalg_solve_triangular_out(
379 result, input, other, upper, left, unitriangular);
380 }
381
svd(const Tensor & input,bool full_matrices,std::optional<c10::string_view> driver)382 inline std::tuple<Tensor, Tensor, Tensor> svd(
383 const Tensor& input,
384 bool full_matrices,
385 std::optional<c10::string_view> driver) {
386 return torch::linalg_svd(input, full_matrices, driver);
387 }
388
svd_out(Tensor & U,Tensor & S,Tensor & Vh,const Tensor & input,bool full_matrices,std::optional<c10::string_view> driver)389 inline std::tuple<Tensor&, Tensor&, Tensor&> svd_out(
390 Tensor& U,
391 Tensor& S,
392 Tensor& Vh,
393 const Tensor& input,
394 bool full_matrices,
395 std::optional<c10::string_view> driver) {
396 return torch::linalg_svd_out(U, S, Vh, input, full_matrices, driver);
397 }
398
svdvals(const Tensor & input,std::optional<c10::string_view> driver)399 inline Tensor svdvals(
400 const Tensor& input,
401 std::optional<c10::string_view> driver) {
402 return torch::linalg_svdvals(input, driver);
403 }
404
svdvals_out(Tensor & result,const Tensor & input,std::optional<c10::string_view> driver)405 inline Tensor& svdvals_out(
406 Tensor& result,
407 const Tensor& input,
408 std::optional<c10::string_view> driver) {
409 return torch::linalg_svdvals_out(result, input, driver);
410 }
411
tensorinv(const Tensor & self,int64_t ind)412 inline Tensor tensorinv(const Tensor& self, int64_t ind) {
413 return torch::linalg_tensorinv(self, ind);
414 }
415
tensorinv_out(Tensor & result,const Tensor & self,int64_t ind)416 inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) {
417 return torch::linalg_tensorinv_out(result, self, ind);
418 }
419
tensorsolve(const Tensor & self,const Tensor & other,OptionalIntArrayRef dims)420 inline Tensor tensorsolve(
421 const Tensor& self,
422 const Tensor& other,
423 OptionalIntArrayRef dims) {
424 return torch::linalg_tensorsolve(self, other, dims);
425 }
426
tensorsolve_out(Tensor & result,const Tensor & self,const Tensor & other,OptionalIntArrayRef dims)427 inline Tensor& tensorsolve_out(
428 Tensor& result,
429 const Tensor& self,
430 const Tensor& other,
431 OptionalIntArrayRef dims) {
432 return torch::linalg_tensorsolve_out(result, self, other, dims);
433 }
434
inv(const Tensor & input)435 inline Tensor inv(const Tensor& input) {
436 return torch::linalg_inv(input);
437 }
438
inv_out(Tensor & result,const Tensor & input)439 inline Tensor& inv_out(Tensor& result, const Tensor& input) {
440 return torch::linalg_inv_out(result, input);
441 }
442
443 } // namespace detail
444 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
445
446 /// Cholesky decomposition
447 ///
448 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.cholesky
449 ///
450 /// Example:
451 /// ```
452 /// auto A = torch::randn({4, 4});
453 /// auto A = torch::matmul(A, A.t());
454 /// auto L = torch::linalg::cholesky(A);
455 /// assert(torch::allclose(torch::matmul(L, L.t()), A));
456 /// ```
cholesky(const Tensor & self)457 inline Tensor cholesky(const Tensor& self) {
458 return detail::cholesky(self);
459 }
460
cholesky_out(Tensor & result,const Tensor & self)461 inline Tensor cholesky_out(Tensor& result, const Tensor& self) {
462 return detail::cholesky_out(result, self);
463 }
464
465 // C10_DEPRECATED_MESSAGE("linalg_det is deprecated, use det instead.")
linalg_det(const Tensor & self)466 inline Tensor linalg_det(const Tensor& self) {
467 return detail::det(self);
468 }
469
470 /// See the documentation of torch.linalg.det
det(const Tensor & self)471 inline Tensor det(const Tensor& self) {
472 return detail::det(self);
473 }
474
475 /// Computes the sign and (natural) logarithm of the determinant
476 ///
477 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.slogdet
slogdet(const Tensor & input)478 inline std::tuple<Tensor, Tensor> slogdet(const Tensor& input) {
479 return detail::slogdet(input);
480 }
481
slogdet_out(Tensor & sign,Tensor & logabsdet,const Tensor & input)482 inline std::tuple<Tensor&, Tensor&> slogdet_out(
483 Tensor& sign,
484 Tensor& logabsdet,
485 const Tensor& input) {
486 return detail::slogdet_out(sign, logabsdet, input);
487 }
488
489 /// Computes eigenvalues and eigenvectors of non-symmetric/non-hermitian
490 /// matrices
491 ///
492 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eig
eig(const Tensor & self)493 inline std::tuple<Tensor, Tensor> eig(const Tensor& self) {
494 return detail::eig(self);
495 }
496
eig_out(Tensor & eigvals,Tensor & eigvecs,const Tensor & self)497 inline std::tuple<Tensor&, Tensor&> eig_out(
498 Tensor& eigvals,
499 Tensor& eigvecs,
500 const Tensor& self) {
501 return detail::eig_out(eigvals, eigvecs, self);
502 }
503
504 /// Computes eigenvalues of non-symmetric/non-hermitian matrices
505 ///
506 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eigvals
eigvals(const Tensor & self)507 inline Tensor eigvals(const Tensor& self) {
508 return detail::eigvals(self);
509 }
510
eigvals_out(Tensor & result,const Tensor & self)511 inline Tensor& eigvals_out(Tensor& result, const Tensor& self) {
512 return detail::eigvals_out(result, self);
513 }
514
515 /// Computes eigenvalues and eigenvectors
516 ///
517 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eigh
eigh(const Tensor & self,c10::string_view uplo)518 inline std::tuple<Tensor, Tensor> eigh(
519 const Tensor& self,
520 c10::string_view uplo) {
521 return detail::eigh(self, uplo);
522 }
523
eigh_out(Tensor & eigvals,Tensor & eigvecs,const Tensor & self,c10::string_view uplo)524 inline std::tuple<Tensor&, Tensor&> eigh_out(
525 Tensor& eigvals,
526 Tensor& eigvecs,
527 const Tensor& self,
528 c10::string_view uplo) {
529 return detail::eigh_out(eigvals, eigvecs, self, uplo);
530 }
531
532 /// Computes eigenvalues
533 ///
534 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eigvalsh
eigvalsh(const Tensor & self,c10::string_view uplo)535 inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) {
536 return detail::eigvalsh(self, uplo);
537 }
538
eigvalsh_out(Tensor & result,const Tensor & self,c10::string_view uplo)539 inline Tensor& eigvalsh_out(
540 Tensor& result,
541 const Tensor& self,
542 c10::string_view uplo) {
543 return detail::eigvalsh_out(result, self, uplo);
544 }
545
546 /// Computes the product of Householder matrices
547 ///
548 /// See
549 /// https://pytorch.org/docs/main/linalg.html#torch.linalg.householder_product
householder_product(const Tensor & input,const Tensor & tau)550 inline Tensor householder_product(const Tensor& input, const Tensor& tau) {
551 return detail::householder_product(input, tau);
552 }
553
householder_product_out(Tensor & result,const Tensor & input,const Tensor & tau)554 inline Tensor& householder_product_out(
555 Tensor& result,
556 const Tensor& input,
557 const Tensor& tau) {
558 return detail::householder_product_out(result, input, tau);
559 }
560
lstsq(const Tensor & self,const Tensor & b,std::optional<double> cond,std::optional<c10::string_view> driver)561 inline std::tuple<Tensor, Tensor, Tensor, Tensor> lstsq(
562 const Tensor& self,
563 const Tensor& b,
564 std::optional<double> cond,
565 std::optional<c10::string_view> driver) {
566 return detail::lstsq(self, b, cond, driver);
567 }
568
569 /// Computes the matrix exponential
570 ///
571 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_exp
matrix_exp(const Tensor & input)572 inline Tensor matrix_exp(const Tensor& input) {
573 return detail::matrix_exp(input);
574 }
575
576 // C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.")
linalg_norm(const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)577 inline Tensor linalg_norm(
578 const Tensor& self,
579 const std::optional<Scalar>& opt_ord,
580 OptionalIntArrayRef opt_dim,
581 bool keepdim,
582 std::optional<ScalarType> opt_dtype) {
583 return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
584 }
585
586 // C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.")
linalg_norm(const Tensor & self,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)587 inline Tensor linalg_norm(
588 const Tensor& self,
589 c10::string_view ord,
590 OptionalIntArrayRef opt_dim,
591 bool keepdim,
592 std::optional<ScalarType> opt_dtype) {
593 return detail::norm(self, ord, opt_dim, keepdim, opt_dtype);
594 }
595
596 // C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out
597 // instead.")
linalg_norm_out(Tensor & result,const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)598 inline Tensor& linalg_norm_out(
599 Tensor& result,
600 const Tensor& self,
601 const std::optional<Scalar>& opt_ord,
602 OptionalIntArrayRef opt_dim,
603 bool keepdim,
604 std::optional<ScalarType> opt_dtype) {
605 return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
606 }
607
608 // C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out
609 // instead.")
linalg_norm_out(Tensor & result,const Tensor & self,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)610 inline Tensor& linalg_norm_out(
611 Tensor& result,
612 const Tensor& self,
613 c10::string_view ord,
614 OptionalIntArrayRef opt_dim,
615 bool keepdim,
616 std::optional<ScalarType> opt_dtype) {
617 return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
618 }
619
620 /// Computes the LU factorization with partial pivoting
621 ///
622 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.lu_factor
623 inline std::tuple<Tensor, Tensor> lu_factor(
624 const Tensor& input,
625 const bool pivot = true) {
626 return detail::lu_factor(input, pivot);
627 }
628
629 inline std::tuple<Tensor&, Tensor&> lu_factor_out(
630 Tensor& LU,
631 Tensor& pivots,
632 const Tensor& self,
633 const bool pivot = true) {
634 return detail::lu_factor_out(LU, pivots, self, pivot);
635 }
636
637 /// Computes the LU factorization with partial pivoting
638 ///
639 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.lu
640 inline std::tuple<Tensor, Tensor, Tensor> lu(
641 const Tensor& input,
642 const bool pivot = true) {
643 return detail::lu(input, pivot);
644 }
645
646 inline std::tuple<Tensor&, Tensor&, Tensor&> lu_out(
647 Tensor& P,
648 Tensor& L,
649 Tensor& U,
650 const Tensor& self,
651 const bool pivot = true) {
652 return detail::lu_out(P, L, U, self, pivot);
653 }
654
norm(const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)655 inline Tensor norm(
656 const Tensor& self,
657 const std::optional<Scalar>& opt_ord,
658 OptionalIntArrayRef opt_dim,
659 bool keepdim,
660 std::optional<ScalarType> opt_dtype) {
661 return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
662 }
663
norm(const Tensor & self,std::string ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)664 inline Tensor norm(
665 const Tensor& self,
666 std::string ord,
667 OptionalIntArrayRef opt_dim,
668 bool keepdim,
669 std::optional<ScalarType> opt_dtype) {
670 return detail::norm(self, ord, opt_dim, keepdim, opt_dtype);
671 }
672
norm_out(Tensor & result,const Tensor & self,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)673 inline Tensor& norm_out(
674 Tensor& result,
675 const Tensor& self,
676 const std::optional<Scalar>& opt_ord,
677 OptionalIntArrayRef opt_dim,
678 bool keepdim,
679 std::optional<ScalarType> opt_dtype) {
680 return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
681 }
682
norm_out(Tensor & result,const Tensor & self,std::string ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)683 inline Tensor& norm_out(
684 Tensor& result,
685 const Tensor& self,
686 std::string ord,
687 OptionalIntArrayRef opt_dim,
688 bool keepdim,
689 std::optional<ScalarType> opt_dtype) {
690 return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
691 }
692
693 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.vector_norm
vector_norm(const Tensor & self,Scalar ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)694 inline Tensor vector_norm(
695 const Tensor& self,
696 Scalar ord,
697 OptionalIntArrayRef opt_dim,
698 bool keepdim,
699 std::optional<ScalarType> opt_dtype) {
700 return detail::vector_norm(self, ord, opt_dim, keepdim, opt_dtype);
701 }
702
vector_norm_out(Tensor & result,const Tensor & self,Scalar ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)703 inline Tensor& vector_norm_out(
704 Tensor& result,
705 const Tensor& self,
706 Scalar ord,
707 OptionalIntArrayRef opt_dim,
708 bool keepdim,
709 std::optional<ScalarType> opt_dtype) {
710 return detail::vector_norm_out(
711 result, self, ord, opt_dim, keepdim, opt_dtype);
712 }
713
714 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_norm
matrix_norm(const Tensor & self,const Scalar & ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)715 inline Tensor matrix_norm(
716 const Tensor& self,
717 const Scalar& ord,
718 IntArrayRef dim,
719 bool keepdim,
720 std::optional<ScalarType> dtype) {
721 return detail::matrix_norm(self, ord, dim, keepdim, dtype);
722 }
723
matrix_norm_out(const Tensor & self,const Scalar & ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype,Tensor & result)724 inline Tensor& matrix_norm_out(
725 const Tensor& self,
726 const Scalar& ord,
727 IntArrayRef dim,
728 bool keepdim,
729 std::optional<ScalarType> dtype,
730 Tensor& result) {
731 return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result);
732 }
733
matrix_norm(const Tensor & self,std::string ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)734 inline Tensor matrix_norm(
735 const Tensor& self,
736 std::string ord,
737 IntArrayRef dim,
738 bool keepdim,
739 std::optional<ScalarType> dtype) {
740 return detail::matrix_norm(self, ord, dim, keepdim, dtype);
741 }
742
matrix_norm_out(const Tensor & self,std::string ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype,Tensor & result)743 inline Tensor& matrix_norm_out(
744 const Tensor& self,
745 std::string ord,
746 IntArrayRef dim,
747 bool keepdim,
748 std::optional<ScalarType> dtype,
749 Tensor& result) {
750 return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result);
751 }
752
753 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_power
matrix_power(const Tensor & self,int64_t n)754 inline Tensor matrix_power(const Tensor& self, int64_t n) {
755 return detail::matrix_power(self, n);
756 }
757
matrix_power_out(const Tensor & self,int64_t n,Tensor & result)758 inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
759 return detail::matrix_power_out(self, n, result);
760 }
761
762 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_rank
matrix_rank(const Tensor & input,double tol,bool hermitian)763 inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) {
764 return detail::matrix_rank(input, tol, hermitian);
765 }
766
matrix_rank(const Tensor & input,const Tensor & tol,bool hermitian)767 inline Tensor matrix_rank(
768 const Tensor& input,
769 const Tensor& tol,
770 bool hermitian) {
771 return detail::matrix_rank(input, tol, hermitian);
772 }
773
matrix_rank(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)774 inline Tensor matrix_rank(
775 const Tensor& input,
776 std::optional<double> atol,
777 std::optional<double> rtol,
778 bool hermitian) {
779 return detail::matrix_rank(input, atol, rtol, hermitian);
780 }
781
matrix_rank(const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian)782 inline Tensor matrix_rank(
783 const Tensor& input,
784 const std::optional<Tensor>& atol,
785 const std::optional<Tensor>& rtol,
786 bool hermitian) {
787 return detail::matrix_rank(input, atol, rtol, hermitian);
788 }
789
matrix_rank_out(Tensor & result,const Tensor & input,double tol,bool hermitian)790 inline Tensor& matrix_rank_out(
791 Tensor& result,
792 const Tensor& input,
793 double tol,
794 bool hermitian) {
795 return detail::matrix_rank_out(result, input, tol, hermitian);
796 }
797
matrix_rank_out(Tensor & result,const Tensor & input,const Tensor & tol,bool hermitian)798 inline Tensor& matrix_rank_out(
799 Tensor& result,
800 const Tensor& input,
801 const Tensor& tol,
802 bool hermitian) {
803 return detail::matrix_rank_out(result, input, tol, hermitian);
804 }
805
matrix_rank_out(Tensor & result,const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)806 inline Tensor& matrix_rank_out(
807 Tensor& result,
808 const Tensor& input,
809 std::optional<double> atol,
810 std::optional<double> rtol,
811 bool hermitian) {
812 return detail::matrix_rank_out(result, input, atol, rtol, hermitian);
813 }
814
matrix_rank_out(Tensor & result,const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian)815 inline Tensor& matrix_rank_out(
816 Tensor& result,
817 const Tensor& input,
818 const std::optional<Tensor>& atol,
819 const std::optional<Tensor>& rtol,
820 bool hermitian) {
821 return detail::matrix_rank_out(result, input, atol, rtol, hermitian);
822 }
823
824 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.multi_dot
multi_dot(TensorList tensors)825 inline Tensor multi_dot(TensorList tensors) {
826 return detail::multi_dot(tensors);
827 }
828
multi_dot_out(TensorList tensors,Tensor & result)829 inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) {
830 return detail::multi_dot_out(tensors, result);
831 }
832
833 /// Computes the pseudo-inverse
834 ///
835 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.pinv
836 inline Tensor pinv(
837 const Tensor& input,
838 double rcond = 1e-15,
839 bool hermitian = false) {
840 return detail::pinv(input, rcond, hermitian);
841 }
842
843 inline Tensor& pinv_out(
844 Tensor& result,
845 const Tensor& input,
846 double rcond = 1e-15,
847 bool hermitian = false) {
848 return detail::pinv_out(result, input, rcond, hermitian);
849 }
850
851 /// Computes the QR decomposition
852 ///
853 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.qr
854 inline std::tuple<Tensor, Tensor> qr(
855 const Tensor& input,
856 c10::string_view mode = "reduced") {
857 // C++17 Change the initialisation to "reduced"sv
858 // Same for qr_out
859 return detail::qr(input, mode);
860 }
861
862 inline std::tuple<Tensor&, Tensor&> qr_out(
863 Tensor& Q,
864 Tensor& R,
865 const Tensor& input,
866 c10::string_view mode = "reduced") {
867 return detail::qr_out(Q, R, input, mode);
868 }
869
870 /// Computes the LDL decomposition
871 ///
872 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.ldl_factor_ex
ldl_factor_ex(const Tensor & input,bool hermitian,bool check_errors)873 inline std::tuple<Tensor, Tensor, Tensor> ldl_factor_ex(
874 const Tensor& input,
875 bool hermitian,
876 bool check_errors) {
877 return torch::linalg_ldl_factor_ex(input, hermitian, check_errors);
878 }
879
ldl_factor_ex_out(Tensor & LD,Tensor & pivots,Tensor & info,const Tensor & input,bool hermitian,bool check_errors)880 inline std::tuple<Tensor&, Tensor&, Tensor&> ldl_factor_ex_out(
881 Tensor& LD,
882 Tensor& pivots,
883 Tensor& info,
884 const Tensor& input,
885 bool hermitian,
886 bool check_errors) {
887 return torch::linalg_ldl_factor_ex_out(
888 LD, pivots, info, input, hermitian, check_errors);
889 }
890
891 /// Solve a system of linear equations using the LDL decomposition
892 ///
893 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.ldl_solve
ldl_solve(const Tensor & LD,const Tensor & pivots,const Tensor & B,bool hermitian)894 inline Tensor ldl_solve(
895 const Tensor& LD,
896 const Tensor& pivots,
897 const Tensor& B,
898 bool hermitian) {
899 return torch::linalg_ldl_solve(LD, pivots, B, hermitian);
900 }
901
ldl_solve_out(Tensor & result,const Tensor & LD,const Tensor & pivots,const Tensor & B,bool hermitian)902 inline Tensor& ldl_solve_out(
903 Tensor& result,
904 const Tensor& LD,
905 const Tensor& pivots,
906 const Tensor& B,
907 bool hermitian) {
908 return torch::linalg_ldl_solve_out(result, LD, pivots, B, hermitian);
909 }
910
911 /// Solves a system linear system AX = B
912 ///
913 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.solve_ex
solve_ex(const Tensor & input,const Tensor & other,bool left,bool check_errors)914 inline std::tuple<Tensor, Tensor> solve_ex(
915 const Tensor& input,
916 const Tensor& other,
917 bool left,
918 bool check_errors) {
919 return detail::solve_ex(input, other, left, check_errors);
920 }
921
solve_ex_out(Tensor & result,Tensor & info,const Tensor & input,const Tensor & other,bool left,bool check_errors)922 inline std::tuple<Tensor&, Tensor&> solve_ex_out(
923 Tensor& result,
924 Tensor& info,
925 const Tensor& input,
926 const Tensor& other,
927 bool left,
928 bool check_errors) {
929 return detail::solve_ex_out(result, info, input, other, left, check_errors);
930 }
931
932 /// Computes a tensor `x` such that `matmul(input, x) = other`.
933 ///
934 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.solve
solve(const Tensor & input,const Tensor & other,bool left)935 inline Tensor solve(const Tensor& input, const Tensor& other, bool left) {
936 return detail::solve(input, other, left);
937 }
938
solve_out(Tensor & result,const Tensor & input,const Tensor & other,bool left)939 inline Tensor& solve_out(
940 Tensor& result,
941 const Tensor& input,
942 const Tensor& other,
943 bool left) {
944 return detail::solve_out(result, input, other, left);
945 }
946
947 /// Computes a solution of a linear system AX = B for input = A and other = B
948 /// whenever A is square upper or lower triangular and does not have zeros in
949 /// the diagonal
950 ///
951 /// See
952 /// https://pytorch.org/docs/main/linalg.html#torch.linalg.solve_triangular
solve_triangular(const Tensor & input,const Tensor & other,bool upper,bool left,bool unitriangular)953 inline Tensor solve_triangular(
954 const Tensor& input,
955 const Tensor& other,
956 bool upper,
957 bool left,
958 bool unitriangular) {
959 return detail::solve_triangular(input, other, upper, left, unitriangular);
960 }
961
solve_triangular_out(Tensor & result,const Tensor & input,const Tensor & other,bool upper,bool left,bool unitriangular)962 inline Tensor& solve_triangular_out(
963 Tensor& result,
964 const Tensor& input,
965 const Tensor& other,
966 bool upper,
967 bool left,
968 bool unitriangular) {
969 return detail::solve_triangular_out(
970 result, input, other, upper, left, unitriangular);
971 }
972
973 /// Computes the singular values and singular vectors
974 ///
975 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.svd
svd(const Tensor & input,bool full_matrices,std::optional<c10::string_view> driver)976 inline std::tuple<Tensor, Tensor, Tensor> svd(
977 const Tensor& input,
978 bool full_matrices,
979 std::optional<c10::string_view> driver) {
980 return detail::svd(input, full_matrices, driver);
981 }
982
svd_out(Tensor & U,Tensor & S,Tensor & Vh,const Tensor & input,bool full_matrices,std::optional<c10::string_view> driver)983 inline std::tuple<Tensor&, Tensor&, Tensor&> svd_out(
984 Tensor& U,
985 Tensor& S,
986 Tensor& Vh,
987 const Tensor& input,
988 bool full_matrices,
989 std::optional<c10::string_view> driver) {
990 return detail::svd_out(U, S, Vh, input, full_matrices, driver);
991 }
992
993 /// Computes the singular values
994 ///
995 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.svdvals
svdvals(const Tensor & input,std::optional<c10::string_view> driver)996 inline Tensor svdvals(
997 const Tensor& input,
998 std::optional<c10::string_view> driver) {
999 return detail::svdvals(input, driver);
1000 }
1001
svdvals_out(Tensor & result,const Tensor & input,std::optional<c10::string_view> driver)1002 inline Tensor& svdvals_out(
1003 Tensor& result,
1004 const Tensor& input,
1005 std::optional<c10::string_view> driver) {
1006 return detail::svdvals_out(result, input, driver);
1007 }
1008
1009 /// Computes the inverse of a tensor
1010 ///
1011 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.tensorinv
1012 ///
1013 /// Example:
1014 /// ```
1015 /// auto a = torch::eye(4*6).reshape({4, 6, 8, 3});
1016 /// int64_t ind = 2;
1017 /// auto ainv = torch::linalg::tensorinv(a, ind);
1018 /// ```
tensorinv(const Tensor & self,int64_t ind)1019 inline Tensor tensorinv(const Tensor& self, int64_t ind) {
1020 return detail::tensorinv(self, ind);
1021 }
1022
tensorinv_out(Tensor & result,const Tensor & self,int64_t ind)1023 inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) {
1024 return detail::tensorinv_out(result, self, ind);
1025 }
1026
1027 /// Computes a tensor `x` such that `tensordot(input, x, dims=x.dim()) = other`.
1028 ///
1029 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.tensorsolve
1030 ///
1031 /// Example:
1032 /// ```
1033 /// auto a = torch::eye(2*3*4).reshape({2*3, 4, 2, 3, 4});
1034 /// auto b = torch::randn(2*3, 4);
1035 /// auto x = torch::linalg::tensorsolve(a, b);
1036 /// ```
tensorsolve(const Tensor & input,const Tensor & other,OptionalIntArrayRef dims)1037 inline Tensor tensorsolve(
1038 const Tensor& input,
1039 const Tensor& other,
1040 OptionalIntArrayRef dims) {
1041 return detail::tensorsolve(input, other, dims);
1042 }
1043
tensorsolve_out(Tensor & result,const Tensor & input,const Tensor & other,OptionalIntArrayRef dims)1044 inline Tensor& tensorsolve_out(
1045 Tensor& result,
1046 const Tensor& input,
1047 const Tensor& other,
1048 OptionalIntArrayRef dims) {
1049 return detail::tensorsolve_out(result, input, other, dims);
1050 }
1051
1052 /// Computes a tensor `inverse_input` such that `dot(input, inverse_input) =
1053 /// eye(input.size(0))`.
1054 ///
1055 /// See https://pytorch.org/docs/main/linalg.html#torch.linalg.inv
inv(const Tensor & input)1056 inline Tensor inv(const Tensor& input) {
1057 return detail::inv(input);
1058 }
1059
inv_out(Tensor & result,const Tensor & input)1060 inline Tensor& inv_out(Tensor& result, const Tensor& input) {
1061 return detail::inv_out(result, input);
1062 }
1063
1064 } // namespace linalg
1065 } // namespace torch
1066