xref: /aosp_15_r20/external/pytorch/torch/_lowrank.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Implement various linear algebra algorithms for low rank matrices."""
2
3__all__ = ["svd_lowrank", "pca_lowrank"]
4
5from typing import Optional, Tuple
6
7import torch
8from torch import _linalg_utils as _utils, Tensor
9from torch.overrides import handle_torch_function, has_torch_function
10
11
12def get_approximate_basis(
13    A: Tensor,
14    q: int,
15    niter: Optional[int] = 2,
16    M: Optional[Tensor] = None,
17) -> Tensor:
18    """Return tensor :math:`Q` with :math:`q` orthonormal columns such
19    that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is
20    specified, then :math:`Q` is such that :math:`Q Q^H (A - M)`
21    approximates :math:`A - M`. without instantiating any tensors
22    of the size of :math:`A` or :math:`M`.
23
24    .. note:: The implementation is based on the Algorithm 4.4 from
25              Halko et al., 2009.
26
27    .. note:: For an adequate approximation of a k-rank matrix
28              :math:`A`, where k is not known in advance but could be
29              estimated, the number of :math:`Q` columns, q, can be
30              choosen according to the following criteria: in general,
31              :math:`k <= q <= min(2*k, m, n)`. For large low-rank
32              matrices, take :math:`q = k + 5..10`.  If k is
33              relatively small compared to :math:`min(m, n)`, choosing
34              :math:`q = k + 0..2` may be sufficient.
35
36    .. note:: To obtain repeatable results, reset the seed for the
37              pseudorandom number generator
38
39    Args::
40        A (Tensor): the input tensor of size :math:`(*, m, n)`
41
42        q (int): the dimension of subspace spanned by :math:`Q`
43                 columns.
44
45        niter (int, optional): the number of subspace iterations to
46                               conduct; ``niter`` must be a
47                               nonnegative integer. In most cases, the
48                               default value 2 is more than enough.
49
50        M (Tensor, optional): the input tensor's mean of size
51                              :math:`(*, m, n)`.
52
53    References::
54        - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
55          structure with randomness: probabilistic algorithms for
56          constructing approximate matrix decompositions,
57          arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
58          `arXiv <http://arxiv.org/abs/0909.4061>`_).
59    """
60
61    niter = 2 if niter is None else niter
62    dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype
63    matmul = _utils.matmul
64
65    R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device)
66
67    # The following code could be made faster using torch.geqrf + torch.ormqr
68    # but geqrf is not differentiable
69
70    X = matmul(A, R)
71    if M is not None:
72        X = X - matmul(M, R)
73    Q = torch.linalg.qr(X).Q
74    for i in range(niter):
75        X = matmul(A.mH, Q)
76        if M is not None:
77            X = X - matmul(M.mH, Q)
78        Q = torch.linalg.qr(X).Q
79        X = matmul(A, Q)
80        if M is not None:
81            X = X - matmul(M, Q)
82        Q = torch.linalg.qr(X).Q
83    return Q
84
85
86def svd_lowrank(
87    A: Tensor,
88    q: Optional[int] = 6,
89    niter: Optional[int] = 2,
90    M: Optional[Tensor] = None,
91) -> Tuple[Tensor, Tensor, Tensor]:
92    r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
93    batches of matrices, or a sparse matrix :math:`A` such that
94    :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then
95    SVD is computed for the matrix :math:`A - M`.
96
97    .. note:: The implementation is based on the Algorithm 5.1 from
98              Halko et al., 2009.
99
100    .. note:: For an adequate approximation of a k-rank matrix
101              :math:`A`, where k is not known in advance but could be
102              estimated, the number of :math:`Q` columns, q, can be
103              choosen according to the following criteria: in general,
104              :math:`k <= q <= min(2*k, m, n)`. For large low-rank
105              matrices, take :math:`q = k + 5..10`.  If k is
106              relatively small compared to :math:`min(m, n)`, choosing
107              :math:`q = k + 0..2` may be sufficient.
108
109    .. note:: This is a randomized method. To obtain repeatable results,
110              set the seed for the pseudorandom number generator
111
112    .. note:: In general, use the full-rank SVD implementation
113              :func:`torch.linalg.svd` for dense matrices due to its 10x
114              higher performance characteristics. The low-rank SVD
115              will be useful for huge sparse matrices that
116              :func:`torch.linalg.svd` cannot handle.
117
118    Args::
119        A (Tensor): the input tensor of size :math:`(*, m, n)`
120
121        q (int, optional): a slightly overestimated rank of A.
122
123        niter (int, optional): the number of subspace iterations to
124                               conduct; niter must be a nonnegative
125                               integer, and defaults to 2
126
127        M (Tensor, optional): the input tensor's mean of size
128                              :math:`(*, m, n)`, which will be broadcasted
129                              to the size of A in this function.
130
131    References::
132        - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
133          structure with randomness: probabilistic algorithms for
134          constructing approximate matrix decompositions,
135          arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
136          `arXiv <https://arxiv.org/abs/0909.4061>`_).
137
138    """
139    if not torch.jit.is_scripting():
140        tensor_ops = (A, M)
141        if not set(map(type, tensor_ops)).issubset(
142            (torch.Tensor, type(None))
143        ) and has_torch_function(tensor_ops):
144            return handle_torch_function(
145                svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M
146            )
147    return _svd_lowrank(A, q=q, niter=niter, M=M)
148
149
150def _svd_lowrank(
151    A: Tensor,
152    q: Optional[int] = 6,
153    niter: Optional[int] = 2,
154    M: Optional[Tensor] = None,
155) -> Tuple[Tensor, Tensor, Tensor]:
156    # Algorithm 5.1 in Halko et al., 2009
157
158    q = 6 if q is None else q
159    m, n = A.shape[-2:]
160    matmul = _utils.matmul
161    if M is not None:
162        M = M.broadcast_to(A.size())
163
164    # Assume that A is tall
165    if m < n:
166        A = A.mH
167        if M is not None:
168            M = M.mH
169
170    Q = get_approximate_basis(A, q, niter=niter, M=M)
171    B = matmul(Q.mH, A)
172    if M is not None:
173        B = B - matmul(Q.mH, M)
174    U, S, Vh = torch.linalg.svd(B, full_matrices=False)
175    V = Vh.mH
176    U = Q.matmul(U)
177
178    if m < n:
179        U, V = V, U
180
181    return U, S, V
182
183
184def pca_lowrank(
185    A: Tensor,
186    q: Optional[int] = None,
187    center: bool = True,
188    niter: int = 2,
189) -> Tuple[Tensor, Tensor, Tensor]:
190    r"""Performs linear Principal Component Analysis (PCA) on a low-rank
191    matrix, batches of such matrices, or sparse matrix.
192
193    This function returns a namedtuple ``(U, S, V)`` which is the
194    nearly optimal approximation of a singular value decomposition of
195    a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`
196
197    .. note:: The relation of ``(U, S, V)`` to PCA is as follows:
198
199                - :math:`A` is a data matrix with ``m`` samples and
200                  ``n`` features
201
202                - the :math:`V` columns represent the principal directions
203
204                - :math:`S ** 2 / (m - 1)` contains the eigenvalues of
205                  :math:`A^T A / (m - 1)` which is the covariance of
206                  ``A`` when ``center=True`` is provided.
207
208                - ``matmul(A, V[:, :k])`` projects data to the first k
209                  principal components
210
211    .. note:: Different from the standard SVD, the size of returned
212              matrices depend on the specified rank and q
213              values as follows:
214
215                - :math:`U` is m x q matrix
216
217                - :math:`S` is q-vector
218
219                - :math:`V` is n x q matrix
220
221    .. note:: To obtain repeatable results, reset the seed for the
222              pseudorandom number generator
223
224    Args:
225
226        A (Tensor): the input tensor of size :math:`(*, m, n)`
227
228        q (int, optional): a slightly overestimated rank of
229                           :math:`A`. By default, ``q = min(6, m,
230                           n)``.
231
232        center (bool, optional): if True, center the input tensor,
233                                 otherwise, assume that the input is
234                                 centered.
235
236        niter (int, optional): the number of subspace iterations to
237                               conduct; niter must be a nonnegative
238                               integer, and defaults to 2.
239
240    References::
241
242        - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
243          structure with randomness: probabilistic algorithms for
244          constructing approximate matrix decompositions,
245          arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
246          `arXiv <http://arxiv.org/abs/0909.4061>`_).
247
248    """
249
250    if not torch.jit.is_scripting():
251        if type(A) is not torch.Tensor and has_torch_function((A,)):
252            return handle_torch_function(
253                pca_lowrank, (A,), A, q=q, center=center, niter=niter
254            )
255
256    (m, n) = A.shape[-2:]
257
258    if q is None:
259        q = min(6, m, n)
260    elif not (q >= 0 and q <= min(m, n)):
261        raise ValueError(
262            f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}"
263        )
264    if not (niter >= 0):
265        raise ValueError(f"niter(={niter}) must be non-negative integer")
266
267    dtype = _utils.get_floating_dtype(A)
268
269    if not center:
270        return _svd_lowrank(A, q, niter=niter, M=None)
271
272    if _utils.is_sparse(A):
273        if len(A.shape) != 2:
274            raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor")
275        c = torch.sparse.sum(A, dim=(-2,)) / m
276        # reshape c
277        column_indices = c.indices()[0]
278        indices = torch.zeros(
279            2,
280            len(column_indices),
281            dtype=column_indices.dtype,
282            device=column_indices.device,
283        )
284        indices[0] = column_indices
285        C_t = torch.sparse_coo_tensor(
286            indices, c.values(), (n, 1), dtype=dtype, device=A.device
287        )
288
289        ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device)
290        M = torch.sparse.mm(C_t, ones_m1_t).mT
291        return _svd_lowrank(A, q, niter=niter, M=M)
292    else:
293        C = A.mean(dim=(-2,), keepdim=True)
294        return _svd_lowrank(A - C, q, niter=niter, M=None)
295