1# mypy: allow-untyped-defs 2import itertools 3import operator 4from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union 5 6import torch 7import torch.nn.functional as F 8from torch import _VF, Tensor 9from torch._C import _add_docstr 10from torch._jit_internal import _overload as overload, boolean_dispatch 11from torch._lowrank import pca_lowrank, svd_lowrank 12from torch.overrides import ( 13 handle_torch_function, 14 has_torch_function, 15 has_torch_function_unary, 16 has_torch_function_variadic, 17) 18 19 20__all__ = [ 21 "atleast_1d", 22 "atleast_2d", 23 "atleast_3d", 24 "align_tensors", 25 "broadcast_shapes", 26 "broadcast_tensors", 27 "cartesian_prod", 28 "block_diag", 29 "cdist", 30 "chain_matmul", 31 "einsum", 32 "istft", 33 "lu", 34 "norm", 35 "meshgrid", 36 "pca_lowrank", 37 "split", 38 "stft", 39 "svd_lowrank", 40 "tensordot", 41 "unique", 42 "unique_consecutive", 43 "unravel_index", 44] 45 46 47def broadcast_tensors(*tensors): 48 r"""broadcast_tensors(*tensors) -> List of Tensors 49 50 Broadcasts the given tensors according to :ref:`broadcasting-semantics`. 51 52 Args: 53 *tensors: any number of tensors of the same type 54 55 .. warning:: 56 57 More than one element of a broadcasted tensor may refer to a single 58 memory location. As a result, in-place operations (especially ones that 59 are vectorized) may result in incorrect behavior. If you need to write 60 to the tensors, please clone them first. 61 62 Example:: 63 64 >>> x = torch.arange(3).view(1, 3) 65 >>> y = torch.arange(2).view(2, 1) 66 >>> a, b = torch.broadcast_tensors(x, y) 67 >>> a.size() 68 torch.Size([2, 3]) 69 >>> a 70 tensor([[0, 1, 2], 71 [0, 1, 2]]) 72 """ 73 # This wrapper exists to support variadic args. 74 if has_torch_function(tensors): 75 return handle_torch_function(broadcast_tensors, tensors, *tensors) 76 return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined] 77 78 79def broadcast_shapes(*shapes): 80 r"""broadcast_shapes(*shapes) -> Size 81 82 Similar to :func:`broadcast_tensors` but for shapes. 83 84 This is equivalent to 85 ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape`` 86 but avoids the need create to intermediate tensors. This is useful for 87 broadcasting tensors of common batch shape but different rightmost shape, 88 e.g. to broadcast mean vectors with covariance matrices. 89 90 Example:: 91 92 >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1)) 93 torch.Size([1, 3, 2]) 94 95 Args: 96 \*shapes (torch.Size): Shapes of tensors. 97 98 Returns: 99 shape (torch.Size): A shape compatible with all input shapes. 100 101 Raises: 102 RuntimeError: If shapes are incompatible. 103 """ 104 # This wrapper exists to support variadic args. 105 # TODO Move this to C++ once the jit has better support for torch.Size. 106 if not torch.jit.is_tracing(): 107 max_len = 0 108 for shape in shapes: 109 if isinstance(shape, (int, torch.SymInt)): 110 if max_len < 1: 111 max_len = 1 112 elif isinstance(shape, (tuple, list)): 113 s = len(shape) 114 if max_len < s: 115 max_len = s 116 result = [1] * max_len 117 118 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 119 120 for shape in shapes: 121 if isinstance(shape, (int, torch.SymInt)): 122 shape = (shape,) 123 if isinstance(shape, (tuple, list)): 124 for i in range(-1, -1 - len(shape), -1): 125 if shape[i] < 0: 126 raise RuntimeError( 127 f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})" 128 ) 129 # NB: result is initialized to 1 so this is effectively an 130 # equals one test 131 if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious( 132 shape[i] == result[i] 133 ): 134 continue 135 if result[i] != 1: 136 raise RuntimeError( 137 "Shape mismatch: objects cannot be broadcast to a single shape" 138 ) 139 result[i] = shape[i] 140 else: 141 raise RuntimeError( 142 "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", 143 shape, 144 ) 145 return torch.Size(result) 146 else: 147 # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail 148 with torch.no_grad(): 149 scalar = torch.zeros((), device="cpu") 150 tensors = [scalar.expand(shape) for shape in shapes] 151 tensors = broadcast_tensors(*tensors) 152 return tensors[0].shape 153 154 155def split( 156 tensor: Tensor, 157 split_size_or_sections: Union[int, List[int]], 158 dim: int = 0, 159) -> Tuple[Tensor, ...]: 160 r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. 161 162 If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will 163 be split into equally sized chunks (if possible). Last chunk will be smaller if 164 the tensor size along the given dimension :attr:`dim` is not divisible by 165 :attr:`split_size`. 166 167 If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split 168 into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according 169 to :attr:`split_size_or_sections`. 170 171 Args: 172 tensor (Tensor): tensor to split. 173 split_size_or_sections (int) or (list(int)): size of a single chunk or 174 list of sizes for each chunk 175 dim (int): dimension along which to split the tensor. 176 177 Example:: 178 179 >>> a = torch.arange(10).reshape(5, 2) 180 >>> a 181 tensor([[0, 1], 182 [2, 3], 183 [4, 5], 184 [6, 7], 185 [8, 9]]) 186 >>> torch.split(a, 2) 187 (tensor([[0, 1], 188 [2, 3]]), 189 tensor([[4, 5], 190 [6, 7]]), 191 tensor([[8, 9]])) 192 >>> torch.split(a, [1, 4]) 193 (tensor([[0, 1]]), 194 tensor([[2, 3], 195 [4, 5], 196 [6, 7], 197 [8, 9]])) 198 """ 199 if has_torch_function_unary(tensor): 200 return handle_torch_function( 201 split, (tensor,), tensor, split_size_or_sections, dim=dim 202 ) 203 # Overwriting reason: 204 # This dispatches to two ATen functions depending on the type of 205 # split_size_or_sections. The branching code is in _tensor.py, which we 206 # call here. 207 return tensor.split(split_size_or_sections, dim) 208 209 210def einsum(*args: Any) -> Tensor: 211 r"""einsum(equation, *operands) -> Tensor 212 213 Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation 214 based on the Einstein summation convention. 215 216 Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them 217 in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of 218 this format are described below, but the general idea is to label every dimension of the input :attr:`operands` 219 with some subscript and define which subscripts are part of the output. The output is then computed by summing 220 the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the 221 output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`. 222 Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why). 223 224 Equation: 225 226 The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of 227 the input :attr:`operands` in the same order as the dimensions, separating subscripts for each operand by a 228 comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript 229 must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is 230 repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand 231 must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that 232 appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order. 233 The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based 234 on the subscripts, and then summing out the dimensions whose subscripts are not part of the output. 235 236 Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation 237 followed by the subscripts for the output. For instance, the following equation computes the transpose of a 238 matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and 239 at most once for the output. 240 241 Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis. 242 Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts, 243 e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth 244 dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the 245 'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not 246 explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions), 247 before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements 248 batch matrix multiplication `'...ij,...jk'`. 249 250 A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis, 251 arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands. 252 253 .. note:: 254 255 ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions 256 covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output. 257 258 .. note:: 259 260 This function uses opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) to speed up computation or to 261 consume less memory by optimizing contraction order. This optimization occurs when there are at least three 262 inputs, since the order does not matter otherwise. Note that finding _the_ optimal path is an NP-hard problem, 263 thus, opt_einsum relies on different heuristics to achieve near-optimal results. If opt_einsum is not available, 264 the default order is to contract from left to right. 265 266 To bypass this default behavior, add the following line to disable the usage of opt_einsum and skip path 267 calculation: `torch.backends.opt_einsum.enabled = False` 268 269 To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line: 270 `torch.backends.opt_einsum.strategy = 'auto'`. The default strategy is 'auto', and we also support 'greedy' and 271 'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in 272 the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html). 273 274 .. note:: 275 276 As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format, 277 subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists 278 follow their operands, and an extra sublist can appear at the end of the input to specify the output's 279 subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object 280 may be provided in a sublist to enable broadcasting as described in the Equation section above. 281 282 Args: 283 equation (str): The subscripts for the Einstein summation. 284 operands (List[Tensor]): The tensors to compute the Einstein summation of. 285 286 Examples:: 287 288 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 289 >>> # trace 290 >>> torch.einsum('ii', torch.randn(4, 4)) 291 tensor(-1.2104) 292 293 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 294 >>> # diagonal 295 >>> torch.einsum('ii->i', torch.randn(4, 4)) 296 tensor([-0.1034, 0.7952, -0.2433, 0.4545]) 297 298 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 299 >>> # outer product 300 >>> x = torch.randn(5) 301 >>> y = torch.randn(4) 302 >>> torch.einsum('i,j->ij', x, y) 303 tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], 304 [-0.3744, 0.9381, 1.2685, -1.6070], 305 [ 0.7208, -1.8058, -2.4419, 3.0936], 306 [ 0.1713, -0.4291, -0.5802, 0.7350], 307 [ 0.5704, -1.4290, -1.9323, 2.4480]]) 308 309 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 310 >>> # batch matrix multiplication 311 >>> As = torch.randn(3, 2, 5) 312 >>> Bs = torch.randn(3, 5, 4) 313 >>> torch.einsum('bij,bjk->bik', As, Bs) 314 tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], 315 [-1.6706, -0.8097, -0.8025, -2.1183]], 316 317 [[ 4.2239, 0.3107, -0.5756, -0.2354], 318 [-1.4558, -0.3460, 1.5087, -0.8530]], 319 320 [[ 2.8153, 1.8787, -4.3839, -1.2112], 321 [ 0.3728, -2.1131, 0.0921, 0.8305]]]) 322 323 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 324 >>> # with sublist format and ellipsis 325 >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) 326 tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], 327 [-1.6706, -0.8097, -0.8025, -2.1183]], 328 329 [[ 4.2239, 0.3107, -0.5756, -0.2354], 330 [-1.4558, -0.3460, 1.5087, -0.8530]], 331 332 [[ 2.8153, 1.8787, -4.3839, -1.2112], 333 [ 0.3728, -2.1131, 0.0921, 0.8305]]]) 334 335 >>> # batch permute 336 >>> A = torch.randn(2, 3, 4, 5) 337 >>> torch.einsum('...ij->...ji', A).shape 338 torch.Size([2, 3, 5, 4]) 339 340 >>> # equivalent to torch.nn.functional.bilinear 341 >>> A = torch.randn(3, 5, 4) 342 >>> l = torch.randn(2, 5) 343 >>> r = torch.randn(2, 4) 344 >>> torch.einsum('bn,anm,bm->ba', l, A, r) 345 tensor([[-0.3430, -5.2405, 0.4494], 346 [ 0.3311, 5.5201, -3.0356]]) 347 """ 348 import torch.backends.opt_einsum as opt_einsum 349 350 # This wrapper exists to support variadic args. 351 if len(args) < 2: 352 raise ValueError( 353 "einsum(): must specify the equation string and at least one operand, " 354 "or at least one operand and its subscripts list" 355 ) 356 357 equation = None 358 operands = None 359 360 if isinstance(args[0], torch.Tensor): 361 # Convert the subscript list format which is an interleaving of operand and its subscripts 362 # list with an optional output subscripts list at the end (see documentation for more details on this) 363 # to the equation string format by creating the equation string from the subscripts list and grouping the 364 # input operands into a tensorlist (List[Tensor]). 365 def parse_subscript(n: int) -> str: 366 if n == Ellipsis: 367 return "..." 368 if n >= 0 and n < 26: 369 return chr(ord("A") + n) 370 if n >= 26 and n < 52: 371 return chr(ord("a") + n - 26) 372 raise ValueError( 373 "einsum(): subscript in subscript list is not within the valid range [0, 52)" 374 ) 375 376 # Parse subscripts for input operands 377 equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2]) 378 379 # Parse optional output subscripts (provided when the number of arguments is odd) 380 if len(args) % 2 == 1: 381 equation += "->" + "".join(parse_subscript(s) for s in args[-1]) 382 operands = args[:-1:2] 383 else: 384 operands = args[::2] 385 else: 386 equation = args[0] 387 operands = args[1:] 388 389 if has_torch_function(operands): 390 return handle_torch_function(einsum, operands, equation, *operands) 391 392 if len(operands) == 1 and isinstance(operands[0], (list, tuple)): 393 # the old interface of passing the operands as one list argument 394 _operands = operands[0] 395 # recurse incase operands contains value that has torch function 396 # in the original implementation this line is omitted 397 return einsum(equation, *_operands) 398 399 if len(operands) <= 2 or not opt_einsum.enabled: 400 # the path for contracting 0 or 1 time(s) is already optimized 401 # or the user has disabled using opt_einsum 402 return _VF.einsum(equation, operands) # type: ignore[attr-defined] 403 404 path = None 405 if opt_einsum.is_available(): 406 _opt_einsum = opt_einsum.get_opt_einsum() 407 tupled_path = _opt_einsum.contract_path( 408 equation, *operands, optimize=opt_einsum.strategy 409 )[0] 410 # flatten path for dispatching to C++ 411 path = [item for pair in tupled_path for item in pair] 412 return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined] 413 414 415# This wrapper exists to support variadic args. 416if TYPE_CHECKING: 417 # The JIT doesn't understand Union, so only add type annotation for mypy 418 def meshgrid( 419 *tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None 420 ) -> Tuple[Tensor, ...]: 421 return _meshgrid(*tensors, indexing=indexing) 422 423else: 424 425 def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]: 426 r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors. 427 428 This is helpful when you want to visualize data over some 429 range of inputs. See below for a plotting example. 430 431 Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as 432 inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`, 433 this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots 434 G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where 435 the output :math:`G_i` is constructed by expanding :math:`T_i` 436 to the result shape. 437 438 .. note:: 439 0D inputs are treated equivalently to 1D inputs of a 440 single element. 441 442 .. warning:: 443 `torch.meshgrid(*tensors)` currently has the same behavior 444 as calling `numpy.meshgrid(*arrays, indexing='ij')`. 445 446 In the future `torch.meshgrid` will transition to 447 `indexing='xy'` as the default. 448 449 https://github.com/pytorch/pytorch/issues/50276 tracks 450 this issue with the goal of migrating to NumPy's behavior. 451 452 .. seealso:: 453 454 :func:`torch.cartesian_prod` has the same effect but it 455 collects the data in a tensor of vectors. 456 457 Args: 458 tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be 459 treated as tensors of size :math:`(1,)` automatically 460 461 indexing: (str, optional): the indexing mode, either "xy" 462 or "ij", defaults to "ij". See warning for future changes. 463 464 If "xy" is selected, the first dimension corresponds 465 to the cardinality of the second input and the second 466 dimension corresponds to the cardinality of the first 467 input. 468 469 If "ij" is selected, the dimensions are in the same 470 order as the cardinality of the inputs. 471 472 Returns: 473 seq (sequence of Tensors): If the input has :math:`N` 474 tensors of size :math:`S_0 \ldots S_{N-1}``, then the 475 output will also have :math:`N` tensors, where each tensor 476 is of shape :math:`(S_0, ..., S_{N-1})`. 477 478 Example:: 479 480 >>> x = torch.tensor([1, 2, 3]) 481 >>> y = torch.tensor([4, 5, 6]) 482 483 Observe the element-wise pairings across the grid, (1, 4), 484 (1, 5), ..., (3, 6). This is the same thing as the 485 cartesian product. 486 >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij') 487 >>> grid_x 488 tensor([[1, 1, 1], 489 [2, 2, 2], 490 [3, 3, 3]]) 491 >>> grid_y 492 tensor([[4, 5, 6], 493 [4, 5, 6], 494 [4, 5, 6]]) 495 496 This correspondence can be seen when these grids are 497 stacked properly. 498 >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))), 499 ... torch.cartesian_prod(x, y)) 500 True 501 502 `torch.meshgrid` is commonly used to produce a grid for 503 plotting. 504 >>> # xdoctest: +REQUIRES(module:matplotlib) 505 >>> # xdoctest: +REQUIRES(env:DOCTEST_SHOW) 506 >>> import matplotlib.pyplot as plt 507 >>> xs = torch.linspace(-5, 5, steps=100) 508 >>> ys = torch.linspace(-5, 5, steps=100) 509 >>> x, y = torch.meshgrid(xs, ys, indexing='xy') 510 >>> z = torch.sin(torch.sqrt(x * x + y * y)) 511 >>> ax = plt.axes(projection='3d') 512 >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy()) 513 >>> plt.show() 514 515 .. image:: ../_static/img/meshgrid.png 516 :width: 512 517 518 """ 519 return _meshgrid(*tensors, indexing=indexing) 520 521 522def _meshgrid(*tensors, indexing: Optional[str]): 523 if has_torch_function(tensors): 524 return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing) 525 if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): 526 # the old interface of passing the operands as one list argument 527 tensors = tensors[0] # type: ignore[assignment] 528 529 # Continue allowing call of old method that takes no indexing 530 # kwarg for forward compatibility reasons. 531 # 532 # Remove this two weeks after landing. 533 kwargs = {} if indexing is None else {"indexing": indexing} 534 return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] 535 536 537def stft( 538 input: Tensor, 539 n_fft: int, 540 hop_length: Optional[int] = None, 541 win_length: Optional[int] = None, 542 window: Optional[Tensor] = None, 543 center: bool = True, 544 pad_mode: str = "reflect", 545 normalized: bool = False, 546 onesided: Optional[bool] = None, 547 return_complex: Optional[bool] = None, 548) -> Tensor: 549 r"""Short-time Fourier transform (STFT). 550 551 .. warning:: 552 From version 1.8.0, :attr:`return_complex` must always be given 553 explicitly for real inputs and `return_complex=False` has been 554 deprecated. Strongly prefer `return_complex=True` as in a future 555 pytorch release, this function will only return complex tensors. 556 557 Note that :func:`torch.view_as_real` can be used to recover a real 558 tensor with an extra last dimension for real and imaginary components. 559 560 .. warning:: 561 From version 2.1, a warning will be provided if a :attr:`window` is 562 not specified. In a future release, this attribute will be required. 563 Not providing a window currently defaults to using a rectangular window, 564 which may result in undesirable artifacts. Consider using tapered windows, 565 such as :func:`torch.hann_window`. 566 567 The STFT computes the Fourier transform of short overlapping windows of the 568 input. This giving frequency components of the signal as they change over 569 time. The interface of this function is modeled after (but *not* a drop-in 570 replacement for) librosa_ stft function. 571 572 .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html 573 574 Ignoring the optional batch dimension, this method computes the following 575 expression: 576 577 .. math:: 578 X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}% 579 \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ % 580 \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{n\_fft}}\right), 581 582 where :math:`m` is the index of the sliding window, and :math:`\omega` is 583 the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``, 584 or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``. 585 586 * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time 587 sequences. 588 589 * If :attr:`hop_length` is ``None`` (default), it is treated as equal to 590 ``floor(n_fft / 4)``. 591 592 * If :attr:`win_length` is ``None`` (default), it is treated as equal to 593 :attr:`n_fft`. 594 595 * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from 596 :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is 597 treated as if having :math:`1` everywhere in the window. If 598 :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on 599 both sides to length :attr:`n_fft` before being applied. 600 601 * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on 602 both sides so that the :math:`t`-th frame is centered at time 603 :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame 604 begins at time :math:`t \times \text{hop\_length}`. 605 606 * :attr:`pad_mode` determines the padding method used on :attr:`input` when 607 :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for 608 all available options. Default is ``"reflect"``. 609 610 * If :attr:`onesided` is ``True`` (default for real input), only values for 611 :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor 612 \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because 613 the real-to-complex Fourier transform satisfies the conjugate symmetry, 614 i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`. 615 Note if the input or window tensors are complex, then :attr:`onesided` 616 output is not possible. 617 618 * If :attr:`normalized` is ``True`` (default is ``False``), the function 619 returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`. 620 621 * If :attr:`return_complex` is ``True`` (default if input is complex), the 622 return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``, 623 the output is a ``input.dim() + 2`` dimensional real tensor where the last 624 dimension represents the real and imaginary components. 625 626 Returns either a complex tensor of size :math:`(* \times N \times T)` if 627 :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N 628 \times T \times 2)`. Where :math:`*` is the optional batch size of 629 :attr:`input`, :math:`N` is the number of frequencies where STFT is applied 630 and :math:`T` is the total number of frames used. 631 632 .. warning:: 633 This function changed signature at version 0.4.1. Calling with the 634 previous signature may cause error or return incorrect result. 635 636 Args: 637 input (Tensor): the input tensor of shape `(B?, L)` where `B?` is an optional 638 batch dimension 639 n_fft (int): size of Fourier transform 640 hop_length (int, optional): the distance between neighboring sliding window 641 frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) 642 win_length (int, optional): the size of window frame and STFT filter. 643 Default: ``None`` (treated as equal to :attr:`n_fft`) 644 window (Tensor, optional): the optional window function. 645 Shape must be 1d and `<= n_fft` 646 Default: ``None`` (treated as window of all :math:`1` s) 647 center (bool, optional): whether to pad :attr:`input` on both sides so 648 that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. 649 Default: ``True`` 650 pad_mode (str, optional): controls the padding method used when 651 :attr:`center` is ``True``. Default: ``"reflect"`` 652 normalized (bool, optional): controls whether to return the normalized STFT results 653 Default: ``False`` 654 onesided (bool, optional): controls whether to return half of results to 655 avoid redundancy for real inputs. 656 Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise. 657 return_complex (bool, optional): whether to return a complex tensor, or 658 a real tensor with an extra last dimension for the real and 659 imaginary components. 660 661 .. versionchanged:: 2.0 662 ``return_complex`` is now a required argument for real inputs, 663 as the default is being transitioned to ``True``. 664 665 .. deprecated:: 2.0 666 ``return_complex=False`` is deprecated, instead use ``return_complex=True`` 667 Note that calling :func:`torch.view_as_real` on the output will 668 recover the deprecated output format. 669 670 Returns: 671 Tensor: A tensor containing the STFT result with shape `(B?, N, T, C?)` where 672 - `B?` is an optional batch dimension from the input. 673 - `N` is the number of frequency samples, `(n_fft // 2) + 1` for 674 `onesided=True`, or otherwise `n_fft`. 675 - `T` is the number of frames, `1 + L // hop_length` 676 for `center=True`, or `1 + (L - n_fft) // hop_length` otherwise. 677 - `C?` is an optional length-2 dimension of real and imaginary 678 components, present when `return_complex=False`. 679 680 """ 681 if has_torch_function_unary(input): 682 return handle_torch_function( 683 stft, 684 (input,), 685 input, 686 n_fft, 687 hop_length=hop_length, 688 win_length=win_length, 689 window=window, 690 center=center, 691 pad_mode=pad_mode, 692 normalized=normalized, 693 onesided=onesided, 694 return_complex=return_complex, 695 ) 696 # NOTE: Do not edit. This code will be removed once the forward-compatibility 697 # period is over for PR #73432 698 if center: 699 signal_dim = input.dim() 700 extended_shape = [1] * (3 - signal_dim) + list(input.size()) 701 pad = int(n_fft // 2) 702 input = F.pad(input.view(extended_shape), [pad, pad], pad_mode) 703 input = input.view(input.shape[-signal_dim:]) 704 return _VF.stft( # type: ignore[attr-defined] 705 input, 706 n_fft, 707 hop_length, 708 win_length, 709 window, 710 normalized, 711 onesided, 712 return_complex, 713 ) 714 715 716istft = _add_docstr( 717 torch.istft, 718 "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, " 719 "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n" 720 r""" 721Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`. 722 723.. warning:: 724 From version 2.1, a warning will be provided if a :attr:`window` is 725 not specified. In a future release, this attribute will be required. 726 Please provide the same window used in the stft call. 727 728It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the 729least squares estimation of the original signal. The algorithm will check using the NOLA condition ( 730nonzero overlap). 731 732Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelope 733created by the summation of all the windows is never zero at certain point in time. Specifically, 734:math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`. 735 736Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame, 737``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False 738since the signal isn't padded). If `length` is given in the arguments and is longer than expected, 739``istft`` will pad zeros to the end of the returned signal. 740 741If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc. 742Left padding can be trimmed off exactly because they can be calculated but right padding cannot be 743calculated without additional information. 744 745Example: Suppose the last window is: 746``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]`` 747 748The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation 749of right padding. These additional values could be zeros or a reflection of the signal so providing 750:attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed 751(some loss of signal). 752 753[1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform," 754IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984. 755 756Args: 757 input (Tensor): The input tensor. Expected to be in the format of :func:`~torch.stft`, 758 output. That is a complex tensor of shape `(B?, N, T)` where 759 760 - `B?` is an optional batch dimension 761 - `N` is the number of frequency samples, `(n_fft // 2) + 1` 762 for onesided input, or otherwise `n_fft`. 763 - `T` is the number of frames, `1 + length // hop_length` for centered stft, 764 or `1 + (length - n_fft) // hop_length` otherwise. 765 766 .. versionchanged:: 2.0 767 Real datatype inputs are no longer supported. Input must now have a 768 complex datatype, as returned by ``stft(..., return_complex=True)``. 769 n_fft (int): Size of Fourier transform 770 hop_length (Optional[int]): The distance between neighboring sliding window frames. 771 (Default: ``n_fft // 4``) 772 win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``) 773 window (Optional[torch.Tensor]): The optional window function. 774 Shape must be 1d and `<= n_fft` 775 (Default: ``torch.ones(win_length)``) 776 center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is 777 centered at time :math:`t \times \text{hop\_length}`. 778 (Default: ``True``) 779 normalized (bool): Whether the STFT was normalized. (Default: ``False``) 780 onesided (Optional[bool]): Whether the STFT was onesided. 781 (Default: ``True`` if `n_fft != fft_size` in the input size) 782 length (Optional[int]): The amount to trim the signal by (i.e. the 783 original signal length). Defaults to `(T - 1) * hop_length` for 784 centered stft, or `n_fft + (T - 1) * hop_length` otherwise, where `T` 785 is the number of input frames. 786 return_complex (Optional[bool]): 787 Whether the output should be complex, or if the input should be 788 assumed to derive from a real signal and window. 789 Note that this is incompatible with ``onesided=True``. 790 (Default: ``False``) 791 792Returns: 793 Tensor: Least squares estimation of the original signal of shape `(B?, length)` where 794 `B?` is an optional batch dimension from the input tensor. 795""", 796) 797 798 799if TYPE_CHECKING: 800 # These _impl functions return a variable number of tensors as output with 801 # __torch_function__; tuple unpacking is done already rather than being 802 # done by the caller of the _impl function 803 _unique_impl_out = Any 804else: 805 _unique_impl_out = Tuple[Tensor, Tensor, Tensor] 806 807 808def _unique_impl( 809 input: Tensor, 810 sorted: bool = True, 811 return_inverse: bool = False, 812 return_counts: bool = False, 813 dim: Optional[int] = None, 814) -> _unique_impl_out: 815 r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor] 816 817 Returns the unique elements of the input tensor. 818 819 .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that 820 this function also eliminates non-consecutive duplicate values. 821 822 .. note:: Currently in the CUDA implementation and the CPU implementation, 823 `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument. 824 Sorting could be slow, so if your input tensor is already sorted, it is recommended to use 825 :func:`torch.unique_consecutive` which avoids the sorting. 826 827 Args: 828 input (Tensor): the input tensor 829 sorted (bool): Whether to sort the unique elements in ascending order 830 before returning as output. 831 return_inverse (bool): Whether to also return the indices for where 832 elements in the original input ended up in the returned unique list. 833 return_counts (bool): Whether to also return the counts for each unique 834 element. 835 dim (int, optional): the dimension to operate upon. If ``None``, the 836 unique of the flattened input is returned. Otherwise, each of the 837 tensors indexed by the given dimension is treated as one of the 838 elements to apply the unique operation upon. See examples for more 839 details. Default: ``None`` 840 841 Returns: 842 (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing 843 844 - **output** (*Tensor*): the output list of unique scalar elements. 845 - **inverse_indices** (*Tensor*): (optional) if 846 :attr:`return_inverse` is True, there will be an additional 847 returned tensor (same shape as input) representing the indices 848 for where elements in the original input map to in the output; 849 otherwise, this function will only return a single tensor. 850 - **counts** (*Tensor*): (optional) if 851 :attr:`return_counts` is True, there will be an additional 852 returned tensor (same shape as output or output.size(dim), 853 if dim was specified) representing the number of occurrences 854 for each unique value or tensor. 855 856 Example:: 857 858 >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long)) 859 >>> output 860 tensor([1, 2, 3]) 861 862 >>> output, inverse_indices = torch.unique( 863 ... torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True) 864 >>> output 865 tensor([1, 2, 3]) 866 >>> inverse_indices 867 tensor([0, 2, 1, 2]) 868 869 >>> output, inverse_indices = torch.unique( 870 ... torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True) 871 >>> output 872 tensor([1, 2, 3]) 873 >>> inverse_indices 874 tensor([[0, 2], 875 [1, 2]]) 876 877 >>> a = torch.tensor([ 878 ... [ 879 ... [1, 1, 0, 0], 880 ... [1, 1, 0, 0], 881 ... [0, 0, 1, 1], 882 ... ], 883 ... [ 884 ... [0, 0, 1, 1], 885 ... [0, 0, 1, 1], 886 ... [1, 1, 1, 1], 887 ... ], 888 ... [ 889 ... [1, 1, 0, 0], 890 ... [1, 1, 0, 0], 891 ... [0, 0, 1, 1], 892 ... ], 893 ... ]) 894 895 >>> # If we call `torch.unique(a, dim=0)`, each of the tensors `a[idx, :, :]` 896 >>> # will be compared. We can see that `a[0, :, :]` and `a[2, :, :]` match 897 >>> # each other, so one of them will be removed. 898 >>> (a[0, :, :] == a[2, :, :]).all() 899 tensor(True) 900 >>> a_unique_dim0 = torch.unique(a, dim=0) 901 >>> a_unique_dim0 902 tensor([[[0, 0, 1, 1], 903 [0, 0, 1, 1], 904 [1, 1, 1, 1]], 905 [[1, 1, 0, 0], 906 [1, 1, 0, 0], 907 [0, 0, 1, 1]]]) 908 909 >>> # Notice which sub-tensors from `a` match with the sub-tensors from 910 >>> # `a_unique_dim0`: 911 >>> (a_unique_dim0[0, :, :] == a[1, :, :]).all() 912 tensor(True) 913 >>> (a_unique_dim0[1, :, :] == a[0, :, :]).all() 914 tensor(True) 915 916 >>> # For `torch.unique(a, dim=1)`, each of the tensors `a[:, idx, :]` are 917 >>> # compared. `a[:, 0, :]` and `a[:, 1, :]` match each other, so one of 918 >>> # them will be removed. 919 >>> (a[:, 0, :] == a[:, 1, :]).all() 920 tensor(True) 921 >>> torch.unique(a, dim=1) 922 tensor([[[0, 0, 1, 1], 923 [1, 1, 0, 0]], 924 [[1, 1, 1, 1], 925 [0, 0, 1, 1]], 926 [[0, 0, 1, 1], 927 [1, 1, 0, 0]]]) 928 929 >>> # For `torch.unique(a, dim=2)`, the tensors `a[:, :, idx]` are compared. 930 >>> # `a[:, :, 0]` and `a[:, :, 1]` match each other. Also, `a[:, :, 2]` and 931 >>> # `a[:, :, 3]` match each other as well. So in this case, two of the 932 >>> # sub-tensors will be removed. 933 >>> (a[:, :, 0] == a[:, :, 1]).all() 934 tensor(True) 935 >>> (a[:, :, 2] == a[:, :, 3]).all() 936 tensor(True) 937 >>> torch.unique(a, dim=2) 938 tensor([[[0, 1], 939 [0, 1], 940 [1, 0]], 941 [[1, 0], 942 [1, 0], 943 [1, 1]], 944 [[0, 1], 945 [0, 1], 946 [1, 0]]]) 947 """ 948 if has_torch_function_unary(input): 949 return handle_torch_function( 950 unique, 951 (input,), 952 input, 953 sorted=sorted, 954 return_inverse=return_inverse, 955 return_counts=return_counts, 956 dim=dim, 957 ) 958 959 if dim is not None: 960 output, inverse_indices, counts = _VF.unique_dim( 961 input, 962 dim, 963 sorted=sorted, 964 return_inverse=return_inverse, 965 return_counts=return_counts, 966 ) 967 else: 968 output, inverse_indices, counts = torch._unique2( 969 input, 970 sorted=sorted, 971 return_inverse=return_inverse, 972 return_counts=return_counts, 973 ) 974 return output, inverse_indices, counts 975 976 977def _unique_consecutive_impl( 978 input: Tensor, 979 return_inverse: bool = False, 980 return_counts: bool = False, 981 dim: Optional[int] = None, 982) -> _unique_impl_out: 983 r"""Eliminates all but the first element from every consecutive group of equivalent elements. 984 985 .. note:: This function is different from :func:`torch.unique` in the sense that this function 986 only eliminates consecutive duplicate values. This semantics is similar to `std::unique` 987 in C++. 988 989 Args: 990 input (Tensor): the input tensor 991 return_inverse (bool): Whether to also return the indices for where 992 elements in the original input ended up in the returned unique list. 993 return_counts (bool): Whether to also return the counts for each unique 994 element. 995 dim (int): the dimension to apply unique. If ``None``, the unique of the 996 flattened input is returned. default: ``None`` 997 998 Returns: 999 (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing 1000 1001 - **output** (*Tensor*): the output list of unique scalar elements. 1002 - **inverse_indices** (*Tensor*): (optional) if 1003 :attr:`return_inverse` is True, there will be an additional 1004 returned tensor (same shape as input) representing the indices 1005 for where elements in the original input map to in the output; 1006 otherwise, this function will only return a single tensor. 1007 - **counts** (*Tensor*): (optional) if 1008 :attr:`return_counts` is True, there will be an additional 1009 returned tensor (same shape as output or output.size(dim), 1010 if dim was specified) representing the number of occurrences 1011 for each unique value or tensor. 1012 1013 Example:: 1014 1015 >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2]) 1016 >>> output = torch.unique_consecutive(x) 1017 >>> output 1018 tensor([1, 2, 3, 1, 2]) 1019 1020 >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True) 1021 >>> output 1022 tensor([1, 2, 3, 1, 2]) 1023 >>> inverse_indices 1024 tensor([0, 0, 1, 1, 2, 3, 3, 4]) 1025 1026 >>> output, counts = torch.unique_consecutive(x, return_counts=True) 1027 >>> output 1028 tensor([1, 2, 3, 1, 2]) 1029 >>> counts 1030 tensor([2, 2, 1, 2, 1]) 1031 """ 1032 if has_torch_function_unary(input): 1033 return handle_torch_function( 1034 unique_consecutive, 1035 (input,), 1036 input, 1037 return_inverse=return_inverse, 1038 return_counts=return_counts, 1039 dim=dim, 1040 ) 1041 output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined] 1042 input, return_inverse=return_inverse, return_counts=return_counts, dim=dim 1043 ) 1044 return output, inverse_indices, counts 1045 1046 1047def _return_counts( 1048 input, 1049 sorted=True, 1050 return_inverse=False, 1051 return_counts=False, 1052 dim=None, 1053): 1054 # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] 1055 1056 if has_torch_function_unary(input): 1057 return _unique_impl(input, sorted, return_inverse, return_counts, dim) 1058 1059 output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim) 1060 return output, counts 1061 1062 1063def _return_output( 1064 input, 1065 sorted=True, 1066 return_inverse=False, 1067 return_counts=False, 1068 dim=None, 1069): 1070 # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor 1071 1072 if has_torch_function_unary(input): 1073 return _unique_impl(input, sorted, return_inverse, return_counts, dim) 1074 1075 output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) 1076 return output 1077 1078 1079def _return_inverse( 1080 input, 1081 sorted=True, 1082 return_inverse=False, 1083 return_counts=False, 1084 dim=None, 1085): 1086 # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] 1087 1088 if has_torch_function_unary(input): 1089 return _unique_impl(input, sorted, return_inverse, return_counts, dim) 1090 1091 output, inverse_indices, _ = _unique_impl( 1092 input, sorted, return_inverse, return_counts, dim 1093 ) 1094 return output, inverse_indices 1095 1096 1097_return_inverse_false = boolean_dispatch( 1098 arg_name="return_counts", 1099 arg_index=3, 1100 default=False, 1101 if_true=_return_counts, 1102 if_false=_return_output, 1103 module_name=__name__, 1104 func_name="unique", 1105) 1106 1107_return_inverse_true = boolean_dispatch( 1108 arg_name="return_counts", 1109 arg_index=3, 1110 default=False, 1111 if_true=_unique_impl, 1112 if_false=_return_inverse, 1113 module_name=__name__, 1114 func_name="unique", 1115) 1116 1117# The return type of unique depends on `return_inverse`, and `return_counts` so in order to 1118# resolve the output type in TorchScript we need to statically know the value of both parameters 1119 1120unique = boolean_dispatch( 1121 arg_name="return_inverse", 1122 arg_index=2, 1123 default=False, 1124 if_true=_return_inverse_true, 1125 if_false=_return_inverse_false, 1126 module_name=__name__, 1127 func_name="unique", 1128) 1129unique.__doc__ = _unique_impl.__doc__ 1130 1131 1132def _consecutive_return_counts( 1133 input, 1134 return_inverse=False, 1135 return_counts=False, 1136 dim=None, 1137): 1138 # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] 1139 1140 if has_torch_function_unary(input): 1141 return _unique_consecutive_impl(input, return_inverse, return_counts, dim) 1142 1143 output, _, counts = _unique_consecutive_impl( 1144 input, return_inverse, return_counts, dim 1145 ) 1146 return output, counts 1147 1148 1149def _consecutive_return_output( 1150 input, 1151 return_inverse=False, 1152 return_counts=False, 1153 dim=None, 1154): 1155 # type: (Tensor, bool, bool, Optional[int]) -> Tensor 1156 1157 if has_torch_function_unary(input): 1158 return _unique_consecutive_impl(input, return_inverse, return_counts, dim) 1159 1160 output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) 1161 return output 1162 1163 1164def _consecutive_return_inverse( 1165 input, 1166 return_inverse=False, 1167 return_counts=False, 1168 dim=None, 1169): 1170 # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] 1171 1172 if has_torch_function_unary(input): 1173 return _unique_consecutive_impl(input, return_inverse, return_counts, dim) 1174 1175 output, inverse_indices, _ = _unique_consecutive_impl( 1176 input, return_inverse, return_counts, dim 1177 ) 1178 return output, inverse_indices 1179 1180 1181_consecutive_return_inverse_false = boolean_dispatch( 1182 arg_name="return_counts", 1183 arg_index=1, 1184 default=False, 1185 if_true=_consecutive_return_counts, 1186 if_false=_consecutive_return_output, 1187 module_name=__name__, 1188 func_name="unique_consecutive", 1189) 1190 1191_consecutive_return_inverse_true = boolean_dispatch( 1192 arg_name="return_counts", 1193 arg_index=1, 1194 default=False, 1195 if_true=_unique_consecutive_impl, 1196 if_false=_consecutive_return_inverse, 1197 module_name=__name__, 1198 func_name="unique_consecutive", 1199) 1200 1201# The return type of unique depends on `return_inverse`, and `return_counts` so in order to 1202# resolve the output type in TorchScript we need to statically know the value of both parameters 1203 1204unique_consecutive = boolean_dispatch( 1205 arg_name="return_inverse", 1206 arg_index=2, 1207 default=False, 1208 if_true=_consecutive_return_inverse_true, 1209 if_false=_consecutive_return_inverse_false, 1210 module_name=__name__, 1211 func_name="unique_consecutive", 1212) 1213unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__ 1214 1215if TYPE_CHECKING: 1216 pass 1217 # There's no good way to use this type annotation without breaking JIT 1218 # overloads. So leave untyped for mypy for now. 1219else: 1220 1221 @overload 1222 def tensordot( 1223 a, 1224 b, 1225 dims: int = 2, 1226 out: Optional[torch.Tensor] = None, 1227 ): 1228 pass 1229 1230 @overload 1231 def tensordot( # noqa: F811 1232 a, 1233 b, 1234 dims: Tuple[List[int], List[int]], 1235 out: Optional[torch.Tensor] = None, 1236 ): 1237 pass 1238 1239 @overload 1240 def tensordot( # noqa: F811 1241 a, 1242 b, 1243 dims: List[List[int]], 1244 out: Optional[torch.Tensor] = None, 1245 ): 1246 pass 1247 1248 @overload 1249 def tensordot( # noqa: F811 1250 a, 1251 b, 1252 dims: torch.Tensor, 1253 out: Optional[torch.Tensor] = None, 1254 ): 1255 pass 1256 1257 1258def tensordot( # noqa: F811 1259 a, 1260 b, 1261 dims=2, 1262 out: Optional[torch.Tensor] = None, 1263): 1264 r"""Returns a contraction of a and b over multiple dimensions. 1265 1266 :attr:`tensordot` implements a generalized matrix product. 1267 1268 Args: 1269 a (Tensor): Left tensor to contract 1270 b (Tensor): Right tensor to contract 1271 dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to 1272 contract or explicit lists of dimensions for :attr:`a` and 1273 :attr:`b` respectively 1274 1275 When called with a non-negative integer argument :attr:`dims` = :math:`d`, and 1276 the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, 1277 respectively, :func:`~torch.tensordot` computes 1278 1279 .. math:: 1280 r_{i_0,...,i_{m-d}, i_d,...,i_n} 1281 = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}. 1282 1283 When called with :attr:`dims` of the list form, the given dimensions will be contracted 1284 in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes 1285 in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted 1286 dimensions. 1287 1288 Examples:: 1289 1290 >>> a = torch.arange(60.).reshape(3, 4, 5) 1291 >>> b = torch.arange(24.).reshape(4, 3, 2) 1292 >>> torch.tensordot(a, b, dims=([1, 0], [0, 1])) 1293 tensor([[4400., 4730.], 1294 [4532., 4874.], 1295 [4664., 5018.], 1296 [4796., 5162.], 1297 [4928., 5306.]]) 1298 1299 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 1300 >>> a = torch.randn(3, 4, 5, device='cuda') 1301 >>> b = torch.randn(4, 5, 6, device='cuda') 1302 >>> c = torch.tensordot(a, b, dims=2).cpu() 1303 tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741], 1304 [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744], 1305 [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]]) 1306 1307 >>> a = torch.randn(3, 5, 4, 6) 1308 >>> b = torch.randn(6, 4, 5, 3) 1309 >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0])) 1310 tensor([[ 7.7193, -2.4867, -10.3204], 1311 [ 1.5513, -14.4737, -6.5113], 1312 [ -0.2850, 4.2573, -3.5997]]) 1313 """ 1314 if has_torch_function_variadic(a, b): 1315 return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out) 1316 1317 if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)): 1318 raise RuntimeError( 1319 "tensordot expects dims to be int or " 1320 + "Tuple[List[int], List[int]] or " 1321 + "List[List[int]] containing two lists, but got " 1322 + f"dims={dims}" 1323 ) 1324 1325 dims_a: List[int] = [] 1326 dims_b: List[int] = [] 1327 1328 if isinstance(dims, (tuple, list)): 1329 dims_a, dims_b = dims 1330 1331 if isinstance(dims, torch.Tensor): 1332 num_elements = dims.numel() 1333 if num_elements > 1: 1334 assert dims.size()[0] == 2 1335 dims_a = torch.jit.annotate(List[int], dims[0].tolist()) 1336 dims_b = torch.jit.annotate(List[int], dims[1].tolist()) 1337 else: 1338 dims_val = int(dims.item()) 1339 if dims_val < 0: 1340 raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}") 1341 dims_a = list(range(-dims_val, 0)) 1342 dims_b = list(range(dims_val)) 1343 1344 if isinstance(dims, (int, torch.SymInt)): 1345 if dims < 0: 1346 raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}") 1347 if dims > min(a.dim(), b.dim()): 1348 raise RuntimeError( 1349 f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}" 1350 ) 1351 dims_a = list(range(-dims, 0)) 1352 dims_b = list(range(dims)) 1353 1354 if out is None: 1355 return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore[attr-defined] 1356 else: 1357 return _VF.tensordot(a, b, dims_a, dims_b, out=out) # type: ignore[attr-defined] 1358 1359 1360def cartesian_prod(*tensors: Tensor) -> Tensor: 1361 """Do cartesian product of the given sequence of tensors. The behavior is similar to 1362 python's `itertools.product`. 1363 1364 Args: 1365 *tensors: any number of 1 dimensional tensors. 1366 1367 Returns: 1368 Tensor: A tensor equivalent to converting all the input tensors into lists, 1369 do `itertools.product` on these lists, and finally convert the resulting list 1370 into tensor. 1371 1372 Example:: 1373 1374 >>> import itertools 1375 >>> a = [1, 2, 3] 1376 >>> b = [4, 5] 1377 >>> list(itertools.product(a, b)) 1378 [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)] 1379 >>> tensor_a = torch.tensor(a) 1380 >>> tensor_b = torch.tensor(b) 1381 >>> torch.cartesian_prod(tensor_a, tensor_b) 1382 tensor([[1, 4], 1383 [1, 5], 1384 [2, 4], 1385 [2, 5], 1386 [3, 4], 1387 [3, 5]]) 1388 """ 1389 # This wrapper exists to support variadic args. 1390 if has_torch_function(tensors): 1391 return handle_torch_function(cartesian_prod, tensors, *tensors) 1392 return _VF.cartesian_prod(tensors) # type: ignore[attr-defined] 1393 1394 1395def block_diag(*tensors): 1396 """Create a block diagonal matrix from provided tensors. 1397 1398 Args: 1399 *tensors: One or more tensors with 0, 1, or 2 dimensions. 1400 1401 Returns: 1402 Tensor: A 2 dimensional tensor with all the input tensors arranged in 1403 order such that their upper left and lower right corners are 1404 diagonally adjacent. All other elements are set to 0. 1405 1406 Example:: 1407 1408 >>> import torch 1409 >>> A = torch.tensor([[0, 1], [1, 0]]) 1410 >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]]) 1411 >>> C = torch.tensor(7) 1412 >>> D = torch.tensor([1, 2, 3]) 1413 >>> E = torch.tensor([[4], [5], [6]]) 1414 >>> torch.block_diag(A, B, C, D, E) 1415 tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 1416 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1417 [0, 0, 3, 4, 5, 0, 0, 0, 0, 0], 1418 [0, 0, 6, 7, 8, 0, 0, 0, 0, 0], 1419 [0, 0, 0, 0, 0, 7, 0, 0, 0, 0], 1420 [0, 0, 0, 0, 0, 0, 1, 2, 3, 0], 1421 [0, 0, 0, 0, 0, 0, 0, 0, 0, 4], 1422 [0, 0, 0, 0, 0, 0, 0, 0, 0, 5], 1423 [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]]) 1424 """ 1425 # This wrapper exists to support variadic args. 1426 if has_torch_function(tensors): 1427 return handle_torch_function(block_diag, tensors, *tensors) 1428 return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined] 1429 1430 1431def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): 1432 # type: (Tensor, Tensor, float, str) -> (Tensor) 1433 r"""Computes batched the p-norm distance between each pair of the two collections of row vectors. 1434 1435 Args: 1436 x1 (Tensor): input tensor of shape :math:`B \times P \times M`. 1437 x2 (Tensor): input tensor of shape :math:`B \times R \times M`. 1438 p: p value for the p-norm distance to calculate between each vector pair 1439 :math:`\in [0, \infty]`. 1440 compute_mode: 1441 'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate 1442 euclidean distance (p = 2) if P > 25 or R > 25 1443 'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate 1444 euclidean distance (p = 2) 1445 'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate 1446 euclidean distance (p = 2) 1447 Default: use_mm_for_euclid_dist_if_necessary. 1448 1449 If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the 1450 output will have shape :math:`B \times P \times R`. 1451 1452 This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)` 1453 if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to 1454 `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest 1455 scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`. 1456 1457 Example: 1458 1459 >>> a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]]) 1460 >>> a 1461 tensor([[ 0.9041, 0.0196], 1462 [-0.3108, -2.4423], 1463 [-0.4821, 1.0590]]) 1464 >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]]) 1465 >>> b 1466 tensor([[-2.1763, -0.4713], 1467 [-0.6986, 1.3702]]) 1468 >>> torch.cdist(a, b, p=2) 1469 tensor([[3.1193, 2.0959], 1470 [2.7138, 3.8322], 1471 [2.2830, 0.3791]]) 1472 """ 1473 if has_torch_function_variadic(x1, x2): 1474 return handle_torch_function( 1475 cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode 1476 ) 1477 if compute_mode == "use_mm_for_euclid_dist_if_necessary": 1478 return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined] 1479 elif compute_mode == "use_mm_for_euclid_dist": 1480 return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined] 1481 elif compute_mode == "donot_use_mm_for_euclid_dist": 1482 return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined] 1483 else: 1484 raise ValueError(f"{compute_mode} is not a valid value for compute_mode") 1485 1486 1487def atleast_1d(*tensors): 1488 r""" 1489 Returns a 1-dimensional view of each input tensor with zero dimensions. 1490 Input tensors with one or more dimensions are returned as-is. 1491 1492 Args: 1493 input (Tensor or list of Tensors) 1494 1495 Returns: 1496 output (Tensor or tuple of Tensors) 1497 1498 Example:: 1499 1500 >>> x = torch.arange(2) 1501 >>> x 1502 tensor([0, 1]) 1503 >>> torch.atleast_1d(x) 1504 tensor([0, 1]) 1505 >>> x = torch.tensor(1.) 1506 >>> x 1507 tensor(1.) 1508 >>> torch.atleast_1d(x) 1509 tensor([1.]) 1510 >>> x = torch.tensor(0.5) 1511 >>> y = torch.tensor(1.) 1512 >>> torch.atleast_1d((x, y)) 1513 (tensor([0.5000]), tensor([1.])) 1514 """ 1515 # This wrapper exists to support variadic args. 1516 if has_torch_function(tensors): 1517 return handle_torch_function(atleast_1d, tensors, *tensors) 1518 if len(tensors) == 1: 1519 tensors = tensors[0] 1520 return _VF.atleast_1d(tensors) # type: ignore[attr-defined] 1521 1522 1523def atleast_2d(*tensors): 1524 r""" 1525 Returns a 2-dimensional view of each input tensor with zero dimensions. 1526 Input tensors with two or more dimensions are returned as-is. 1527 1528 Args: 1529 input (Tensor or list of Tensors) 1530 1531 Returns: 1532 output (Tensor or tuple of Tensors) 1533 1534 Example:: 1535 1536 >>> x = torch.tensor(1.) 1537 >>> x 1538 tensor(1.) 1539 >>> torch.atleast_2d(x) 1540 tensor([[1.]]) 1541 >>> x = torch.arange(4).view(2, 2) 1542 >>> x 1543 tensor([[0, 1], 1544 [2, 3]]) 1545 >>> torch.atleast_2d(x) 1546 tensor([[0, 1], 1547 [2, 3]]) 1548 >>> x = torch.tensor(0.5) 1549 >>> y = torch.tensor(1.) 1550 >>> torch.atleast_2d((x, y)) 1551 (tensor([[0.5000]]), tensor([[1.]])) 1552 """ 1553 # This wrapper exists to support variadic args. 1554 if has_torch_function(tensors): 1555 return handle_torch_function(atleast_2d, tensors, *tensors) 1556 if len(tensors) == 1: 1557 tensors = tensors[0] 1558 return _VF.atleast_2d(tensors) # type: ignore[attr-defined] 1559 1560 1561def atleast_3d(*tensors): 1562 r""" 1563 Returns a 3-dimensional view of each input tensor with zero dimensions. 1564 Input tensors with three or more dimensions are returned as-is. 1565 1566 Args: 1567 input (Tensor or list of Tensors) 1568 1569 Returns: 1570 output (Tensor or tuple of Tensors) 1571 1572 Example: 1573 1574 >>> x = torch.tensor(0.5) 1575 >>> x 1576 tensor(0.5000) 1577 >>> torch.atleast_3d(x) 1578 tensor([[[0.5000]]]) 1579 >>> y = torch.arange(4).view(2, 2) 1580 >>> y 1581 tensor([[0, 1], 1582 [2, 3]]) 1583 >>> torch.atleast_3d(y) 1584 tensor([[[0], 1585 [1]], 1586 <BLANKLINE> 1587 [[2], 1588 [3]]]) 1589 >>> x = torch.tensor(1).view(1, 1, 1) 1590 >>> x 1591 tensor([[[1]]]) 1592 >>> torch.atleast_3d(x) 1593 tensor([[[1]]]) 1594 >>> x = torch.tensor(0.5) 1595 >>> y = torch.tensor(1.0) 1596 >>> torch.atleast_3d((x, y)) 1597 (tensor([[[0.5000]]]), tensor([[[1.]]])) 1598 """ 1599 # This wrapper exists to support variadic args. 1600 if has_torch_function(tensors): 1601 return handle_torch_function(atleast_3d, tensors, *tensors) 1602 if len(tensors) == 1: 1603 tensors = tensors[0] 1604 return _VF.atleast_3d(tensors) # type: ignore[attr-defined] 1605 1606 1607if TYPE_CHECKING: 1608 pass 1609 # There's no good way to use this type annotation; cannot rename norm() to 1610 # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped 1611 # for mypy for now. 1612 # def norm(input: Tensor, 1613 # p: Optional[Union[str, Number]] = "fro", 1614 # dim: Optional[Union[int, List[int]]] = None, 1615 # keepdim: bool = False, 1616 # out: Optional[Tensor] = None, 1617 # dtype: _dtype = None) -> Tensor: 1618 # return _norm_impl(input, p, dim, keepdim, out, dtype) 1619else: 1620 # TODO: type dim as BroadcastingList when 1621 # https://github.com/pytorch/pytorch/issues/33782 is fixed 1622 @overload 1623 def norm( 1624 input, 1625 p="fro", 1626 dim=None, 1627 keepdim=False, 1628 out=None, 1629 dtype=None, 1630 ): 1631 # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor 1632 pass 1633 1634 @overload 1635 def norm( # noqa: F811 1636 input, 1637 p="fro", 1638 dim=None, 1639 keepdim=False, 1640 out=None, 1641 dtype=None, 1642 ): 1643 # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor 1644 pass 1645 1646 @overload 1647 def norm( # noqa: F811 1648 input, 1649 p="fro", 1650 dim=None, 1651 keepdim=False, 1652 out=None, 1653 dtype=None, 1654 ): 1655 # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor 1656 pass 1657 1658 @overload 1659 def norm( # noqa: F811 1660 input, 1661 p="fro", 1662 dim=None, 1663 keepdim=False, 1664 out=None, 1665 dtype=None, 1666 ): 1667 # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor 1668 pass 1669 1670 1671def norm( # noqa: F811 1672 input, 1673 p: Optional[Union[float, str]] = "fro", 1674 dim=None, 1675 keepdim=False, 1676 out=None, 1677 dtype=None, 1678): 1679 r"""Returns the matrix norm or vector norm of a given tensor. 1680 1681 .. warning:: 1682 1683 torch.norm is deprecated and may be removed in a future PyTorch release. 1684 Its documentation and behavior may be incorrect, and it is no longer 1685 actively maintained. 1686 1687 Use :func:`torch.linalg.vector_norm` when computing vector norms and 1688 :func:`torch.linalg.matrix_norm` when computing matrix norms. 1689 For a function with a similar behavior as this one see :func:`torch.linalg.norm`. 1690 Note, however, the signature for these functions is slightly different than the 1691 signature for ``torch.norm``. 1692 1693 Args: 1694 input (Tensor): The input tensor. Its data type must be either a floating 1695 point or complex type. For complex inputs, the norm is calculated using the 1696 absolute value of each element. If the input is complex and neither 1697 :attr:`dtype` nor :attr:`out` is specified, the result's data type will 1698 be the corresponding floating point type (e.g. float if :attr:`input` is 1699 complexfloat). 1700 1701 p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'`` 1702 The following norms can be calculated: 1703 1704 ====== ============== ========================== 1705 ord matrix norm vector norm 1706 ====== ============== ========================== 1707 'fro' Frobenius norm -- 1708 'nuc' nuclear norm -- 1709 Number -- sum(abs(x)**ord)**(1./ord) 1710 ====== ============== ========================== 1711 1712 The vector norm can be calculated across any number of dimensions. 1713 The corresponding dimensions of :attr:`input` are flattened into 1714 one dimension, and the norm is calculated on the flattened 1715 dimension. 1716 1717 Frobenius norm produces the same result as ``p=2`` in all cases 1718 except when :attr:`dim` is a list of three or more dims, in which 1719 case Frobenius norm throws an error. 1720 1721 Nuclear norm can only be calculated across exactly two dimensions. 1722 1723 dim (int, tuple of ints, list of ints, optional): 1724 Specifies which dimension or dimensions of :attr:`input` to 1725 calculate the norm across. If :attr:`dim` is ``None``, the norm will 1726 be calculated across all dimensions of :attr:`input`. If the norm 1727 type indicated by :attr:`p` does not support the specified number of 1728 dimensions, an error will occur. 1729 keepdim (bool, optional): whether the output tensors have :attr:`dim` 1730 retained or not. Ignored if :attr:`dim` = ``None`` and 1731 :attr:`out` = ``None``. Default: ``False`` 1732 out (Tensor, optional): the output tensor. Ignored if 1733 :attr:`dim` = ``None`` and :attr:`out` = ``None``. 1734 dtype (:class:`torch.dtype`, optional): the desired data type of 1735 returned tensor. If specified, the input tensor is casted to 1736 :attr:`dtype` while performing the operation. Default: None. 1737 1738 .. note:: 1739 Even though ``p='fro'`` supports any number of dimensions, the true 1740 mathematical definition of Frobenius norm only applies to tensors with 1741 exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'`` 1742 aligns with the mathematical definition, since it can only be applied across 1743 exactly two dimensions. 1744 1745 Example:: 1746 1747 >>> import torch 1748 >>> a = torch.arange(9, dtype= torch.float) - 4 1749 >>> b = a.reshape((3, 3)) 1750 >>> torch.norm(a) 1751 tensor(7.7460) 1752 >>> torch.norm(b) 1753 tensor(7.7460) 1754 >>> torch.norm(a, float('inf')) 1755 tensor(4.) 1756 >>> torch.norm(b, float('inf')) 1757 tensor(4.) 1758 >>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float) 1759 >>> torch.norm(c, dim=0) 1760 tensor([1.4142, 2.2361, 5.0000]) 1761 >>> torch.norm(c, dim=1) 1762 tensor([3.7417, 4.2426]) 1763 >>> torch.norm(c, p=1, dim=1) 1764 tensor([6., 6.]) 1765 >>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) 1766 >>> torch.norm(d, dim=(1, 2)) 1767 tensor([ 3.7417, 11.2250]) 1768 >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :]) 1769 (tensor(3.7417), tensor(11.2250)) 1770 """ 1771 1772 if has_torch_function_unary(input): 1773 return handle_torch_function( 1774 norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype 1775 ) 1776 1777 # NB. All the repeated code and weird python is to please TorchScript. 1778 # For a more compact implementation see the relevant function in `_refs/__init__.py` 1779 1780 # We don't do this for MPS or sparse tensors 1781 if input.layout == torch.strided and input.device.type in ( 1782 "cpu", 1783 "cuda", 1784 "meta", 1785 torch.utils.backend_registration._privateuse1_backend_name, 1786 ): 1787 if dim is not None: 1788 if isinstance(dim, (int, torch.SymInt)): 1789 _dim = [dim] 1790 else: 1791 _dim = dim 1792 else: 1793 _dim = None # type: ignore[assignment] 1794 1795 if isinstance(p, str): 1796 if p == "fro" and ( 1797 dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2 1798 ): 1799 if out is None: 1800 return torch.linalg.vector_norm( 1801 input, 2, _dim, keepdim, dtype=dtype 1802 ) 1803 else: 1804 return torch.linalg.vector_norm( 1805 input, 2, _dim, keepdim, dtype=dtype, out=out 1806 ) 1807 1808 # Here we either call the nuclear norm, or we call matrix_norm with some arguments 1809 # that will throw an error 1810 if _dim is None: 1811 _dim = list(range(input.ndim)) 1812 if out is None: 1813 return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype) 1814 else: 1815 return torch.linalg.matrix_norm( 1816 input, p, _dim, keepdim, dtype=dtype, out=out 1817 ) 1818 else: 1819 # NB. p should be Union[str, number], not Optional! 1820 _p = 2.0 if p is None else p 1821 if out is None: 1822 return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype) 1823 else: 1824 return torch.linalg.vector_norm( 1825 input, _p, _dim, keepdim, dtype=dtype, out=out 1826 ) 1827 1828 ndim = input.dim() 1829 1830 # catch default case 1831 if dim is None and out is None and dtype is None and p is not None: 1832 if isinstance(p, str): 1833 if p == "fro": 1834 return _VF.frobenius_norm(input, dim=(), keepdim=keepdim) 1835 if not isinstance(p, str): 1836 _dim = list(range(ndim)) 1837 return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined] 1838 1839 # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed 1840 # remove the overloads where dim is an int and replace with BraodcastingList1 1841 # and remove next four lines, replace _dim with dim 1842 if dim is not None: 1843 if isinstance(dim, (int, torch.SymInt)): 1844 _dim = [dim] 1845 else: 1846 _dim = dim 1847 else: 1848 _dim = None # type: ignore[assignment] 1849 1850 if isinstance(p, str): 1851 if p == "fro": 1852 if dtype is not None: 1853 raise ValueError("dtype argument is not supported in frobenius norm") 1854 1855 if _dim is None: 1856 _dim = list(range(ndim)) 1857 if out is None: 1858 return _VF.frobenius_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type] 1859 else: 1860 return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type] 1861 elif p == "nuc": 1862 if dtype is not None: 1863 raise ValueError("dtype argument is not supported in nuclear norm") 1864 if _dim is None: 1865 if out is None: 1866 return _VF.nuclear_norm(input, keepdim=keepdim) # type: ignore[arg-type] 1867 else: 1868 return _VF.nuclear_norm(input, keepdim=keepdim, out=out) # type: ignore[arg-type] 1869 else: 1870 if out is None: 1871 return _VF.nuclear_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type] 1872 else: 1873 return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type] 1874 raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}") 1875 else: 1876 if _dim is None: 1877 _dim = list(range(ndim)) 1878 1879 if out is None: 1880 if dtype is None: 1881 return _VF.norm(input, p, _dim, keepdim=keepdim) # type: ignore[attr-defined] 1882 else: 1883 return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype) # type: ignore[attr-defined] 1884 else: 1885 if dtype is None: 1886 return _VF.norm(input, p, _dim, keepdim=keepdim, out=out) # type: ignore[attr-defined] 1887 else: 1888 return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined] 1889 1890 1891def unravel_index( 1892 indices: Tensor, 1893 shape: Union[int, Sequence[int], torch.Size], 1894) -> Tuple[Tensor, ...]: 1895 r"""Converts a tensor of flat indices into a tuple of coordinate tensors that 1896 index into an arbitrary tensor of the specified shape. 1897 1898 Args: 1899 indices (Tensor): An integer tensor containing indices into the 1900 flattened version of an arbitrary tensor of shape :attr:`shape`. 1901 All elements must be in the range ``[0, prod(shape) - 1]``. 1902 1903 shape (int, sequence of ints, or torch.Size): The shape of the arbitrary 1904 tensor. All elements must be non-negative. 1905 1906 Returns: 1907 tuple of Tensors: Each ``i``-th tensor in the output corresponds with 1908 dimension ``i`` of :attr:`shape`. Each tensor has the same shape as 1909 ``indices`` and contains one index into dimension ``i`` for each of the 1910 flat indices given by ``indices``. 1911 1912 Example:: 1913 1914 >>> import torch 1915 >>> torch.unravel_index(torch.tensor(4), (3, 2)) 1916 (tensor(2), 1917 tensor(0)) 1918 1919 >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2)) 1920 (tensor([2, 0]), 1921 tensor([0, 1])) 1922 1923 >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2)) 1924 (tensor([0, 0, 1, 1, 2, 2]), 1925 tensor([0, 1, 0, 1, 0, 1])) 1926 1927 >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10)) 1928 (tensor([1, 5]), 1929 tensor([2, 6]), 1930 tensor([3, 7]), 1931 tensor([4, 8])) 1932 1933 >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10)) 1934 (tensor([[1], [5]]), 1935 tensor([[2], [6]]), 1936 tensor([[3], [7]]), 1937 tensor([[4], [8]])) 1938 1939 >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100)) 1940 (tensor([[12], [56]]), 1941 tensor([[34], [78]])) 1942 """ 1943 if has_torch_function_unary(indices): 1944 return handle_torch_function(unravel_index, (indices,), indices, shape=shape) 1945 res_tensor = _unravel_index(indices, shape) 1946 return res_tensor.unbind(-1) 1947 1948 1949def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor: 1950 torch._check_type( 1951 not indices.is_complex() 1952 and not indices.is_floating_point() 1953 and not indices.dtype == torch.bool, 1954 lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}", 1955 ) 1956 1957 torch._check_type( 1958 isinstance(shape, (int, torch.SymInt, Sequence)), 1959 lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}", 1960 ) 1961 1962 if isinstance(shape, (int, torch.SymInt)): 1963 shape = torch.Size([shape]) 1964 else: 1965 for dim in shape: 1966 torch._check_type( 1967 isinstance(dim, (int, torch.SymInt)), 1968 lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}", 1969 ) 1970 shape = torch.Size(shape) 1971 1972 torch._check_value( 1973 all(dim >= 0 for dim in shape), 1974 lambda: f"'shape' cannot have negative values, but got {tuple(shape)}", 1975 ) 1976 1977 coefs = list( 1978 reversed( 1979 list( 1980 itertools.accumulate( 1981 reversed(shape[1:] + torch.Size([1])), func=operator.mul 1982 ) 1983 ) 1984 ) 1985 ) 1986 return indices.unsqueeze(-1).floor_divide( 1987 torch.tensor(coefs, device=indices.device, dtype=torch.int64) 1988 ) % torch.tensor(shape, device=indices.device, dtype=torch.int64) 1989 1990 1991def chain_matmul(*matrices, out=None): 1992 r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed 1993 using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms 1994 of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N` 1995 needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned. 1996 If :math:`N` is 1, then this is a no-op - the original matrix is returned as is. 1997 1998 .. warning:: 1999 2000 :func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release. 2001 Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors 2002 rather than multiple arguments. 2003 2004 Args: 2005 matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined. 2006 out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``. 2007 2008 Returns: 2009 Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product 2010 would be of dimensions :math:`p_{1} \times p_{N + 1}`. 2011 2012 Example:: 2013 2014 >>> # xdoctest: +SKIP 2015 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 2016 >>> a = torch.randn(3, 4) 2017 >>> b = torch.randn(4, 5) 2018 >>> c = torch.randn(5, 6) 2019 >>> d = torch.randn(6, 7) 2020 >>> # will raise a deprecation warning 2021 >>> torch.chain_matmul(a, b, c, d) 2022 tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614], 2023 [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163], 2024 [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]]) 2025 2026 .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition 2027 """ 2028 # This wrapper exists to support variadic args. 2029 if has_torch_function(matrices): 2030 return handle_torch_function(chain_matmul, matrices, *matrices) 2031 2032 if out is None: 2033 return _VF.chain_matmul(matrices) # type: ignore[attr-defined] 2034 else: 2035 return _VF.chain_matmul(matrices, out=out) # type: ignore[attr-defined] 2036 2037 2038def _lu_impl(A, pivot=True, get_infos=False, out=None): 2039 # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor] 2040 r"""Computes the LU factorization of a matrix or batches of matrices 2041 :attr:`A`. Returns a tuple containing the LU factorization and 2042 pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to 2043 ``True``. 2044 2045 .. warning:: 2046 2047 :func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor` 2048 and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a 2049 future PyTorch release. 2050 ``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with 2051 2052 .. code:: python 2053 2054 LU, pivots = torch.linalg.lu_factor(A, compute_pivots) 2055 2056 ``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with 2057 2058 .. code:: python 2059 2060 LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) 2061 2062 .. note:: 2063 * The returned permutation matrix for every matrix in the batch is 2064 represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``. 2065 ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm, 2066 the ``i``-th row was permuted with the ``j-1``-th row. 2067 * LU factorization with :attr:`pivot` = ``False`` is not available 2068 for CPU, and attempting to do so will throw an error. However, 2069 LU factorization with :attr:`pivot` = ``False`` is available for 2070 CUDA. 2071 * This function does not check if the factorization was successful 2072 or not if :attr:`get_infos` is ``True`` since the status of the 2073 factorization is present in the third element of the return tuple. 2074 * In the case of batches of square matrices with size less or equal 2075 to 32 on a CUDA device, the LU factorization is repeated for 2076 singular matrices due to the bug in the MAGMA library 2077 (see magma issue 13). 2078 * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`. 2079 2080 .. warning:: 2081 The gradients of this function will only be finite when :attr:`A` is full rank. 2082 This is because the LU decomposition is just differentiable at full rank matrices. 2083 Furthermore, if :attr:`A` is close to not being full rank, 2084 the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`. 2085 2086 Args: 2087 A (Tensor): the tensor to factor of size :math:`(*, m, n)` 2088 pivot (bool, optional): controls whether pivoting is done. Default: ``True`` 2089 get_infos (bool, optional): if set to ``True``, returns an info IntTensor. 2090 Default: ``False`` 2091 out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``, 2092 then the elements in the tuple are Tensor, IntTensor, 2093 and IntTensor. If :attr:`get_infos` is ``False``, then the 2094 elements in the tuple are Tensor, IntTensor. Default: ``None`` 2095 2096 Returns: 2097 (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing 2098 2099 - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)` 2100 2101 - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`. 2102 ``pivots`` stores all the intermediate transpositions of rows. 2103 The final permutation ``perm`` could be reconstructed by 2104 applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``, 2105 where ``perm`` is initially the identity permutation of :math:`m` elements 2106 (essentially this is what :func:`torch.lu_unpack` is doing). 2107 2108 - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of 2109 size :math:`(*)` where non-zero values indicate whether factorization for the matrix or 2110 each minibatch has succeeded or failed 2111 2112 Example:: 2113 2114 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 2115 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 2116 >>> A = torch.randn(2, 3, 3) 2117 >>> A_LU, pivots = torch.lu(A) 2118 >>> A_LU 2119 tensor([[[ 1.3506, 2.5558, -0.0816], 2120 [ 0.1684, 1.1551, 0.1940], 2121 [ 0.1193, 0.6189, -0.5497]], 2122 2123 [[ 0.4526, 1.2526, -0.3285], 2124 [-0.7988, 0.7175, -0.9701], 2125 [ 0.2634, -0.9255, -0.3459]]]) 2126 >>> pivots 2127 tensor([[ 3, 3, 3], 2128 [ 3, 3, 3]], dtype=torch.int32) 2129 >>> A_LU, pivots, info = torch.lu(A, get_infos=True) 2130 >>> if info.nonzero().size(0) == 0: 2131 ... print('LU factorization succeeded for all samples!') 2132 LU factorization succeeded for all samples! 2133 """ 2134 # If get_infos is True, then we don't need to check for errors and vice versa 2135 return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos)) 2136 2137 2138if TYPE_CHECKING: 2139 _ListOrSeq = Sequence[Tensor] 2140else: 2141 _ListOrSeq = List[Tensor] 2142 2143 2144def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None: 2145 get_infos_int = 1 if get_infos else 0 2146 if out_len - get_infos_int != 2: 2147 raise TypeError( 2148 f"expected tuple of {2 + int(get_infos)} elements but got {out_len}" 2149 ) 2150 if not isinstance(out, (tuple, list)): 2151 raise TypeError( 2152 f"argument 'out' must be tuple of Tensors, not {type(out).__name__}" 2153 ) 2154 2155 2156def _lu_with_infos(A, pivot=True, get_infos=False, out=None): 2157 # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor] 2158 if has_torch_function_unary(A): 2159 return handle_torch_function( 2160 lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out 2161 ) 2162 result = _lu_impl(A, pivot, get_infos, out) 2163 if out is not None: 2164 _check_list_size(len(out), get_infos, out) 2165 for i in range(len(out)): 2166 out[i].resize_as_(result[i]).copy_(result[i]) 2167 return out 2168 else: 2169 return result # A_LU, pivots, infos 2170 2171 2172def _lu_no_infos(A, pivot=True, get_infos=False, out=None): 2173 # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] 2174 # need to check for torch_function here so that we exit if 2175 if has_torch_function_unary(A): 2176 return handle_torch_function( 2177 lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out 2178 ) 2179 result = _lu_impl(A, pivot, get_infos, out) 2180 if out is not None: 2181 _check_list_size(len(out), get_infos, out) 2182 for i in range(len(out)): 2183 out[i].resize_as_(result[i]).copy_(result[i]) 2184 return out 2185 else: 2186 return result[0], result[1] # A_LU, pivots 2187 2188 2189# The return type of lu depends on `get_infos`, so in order to resolve the output type 2190# of lu in TorchScript we need to statically know the value of `get_infos` 2191lu = boolean_dispatch( 2192 arg_name="get_infos", 2193 arg_index=2, 2194 default=False, 2195 if_true=_lu_with_infos, 2196 if_false=_lu_no_infos, 2197 module_name=__name__, 2198 func_name="lu", 2199) 2200lu.__doc__ = _lu_impl.__doc__ 2201 2202 2203def align_tensors(*tensors): 2204 raise RuntimeError("`align_tensors` not yet implemented.") 2205