xref: /aosp_15_r20/external/pytorch/torch/testing/_comparison.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3import cmath
4import collections.abc
5import contextlib
6from typing import (
7    Any,
8    Callable,
9    Collection,
10    Dict,
11    List,
12    NoReturn,
13    Optional,
14    Sequence,
15    Tuple,
16    Type,
17    Union,
18)
19from typing_extensions import deprecated
20
21import torch
22
23
24try:
25    import numpy as np
26
27    HAS_NUMPY = True
28except ModuleNotFoundError:
29    HAS_NUMPY = False
30    np = None  # type: ignore[assignment]
31
32
33class ErrorMeta(Exception):
34    """Internal testing exception that makes that carries error metadata."""
35
36    def __init__(
37        self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()
38    ) -> None:
39        super().__init__(
40            "If you are a user and see this message during normal operation "
41            "please file an issue at https://github.com/pytorch/pytorch/issues. "
42            "If you are a developer and working on the comparison functions, please `raise ErrorMeta.to_error()` "
43            "for user facing errors."
44        )
45        self.type = type
46        self.msg = msg
47        self.id = id
48
49    def to_error(
50        self, msg: Optional[Union[str, Callable[[str], str]]] = None
51    ) -> Exception:
52        if not isinstance(msg, str):
53            generated_msg = self.msg
54            if self.id:
55                generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}"
56
57            msg = msg(generated_msg) if callable(msg) else generated_msg
58
59        return self.type(msg)
60
61
62# Some analysis of tolerance by logging tests from test_torch.py can be found in
63# https://github.com/pytorch/pytorch/pull/32538.
64# {dtype: (rtol, atol)}
65_DTYPE_PRECISIONS = {
66    torch.float16: (0.001, 1e-5),
67    torch.bfloat16: (0.016, 1e-5),
68    torch.float32: (1.3e-6, 1e-5),
69    torch.float64: (1e-7, 1e-7),
70    torch.complex32: (0.001, 1e-5),
71    torch.complex64: (1.3e-6, 1e-5),
72    torch.complex128: (1e-7, 1e-7),
73}
74# The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in
75# their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values`
76_DTYPE_PRECISIONS.update(
77    dict.fromkeys(
78        (torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32),
79        _DTYPE_PRECISIONS[torch.float32],
80    )
81)
82
83
84def default_tolerances(
85    *inputs: Union[torch.Tensor, torch.dtype],
86    dtype_precisions: Optional[Dict[torch.dtype, Tuple[float, float]]] = None,
87) -> Tuple[float, float]:
88    """Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype.
89
90    See :func:`assert_close` for a table of the default tolerance for each dtype.
91
92    Returns:
93        (Tuple[float, float]): Loosest tolerances of all input dtypes.
94    """
95    dtypes = []
96    for input in inputs:
97        if isinstance(input, torch.Tensor):
98            dtypes.append(input.dtype)
99        elif isinstance(input, torch.dtype):
100            dtypes.append(input)
101        else:
102            raise TypeError(
103                f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead."
104            )
105    dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS
106    rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes])
107    return max(rtols), max(atols)
108
109
110def get_tolerances(
111    *inputs: Union[torch.Tensor, torch.dtype],
112    rtol: Optional[float],
113    atol: Optional[float],
114    id: Tuple[Any, ...] = (),
115) -> Tuple[float, float]:
116    """Gets absolute and relative to be used for numeric comparisons.
117
118    If both ``rtol`` and ``atol`` are specified, this is a no-op. If both are not specified, the return value of
119    :func:`default_tolerances` is used.
120
121    Raises:
122        ErrorMeta: With :class:`ValueError`, if only ``rtol`` or ``atol`` is specified.
123
124    Returns:
125        (Tuple[float, float]): Valid absolute and relative tolerances.
126    """
127    if (rtol is None) ^ (atol is None):
128        # We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
129        # results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
130        raise ErrorMeta(
131            ValueError,
132            f"Both 'rtol' and 'atol' must be either specified or omitted, "
133            f"but got no {'rtol' if rtol is None else 'atol'}.",
134            id=id,
135        )
136    elif rtol is not None and atol is not None:
137        return rtol, atol
138    else:
139        return default_tolerances(*inputs)
140
141
142def _make_mismatch_msg(
143    *,
144    default_identifier: str,
145    identifier: Optional[Union[str, Callable[[str], str]]] = None,
146    extra: Optional[str] = None,
147    abs_diff: float,
148    abs_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None,
149    atol: float,
150    rel_diff: float,
151    rel_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None,
152    rtol: float,
153) -> str:
154    """Makes a mismatch error message for numeric values.
155
156    Args:
157        default_identifier (str): Default description of the compared values, e.g. "Tensor-likes".
158        identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides
159            ``default_identifier``. Can be passed as callable in which case it will be called with
160            ``default_identifier`` to create the description at runtime.
161        extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics.
162        abs_diff (float): Absolute difference.
163        abs_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the absolute difference.
164        atol (float): Allowed absolute tolerance. Will only be added to mismatch statistics if it or ``rtol`` are
165            ``> 0``.
166        rel_diff (float): Relative difference.
167        rel_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the relative difference.
168        rtol (float): Allowed relative tolerance. Will only be added to mismatch statistics if it or ``atol`` are
169            ``> 0``.
170    """
171    equality = rtol == 0 and atol == 0
172
173    def make_diff_msg(
174        *,
175        type: str,
176        diff: float,
177        idx: Optional[Union[int, Tuple[int, ...]]],
178        tol: float,
179    ) -> str:
180        if idx is None:
181            msg = f"{type.title()} difference: {diff}"
182        else:
183            msg = f"Greatest {type} difference: {diff} at index {idx}"
184        if not equality:
185            msg += f" (up to {tol} allowed)"
186        return msg + "\n"
187
188    if identifier is None:
189        identifier = default_identifier
190    elif callable(identifier):
191        identifier = identifier(default_identifier)
192
193    msg = f"{identifier} are not {'equal' if equality else 'close'}!\n\n"
194
195    if extra:
196        msg += f"{extra.strip()}\n"
197
198    msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol)
199    msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol)
200
201    return msg.strip()
202
203
204def make_scalar_mismatch_msg(
205    actual: Union[bool, int, float, complex],
206    expected: Union[bool, int, float, complex],
207    *,
208    rtol: float,
209    atol: float,
210    identifier: Optional[Union[str, Callable[[str], str]]] = None,
211) -> str:
212    """Makes a mismatch error message for scalars.
213
214    Args:
215        actual (Union[bool, int, float, complex]): Actual scalar.
216        expected (Union[bool, int, float, complex]): Expected scalar.
217        rtol (float): Relative tolerance.
218        atol (float): Absolute tolerance.
219        identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the scalars. Can be passed
220            as callable in which case it will be called by the default value to create the description at runtime.
221            Defaults to "Scalars".
222    """
223    abs_diff = abs(actual - expected)
224    rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected)
225    return _make_mismatch_msg(
226        default_identifier="Scalars",
227        identifier=identifier,
228        extra=f"Expected {expected} but got {actual}.",
229        abs_diff=abs_diff,
230        atol=atol,
231        rel_diff=rel_diff,
232        rtol=rtol,
233    )
234
235
236def make_tensor_mismatch_msg(
237    actual: torch.Tensor,
238    expected: torch.Tensor,
239    matches: torch.Tensor,
240    *,
241    rtol: float,
242    atol: float,
243    identifier: Optional[Union[str, Callable[[str], str]]] = None,
244):
245    """Makes a mismatch error message for tensors.
246
247    Args:
248        actual (torch.Tensor): Actual tensor.
249        expected (torch.Tensor): Expected tensor.
250        matches (torch.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the
251            location of matches.
252        rtol (float): Relative tolerance.
253        atol (float): Absolute tolerance.
254        identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the tensors. Can be passed
255            as callable in which case it will be called by the default value to create the description at runtime.
256            Defaults to "Tensor-likes".
257    """
258
259    def unravel_flat_index(flat_index: int) -> Tuple[int, ...]:
260        if not matches.shape:
261            return ()
262
263        inverse_index = []
264        for size in matches.shape[::-1]:
265            div, mod = divmod(flat_index, size)
266            flat_index = div
267            inverse_index.append(mod)
268
269        return tuple(inverse_index[::-1])
270
271    number_of_elements = matches.numel()
272    total_mismatches = number_of_elements - int(torch.sum(matches))
273    extra = (
274        f"Mismatched elements: {total_mismatches} / {number_of_elements} "
275        f"({total_mismatches / number_of_elements:.1%})"
276    )
277
278    actual_flat = actual.flatten()
279    expected_flat = expected.flatten()
280    matches_flat = matches.flatten()
281
282    if not actual.dtype.is_floating_point and not actual.dtype.is_complex:
283        # TODO: Instead of always upcasting to int64, it would be sufficient to cast to the next higher dtype to avoid
284        #  overflow
285        actual_flat = actual_flat.to(torch.int64)
286        expected_flat = expected_flat.to(torch.int64)
287
288    abs_diff = torch.abs(actual_flat - expected_flat)
289    # Ensure that only mismatches are used for the max_abs_diff computation
290    abs_diff[matches_flat] = 0
291    max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0)
292
293    rel_diff = abs_diff / torch.abs(expected_flat)
294    # Ensure that only mismatches are used for the max_rel_diff computation
295    rel_diff[matches_flat] = 0
296    max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0)
297    return _make_mismatch_msg(
298        default_identifier="Tensor-likes",
299        identifier=identifier,
300        extra=extra,
301        abs_diff=max_abs_diff.item(),
302        abs_diff_idx=unravel_flat_index(int(max_abs_diff_flat_idx)),
303        atol=atol,
304        rel_diff=max_rel_diff.item(),
305        rel_diff_idx=unravel_flat_index(int(max_rel_diff_flat_idx)),
306        rtol=rtol,
307    )
308
309
310class UnsupportedInputs(Exception):  # noqa: B903
311    """Exception to be raised during the construction of a :class:`Pair` in case it doesn't support the inputs."""
312
313
314class Pair(abc.ABC):
315    """ABC for all comparison pairs to be used in conjunction with :func:`assert_equal`.
316
317    Each subclass needs to overwrite :meth:`Pair.compare` that performs the actual comparison.
318
319    Each pair receives **all** options, so select the ones applicable for the subclass and forward the rest to the
320    super class. Raising an :class:`UnsupportedInputs` during constructions indicates that the pair is not able to
321    handle the inputs and the next pair type will be tried.
322
323    All other errors should be raised as :class:`ErrorMeta`. After the instantiation, :meth:`Pair._make_error_meta` can
324    be used to automatically handle overwriting the message with a user supplied one and id handling.
325    """
326
327    def __init__(
328        self,
329        actual: Any,
330        expected: Any,
331        *,
332        id: Tuple[Any, ...] = (),
333        **unknown_parameters: Any,
334    ) -> None:
335        self.actual = actual
336        self.expected = expected
337        self.id = id
338        self._unknown_parameters = unknown_parameters
339
340    @staticmethod
341    def _inputs_not_supported() -> NoReturn:
342        raise UnsupportedInputs
343
344    @staticmethod
345    def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, Tuple[Type, ...]]):
346        """Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise."""
347        if not all(isinstance(input, cls) for input in inputs):
348            Pair._inputs_not_supported()
349
350    def _fail(
351        self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()
352    ) -> NoReturn:
353        """Raises an :class:`ErrorMeta` from a given exception type and message and the stored id.
354
355        .. warning::
356
357            If you use this before the ``super().__init__(...)`` call in the constructor, you have to pass the ``id``
358            explicitly.
359        """
360        raise ErrorMeta(type, msg, id=self.id if not id and hasattr(self, "id") else id)
361
362    @abc.abstractmethod
363    def compare(self) -> None:
364        """Compares the inputs and raises an :class`ErrorMeta` in case they mismatch."""
365
366    def extra_repr(self) -> Sequence[Union[str, Tuple[str, Any]]]:
367        """Returns extra information that will be included in the representation.
368
369        Should be overwritten by all subclasses that use additional options. The representation of the object will only
370        be surfaced in case we encounter an unexpected error and thus should help debug the issue. Can be a sequence of
371        key-value-pairs or attribute names.
372        """
373        return []
374
375    def __repr__(self) -> str:
376        head = f"{type(self).__name__}("
377        tail = ")"
378        body = [
379            f"    {name}={value!s},"
380            for name, value in [
381                ("id", self.id),
382                ("actual", self.actual),
383                ("expected", self.expected),
384                *[
385                    (extra, getattr(self, extra)) if isinstance(extra, str) else extra
386                    for extra in self.extra_repr()
387                ],
388            ]
389        ]
390        return "\n".join((head, *body, *tail))
391
392
393class ObjectPair(Pair):
394    """Pair for any type of inputs that will be compared with the `==` operator.
395
396    .. note::
397
398        Since this will instantiate for any kind of inputs, it should only be used as fallback after all other pairs
399        couldn't handle the inputs.
400
401    """
402
403    def compare(self) -> None:
404        try:
405            equal = self.actual == self.expected
406        except Exception as error:
407            # We are not using `self._raise_error_meta` here since we need the exception chaining
408            raise ErrorMeta(
409                ValueError,
410                f"{self.actual} == {self.expected} failed with:\n{error}.",
411                id=self.id,
412            ) from error
413
414        if not equal:
415            self._fail(AssertionError, f"{self.actual} != {self.expected}")
416
417
418class NonePair(Pair):
419    """Pair for ``None`` inputs."""
420
421    def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None:
422        if not (actual is None or expected is None):
423            self._inputs_not_supported()
424
425        super().__init__(actual, expected, **other_parameters)
426
427    def compare(self) -> None:
428        if not (self.actual is None and self.expected is None):
429            self._fail(
430                AssertionError, f"None mismatch: {self.actual} is not {self.expected}"
431            )
432
433
434class BooleanPair(Pair):
435    """Pair for :class:`bool` inputs.
436
437    .. note::
438
439        If ``numpy`` is available, also handles :class:`numpy.bool_` inputs.
440
441    """
442
443    def __init__(
444        self,
445        actual: Any,
446        expected: Any,
447        *,
448        id: Tuple[Any, ...],
449        **other_parameters: Any,
450    ) -> None:
451        actual, expected = self._process_inputs(actual, expected, id=id)
452        super().__init__(actual, expected, **other_parameters)
453
454    @property
455    def _supported_types(self) -> Tuple[Type, ...]:
456        cls: List[Type] = [bool]
457        if HAS_NUMPY:
458            cls.append(np.bool_)
459        return tuple(cls)
460
461    def _process_inputs(
462        self, actual: Any, expected: Any, *, id: Tuple[Any, ...]
463    ) -> Tuple[bool, bool]:
464        self._check_inputs_isinstance(actual, expected, cls=self._supported_types)
465        actual, expected = (
466            self._to_bool(bool_like, id=id) for bool_like in (actual, expected)
467        )
468        return actual, expected
469
470    def _to_bool(self, bool_like: Any, *, id: Tuple[Any, ...]) -> bool:
471        if isinstance(bool_like, bool):
472            return bool_like
473        elif isinstance(bool_like, np.bool_):
474            return bool_like.item()
475        else:
476            raise ErrorMeta(
477                TypeError, f"Unknown boolean type {type(bool_like)}.", id=id
478            )
479
480    def compare(self) -> None:
481        if self.actual is not self.expected:
482            self._fail(
483                AssertionError,
484                f"Booleans mismatch: {self.actual} is not {self.expected}",
485            )
486
487
488class NumberPair(Pair):
489    """Pair for Python number (:class:`int`, :class:`float`, and :class:`complex`) inputs.
490
491    .. note::
492
493        If ``numpy`` is available, also handles :class:`numpy.number` inputs.
494
495    Kwargs:
496        rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
497            values based on the type are selected with the below table.
498        atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
499            values based on the type are selected with the below table.
500        equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``.
501        check_dtype (bool): If ``True``, the type of the inputs will be checked for equality. Defaults to ``False``.
502
503    The following table displays correspondence between Python number type and the ``torch.dtype``'s. See
504    :func:`assert_close` for the corresponding tolerances.
505
506    +------------------+-------------------------------+
507    | ``type``         | corresponding ``torch.dtype`` |
508    +==================+===============================+
509    | :class:`int`     | :attr:`~torch.int64`          |
510    +------------------+-------------------------------+
511    | :class:`float`   | :attr:`~torch.float64`        |
512    +------------------+-------------------------------+
513    | :class:`complex` | :attr:`~torch.complex64`      |
514    +------------------+-------------------------------+
515    """
516
517    _TYPE_TO_DTYPE = {
518        int: torch.int64,
519        float: torch.float64,
520        complex: torch.complex128,
521    }
522    _NUMBER_TYPES = tuple(_TYPE_TO_DTYPE.keys())
523
524    def __init__(
525        self,
526        actual: Any,
527        expected: Any,
528        *,
529        id: Tuple[Any, ...] = (),
530        rtol: Optional[float] = None,
531        atol: Optional[float] = None,
532        equal_nan: bool = False,
533        check_dtype: bool = False,
534        **other_parameters: Any,
535    ) -> None:
536        actual, expected = self._process_inputs(actual, expected, id=id)
537        super().__init__(actual, expected, id=id, **other_parameters)
538
539        self.rtol, self.atol = get_tolerances(
540            *[self._TYPE_TO_DTYPE[type(input)] for input in (actual, expected)],
541            rtol=rtol,
542            atol=atol,
543            id=id,
544        )
545        self.equal_nan = equal_nan
546        self.check_dtype = check_dtype
547
548    @property
549    def _supported_types(self) -> Tuple[Type, ...]:
550        cls = list(self._NUMBER_TYPES)
551        if HAS_NUMPY:
552            cls.append(np.number)
553        return tuple(cls)
554
555    def _process_inputs(
556        self, actual: Any, expected: Any, *, id: Tuple[Any, ...]
557    ) -> Tuple[Union[int, float, complex], Union[int, float, complex]]:
558        self._check_inputs_isinstance(actual, expected, cls=self._supported_types)
559        actual, expected = (
560            self._to_number(number_like, id=id) for number_like in (actual, expected)
561        )
562        return actual, expected
563
564    def _to_number(
565        self, number_like: Any, *, id: Tuple[Any, ...]
566    ) -> Union[int, float, complex]:
567        if HAS_NUMPY and isinstance(number_like, np.number):
568            return number_like.item()
569        elif isinstance(number_like, self._NUMBER_TYPES):
570            return number_like  # type: ignore[return-value]
571        else:
572            raise ErrorMeta(
573                TypeError, f"Unknown number type {type(number_like)}.", id=id
574            )
575
576    def compare(self) -> None:
577        if self.check_dtype and type(self.actual) is not type(self.expected):
578            self._fail(
579                AssertionError,
580                f"The (d)types do not match: {type(self.actual)} != {type(self.expected)}.",
581            )
582
583        if self.actual == self.expected:
584            return
585
586        if self.equal_nan and cmath.isnan(self.actual) and cmath.isnan(self.expected):
587            return
588
589        abs_diff = abs(self.actual - self.expected)
590        tolerance = self.atol + self.rtol * abs(self.expected)
591
592        if cmath.isfinite(abs_diff) and abs_diff <= tolerance:
593            return
594
595        self._fail(
596            AssertionError,
597            make_scalar_mismatch_msg(
598                self.actual, self.expected, rtol=self.rtol, atol=self.atol
599            ),
600        )
601
602    def extra_repr(self) -> Sequence[str]:
603        return (
604            "rtol",
605            "atol",
606            "equal_nan",
607            "check_dtype",
608        )
609
610
611class TensorLikePair(Pair):
612    """Pair for :class:`torch.Tensor`-like inputs.
613
614    Kwargs:
615        allow_subclasses (bool):
616        rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
617            values based on the type are selected. See :func:assert_close: for details.
618        atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
619            values based on the type are selected. See :func:assert_close: for details.
620        equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``.
621        check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same
622            :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different
623            :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared.
624        check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
625            check is disabled, tensors with different ``dtype``'s are promoted  to a common ``dtype`` (according to
626            :func:`torch.promote_types`) before being compared.
627        check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this
628            check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
629            compared.
630        check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
631    """
632
633    def __init__(
634        self,
635        actual: Any,
636        expected: Any,
637        *,
638        id: Tuple[Any, ...] = (),
639        allow_subclasses: bool = True,
640        rtol: Optional[float] = None,
641        atol: Optional[float] = None,
642        equal_nan: bool = False,
643        check_device: bool = True,
644        check_dtype: bool = True,
645        check_layout: bool = True,
646        check_stride: bool = False,
647        **other_parameters: Any,
648    ):
649        actual, expected = self._process_inputs(
650            actual, expected, id=id, allow_subclasses=allow_subclasses
651        )
652        super().__init__(actual, expected, id=id, **other_parameters)
653
654        self.rtol, self.atol = get_tolerances(
655            actual, expected, rtol=rtol, atol=atol, id=self.id
656        )
657        self.equal_nan = equal_nan
658        self.check_device = check_device
659        self.check_dtype = check_dtype
660        self.check_layout = check_layout
661        self.check_stride = check_stride
662
663    def _process_inputs(
664        self, actual: Any, expected: Any, *, id: Tuple[Any, ...], allow_subclasses: bool
665    ) -> Tuple[torch.Tensor, torch.Tensor]:
666        directly_related = isinstance(actual, type(expected)) or isinstance(
667            expected, type(actual)
668        )
669        if not directly_related:
670            self._inputs_not_supported()
671
672        if not allow_subclasses and type(actual) is not type(expected):
673            self._inputs_not_supported()
674
675        actual, expected = (self._to_tensor(input) for input in (actual, expected))
676        for tensor in (actual, expected):
677            self._check_supported(tensor, id=id)
678        return actual, expected
679
680    def _to_tensor(self, tensor_like: Any) -> torch.Tensor:
681        if isinstance(tensor_like, torch.Tensor):
682            return tensor_like
683
684        try:
685            return torch.as_tensor(tensor_like)
686        except Exception:
687            self._inputs_not_supported()
688
689    def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None:
690        if tensor.layout not in {
691            torch.strided,
692            torch.jagged,
693            torch.sparse_coo,
694            torch.sparse_csr,
695            torch.sparse_csc,
696            torch.sparse_bsr,
697            torch.sparse_bsc,
698        }:
699            raise ErrorMeta(
700                ValueError, f"Unsupported tensor layout {tensor.layout}", id=id
701            )
702
703    def compare(self) -> None:
704        actual, expected = self.actual, self.expected
705
706        self._compare_attributes(actual, expected)
707        if any(input.device.type == "meta" for input in (actual, expected)):
708            return
709
710        actual, expected = self._equalize_attributes(actual, expected)
711        self._compare_values(actual, expected)
712
713    def _compare_attributes(
714        self,
715        actual: torch.Tensor,
716        expected: torch.Tensor,
717    ) -> None:
718        """Checks if the attributes of two tensors match.
719
720        Always checks
721
722        - the :attr:`~torch.Tensor.shape`,
723        - whether both inputs are quantized or not,
724        - and if they use the same quantization scheme.
725
726        Checks for
727
728        - :attr:`~torch.Tensor.layout`,
729        - :meth:`~torch.Tensor.stride`,
730        - :attr:`~torch.Tensor.device`, and
731        - :attr:`~torch.Tensor.dtype`
732
733        are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair.
734        """
735
736        def raise_mismatch_error(
737            attribute_name: str, actual_value: Any, expected_value: Any
738        ) -> NoReturn:
739            self._fail(
740                AssertionError,
741                f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.",
742            )
743
744        if actual.shape != expected.shape:
745            raise_mismatch_error("shape", actual.shape, expected.shape)
746
747        if actual.is_quantized != expected.is_quantized:
748            raise_mismatch_error(
749                "is_quantized", actual.is_quantized, expected.is_quantized
750            )
751        elif actual.is_quantized and actual.qscheme() != expected.qscheme():
752            raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme())
753
754        if actual.layout != expected.layout:
755            if self.check_layout:
756                raise_mismatch_error("layout", actual.layout, expected.layout)
757        elif (
758            actual.layout == torch.strided
759            and self.check_stride
760            and actual.stride() != expected.stride()
761        ):
762            raise_mismatch_error("stride()", actual.stride(), expected.stride())
763
764        if self.check_device and actual.device != expected.device:
765            raise_mismatch_error("device", actual.device, expected.device)
766
767        if self.check_dtype and actual.dtype != expected.dtype:
768            raise_mismatch_error("dtype", actual.dtype, expected.dtype)
769
770    def _equalize_attributes(
771        self, actual: torch.Tensor, expected: torch.Tensor
772    ) -> Tuple[torch.Tensor, torch.Tensor]:
773        """Equalizes some attributes of two tensors for value comparison.
774
775        If ``actual`` and ``expected`` are ...
776
777        - ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory.
778        - ... not of the same ``dtype``, they are promoted  to a common ``dtype`` (according to
779            :func:`torch.promote_types`).
780        - ... not of the same ``layout``, they are converted to strided tensors.
781
782        Args:
783            actual (Tensor): Actual tensor.
784            expected (Tensor): Expected tensor.
785
786        Returns:
787            (Tuple[Tensor, Tensor]): Equalized tensors.
788        """
789        # The comparison logic uses operators currently not supported by the MPS backends.
790        #  See https://github.com/pytorch/pytorch/issues/77144 for details.
791        # TODO: Remove this conversion as soon as all operations are supported natively by the MPS backend
792        if actual.is_mps or expected.is_mps:  # type: ignore[attr-defined]
793            actual = actual.cpu()
794            expected = expected.cpu()
795
796        if actual.device != expected.device:
797            actual = actual.cpu()
798            expected = expected.cpu()
799
800        if actual.dtype != expected.dtype:
801            actual_dtype = actual.dtype
802            expected_dtype = expected.dtype
803            # For uint64, this is not sound in general, which is why promote_types doesn't
804            # allow it, but for easy testing, we're unlikely to get confused
805            # by large uint64 overflowing into negative int64
806            if actual_dtype in [torch.uint64, torch.uint32, torch.uint16]:
807                actual_dtype = torch.int64
808            if expected_dtype in [torch.uint64, torch.uint32, torch.uint16]:
809                expected_dtype = torch.int64
810            dtype = torch.promote_types(actual_dtype, expected_dtype)
811            actual = actual.to(dtype)
812            expected = expected.to(dtype)
813
814        if actual.layout != expected.layout:
815            # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
816            actual = actual.to_dense() if actual.layout != torch.strided else actual
817            expected = (
818                expected.to_dense() if expected.layout != torch.strided else expected
819            )
820
821        return actual, expected
822
823    def _compare_values(self, actual: torch.Tensor, expected: torch.Tensor) -> None:
824        if actual.is_quantized:
825            compare_fn = self._compare_quantized_values
826        elif actual.is_sparse:
827            compare_fn = self._compare_sparse_coo_values
828        elif actual.layout in {
829            torch.sparse_csr,
830            torch.sparse_csc,
831            torch.sparse_bsr,
832            torch.sparse_bsc,
833        }:
834            compare_fn = self._compare_sparse_compressed_values
835        elif actual.layout == torch.jagged:
836            actual, expected = actual.values(), expected.values()
837            compare_fn = self._compare_regular_values_close
838        else:
839            compare_fn = self._compare_regular_values_close
840
841        compare_fn(
842            actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan
843        )
844
845    def _compare_quantized_values(
846        self,
847        actual: torch.Tensor,
848        expected: torch.Tensor,
849        *,
850        rtol: float,
851        atol: float,
852        equal_nan: bool,
853    ) -> None:
854        """Compares quantized tensors by comparing the :meth:`~torch.Tensor.dequantize`'d variants for closeness.
855
856        .. note::
857
858            A detailed discussion about why only the dequantized variant is checked for closeness rather than checking
859            the individual quantization parameters for closeness and the integer representation for equality can be
860            found in https://github.com/pytorch/pytorch/issues/68548.
861        """
862        return self._compare_regular_values_close(
863            actual.dequantize(),
864            expected.dequantize(),
865            rtol=rtol,
866            atol=atol,
867            equal_nan=equal_nan,
868            identifier=lambda default_identifier: f"Quantized {default_identifier.lower()}",
869        )
870
871    def _compare_sparse_coo_values(
872        self,
873        actual: torch.Tensor,
874        expected: torch.Tensor,
875        *,
876        rtol: float,
877        atol: float,
878        equal_nan: bool,
879    ) -> None:
880        """Compares sparse COO tensors by comparing
881
882        - the number of sparse dimensions,
883        - the number of non-zero elements (nnz) for equality,
884        - the indices for equality, and
885        - the values for closeness.
886        """
887        if actual.sparse_dim() != expected.sparse_dim():
888            self._fail(
889                AssertionError,
890                (
891                    f"The number of sparse dimensions in sparse COO tensors does not match: "
892                    f"{actual.sparse_dim()} != {expected.sparse_dim()}"
893                ),
894            )
895
896        if actual._nnz() != expected._nnz():
897            self._fail(
898                AssertionError,
899                (
900                    f"The number of specified values in sparse COO tensors does not match: "
901                    f"{actual._nnz()} != {expected._nnz()}"
902                ),
903            )
904
905        self._compare_regular_values_equal(
906            actual._indices(),
907            expected._indices(),
908            identifier="Sparse COO indices",
909        )
910        self._compare_regular_values_close(
911            actual._values(),
912            expected._values(),
913            rtol=rtol,
914            atol=atol,
915            equal_nan=equal_nan,
916            identifier="Sparse COO values",
917        )
918
919    def _compare_sparse_compressed_values(
920        self,
921        actual: torch.Tensor,
922        expected: torch.Tensor,
923        *,
924        rtol: float,
925        atol: float,
926        equal_nan: bool,
927    ) -> None:
928        """Compares sparse compressed tensors by comparing
929
930        - the number of non-zero elements (nnz) for equality,
931        - the plain indices for equality,
932        - the compressed indices for equality, and
933        - the values for closeness.
934        """
935        format_name, compressed_indices_method, plain_indices_method = {
936            torch.sparse_csr: (
937                "CSR",
938                torch.Tensor.crow_indices,
939                torch.Tensor.col_indices,
940            ),
941            torch.sparse_csc: (
942                "CSC",
943                torch.Tensor.ccol_indices,
944                torch.Tensor.row_indices,
945            ),
946            torch.sparse_bsr: (
947                "BSR",
948                torch.Tensor.crow_indices,
949                torch.Tensor.col_indices,
950            ),
951            torch.sparse_bsc: (
952                "BSC",
953                torch.Tensor.ccol_indices,
954                torch.Tensor.row_indices,
955            ),
956        }[actual.layout]
957
958        if actual._nnz() != expected._nnz():
959            self._fail(
960                AssertionError,
961                (
962                    f"The number of specified values in sparse {format_name} tensors does not match: "
963                    f"{actual._nnz()} != {expected._nnz()}"
964                ),
965            )
966
967        # Compressed and plain indices in the CSR / CSC / BSR / BSC sparse formates can be `torch.int32` _or_
968        # `torch.int64`. While the same dtype is enforced for the compressed and plain indices of a single tensor, it
969        # can be different between two tensors. Thus, we need to convert them to the same dtype, or the comparison will
970        # fail.
971        actual_compressed_indices = compressed_indices_method(actual)
972        expected_compressed_indices = compressed_indices_method(expected)
973        indices_dtype = torch.promote_types(
974            actual_compressed_indices.dtype, expected_compressed_indices.dtype
975        )
976
977        self._compare_regular_values_equal(
978            actual_compressed_indices.to(indices_dtype),
979            expected_compressed_indices.to(indices_dtype),
980            identifier=f"Sparse {format_name} {compressed_indices_method.__name__}",
981        )
982        self._compare_regular_values_equal(
983            plain_indices_method(actual).to(indices_dtype),
984            plain_indices_method(expected).to(indices_dtype),
985            identifier=f"Sparse {format_name} {plain_indices_method.__name__}",
986        )
987        self._compare_regular_values_close(
988            actual.values(),
989            expected.values(),
990            rtol=rtol,
991            atol=atol,
992            equal_nan=equal_nan,
993            identifier=f"Sparse {format_name} values",
994        )
995
996    def _compare_regular_values_equal(
997        self,
998        actual: torch.Tensor,
999        expected: torch.Tensor,
1000        *,
1001        equal_nan: bool = False,
1002        identifier: Optional[Union[str, Callable[[str], str]]] = None,
1003    ) -> None:
1004        """Checks if the values of two tensors are equal."""
1005        self._compare_regular_values_close(
1006            actual, expected, rtol=0, atol=0, equal_nan=equal_nan, identifier=identifier
1007        )
1008
1009    def _compare_regular_values_close(
1010        self,
1011        actual: torch.Tensor,
1012        expected: torch.Tensor,
1013        *,
1014        rtol: float,
1015        atol: float,
1016        equal_nan: bool,
1017        identifier: Optional[Union[str, Callable[[str], str]]] = None,
1018    ) -> None:
1019        """Checks if the values of two tensors are close up to a desired tolerance."""
1020        matches = torch.isclose(
1021            actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan
1022        )
1023        if torch.all(matches):
1024            return
1025
1026        if actual.shape == torch.Size([]):
1027            msg = make_scalar_mismatch_msg(
1028                actual.item(),
1029                expected.item(),
1030                rtol=rtol,
1031                atol=atol,
1032                identifier=identifier,
1033            )
1034        else:
1035            msg = make_tensor_mismatch_msg(
1036                actual, expected, matches, rtol=rtol, atol=atol, identifier=identifier
1037            )
1038        self._fail(AssertionError, msg)
1039
1040    def extra_repr(self) -> Sequence[str]:
1041        return (
1042            "rtol",
1043            "atol",
1044            "equal_nan",
1045            "check_device",
1046            "check_dtype",
1047            "check_layout",
1048            "check_stride",
1049        )
1050
1051
1052def originate_pairs(
1053    actual: Any,
1054    expected: Any,
1055    *,
1056    pair_types: Sequence[Type[Pair]],
1057    sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,),
1058    mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,),
1059    id: Tuple[Any, ...] = (),
1060    **options: Any,
1061) -> List[Pair]:
1062    """Originates pairs from the individual inputs.
1063
1064    ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
1065    :class:`~collections.abc.Mapping`'s. In this case the pairs are originated by recursing through them.
1066
1067    Args:
1068        actual (Any): Actual input.
1069        expected (Any): Expected input.
1070        pair_types (Sequence[Type[Pair]]): Sequence of pair types that will be tried to construct with the inputs.
1071            First successful pair will be used.
1072        sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise.
1073        mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise.
1074        id (Tuple[Any, ...]): Optional id of a pair that will be included in an error message.
1075        **options (Any): Options passed to each pair during construction.
1076
1077    Raises:
1078        ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Sequence`'s, but their
1079            length does not match.
1080        ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Mapping`'s, but their set of
1081            keys do not match.
1082        ErrorMeta: With :class`TypeError`, if no pair is able to handle the inputs.
1083        ErrorMeta: With any expected exception that happens during the construction of a pair.
1084
1085    Returns:
1086        (List[Pair]): Originated pairs.
1087    """
1088    # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
1089    # "a" == "a"[0][0]...
1090    if (
1091        isinstance(actual, sequence_types)
1092        and not isinstance(actual, str)
1093        and isinstance(expected, sequence_types)
1094        and not isinstance(expected, str)
1095    ):
1096        actual_len = len(actual)
1097        expected_len = len(expected)
1098        if actual_len != expected_len:
1099            raise ErrorMeta(
1100                AssertionError,
1101                f"The length of the sequences mismatch: {actual_len} != {expected_len}",
1102                id=id,
1103            )
1104
1105        pairs = []
1106        for idx in range(actual_len):
1107            pairs.extend(
1108                originate_pairs(
1109                    actual[idx],
1110                    expected[idx],
1111                    pair_types=pair_types,
1112                    sequence_types=sequence_types,
1113                    mapping_types=mapping_types,
1114                    id=(*id, idx),
1115                    **options,
1116                )
1117            )
1118        return pairs
1119
1120    elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types):
1121        actual_keys = set(actual.keys())
1122        expected_keys = set(expected.keys())
1123        if actual_keys != expected_keys:
1124            missing_keys = expected_keys - actual_keys
1125            additional_keys = actual_keys - expected_keys
1126            raise ErrorMeta(
1127                AssertionError,
1128                (
1129                    f"The keys of the mappings do not match:\n"
1130                    f"Missing keys in the actual mapping: {sorted(missing_keys)}\n"
1131                    f"Additional keys in the actual mapping: {sorted(additional_keys)}"
1132                ),
1133                id=id,
1134            )
1135
1136        keys: Collection = actual_keys
1137        # Since the origination aborts after the first failure, we try to be deterministic
1138        with contextlib.suppress(Exception):
1139            keys = sorted(keys)
1140
1141        pairs = []
1142        for key in keys:
1143            pairs.extend(
1144                originate_pairs(
1145                    actual[key],
1146                    expected[key],
1147                    pair_types=pair_types,
1148                    sequence_types=sequence_types,
1149                    mapping_types=mapping_types,
1150                    id=(*id, key),
1151                    **options,
1152                )
1153            )
1154        return pairs
1155
1156    else:
1157        for pair_type in pair_types:
1158            try:
1159                return [pair_type(actual, expected, id=id, **options)]
1160            # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the
1161            # inputs. Thus, we try the next pair type.
1162            except UnsupportedInputs:
1163                continue
1164            # Raising an `ErrorMeta` during origination is the orderly way to abort and so we simply re-raise it. This
1165            # is only in a separate branch, because the one below would also except it.
1166            except ErrorMeta:
1167                raise
1168            # Raising any other exception during origination is unexpected and will give some extra information about
1169            # what happened. If applicable, the exception should be expected in the future.
1170            except Exception as error:
1171                raise RuntimeError(
1172                    f"Originating a {pair_type.__name__}() at item {''.join(str([item]) for item in id)} with\n\n"
1173                    f"{type(actual).__name__}(): {actual}\n\n"
1174                    f"and\n\n"
1175                    f"{type(expected).__name__}(): {expected}\n\n"
1176                    f"resulted in the unexpected exception above. "
1177                    f"If you are a user and see this message during normal operation "
1178                    "please file an issue at https://github.com/pytorch/pytorch/issues. "
1179                    "If you are a developer and working on the comparison functions, "
1180                    "please except the previous error and raise an expressive `ErrorMeta` instead."
1181                ) from error
1182        else:
1183            raise ErrorMeta(
1184                TypeError,
1185                f"No comparison pair was able to handle inputs of type {type(actual)} and {type(expected)}.",
1186                id=id,
1187            )
1188
1189
1190def not_close_error_metas(
1191    actual: Any,
1192    expected: Any,
1193    *,
1194    pair_types: Sequence[Type[Pair]] = (ObjectPair,),
1195    sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,),
1196    mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,),
1197    **options: Any,
1198) -> List[ErrorMeta]:
1199    """Asserts that inputs are equal.
1200
1201    ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
1202    :class:`~collections.abc.Mapping`'s. In this case the comparison happens elementwise by recursing through them.
1203
1204    Args:
1205        actual (Any): Actual input.
1206        expected (Any): Expected input.
1207        pair_types (Sequence[Type[Pair]]): Sequence of :class:`Pair` types that will be tried to construct with the
1208            inputs. First successful pair will be used. Defaults to only using :class:`ObjectPair`.
1209        sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise.
1210        mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise.
1211        **options (Any): Options passed to each pair during construction.
1212    """
1213    # Hide this function from `pytest`'s traceback
1214    __tracebackhide__ = True
1215
1216    try:
1217        pairs = originate_pairs(
1218            actual,
1219            expected,
1220            pair_types=pair_types,
1221            sequence_types=sequence_types,
1222            mapping_types=mapping_types,
1223            **options,
1224        )
1225    except ErrorMeta as error_meta:
1226        # Explicitly raising from None to hide the internal traceback
1227        raise error_meta.to_error() from None  # noqa: RSE102
1228
1229    error_metas: List[ErrorMeta] = []
1230    for pair in pairs:
1231        try:
1232            pair.compare()
1233        except ErrorMeta as error_meta:
1234            error_metas.append(error_meta)
1235        # Raising any exception besides `ErrorMeta` while comparing is unexpected and will give some extra information
1236        # about what happened. If applicable, the exception should be expected in the future.
1237        except Exception as error:
1238            raise RuntimeError(
1239                f"Comparing\n\n"
1240                f"{pair}\n\n"
1241                f"resulted in the unexpected exception above. "
1242                f"If you are a user and see this message during normal operation "
1243                "please file an issue at https://github.com/pytorch/pytorch/issues. "
1244                "If you are a developer and working on the comparison functions, "
1245                "please except the previous error and raise an expressive `ErrorMeta` instead."
1246            ) from error
1247
1248    # [ErrorMeta Cycles]
1249    # ErrorMeta objects in this list capture
1250    # tracebacks that refer to the frame of this function.
1251    # The local variable `error_metas` refers to the error meta
1252    # objects, creating a reference cycle. Frames in the traceback
1253    # would not get freed until cycle collection, leaking cuda memory in tests.
1254    # We break the cycle by removing the reference to the error_meta objects
1255    # from this frame as it returns.
1256    error_metas = [error_metas]
1257    return error_metas.pop()
1258
1259
1260def assert_close(
1261    actual: Any,
1262    expected: Any,
1263    *,
1264    allow_subclasses: bool = True,
1265    rtol: Optional[float] = None,
1266    atol: Optional[float] = None,
1267    equal_nan: bool = False,
1268    check_device: bool = True,
1269    check_dtype: bool = True,
1270    check_layout: bool = True,
1271    check_stride: bool = False,
1272    msg: Optional[Union[str, Callable[[str], str]]] = None,
1273):
1274    r"""Asserts that ``actual`` and ``expected`` are close.
1275
1276    If ``actual`` and ``expected`` are strided, non-quantized, real-valued, and finite, they are considered close if
1277
1278    .. math::
1279
1280        \lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert
1281
1282    Non-finite values (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are
1283    only considered equal to each other if ``equal_nan`` is ``True``.
1284
1285    In addition, they are only considered close if they have the same
1286
1287    - :attr:`~torch.Tensor.device` (if ``check_device`` is ``True``),
1288    - ``dtype`` (if ``check_dtype`` is ``True``),
1289    - ``layout`` (if ``check_layout`` is ``True``), and
1290    - stride (if ``check_stride`` is ``True``).
1291
1292    If either ``actual`` or ``expected`` is a meta tensor, only the attribute checks will be performed.
1293
1294    If ``actual`` and ``expected`` are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members are
1295    checked individually. Indices, namely ``indices`` for COO, ``crow_indices`` and ``col_indices`` for CSR and BSR,
1296    or ``ccol_indices``  and ``row_indices`` for CSC and BSC layouts, respectively,
1297    are always checked for equality whereas the values are checked for closeness according to the definition above.
1298
1299    If ``actual`` and ``expected`` are quantized, they are considered close if they have the same
1300    :meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the
1301    definition above.
1302
1303    ``actual`` and ``expected`` can be :class:`~torch.Tensor`'s or any tensor-or-scalar-likes from which
1304    :class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. Except for Python scalars the input types
1305    have to be directly related. In addition, ``actual`` and ``expected`` can be :class:`~collections.abc.Sequence`'s
1306    or :class:`~collections.abc.Mapping`'s in which case they are considered close if their structure matches and all
1307    their elements are considered close according to the above definition.
1308
1309    .. note::
1310
1311        Python scalars are an exception to the type relation requirement, because their :func:`type`, i.e.
1312        :class:`int`, :class:`float`, and :class:`complex`, is equivalent to the ``dtype`` of a tensor-like. Thus,
1313        Python scalars of different types can be checked, but require ``check_dtype=False``.
1314
1315    Args:
1316        actual (Any): Actual input.
1317        expected (Any): Expected input.
1318        allow_subclasses (bool): If ``True`` (default) and except for Python scalars, inputs of directly related types
1319            are allowed. Otherwise type equality is required.
1320        rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
1321            values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
1322        atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
1323            values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
1324        equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal.
1325        check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same
1326            :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different
1327            :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared.
1328        check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
1329            check is disabled, tensors with different ``dtype``'s are promoted  to a common ``dtype`` (according to
1330            :func:`torch.promote_types`) before being compared.
1331        check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this
1332            check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
1333            compared.
1334        check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
1335        msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during
1336            the comparison. Can also passed as callable in which case it will be called with the generated message and
1337            should return the new message.
1338
1339    Raises:
1340        ValueError: If no :class:`torch.Tensor` can be constructed from an input.
1341        ValueError: If only ``rtol`` or ``atol`` is specified.
1342        AssertionError: If corresponding inputs are not Python scalars and are not directly related.
1343        AssertionError: If ``allow_subclasses`` is ``False``, but corresponding inputs are not Python scalars and have
1344            different types.
1345        AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
1346        AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match.
1347        AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
1348        AssertionError: If ``check_layout`` is ``True``, but corresponding tensors do not have the same
1349            :attr:`~torch.Tensor.layout`.
1350        AssertionError: If only one of corresponding tensors is quantized.
1351        AssertionError: If corresponding tensors are quantized, but have different :meth:`~torch.Tensor.qscheme`'s.
1352        AssertionError: If ``check_device`` is ``True``, but corresponding tensors are not on the same
1353            :attr:`~torch.Tensor.device`.
1354        AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``.
1355        AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride.
1356        AssertionError: If the values of corresponding tensors are not close according to the definition above.
1357
1358    The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching
1359    ``dtype``'s, the maximum of both tolerances is used.
1360
1361    +---------------------------+------------+----------+
1362    | ``dtype``                 | ``rtol``   | ``atol`` |
1363    +===========================+============+==========+
1364    | :attr:`~torch.float16`    | ``1e-3``   | ``1e-5`` |
1365    +---------------------------+------------+----------+
1366    | :attr:`~torch.bfloat16`   | ``1.6e-2`` | ``1e-5`` |
1367    +---------------------------+------------+----------+
1368    | :attr:`~torch.float32`    | ``1.3e-6`` | ``1e-5`` |
1369    +---------------------------+------------+----------+
1370    | :attr:`~torch.float64`    | ``1e-7``   | ``1e-7`` |
1371    +---------------------------+------------+----------+
1372    | :attr:`~torch.complex32`  | ``1e-3``   | ``1e-5`` |
1373    +---------------------------+------------+----------+
1374    | :attr:`~torch.complex64`  | ``1.3e-6`` | ``1e-5`` |
1375    +---------------------------+------------+----------+
1376    | :attr:`~torch.complex128` | ``1e-7``   | ``1e-7`` |
1377    +---------------------------+------------+----------+
1378    | :attr:`~torch.quint8`     | ``1.3e-6`` | ``1e-5`` |
1379    +---------------------------+------------+----------+
1380    | :attr:`~torch.quint2x4`   | ``1.3e-6`` | ``1e-5`` |
1381    +---------------------------+------------+----------+
1382    | :attr:`~torch.quint4x2`   | ``1.3e-6`` | ``1e-5`` |
1383    +---------------------------+------------+----------+
1384    | :attr:`~torch.qint8`      | ``1.3e-6`` | ``1e-5`` |
1385    +---------------------------+------------+----------+
1386    | :attr:`~torch.qint32`     | ``1.3e-6`` | ``1e-5`` |
1387    +---------------------------+------------+----------+
1388    | other                     | ``0.0``    | ``0.0``  |
1389    +---------------------------+------------+----------+
1390
1391    .. note::
1392
1393        :func:`~torch.testing.assert_close` is highly configurable with strict default settings. Users are encouraged
1394        to :func:`~functools.partial` it to fit their use case. For example, if an equality check is needed, one might
1395        define an ``assert_equal`` that uses zero tolerances for every ``dtype`` by default:
1396
1397        >>> import functools
1398        >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
1399        >>> assert_equal(1e-9, 1e-10)
1400        Traceback (most recent call last):
1401        ...
1402        AssertionError: Scalars are not equal!
1403        <BLANKLINE>
1404        Expected 1e-10 but got 1e-09.
1405        Absolute difference: 9.000000000000001e-10
1406        Relative difference: 9.0
1407
1408    Examples:
1409        >>> # tensor to tensor comparison
1410        >>> expected = torch.tensor([1e0, 1e-1, 1e-2])
1411        >>> actual = torch.acos(torch.cos(expected))
1412        >>> torch.testing.assert_close(actual, expected)
1413
1414        >>> # scalar to scalar comparison
1415        >>> import math
1416        >>> expected = math.sqrt(2.0)
1417        >>> actual = 2.0 / math.sqrt(2.0)
1418        >>> torch.testing.assert_close(actual, expected)
1419
1420        >>> # numpy array to numpy array comparison
1421        >>> import numpy as np
1422        >>> expected = np.array([1e0, 1e-1, 1e-2])
1423        >>> actual = np.arccos(np.cos(expected))
1424        >>> torch.testing.assert_close(actual, expected)
1425
1426        >>> # sequence to sequence comparison
1427        >>> import numpy as np
1428        >>> # The types of the sequences do not have to match. They only have to have the same
1429        >>> # length and their elements have to match.
1430        >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
1431        >>> actual = tuple(expected)
1432        >>> torch.testing.assert_close(actual, expected)
1433
1434        >>> # mapping to mapping comparison
1435        >>> from collections import OrderedDict
1436        >>> import numpy as np
1437        >>> foo = torch.tensor(1.0)
1438        >>> bar = 2.0
1439        >>> baz = np.array(3.0)
1440        >>> # The types and a possible ordering of mappings do not have to match. They only
1441        >>> # have to have the same set of keys and their elements have to match.
1442        >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
1443        >>> actual = {"baz": baz, "bar": bar, "foo": foo}
1444        >>> torch.testing.assert_close(actual, expected)
1445
1446        >>> expected = torch.tensor([1.0, 2.0, 3.0])
1447        >>> actual = expected.clone()
1448        >>> # By default, directly related instances can be compared
1449        >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
1450        >>> # This check can be made more strict with allow_subclasses=False
1451        >>> torch.testing.assert_close(
1452        ...     torch.nn.Parameter(actual), expected, allow_subclasses=False
1453        ... )
1454        Traceback (most recent call last):
1455        ...
1456        TypeError: No comparison pair was able to handle inputs of type
1457        <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>.
1458        >>> # If the inputs are not directly related, they are never considered close
1459        >>> torch.testing.assert_close(actual.numpy(), expected)
1460        Traceback (most recent call last):
1461        ...
1462        TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'>
1463        and <class 'torch.Tensor'>.
1464        >>> # Exceptions to these rules are Python scalars. They can be checked regardless of
1465        >>> # their type if check_dtype=False.
1466        >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
1467
1468        >>> # NaN != NaN by default.
1469        >>> expected = torch.tensor(float("Nan"))
1470        >>> actual = expected.clone()
1471        >>> torch.testing.assert_close(actual, expected)
1472        Traceback (most recent call last):
1473        ...
1474        AssertionError: Scalars are not close!
1475        <BLANKLINE>
1476        Expected nan but got nan.
1477        Absolute difference: nan (up to 1e-05 allowed)
1478        Relative difference: nan (up to 1.3e-06 allowed)
1479        >>> torch.testing.assert_close(actual, expected, equal_nan=True)
1480
1481        >>> expected = torch.tensor([1.0, 2.0, 3.0])
1482        >>> actual = torch.tensor([1.0, 4.0, 5.0])
1483        >>> # The default error message can be overwritten.
1484        >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
1485        Traceback (most recent call last):
1486        ...
1487        AssertionError: Argh, the tensors are not close!
1488        >>> # If msg is a callable, it can be used to augment the generated message with
1489        >>> # extra information
1490        >>> torch.testing.assert_close(
1491        ...     actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
1492        ... )
1493        Traceback (most recent call last):
1494        ...
1495        AssertionError: Header
1496        <BLANKLINE>
1497        Tensor-likes are not close!
1498        <BLANKLINE>
1499        Mismatched elements: 2 / 3 (66.7%)
1500        Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
1501        Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)
1502        <BLANKLINE>
1503        Footer
1504    """
1505    # Hide this function from `pytest`'s traceback
1506    __tracebackhide__ = True
1507
1508    error_metas = not_close_error_metas(
1509        actual,
1510        expected,
1511        pair_types=(
1512            NonePair,
1513            BooleanPair,
1514            NumberPair,
1515            TensorLikePair,
1516        ),
1517        allow_subclasses=allow_subclasses,
1518        rtol=rtol,
1519        atol=atol,
1520        equal_nan=equal_nan,
1521        check_device=check_device,
1522        check_dtype=check_dtype,
1523        check_layout=check_layout,
1524        check_stride=check_stride,
1525        msg=msg,
1526    )
1527
1528    if error_metas:
1529        # TODO: compose all metas into one AssertionError
1530        raise error_metas[0].to_error(msg)
1531
1532
1533@deprecated(
1534    "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. "
1535    "Please use `torch.testing.assert_close()` instead. "
1536    "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.",
1537    category=FutureWarning,
1538)
1539def assert_allclose(
1540    actual: Any,
1541    expected: Any,
1542    rtol: Optional[float] = None,
1543    atol: Optional[float] = None,
1544    equal_nan: bool = True,
1545    msg: str = "",
1546) -> None:
1547    """
1548    .. warning::
1549
1550       :func:`torch.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release.
1551       Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions
1552       `here <https://github.com/pytorch/pytorch/issues/61844>`_.
1553    """
1554    if not isinstance(actual, torch.Tensor):
1555        actual = torch.tensor(actual)
1556    if not isinstance(expected, torch.Tensor):
1557        expected = torch.tensor(expected, dtype=actual.dtype)
1558
1559    if rtol is None and atol is None:
1560        rtol, atol = default_tolerances(
1561            actual,
1562            expected,
1563            dtype_precisions={
1564                torch.float16: (1e-3, 1e-3),
1565                torch.float32: (1e-4, 1e-5),
1566                torch.float64: (1e-5, 1e-8),
1567            },
1568        )
1569
1570    torch.testing.assert_close(
1571        actual,
1572        expected,
1573        rtol=rtol,
1574        atol=atol,
1575        equal_nan=equal_nan,
1576        check_device=True,
1577        check_dtype=False,
1578        check_stride=False,
1579        msg=msg or None,
1580    )
1581