1# mypy: allow-untyped-defs 2import abc 3import cmath 4import collections.abc 5import contextlib 6from typing import ( 7 Any, 8 Callable, 9 Collection, 10 Dict, 11 List, 12 NoReturn, 13 Optional, 14 Sequence, 15 Tuple, 16 Type, 17 Union, 18) 19from typing_extensions import deprecated 20 21import torch 22 23 24try: 25 import numpy as np 26 27 HAS_NUMPY = True 28except ModuleNotFoundError: 29 HAS_NUMPY = False 30 np = None # type: ignore[assignment] 31 32 33class ErrorMeta(Exception): 34 """Internal testing exception that makes that carries error metadata.""" 35 36 def __init__( 37 self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = () 38 ) -> None: 39 super().__init__( 40 "If you are a user and see this message during normal operation " 41 "please file an issue at https://github.com/pytorch/pytorch/issues. " 42 "If you are a developer and working on the comparison functions, please `raise ErrorMeta.to_error()` " 43 "for user facing errors." 44 ) 45 self.type = type 46 self.msg = msg 47 self.id = id 48 49 def to_error( 50 self, msg: Optional[Union[str, Callable[[str], str]]] = None 51 ) -> Exception: 52 if not isinstance(msg, str): 53 generated_msg = self.msg 54 if self.id: 55 generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}" 56 57 msg = msg(generated_msg) if callable(msg) else generated_msg 58 59 return self.type(msg) 60 61 62# Some analysis of tolerance by logging tests from test_torch.py can be found in 63# https://github.com/pytorch/pytorch/pull/32538. 64# {dtype: (rtol, atol)} 65_DTYPE_PRECISIONS = { 66 torch.float16: (0.001, 1e-5), 67 torch.bfloat16: (0.016, 1e-5), 68 torch.float32: (1.3e-6, 1e-5), 69 torch.float64: (1e-7, 1e-7), 70 torch.complex32: (0.001, 1e-5), 71 torch.complex64: (1.3e-6, 1e-5), 72 torch.complex128: (1e-7, 1e-7), 73} 74# The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in 75# their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values` 76_DTYPE_PRECISIONS.update( 77 dict.fromkeys( 78 (torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32), 79 _DTYPE_PRECISIONS[torch.float32], 80 ) 81) 82 83 84def default_tolerances( 85 *inputs: Union[torch.Tensor, torch.dtype], 86 dtype_precisions: Optional[Dict[torch.dtype, Tuple[float, float]]] = None, 87) -> Tuple[float, float]: 88 """Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype. 89 90 See :func:`assert_close` for a table of the default tolerance for each dtype. 91 92 Returns: 93 (Tuple[float, float]): Loosest tolerances of all input dtypes. 94 """ 95 dtypes = [] 96 for input in inputs: 97 if isinstance(input, torch.Tensor): 98 dtypes.append(input.dtype) 99 elif isinstance(input, torch.dtype): 100 dtypes.append(input) 101 else: 102 raise TypeError( 103 f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead." 104 ) 105 dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS 106 rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes]) 107 return max(rtols), max(atols) 108 109 110def get_tolerances( 111 *inputs: Union[torch.Tensor, torch.dtype], 112 rtol: Optional[float], 113 atol: Optional[float], 114 id: Tuple[Any, ...] = (), 115) -> Tuple[float, float]: 116 """Gets absolute and relative to be used for numeric comparisons. 117 118 If both ``rtol`` and ``atol`` are specified, this is a no-op. If both are not specified, the return value of 119 :func:`default_tolerances` is used. 120 121 Raises: 122 ErrorMeta: With :class:`ValueError`, if only ``rtol`` or ``atol`` is specified. 123 124 Returns: 125 (Tuple[float, float]): Valid absolute and relative tolerances. 126 """ 127 if (rtol is None) ^ (atol is None): 128 # We require both tolerance to be omitted or specified, because specifying only one might lead to surprising 129 # results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0. 130 raise ErrorMeta( 131 ValueError, 132 f"Both 'rtol' and 'atol' must be either specified or omitted, " 133 f"but got no {'rtol' if rtol is None else 'atol'}.", 134 id=id, 135 ) 136 elif rtol is not None and atol is not None: 137 return rtol, atol 138 else: 139 return default_tolerances(*inputs) 140 141 142def _make_mismatch_msg( 143 *, 144 default_identifier: str, 145 identifier: Optional[Union[str, Callable[[str], str]]] = None, 146 extra: Optional[str] = None, 147 abs_diff: float, 148 abs_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None, 149 atol: float, 150 rel_diff: float, 151 rel_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None, 152 rtol: float, 153) -> str: 154 """Makes a mismatch error message for numeric values. 155 156 Args: 157 default_identifier (str): Default description of the compared values, e.g. "Tensor-likes". 158 identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides 159 ``default_identifier``. Can be passed as callable in which case it will be called with 160 ``default_identifier`` to create the description at runtime. 161 extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics. 162 abs_diff (float): Absolute difference. 163 abs_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the absolute difference. 164 atol (float): Allowed absolute tolerance. Will only be added to mismatch statistics if it or ``rtol`` are 165 ``> 0``. 166 rel_diff (float): Relative difference. 167 rel_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the relative difference. 168 rtol (float): Allowed relative tolerance. Will only be added to mismatch statistics if it or ``atol`` are 169 ``> 0``. 170 """ 171 equality = rtol == 0 and atol == 0 172 173 def make_diff_msg( 174 *, 175 type: str, 176 diff: float, 177 idx: Optional[Union[int, Tuple[int, ...]]], 178 tol: float, 179 ) -> str: 180 if idx is None: 181 msg = f"{type.title()} difference: {diff}" 182 else: 183 msg = f"Greatest {type} difference: {diff} at index {idx}" 184 if not equality: 185 msg += f" (up to {tol} allowed)" 186 return msg + "\n" 187 188 if identifier is None: 189 identifier = default_identifier 190 elif callable(identifier): 191 identifier = identifier(default_identifier) 192 193 msg = f"{identifier} are not {'equal' if equality else 'close'}!\n\n" 194 195 if extra: 196 msg += f"{extra.strip()}\n" 197 198 msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol) 199 msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol) 200 201 return msg.strip() 202 203 204def make_scalar_mismatch_msg( 205 actual: Union[bool, int, float, complex], 206 expected: Union[bool, int, float, complex], 207 *, 208 rtol: float, 209 atol: float, 210 identifier: Optional[Union[str, Callable[[str], str]]] = None, 211) -> str: 212 """Makes a mismatch error message for scalars. 213 214 Args: 215 actual (Union[bool, int, float, complex]): Actual scalar. 216 expected (Union[bool, int, float, complex]): Expected scalar. 217 rtol (float): Relative tolerance. 218 atol (float): Absolute tolerance. 219 identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the scalars. Can be passed 220 as callable in which case it will be called by the default value to create the description at runtime. 221 Defaults to "Scalars". 222 """ 223 abs_diff = abs(actual - expected) 224 rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected) 225 return _make_mismatch_msg( 226 default_identifier="Scalars", 227 identifier=identifier, 228 extra=f"Expected {expected} but got {actual}.", 229 abs_diff=abs_diff, 230 atol=atol, 231 rel_diff=rel_diff, 232 rtol=rtol, 233 ) 234 235 236def make_tensor_mismatch_msg( 237 actual: torch.Tensor, 238 expected: torch.Tensor, 239 matches: torch.Tensor, 240 *, 241 rtol: float, 242 atol: float, 243 identifier: Optional[Union[str, Callable[[str], str]]] = None, 244): 245 """Makes a mismatch error message for tensors. 246 247 Args: 248 actual (torch.Tensor): Actual tensor. 249 expected (torch.Tensor): Expected tensor. 250 matches (torch.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the 251 location of matches. 252 rtol (float): Relative tolerance. 253 atol (float): Absolute tolerance. 254 identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the tensors. Can be passed 255 as callable in which case it will be called by the default value to create the description at runtime. 256 Defaults to "Tensor-likes". 257 """ 258 259 def unravel_flat_index(flat_index: int) -> Tuple[int, ...]: 260 if not matches.shape: 261 return () 262 263 inverse_index = [] 264 for size in matches.shape[::-1]: 265 div, mod = divmod(flat_index, size) 266 flat_index = div 267 inverse_index.append(mod) 268 269 return tuple(inverse_index[::-1]) 270 271 number_of_elements = matches.numel() 272 total_mismatches = number_of_elements - int(torch.sum(matches)) 273 extra = ( 274 f"Mismatched elements: {total_mismatches} / {number_of_elements} " 275 f"({total_mismatches / number_of_elements:.1%})" 276 ) 277 278 actual_flat = actual.flatten() 279 expected_flat = expected.flatten() 280 matches_flat = matches.flatten() 281 282 if not actual.dtype.is_floating_point and not actual.dtype.is_complex: 283 # TODO: Instead of always upcasting to int64, it would be sufficient to cast to the next higher dtype to avoid 284 # overflow 285 actual_flat = actual_flat.to(torch.int64) 286 expected_flat = expected_flat.to(torch.int64) 287 288 abs_diff = torch.abs(actual_flat - expected_flat) 289 # Ensure that only mismatches are used for the max_abs_diff computation 290 abs_diff[matches_flat] = 0 291 max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0) 292 293 rel_diff = abs_diff / torch.abs(expected_flat) 294 # Ensure that only mismatches are used for the max_rel_diff computation 295 rel_diff[matches_flat] = 0 296 max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0) 297 return _make_mismatch_msg( 298 default_identifier="Tensor-likes", 299 identifier=identifier, 300 extra=extra, 301 abs_diff=max_abs_diff.item(), 302 abs_diff_idx=unravel_flat_index(int(max_abs_diff_flat_idx)), 303 atol=atol, 304 rel_diff=max_rel_diff.item(), 305 rel_diff_idx=unravel_flat_index(int(max_rel_diff_flat_idx)), 306 rtol=rtol, 307 ) 308 309 310class UnsupportedInputs(Exception): # noqa: B903 311 """Exception to be raised during the construction of a :class:`Pair` in case it doesn't support the inputs.""" 312 313 314class Pair(abc.ABC): 315 """ABC for all comparison pairs to be used in conjunction with :func:`assert_equal`. 316 317 Each subclass needs to overwrite :meth:`Pair.compare` that performs the actual comparison. 318 319 Each pair receives **all** options, so select the ones applicable for the subclass and forward the rest to the 320 super class. Raising an :class:`UnsupportedInputs` during constructions indicates that the pair is not able to 321 handle the inputs and the next pair type will be tried. 322 323 All other errors should be raised as :class:`ErrorMeta`. After the instantiation, :meth:`Pair._make_error_meta` can 324 be used to automatically handle overwriting the message with a user supplied one and id handling. 325 """ 326 327 def __init__( 328 self, 329 actual: Any, 330 expected: Any, 331 *, 332 id: Tuple[Any, ...] = (), 333 **unknown_parameters: Any, 334 ) -> None: 335 self.actual = actual 336 self.expected = expected 337 self.id = id 338 self._unknown_parameters = unknown_parameters 339 340 @staticmethod 341 def _inputs_not_supported() -> NoReturn: 342 raise UnsupportedInputs 343 344 @staticmethod 345 def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, Tuple[Type, ...]]): 346 """Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise.""" 347 if not all(isinstance(input, cls) for input in inputs): 348 Pair._inputs_not_supported() 349 350 def _fail( 351 self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = () 352 ) -> NoReturn: 353 """Raises an :class:`ErrorMeta` from a given exception type and message and the stored id. 354 355 .. warning:: 356 357 If you use this before the ``super().__init__(...)`` call in the constructor, you have to pass the ``id`` 358 explicitly. 359 """ 360 raise ErrorMeta(type, msg, id=self.id if not id and hasattr(self, "id") else id) 361 362 @abc.abstractmethod 363 def compare(self) -> None: 364 """Compares the inputs and raises an :class`ErrorMeta` in case they mismatch.""" 365 366 def extra_repr(self) -> Sequence[Union[str, Tuple[str, Any]]]: 367 """Returns extra information that will be included in the representation. 368 369 Should be overwritten by all subclasses that use additional options. The representation of the object will only 370 be surfaced in case we encounter an unexpected error and thus should help debug the issue. Can be a sequence of 371 key-value-pairs or attribute names. 372 """ 373 return [] 374 375 def __repr__(self) -> str: 376 head = f"{type(self).__name__}(" 377 tail = ")" 378 body = [ 379 f" {name}={value!s}," 380 for name, value in [ 381 ("id", self.id), 382 ("actual", self.actual), 383 ("expected", self.expected), 384 *[ 385 (extra, getattr(self, extra)) if isinstance(extra, str) else extra 386 for extra in self.extra_repr() 387 ], 388 ] 389 ] 390 return "\n".join((head, *body, *tail)) 391 392 393class ObjectPair(Pair): 394 """Pair for any type of inputs that will be compared with the `==` operator. 395 396 .. note:: 397 398 Since this will instantiate for any kind of inputs, it should only be used as fallback after all other pairs 399 couldn't handle the inputs. 400 401 """ 402 403 def compare(self) -> None: 404 try: 405 equal = self.actual == self.expected 406 except Exception as error: 407 # We are not using `self._raise_error_meta` here since we need the exception chaining 408 raise ErrorMeta( 409 ValueError, 410 f"{self.actual} == {self.expected} failed with:\n{error}.", 411 id=self.id, 412 ) from error 413 414 if not equal: 415 self._fail(AssertionError, f"{self.actual} != {self.expected}") 416 417 418class NonePair(Pair): 419 """Pair for ``None`` inputs.""" 420 421 def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None: 422 if not (actual is None or expected is None): 423 self._inputs_not_supported() 424 425 super().__init__(actual, expected, **other_parameters) 426 427 def compare(self) -> None: 428 if not (self.actual is None and self.expected is None): 429 self._fail( 430 AssertionError, f"None mismatch: {self.actual} is not {self.expected}" 431 ) 432 433 434class BooleanPair(Pair): 435 """Pair for :class:`bool` inputs. 436 437 .. note:: 438 439 If ``numpy`` is available, also handles :class:`numpy.bool_` inputs. 440 441 """ 442 443 def __init__( 444 self, 445 actual: Any, 446 expected: Any, 447 *, 448 id: Tuple[Any, ...], 449 **other_parameters: Any, 450 ) -> None: 451 actual, expected = self._process_inputs(actual, expected, id=id) 452 super().__init__(actual, expected, **other_parameters) 453 454 @property 455 def _supported_types(self) -> Tuple[Type, ...]: 456 cls: List[Type] = [bool] 457 if HAS_NUMPY: 458 cls.append(np.bool_) 459 return tuple(cls) 460 461 def _process_inputs( 462 self, actual: Any, expected: Any, *, id: Tuple[Any, ...] 463 ) -> Tuple[bool, bool]: 464 self._check_inputs_isinstance(actual, expected, cls=self._supported_types) 465 actual, expected = ( 466 self._to_bool(bool_like, id=id) for bool_like in (actual, expected) 467 ) 468 return actual, expected 469 470 def _to_bool(self, bool_like: Any, *, id: Tuple[Any, ...]) -> bool: 471 if isinstance(bool_like, bool): 472 return bool_like 473 elif isinstance(bool_like, np.bool_): 474 return bool_like.item() 475 else: 476 raise ErrorMeta( 477 TypeError, f"Unknown boolean type {type(bool_like)}.", id=id 478 ) 479 480 def compare(self) -> None: 481 if self.actual is not self.expected: 482 self._fail( 483 AssertionError, 484 f"Booleans mismatch: {self.actual} is not {self.expected}", 485 ) 486 487 488class NumberPair(Pair): 489 """Pair for Python number (:class:`int`, :class:`float`, and :class:`complex`) inputs. 490 491 .. note:: 492 493 If ``numpy`` is available, also handles :class:`numpy.number` inputs. 494 495 Kwargs: 496 rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default 497 values based on the type are selected with the below table. 498 atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default 499 values based on the type are selected with the below table. 500 equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``. 501 check_dtype (bool): If ``True``, the type of the inputs will be checked for equality. Defaults to ``False``. 502 503 The following table displays correspondence between Python number type and the ``torch.dtype``'s. See 504 :func:`assert_close` for the corresponding tolerances. 505 506 +------------------+-------------------------------+ 507 | ``type`` | corresponding ``torch.dtype`` | 508 +==================+===============================+ 509 | :class:`int` | :attr:`~torch.int64` | 510 +------------------+-------------------------------+ 511 | :class:`float` | :attr:`~torch.float64` | 512 +------------------+-------------------------------+ 513 | :class:`complex` | :attr:`~torch.complex64` | 514 +------------------+-------------------------------+ 515 """ 516 517 _TYPE_TO_DTYPE = { 518 int: torch.int64, 519 float: torch.float64, 520 complex: torch.complex128, 521 } 522 _NUMBER_TYPES = tuple(_TYPE_TO_DTYPE.keys()) 523 524 def __init__( 525 self, 526 actual: Any, 527 expected: Any, 528 *, 529 id: Tuple[Any, ...] = (), 530 rtol: Optional[float] = None, 531 atol: Optional[float] = None, 532 equal_nan: bool = False, 533 check_dtype: bool = False, 534 **other_parameters: Any, 535 ) -> None: 536 actual, expected = self._process_inputs(actual, expected, id=id) 537 super().__init__(actual, expected, id=id, **other_parameters) 538 539 self.rtol, self.atol = get_tolerances( 540 *[self._TYPE_TO_DTYPE[type(input)] for input in (actual, expected)], 541 rtol=rtol, 542 atol=atol, 543 id=id, 544 ) 545 self.equal_nan = equal_nan 546 self.check_dtype = check_dtype 547 548 @property 549 def _supported_types(self) -> Tuple[Type, ...]: 550 cls = list(self._NUMBER_TYPES) 551 if HAS_NUMPY: 552 cls.append(np.number) 553 return tuple(cls) 554 555 def _process_inputs( 556 self, actual: Any, expected: Any, *, id: Tuple[Any, ...] 557 ) -> Tuple[Union[int, float, complex], Union[int, float, complex]]: 558 self._check_inputs_isinstance(actual, expected, cls=self._supported_types) 559 actual, expected = ( 560 self._to_number(number_like, id=id) for number_like in (actual, expected) 561 ) 562 return actual, expected 563 564 def _to_number( 565 self, number_like: Any, *, id: Tuple[Any, ...] 566 ) -> Union[int, float, complex]: 567 if HAS_NUMPY and isinstance(number_like, np.number): 568 return number_like.item() 569 elif isinstance(number_like, self._NUMBER_TYPES): 570 return number_like # type: ignore[return-value] 571 else: 572 raise ErrorMeta( 573 TypeError, f"Unknown number type {type(number_like)}.", id=id 574 ) 575 576 def compare(self) -> None: 577 if self.check_dtype and type(self.actual) is not type(self.expected): 578 self._fail( 579 AssertionError, 580 f"The (d)types do not match: {type(self.actual)} != {type(self.expected)}.", 581 ) 582 583 if self.actual == self.expected: 584 return 585 586 if self.equal_nan and cmath.isnan(self.actual) and cmath.isnan(self.expected): 587 return 588 589 abs_diff = abs(self.actual - self.expected) 590 tolerance = self.atol + self.rtol * abs(self.expected) 591 592 if cmath.isfinite(abs_diff) and abs_diff <= tolerance: 593 return 594 595 self._fail( 596 AssertionError, 597 make_scalar_mismatch_msg( 598 self.actual, self.expected, rtol=self.rtol, atol=self.atol 599 ), 600 ) 601 602 def extra_repr(self) -> Sequence[str]: 603 return ( 604 "rtol", 605 "atol", 606 "equal_nan", 607 "check_dtype", 608 ) 609 610 611class TensorLikePair(Pair): 612 """Pair for :class:`torch.Tensor`-like inputs. 613 614 Kwargs: 615 allow_subclasses (bool): 616 rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default 617 values based on the type are selected. See :func:assert_close: for details. 618 atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default 619 values based on the type are selected. See :func:assert_close: for details. 620 equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``. 621 check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same 622 :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different 623 :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared. 624 check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this 625 check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to 626 :func:`torch.promote_types`) before being compared. 627 check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this 628 check is disabled, tensors with different ``layout``'s are converted to strided tensors before being 629 compared. 630 check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. 631 """ 632 633 def __init__( 634 self, 635 actual: Any, 636 expected: Any, 637 *, 638 id: Tuple[Any, ...] = (), 639 allow_subclasses: bool = True, 640 rtol: Optional[float] = None, 641 atol: Optional[float] = None, 642 equal_nan: bool = False, 643 check_device: bool = True, 644 check_dtype: bool = True, 645 check_layout: bool = True, 646 check_stride: bool = False, 647 **other_parameters: Any, 648 ): 649 actual, expected = self._process_inputs( 650 actual, expected, id=id, allow_subclasses=allow_subclasses 651 ) 652 super().__init__(actual, expected, id=id, **other_parameters) 653 654 self.rtol, self.atol = get_tolerances( 655 actual, expected, rtol=rtol, atol=atol, id=self.id 656 ) 657 self.equal_nan = equal_nan 658 self.check_device = check_device 659 self.check_dtype = check_dtype 660 self.check_layout = check_layout 661 self.check_stride = check_stride 662 663 def _process_inputs( 664 self, actual: Any, expected: Any, *, id: Tuple[Any, ...], allow_subclasses: bool 665 ) -> Tuple[torch.Tensor, torch.Tensor]: 666 directly_related = isinstance(actual, type(expected)) or isinstance( 667 expected, type(actual) 668 ) 669 if not directly_related: 670 self._inputs_not_supported() 671 672 if not allow_subclasses and type(actual) is not type(expected): 673 self._inputs_not_supported() 674 675 actual, expected = (self._to_tensor(input) for input in (actual, expected)) 676 for tensor in (actual, expected): 677 self._check_supported(tensor, id=id) 678 return actual, expected 679 680 def _to_tensor(self, tensor_like: Any) -> torch.Tensor: 681 if isinstance(tensor_like, torch.Tensor): 682 return tensor_like 683 684 try: 685 return torch.as_tensor(tensor_like) 686 except Exception: 687 self._inputs_not_supported() 688 689 def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None: 690 if tensor.layout not in { 691 torch.strided, 692 torch.jagged, 693 torch.sparse_coo, 694 torch.sparse_csr, 695 torch.sparse_csc, 696 torch.sparse_bsr, 697 torch.sparse_bsc, 698 }: 699 raise ErrorMeta( 700 ValueError, f"Unsupported tensor layout {tensor.layout}", id=id 701 ) 702 703 def compare(self) -> None: 704 actual, expected = self.actual, self.expected 705 706 self._compare_attributes(actual, expected) 707 if any(input.device.type == "meta" for input in (actual, expected)): 708 return 709 710 actual, expected = self._equalize_attributes(actual, expected) 711 self._compare_values(actual, expected) 712 713 def _compare_attributes( 714 self, 715 actual: torch.Tensor, 716 expected: torch.Tensor, 717 ) -> None: 718 """Checks if the attributes of two tensors match. 719 720 Always checks 721 722 - the :attr:`~torch.Tensor.shape`, 723 - whether both inputs are quantized or not, 724 - and if they use the same quantization scheme. 725 726 Checks for 727 728 - :attr:`~torch.Tensor.layout`, 729 - :meth:`~torch.Tensor.stride`, 730 - :attr:`~torch.Tensor.device`, and 731 - :attr:`~torch.Tensor.dtype` 732 733 are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair. 734 """ 735 736 def raise_mismatch_error( 737 attribute_name: str, actual_value: Any, expected_value: Any 738 ) -> NoReturn: 739 self._fail( 740 AssertionError, 741 f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.", 742 ) 743 744 if actual.shape != expected.shape: 745 raise_mismatch_error("shape", actual.shape, expected.shape) 746 747 if actual.is_quantized != expected.is_quantized: 748 raise_mismatch_error( 749 "is_quantized", actual.is_quantized, expected.is_quantized 750 ) 751 elif actual.is_quantized and actual.qscheme() != expected.qscheme(): 752 raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme()) 753 754 if actual.layout != expected.layout: 755 if self.check_layout: 756 raise_mismatch_error("layout", actual.layout, expected.layout) 757 elif ( 758 actual.layout == torch.strided 759 and self.check_stride 760 and actual.stride() != expected.stride() 761 ): 762 raise_mismatch_error("stride()", actual.stride(), expected.stride()) 763 764 if self.check_device and actual.device != expected.device: 765 raise_mismatch_error("device", actual.device, expected.device) 766 767 if self.check_dtype and actual.dtype != expected.dtype: 768 raise_mismatch_error("dtype", actual.dtype, expected.dtype) 769 770 def _equalize_attributes( 771 self, actual: torch.Tensor, expected: torch.Tensor 772 ) -> Tuple[torch.Tensor, torch.Tensor]: 773 """Equalizes some attributes of two tensors for value comparison. 774 775 If ``actual`` and ``expected`` are ... 776 777 - ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory. 778 - ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to 779 :func:`torch.promote_types`). 780 - ... not of the same ``layout``, they are converted to strided tensors. 781 782 Args: 783 actual (Tensor): Actual tensor. 784 expected (Tensor): Expected tensor. 785 786 Returns: 787 (Tuple[Tensor, Tensor]): Equalized tensors. 788 """ 789 # The comparison logic uses operators currently not supported by the MPS backends. 790 # See https://github.com/pytorch/pytorch/issues/77144 for details. 791 # TODO: Remove this conversion as soon as all operations are supported natively by the MPS backend 792 if actual.is_mps or expected.is_mps: # type: ignore[attr-defined] 793 actual = actual.cpu() 794 expected = expected.cpu() 795 796 if actual.device != expected.device: 797 actual = actual.cpu() 798 expected = expected.cpu() 799 800 if actual.dtype != expected.dtype: 801 actual_dtype = actual.dtype 802 expected_dtype = expected.dtype 803 # For uint64, this is not sound in general, which is why promote_types doesn't 804 # allow it, but for easy testing, we're unlikely to get confused 805 # by large uint64 overflowing into negative int64 806 if actual_dtype in [torch.uint64, torch.uint32, torch.uint16]: 807 actual_dtype = torch.int64 808 if expected_dtype in [torch.uint64, torch.uint32, torch.uint16]: 809 expected_dtype = torch.int64 810 dtype = torch.promote_types(actual_dtype, expected_dtype) 811 actual = actual.to(dtype) 812 expected = expected.to(dtype) 813 814 if actual.layout != expected.layout: 815 # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided 816 actual = actual.to_dense() if actual.layout != torch.strided else actual 817 expected = ( 818 expected.to_dense() if expected.layout != torch.strided else expected 819 ) 820 821 return actual, expected 822 823 def _compare_values(self, actual: torch.Tensor, expected: torch.Tensor) -> None: 824 if actual.is_quantized: 825 compare_fn = self._compare_quantized_values 826 elif actual.is_sparse: 827 compare_fn = self._compare_sparse_coo_values 828 elif actual.layout in { 829 torch.sparse_csr, 830 torch.sparse_csc, 831 torch.sparse_bsr, 832 torch.sparse_bsc, 833 }: 834 compare_fn = self._compare_sparse_compressed_values 835 elif actual.layout == torch.jagged: 836 actual, expected = actual.values(), expected.values() 837 compare_fn = self._compare_regular_values_close 838 else: 839 compare_fn = self._compare_regular_values_close 840 841 compare_fn( 842 actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan 843 ) 844 845 def _compare_quantized_values( 846 self, 847 actual: torch.Tensor, 848 expected: torch.Tensor, 849 *, 850 rtol: float, 851 atol: float, 852 equal_nan: bool, 853 ) -> None: 854 """Compares quantized tensors by comparing the :meth:`~torch.Tensor.dequantize`'d variants for closeness. 855 856 .. note:: 857 858 A detailed discussion about why only the dequantized variant is checked for closeness rather than checking 859 the individual quantization parameters for closeness and the integer representation for equality can be 860 found in https://github.com/pytorch/pytorch/issues/68548. 861 """ 862 return self._compare_regular_values_close( 863 actual.dequantize(), 864 expected.dequantize(), 865 rtol=rtol, 866 atol=atol, 867 equal_nan=equal_nan, 868 identifier=lambda default_identifier: f"Quantized {default_identifier.lower()}", 869 ) 870 871 def _compare_sparse_coo_values( 872 self, 873 actual: torch.Tensor, 874 expected: torch.Tensor, 875 *, 876 rtol: float, 877 atol: float, 878 equal_nan: bool, 879 ) -> None: 880 """Compares sparse COO tensors by comparing 881 882 - the number of sparse dimensions, 883 - the number of non-zero elements (nnz) for equality, 884 - the indices for equality, and 885 - the values for closeness. 886 """ 887 if actual.sparse_dim() != expected.sparse_dim(): 888 self._fail( 889 AssertionError, 890 ( 891 f"The number of sparse dimensions in sparse COO tensors does not match: " 892 f"{actual.sparse_dim()} != {expected.sparse_dim()}" 893 ), 894 ) 895 896 if actual._nnz() != expected._nnz(): 897 self._fail( 898 AssertionError, 899 ( 900 f"The number of specified values in sparse COO tensors does not match: " 901 f"{actual._nnz()} != {expected._nnz()}" 902 ), 903 ) 904 905 self._compare_regular_values_equal( 906 actual._indices(), 907 expected._indices(), 908 identifier="Sparse COO indices", 909 ) 910 self._compare_regular_values_close( 911 actual._values(), 912 expected._values(), 913 rtol=rtol, 914 atol=atol, 915 equal_nan=equal_nan, 916 identifier="Sparse COO values", 917 ) 918 919 def _compare_sparse_compressed_values( 920 self, 921 actual: torch.Tensor, 922 expected: torch.Tensor, 923 *, 924 rtol: float, 925 atol: float, 926 equal_nan: bool, 927 ) -> None: 928 """Compares sparse compressed tensors by comparing 929 930 - the number of non-zero elements (nnz) for equality, 931 - the plain indices for equality, 932 - the compressed indices for equality, and 933 - the values for closeness. 934 """ 935 format_name, compressed_indices_method, plain_indices_method = { 936 torch.sparse_csr: ( 937 "CSR", 938 torch.Tensor.crow_indices, 939 torch.Tensor.col_indices, 940 ), 941 torch.sparse_csc: ( 942 "CSC", 943 torch.Tensor.ccol_indices, 944 torch.Tensor.row_indices, 945 ), 946 torch.sparse_bsr: ( 947 "BSR", 948 torch.Tensor.crow_indices, 949 torch.Tensor.col_indices, 950 ), 951 torch.sparse_bsc: ( 952 "BSC", 953 torch.Tensor.ccol_indices, 954 torch.Tensor.row_indices, 955 ), 956 }[actual.layout] 957 958 if actual._nnz() != expected._nnz(): 959 self._fail( 960 AssertionError, 961 ( 962 f"The number of specified values in sparse {format_name} tensors does not match: " 963 f"{actual._nnz()} != {expected._nnz()}" 964 ), 965 ) 966 967 # Compressed and plain indices in the CSR / CSC / BSR / BSC sparse formates can be `torch.int32` _or_ 968 # `torch.int64`. While the same dtype is enforced for the compressed and plain indices of a single tensor, it 969 # can be different between two tensors. Thus, we need to convert them to the same dtype, or the comparison will 970 # fail. 971 actual_compressed_indices = compressed_indices_method(actual) 972 expected_compressed_indices = compressed_indices_method(expected) 973 indices_dtype = torch.promote_types( 974 actual_compressed_indices.dtype, expected_compressed_indices.dtype 975 ) 976 977 self._compare_regular_values_equal( 978 actual_compressed_indices.to(indices_dtype), 979 expected_compressed_indices.to(indices_dtype), 980 identifier=f"Sparse {format_name} {compressed_indices_method.__name__}", 981 ) 982 self._compare_regular_values_equal( 983 plain_indices_method(actual).to(indices_dtype), 984 plain_indices_method(expected).to(indices_dtype), 985 identifier=f"Sparse {format_name} {plain_indices_method.__name__}", 986 ) 987 self._compare_regular_values_close( 988 actual.values(), 989 expected.values(), 990 rtol=rtol, 991 atol=atol, 992 equal_nan=equal_nan, 993 identifier=f"Sparse {format_name} values", 994 ) 995 996 def _compare_regular_values_equal( 997 self, 998 actual: torch.Tensor, 999 expected: torch.Tensor, 1000 *, 1001 equal_nan: bool = False, 1002 identifier: Optional[Union[str, Callable[[str], str]]] = None, 1003 ) -> None: 1004 """Checks if the values of two tensors are equal.""" 1005 self._compare_regular_values_close( 1006 actual, expected, rtol=0, atol=0, equal_nan=equal_nan, identifier=identifier 1007 ) 1008 1009 def _compare_regular_values_close( 1010 self, 1011 actual: torch.Tensor, 1012 expected: torch.Tensor, 1013 *, 1014 rtol: float, 1015 atol: float, 1016 equal_nan: bool, 1017 identifier: Optional[Union[str, Callable[[str], str]]] = None, 1018 ) -> None: 1019 """Checks if the values of two tensors are close up to a desired tolerance.""" 1020 matches = torch.isclose( 1021 actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan 1022 ) 1023 if torch.all(matches): 1024 return 1025 1026 if actual.shape == torch.Size([]): 1027 msg = make_scalar_mismatch_msg( 1028 actual.item(), 1029 expected.item(), 1030 rtol=rtol, 1031 atol=atol, 1032 identifier=identifier, 1033 ) 1034 else: 1035 msg = make_tensor_mismatch_msg( 1036 actual, expected, matches, rtol=rtol, atol=atol, identifier=identifier 1037 ) 1038 self._fail(AssertionError, msg) 1039 1040 def extra_repr(self) -> Sequence[str]: 1041 return ( 1042 "rtol", 1043 "atol", 1044 "equal_nan", 1045 "check_device", 1046 "check_dtype", 1047 "check_layout", 1048 "check_stride", 1049 ) 1050 1051 1052def originate_pairs( 1053 actual: Any, 1054 expected: Any, 1055 *, 1056 pair_types: Sequence[Type[Pair]], 1057 sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,), 1058 mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,), 1059 id: Tuple[Any, ...] = (), 1060 **options: Any, 1061) -> List[Pair]: 1062 """Originates pairs from the individual inputs. 1063 1064 ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or 1065 :class:`~collections.abc.Mapping`'s. In this case the pairs are originated by recursing through them. 1066 1067 Args: 1068 actual (Any): Actual input. 1069 expected (Any): Expected input. 1070 pair_types (Sequence[Type[Pair]]): Sequence of pair types that will be tried to construct with the inputs. 1071 First successful pair will be used. 1072 sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise. 1073 mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise. 1074 id (Tuple[Any, ...]): Optional id of a pair that will be included in an error message. 1075 **options (Any): Options passed to each pair during construction. 1076 1077 Raises: 1078 ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Sequence`'s, but their 1079 length does not match. 1080 ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Mapping`'s, but their set of 1081 keys do not match. 1082 ErrorMeta: With :class`TypeError`, if no pair is able to handle the inputs. 1083 ErrorMeta: With any expected exception that happens during the construction of a pair. 1084 1085 Returns: 1086 (List[Pair]): Originated pairs. 1087 """ 1088 # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: 1089 # "a" == "a"[0][0]... 1090 if ( 1091 isinstance(actual, sequence_types) 1092 and not isinstance(actual, str) 1093 and isinstance(expected, sequence_types) 1094 and not isinstance(expected, str) 1095 ): 1096 actual_len = len(actual) 1097 expected_len = len(expected) 1098 if actual_len != expected_len: 1099 raise ErrorMeta( 1100 AssertionError, 1101 f"The length of the sequences mismatch: {actual_len} != {expected_len}", 1102 id=id, 1103 ) 1104 1105 pairs = [] 1106 for idx in range(actual_len): 1107 pairs.extend( 1108 originate_pairs( 1109 actual[idx], 1110 expected[idx], 1111 pair_types=pair_types, 1112 sequence_types=sequence_types, 1113 mapping_types=mapping_types, 1114 id=(*id, idx), 1115 **options, 1116 ) 1117 ) 1118 return pairs 1119 1120 elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types): 1121 actual_keys = set(actual.keys()) 1122 expected_keys = set(expected.keys()) 1123 if actual_keys != expected_keys: 1124 missing_keys = expected_keys - actual_keys 1125 additional_keys = actual_keys - expected_keys 1126 raise ErrorMeta( 1127 AssertionError, 1128 ( 1129 f"The keys of the mappings do not match:\n" 1130 f"Missing keys in the actual mapping: {sorted(missing_keys)}\n" 1131 f"Additional keys in the actual mapping: {sorted(additional_keys)}" 1132 ), 1133 id=id, 1134 ) 1135 1136 keys: Collection = actual_keys 1137 # Since the origination aborts after the first failure, we try to be deterministic 1138 with contextlib.suppress(Exception): 1139 keys = sorted(keys) 1140 1141 pairs = [] 1142 for key in keys: 1143 pairs.extend( 1144 originate_pairs( 1145 actual[key], 1146 expected[key], 1147 pair_types=pair_types, 1148 sequence_types=sequence_types, 1149 mapping_types=mapping_types, 1150 id=(*id, key), 1151 **options, 1152 ) 1153 ) 1154 return pairs 1155 1156 else: 1157 for pair_type in pair_types: 1158 try: 1159 return [pair_type(actual, expected, id=id, **options)] 1160 # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the 1161 # inputs. Thus, we try the next pair type. 1162 except UnsupportedInputs: 1163 continue 1164 # Raising an `ErrorMeta` during origination is the orderly way to abort and so we simply re-raise it. This 1165 # is only in a separate branch, because the one below would also except it. 1166 except ErrorMeta: 1167 raise 1168 # Raising any other exception during origination is unexpected and will give some extra information about 1169 # what happened. If applicable, the exception should be expected in the future. 1170 except Exception as error: 1171 raise RuntimeError( 1172 f"Originating a {pair_type.__name__}() at item {''.join(str([item]) for item in id)} with\n\n" 1173 f"{type(actual).__name__}(): {actual}\n\n" 1174 f"and\n\n" 1175 f"{type(expected).__name__}(): {expected}\n\n" 1176 f"resulted in the unexpected exception above. " 1177 f"If you are a user and see this message during normal operation " 1178 "please file an issue at https://github.com/pytorch/pytorch/issues. " 1179 "If you are a developer and working on the comparison functions, " 1180 "please except the previous error and raise an expressive `ErrorMeta` instead." 1181 ) from error 1182 else: 1183 raise ErrorMeta( 1184 TypeError, 1185 f"No comparison pair was able to handle inputs of type {type(actual)} and {type(expected)}.", 1186 id=id, 1187 ) 1188 1189 1190def not_close_error_metas( 1191 actual: Any, 1192 expected: Any, 1193 *, 1194 pair_types: Sequence[Type[Pair]] = (ObjectPair,), 1195 sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,), 1196 mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,), 1197 **options: Any, 1198) -> List[ErrorMeta]: 1199 """Asserts that inputs are equal. 1200 1201 ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or 1202 :class:`~collections.abc.Mapping`'s. In this case the comparison happens elementwise by recursing through them. 1203 1204 Args: 1205 actual (Any): Actual input. 1206 expected (Any): Expected input. 1207 pair_types (Sequence[Type[Pair]]): Sequence of :class:`Pair` types that will be tried to construct with the 1208 inputs. First successful pair will be used. Defaults to only using :class:`ObjectPair`. 1209 sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise. 1210 mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise. 1211 **options (Any): Options passed to each pair during construction. 1212 """ 1213 # Hide this function from `pytest`'s traceback 1214 __tracebackhide__ = True 1215 1216 try: 1217 pairs = originate_pairs( 1218 actual, 1219 expected, 1220 pair_types=pair_types, 1221 sequence_types=sequence_types, 1222 mapping_types=mapping_types, 1223 **options, 1224 ) 1225 except ErrorMeta as error_meta: 1226 # Explicitly raising from None to hide the internal traceback 1227 raise error_meta.to_error() from None # noqa: RSE102 1228 1229 error_metas: List[ErrorMeta] = [] 1230 for pair in pairs: 1231 try: 1232 pair.compare() 1233 except ErrorMeta as error_meta: 1234 error_metas.append(error_meta) 1235 # Raising any exception besides `ErrorMeta` while comparing is unexpected and will give some extra information 1236 # about what happened. If applicable, the exception should be expected in the future. 1237 except Exception as error: 1238 raise RuntimeError( 1239 f"Comparing\n\n" 1240 f"{pair}\n\n" 1241 f"resulted in the unexpected exception above. " 1242 f"If you are a user and see this message during normal operation " 1243 "please file an issue at https://github.com/pytorch/pytorch/issues. " 1244 "If you are a developer and working on the comparison functions, " 1245 "please except the previous error and raise an expressive `ErrorMeta` instead." 1246 ) from error 1247 1248 # [ErrorMeta Cycles] 1249 # ErrorMeta objects in this list capture 1250 # tracebacks that refer to the frame of this function. 1251 # The local variable `error_metas` refers to the error meta 1252 # objects, creating a reference cycle. Frames in the traceback 1253 # would not get freed until cycle collection, leaking cuda memory in tests. 1254 # We break the cycle by removing the reference to the error_meta objects 1255 # from this frame as it returns. 1256 error_metas = [error_metas] 1257 return error_metas.pop() 1258 1259 1260def assert_close( 1261 actual: Any, 1262 expected: Any, 1263 *, 1264 allow_subclasses: bool = True, 1265 rtol: Optional[float] = None, 1266 atol: Optional[float] = None, 1267 equal_nan: bool = False, 1268 check_device: bool = True, 1269 check_dtype: bool = True, 1270 check_layout: bool = True, 1271 check_stride: bool = False, 1272 msg: Optional[Union[str, Callable[[str], str]]] = None, 1273): 1274 r"""Asserts that ``actual`` and ``expected`` are close. 1275 1276 If ``actual`` and ``expected`` are strided, non-quantized, real-valued, and finite, they are considered close if 1277 1278 .. math:: 1279 1280 \lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert 1281 1282 Non-finite values (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are 1283 only considered equal to each other if ``equal_nan`` is ``True``. 1284 1285 In addition, they are only considered close if they have the same 1286 1287 - :attr:`~torch.Tensor.device` (if ``check_device`` is ``True``), 1288 - ``dtype`` (if ``check_dtype`` is ``True``), 1289 - ``layout`` (if ``check_layout`` is ``True``), and 1290 - stride (if ``check_stride`` is ``True``). 1291 1292 If either ``actual`` or ``expected`` is a meta tensor, only the attribute checks will be performed. 1293 1294 If ``actual`` and ``expected`` are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members are 1295 checked individually. Indices, namely ``indices`` for COO, ``crow_indices`` and ``col_indices`` for CSR and BSR, 1296 or ``ccol_indices`` and ``row_indices`` for CSC and BSC layouts, respectively, 1297 are always checked for equality whereas the values are checked for closeness according to the definition above. 1298 1299 If ``actual`` and ``expected`` are quantized, they are considered close if they have the same 1300 :meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the 1301 definition above. 1302 1303 ``actual`` and ``expected`` can be :class:`~torch.Tensor`'s or any tensor-or-scalar-likes from which 1304 :class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. Except for Python scalars the input types 1305 have to be directly related. In addition, ``actual`` and ``expected`` can be :class:`~collections.abc.Sequence`'s 1306 or :class:`~collections.abc.Mapping`'s in which case they are considered close if their structure matches and all 1307 their elements are considered close according to the above definition. 1308 1309 .. note:: 1310 1311 Python scalars are an exception to the type relation requirement, because their :func:`type`, i.e. 1312 :class:`int`, :class:`float`, and :class:`complex`, is equivalent to the ``dtype`` of a tensor-like. Thus, 1313 Python scalars of different types can be checked, but require ``check_dtype=False``. 1314 1315 Args: 1316 actual (Any): Actual input. 1317 expected (Any): Expected input. 1318 allow_subclasses (bool): If ``True`` (default) and except for Python scalars, inputs of directly related types 1319 are allowed. Otherwise type equality is required. 1320 rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default 1321 values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. 1322 atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default 1323 values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. 1324 equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal. 1325 check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same 1326 :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different 1327 :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared. 1328 check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this 1329 check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to 1330 :func:`torch.promote_types`) before being compared. 1331 check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this 1332 check is disabled, tensors with different ``layout``'s are converted to strided tensors before being 1333 compared. 1334 check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. 1335 msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during 1336 the comparison. Can also passed as callable in which case it will be called with the generated message and 1337 should return the new message. 1338 1339 Raises: 1340 ValueError: If no :class:`torch.Tensor` can be constructed from an input. 1341 ValueError: If only ``rtol`` or ``atol`` is specified. 1342 AssertionError: If corresponding inputs are not Python scalars and are not directly related. 1343 AssertionError: If ``allow_subclasses`` is ``False``, but corresponding inputs are not Python scalars and have 1344 different types. 1345 AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match. 1346 AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match. 1347 AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`. 1348 AssertionError: If ``check_layout`` is ``True``, but corresponding tensors do not have the same 1349 :attr:`~torch.Tensor.layout`. 1350 AssertionError: If only one of corresponding tensors is quantized. 1351 AssertionError: If corresponding tensors are quantized, but have different :meth:`~torch.Tensor.qscheme`'s. 1352 AssertionError: If ``check_device`` is ``True``, but corresponding tensors are not on the same 1353 :attr:`~torch.Tensor.device`. 1354 AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``. 1355 AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride. 1356 AssertionError: If the values of corresponding tensors are not close according to the definition above. 1357 1358 The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching 1359 ``dtype``'s, the maximum of both tolerances is used. 1360 1361 +---------------------------+------------+----------+ 1362 | ``dtype`` | ``rtol`` | ``atol`` | 1363 +===========================+============+==========+ 1364 | :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` | 1365 +---------------------------+------------+----------+ 1366 | :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` | 1367 +---------------------------+------------+----------+ 1368 | :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` | 1369 +---------------------------+------------+----------+ 1370 | :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` | 1371 +---------------------------+------------+----------+ 1372 | :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` | 1373 +---------------------------+------------+----------+ 1374 | :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` | 1375 +---------------------------+------------+----------+ 1376 | :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` | 1377 +---------------------------+------------+----------+ 1378 | :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` | 1379 +---------------------------+------------+----------+ 1380 | :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` | 1381 +---------------------------+------------+----------+ 1382 | :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` | 1383 +---------------------------+------------+----------+ 1384 | :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` | 1385 +---------------------------+------------+----------+ 1386 | :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` | 1387 +---------------------------+------------+----------+ 1388 | other | ``0.0`` | ``0.0`` | 1389 +---------------------------+------------+----------+ 1390 1391 .. note:: 1392 1393 :func:`~torch.testing.assert_close` is highly configurable with strict default settings. Users are encouraged 1394 to :func:`~functools.partial` it to fit their use case. For example, if an equality check is needed, one might 1395 define an ``assert_equal`` that uses zero tolerances for every ``dtype`` by default: 1396 1397 >>> import functools 1398 >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) 1399 >>> assert_equal(1e-9, 1e-10) 1400 Traceback (most recent call last): 1401 ... 1402 AssertionError: Scalars are not equal! 1403 <BLANKLINE> 1404 Expected 1e-10 but got 1e-09. 1405 Absolute difference: 9.000000000000001e-10 1406 Relative difference: 9.0 1407 1408 Examples: 1409 >>> # tensor to tensor comparison 1410 >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) 1411 >>> actual = torch.acos(torch.cos(expected)) 1412 >>> torch.testing.assert_close(actual, expected) 1413 1414 >>> # scalar to scalar comparison 1415 >>> import math 1416 >>> expected = math.sqrt(2.0) 1417 >>> actual = 2.0 / math.sqrt(2.0) 1418 >>> torch.testing.assert_close(actual, expected) 1419 1420 >>> # numpy array to numpy array comparison 1421 >>> import numpy as np 1422 >>> expected = np.array([1e0, 1e-1, 1e-2]) 1423 >>> actual = np.arccos(np.cos(expected)) 1424 >>> torch.testing.assert_close(actual, expected) 1425 1426 >>> # sequence to sequence comparison 1427 >>> import numpy as np 1428 >>> # The types of the sequences do not have to match. They only have to have the same 1429 >>> # length and their elements have to match. 1430 >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] 1431 >>> actual = tuple(expected) 1432 >>> torch.testing.assert_close(actual, expected) 1433 1434 >>> # mapping to mapping comparison 1435 >>> from collections import OrderedDict 1436 >>> import numpy as np 1437 >>> foo = torch.tensor(1.0) 1438 >>> bar = 2.0 1439 >>> baz = np.array(3.0) 1440 >>> # The types and a possible ordering of mappings do not have to match. They only 1441 >>> # have to have the same set of keys and their elements have to match. 1442 >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) 1443 >>> actual = {"baz": baz, "bar": bar, "foo": foo} 1444 >>> torch.testing.assert_close(actual, expected) 1445 1446 >>> expected = torch.tensor([1.0, 2.0, 3.0]) 1447 >>> actual = expected.clone() 1448 >>> # By default, directly related instances can be compared 1449 >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) 1450 >>> # This check can be made more strict with allow_subclasses=False 1451 >>> torch.testing.assert_close( 1452 ... torch.nn.Parameter(actual), expected, allow_subclasses=False 1453 ... ) 1454 Traceback (most recent call last): 1455 ... 1456 TypeError: No comparison pair was able to handle inputs of type 1457 <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>. 1458 >>> # If the inputs are not directly related, they are never considered close 1459 >>> torch.testing.assert_close(actual.numpy(), expected) 1460 Traceback (most recent call last): 1461 ... 1462 TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'> 1463 and <class 'torch.Tensor'>. 1464 >>> # Exceptions to these rules are Python scalars. They can be checked regardless of 1465 >>> # their type if check_dtype=False. 1466 >>> torch.testing.assert_close(1.0, 1, check_dtype=False) 1467 1468 >>> # NaN != NaN by default. 1469 >>> expected = torch.tensor(float("Nan")) 1470 >>> actual = expected.clone() 1471 >>> torch.testing.assert_close(actual, expected) 1472 Traceback (most recent call last): 1473 ... 1474 AssertionError: Scalars are not close! 1475 <BLANKLINE> 1476 Expected nan but got nan. 1477 Absolute difference: nan (up to 1e-05 allowed) 1478 Relative difference: nan (up to 1.3e-06 allowed) 1479 >>> torch.testing.assert_close(actual, expected, equal_nan=True) 1480 1481 >>> expected = torch.tensor([1.0, 2.0, 3.0]) 1482 >>> actual = torch.tensor([1.0, 4.0, 5.0]) 1483 >>> # The default error message can be overwritten. 1484 >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") 1485 Traceback (most recent call last): 1486 ... 1487 AssertionError: Argh, the tensors are not close! 1488 >>> # If msg is a callable, it can be used to augment the generated message with 1489 >>> # extra information 1490 >>> torch.testing.assert_close( 1491 ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" 1492 ... ) 1493 Traceback (most recent call last): 1494 ... 1495 AssertionError: Header 1496 <BLANKLINE> 1497 Tensor-likes are not close! 1498 <BLANKLINE> 1499 Mismatched elements: 2 / 3 (66.7%) 1500 Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) 1501 Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) 1502 <BLANKLINE> 1503 Footer 1504 """ 1505 # Hide this function from `pytest`'s traceback 1506 __tracebackhide__ = True 1507 1508 error_metas = not_close_error_metas( 1509 actual, 1510 expected, 1511 pair_types=( 1512 NonePair, 1513 BooleanPair, 1514 NumberPair, 1515 TensorLikePair, 1516 ), 1517 allow_subclasses=allow_subclasses, 1518 rtol=rtol, 1519 atol=atol, 1520 equal_nan=equal_nan, 1521 check_device=check_device, 1522 check_dtype=check_dtype, 1523 check_layout=check_layout, 1524 check_stride=check_stride, 1525 msg=msg, 1526 ) 1527 1528 if error_metas: 1529 # TODO: compose all metas into one AssertionError 1530 raise error_metas[0].to_error(msg) 1531 1532 1533@deprecated( 1534 "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. " 1535 "Please use `torch.testing.assert_close()` instead. " 1536 "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.", 1537 category=FutureWarning, 1538) 1539def assert_allclose( 1540 actual: Any, 1541 expected: Any, 1542 rtol: Optional[float] = None, 1543 atol: Optional[float] = None, 1544 equal_nan: bool = True, 1545 msg: str = "", 1546) -> None: 1547 """ 1548 .. warning:: 1549 1550 :func:`torch.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release. 1551 Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions 1552 `here <https://github.com/pytorch/pytorch/issues/61844>`_. 1553 """ 1554 if not isinstance(actual, torch.Tensor): 1555 actual = torch.tensor(actual) 1556 if not isinstance(expected, torch.Tensor): 1557 expected = torch.tensor(expected, dtype=actual.dtype) 1558 1559 if rtol is None and atol is None: 1560 rtol, atol = default_tolerances( 1561 actual, 1562 expected, 1563 dtype_precisions={ 1564 torch.float16: (1e-3, 1e-3), 1565 torch.float32: (1e-4, 1e-5), 1566 torch.float64: (1e-5, 1e-8), 1567 }, 1568 ) 1569 1570 torch.testing.assert_close( 1571 actual, 1572 expected, 1573 rtol=rtol, 1574 atol=atol, 1575 equal_nan=equal_nan, 1576 check_device=True, 1577 check_dtype=False, 1578 check_stride=False, 1579 msg=msg or None, 1580 ) 1581