xref: /aosp_15_r20/external/pytorch/torch/_lobpcg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Locally Optimal Block Preconditioned Conjugate Gradient methods."""
3# Author: Pearu Peterson
4# Created: February 2020
5
6from typing import Dict, Optional, Tuple
7
8import torch
9from torch import _linalg_utils as _utils, Tensor
10from torch.overrides import handle_torch_function, has_torch_function
11
12
13__all__ = ["lobpcg"]
14
15
16def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
17    # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
18    F = D.unsqueeze(-2) - D.unsqueeze(-1)
19    F.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))
20    F.pow_(-1)
21
22    # A.grad = U (D.grad + (U^T U.grad * F)) U^T
23    Ut = U.mT.contiguous()
24    res = torch.matmul(
25        U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut)
26    )
27
28    return res
29
30
31def _polynomial_coefficients_given_roots(roots):
32    """
33    Given the `roots` of a polynomial, find the polynomial's coefficients.
34
35    If roots = (r_1, ..., r_n), then the method returns
36    coefficients (a_0, a_1, ..., a_n (== 1)) so that
37    p(x) = (x - r_1) * ... * (x - r_n)
38         = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
39
40    Note: for better performance requires writing a low-level kernel
41    """
42    poly_order = roots.shape[-1]
43    poly_coeffs_shape = list(roots.shape)
44    # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
45    # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
46    # but we insert one extra coefficient to enable better vectorization below
47    poly_coeffs_shape[-1] += 2
48    poly_coeffs = roots.new_zeros(poly_coeffs_shape)
49    poly_coeffs[..., 0] = 1
50    poly_coeffs[..., -1] = 1
51
52    # perform the Horner's rule
53    for i in range(1, poly_order + 1):
54        # note that it is computationally hard to compute backward for this method,
55        # because then given the coefficients it would require finding the roots and/or
56        # calculating the sensitivity based on the Vieta's theorem.
57        # So the code below tries to circumvent the explicit root finding by series
58        # of operations on memory copies imitating the Horner's method.
59        # The memory copies are required to construct nodes in the computational graph
60        # by exploting the explicit (not in-place, separate node for each step)
61        # recursion of the Horner's method.
62        # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
63        poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
64        out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
65        out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(
66            -1, poly_order - i + 1, i + 1
67        )
68        poly_coeffs = poly_coeffs_new
69
70    return poly_coeffs.narrow(-1, 1, poly_order + 1)
71
72
73def _polynomial_value(poly, x, zero_power, transition):
74    """
75    A generic method for computing poly(x) using the Horner's rule.
76
77    Args:
78      poly (Tensor): the (possibly batched) 1D Tensor representing
79                     polynomial coefficients such that
80                     poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
81                     poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
82
83      x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
84
85      zero_power (Tensor): the representation of `x^0`. It is application-specific.
86
87      transition (Callable): the function that accepts some intermediate result `int_val`,
88                             the `x` and a specific polynomial coefficient
89                             `poly[..., k]` for some iteration `k`.
90                             It basically performs one iteration of the Horner's rule
91                             defined as `x * int_val + poly[..., k] * zero_power`.
92                             Note that `zero_power` is not a parameter,
93                             because the step `+ poly[..., k] * zero_power` depends on `x`,
94                             whether it is a vector, a matrix, or something else, so this
95                             functionality is delegated to the user.
96    """
97
98    res = zero_power.clone()
99    for k in range(poly.size(-1) - 2, -1, -1):
100        res = transition(res, x, poly[..., k])
101    return res
102
103
104def _matrix_polynomial_value(poly, x, zero_power=None):
105    """
106    Evaluates `poly(x)` for the (batched) matrix input `x`.
107    Check out `_polynomial_value` function for more details.
108    """
109
110    # matrix-aware Horner's rule iteration
111    def transition(curr_poly_val, x, poly_coeff):
112        res = x.matmul(curr_poly_val)
113        res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
114        return res
115
116    if zero_power is None:
117        zero_power = torch.eye(
118            x.size(-1), x.size(-1), dtype=x.dtype, device=x.device
119        ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
120
121    return _polynomial_value(poly, x, zero_power, transition)
122
123
124def _vector_polynomial_value(poly, x, zero_power=None):
125    """
126    Evaluates `poly(x)` for the (batched) vector input `x`.
127    Check out `_polynomial_value` function for more details.
128    """
129
130    # vector-aware Horner's rule iteration
131    def transition(curr_poly_val, x, poly_coeff):
132        res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
133        return res
134
135    if zero_power is None:
136        zero_power = x.new_ones(1).expand(x.shape)
137
138    return _polynomial_value(poly, x, zero_power, transition)
139
140
141def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
142    # compute a projection operator onto an orthogonal subspace spanned by the
143    # columns of U defined as (I - UU^T)
144    Ut = U.mT.contiguous()
145    proj_U_ortho = -U.matmul(Ut)
146    proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
147
148    # compute U_ortho, a basis for the orthogonal complement to the span(U),
149    # by projecting a random [..., m, m - k] matrix onto the subspace spanned
150    # by the columns of U.
151    #
152    # fix generator for determinism
153    gen = torch.Generator(A.device)
154
155    # orthogonal complement to the span(U)
156    U_ortho = proj_U_ortho.matmul(
157        torch.randn(
158            (*A.shape[:-1], A.size(-1) - D.size(-1)),
159            dtype=A.dtype,
160            device=A.device,
161            generator=gen,
162        )
163    )
164    U_ortho_t = U_ortho.mT.contiguous()
165
166    # compute the coefficients of the characteristic polynomial of the tensor D.
167    # Note that D is diagonal, so the diagonal elements are exactly the roots
168    # of the characteristic polynomial.
169    chr_poly_D = _polynomial_coefficients_given_roots(D)
170
171    # the code belows finds the explicit solution to the Sylvester equation
172    # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
173    # and incorporates it into the whole gradient stored in the `res` variable.
174    #
175    # Equivalent to the following naive implementation:
176    # res = A.new_zeros(A.shape)
177    # p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
178    # for k in range(1, chr_poly_D.size(-1)):
179    #     p_res.zero_()
180    #     for i in range(0, k):
181    #         p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
182    #     res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @  p_res @ U.t())
183    #
184    # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
185    # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
186    # and we need to compute g(U_grad, A, U, D)
187    #
188    # The naive implementation is based on the paper
189    # Hu, Qingxi, and Daizhan Cheng.
190    # "The polynomial solution to the Sylvester matrix equation."
191    # Applied mathematics letters 19.9 (2006): 859-864.
192    #
193    # We can modify the computation of `p_res` from above in a more efficient way
194    # p_res =   U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
195    #       + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
196    #       + ...
197    #       + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
198    # Note that this saves us from redundant matrix products with A (elimination of matrix_power)
199    U_grad_projected = U_grad
200    series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
201    for k in range(1, chr_poly_D.size(-1)):
202        poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
203        series_acc += U_grad_projected * poly_D.unsqueeze(-2)
204        U_grad_projected = A.matmul(U_grad_projected)
205
206    # compute chr_poly_D(A) which essentially is:
207    #
208    # chr_poly_D_at_A = A.new_zeros(A.shape)
209    # for k in range(chr_poly_D.size(-1)):
210    #     chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
211    #
212    # Note, however, for better performance we use the Horner's rule
213    chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
214
215    # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
216    chr_poly_D_at_A_to_U_ortho = torch.matmul(
217        U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho)
218    )
219    # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
220    # Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
221    # Cholesky decomposition requires the input to be positive-definite.
222    # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
223    # 1. `largest` == False, or
224    # 2. `largest` == True and `k` is even
225    # under the assumption that `A` has distinct eigenvalues.
226    #
227    # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
228    chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
229    chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
230        chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
231    )
232
233    # compute the gradient part in span(U)
234    res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
235
236    # incorporate the Sylvester equation solution into the full gradient
237    # it resides in span(U_ortho)
238    res -= U_ortho.matmul(
239        chr_poly_D_at_A_to_U_ortho_sign
240        * torch.cholesky_solve(
241            U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L
242        )
243    ).matmul(Ut)
244
245    return res
246
247
248def _symeig_backward(D_grad, U_grad, A, D, U, largest):
249    # if `U` is square, then the columns of `U` is a complete eigenspace
250    if U.size(-1) == U.size(-2):
251        return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
252    else:
253        return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest)
254
255
256class LOBPCGAutogradFunction(torch.autograd.Function):
257    @staticmethod
258    def forward(  # type: ignore[override]
259        ctx,
260        A: Tensor,
261        k: Optional[int] = None,
262        B: Optional[Tensor] = None,
263        X: Optional[Tensor] = None,
264        n: Optional[int] = None,
265        iK: Optional[Tensor] = None,
266        niter: Optional[int] = None,
267        tol: Optional[float] = None,
268        largest: Optional[bool] = None,
269        method: Optional[str] = None,
270        tracker: None = None,
271        ortho_iparams: Optional[Dict[str, int]] = None,
272        ortho_fparams: Optional[Dict[str, float]] = None,
273        ortho_bparams: Optional[Dict[str, bool]] = None,
274    ) -> Tuple[Tensor, Tensor]:
275        # makes sure that input is contiguous for efficiency.
276        # Note: autograd does not support dense gradients for sparse input yet.
277        A = A.contiguous() if (not A.is_sparse) else A
278        if B is not None:
279            B = B.contiguous() if (not B.is_sparse) else B
280
281        D, U = _lobpcg(
282            A,
283            k,
284            B,
285            X,
286            n,
287            iK,
288            niter,
289            tol,
290            largest,
291            method,
292            tracker,
293            ortho_iparams,
294            ortho_fparams,
295            ortho_bparams,
296        )
297
298        ctx.save_for_backward(A, B, D, U)
299        ctx.largest = largest
300
301        return D, U
302
303    @staticmethod
304    def backward(ctx, D_grad, U_grad):
305        A_grad = B_grad = None
306        grads = [None] * 14
307
308        A, B, D, U = ctx.saved_tensors
309        largest = ctx.largest
310
311        # lobpcg.backward has some limitations. Checks for unsupported input
312        if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
313            raise ValueError(
314                "lobpcg.backward does not support sparse input yet."
315                "Note that lobpcg.forward does though."
316            )
317        if (
318            A.dtype in (torch.complex64, torch.complex128)
319            or B is not None
320            and B.dtype in (torch.complex64, torch.complex128)
321        ):
322            raise ValueError(
323                "lobpcg.backward does not support complex input yet."
324                "Note that lobpcg.forward does though."
325            )
326        if B is not None:
327            raise ValueError(
328                "lobpcg.backward does not support backward with B != I yet."
329            )
330
331        if largest is None:
332            largest = True
333
334        # symeig backward
335        if B is None:
336            A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest)
337
338        # A has index 0
339        grads[0] = A_grad
340        # B has index 2
341        grads[2] = B_grad
342        return tuple(grads)
343
344
345def lobpcg(
346    A: Tensor,
347    k: Optional[int] = None,
348    B: Optional[Tensor] = None,
349    X: Optional[Tensor] = None,
350    n: Optional[int] = None,
351    iK: Optional[Tensor] = None,
352    niter: Optional[int] = None,
353    tol: Optional[float] = None,
354    largest: Optional[bool] = None,
355    method: Optional[str] = None,
356    tracker: None = None,
357    ortho_iparams: Optional[Dict[str, int]] = None,
358    ortho_fparams: Optional[Dict[str, float]] = None,
359    ortho_bparams: Optional[Dict[str, bool]] = None,
360) -> Tuple[Tensor, Tensor]:
361    """Find the k largest (or smallest) eigenvalues and the corresponding
362    eigenvectors of a symmetric positive definite generalized
363    eigenvalue problem using matrix-free LOBPCG methods.
364
365    This function is a front-end to the following LOBPCG algorithms
366    selectable via `method` argument:
367
368      `method="basic"` - the LOBPCG method introduced by Andrew
369      Knyazev, see [Knyazev2001]. A less robust method, may fail when
370      Cholesky is applied to singular input.
371
372      `method="ortho"` - the LOBPCG method with orthogonal basis
373      selection [StathopoulosEtal2002]. A robust method.
374
375    Supported inputs are dense, sparse, and batches of dense matrices.
376
377    .. note:: In general, the basic method spends least time per
378      iteration. However, the robust methods converge much faster and
379      are more stable. So, the usage of the basic method is generally
380      not recommended but there exist cases where the usage of the
381      basic method may be preferred.
382
383    .. warning:: The backward method does not support sparse and complex inputs.
384      It works only when `B` is not provided (i.e. `B == None`).
385      We are actively working on extensions, and the details of
386      the algorithms are going to be published promptly.
387
388    .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
389      To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
390      in first-order optimization routines, prior to running `lobpcg`
391      we do the following symmetrization map: `A -> (A + A.t()) / 2`.
392      The map is performed only when the `A` requires gradients.
393
394    Args:
395
396      A (Tensor): the input tensor of size :math:`(*, m, m)`
397
398      B (Tensor, optional): the input tensor of size :math:`(*, m,
399                  m)`. When not specified, `B` is interpreted as
400                  identity matrix.
401
402      X (tensor, optional): the input tensor of size :math:`(*, m, n)`
403                  where `k <= n <= m`. When specified, it is used as
404                  initial approximation of eigenvectors. X must be a
405                  dense tensor.
406
407      iK (tensor, optional): the input tensor of size :math:`(*, m,
408                  m)`. When specified, it will be used as preconditioner.
409
410      k (integer, optional): the number of requested
411                  eigenpairs. Default is the number of :math:`X`
412                  columns (when specified) or `1`.
413
414      n (integer, optional): if :math:`X` is not specified then `n`
415                  specifies the size of the generated random
416                  approximation of eigenvectors. Default value for `n`
417                  is `k`. If :math:`X` is specified, the value of `n`
418                  (when specified) must be the number of :math:`X`
419                  columns.
420
421      tol (float, optional): residual tolerance for stopping
422                 criterion. Default is `feps ** 0.5` where `feps` is
423                 smallest non-zero floating-point number of the given
424                 input tensor `A` data type.
425
426      largest (bool, optional): when True, solve the eigenproblem for
427                 the largest eigenvalues. Otherwise, solve the
428                 eigenproblem for smallest eigenvalues. Default is
429                 `True`.
430
431      method (str, optional): select LOBPCG method. See the
432                 description of the function above. Default is
433                 "ortho".
434
435      niter (int, optional): maximum number of iterations. When
436                 reached, the iteration process is hard-stopped and
437                 the current approximation of eigenpairs is returned.
438                 For infinite iteration but until convergence criteria
439                 is met, use `-1`.
440
441      tracker (callable, optional) : a function for tracing the
442                 iteration process. When specified, it is called at
443                 each iteration step with LOBPCG instance as an
444                 argument. The LOBPCG instance holds the full state of
445                 the iteration process in the following attributes:
446
447                   `iparams`, `fparams`, `bparams` - dictionaries of
448                   integer, float, and boolean valued input
449                   parameters, respectively
450
451                   `ivars`, `fvars`, `bvars`, `tvars` - dictionaries
452                   of integer, float, boolean, and Tensor valued
453                   iteration variables, respectively.
454
455                   `A`, `B`, `iK` - input Tensor arguments.
456
457                   `E`, `X`, `S`, `R` - iteration Tensor variables.
458
459                 For instance:
460
461                   `ivars["istep"]` - the current iteration step
462                   `X` - the current approximation of eigenvectors
463                   `E` - the current approximation of eigenvalues
464                   `R` - the current residual
465                   `ivars["converged_count"]` - the current number of converged eigenpairs
466                   `tvars["rerr"]` - the current state of convergence criteria
467
468                 Note that when `tracker` stores Tensor objects from
469                 the LOBPCG instance, it must make copies of these.
470
471                 If `tracker` sets `bvars["force_stop"] = True`, the
472                 iteration process will be hard-stopped.
473
474      ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
475                 various parameters to LOBPCG algorithm when using
476                 `method="ortho"`.
477
478    Returns:
479
480      E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
481
482      X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
483
484    References:
485
486      [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
487      Preconditioned Eigensolver: Locally Optimal Block Preconditioned
488      Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
489      517-541. (25 pages)
490      https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
491
492      [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
493      Wu. (2002) A Block Orthogonalization Procedure with Constant
494      Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
495      2165-2182. (18 pages)
496      https://epubs.siam.org/doi/10.1137/S1064827500370883
497
498      [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
499      Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
500      SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
501      https://epubs.siam.org/doi/abs/10.1137/17M1129830
502
503    """
504
505    if not torch.jit.is_scripting():
506        tensor_ops = (A, B, X, iK)
507        if not set(map(type, tensor_ops)).issubset(
508            (torch.Tensor, type(None))
509        ) and has_torch_function(tensor_ops):
510            return handle_torch_function(
511                lobpcg,
512                tensor_ops,
513                A,
514                k=k,
515                B=B,
516                X=X,
517                n=n,
518                iK=iK,
519                niter=niter,
520                tol=tol,
521                largest=largest,
522                method=method,
523                tracker=tracker,
524                ortho_iparams=ortho_iparams,
525                ortho_fparams=ortho_fparams,
526                ortho_bparams=ortho_bparams,
527            )
528
529    if not torch._jit_internal.is_scripting():
530        if A.requires_grad or (B is not None and B.requires_grad):
531            # While it is expected that `A` is symmetric,
532            # the `A_grad` might be not. Therefore we perform the trick below,
533            # so that `A_grad` becomes symmetric.
534            # The symmetrization is important for first-order optimization methods,
535            # so that (A - alpha * A_grad) is still a symmetric matrix.
536            # Same holds for `B`.
537            A_sym = (A + A.mT) / 2
538            B_sym = (B + B.mT) / 2 if (B is not None) else None
539
540            return LOBPCGAutogradFunction.apply(
541                A_sym,
542                k,
543                B_sym,
544                X,
545                n,
546                iK,
547                niter,
548                tol,
549                largest,
550                method,
551                tracker,
552                ortho_iparams,
553                ortho_fparams,
554                ortho_bparams,
555            )
556    else:
557        if A.requires_grad or (B is not None and B.requires_grad):
558            raise RuntimeError(
559                "Script and require grads is not supported atm."
560                "If you just want to do the forward, use .detach()"
561                "on A and B before calling into lobpcg"
562            )
563
564    return _lobpcg(
565        A,
566        k,
567        B,
568        X,
569        n,
570        iK,
571        niter,
572        tol,
573        largest,
574        method,
575        tracker,
576        ortho_iparams,
577        ortho_fparams,
578        ortho_bparams,
579    )
580
581
582def _lobpcg(
583    A: Tensor,
584    k: Optional[int] = None,
585    B: Optional[Tensor] = None,
586    X: Optional[Tensor] = None,
587    n: Optional[int] = None,
588    iK: Optional[Tensor] = None,
589    niter: Optional[int] = None,
590    tol: Optional[float] = None,
591    largest: Optional[bool] = None,
592    method: Optional[str] = None,
593    tracker: None = None,
594    ortho_iparams: Optional[Dict[str, int]] = None,
595    ortho_fparams: Optional[Dict[str, float]] = None,
596    ortho_bparams: Optional[Dict[str, bool]] = None,
597) -> Tuple[Tensor, Tensor]:
598    # A must be square:
599    assert A.shape[-2] == A.shape[-1], A.shape
600    if B is not None:
601        # A and B must have the same shapes:
602        assert A.shape == B.shape, (A.shape, B.shape)
603
604    dtype = _utils.get_floating_dtype(A)
605    device = A.device
606    if tol is None:
607        feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype]
608        tol = feps**0.5
609
610    m = A.shape[-1]
611    k = (1 if X is None else X.shape[-1]) if k is None else k
612    n = (k if n is None else n) if X is None else X.shape[-1]
613
614    if m < 3 * n:
615        raise ValueError(
616            f"LPBPCG algorithm is not applicable when the number of A rows (={m})"
617            f" is smaller than 3 x the number of requested eigenpairs (={n})"
618        )
619
620    method = "ortho" if method is None else method
621
622    iparams = {
623        "m": m,
624        "n": n,
625        "k": k,
626        "niter": 1000 if niter is None else niter,
627    }
628
629    fparams = {
630        "tol": tol,
631    }
632
633    bparams = {"largest": True if largest is None else largest}
634
635    if method == "ortho":
636        if ortho_iparams is not None:
637            iparams.update(ortho_iparams)
638        if ortho_fparams is not None:
639            fparams.update(ortho_fparams)
640        if ortho_bparams is not None:
641            bparams.update(ortho_bparams)
642        iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3)
643        iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3)
644        fparams["ortho_tol"] = fparams.get("ortho_tol", tol)
645        fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol)
646        fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol)
647        bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False)
648
649    if not torch.jit.is_scripting():
650        LOBPCG.call_tracker = LOBPCG_call_tracker  # type: ignore[method-assign]
651
652    if len(A.shape) > 2:
653        N = int(torch.prod(torch.tensor(A.shape[:-2])))
654        bA = A.reshape((N,) + A.shape[-2:])
655        bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
656        bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
657        bE = torch.empty((N, k), dtype=dtype, device=device)
658        bXret = torch.empty((N, m, k), dtype=dtype, device=device)
659
660        for i in range(N):
661            A_ = bA[i]
662            B_ = bB[i] if bB is not None else None
663            X_ = (
664                torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
665            )
666            assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
667            iparams["batch_index"] = i
668            worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
669            worker.run()
670            bE[i] = worker.E[:k]
671            bXret[i] = worker.X[:, :k]
672
673        if not torch.jit.is_scripting():
674            LOBPCG.call_tracker = LOBPCG_call_tracker_orig  # type: ignore[method-assign]
675
676        return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
677
678    X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
679    assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
680
681    worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
682
683    worker.run()
684
685    if not torch.jit.is_scripting():
686        LOBPCG.call_tracker = LOBPCG_call_tracker_orig  # type: ignore[method-assign]
687
688    return worker.E[:k], worker.X[:, :k]
689
690
691class LOBPCG:
692    """Worker class of LOBPCG methods."""
693
694    def __init__(
695        self,
696        A: Optional[Tensor],
697        B: Optional[Tensor],
698        X: Tensor,
699        iK: Optional[Tensor],
700        iparams: Dict[str, int],
701        fparams: Dict[str, float],
702        bparams: Dict[str, bool],
703        method: str,
704        tracker: None,
705    ) -> None:
706        # constant parameters
707        self.A = A
708        self.B = B
709        self.iK = iK
710        self.iparams = iparams
711        self.fparams = fparams
712        self.bparams = bparams
713        self.method = method
714        self.tracker = tracker
715        m = iparams["m"]
716        n = iparams["n"]
717
718        # variable parameters
719        self.X = X
720        self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
721        self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
722        self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
723        self.tvars: Dict[str, Tensor] = {}
724        self.ivars: Dict[str, int] = {"istep": 0}
725        self.fvars: Dict[str, float] = {"_": 0.0}
726        self.bvars: Dict[str, bool] = {"_": False}
727
728    def __str__(self):
729        lines = ["LOPBCG:"]
730        lines += [f"  iparams={self.iparams}"]
731        lines += [f"  fparams={self.fparams}"]
732        lines += [f"  bparams={self.bparams}"]
733        lines += [f"  ivars={self.ivars}"]
734        lines += [f"  fvars={self.fvars}"]
735        lines += [f"  bvars={self.bvars}"]
736        lines += [f"  tvars={self.tvars}"]
737        lines += [f"  A={self.A}"]
738        lines += [f"  B={self.B}"]
739        lines += [f"  iK={self.iK}"]
740        lines += [f"  X={self.X}"]
741        lines += [f"  E={self.E}"]
742        r = ""
743        for line in lines:
744            r += line + "\n"
745        return r
746
747    def update(self):
748        """Set and update iteration variables."""
749        if self.ivars["istep"] == 0:
750            X_norm = float(torch.norm(self.X))
751            iX_norm = X_norm**-1
752            A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
753            B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
754            self.fvars["X_norm"] = X_norm
755            self.fvars["A_norm"] = A_norm
756            self.fvars["B_norm"] = B_norm
757            self.ivars["iterations_left"] = self.iparams["niter"]
758            self.ivars["converged_count"] = 0
759            self.ivars["converged_end"] = 0
760
761        if self.method == "ortho":
762            self._update_ortho()
763        else:
764            self._update_basic()
765
766        self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1
767        self.ivars["istep"] = self.ivars["istep"] + 1
768
769    def update_residual(self):
770        """Update residual R from A, B, X, E."""
771        mm = _utils.matmul
772        self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
773
774    def update_converged_count(self):
775        """Determine the number of converged eigenpairs using backward stable
776        convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
777
778        Users may redefine this method for custom convergence criteria.
779        """
780        # (...) -> int
781        prev_count = self.ivars["converged_count"]
782        tol = self.fparams["tol"]
783        A_norm = self.fvars["A_norm"]
784        B_norm = self.fvars["B_norm"]
785        E, X, R = self.E, self.X, self.R
786        rerr = (
787            torch.norm(R, 2, (0,))
788            * (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1
789        )
790        converged = rerr.real < tol  # this is a norm so imag is 0.0
791        count = 0
792        for b in converged:
793            if not b:
794                # ignore convergence of following pairs to ensure
795                # strict ordering of eigenpairs
796                break
797            count += 1
798        assert (
799            count >= prev_count
800        ), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease"
801        self.ivars["converged_count"] = count
802        self.tvars["rerr"] = rerr
803        return count
804
805    def stop_iteration(self):
806        """Return True to stop iterations.
807
808        Note that tracker (if defined) can force-stop iterations by
809        setting ``worker.bvars['force_stop'] = True``.
810        """
811        return (
812            self.bvars.get("force_stop", False)
813            or self.ivars["iterations_left"] == 0
814            or self.ivars["converged_count"] >= self.iparams["k"]
815        )
816
817    def run(self):
818        """Run LOBPCG iterations.
819
820        Use this method as a template for implementing LOBPCG
821        iteration scheme with custom tracker that is compatible with
822        TorchScript.
823        """
824        self.update()
825
826        if not torch.jit.is_scripting() and self.tracker is not None:
827            self.call_tracker()
828
829        while not self.stop_iteration():
830            self.update()
831
832            if not torch.jit.is_scripting() and self.tracker is not None:
833                self.call_tracker()
834
835    @torch.jit.unused
836    def call_tracker(self):
837        """Interface for tracking iteration process in Python mode.
838
839        Tracking the iteration process is disabled in TorchScript
840        mode. In fact, one should specify tracker=None when JIT
841        compiling functions using lobpcg.
842        """
843        # do nothing when in TorchScript mode
844
845    # Internal methods
846
847    def _update_basic(self):
848        """
849        Update or initialize iteration variables when `method == "basic"`.
850        """
851        mm = torch.matmul
852        ns = self.ivars["converged_end"]
853        nc = self.ivars["converged_count"]
854        n = self.iparams["n"]
855        largest = self.bparams["largest"]
856
857        if self.ivars["istep"] == 0:
858            Ri = self._get_rayleigh_ritz_transform(self.X)
859            M = _utils.qform(_utils.qform(self.A, self.X), Ri)
860            E, Z = _utils.symeig(M, largest)
861            self.X[:] = mm(self.X, mm(Ri, Z))
862            self.E[:] = E
863            np = 0
864            self.update_residual()
865            nc = self.update_converged_count()
866            self.S[..., :n] = self.X
867
868            W = _utils.matmul(self.iK, self.R)
869            self.ivars["converged_end"] = ns = n + np + W.shape[-1]
870            self.S[:, n + np : ns] = W
871        else:
872            S_ = self.S[:, nc:ns]
873            Ri = self._get_rayleigh_ritz_transform(S_)
874            M = _utils.qform(_utils.qform(self.A, S_), Ri)
875            E_, Z = _utils.symeig(M, largest)
876            self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc]))
877            self.E[nc:] = E_[: n - nc]
878            P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc]))
879            np = P.shape[-1]
880
881            self.update_residual()
882            nc = self.update_converged_count()
883            self.S[..., :n] = self.X
884            self.S[:, n : n + np] = P
885            W = _utils.matmul(self.iK, self.R[:, nc:])
886
887            self.ivars["converged_end"] = ns = n + np + W.shape[-1]
888            self.S[:, n + np : ns] = W
889
890    def _update_ortho(self):
891        """
892        Update or initialize iteration variables when `method == "ortho"`.
893        """
894        mm = torch.matmul
895        ns = self.ivars["converged_end"]
896        nc = self.ivars["converged_count"]
897        n = self.iparams["n"]
898        largest = self.bparams["largest"]
899
900        if self.ivars["istep"] == 0:
901            Ri = self._get_rayleigh_ritz_transform(self.X)
902            M = _utils.qform(_utils.qform(self.A, self.X), Ri)
903            E, Z = _utils.symeig(M, largest)
904            self.X = mm(self.X, mm(Ri, Z))
905            self.update_residual()
906            np = 0
907            nc = self.update_converged_count()
908            self.S[:, :n] = self.X
909            W = self._get_ortho(self.R, self.X)
910            ns = self.ivars["converged_end"] = n + np + W.shape[-1]
911            self.S[:, n + np : ns] = W
912
913        else:
914            S_ = self.S[:, nc:ns]
915            # Rayleigh-Ritz procedure
916            E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
917
918            # Update E, X, P
919            self.X[:, nc:] = mm(S_, Z[:, : n - nc])
920            self.E[nc:] = E_[: n - nc]
921            P = mm(S_, mm(Z[:, n - nc :], _utils.basis(Z[: n - nc, n - nc :].mT)))
922            np = P.shape[-1]
923
924            # check convergence
925            self.update_residual()
926            nc = self.update_converged_count()
927
928            # update S
929            self.S[:, :n] = self.X
930            self.S[:, n : n + np] = P
931            W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np])
932            ns = self.ivars["converged_end"] = n + np + W.shape[-1]
933            self.S[:, n + np : ns] = W
934
935    def _get_rayleigh_ritz_transform(self, S):
936        """Return a transformation matrix that is used in Rayleigh-Ritz
937        procedure for reducing a general eigenvalue problem :math:`(S^TAS)
938        C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
939        S^TAS Ri) Z = Z E` where `C = Ri Z`.
940
941        .. note:: In the original Rayleight-Ritz procedure in
942          [DuerschEtal2018], the problem is formulated as follows::
943
944            SAS = S^T A S
945            SBS = S^T B S
946            D = (<diagonal matrix of SBS>) ** -1/2
947            R^T R = Cholesky(D SBS D)
948            Ri = D R^-1
949            solve symeig problem Ri^T SAS Ri Z = Theta Z
950            C = Ri Z
951
952          To reduce the number of matrix products (denoted by empty
953          space between matrices), here we introduce element-wise
954          products (denoted by symbol `*`) so that the Rayleight-Ritz
955          procedure becomes::
956
957            SAS = S^T A S
958            SBS = S^T B S
959            d = (<diagonal of SBS>) ** -1/2    # this is 1-d column vector
960            dd = d d^T                         # this is 2-d matrix
961            R^T R = Cholesky(dd * SBS)
962            Ri = R^-1 * d                      # broadcasting
963            solve symeig problem Ri^T SAS Ri Z = Theta Z
964            C = Ri Z
965
966          where `dd` is 2-d matrix that replaces matrix products `D M
967          D` with one element-wise product `M * dd`; and `d` replaces
968          matrix product `D M` with element-wise product `M *
969          d`. Also, creating the diagonal matrix `D` is avoided.
970
971        Args:
972        S (Tensor): the matrix basis for the search subspace, size is
973                    :math:`(m, n)`.
974
975        Returns:
976        Ri (tensor): upper-triangular transformation matrix of size
977                     :math:`(n, n)`.
978
979        """
980        B = self.B
981        mm = torch.matmul
982        SBS = _utils.qform(B, S)
983        d_row = SBS.diagonal(0, -2, -1) ** -0.5
984        d_col = d_row.reshape(d_row.shape[0], 1)
985        # TODO use torch.linalg.cholesky_solve once it is implemented
986        R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True)
987        return torch.linalg.solve_triangular(
988            R, d_row.diag_embed(), upper=True, left=False
989        )
990
991    def _get_svqb(self, U: Tensor, drop: bool, tau: float) -> Tensor:
992        """Return B-orthonormal U.
993
994        .. note:: When `drop` is `False` then `svqb` is based on the
995                  Algorithm 4 from [DuerschPhD2015] that is a slight
996                  modification of the corresponding algorithm
997                  introduced in [StathopolousWu2002].
998
999        Args:
1000
1001          U (Tensor) : initial approximation, size is (m, n)
1002          drop (bool) : when True, drop columns that
1003                     contribution to the `span([U])` is small.
1004          tau (float) : positive tolerance
1005
1006        Returns:
1007
1008          U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
1009                       is (m, n1), where `n1 = n` if `drop` is `False,
1010                       otherwise `n1 <= n`.
1011
1012        """
1013        if torch.numel(U) == 0:
1014            return U
1015        UBU = _utils.qform(self.B, U)
1016        d = UBU.diagonal(0, -2, -1)
1017
1018        # Detect and drop exact zero columns from U. While the test
1019        # `abs(d) == 0` is unlikely to be True for random data, it is
1020        # possible to construct input data to lobpcg where it will be
1021        # True leading to a failure (notice the `d ** -0.5` operation
1022        # in the original algorithm). To prevent the failure, we drop
1023        # the exact zero columns here and then continue with the
1024        # original algorithm below.
1025        nz = torch.where(abs(d) != 0.0)
1026        assert len(nz) == 1, nz
1027        if len(nz[0]) < len(d):
1028            U = U[:, nz[0]]
1029            if torch.numel(U) == 0:
1030                return U
1031            UBU = _utils.qform(self.B, U)
1032            d = UBU.diagonal(0, -2, -1)
1033            nz = torch.where(abs(d) != 0.0)
1034            assert len(nz[0]) == len(d)
1035
1036        # The original algorithm 4 from [DuerschPhD2015].
1037        d_col = (d**-0.5).reshape(d.shape[0], 1)
1038        DUBUD = (UBU * d_col) * d_col.mT
1039        E, Z = _utils.symeig(DUBUD)
1040        t = tau * abs(E).max()
1041        if drop:
1042            keep = torch.where(E > t)
1043            assert len(keep) == 1, keep
1044            E = E[keep[0]]
1045            Z = Z[:, keep[0]]
1046            d_col = d_col[keep[0]]
1047        else:
1048            E[(torch.where(E < t))[0]] = t
1049
1050        return torch.matmul(U * d_col.mT, Z * E**-0.5)
1051
1052    def _get_ortho(self, U, V):
1053        """Return B-orthonormal U with columns are B-orthogonal to V.
1054
1055        .. note:: When `bparams["ortho_use_drop"] == False` then
1056                  `_get_ortho` is based on the Algorithm 3 from
1057                  [DuerschPhD2015] that is a slight modification of
1058                  the corresponding algorithm introduced in
1059                  [StathopolousWu2002]. Otherwise, the method
1060                  implements Algorithm 6 from [DuerschPhD2015]
1061
1062        .. note:: If all U columns are B-collinear to V then the
1063                  returned tensor U will be empty.
1064
1065        Args:
1066
1067          U (Tensor) : initial approximation, size is (m, n)
1068          V (Tensor) : B-orthogonal external basis, size is (m, k)
1069
1070        Returns:
1071
1072          U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
1073                       such that :math:`V^T B U=0`, size is (m, n1),
1074                       where `n1 = n` if `drop` is `False, otherwise
1075                       `n1 <= n`.
1076        """
1077        mm = torch.matmul
1078        mm_B = _utils.matmul
1079        m = self.iparams["m"]
1080        tau_ortho = self.fparams["ortho_tol"]
1081        tau_drop = self.fparams["ortho_tol_drop"]
1082        tau_replace = self.fparams["ortho_tol_replace"]
1083        i_max = self.iparams["ortho_i_max"]
1084        j_max = self.iparams["ortho_j_max"]
1085        # when use_drop==True, enable dropping U columns that have
1086        # small contribution to the `span([U, V])`.
1087        use_drop = self.bparams["ortho_use_drop"]
1088
1089        # clean up variables from the previous call
1090        for vkey in list(self.fvars.keys()):
1091            if vkey.startswith("ortho_") and vkey.endswith("_rerr"):
1092                self.fvars.pop(vkey)
1093        self.ivars.pop("ortho_i", 0)
1094        self.ivars.pop("ortho_j", 0)
1095
1096        BV_norm = torch.norm(mm_B(self.B, V))
1097        BU = mm_B(self.B, U)
1098        VBU = mm(V.mT, BU)
1099        i = j = 0
1100        stats = ""
1101        for i in range(i_max):
1102            U = U - mm(V, VBU)
1103            drop = False
1104            tau_svqb = tau_drop
1105            for j in range(j_max):
1106                if use_drop:
1107                    U = self._get_svqb(U, drop, tau_svqb)
1108                    drop = True
1109                    tau_svqb = tau_replace
1110                else:
1111                    U = self._get_svqb(U, False, tau_replace)
1112                if torch.numel(U) == 0:
1113                    # all initial U columns are B-collinear to V
1114                    self.ivars["ortho_i"] = i
1115                    self.ivars["ortho_j"] = j
1116                    return U
1117                BU = mm_B(self.B, U)
1118                UBU = mm(U.mT, BU)
1119                U_norm = torch.norm(U)
1120                BU_norm = torch.norm(BU)
1121                R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype)
1122                R_norm = torch.norm(R)
1123                # https://github.com/pytorch/pytorch/issues/33810 workaround:
1124                rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
1125                vkey = f"ortho_UBUmI_rerr[{i}, {j}]"
1126                self.fvars[vkey] = rerr
1127                if rerr < tau_ortho:
1128                    break
1129            VBU = mm(V.mT, BU)
1130            VBU_norm = torch.norm(VBU)
1131            U_norm = torch.norm(U)
1132            rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
1133            vkey = f"ortho_VBU_rerr[{i}]"
1134            self.fvars[vkey] = rerr
1135            if rerr < tau_ortho:
1136                break
1137            if m < U.shape[-1] + V.shape[-1]:
1138                # TorchScript needs the class var to be assigned to a local to
1139                # do optional type refinement
1140                B = self.B
1141                assert B is not None
1142                raise ValueError(
1143                    "Overdetermined shape of U:"
1144                    f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold"
1145                )
1146        self.ivars["ortho_i"] = i
1147        self.ivars["ortho_j"] = j
1148        return U
1149
1150
1151# Calling tracker is separated from LOBPCG definitions because
1152# TorchScript does not support user-defined callback arguments:
1153LOBPCG_call_tracker_orig = LOBPCG.call_tracker
1154
1155
1156def LOBPCG_call_tracker(self):
1157    self.tracker(self)
1158