xref: /aosp_15_r20/external/pytorch/torch/functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3import operator
4from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
5
6import torch
7import torch.nn.functional as F
8from torch import _VF, Tensor
9from torch._C import _add_docstr
10from torch._jit_internal import _overload as overload, boolean_dispatch
11from torch._lowrank import pca_lowrank, svd_lowrank
12from torch.overrides import (
13    handle_torch_function,
14    has_torch_function,
15    has_torch_function_unary,
16    has_torch_function_variadic,
17)
18
19
20__all__ = [
21    "atleast_1d",
22    "atleast_2d",
23    "atleast_3d",
24    "align_tensors",
25    "broadcast_shapes",
26    "broadcast_tensors",
27    "cartesian_prod",
28    "block_diag",
29    "cdist",
30    "chain_matmul",
31    "einsum",
32    "istft",
33    "lu",
34    "norm",
35    "meshgrid",
36    "pca_lowrank",
37    "split",
38    "stft",
39    "svd_lowrank",
40    "tensordot",
41    "unique",
42    "unique_consecutive",
43    "unravel_index",
44]
45
46
47def broadcast_tensors(*tensors):
48    r"""broadcast_tensors(*tensors) -> List of Tensors
49
50    Broadcasts the given tensors according to :ref:`broadcasting-semantics`.
51
52    Args:
53        *tensors: any number of tensors of the same type
54
55    .. warning::
56
57        More than one element of a broadcasted tensor may refer to a single
58        memory location. As a result, in-place operations (especially ones that
59        are vectorized) may result in incorrect behavior. If you need to write
60        to the tensors, please clone them first.
61
62    Example::
63
64        >>> x = torch.arange(3).view(1, 3)
65        >>> y = torch.arange(2).view(2, 1)
66        >>> a, b = torch.broadcast_tensors(x, y)
67        >>> a.size()
68        torch.Size([2, 3])
69        >>> a
70        tensor([[0, 1, 2],
71                [0, 1, 2]])
72    """
73    # This wrapper exists to support variadic args.
74    if has_torch_function(tensors):
75        return handle_torch_function(broadcast_tensors, tensors, *tensors)
76    return _VF.broadcast_tensors(tensors)  # type: ignore[attr-defined]
77
78
79def broadcast_shapes(*shapes):
80    r"""broadcast_shapes(*shapes) -> Size
81
82    Similar to :func:`broadcast_tensors` but for shapes.
83
84    This is equivalent to
85    ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
86    but avoids the need create to intermediate tensors. This is useful for
87    broadcasting tensors of common batch shape but different rightmost shape,
88    e.g. to broadcast mean vectors with covariance matrices.
89
90    Example::
91
92        >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
93        torch.Size([1, 3, 2])
94
95    Args:
96        \*shapes (torch.Size): Shapes of tensors.
97
98    Returns:
99        shape (torch.Size): A shape compatible with all input shapes.
100
101    Raises:
102        RuntimeError: If shapes are incompatible.
103    """
104    # This wrapper exists to support variadic args.
105    # TODO Move this to C++ once the jit has better support for torch.Size.
106    if not torch.jit.is_tracing():
107        max_len = 0
108        for shape in shapes:
109            if isinstance(shape, (int, torch.SymInt)):
110                if max_len < 1:
111                    max_len = 1
112            elif isinstance(shape, (tuple, list)):
113                s = len(shape)
114                if max_len < s:
115                    max_len = s
116        result = [1] * max_len
117
118        from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
119
120        for shape in shapes:
121            if isinstance(shape, (int, torch.SymInt)):
122                shape = (shape,)
123            if isinstance(shape, (tuple, list)):
124                for i in range(-1, -1 - len(shape), -1):
125                    if shape[i] < 0:
126                        raise RuntimeError(
127                            f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
128                        )
129                    # NB: result is initialized to 1 so this is effectively an
130                    # equals one test
131                    if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
132                        shape[i] == result[i]
133                    ):
134                        continue
135                    if result[i] != 1:
136                        raise RuntimeError(
137                            "Shape mismatch: objects cannot be broadcast to a single shape"
138                        )
139                    result[i] = shape[i]
140            else:
141                raise RuntimeError(
142                    "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
143                    shape,
144                )
145        return torch.Size(result)
146    else:
147        # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
148        with torch.no_grad():
149            scalar = torch.zeros((), device="cpu")
150            tensors = [scalar.expand(shape) for shape in shapes]
151            tensors = broadcast_tensors(*tensors)
152            return tensors[0].shape
153
154
155def split(
156    tensor: Tensor,
157    split_size_or_sections: Union[int, List[int]],
158    dim: int = 0,
159) -> Tuple[Tensor, ...]:
160    r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
161
162    If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
163    be split into equally sized chunks (if possible). Last chunk will be smaller if
164    the tensor size along the given dimension :attr:`dim` is not divisible by
165    :attr:`split_size`.
166
167    If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
168    into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
169    to :attr:`split_size_or_sections`.
170
171    Args:
172        tensor (Tensor): tensor to split.
173        split_size_or_sections (int) or (list(int)): size of a single chunk or
174            list of sizes for each chunk
175        dim (int): dimension along which to split the tensor.
176
177    Example::
178
179        >>> a = torch.arange(10).reshape(5, 2)
180        >>> a
181        tensor([[0, 1],
182                [2, 3],
183                [4, 5],
184                [6, 7],
185                [8, 9]])
186        >>> torch.split(a, 2)
187        (tensor([[0, 1],
188                 [2, 3]]),
189         tensor([[4, 5],
190                 [6, 7]]),
191         tensor([[8, 9]]))
192        >>> torch.split(a, [1, 4])
193        (tensor([[0, 1]]),
194         tensor([[2, 3],
195                 [4, 5],
196                 [6, 7],
197                 [8, 9]]))
198    """
199    if has_torch_function_unary(tensor):
200        return handle_torch_function(
201            split, (tensor,), tensor, split_size_or_sections, dim=dim
202        )
203    # Overwriting reason:
204    # This dispatches to two ATen functions depending on the type of
205    # split_size_or_sections. The branching code is in _tensor.py, which we
206    # call here.
207    return tensor.split(split_size_or_sections, dim)
208
209
210def einsum(*args: Any) -> Tensor:
211    r"""einsum(equation, *operands) -> Tensor
212
213    Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
214    based on the Einstein summation convention.
215
216    Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
217    in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
218    this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
219    with some subscript and define which subscripts are part of the output. The output is then computed by summing
220    the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
221    output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`.
222    Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).
223
224    Equation:
225
226        The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of
227        the input :attr:`operands` in the same order as the dimensions, separating subscripts for each operand by a
228        comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript
229        must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is
230        repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand
231        must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that
232        appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.
233        The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based
234        on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.
235
236        Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation
237        followed by the subscripts for the output. For instance, the following equation computes the transpose of a
238        matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and
239        at most once for the output.
240
241        Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.
242        Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,
243        e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth
244        dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the
245        'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not
246        explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),
247        before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements
248        batch matrix multiplication `'...ij,...jk'`.
249
250        A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,
251        arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.
252
253    .. note::
254
255        ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions
256        covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.
257
258    .. note::
259
260        This function uses opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) to speed up computation or to
261        consume less memory by optimizing contraction order. This optimization occurs when there are at least three
262        inputs, since the order does not matter otherwise. Note that finding _the_ optimal path is an NP-hard problem,
263        thus, opt_einsum relies on different heuristics to achieve near-optimal results. If opt_einsum is not available,
264        the default order is to contract from left to right.
265
266        To bypass this default behavior, add the following line to disable the usage of opt_einsum and skip path
267        calculation: `torch.backends.opt_einsum.enabled = False`
268
269        To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line:
270        `torch.backends.opt_einsum.strategy = 'auto'`. The default strategy is 'auto', and we also support 'greedy' and
271        'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in
272        the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html).
273
274    .. note::
275
276        As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format,
277        subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists
278        follow their operands, and an extra sublist can appear at the end of the input to specify the output's
279        subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object
280        may be provided in a sublist to enable broadcasting as described in the Equation section above.
281
282    Args:
283        equation (str): The subscripts for the Einstein summation.
284        operands (List[Tensor]): The tensors to compute the Einstein summation of.
285
286    Examples::
287
288        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
289        >>> # trace
290        >>> torch.einsum('ii', torch.randn(4, 4))
291        tensor(-1.2104)
292
293        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
294        >>> # diagonal
295        >>> torch.einsum('ii->i', torch.randn(4, 4))
296        tensor([-0.1034,  0.7952, -0.2433,  0.4545])
297
298        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
299        >>> # outer product
300        >>> x = torch.randn(5)
301        >>> y = torch.randn(4)
302        >>> torch.einsum('i,j->ij', x, y)
303        tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
304                [-0.3744,  0.9381,  1.2685, -1.6070],
305                [ 0.7208, -1.8058, -2.4419,  3.0936],
306                [ 0.1713, -0.4291, -0.5802,  0.7350],
307                [ 0.5704, -1.4290, -1.9323,  2.4480]])
308
309        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
310        >>> # batch matrix multiplication
311        >>> As = torch.randn(3, 2, 5)
312        >>> Bs = torch.randn(3, 5, 4)
313        >>> torch.einsum('bij,bjk->bik', As, Bs)
314        tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
315                [-1.6706, -0.8097, -0.8025, -2.1183]],
316
317                [[ 4.2239,  0.3107, -0.5756, -0.2354],
318                [-1.4558, -0.3460,  1.5087, -0.8530]],
319
320                [[ 2.8153,  1.8787, -4.3839, -1.2112],
321                [ 0.3728, -2.1131,  0.0921,  0.8305]]])
322
323        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
324        >>> # with sublist format and ellipsis
325        >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
326        tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
327                [-1.6706, -0.8097, -0.8025, -2.1183]],
328
329                [[ 4.2239,  0.3107, -0.5756, -0.2354],
330                [-1.4558, -0.3460,  1.5087, -0.8530]],
331
332                [[ 2.8153,  1.8787, -4.3839, -1.2112],
333                [ 0.3728, -2.1131,  0.0921,  0.8305]]])
334
335        >>> # batch permute
336        >>> A = torch.randn(2, 3, 4, 5)
337        >>> torch.einsum('...ij->...ji', A).shape
338        torch.Size([2, 3, 5, 4])
339
340        >>> # equivalent to torch.nn.functional.bilinear
341        >>> A = torch.randn(3, 5, 4)
342        >>> l = torch.randn(2, 5)
343        >>> r = torch.randn(2, 4)
344        >>> torch.einsum('bn,anm,bm->ba', l, A, r)
345        tensor([[-0.3430, -5.2405,  0.4494],
346                [ 0.3311,  5.5201, -3.0356]])
347    """
348    import torch.backends.opt_einsum as opt_einsum
349
350    # This wrapper exists to support variadic args.
351    if len(args) < 2:
352        raise ValueError(
353            "einsum(): must specify the equation string and at least one operand, "
354            "or at least one operand and its subscripts list"
355        )
356
357    equation = None
358    operands = None
359
360    if isinstance(args[0], torch.Tensor):
361        # Convert the subscript list format which is an interleaving of operand and its subscripts
362        # list with an optional output subscripts list at the end (see documentation for more details on this)
363        # to the equation string format by creating the equation string from the subscripts list and grouping the
364        # input operands into a tensorlist (List[Tensor]).
365        def parse_subscript(n: int) -> str:
366            if n == Ellipsis:
367                return "..."
368            if n >= 0 and n < 26:
369                return chr(ord("A") + n)
370            if n >= 26 and n < 52:
371                return chr(ord("a") + n - 26)
372            raise ValueError(
373                "einsum(): subscript in subscript list is not within the valid range [0, 52)"
374            )
375
376        # Parse subscripts for input operands
377        equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2])
378
379        # Parse optional output subscripts (provided when the number of arguments is odd)
380        if len(args) % 2 == 1:
381            equation += "->" + "".join(parse_subscript(s) for s in args[-1])
382            operands = args[:-1:2]
383        else:
384            operands = args[::2]
385    else:
386        equation = args[0]
387        operands = args[1:]
388
389    if has_torch_function(operands):
390        return handle_torch_function(einsum, operands, equation, *operands)
391
392    if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
393        # the old interface of passing the operands as one list argument
394        _operands = operands[0]
395        # recurse incase operands contains value that has torch function
396        # in the original implementation this line is omitted
397        return einsum(equation, *_operands)
398
399    if len(operands) <= 2 or not opt_einsum.enabled:
400        # the path for contracting 0 or 1 time(s) is already optimized
401        # or the user has disabled using opt_einsum
402        return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
403
404    path = None
405    if opt_einsum.is_available():
406        _opt_einsum = opt_einsum.get_opt_einsum()
407        tupled_path = _opt_einsum.contract_path(
408            equation, *operands, optimize=opt_einsum.strategy
409        )[0]
410        # flatten path for dispatching to C++
411        path = [item for pair in tupled_path for item in pair]
412    return _VF.einsum(equation, operands, path=path)  # type: ignore[attr-defined]
413
414
415# This wrapper exists to support variadic args.
416if TYPE_CHECKING:
417    # The JIT doesn't understand Union, so only add type annotation for mypy
418    def meshgrid(
419        *tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None
420    ) -> Tuple[Tensor, ...]:
421        return _meshgrid(*tensors, indexing=indexing)
422
423else:
424
425    def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
426        r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
427
428        This is helpful when you want to visualize data over some
429        range of inputs. See below for a plotting example.
430
431        Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as
432        inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`,
433        this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots
434        G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where
435        the output :math:`G_i` is constructed by expanding :math:`T_i`
436        to the result shape.
437
438        .. note::
439            0D inputs are treated equivalently to 1D inputs of a
440            single element.
441
442        .. warning::
443            `torch.meshgrid(*tensors)` currently has the same behavior
444            as calling `numpy.meshgrid(*arrays, indexing='ij')`.
445
446            In the future `torch.meshgrid` will transition to
447            `indexing='xy'` as the default.
448
449            https://github.com/pytorch/pytorch/issues/50276 tracks
450            this issue with the goal of migrating to NumPy's behavior.
451
452        .. seealso::
453
454            :func:`torch.cartesian_prod` has the same effect but it
455            collects the data in a tensor of vectors.
456
457        Args:
458            tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
459                treated as tensors of size :math:`(1,)` automatically
460
461            indexing: (str, optional): the indexing mode, either "xy"
462                or "ij", defaults to "ij". See warning for future changes.
463
464                If "xy" is selected, the first dimension corresponds
465                to the cardinality of the second input and the second
466                dimension corresponds to the cardinality of the first
467                input.
468
469                If "ij" is selected, the dimensions are in the same
470                order as the cardinality of the inputs.
471
472        Returns:
473            seq (sequence of Tensors): If the input has :math:`N`
474            tensors of size :math:`S_0 \ldots S_{N-1}``, then the
475            output will also have :math:`N` tensors, where each tensor
476            is of shape :math:`(S_0, ..., S_{N-1})`.
477
478        Example::
479
480            >>> x = torch.tensor([1, 2, 3])
481            >>> y = torch.tensor([4, 5, 6])
482
483            Observe the element-wise pairings across the grid, (1, 4),
484            (1, 5), ..., (3, 6). This is the same thing as the
485            cartesian product.
486            >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
487            >>> grid_x
488            tensor([[1, 1, 1],
489                    [2, 2, 2],
490                    [3, 3, 3]])
491            >>> grid_y
492            tensor([[4, 5, 6],
493                    [4, 5, 6],
494                    [4, 5, 6]])
495
496            This correspondence can be seen when these grids are
497            stacked properly.
498            >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))),
499            ...             torch.cartesian_prod(x, y))
500            True
501
502            `torch.meshgrid` is commonly used to produce a grid for
503            plotting.
504            >>> # xdoctest: +REQUIRES(module:matplotlib)
505            >>> # xdoctest: +REQUIRES(env:DOCTEST_SHOW)
506            >>> import matplotlib.pyplot as plt
507            >>> xs = torch.linspace(-5, 5, steps=100)
508            >>> ys = torch.linspace(-5, 5, steps=100)
509            >>> x, y = torch.meshgrid(xs, ys, indexing='xy')
510            >>> z = torch.sin(torch.sqrt(x * x + y * y))
511            >>> ax = plt.axes(projection='3d')
512            >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
513            >>> plt.show()
514
515        .. image:: ../_static/img/meshgrid.png
516            :width: 512
517
518        """
519        return _meshgrid(*tensors, indexing=indexing)
520
521
522def _meshgrid(*tensors, indexing: Optional[str]):
523    if has_torch_function(tensors):
524        return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing)
525    if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
526        # the old interface of passing the operands as one list argument
527        tensors = tensors[0]  # type: ignore[assignment]
528
529    # Continue allowing call of old method that takes no indexing
530    # kwarg for forward compatibility reasons.
531    #
532    # Remove this two weeks after landing.
533    kwargs = {} if indexing is None else {"indexing": indexing}
534    return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
535
536
537def stft(
538    input: Tensor,
539    n_fft: int,
540    hop_length: Optional[int] = None,
541    win_length: Optional[int] = None,
542    window: Optional[Tensor] = None,
543    center: bool = True,
544    pad_mode: str = "reflect",
545    normalized: bool = False,
546    onesided: Optional[bool] = None,
547    return_complex: Optional[bool] = None,
548) -> Tensor:
549    r"""Short-time Fourier transform (STFT).
550
551    .. warning::
552        From version 1.8.0, :attr:`return_complex` must always be given
553        explicitly for real inputs and `return_complex=False` has been
554        deprecated. Strongly prefer `return_complex=True` as in a future
555        pytorch release, this function will only return complex tensors.
556
557        Note that :func:`torch.view_as_real` can be used to recover a real
558        tensor with an extra last dimension for real and imaginary components.
559
560    .. warning::
561        From version 2.1, a warning will be provided if a :attr:`window` is
562        not specified. In a future release, this attribute will be required.
563        Not providing a window currently defaults to using a rectangular window,
564        which may result in undesirable artifacts. Consider using tapered windows,
565        such as :func:`torch.hann_window`.
566
567    The STFT computes the Fourier transform of short overlapping windows of the
568    input. This giving frequency components of the signal as they change over
569    time. The interface of this function is modeled after (but *not* a drop-in
570    replacement for) librosa_ stft function.
571
572    .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
573
574    Ignoring the optional batch dimension, this method computes the following
575    expression:
576
577    .. math::
578        X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
579                            \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
580                            \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{n\_fft}}\right),
581
582    where :math:`m` is the index of the sliding window, and :math:`\omega` is
583    the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
584    or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
585
586    * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
587      sequences.
588
589    * If :attr:`hop_length` is ``None`` (default), it is treated as equal to
590      ``floor(n_fft / 4)``.
591
592    * If :attr:`win_length` is ``None`` (default), it is treated as equal to
593      :attr:`n_fft`.
594
595    * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
596      :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
597      treated as if having :math:`1` everywhere in the window. If
598      :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
599      both sides to length :attr:`n_fft` before being applied.
600
601    * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
602      both sides so that the :math:`t`-th frame is centered at time
603      :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
604      begins at time  :math:`t \times \text{hop\_length}`.
605
606    * :attr:`pad_mode` determines the padding method used on :attr:`input` when
607      :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
608      all available options. Default is ``"reflect"``.
609
610    * If :attr:`onesided` is ``True`` (default for real input), only values for
611      :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
612      \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
613      the real-to-complex Fourier transform satisfies the conjugate symmetry,
614      i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
615      Note if the input or window tensors are complex, then :attr:`onesided`
616      output is not possible.
617
618    * If :attr:`normalized` is ``True`` (default is ``False``), the function
619      returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
620
621    * If :attr:`return_complex` is ``True`` (default if input is complex), the
622      return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
623      the output is a ``input.dim() + 2`` dimensional real tensor where the last
624      dimension represents the real and imaginary components.
625
626    Returns either a complex tensor of size :math:`(* \times N \times T)` if
627    :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
628    \times T \times 2)`. Where :math:`*` is the optional batch size of
629    :attr:`input`, :math:`N` is the number of frequencies where STFT is applied
630    and :math:`T` is the total number of frames used.
631
632    .. warning::
633      This function changed signature at version 0.4.1. Calling with the
634      previous signature may cause error or return incorrect result.
635
636    Args:
637        input (Tensor): the input tensor of shape `(B?, L)` where `B?` is an optional
638            batch dimension
639        n_fft (int): size of Fourier transform
640        hop_length (int, optional): the distance between neighboring sliding window
641            frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
642        win_length (int, optional): the size of window frame and STFT filter.
643            Default: ``None``  (treated as equal to :attr:`n_fft`)
644        window (Tensor, optional): the optional window function.
645            Shape must be 1d and `<= n_fft`
646            Default: ``None`` (treated as window of all :math:`1` s)
647        center (bool, optional): whether to pad :attr:`input` on both sides so
648            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
649            Default: ``True``
650        pad_mode (str, optional): controls the padding method used when
651            :attr:`center` is ``True``. Default: ``"reflect"``
652        normalized (bool, optional): controls whether to return the normalized STFT results
653             Default: ``False``
654        onesided (bool, optional): controls whether to return half of results to
655            avoid redundancy for real inputs.
656            Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
657        return_complex (bool, optional): whether to return a complex tensor, or
658            a real tensor with an extra last dimension for the real and
659            imaginary components.
660
661            .. versionchanged:: 2.0
662               ``return_complex`` is now a required argument for real inputs,
663               as the default is being transitioned to ``True``.
664
665            .. deprecated:: 2.0
666               ``return_complex=False`` is deprecated, instead use ``return_complex=True``
667               Note that calling :func:`torch.view_as_real` on the output will
668               recover the deprecated output format.
669
670    Returns:
671        Tensor: A tensor containing the STFT result with shape `(B?, N, T, C?)` where
672           - `B?` is an optional batch dimension from the input.
673           - `N` is the number of frequency samples, `(n_fft // 2) + 1` for
674             `onesided=True`, or otherwise `n_fft`.
675           - `T` is the number of frames, `1 + L // hop_length`
676             for `center=True`, or `1 + (L - n_fft) // hop_length` otherwise.
677           - `C?` is an optional length-2 dimension of real and imaginary
678             components, present when `return_complex=False`.
679
680    """
681    if has_torch_function_unary(input):
682        return handle_torch_function(
683            stft,
684            (input,),
685            input,
686            n_fft,
687            hop_length=hop_length,
688            win_length=win_length,
689            window=window,
690            center=center,
691            pad_mode=pad_mode,
692            normalized=normalized,
693            onesided=onesided,
694            return_complex=return_complex,
695        )
696    # NOTE: Do not edit. This code will be removed once the forward-compatibility
697    #       period is over for PR #73432
698    if center:
699        signal_dim = input.dim()
700        extended_shape = [1] * (3 - signal_dim) + list(input.size())
701        pad = int(n_fft // 2)
702        input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
703        input = input.view(input.shape[-signal_dim:])
704    return _VF.stft(  # type: ignore[attr-defined]
705        input,
706        n_fft,
707        hop_length,
708        win_length,
709        window,
710        normalized,
711        onesided,
712        return_complex,
713    )
714
715
716istft = _add_docstr(
717    torch.istft,
718    "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
719    "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n"
720    r"""
721Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.
722
723.. warning::
724    From version 2.1, a warning will be provided if a :attr:`window` is
725    not specified. In a future release, this attribute will be required.
726    Please provide the same window used in the stft call.
727
728It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the
729least squares estimation of the original signal. The algorithm will check using the NOLA condition (
730nonzero overlap).
731
732Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelope
733created by the summation of all the windows is never zero at certain point in time. Specifically,
734:math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`.
735
736Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,
737``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False
738since the signal isn't padded). If `length` is given in the arguments and is longer than expected,
739``istft`` will pad zeros to the end of the returned signal.
740
741If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.
742Left padding can be trimmed off exactly because they can be calculated but right padding cannot be
743calculated without additional information.
744
745Example: Suppose the last window is:
746``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``
747
748The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation
749of right padding. These additional values could be zeros or a reflection of the signal so providing
750:attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed
751(some loss of signal).
752
753[1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
754IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
755
756Args:
757    input (Tensor): The input tensor. Expected to be in the format of :func:`~torch.stft`,
758        output. That is a complex tensor of shape `(B?, N, T)` where
759
760        - `B?` is an optional batch dimension
761        - `N` is the number of frequency samples, `(n_fft // 2) + 1`
762          for onesided input, or otherwise `n_fft`.
763        - `T` is the number of frames, `1 + length // hop_length` for centered stft,
764          or `1 + (length - n_fft) // hop_length` otherwise.
765
766        .. versionchanged:: 2.0
767            Real datatype inputs are no longer supported. Input must now have a
768            complex datatype, as returned by ``stft(..., return_complex=True)``.
769    n_fft (int): Size of Fourier transform
770    hop_length (Optional[int]): The distance between neighboring sliding window frames.
771        (Default: ``n_fft // 4``)
772    win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
773    window (Optional[torch.Tensor]): The optional window function.
774        Shape must be 1d and `<= n_fft`
775        (Default: ``torch.ones(win_length)``)
776    center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is
777        centered at time :math:`t \times \text{hop\_length}`.
778        (Default: ``True``)
779    normalized (bool): Whether the STFT was normalized. (Default: ``False``)
780    onesided (Optional[bool]): Whether the STFT was onesided.
781        (Default: ``True`` if `n_fft != fft_size` in the input size)
782    length (Optional[int]): The amount to trim the signal by (i.e. the
783        original signal length). Defaults to `(T - 1) * hop_length` for
784        centered stft, or `n_fft + (T - 1) * hop_length` otherwise, where `T`
785        is the number of input frames.
786    return_complex (Optional[bool]):
787        Whether the output should be complex, or if the input should be
788        assumed to derive from a real signal and window.
789        Note that this is incompatible with ``onesided=True``.
790        (Default: ``False``)
791
792Returns:
793    Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
794        `B?` is an optional batch dimension from the input tensor.
795""",
796)
797
798
799if TYPE_CHECKING:
800    # These _impl functions return a variable number of tensors as output with
801    # __torch_function__; tuple unpacking is done already rather than being
802    # done by the caller of the _impl function
803    _unique_impl_out = Any
804else:
805    _unique_impl_out = Tuple[Tensor, Tensor, Tensor]
806
807
808def _unique_impl(
809    input: Tensor,
810    sorted: bool = True,
811    return_inverse: bool = False,
812    return_counts: bool = False,
813    dim: Optional[int] = None,
814) -> _unique_impl_out:
815    r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
816
817    Returns the unique elements of the input tensor.
818
819    .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that
820        this function also eliminates non-consecutive duplicate values.
821
822    .. note:: Currently in the CUDA implementation and the CPU implementation,
823        `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.
824        Sorting could be slow, so if your input tensor is already sorted, it is recommended to use
825        :func:`torch.unique_consecutive` which avoids the sorting.
826
827    Args:
828        input (Tensor): the input tensor
829        sorted (bool): Whether to sort the unique elements in ascending order
830            before returning as output.
831        return_inverse (bool): Whether to also return the indices for where
832            elements in the original input ended up in the returned unique list.
833        return_counts (bool): Whether to also return the counts for each unique
834            element.
835        dim (int, optional): the dimension to operate upon. If ``None``, the
836            unique of the flattened input is returned. Otherwise, each of the
837            tensors indexed by the given dimension is treated as one of the
838            elements to apply the unique operation upon. See examples for more
839            details. Default: ``None``
840
841    Returns:
842        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
843
844            - **output** (*Tensor*): the output list of unique scalar elements.
845            - **inverse_indices** (*Tensor*): (optional) if
846              :attr:`return_inverse` is True, there will be an additional
847              returned tensor (same shape as input) representing the indices
848              for where elements in the original input map to in the output;
849              otherwise, this function will only return a single tensor.
850            - **counts** (*Tensor*): (optional) if
851              :attr:`return_counts` is True, there will be an additional
852              returned tensor (same shape as output or output.size(dim),
853              if dim was specified) representing the number of occurrences
854              for each unique value or tensor.
855
856    Example::
857
858        >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))
859        >>> output
860        tensor([1, 2, 3])
861
862        >>> output, inverse_indices = torch.unique(
863        ...     torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)
864        >>> output
865        tensor([1, 2, 3])
866        >>> inverse_indices
867        tensor([0, 2, 1, 2])
868
869        >>> output, inverse_indices = torch.unique(
870        ...     torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)
871        >>> output
872        tensor([1, 2, 3])
873        >>> inverse_indices
874        tensor([[0, 2],
875                [1, 2]])
876
877        >>> a = torch.tensor([
878        ...     [
879        ...         [1, 1, 0, 0],
880        ...         [1, 1, 0, 0],
881        ...         [0, 0, 1, 1],
882        ...     ],
883        ...     [
884        ...         [0, 0, 1, 1],
885        ...         [0, 0, 1, 1],
886        ...         [1, 1, 1, 1],
887        ...     ],
888        ...     [
889        ...         [1, 1, 0, 0],
890        ...         [1, 1, 0, 0],
891        ...         [0, 0, 1, 1],
892        ...     ],
893        ... ])
894
895        >>> # If we call `torch.unique(a, dim=0)`, each of the tensors `a[idx, :, :]`
896        >>> # will be compared. We can see that `a[0, :, :]` and `a[2, :, :]` match
897        >>> # each other, so one of them will be removed.
898        >>> (a[0, :, :] == a[2, :, :]).all()
899        tensor(True)
900        >>> a_unique_dim0 = torch.unique(a, dim=0)
901        >>> a_unique_dim0
902        tensor([[[0, 0, 1, 1],
903                 [0, 0, 1, 1],
904                 [1, 1, 1, 1]],
905                [[1, 1, 0, 0],
906                 [1, 1, 0, 0],
907                 [0, 0, 1, 1]]])
908
909        >>> # Notice which sub-tensors from `a` match with the sub-tensors from
910        >>> # `a_unique_dim0`:
911        >>> (a_unique_dim0[0, :, :] == a[1, :, :]).all()
912        tensor(True)
913        >>> (a_unique_dim0[1, :, :] == a[0, :, :]).all()
914        tensor(True)
915
916        >>> # For `torch.unique(a, dim=1)`, each of the tensors `a[:, idx, :]` are
917        >>> # compared. `a[:, 0, :]` and `a[:, 1, :]` match each other, so one of
918        >>> # them will be removed.
919        >>> (a[:, 0, :] == a[:, 1, :]).all()
920        tensor(True)
921        >>> torch.unique(a, dim=1)
922        tensor([[[0, 0, 1, 1],
923                 [1, 1, 0, 0]],
924                [[1, 1, 1, 1],
925                 [0, 0, 1, 1]],
926                [[0, 0, 1, 1],
927                 [1, 1, 0, 0]]])
928
929        >>> # For `torch.unique(a, dim=2)`, the tensors `a[:, :, idx]` are compared.
930        >>> # `a[:, :, 0]` and `a[:, :, 1]` match each other. Also, `a[:, :, 2]` and
931        >>> # `a[:, :, 3]` match each other as well. So in this case, two of the
932        >>> # sub-tensors will be removed.
933        >>> (a[:, :, 0] == a[:, :, 1]).all()
934        tensor(True)
935        >>> (a[:, :, 2] == a[:, :, 3]).all()
936        tensor(True)
937        >>> torch.unique(a, dim=2)
938        tensor([[[0, 1],
939                 [0, 1],
940                 [1, 0]],
941                [[1, 0],
942                 [1, 0],
943                 [1, 1]],
944                [[0, 1],
945                 [0, 1],
946                 [1, 0]]])
947    """
948    if has_torch_function_unary(input):
949        return handle_torch_function(
950            unique,
951            (input,),
952            input,
953            sorted=sorted,
954            return_inverse=return_inverse,
955            return_counts=return_counts,
956            dim=dim,
957        )
958
959    if dim is not None:
960        output, inverse_indices, counts = _VF.unique_dim(
961            input,
962            dim,
963            sorted=sorted,
964            return_inverse=return_inverse,
965            return_counts=return_counts,
966        )
967    else:
968        output, inverse_indices, counts = torch._unique2(
969            input,
970            sorted=sorted,
971            return_inverse=return_inverse,
972            return_counts=return_counts,
973        )
974    return output, inverse_indices, counts
975
976
977def _unique_consecutive_impl(
978    input: Tensor,
979    return_inverse: bool = False,
980    return_counts: bool = False,
981    dim: Optional[int] = None,
982) -> _unique_impl_out:
983    r"""Eliminates all but the first element from every consecutive group of equivalent elements.
984
985    .. note:: This function is different from :func:`torch.unique` in the sense that this function
986        only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
987        in C++.
988
989    Args:
990        input (Tensor): the input tensor
991        return_inverse (bool): Whether to also return the indices for where
992            elements in the original input ended up in the returned unique list.
993        return_counts (bool): Whether to also return the counts for each unique
994            element.
995        dim (int): the dimension to apply unique. If ``None``, the unique of the
996            flattened input is returned. default: ``None``
997
998    Returns:
999        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
1000
1001            - **output** (*Tensor*): the output list of unique scalar elements.
1002            - **inverse_indices** (*Tensor*): (optional) if
1003              :attr:`return_inverse` is True, there will be an additional
1004              returned tensor (same shape as input) representing the indices
1005              for where elements in the original input map to in the output;
1006              otherwise, this function will only return a single tensor.
1007            - **counts** (*Tensor*): (optional) if
1008              :attr:`return_counts` is True, there will be an additional
1009              returned tensor (same shape as output or output.size(dim),
1010              if dim was specified) representing the number of occurrences
1011              for each unique value or tensor.
1012
1013    Example::
1014
1015        >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
1016        >>> output = torch.unique_consecutive(x)
1017        >>> output
1018        tensor([1, 2, 3, 1, 2])
1019
1020        >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
1021        >>> output
1022        tensor([1, 2, 3, 1, 2])
1023        >>> inverse_indices
1024        tensor([0, 0, 1, 1, 2, 3, 3, 4])
1025
1026        >>> output, counts = torch.unique_consecutive(x, return_counts=True)
1027        >>> output
1028        tensor([1, 2, 3, 1, 2])
1029        >>> counts
1030        tensor([2, 2, 1, 2, 1])
1031    """
1032    if has_torch_function_unary(input):
1033        return handle_torch_function(
1034            unique_consecutive,
1035            (input,),
1036            input,
1037            return_inverse=return_inverse,
1038            return_counts=return_counts,
1039            dim=dim,
1040        )
1041    output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore[attr-defined]
1042        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim
1043    )
1044    return output, inverse_indices, counts
1045
1046
1047def _return_counts(
1048    input,
1049    sorted=True,
1050    return_inverse=False,
1051    return_counts=False,
1052    dim=None,
1053):
1054    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1055
1056    if has_torch_function_unary(input):
1057        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
1058
1059    output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
1060    return output, counts
1061
1062
1063def _return_output(
1064    input,
1065    sorted=True,
1066    return_inverse=False,
1067    return_counts=False,
1068    dim=None,
1069):
1070    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
1071
1072    if has_torch_function_unary(input):
1073        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
1074
1075    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
1076    return output
1077
1078
1079def _return_inverse(
1080    input,
1081    sorted=True,
1082    return_inverse=False,
1083    return_counts=False,
1084    dim=None,
1085):
1086    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1087
1088    if has_torch_function_unary(input):
1089        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
1090
1091    output, inverse_indices, _ = _unique_impl(
1092        input, sorted, return_inverse, return_counts, dim
1093    )
1094    return output, inverse_indices
1095
1096
1097_return_inverse_false = boolean_dispatch(
1098    arg_name="return_counts",
1099    arg_index=3,
1100    default=False,
1101    if_true=_return_counts,
1102    if_false=_return_output,
1103    module_name=__name__,
1104    func_name="unique",
1105)
1106
1107_return_inverse_true = boolean_dispatch(
1108    arg_name="return_counts",
1109    arg_index=3,
1110    default=False,
1111    if_true=_unique_impl,
1112    if_false=_return_inverse,
1113    module_name=__name__,
1114    func_name="unique",
1115)
1116
1117# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
1118# resolve the output type in TorchScript we need to statically know the value of both parameters
1119
1120unique = boolean_dispatch(
1121    arg_name="return_inverse",
1122    arg_index=2,
1123    default=False,
1124    if_true=_return_inverse_true,
1125    if_false=_return_inverse_false,
1126    module_name=__name__,
1127    func_name="unique",
1128)
1129unique.__doc__ = _unique_impl.__doc__
1130
1131
1132def _consecutive_return_counts(
1133    input,
1134    return_inverse=False,
1135    return_counts=False,
1136    dim=None,
1137):
1138    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1139
1140    if has_torch_function_unary(input):
1141        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1142
1143    output, _, counts = _unique_consecutive_impl(
1144        input, return_inverse, return_counts, dim
1145    )
1146    return output, counts
1147
1148
1149def _consecutive_return_output(
1150    input,
1151    return_inverse=False,
1152    return_counts=False,
1153    dim=None,
1154):
1155    # type: (Tensor, bool, bool, Optional[int]) -> Tensor
1156
1157    if has_torch_function_unary(input):
1158        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1159
1160    output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1161    return output
1162
1163
1164def _consecutive_return_inverse(
1165    input,
1166    return_inverse=False,
1167    return_counts=False,
1168    dim=None,
1169):
1170    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1171
1172    if has_torch_function_unary(input):
1173        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1174
1175    output, inverse_indices, _ = _unique_consecutive_impl(
1176        input, return_inverse, return_counts, dim
1177    )
1178    return output, inverse_indices
1179
1180
1181_consecutive_return_inverse_false = boolean_dispatch(
1182    arg_name="return_counts",
1183    arg_index=1,
1184    default=False,
1185    if_true=_consecutive_return_counts,
1186    if_false=_consecutive_return_output,
1187    module_name=__name__,
1188    func_name="unique_consecutive",
1189)
1190
1191_consecutive_return_inverse_true = boolean_dispatch(
1192    arg_name="return_counts",
1193    arg_index=1,
1194    default=False,
1195    if_true=_unique_consecutive_impl,
1196    if_false=_consecutive_return_inverse,
1197    module_name=__name__,
1198    func_name="unique_consecutive",
1199)
1200
1201# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
1202# resolve the output type in TorchScript we need to statically know the value of both parameters
1203
1204unique_consecutive = boolean_dispatch(
1205    arg_name="return_inverse",
1206    arg_index=2,
1207    default=False,
1208    if_true=_consecutive_return_inverse_true,
1209    if_false=_consecutive_return_inverse_false,
1210    module_name=__name__,
1211    func_name="unique_consecutive",
1212)
1213unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
1214
1215if TYPE_CHECKING:
1216    pass
1217    # There's no good way to use this type annotation without breaking JIT
1218    # overloads. So leave untyped for mypy for now.
1219else:
1220
1221    @overload
1222    def tensordot(
1223        a,
1224        b,
1225        dims: int = 2,
1226        out: Optional[torch.Tensor] = None,
1227    ):
1228        pass
1229
1230    @overload
1231    def tensordot(  # noqa: F811
1232        a,
1233        b,
1234        dims: Tuple[List[int], List[int]],
1235        out: Optional[torch.Tensor] = None,
1236    ):
1237        pass
1238
1239    @overload
1240    def tensordot(  # noqa: F811
1241        a,
1242        b,
1243        dims: List[List[int]],
1244        out: Optional[torch.Tensor] = None,
1245    ):
1246        pass
1247
1248    @overload
1249    def tensordot(  # noqa: F811
1250        a,
1251        b,
1252        dims: torch.Tensor,
1253        out: Optional[torch.Tensor] = None,
1254    ):
1255        pass
1256
1257
1258def tensordot(  # noqa: F811
1259    a,
1260    b,
1261    dims=2,
1262    out: Optional[torch.Tensor] = None,
1263):
1264    r"""Returns a contraction of a and b over multiple dimensions.
1265
1266    :attr:`tensordot` implements a generalized matrix product.
1267
1268    Args:
1269      a (Tensor): Left tensor to contract
1270      b (Tensor): Right tensor to contract
1271      dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
1272         contract or explicit lists of dimensions for :attr:`a` and
1273         :attr:`b` respectively
1274
1275    When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
1276    the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
1277    respectively, :func:`~torch.tensordot` computes
1278
1279    .. math::
1280        r_{i_0,...,i_{m-d}, i_d,...,i_n}
1281          = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
1282
1283    When called with :attr:`dims` of the list form, the given dimensions will be contracted
1284    in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
1285    in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted
1286    dimensions.
1287
1288    Examples::
1289
1290        >>> a = torch.arange(60.).reshape(3, 4, 5)
1291        >>> b = torch.arange(24.).reshape(4, 3, 2)
1292        >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
1293        tensor([[4400., 4730.],
1294                [4532., 4874.],
1295                [4664., 5018.],
1296                [4796., 5162.],
1297                [4928., 5306.]])
1298
1299        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
1300        >>> a = torch.randn(3, 4, 5, device='cuda')
1301        >>> b = torch.randn(4, 5, 6, device='cuda')
1302        >>> c = torch.tensordot(a, b, dims=2).cpu()
1303        tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],
1304                [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],
1305                [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])
1306
1307        >>> a = torch.randn(3, 5, 4, 6)
1308        >>> b = torch.randn(6, 4, 5, 3)
1309        >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0]))
1310        tensor([[  7.7193,  -2.4867, -10.3204],
1311                [  1.5513, -14.4737,  -6.5113],
1312                [ -0.2850,   4.2573,  -3.5997]])
1313    """
1314    if has_torch_function_variadic(a, b):
1315        return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
1316
1317    if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
1318        raise RuntimeError(
1319            "tensordot expects dims to be int or "
1320            + "Tuple[List[int], List[int]] or "
1321            + "List[List[int]] containing two lists, but got "
1322            + f"dims={dims}"
1323        )
1324
1325    dims_a: List[int] = []
1326    dims_b: List[int] = []
1327
1328    if isinstance(dims, (tuple, list)):
1329        dims_a, dims_b = dims
1330
1331    if isinstance(dims, torch.Tensor):
1332        num_elements = dims.numel()
1333        if num_elements > 1:
1334            assert dims.size()[0] == 2
1335            dims_a = torch.jit.annotate(List[int], dims[0].tolist())
1336            dims_b = torch.jit.annotate(List[int], dims[1].tolist())
1337        else:
1338            dims_val = int(dims.item())
1339            if dims_val < 0:
1340                raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
1341            dims_a = list(range(-dims_val, 0))
1342            dims_b = list(range(dims_val))
1343
1344    if isinstance(dims, (int, torch.SymInt)):
1345        if dims < 0:
1346            raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
1347        if dims > min(a.dim(), b.dim()):
1348            raise RuntimeError(
1349                f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
1350            )
1351        dims_a = list(range(-dims, 0))
1352        dims_b = list(range(dims))
1353
1354    if out is None:
1355        return _VF.tensordot(a, b, dims_a, dims_b)  # type: ignore[attr-defined]
1356    else:
1357        return _VF.tensordot(a, b, dims_a, dims_b, out=out)  # type: ignore[attr-defined]
1358
1359
1360def cartesian_prod(*tensors: Tensor) -> Tensor:
1361    """Do cartesian product of the given sequence of tensors. The behavior is similar to
1362    python's `itertools.product`.
1363
1364    Args:
1365        *tensors: any number of 1 dimensional tensors.
1366
1367    Returns:
1368        Tensor: A tensor equivalent to converting all the input tensors into lists,
1369        do `itertools.product` on these lists, and finally convert the resulting list
1370        into tensor.
1371
1372    Example::
1373
1374        >>> import itertools
1375        >>> a = [1, 2, 3]
1376        >>> b = [4, 5]
1377        >>> list(itertools.product(a, b))
1378        [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
1379        >>> tensor_a = torch.tensor(a)
1380        >>> tensor_b = torch.tensor(b)
1381        >>> torch.cartesian_prod(tensor_a, tensor_b)
1382        tensor([[1, 4],
1383                [1, 5],
1384                [2, 4],
1385                [2, 5],
1386                [3, 4],
1387                [3, 5]])
1388    """
1389    # This wrapper exists to support variadic args.
1390    if has_torch_function(tensors):
1391        return handle_torch_function(cartesian_prod, tensors, *tensors)
1392    return _VF.cartesian_prod(tensors)  # type: ignore[attr-defined]
1393
1394
1395def block_diag(*tensors):
1396    """Create a block diagonal matrix from provided tensors.
1397
1398    Args:
1399        *tensors: One or more tensors with 0, 1, or 2 dimensions.
1400
1401    Returns:
1402        Tensor: A 2 dimensional tensor with all the input tensors arranged in
1403        order such that their upper left and lower right corners are
1404        diagonally adjacent. All other elements are set to 0.
1405
1406    Example::
1407
1408        >>> import torch
1409        >>> A = torch.tensor([[0, 1], [1, 0]])
1410        >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])
1411        >>> C = torch.tensor(7)
1412        >>> D = torch.tensor([1, 2, 3])
1413        >>> E = torch.tensor([[4], [5], [6]])
1414        >>> torch.block_diag(A, B, C, D, E)
1415        tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
1416                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1417                [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],
1418                [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],
1419                [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],
1420                [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],
1421                [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
1422                [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
1423                [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])
1424    """
1425    # This wrapper exists to support variadic args.
1426    if has_torch_function(tensors):
1427        return handle_torch_function(block_diag, tensors, *tensors)
1428    return torch._C._VariableFunctions.block_diag(tensors)  # type: ignore[attr-defined]
1429
1430
1431def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
1432    # type: (Tensor, Tensor, float, str) -> (Tensor)
1433    r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
1434
1435    Args:
1436        x1 (Tensor): input tensor of shape :math:`B \times P \times M`.
1437        x2 (Tensor): input tensor of shape :math:`B \times R \times M`.
1438        p: p value for the p-norm distance to calculate between each vector pair
1439            :math:`\in [0, \infty]`.
1440        compute_mode:
1441            'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate
1442            euclidean distance (p = 2) if P > 25 or R > 25
1443            'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate
1444            euclidean distance (p = 2)
1445            'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate
1446            euclidean distance (p = 2)
1447            Default: use_mm_for_euclid_dist_if_necessary.
1448
1449    If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the
1450    output will have shape :math:`B \times P \times R`.
1451
1452    This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`
1453    if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to
1454    `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest
1455    scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.
1456
1457    Example:
1458
1459        >>> a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
1460        >>> a
1461        tensor([[ 0.9041,  0.0196],
1462                [-0.3108, -2.4423],
1463                [-0.4821,  1.0590]])
1464        >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
1465        >>> b
1466        tensor([[-2.1763, -0.4713],
1467                [-0.6986,  1.3702]])
1468        >>> torch.cdist(a, b, p=2)
1469        tensor([[3.1193, 2.0959],
1470                [2.7138, 3.8322],
1471                [2.2830, 0.3791]])
1472    """
1473    if has_torch_function_variadic(x1, x2):
1474        return handle_torch_function(
1475            cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode
1476        )
1477    if compute_mode == "use_mm_for_euclid_dist_if_necessary":
1478        return _VF.cdist(x1, x2, p, None)  # type: ignore[attr-defined]
1479    elif compute_mode == "use_mm_for_euclid_dist":
1480        return _VF.cdist(x1, x2, p, 1)  # type: ignore[attr-defined]
1481    elif compute_mode == "donot_use_mm_for_euclid_dist":
1482        return _VF.cdist(x1, x2, p, 2)  # type: ignore[attr-defined]
1483    else:
1484        raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
1485
1486
1487def atleast_1d(*tensors):
1488    r"""
1489    Returns a 1-dimensional view of each input tensor with zero dimensions.
1490    Input tensors with one or more dimensions are returned as-is.
1491
1492    Args:
1493        input (Tensor or list of Tensors)
1494
1495    Returns:
1496        output (Tensor or tuple of Tensors)
1497
1498    Example::
1499
1500        >>> x = torch.arange(2)
1501        >>> x
1502        tensor([0, 1])
1503        >>> torch.atleast_1d(x)
1504        tensor([0, 1])
1505        >>> x = torch.tensor(1.)
1506        >>> x
1507        tensor(1.)
1508        >>> torch.atleast_1d(x)
1509        tensor([1.])
1510        >>> x = torch.tensor(0.5)
1511        >>> y = torch.tensor(1.)
1512        >>> torch.atleast_1d((x, y))
1513        (tensor([0.5000]), tensor([1.]))
1514    """
1515    # This wrapper exists to support variadic args.
1516    if has_torch_function(tensors):
1517        return handle_torch_function(atleast_1d, tensors, *tensors)
1518    if len(tensors) == 1:
1519        tensors = tensors[0]
1520    return _VF.atleast_1d(tensors)  # type: ignore[attr-defined]
1521
1522
1523def atleast_2d(*tensors):
1524    r"""
1525    Returns a 2-dimensional view of each input tensor with zero dimensions.
1526    Input tensors with two or more dimensions are returned as-is.
1527
1528    Args:
1529        input (Tensor or list of Tensors)
1530
1531    Returns:
1532        output (Tensor or tuple of Tensors)
1533
1534    Example::
1535
1536        >>> x = torch.tensor(1.)
1537        >>> x
1538        tensor(1.)
1539        >>> torch.atleast_2d(x)
1540        tensor([[1.]])
1541        >>> x = torch.arange(4).view(2, 2)
1542        >>> x
1543        tensor([[0, 1],
1544                [2, 3]])
1545        >>> torch.atleast_2d(x)
1546        tensor([[0, 1],
1547                [2, 3]])
1548        >>> x = torch.tensor(0.5)
1549        >>> y = torch.tensor(1.)
1550        >>> torch.atleast_2d((x, y))
1551        (tensor([[0.5000]]), tensor([[1.]]))
1552    """
1553    # This wrapper exists to support variadic args.
1554    if has_torch_function(tensors):
1555        return handle_torch_function(atleast_2d, tensors, *tensors)
1556    if len(tensors) == 1:
1557        tensors = tensors[0]
1558    return _VF.atleast_2d(tensors)  # type: ignore[attr-defined]
1559
1560
1561def atleast_3d(*tensors):
1562    r"""
1563    Returns a 3-dimensional view of each input tensor with zero dimensions.
1564    Input tensors with three or more dimensions are returned as-is.
1565
1566    Args:
1567        input (Tensor or list of Tensors)
1568
1569    Returns:
1570        output (Tensor or tuple of Tensors)
1571
1572    Example:
1573
1574        >>> x = torch.tensor(0.5)
1575        >>> x
1576        tensor(0.5000)
1577        >>> torch.atleast_3d(x)
1578        tensor([[[0.5000]]])
1579        >>> y = torch.arange(4).view(2, 2)
1580        >>> y
1581        tensor([[0, 1],
1582                [2, 3]])
1583        >>> torch.atleast_3d(y)
1584        tensor([[[0],
1585                 [1]],
1586                <BLANKLINE>
1587                [[2],
1588                 [3]]])
1589        >>> x = torch.tensor(1).view(1, 1, 1)
1590        >>> x
1591        tensor([[[1]]])
1592        >>> torch.atleast_3d(x)
1593        tensor([[[1]]])
1594        >>> x = torch.tensor(0.5)
1595        >>> y = torch.tensor(1.0)
1596        >>> torch.atleast_3d((x, y))
1597        (tensor([[[0.5000]]]), tensor([[[1.]]]))
1598    """
1599    # This wrapper exists to support variadic args.
1600    if has_torch_function(tensors):
1601        return handle_torch_function(atleast_3d, tensors, *tensors)
1602    if len(tensors) == 1:
1603        tensors = tensors[0]
1604    return _VF.atleast_3d(tensors)  # type: ignore[attr-defined]
1605
1606
1607if TYPE_CHECKING:
1608    pass
1609    # There's no good way to use this type annotation; cannot rename norm() to
1610    # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped
1611    # for mypy for now.
1612    #    def norm(input: Tensor,
1613    #             p: Optional[Union[str, Number]] = "fro",
1614    #             dim: Optional[Union[int, List[int]]] = None,
1615    #             keepdim: bool = False,
1616    #             out: Optional[Tensor] = None,
1617    #             dtype: _dtype = None) -> Tensor:
1618    #        return _norm_impl(input, p, dim, keepdim, out, dtype)
1619else:
1620    # TODO: type dim as BroadcastingList when
1621    # https://github.com/pytorch/pytorch/issues/33782 is fixed
1622    @overload
1623    def norm(
1624        input,
1625        p="fro",
1626        dim=None,
1627        keepdim=False,
1628        out=None,
1629        dtype=None,
1630    ):
1631        # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
1632        pass
1633
1634    @overload
1635    def norm(  # noqa: F811
1636        input,
1637        p="fro",
1638        dim=None,
1639        keepdim=False,
1640        out=None,
1641        dtype=None,
1642    ):
1643        # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
1644        pass
1645
1646    @overload
1647    def norm(  # noqa: F811
1648        input,
1649        p="fro",
1650        dim=None,
1651        keepdim=False,
1652        out=None,
1653        dtype=None,
1654    ):
1655        # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
1656        pass
1657
1658    @overload
1659    def norm(  # noqa: F811
1660        input,
1661        p="fro",
1662        dim=None,
1663        keepdim=False,
1664        out=None,
1665        dtype=None,
1666    ):
1667        # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
1668        pass
1669
1670
1671def norm(  # noqa: F811
1672    input,
1673    p: Optional[Union[float, str]] = "fro",
1674    dim=None,
1675    keepdim=False,
1676    out=None,
1677    dtype=None,
1678):
1679    r"""Returns the matrix norm or vector norm of a given tensor.
1680
1681    .. warning::
1682
1683        torch.norm is deprecated and may be removed in a future PyTorch release.
1684        Its documentation and behavior may be incorrect, and it is no longer
1685        actively maintained.
1686
1687        Use :func:`torch.linalg.vector_norm` when computing vector norms and
1688        :func:`torch.linalg.matrix_norm` when computing matrix norms.
1689        For a function with a similar behavior as this one see :func:`torch.linalg.norm`.
1690        Note, however, the signature for these functions is slightly different than the
1691        signature for ``torch.norm``.
1692
1693    Args:
1694        input (Tensor): The input tensor. Its data type must be either a floating
1695            point or complex type. For complex inputs, the norm is calculated using the
1696            absolute value of each element. If the input is complex and neither
1697            :attr:`dtype` nor :attr:`out` is specified, the result's data type will
1698            be the corresponding floating point type (e.g. float if :attr:`input` is
1699            complexfloat).
1700
1701        p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``
1702            The following norms can be calculated:
1703
1704            ======  ==============  ==========================
1705            ord     matrix norm     vector norm
1706            ======  ==============  ==========================
1707            'fro'   Frobenius norm  --
1708            'nuc'   nuclear norm    --
1709            Number  --              sum(abs(x)**ord)**(1./ord)
1710            ======  ==============  ==========================
1711
1712            The vector norm can be calculated across any number of dimensions.
1713            The corresponding dimensions of :attr:`input` are flattened into
1714            one dimension, and the norm is calculated on the flattened
1715            dimension.
1716
1717            Frobenius norm produces the same result as ``p=2`` in all cases
1718            except when :attr:`dim` is a list of three or more dims, in which
1719            case Frobenius norm throws an error.
1720
1721            Nuclear norm can only be calculated across exactly two dimensions.
1722
1723        dim (int, tuple of ints, list of ints, optional):
1724            Specifies which dimension or dimensions of :attr:`input` to
1725            calculate the norm across. If :attr:`dim` is ``None``, the norm will
1726            be calculated across all dimensions of :attr:`input`. If the norm
1727            type indicated by :attr:`p` does not support the specified number of
1728            dimensions, an error will occur.
1729        keepdim (bool, optional): whether the output tensors have :attr:`dim`
1730            retained or not. Ignored if :attr:`dim` = ``None`` and
1731            :attr:`out` = ``None``. Default: ``False``
1732        out (Tensor, optional): the output tensor. Ignored if
1733            :attr:`dim` = ``None`` and :attr:`out` = ``None``.
1734        dtype (:class:`torch.dtype`, optional): the desired data type of
1735            returned tensor. If specified, the input tensor is casted to
1736            :attr:`dtype` while performing the operation. Default: None.
1737
1738    .. note::
1739        Even though ``p='fro'`` supports any number of dimensions, the true
1740        mathematical definition of Frobenius norm only applies to tensors with
1741        exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'``
1742        aligns with the mathematical definition, since it can only be applied across
1743        exactly two dimensions.
1744
1745    Example::
1746
1747        >>> import torch
1748        >>> a = torch.arange(9, dtype= torch.float) - 4
1749        >>> b = a.reshape((3, 3))
1750        >>> torch.norm(a)
1751        tensor(7.7460)
1752        >>> torch.norm(b)
1753        tensor(7.7460)
1754        >>> torch.norm(a, float('inf'))
1755        tensor(4.)
1756        >>> torch.norm(b, float('inf'))
1757        tensor(4.)
1758        >>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float)
1759        >>> torch.norm(c, dim=0)
1760        tensor([1.4142, 2.2361, 5.0000])
1761        >>> torch.norm(c, dim=1)
1762        tensor([3.7417, 4.2426])
1763        >>> torch.norm(c, p=1, dim=1)
1764        tensor([6., 6.])
1765        >>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2)
1766        >>> torch.norm(d, dim=(1, 2))
1767        tensor([ 3.7417, 11.2250])
1768        >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
1769        (tensor(3.7417), tensor(11.2250))
1770    """
1771
1772    if has_torch_function_unary(input):
1773        return handle_torch_function(
1774            norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype
1775        )
1776
1777    # NB. All the repeated code and weird python is to please TorchScript.
1778    #     For a more compact implementation see the relevant function in `_refs/__init__.py`
1779
1780    # We don't do this for MPS or sparse tensors
1781    if input.layout == torch.strided and input.device.type in (
1782        "cpu",
1783        "cuda",
1784        "meta",
1785        torch.utils.backend_registration._privateuse1_backend_name,
1786    ):
1787        if dim is not None:
1788            if isinstance(dim, (int, torch.SymInt)):
1789                _dim = [dim]
1790            else:
1791                _dim = dim
1792        else:
1793            _dim = None  # type: ignore[assignment]
1794
1795        if isinstance(p, str):
1796            if p == "fro" and (
1797                dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
1798            ):
1799                if out is None:
1800                    return torch.linalg.vector_norm(
1801                        input, 2, _dim, keepdim, dtype=dtype
1802                    )
1803                else:
1804                    return torch.linalg.vector_norm(
1805                        input, 2, _dim, keepdim, dtype=dtype, out=out
1806                    )
1807
1808            # Here we either call the nuclear norm, or we call matrix_norm with some arguments
1809            # that will throw an error
1810            if _dim is None:
1811                _dim = list(range(input.ndim))
1812            if out is None:
1813                return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
1814            else:
1815                return torch.linalg.matrix_norm(
1816                    input, p, _dim, keepdim, dtype=dtype, out=out
1817                )
1818        else:
1819            # NB. p should be Union[str, number], not Optional!
1820            _p = 2.0 if p is None else p
1821            if out is None:
1822                return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
1823            else:
1824                return torch.linalg.vector_norm(
1825                    input, _p, _dim, keepdim, dtype=dtype, out=out
1826                )
1827
1828    ndim = input.dim()
1829
1830    # catch default case
1831    if dim is None and out is None and dtype is None and p is not None:
1832        if isinstance(p, str):
1833            if p == "fro":
1834                return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
1835        if not isinstance(p, str):
1836            _dim = list(range(ndim))
1837            return _VF.norm(input, p, dim=_dim, keepdim=keepdim)  # type: ignore[attr-defined]
1838
1839    # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
1840    # remove the overloads where dim is an int and replace with BraodcastingList1
1841    # and remove next four lines, replace _dim with dim
1842    if dim is not None:
1843        if isinstance(dim, (int, torch.SymInt)):
1844            _dim = [dim]
1845        else:
1846            _dim = dim
1847    else:
1848        _dim = None  # type: ignore[assignment]
1849
1850    if isinstance(p, str):
1851        if p == "fro":
1852            if dtype is not None:
1853                raise ValueError("dtype argument is not supported in frobenius norm")
1854
1855            if _dim is None:
1856                _dim = list(range(ndim))
1857            if out is None:
1858                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)  # type: ignore[arg-type]
1859            else:
1860                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore[arg-type]
1861        elif p == "nuc":
1862            if dtype is not None:
1863                raise ValueError("dtype argument is not supported in nuclear norm")
1864            if _dim is None:
1865                if out is None:
1866                    return _VF.nuclear_norm(input, keepdim=keepdim)  # type: ignore[arg-type]
1867                else:
1868                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)  # type: ignore[arg-type]
1869            else:
1870                if out is None:
1871                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)  # type: ignore[arg-type]
1872                else:
1873                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore[arg-type]
1874        raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
1875    else:
1876        if _dim is None:
1877            _dim = list(range(ndim))
1878
1879        if out is None:
1880            if dtype is None:
1881                return _VF.norm(input, p, _dim, keepdim=keepdim)  # type: ignore[attr-defined]
1882            else:
1883                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)  # type: ignore[attr-defined]
1884        else:
1885            if dtype is None:
1886                return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)  # type: ignore[attr-defined]
1887            else:
1888                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)  # type: ignore[attr-defined]
1889
1890
1891def unravel_index(
1892    indices: Tensor,
1893    shape: Union[int, Sequence[int], torch.Size],
1894) -> Tuple[Tensor, ...]:
1895    r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
1896    index into an arbitrary tensor of the specified shape.
1897
1898    Args:
1899        indices (Tensor): An integer tensor containing indices into the
1900            flattened version of an arbitrary tensor of shape :attr:`shape`.
1901            All elements must be in the range ``[0, prod(shape) - 1]``.
1902
1903        shape (int, sequence of ints, or torch.Size): The shape of the arbitrary
1904            tensor. All elements must be non-negative.
1905
1906    Returns:
1907        tuple of Tensors: Each ``i``-th tensor in the output corresponds with
1908        dimension ``i`` of :attr:`shape`. Each tensor has the same shape as
1909        ``indices`` and contains one index into dimension ``i`` for each of the
1910        flat indices given by ``indices``.
1911
1912    Example::
1913
1914        >>> import torch
1915        >>> torch.unravel_index(torch.tensor(4), (3, 2))
1916        (tensor(2),
1917         tensor(0))
1918
1919        >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2))
1920        (tensor([2, 0]),
1921         tensor([0, 1]))
1922
1923        >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2))
1924        (tensor([0, 0, 1, 1, 2, 2]),
1925         tensor([0, 1, 0, 1, 0, 1]))
1926
1927        >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10))
1928        (tensor([1, 5]),
1929         tensor([2, 6]),
1930         tensor([3, 7]),
1931         tensor([4, 8]))
1932
1933        >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10))
1934        (tensor([[1], [5]]),
1935         tensor([[2], [6]]),
1936         tensor([[3], [7]]),
1937         tensor([[4], [8]]))
1938
1939        >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100))
1940        (tensor([[12], [56]]),
1941         tensor([[34], [78]]))
1942    """
1943    if has_torch_function_unary(indices):
1944        return handle_torch_function(unravel_index, (indices,), indices, shape=shape)
1945    res_tensor = _unravel_index(indices, shape)
1946    return res_tensor.unbind(-1)
1947
1948
1949def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
1950    torch._check_type(
1951        not indices.is_complex()
1952        and not indices.is_floating_point()
1953        and not indices.dtype == torch.bool,
1954        lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
1955    )
1956
1957    torch._check_type(
1958        isinstance(shape, (int, torch.SymInt, Sequence)),
1959        lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}",
1960    )
1961
1962    if isinstance(shape, (int, torch.SymInt)):
1963        shape = torch.Size([shape])
1964    else:
1965        for dim in shape:
1966            torch._check_type(
1967                isinstance(dim, (int, torch.SymInt)),
1968                lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}",
1969            )
1970        shape = torch.Size(shape)
1971
1972    torch._check_value(
1973        all(dim >= 0 for dim in shape),
1974        lambda: f"'shape' cannot have negative values, but got {tuple(shape)}",
1975    )
1976
1977    coefs = list(
1978        reversed(
1979            list(
1980                itertools.accumulate(
1981                    reversed(shape[1:] + torch.Size([1])), func=operator.mul
1982                )
1983            )
1984        )
1985    )
1986    return indices.unsqueeze(-1).floor_divide(
1987        torch.tensor(coefs, device=indices.device, dtype=torch.int64)
1988    ) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
1989
1990
1991def chain_matmul(*matrices, out=None):
1992    r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
1993    using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
1994    of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
1995    needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
1996    If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
1997
1998    .. warning::
1999
2000        :func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release.
2001        Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors
2002        rather than multiple arguments.
2003
2004    Args:
2005        matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
2006        out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``.
2007
2008    Returns:
2009        Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
2010        would be of dimensions :math:`p_{1} \times p_{N + 1}`.
2011
2012    Example::
2013
2014        >>> # xdoctest: +SKIP
2015        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
2016        >>> a = torch.randn(3, 4)
2017        >>> b = torch.randn(4, 5)
2018        >>> c = torch.randn(5, 6)
2019        >>> d = torch.randn(6, 7)
2020        >>> # will raise a deprecation warning
2021        >>> torch.chain_matmul(a, b, c, d)
2022        tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],
2023                [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],
2024                [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])
2025
2026    .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
2027    """
2028    # This wrapper exists to support variadic args.
2029    if has_torch_function(matrices):
2030        return handle_torch_function(chain_matmul, matrices, *matrices)
2031
2032    if out is None:
2033        return _VF.chain_matmul(matrices)  # type: ignore[attr-defined]
2034    else:
2035        return _VF.chain_matmul(matrices, out=out)  # type: ignore[attr-defined]
2036
2037
2038def _lu_impl(A, pivot=True, get_infos=False, out=None):
2039    # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]
2040    r"""Computes the LU factorization of a matrix or batches of matrices
2041    :attr:`A`. Returns a tuple containing the LU factorization and
2042    pivots of :attr:`A`.  Pivoting is done if :attr:`pivot` is set to
2043    ``True``.
2044
2045    .. warning::
2046
2047        :func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor`
2048        and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a
2049        future PyTorch release.
2050        ``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with
2051
2052        .. code:: python
2053
2054            LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
2055
2056        ``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with
2057
2058        .. code:: python
2059
2060            LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
2061
2062    .. note::
2063        * The returned permutation matrix for every matrix in the batch is
2064          represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.
2065          ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm,
2066          the ``i``-th row was permuted with the ``j-1``-th row.
2067        * LU factorization with :attr:`pivot` = ``False`` is not available
2068          for CPU, and attempting to do so will throw an error. However,
2069          LU factorization with :attr:`pivot` = ``False`` is available for
2070          CUDA.
2071        * This function does not check if the factorization was successful
2072          or not if :attr:`get_infos` is ``True`` since the status of the
2073          factorization is present in the third element of the return tuple.
2074        * In the case of batches of square matrices with size less or equal
2075          to 32 on a CUDA device, the LU factorization is repeated for
2076          singular matrices due to the bug in the MAGMA library
2077          (see magma issue 13).
2078        * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
2079
2080    .. warning::
2081        The gradients of this function will only be finite when :attr:`A` is full rank.
2082        This is because the LU decomposition is just differentiable at full rank matrices.
2083        Furthermore, if :attr:`A` is close to not being full rank,
2084        the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`.
2085
2086    Args:
2087        A (Tensor): the tensor to factor of size :math:`(*, m, n)`
2088        pivot (bool, optional): controls whether pivoting is done. Default: ``True``
2089        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
2090                                    Default: ``False``
2091        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
2092                               then the elements in the tuple are Tensor, IntTensor,
2093                               and IntTensor. If :attr:`get_infos` is ``False``, then the
2094                               elements in the tuple are Tensor, IntTensor. Default: ``None``
2095
2096    Returns:
2097        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
2098
2099            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`
2100
2101            - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`.
2102              ``pivots`` stores all the intermediate transpositions of rows.
2103              The final permutation ``perm`` could be reconstructed by
2104              applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``,
2105              where ``perm`` is initially the identity permutation of :math:`m` elements
2106              (essentially this is what :func:`torch.lu_unpack` is doing).
2107
2108            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
2109              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
2110              each minibatch has succeeded or failed
2111
2112    Example::
2113
2114        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
2115        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
2116        >>> A = torch.randn(2, 3, 3)
2117        >>> A_LU, pivots = torch.lu(A)
2118        >>> A_LU
2119        tensor([[[ 1.3506,  2.5558, -0.0816],
2120                 [ 0.1684,  1.1551,  0.1940],
2121                 [ 0.1193,  0.6189, -0.5497]],
2122
2123                [[ 0.4526,  1.2526, -0.3285],
2124                 [-0.7988,  0.7175, -0.9701],
2125                 [ 0.2634, -0.9255, -0.3459]]])
2126        >>> pivots
2127        tensor([[ 3,  3,  3],
2128                [ 3,  3,  3]], dtype=torch.int32)
2129        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
2130        >>> if info.nonzero().size(0) == 0:
2131        ...     print('LU factorization succeeded for all samples!')
2132        LU factorization succeeded for all samples!
2133    """
2134    # If get_infos is True, then we don't need to check for errors and vice versa
2135    return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
2136
2137
2138if TYPE_CHECKING:
2139    _ListOrSeq = Sequence[Tensor]
2140else:
2141    _ListOrSeq = List[Tensor]
2142
2143
2144def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
2145    get_infos_int = 1 if get_infos else 0
2146    if out_len - get_infos_int != 2:
2147        raise TypeError(
2148            f"expected tuple of {2 + int(get_infos)} elements but got {out_len}"
2149        )
2150    if not isinstance(out, (tuple, list)):
2151        raise TypeError(
2152            f"argument 'out' must be tuple of Tensors, not {type(out).__name__}"
2153        )
2154
2155
2156def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
2157    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
2158    if has_torch_function_unary(A):
2159        return handle_torch_function(
2160            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
2161        )
2162    result = _lu_impl(A, pivot, get_infos, out)
2163    if out is not None:
2164        _check_list_size(len(out), get_infos, out)
2165        for i in range(len(out)):
2166            out[i].resize_as_(result[i]).copy_(result[i])
2167        return out
2168    else:
2169        return result  # A_LU, pivots, infos
2170
2171
2172def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
2173    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
2174    # need to check for torch_function here so that we exit if
2175    if has_torch_function_unary(A):
2176        return handle_torch_function(
2177            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
2178        )
2179    result = _lu_impl(A, pivot, get_infos, out)
2180    if out is not None:
2181        _check_list_size(len(out), get_infos, out)
2182        for i in range(len(out)):
2183            out[i].resize_as_(result[i]).copy_(result[i])
2184        return out
2185    else:
2186        return result[0], result[1]  # A_LU, pivots
2187
2188
2189# The return type of lu depends on `get_infos`, so in order to resolve the output type
2190# of lu in TorchScript we need to statically know the value of `get_infos`
2191lu = boolean_dispatch(
2192    arg_name="get_infos",
2193    arg_index=2,
2194    default=False,
2195    if_true=_lu_with_infos,
2196    if_false=_lu_no_infos,
2197    module_name=__name__,
2198    func_name="lu",
2199)
2200lu.__doc__ = _lu_impl.__doc__
2201
2202
2203def align_tensors(*tensors):
2204    raise RuntimeError("`align_tensors` not yet implemented.")
2205