1# mypy: allow-untyped-defs 2"""Locally Optimal Block Preconditioned Conjugate Gradient methods.""" 3# Author: Pearu Peterson 4# Created: February 2020 5 6from typing import Dict, Optional, Tuple 7 8import torch 9from torch import _linalg_utils as _utils, Tensor 10from torch.overrides import handle_torch_function, has_torch_function 11 12 13__all__ = ["lobpcg"] 14 15 16def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U): 17 # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0 18 F = D.unsqueeze(-2) - D.unsqueeze(-1) 19 F.diagonal(dim1=-2, dim2=-1).fill_(float("inf")) 20 F.pow_(-1) 21 22 # A.grad = U (D.grad + (U^T U.grad * F)) U^T 23 Ut = U.mT.contiguous() 24 res = torch.matmul( 25 U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut) 26 ) 27 28 return res 29 30 31def _polynomial_coefficients_given_roots(roots): 32 """ 33 Given the `roots` of a polynomial, find the polynomial's coefficients. 34 35 If roots = (r_1, ..., r_n), then the method returns 36 coefficients (a_0, a_1, ..., a_n (== 1)) so that 37 p(x) = (x - r_1) * ... * (x - r_n) 38 = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0 39 40 Note: for better performance requires writing a low-level kernel 41 """ 42 poly_order = roots.shape[-1] 43 poly_coeffs_shape = list(roots.shape) 44 # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0, 45 # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)}, 46 # but we insert one extra coefficient to enable better vectorization below 47 poly_coeffs_shape[-1] += 2 48 poly_coeffs = roots.new_zeros(poly_coeffs_shape) 49 poly_coeffs[..., 0] = 1 50 poly_coeffs[..., -1] = 1 51 52 # perform the Horner's rule 53 for i in range(1, poly_order + 1): 54 # note that it is computationally hard to compute backward for this method, 55 # because then given the coefficients it would require finding the roots and/or 56 # calculating the sensitivity based on the Vieta's theorem. 57 # So the code below tries to circumvent the explicit root finding by series 58 # of operations on memory copies imitating the Horner's method. 59 # The memory copies are required to construct nodes in the computational graph 60 # by exploting the explicit (not in-place, separate node for each step) 61 # recursion of the Horner's method. 62 # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity. 63 poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs 64 out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1) 65 out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow( 66 -1, poly_order - i + 1, i + 1 67 ) 68 poly_coeffs = poly_coeffs_new 69 70 return poly_coeffs.narrow(-1, 1, poly_order + 1) 71 72 73def _polynomial_value(poly, x, zero_power, transition): 74 """ 75 A generic method for computing poly(x) using the Horner's rule. 76 77 Args: 78 poly (Tensor): the (possibly batched) 1D Tensor representing 79 polynomial coefficients such that 80 poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and 81 poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n 82 83 x (Tensor): the value (possible batched) to evalate the polynomial `poly` at. 84 85 zero_power (Tensor): the representation of `x^0`. It is application-specific. 86 87 transition (Callable): the function that accepts some intermediate result `int_val`, 88 the `x` and a specific polynomial coefficient 89 `poly[..., k]` for some iteration `k`. 90 It basically performs one iteration of the Horner's rule 91 defined as `x * int_val + poly[..., k] * zero_power`. 92 Note that `zero_power` is not a parameter, 93 because the step `+ poly[..., k] * zero_power` depends on `x`, 94 whether it is a vector, a matrix, or something else, so this 95 functionality is delegated to the user. 96 """ 97 98 res = zero_power.clone() 99 for k in range(poly.size(-1) - 2, -1, -1): 100 res = transition(res, x, poly[..., k]) 101 return res 102 103 104def _matrix_polynomial_value(poly, x, zero_power=None): 105 """ 106 Evaluates `poly(x)` for the (batched) matrix input `x`. 107 Check out `_polynomial_value` function for more details. 108 """ 109 110 # matrix-aware Horner's rule iteration 111 def transition(curr_poly_val, x, poly_coeff): 112 res = x.matmul(curr_poly_val) 113 res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1)) 114 return res 115 116 if zero_power is None: 117 zero_power = torch.eye( 118 x.size(-1), x.size(-1), dtype=x.dtype, device=x.device 119 ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1)) 120 121 return _polynomial_value(poly, x, zero_power, transition) 122 123 124def _vector_polynomial_value(poly, x, zero_power=None): 125 """ 126 Evaluates `poly(x)` for the (batched) vector input `x`. 127 Check out `_polynomial_value` function for more details. 128 """ 129 130 # vector-aware Horner's rule iteration 131 def transition(curr_poly_val, x, poly_coeff): 132 res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val) 133 return res 134 135 if zero_power is None: 136 zero_power = x.new_ones(1).expand(x.shape) 137 138 return _polynomial_value(poly, x, zero_power, transition) 139 140 141def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): 142 # compute a projection operator onto an orthogonal subspace spanned by the 143 # columns of U defined as (I - UU^T) 144 Ut = U.mT.contiguous() 145 proj_U_ortho = -U.matmul(Ut) 146 proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1) 147 148 # compute U_ortho, a basis for the orthogonal complement to the span(U), 149 # by projecting a random [..., m, m - k] matrix onto the subspace spanned 150 # by the columns of U. 151 # 152 # fix generator for determinism 153 gen = torch.Generator(A.device) 154 155 # orthogonal complement to the span(U) 156 U_ortho = proj_U_ortho.matmul( 157 torch.randn( 158 (*A.shape[:-1], A.size(-1) - D.size(-1)), 159 dtype=A.dtype, 160 device=A.device, 161 generator=gen, 162 ) 163 ) 164 U_ortho_t = U_ortho.mT.contiguous() 165 166 # compute the coefficients of the characteristic polynomial of the tensor D. 167 # Note that D is diagonal, so the diagonal elements are exactly the roots 168 # of the characteristic polynomial. 169 chr_poly_D = _polynomial_coefficients_given_roots(D) 170 171 # the code belows finds the explicit solution to the Sylvester equation 172 # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U 173 # and incorporates it into the whole gradient stored in the `res` variable. 174 # 175 # Equivalent to the following naive implementation: 176 # res = A.new_zeros(A.shape) 177 # p_res = A.new_zeros(*A.shape[:-1], D.size(-1)) 178 # for k in range(1, chr_poly_D.size(-1)): 179 # p_res.zero_() 180 # for i in range(0, k): 181 # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2) 182 # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t()) 183 # 184 # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity 185 # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g, 186 # and we need to compute g(U_grad, A, U, D) 187 # 188 # The naive implementation is based on the paper 189 # Hu, Qingxi, and Daizhan Cheng. 190 # "The polynomial solution to the Sylvester matrix equation." 191 # Applied mathematics letters 19.9 (2006): 859-864. 192 # 193 # We can modify the computation of `p_res` from above in a more efficient way 194 # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2) 195 # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2) 196 # + ... 197 # + A.matrix_power(k - 1) U_grad * chr_poly_D[k] 198 # Note that this saves us from redundant matrix products with A (elimination of matrix_power) 199 U_grad_projected = U_grad 200 series_acc = U_grad_projected.new_zeros(U_grad_projected.shape) 201 for k in range(1, chr_poly_D.size(-1)): 202 poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D) 203 series_acc += U_grad_projected * poly_D.unsqueeze(-2) 204 U_grad_projected = A.matmul(U_grad_projected) 205 206 # compute chr_poly_D(A) which essentially is: 207 # 208 # chr_poly_D_at_A = A.new_zeros(A.shape) 209 # for k in range(chr_poly_D.size(-1)): 210 # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k) 211 # 212 # Note, however, for better performance we use the Horner's rule 213 chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A) 214 215 # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t 216 chr_poly_D_at_A_to_U_ortho = torch.matmul( 217 U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho) 218 ) 219 # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its 220 # Cholesky decomposition and then use `torch.cholesky_solve` for better stability. 221 # Cholesky decomposition requires the input to be positive-definite. 222 # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if 223 # 1. `largest` == False, or 224 # 2. `largest` == True and `k` is even 225 # under the assumption that `A` has distinct eigenvalues. 226 # 227 # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite 228 chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1 229 chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky( 230 chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho 231 ) 232 233 # compute the gradient part in span(U) 234 res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) 235 236 # incorporate the Sylvester equation solution into the full gradient 237 # it resides in span(U_ortho) 238 res -= U_ortho.matmul( 239 chr_poly_D_at_A_to_U_ortho_sign 240 * torch.cholesky_solve( 241 U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L 242 ) 243 ).matmul(Ut) 244 245 return res 246 247 248def _symeig_backward(D_grad, U_grad, A, D, U, largest): 249 # if `U` is square, then the columns of `U` is a complete eigenspace 250 if U.size(-1) == U.size(-2): 251 return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) 252 else: 253 return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest) 254 255 256class LOBPCGAutogradFunction(torch.autograd.Function): 257 @staticmethod 258 def forward( # type: ignore[override] 259 ctx, 260 A: Tensor, 261 k: Optional[int] = None, 262 B: Optional[Tensor] = None, 263 X: Optional[Tensor] = None, 264 n: Optional[int] = None, 265 iK: Optional[Tensor] = None, 266 niter: Optional[int] = None, 267 tol: Optional[float] = None, 268 largest: Optional[bool] = None, 269 method: Optional[str] = None, 270 tracker: None = None, 271 ortho_iparams: Optional[Dict[str, int]] = None, 272 ortho_fparams: Optional[Dict[str, float]] = None, 273 ortho_bparams: Optional[Dict[str, bool]] = None, 274 ) -> Tuple[Tensor, Tensor]: 275 # makes sure that input is contiguous for efficiency. 276 # Note: autograd does not support dense gradients for sparse input yet. 277 A = A.contiguous() if (not A.is_sparse) else A 278 if B is not None: 279 B = B.contiguous() if (not B.is_sparse) else B 280 281 D, U = _lobpcg( 282 A, 283 k, 284 B, 285 X, 286 n, 287 iK, 288 niter, 289 tol, 290 largest, 291 method, 292 tracker, 293 ortho_iparams, 294 ortho_fparams, 295 ortho_bparams, 296 ) 297 298 ctx.save_for_backward(A, B, D, U) 299 ctx.largest = largest 300 301 return D, U 302 303 @staticmethod 304 def backward(ctx, D_grad, U_grad): 305 A_grad = B_grad = None 306 grads = [None] * 14 307 308 A, B, D, U = ctx.saved_tensors 309 largest = ctx.largest 310 311 # lobpcg.backward has some limitations. Checks for unsupported input 312 if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]): 313 raise ValueError( 314 "lobpcg.backward does not support sparse input yet." 315 "Note that lobpcg.forward does though." 316 ) 317 if ( 318 A.dtype in (torch.complex64, torch.complex128) 319 or B is not None 320 and B.dtype in (torch.complex64, torch.complex128) 321 ): 322 raise ValueError( 323 "lobpcg.backward does not support complex input yet." 324 "Note that lobpcg.forward does though." 325 ) 326 if B is not None: 327 raise ValueError( 328 "lobpcg.backward does not support backward with B != I yet." 329 ) 330 331 if largest is None: 332 largest = True 333 334 # symeig backward 335 if B is None: 336 A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest) 337 338 # A has index 0 339 grads[0] = A_grad 340 # B has index 2 341 grads[2] = B_grad 342 return tuple(grads) 343 344 345def lobpcg( 346 A: Tensor, 347 k: Optional[int] = None, 348 B: Optional[Tensor] = None, 349 X: Optional[Tensor] = None, 350 n: Optional[int] = None, 351 iK: Optional[Tensor] = None, 352 niter: Optional[int] = None, 353 tol: Optional[float] = None, 354 largest: Optional[bool] = None, 355 method: Optional[str] = None, 356 tracker: None = None, 357 ortho_iparams: Optional[Dict[str, int]] = None, 358 ortho_fparams: Optional[Dict[str, float]] = None, 359 ortho_bparams: Optional[Dict[str, bool]] = None, 360) -> Tuple[Tensor, Tensor]: 361 """Find the k largest (or smallest) eigenvalues and the corresponding 362 eigenvectors of a symmetric positive definite generalized 363 eigenvalue problem using matrix-free LOBPCG methods. 364 365 This function is a front-end to the following LOBPCG algorithms 366 selectable via `method` argument: 367 368 `method="basic"` - the LOBPCG method introduced by Andrew 369 Knyazev, see [Knyazev2001]. A less robust method, may fail when 370 Cholesky is applied to singular input. 371 372 `method="ortho"` - the LOBPCG method with orthogonal basis 373 selection [StathopoulosEtal2002]. A robust method. 374 375 Supported inputs are dense, sparse, and batches of dense matrices. 376 377 .. note:: In general, the basic method spends least time per 378 iteration. However, the robust methods converge much faster and 379 are more stable. So, the usage of the basic method is generally 380 not recommended but there exist cases where the usage of the 381 basic method may be preferred. 382 383 .. warning:: The backward method does not support sparse and complex inputs. 384 It works only when `B` is not provided (i.e. `B == None`). 385 We are actively working on extensions, and the details of 386 the algorithms are going to be published promptly. 387 388 .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not. 389 To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric 390 in first-order optimization routines, prior to running `lobpcg` 391 we do the following symmetrization map: `A -> (A + A.t()) / 2`. 392 The map is performed only when the `A` requires gradients. 393 394 Args: 395 396 A (Tensor): the input tensor of size :math:`(*, m, m)` 397 398 B (Tensor, optional): the input tensor of size :math:`(*, m, 399 m)`. When not specified, `B` is interpreted as 400 identity matrix. 401 402 X (tensor, optional): the input tensor of size :math:`(*, m, n)` 403 where `k <= n <= m`. When specified, it is used as 404 initial approximation of eigenvectors. X must be a 405 dense tensor. 406 407 iK (tensor, optional): the input tensor of size :math:`(*, m, 408 m)`. When specified, it will be used as preconditioner. 409 410 k (integer, optional): the number of requested 411 eigenpairs. Default is the number of :math:`X` 412 columns (when specified) or `1`. 413 414 n (integer, optional): if :math:`X` is not specified then `n` 415 specifies the size of the generated random 416 approximation of eigenvectors. Default value for `n` 417 is `k`. If :math:`X` is specified, the value of `n` 418 (when specified) must be the number of :math:`X` 419 columns. 420 421 tol (float, optional): residual tolerance for stopping 422 criterion. Default is `feps ** 0.5` where `feps` is 423 smallest non-zero floating-point number of the given 424 input tensor `A` data type. 425 426 largest (bool, optional): when True, solve the eigenproblem for 427 the largest eigenvalues. Otherwise, solve the 428 eigenproblem for smallest eigenvalues. Default is 429 `True`. 430 431 method (str, optional): select LOBPCG method. See the 432 description of the function above. Default is 433 "ortho". 434 435 niter (int, optional): maximum number of iterations. When 436 reached, the iteration process is hard-stopped and 437 the current approximation of eigenpairs is returned. 438 For infinite iteration but until convergence criteria 439 is met, use `-1`. 440 441 tracker (callable, optional) : a function for tracing the 442 iteration process. When specified, it is called at 443 each iteration step with LOBPCG instance as an 444 argument. The LOBPCG instance holds the full state of 445 the iteration process in the following attributes: 446 447 `iparams`, `fparams`, `bparams` - dictionaries of 448 integer, float, and boolean valued input 449 parameters, respectively 450 451 `ivars`, `fvars`, `bvars`, `tvars` - dictionaries 452 of integer, float, boolean, and Tensor valued 453 iteration variables, respectively. 454 455 `A`, `B`, `iK` - input Tensor arguments. 456 457 `E`, `X`, `S`, `R` - iteration Tensor variables. 458 459 For instance: 460 461 `ivars["istep"]` - the current iteration step 462 `X` - the current approximation of eigenvectors 463 `E` - the current approximation of eigenvalues 464 `R` - the current residual 465 `ivars["converged_count"]` - the current number of converged eigenpairs 466 `tvars["rerr"]` - the current state of convergence criteria 467 468 Note that when `tracker` stores Tensor objects from 469 the LOBPCG instance, it must make copies of these. 470 471 If `tracker` sets `bvars["force_stop"] = True`, the 472 iteration process will be hard-stopped. 473 474 ortho_iparams, ortho_fparams, ortho_bparams (dict, optional): 475 various parameters to LOBPCG algorithm when using 476 `method="ortho"`. 477 478 Returns: 479 480 E (Tensor): tensor of eigenvalues of size :math:`(*, k)` 481 482 X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)` 483 484 References: 485 486 [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal 487 Preconditioned Eigensolver: Locally Optimal Block Preconditioned 488 Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2), 489 517-541. (25 pages) 490 https://epubs.siam.org/doi/abs/10.1137/S1064827500366124 491 492 [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng 493 Wu. (2002) A Block Orthogonalization Procedure with Constant 494 Synchronization Requirements. SIAM J. Sci. Comput., 23(6), 495 2165-2182. (18 pages) 496 https://epubs.siam.org/doi/10.1137/S1064827500370883 497 498 [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming 499 Gu. (2018) A Robust and Efficient Implementation of LOBPCG. 500 SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages) 501 https://epubs.siam.org/doi/abs/10.1137/17M1129830 502 503 """ 504 505 if not torch.jit.is_scripting(): 506 tensor_ops = (A, B, X, iK) 507 if not set(map(type, tensor_ops)).issubset( 508 (torch.Tensor, type(None)) 509 ) and has_torch_function(tensor_ops): 510 return handle_torch_function( 511 lobpcg, 512 tensor_ops, 513 A, 514 k=k, 515 B=B, 516 X=X, 517 n=n, 518 iK=iK, 519 niter=niter, 520 tol=tol, 521 largest=largest, 522 method=method, 523 tracker=tracker, 524 ortho_iparams=ortho_iparams, 525 ortho_fparams=ortho_fparams, 526 ortho_bparams=ortho_bparams, 527 ) 528 529 if not torch._jit_internal.is_scripting(): 530 if A.requires_grad or (B is not None and B.requires_grad): 531 # While it is expected that `A` is symmetric, 532 # the `A_grad` might be not. Therefore we perform the trick below, 533 # so that `A_grad` becomes symmetric. 534 # The symmetrization is important for first-order optimization methods, 535 # so that (A - alpha * A_grad) is still a symmetric matrix. 536 # Same holds for `B`. 537 A_sym = (A + A.mT) / 2 538 B_sym = (B + B.mT) / 2 if (B is not None) else None 539 540 return LOBPCGAutogradFunction.apply( 541 A_sym, 542 k, 543 B_sym, 544 X, 545 n, 546 iK, 547 niter, 548 tol, 549 largest, 550 method, 551 tracker, 552 ortho_iparams, 553 ortho_fparams, 554 ortho_bparams, 555 ) 556 else: 557 if A.requires_grad or (B is not None and B.requires_grad): 558 raise RuntimeError( 559 "Script and require grads is not supported atm." 560 "If you just want to do the forward, use .detach()" 561 "on A and B before calling into lobpcg" 562 ) 563 564 return _lobpcg( 565 A, 566 k, 567 B, 568 X, 569 n, 570 iK, 571 niter, 572 tol, 573 largest, 574 method, 575 tracker, 576 ortho_iparams, 577 ortho_fparams, 578 ortho_bparams, 579 ) 580 581 582def _lobpcg( 583 A: Tensor, 584 k: Optional[int] = None, 585 B: Optional[Tensor] = None, 586 X: Optional[Tensor] = None, 587 n: Optional[int] = None, 588 iK: Optional[Tensor] = None, 589 niter: Optional[int] = None, 590 tol: Optional[float] = None, 591 largest: Optional[bool] = None, 592 method: Optional[str] = None, 593 tracker: None = None, 594 ortho_iparams: Optional[Dict[str, int]] = None, 595 ortho_fparams: Optional[Dict[str, float]] = None, 596 ortho_bparams: Optional[Dict[str, bool]] = None, 597) -> Tuple[Tensor, Tensor]: 598 # A must be square: 599 assert A.shape[-2] == A.shape[-1], A.shape 600 if B is not None: 601 # A and B must have the same shapes: 602 assert A.shape == B.shape, (A.shape, B.shape) 603 604 dtype = _utils.get_floating_dtype(A) 605 device = A.device 606 if tol is None: 607 feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype] 608 tol = feps**0.5 609 610 m = A.shape[-1] 611 k = (1 if X is None else X.shape[-1]) if k is None else k 612 n = (k if n is None else n) if X is None else X.shape[-1] 613 614 if m < 3 * n: 615 raise ValueError( 616 f"LPBPCG algorithm is not applicable when the number of A rows (={m})" 617 f" is smaller than 3 x the number of requested eigenpairs (={n})" 618 ) 619 620 method = "ortho" if method is None else method 621 622 iparams = { 623 "m": m, 624 "n": n, 625 "k": k, 626 "niter": 1000 if niter is None else niter, 627 } 628 629 fparams = { 630 "tol": tol, 631 } 632 633 bparams = {"largest": True if largest is None else largest} 634 635 if method == "ortho": 636 if ortho_iparams is not None: 637 iparams.update(ortho_iparams) 638 if ortho_fparams is not None: 639 fparams.update(ortho_fparams) 640 if ortho_bparams is not None: 641 bparams.update(ortho_bparams) 642 iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3) 643 iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3) 644 fparams["ortho_tol"] = fparams.get("ortho_tol", tol) 645 fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol) 646 fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol) 647 bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False) 648 649 if not torch.jit.is_scripting(): 650 LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[method-assign] 651 652 if len(A.shape) > 2: 653 N = int(torch.prod(torch.tensor(A.shape[:-2]))) 654 bA = A.reshape((N,) + A.shape[-2:]) 655 bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None 656 bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None 657 bE = torch.empty((N, k), dtype=dtype, device=device) 658 bXret = torch.empty((N, m, k), dtype=dtype, device=device) 659 660 for i in range(N): 661 A_ = bA[i] 662 B_ = bB[i] if bB is not None else None 663 X_ = ( 664 torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i] 665 ) 666 assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n)) 667 iparams["batch_index"] = i 668 worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker) 669 worker.run() 670 bE[i] = worker.E[:k] 671 bXret[i] = worker.X[:, :k] 672 673 if not torch.jit.is_scripting(): 674 LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign] 675 676 return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k)) 677 678 X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X 679 assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n)) 680 681 worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker) 682 683 worker.run() 684 685 if not torch.jit.is_scripting(): 686 LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign] 687 688 return worker.E[:k], worker.X[:, :k] 689 690 691class LOBPCG: 692 """Worker class of LOBPCG methods.""" 693 694 def __init__( 695 self, 696 A: Optional[Tensor], 697 B: Optional[Tensor], 698 X: Tensor, 699 iK: Optional[Tensor], 700 iparams: Dict[str, int], 701 fparams: Dict[str, float], 702 bparams: Dict[str, bool], 703 method: str, 704 tracker: None, 705 ) -> None: 706 # constant parameters 707 self.A = A 708 self.B = B 709 self.iK = iK 710 self.iparams = iparams 711 self.fparams = fparams 712 self.bparams = bparams 713 self.method = method 714 self.tracker = tracker 715 m = iparams["m"] 716 n = iparams["n"] 717 718 # variable parameters 719 self.X = X 720 self.E = torch.zeros((n,), dtype=X.dtype, device=X.device) 721 self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device) 722 self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device) 723 self.tvars: Dict[str, Tensor] = {} 724 self.ivars: Dict[str, int] = {"istep": 0} 725 self.fvars: Dict[str, float] = {"_": 0.0} 726 self.bvars: Dict[str, bool] = {"_": False} 727 728 def __str__(self): 729 lines = ["LOPBCG:"] 730 lines += [f" iparams={self.iparams}"] 731 lines += [f" fparams={self.fparams}"] 732 lines += [f" bparams={self.bparams}"] 733 lines += [f" ivars={self.ivars}"] 734 lines += [f" fvars={self.fvars}"] 735 lines += [f" bvars={self.bvars}"] 736 lines += [f" tvars={self.tvars}"] 737 lines += [f" A={self.A}"] 738 lines += [f" B={self.B}"] 739 lines += [f" iK={self.iK}"] 740 lines += [f" X={self.X}"] 741 lines += [f" E={self.E}"] 742 r = "" 743 for line in lines: 744 r += line + "\n" 745 return r 746 747 def update(self): 748 """Set and update iteration variables.""" 749 if self.ivars["istep"] == 0: 750 X_norm = float(torch.norm(self.X)) 751 iX_norm = X_norm**-1 752 A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm 753 B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm 754 self.fvars["X_norm"] = X_norm 755 self.fvars["A_norm"] = A_norm 756 self.fvars["B_norm"] = B_norm 757 self.ivars["iterations_left"] = self.iparams["niter"] 758 self.ivars["converged_count"] = 0 759 self.ivars["converged_end"] = 0 760 761 if self.method == "ortho": 762 self._update_ortho() 763 else: 764 self._update_basic() 765 766 self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1 767 self.ivars["istep"] = self.ivars["istep"] + 1 768 769 def update_residual(self): 770 """Update residual R from A, B, X, E.""" 771 mm = _utils.matmul 772 self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E 773 774 def update_converged_count(self): 775 """Determine the number of converged eigenpairs using backward stable 776 convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018]. 777 778 Users may redefine this method for custom convergence criteria. 779 """ 780 # (...) -> int 781 prev_count = self.ivars["converged_count"] 782 tol = self.fparams["tol"] 783 A_norm = self.fvars["A_norm"] 784 B_norm = self.fvars["B_norm"] 785 E, X, R = self.E, self.X, self.R 786 rerr = ( 787 torch.norm(R, 2, (0,)) 788 * (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1 789 ) 790 converged = rerr.real < tol # this is a norm so imag is 0.0 791 count = 0 792 for b in converged: 793 if not b: 794 # ignore convergence of following pairs to ensure 795 # strict ordering of eigenpairs 796 break 797 count += 1 798 assert ( 799 count >= prev_count 800 ), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease" 801 self.ivars["converged_count"] = count 802 self.tvars["rerr"] = rerr 803 return count 804 805 def stop_iteration(self): 806 """Return True to stop iterations. 807 808 Note that tracker (if defined) can force-stop iterations by 809 setting ``worker.bvars['force_stop'] = True``. 810 """ 811 return ( 812 self.bvars.get("force_stop", False) 813 or self.ivars["iterations_left"] == 0 814 or self.ivars["converged_count"] >= self.iparams["k"] 815 ) 816 817 def run(self): 818 """Run LOBPCG iterations. 819 820 Use this method as a template for implementing LOBPCG 821 iteration scheme with custom tracker that is compatible with 822 TorchScript. 823 """ 824 self.update() 825 826 if not torch.jit.is_scripting() and self.tracker is not None: 827 self.call_tracker() 828 829 while not self.stop_iteration(): 830 self.update() 831 832 if not torch.jit.is_scripting() and self.tracker is not None: 833 self.call_tracker() 834 835 @torch.jit.unused 836 def call_tracker(self): 837 """Interface for tracking iteration process in Python mode. 838 839 Tracking the iteration process is disabled in TorchScript 840 mode. In fact, one should specify tracker=None when JIT 841 compiling functions using lobpcg. 842 """ 843 # do nothing when in TorchScript mode 844 845 # Internal methods 846 847 def _update_basic(self): 848 """ 849 Update or initialize iteration variables when `method == "basic"`. 850 """ 851 mm = torch.matmul 852 ns = self.ivars["converged_end"] 853 nc = self.ivars["converged_count"] 854 n = self.iparams["n"] 855 largest = self.bparams["largest"] 856 857 if self.ivars["istep"] == 0: 858 Ri = self._get_rayleigh_ritz_transform(self.X) 859 M = _utils.qform(_utils.qform(self.A, self.X), Ri) 860 E, Z = _utils.symeig(M, largest) 861 self.X[:] = mm(self.X, mm(Ri, Z)) 862 self.E[:] = E 863 np = 0 864 self.update_residual() 865 nc = self.update_converged_count() 866 self.S[..., :n] = self.X 867 868 W = _utils.matmul(self.iK, self.R) 869 self.ivars["converged_end"] = ns = n + np + W.shape[-1] 870 self.S[:, n + np : ns] = W 871 else: 872 S_ = self.S[:, nc:ns] 873 Ri = self._get_rayleigh_ritz_transform(S_) 874 M = _utils.qform(_utils.qform(self.A, S_), Ri) 875 E_, Z = _utils.symeig(M, largest) 876 self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc])) 877 self.E[nc:] = E_[: n - nc] 878 P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc])) 879 np = P.shape[-1] 880 881 self.update_residual() 882 nc = self.update_converged_count() 883 self.S[..., :n] = self.X 884 self.S[:, n : n + np] = P 885 W = _utils.matmul(self.iK, self.R[:, nc:]) 886 887 self.ivars["converged_end"] = ns = n + np + W.shape[-1] 888 self.S[:, n + np : ns] = W 889 890 def _update_ortho(self): 891 """ 892 Update or initialize iteration variables when `method == "ortho"`. 893 """ 894 mm = torch.matmul 895 ns = self.ivars["converged_end"] 896 nc = self.ivars["converged_count"] 897 n = self.iparams["n"] 898 largest = self.bparams["largest"] 899 900 if self.ivars["istep"] == 0: 901 Ri = self._get_rayleigh_ritz_transform(self.X) 902 M = _utils.qform(_utils.qform(self.A, self.X), Ri) 903 E, Z = _utils.symeig(M, largest) 904 self.X = mm(self.X, mm(Ri, Z)) 905 self.update_residual() 906 np = 0 907 nc = self.update_converged_count() 908 self.S[:, :n] = self.X 909 W = self._get_ortho(self.R, self.X) 910 ns = self.ivars["converged_end"] = n + np + W.shape[-1] 911 self.S[:, n + np : ns] = W 912 913 else: 914 S_ = self.S[:, nc:ns] 915 # Rayleigh-Ritz procedure 916 E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest) 917 918 # Update E, X, P 919 self.X[:, nc:] = mm(S_, Z[:, : n - nc]) 920 self.E[nc:] = E_[: n - nc] 921 P = mm(S_, mm(Z[:, n - nc :], _utils.basis(Z[: n - nc, n - nc :].mT))) 922 np = P.shape[-1] 923 924 # check convergence 925 self.update_residual() 926 nc = self.update_converged_count() 927 928 # update S 929 self.S[:, :n] = self.X 930 self.S[:, n : n + np] = P 931 W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np]) 932 ns = self.ivars["converged_end"] = n + np + W.shape[-1] 933 self.S[:, n + np : ns] = W 934 935 def _get_rayleigh_ritz_transform(self, S): 936 """Return a transformation matrix that is used in Rayleigh-Ritz 937 procedure for reducing a general eigenvalue problem :math:`(S^TAS) 938 C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T 939 S^TAS Ri) Z = Z E` where `C = Ri Z`. 940 941 .. note:: In the original Rayleight-Ritz procedure in 942 [DuerschEtal2018], the problem is formulated as follows:: 943 944 SAS = S^T A S 945 SBS = S^T B S 946 D = (<diagonal matrix of SBS>) ** -1/2 947 R^T R = Cholesky(D SBS D) 948 Ri = D R^-1 949 solve symeig problem Ri^T SAS Ri Z = Theta Z 950 C = Ri Z 951 952 To reduce the number of matrix products (denoted by empty 953 space between matrices), here we introduce element-wise 954 products (denoted by symbol `*`) so that the Rayleight-Ritz 955 procedure becomes:: 956 957 SAS = S^T A S 958 SBS = S^T B S 959 d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector 960 dd = d d^T # this is 2-d matrix 961 R^T R = Cholesky(dd * SBS) 962 Ri = R^-1 * d # broadcasting 963 solve symeig problem Ri^T SAS Ri Z = Theta Z 964 C = Ri Z 965 966 where `dd` is 2-d matrix that replaces matrix products `D M 967 D` with one element-wise product `M * dd`; and `d` replaces 968 matrix product `D M` with element-wise product `M * 969 d`. Also, creating the diagonal matrix `D` is avoided. 970 971 Args: 972 S (Tensor): the matrix basis for the search subspace, size is 973 :math:`(m, n)`. 974 975 Returns: 976 Ri (tensor): upper-triangular transformation matrix of size 977 :math:`(n, n)`. 978 979 """ 980 B = self.B 981 mm = torch.matmul 982 SBS = _utils.qform(B, S) 983 d_row = SBS.diagonal(0, -2, -1) ** -0.5 984 d_col = d_row.reshape(d_row.shape[0], 1) 985 # TODO use torch.linalg.cholesky_solve once it is implemented 986 R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True) 987 return torch.linalg.solve_triangular( 988 R, d_row.diag_embed(), upper=True, left=False 989 ) 990 991 def _get_svqb(self, U: Tensor, drop: bool, tau: float) -> Tensor: 992 """Return B-orthonormal U. 993 994 .. note:: When `drop` is `False` then `svqb` is based on the 995 Algorithm 4 from [DuerschPhD2015] that is a slight 996 modification of the corresponding algorithm 997 introduced in [StathopolousWu2002]. 998 999 Args: 1000 1001 U (Tensor) : initial approximation, size is (m, n) 1002 drop (bool) : when True, drop columns that 1003 contribution to the `span([U])` is small. 1004 tau (float) : positive tolerance 1005 1006 Returns: 1007 1008 U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size 1009 is (m, n1), where `n1 = n` if `drop` is `False, 1010 otherwise `n1 <= n`. 1011 1012 """ 1013 if torch.numel(U) == 0: 1014 return U 1015 UBU = _utils.qform(self.B, U) 1016 d = UBU.diagonal(0, -2, -1) 1017 1018 # Detect and drop exact zero columns from U. While the test 1019 # `abs(d) == 0` is unlikely to be True for random data, it is 1020 # possible to construct input data to lobpcg where it will be 1021 # True leading to a failure (notice the `d ** -0.5` operation 1022 # in the original algorithm). To prevent the failure, we drop 1023 # the exact zero columns here and then continue with the 1024 # original algorithm below. 1025 nz = torch.where(abs(d) != 0.0) 1026 assert len(nz) == 1, nz 1027 if len(nz[0]) < len(d): 1028 U = U[:, nz[0]] 1029 if torch.numel(U) == 0: 1030 return U 1031 UBU = _utils.qform(self.B, U) 1032 d = UBU.diagonal(0, -2, -1) 1033 nz = torch.where(abs(d) != 0.0) 1034 assert len(nz[0]) == len(d) 1035 1036 # The original algorithm 4 from [DuerschPhD2015]. 1037 d_col = (d**-0.5).reshape(d.shape[0], 1) 1038 DUBUD = (UBU * d_col) * d_col.mT 1039 E, Z = _utils.symeig(DUBUD) 1040 t = tau * abs(E).max() 1041 if drop: 1042 keep = torch.where(E > t) 1043 assert len(keep) == 1, keep 1044 E = E[keep[0]] 1045 Z = Z[:, keep[0]] 1046 d_col = d_col[keep[0]] 1047 else: 1048 E[(torch.where(E < t))[0]] = t 1049 1050 return torch.matmul(U * d_col.mT, Z * E**-0.5) 1051 1052 def _get_ortho(self, U, V): 1053 """Return B-orthonormal U with columns are B-orthogonal to V. 1054 1055 .. note:: When `bparams["ortho_use_drop"] == False` then 1056 `_get_ortho` is based on the Algorithm 3 from 1057 [DuerschPhD2015] that is a slight modification of 1058 the corresponding algorithm introduced in 1059 [StathopolousWu2002]. Otherwise, the method 1060 implements Algorithm 6 from [DuerschPhD2015] 1061 1062 .. note:: If all U columns are B-collinear to V then the 1063 returned tensor U will be empty. 1064 1065 Args: 1066 1067 U (Tensor) : initial approximation, size is (m, n) 1068 V (Tensor) : B-orthogonal external basis, size is (m, k) 1069 1070 Returns: 1071 1072 U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`) 1073 such that :math:`V^T B U=0`, size is (m, n1), 1074 where `n1 = n` if `drop` is `False, otherwise 1075 `n1 <= n`. 1076 """ 1077 mm = torch.matmul 1078 mm_B = _utils.matmul 1079 m = self.iparams["m"] 1080 tau_ortho = self.fparams["ortho_tol"] 1081 tau_drop = self.fparams["ortho_tol_drop"] 1082 tau_replace = self.fparams["ortho_tol_replace"] 1083 i_max = self.iparams["ortho_i_max"] 1084 j_max = self.iparams["ortho_j_max"] 1085 # when use_drop==True, enable dropping U columns that have 1086 # small contribution to the `span([U, V])`. 1087 use_drop = self.bparams["ortho_use_drop"] 1088 1089 # clean up variables from the previous call 1090 for vkey in list(self.fvars.keys()): 1091 if vkey.startswith("ortho_") and vkey.endswith("_rerr"): 1092 self.fvars.pop(vkey) 1093 self.ivars.pop("ortho_i", 0) 1094 self.ivars.pop("ortho_j", 0) 1095 1096 BV_norm = torch.norm(mm_B(self.B, V)) 1097 BU = mm_B(self.B, U) 1098 VBU = mm(V.mT, BU) 1099 i = j = 0 1100 stats = "" 1101 for i in range(i_max): 1102 U = U - mm(V, VBU) 1103 drop = False 1104 tau_svqb = tau_drop 1105 for j in range(j_max): 1106 if use_drop: 1107 U = self._get_svqb(U, drop, tau_svqb) 1108 drop = True 1109 tau_svqb = tau_replace 1110 else: 1111 U = self._get_svqb(U, False, tau_replace) 1112 if torch.numel(U) == 0: 1113 # all initial U columns are B-collinear to V 1114 self.ivars["ortho_i"] = i 1115 self.ivars["ortho_j"] = j 1116 return U 1117 BU = mm_B(self.B, U) 1118 UBU = mm(U.mT, BU) 1119 U_norm = torch.norm(U) 1120 BU_norm = torch.norm(BU) 1121 R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype) 1122 R_norm = torch.norm(R) 1123 # https://github.com/pytorch/pytorch/issues/33810 workaround: 1124 rerr = float(R_norm) * float(BU_norm * U_norm) ** -1 1125 vkey = f"ortho_UBUmI_rerr[{i}, {j}]" 1126 self.fvars[vkey] = rerr 1127 if rerr < tau_ortho: 1128 break 1129 VBU = mm(V.mT, BU) 1130 VBU_norm = torch.norm(VBU) 1131 U_norm = torch.norm(U) 1132 rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1 1133 vkey = f"ortho_VBU_rerr[{i}]" 1134 self.fvars[vkey] = rerr 1135 if rerr < tau_ortho: 1136 break 1137 if m < U.shape[-1] + V.shape[-1]: 1138 # TorchScript needs the class var to be assigned to a local to 1139 # do optional type refinement 1140 B = self.B 1141 assert B is not None 1142 raise ValueError( 1143 "Overdetermined shape of U:" 1144 f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold" 1145 ) 1146 self.ivars["ortho_i"] = i 1147 self.ivars["ortho_j"] = j 1148 return U 1149 1150 1151# Calling tracker is separated from LOBPCG definitions because 1152# TorchScript does not support user-defined callback arguments: 1153LOBPCG_call_tracker_orig = LOBPCG.call_tracker 1154 1155 1156def LOBPCG_call_tracker(self): 1157 self.tracker(self) 1158