1"""Implement various linear algebra algorithms for low rank matrices.""" 2 3__all__ = ["svd_lowrank", "pca_lowrank"] 4 5from typing import Optional, Tuple 6 7import torch 8from torch import _linalg_utils as _utils, Tensor 9from torch.overrides import handle_torch_function, has_torch_function 10 11 12def get_approximate_basis( 13 A: Tensor, 14 q: int, 15 niter: Optional[int] = 2, 16 M: Optional[Tensor] = None, 17) -> Tensor: 18 """Return tensor :math:`Q` with :math:`q` orthonormal columns such 19 that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is 20 specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` 21 approximates :math:`A - M`. without instantiating any tensors 22 of the size of :math:`A` or :math:`M`. 23 24 .. note:: The implementation is based on the Algorithm 4.4 from 25 Halko et al., 2009. 26 27 .. note:: For an adequate approximation of a k-rank matrix 28 :math:`A`, where k is not known in advance but could be 29 estimated, the number of :math:`Q` columns, q, can be 30 choosen according to the following criteria: in general, 31 :math:`k <= q <= min(2*k, m, n)`. For large low-rank 32 matrices, take :math:`q = k + 5..10`. If k is 33 relatively small compared to :math:`min(m, n)`, choosing 34 :math:`q = k + 0..2` may be sufficient. 35 36 .. note:: To obtain repeatable results, reset the seed for the 37 pseudorandom number generator 38 39 Args:: 40 A (Tensor): the input tensor of size :math:`(*, m, n)` 41 42 q (int): the dimension of subspace spanned by :math:`Q` 43 columns. 44 45 niter (int, optional): the number of subspace iterations to 46 conduct; ``niter`` must be a 47 nonnegative integer. In most cases, the 48 default value 2 is more than enough. 49 50 M (Tensor, optional): the input tensor's mean of size 51 :math:`(*, m, n)`. 52 53 References:: 54 - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding 55 structure with randomness: probabilistic algorithms for 56 constructing approximate matrix decompositions, 57 arXiv:0909.4061 [math.NA; math.PR], 2009 (available at 58 `arXiv <http://arxiv.org/abs/0909.4061>`_). 59 """ 60 61 niter = 2 if niter is None else niter 62 dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype 63 matmul = _utils.matmul 64 65 R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device) 66 67 # The following code could be made faster using torch.geqrf + torch.ormqr 68 # but geqrf is not differentiable 69 70 X = matmul(A, R) 71 if M is not None: 72 X = X - matmul(M, R) 73 Q = torch.linalg.qr(X).Q 74 for i in range(niter): 75 X = matmul(A.mH, Q) 76 if M is not None: 77 X = X - matmul(M.mH, Q) 78 Q = torch.linalg.qr(X).Q 79 X = matmul(A, Q) 80 if M is not None: 81 X = X - matmul(M, Q) 82 Q = torch.linalg.qr(X).Q 83 return Q 84 85 86def svd_lowrank( 87 A: Tensor, 88 q: Optional[int] = 6, 89 niter: Optional[int] = 2, 90 M: Optional[Tensor] = None, 91) -> Tuple[Tensor, Tensor, Tensor]: 92 r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, 93 batches of matrices, or a sparse matrix :math:`A` such that 94 :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then 95 SVD is computed for the matrix :math:`A - M`. 96 97 .. note:: The implementation is based on the Algorithm 5.1 from 98 Halko et al., 2009. 99 100 .. note:: For an adequate approximation of a k-rank matrix 101 :math:`A`, where k is not known in advance but could be 102 estimated, the number of :math:`Q` columns, q, can be 103 choosen according to the following criteria: in general, 104 :math:`k <= q <= min(2*k, m, n)`. For large low-rank 105 matrices, take :math:`q = k + 5..10`. If k is 106 relatively small compared to :math:`min(m, n)`, choosing 107 :math:`q = k + 0..2` may be sufficient. 108 109 .. note:: This is a randomized method. To obtain repeatable results, 110 set the seed for the pseudorandom number generator 111 112 .. note:: In general, use the full-rank SVD implementation 113 :func:`torch.linalg.svd` for dense matrices due to its 10x 114 higher performance characteristics. The low-rank SVD 115 will be useful for huge sparse matrices that 116 :func:`torch.linalg.svd` cannot handle. 117 118 Args:: 119 A (Tensor): the input tensor of size :math:`(*, m, n)` 120 121 q (int, optional): a slightly overestimated rank of A. 122 123 niter (int, optional): the number of subspace iterations to 124 conduct; niter must be a nonnegative 125 integer, and defaults to 2 126 127 M (Tensor, optional): the input tensor's mean of size 128 :math:`(*, m, n)`, which will be broadcasted 129 to the size of A in this function. 130 131 References:: 132 - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding 133 structure with randomness: probabilistic algorithms for 134 constructing approximate matrix decompositions, 135 arXiv:0909.4061 [math.NA; math.PR], 2009 (available at 136 `arXiv <https://arxiv.org/abs/0909.4061>`_). 137 138 """ 139 if not torch.jit.is_scripting(): 140 tensor_ops = (A, M) 141 if not set(map(type, tensor_ops)).issubset( 142 (torch.Tensor, type(None)) 143 ) and has_torch_function(tensor_ops): 144 return handle_torch_function( 145 svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M 146 ) 147 return _svd_lowrank(A, q=q, niter=niter, M=M) 148 149 150def _svd_lowrank( 151 A: Tensor, 152 q: Optional[int] = 6, 153 niter: Optional[int] = 2, 154 M: Optional[Tensor] = None, 155) -> Tuple[Tensor, Tensor, Tensor]: 156 # Algorithm 5.1 in Halko et al., 2009 157 158 q = 6 if q is None else q 159 m, n = A.shape[-2:] 160 matmul = _utils.matmul 161 if M is not None: 162 M = M.broadcast_to(A.size()) 163 164 # Assume that A is tall 165 if m < n: 166 A = A.mH 167 if M is not None: 168 M = M.mH 169 170 Q = get_approximate_basis(A, q, niter=niter, M=M) 171 B = matmul(Q.mH, A) 172 if M is not None: 173 B = B - matmul(Q.mH, M) 174 U, S, Vh = torch.linalg.svd(B, full_matrices=False) 175 V = Vh.mH 176 U = Q.matmul(U) 177 178 if m < n: 179 U, V = V, U 180 181 return U, S, V 182 183 184def pca_lowrank( 185 A: Tensor, 186 q: Optional[int] = None, 187 center: bool = True, 188 niter: int = 2, 189) -> Tuple[Tensor, Tensor, Tensor]: 190 r"""Performs linear Principal Component Analysis (PCA) on a low-rank 191 matrix, batches of such matrices, or sparse matrix. 192 193 This function returns a namedtuple ``(U, S, V)`` which is the 194 nearly optimal approximation of a singular value decomposition of 195 a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}` 196 197 .. note:: The relation of ``(U, S, V)`` to PCA is as follows: 198 199 - :math:`A` is a data matrix with ``m`` samples and 200 ``n`` features 201 202 - the :math:`V` columns represent the principal directions 203 204 - :math:`S ** 2 / (m - 1)` contains the eigenvalues of 205 :math:`A^T A / (m - 1)` which is the covariance of 206 ``A`` when ``center=True`` is provided. 207 208 - ``matmul(A, V[:, :k])`` projects data to the first k 209 principal components 210 211 .. note:: Different from the standard SVD, the size of returned 212 matrices depend on the specified rank and q 213 values as follows: 214 215 - :math:`U` is m x q matrix 216 217 - :math:`S` is q-vector 218 219 - :math:`V` is n x q matrix 220 221 .. note:: To obtain repeatable results, reset the seed for the 222 pseudorandom number generator 223 224 Args: 225 226 A (Tensor): the input tensor of size :math:`(*, m, n)` 227 228 q (int, optional): a slightly overestimated rank of 229 :math:`A`. By default, ``q = min(6, m, 230 n)``. 231 232 center (bool, optional): if True, center the input tensor, 233 otherwise, assume that the input is 234 centered. 235 236 niter (int, optional): the number of subspace iterations to 237 conduct; niter must be a nonnegative 238 integer, and defaults to 2. 239 240 References:: 241 242 - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding 243 structure with randomness: probabilistic algorithms for 244 constructing approximate matrix decompositions, 245 arXiv:0909.4061 [math.NA; math.PR], 2009 (available at 246 `arXiv <http://arxiv.org/abs/0909.4061>`_). 247 248 """ 249 250 if not torch.jit.is_scripting(): 251 if type(A) is not torch.Tensor and has_torch_function((A,)): 252 return handle_torch_function( 253 pca_lowrank, (A,), A, q=q, center=center, niter=niter 254 ) 255 256 (m, n) = A.shape[-2:] 257 258 if q is None: 259 q = min(6, m, n) 260 elif not (q >= 0 and q <= min(m, n)): 261 raise ValueError( 262 f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}" 263 ) 264 if not (niter >= 0): 265 raise ValueError(f"niter(={niter}) must be non-negative integer") 266 267 dtype = _utils.get_floating_dtype(A) 268 269 if not center: 270 return _svd_lowrank(A, q, niter=niter, M=None) 271 272 if _utils.is_sparse(A): 273 if len(A.shape) != 2: 274 raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor") 275 c = torch.sparse.sum(A, dim=(-2,)) / m 276 # reshape c 277 column_indices = c.indices()[0] 278 indices = torch.zeros( 279 2, 280 len(column_indices), 281 dtype=column_indices.dtype, 282 device=column_indices.device, 283 ) 284 indices[0] = column_indices 285 C_t = torch.sparse_coo_tensor( 286 indices, c.values(), (n, 1), dtype=dtype, device=A.device 287 ) 288 289 ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device) 290 M = torch.sparse.mm(C_t, ones_m1_t).mT 291 return _svd_lowrank(A, q, niter=niter, M=M) 292 else: 293 C = A.mean(dim=(-2,), keepdim=True) 294 return _svd_lowrank(A - C, q, niter=niter, M=None) 295