1# mypy: ignore-errors 2 3""" 4Utility function to facilitate testing. 5 6""" 7import contextlib 8import gc 9import operator 10import os 11import platform 12import pprint 13import re 14import shutil 15import sys 16import warnings 17from functools import wraps 18from io import StringIO 19from tempfile import mkdtemp, mkstemp 20from warnings import WarningMessage 21 22import torch._numpy as np 23from torch._numpy import arange, asarray as asanyarray, empty, float32, intp, ndarray 24 25 26__all__ = [ 27 "assert_equal", 28 "assert_almost_equal", 29 "assert_approx_equal", 30 "assert_array_equal", 31 "assert_array_less", 32 "assert_string_equal", 33 "assert_", 34 "assert_array_almost_equal", 35 "build_err_msg", 36 "decorate_methods", 37 "print_assert_equal", 38 "verbose", 39 "assert_", 40 "assert_array_almost_equal_nulp", 41 "assert_raises_regex", 42 "assert_array_max_ulp", 43 "assert_warns", 44 "assert_no_warnings", 45 "assert_allclose", 46 "IgnoreException", 47 "clear_and_catch_warnings", 48 "temppath", 49 "tempdir", 50 "IS_PYPY", 51 "HAS_REFCOUNT", 52 "IS_WASM", 53 "suppress_warnings", 54 "assert_array_compare", 55 "assert_no_gc_cycles", 56 "break_cycles", 57 "IS_PYSTON", 58] 59 60 61verbose = 0 62 63IS_WASM = platform.machine() in ["wasm32", "wasm64"] 64IS_PYPY = sys.implementation.name == "pypy" 65IS_PYSTON = hasattr(sys, "pyston_version_info") 66HAS_REFCOUNT = getattr(sys, "getrefcount", None) is not None and not IS_PYSTON 67 68 69def assert_(val, msg=""): 70 """ 71 Assert that works in release mode. 72 Accepts callable msg to allow deferring evaluation until failure. 73 74 The Python built-in ``assert`` does not work when executing code in 75 optimized mode (the ``-O`` flag) - no byte-code is generated for it. 76 77 For documentation on usage, refer to the Python documentation. 78 79 """ 80 __tracebackhide__ = True # Hide traceback for py.test 81 if not val: 82 try: 83 smsg = msg() 84 except TypeError: 85 smsg = msg 86 raise AssertionError(smsg) 87 88 89def gisnan(x): 90 return np.isnan(x) 91 92 93def gisfinite(x): 94 return np.isfinite(x) 95 96 97def gisinf(x): 98 return np.isinf(x) 99 100 101def build_err_msg( 102 arrays, 103 err_msg, 104 header="Items are not equal:", 105 verbose=True, 106 names=("ACTUAL", "DESIRED"), 107 precision=8, 108): 109 msg = ["\n" + header] 110 if err_msg: 111 if err_msg.find("\n") == -1 and len(err_msg) < 79 - len(header): 112 msg = [msg[0] + " " + err_msg] 113 else: 114 msg.append(err_msg) 115 if verbose: 116 for i, a in enumerate(arrays): 117 if isinstance(a, ndarray): 118 # precision argument is only needed if the objects are ndarrays 119 # r_func = partial(array_repr, precision=precision) 120 r_func = ndarray.__repr__ 121 else: 122 r_func = repr 123 124 try: 125 r = r_func(a) 126 except Exception as exc: 127 r = f"[repr failed for <{type(a).__name__}>: {exc}]" 128 if r.count("\n") > 3: 129 r = "\n".join(r.splitlines()[:3]) 130 r += "..." 131 msg.append(f" {names[i]}: {r}") 132 return "\n".join(msg) 133 134 135def assert_equal(actual, desired, err_msg="", verbose=True): 136 """ 137 Raises an AssertionError if two objects are not equal. 138 139 Given two objects (scalars, lists, tuples, dictionaries or numpy arrays), 140 check that all elements of these objects are equal. An exception is raised 141 at the first conflicting values. 142 143 When one of `actual` and `desired` is a scalar and the other is array_like, 144 the function checks that each element of the array_like object is equal to 145 the scalar. 146 147 This function handles NaN comparisons as if NaN was a "normal" number. 148 That is, AssertionError is not raised if both objects have NaNs in the same 149 positions. This is in contrast to the IEEE standard on NaNs, which says 150 that NaN compared to anything must return False. 151 152 Parameters 153 ---------- 154 actual : array_like 155 The object to check. 156 desired : array_like 157 The expected object. 158 err_msg : str, optional 159 The error message to be printed in case of failure. 160 verbose : bool, optional 161 If True, the conflicting values are appended to the error message. 162 163 Raises 164 ------ 165 AssertionError 166 If actual and desired are not equal. 167 168 Examples 169 -------- 170 >>> np.testing.assert_equal([4,5], [4,6]) 171 Traceback (most recent call last): 172 ... 173 AssertionError: 174 Items are not equal: 175 item=1 176 ACTUAL: 5 177 DESIRED: 6 178 179 The following comparison does not raise an exception. There are NaNs 180 in the inputs, but they are in the same positions. 181 182 >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) 183 184 """ 185 __tracebackhide__ = True # Hide traceback for py.test 186 187 num_nones = sum([actual is None, desired is None]) 188 if num_nones == 1: 189 raise AssertionError(f"Not equal: {actual} != {desired}") 190 elif num_nones == 2: 191 return True 192 # else, carry on 193 194 if isinstance(actual, np.DType) or isinstance(desired, np.DType): 195 result = actual == desired 196 if not result: 197 raise AssertionError(f"Not equal: {actual} != {desired}") 198 else: 199 return True 200 201 if isinstance(desired, str) and isinstance(actual, str): 202 assert actual == desired 203 return 204 205 if isinstance(desired, dict): 206 if not isinstance(actual, dict): 207 raise AssertionError(repr(type(actual))) 208 assert_equal(len(actual), len(desired), err_msg, verbose) 209 for k in desired.keys(): 210 if k not in actual: 211 raise AssertionError(repr(k)) 212 assert_equal(actual[k], desired[k], f"key={k!r}\n{err_msg}", verbose) 213 return 214 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): 215 assert_equal(len(actual), len(desired), err_msg, verbose) 216 for k in range(len(desired)): 217 assert_equal(actual[k], desired[k], f"item={k!r}\n{err_msg}", verbose) 218 return 219 220 from torch._numpy import imag, iscomplexobj, isscalar, ndarray, real, signbit 221 222 if isinstance(actual, ndarray) or isinstance(desired, ndarray): 223 return assert_array_equal(actual, desired, err_msg, verbose) 224 msg = build_err_msg([actual, desired], err_msg, verbose=verbose) 225 226 # Handle complex numbers: separate into real/imag to handle 227 # nan/inf/negative zero correctly 228 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail 229 try: 230 usecomplex = iscomplexobj(actual) or iscomplexobj(desired) 231 except (ValueError, TypeError): 232 usecomplex = False 233 234 if usecomplex: 235 if iscomplexobj(actual): 236 actualr = real(actual) 237 actuali = imag(actual) 238 else: 239 actualr = actual 240 actuali = 0 241 if iscomplexobj(desired): 242 desiredr = real(desired) 243 desiredi = imag(desired) 244 else: 245 desiredr = desired 246 desiredi = 0 247 try: 248 assert_equal(actualr, desiredr) 249 assert_equal(actuali, desiredi) 250 except AssertionError: 251 raise AssertionError(msg) # noqa: B904 252 253 # isscalar test to check cases such as [np.nan] != np.nan 254 if isscalar(desired) != isscalar(actual): 255 raise AssertionError(msg) 256 257 # Inf/nan/negative zero handling 258 try: 259 isdesnan = gisnan(desired) 260 isactnan = gisnan(actual) 261 if isdesnan and isactnan: 262 return # both nan, so equal 263 264 # handle signed zero specially for floats 265 array_actual = np.asarray(actual) 266 array_desired = np.asarray(desired) 267 268 if desired == 0 and actual == 0: 269 if not signbit(desired) == signbit(actual): 270 raise AssertionError(msg) 271 272 except (TypeError, ValueError, NotImplementedError): 273 pass 274 275 try: 276 # Explicitly use __eq__ for comparison, gh-2552 277 if not (desired == actual): 278 raise AssertionError(msg) 279 280 except (DeprecationWarning, FutureWarning) as e: 281 # this handles the case when the two types are not even comparable 282 if "elementwise == comparison" in e.args[0]: 283 raise AssertionError(msg) # noqa: B904 284 else: 285 raise 286 287 288def print_assert_equal(test_string, actual, desired): 289 """ 290 Test if two objects are equal, and print an error message if test fails. 291 292 The test is performed with ``actual == desired``. 293 294 Parameters 295 ---------- 296 test_string : str 297 The message supplied to AssertionError. 298 actual : object 299 The object to test for equality against `desired`. 300 desired : object 301 The expected result. 302 303 Examples 304 -------- 305 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) # doctest: +SKIP 306 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) # doctest: +SKIP 307 Traceback (most recent call last): 308 ... 309 AssertionError: Test XYZ of func xyz failed 310 ACTUAL: 311 [0, 1] 312 DESIRED: 313 [0, 2] 314 315 """ 316 __tracebackhide__ = True # Hide traceback for py.test 317 import pprint 318 319 if not (actual == desired): 320 msg = StringIO() 321 msg.write(test_string) 322 msg.write(" failed\nACTUAL: \n") 323 pprint.pprint(actual, msg) 324 msg.write("DESIRED: \n") 325 pprint.pprint(desired, msg) 326 raise AssertionError(msg.getvalue()) 327 328 329def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True): 330 """ 331 Raises an AssertionError if two items are not equal up to desired 332 precision. 333 334 .. note:: It is recommended to use one of `assert_allclose`, 335 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 336 instead of this function for more consistent floating point 337 comparisons. 338 339 The test verifies that the elements of `actual` and `desired` satisfy. 340 341 ``abs(desired-actual) < float64(1.5 * 10**(-decimal))`` 342 343 That is a looser test than originally documented, but agrees with what the 344 actual implementation in `assert_array_almost_equal` did up to rounding 345 vagaries. An exception is raised at conflicting values. For ndarrays this 346 delegates to assert_array_almost_equal 347 348 Parameters 349 ---------- 350 actual : array_like 351 The object to check. 352 desired : array_like 353 The expected object. 354 decimal : int, optional 355 Desired precision, default is 7. 356 err_msg : str, optional 357 The error message to be printed in case of failure. 358 verbose : bool, optional 359 If True, the conflicting values are appended to the error message. 360 361 Raises 362 ------ 363 AssertionError 364 If actual and desired are not equal up to specified precision. 365 366 See Also 367 -------- 368 assert_allclose: Compare two array_like objects for equality with desired 369 relative and/or absolute precision. 370 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 371 372 Examples 373 -------- 374 >>> from torch._numpy.testing import assert_almost_equal 375 >>> assert_almost_equal(2.3333333333333, 2.33333334) 376 >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) 377 Traceback (most recent call last): 378 ... 379 AssertionError: 380 Arrays are not almost equal to 10 decimals 381 ACTUAL: 2.3333333333333 382 DESIRED: 2.33333334 383 384 >>> assert_almost_equal(np.array([1.0,2.3333333333333]), 385 ... np.array([1.0,2.33333334]), decimal=9) 386 Traceback (most recent call last): 387 ... 388 AssertionError: 389 Arrays are not almost equal to 9 decimals 390 <BLANKLINE> 391 Mismatched elements: 1 / 2 (50%) 392 Max absolute difference: 6.666699636781459e-09 393 Max relative difference: 2.8571569790287484e-09 394 x: torch.ndarray([1.0000, 2.3333], dtype=float64) 395 y: torch.ndarray([1.0000, 2.3333], dtype=float64) 396 397 """ 398 __tracebackhide__ = True # Hide traceback for py.test 399 from torch._numpy import imag, iscomplexobj, ndarray, real 400 401 # Handle complex numbers: separate into real/imag to handle 402 # nan/inf/negative zero correctly 403 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail 404 try: 405 usecomplex = iscomplexobj(actual) or iscomplexobj(desired) 406 except ValueError: 407 usecomplex = False 408 409 def _build_err_msg(): 410 header = "Arrays are not almost equal to %d decimals" % decimal 411 return build_err_msg([actual, desired], err_msg, verbose=verbose, header=header) 412 413 if usecomplex: 414 if iscomplexobj(actual): 415 actualr = real(actual) 416 actuali = imag(actual) 417 else: 418 actualr = actual 419 actuali = 0 420 if iscomplexobj(desired): 421 desiredr = real(desired) 422 desiredi = imag(desired) 423 else: 424 desiredr = desired 425 desiredi = 0 426 try: 427 assert_almost_equal(actualr, desiredr, decimal=decimal) 428 assert_almost_equal(actuali, desiredi, decimal=decimal) 429 except AssertionError: 430 raise AssertionError(_build_err_msg()) # noqa: B904 431 432 if isinstance(actual, (ndarray, tuple, list)) or isinstance( 433 desired, (ndarray, tuple, list) 434 ): 435 return assert_array_almost_equal(actual, desired, decimal, err_msg) 436 try: 437 # If one of desired/actual is not finite, handle it specially here: 438 # check that both are nan if any is a nan, and test for equality 439 # otherwise 440 if not (gisfinite(desired) and gisfinite(actual)): 441 if gisnan(desired) or gisnan(actual): 442 if not (gisnan(desired) and gisnan(actual)): 443 raise AssertionError(_build_err_msg()) 444 else: 445 if not desired == actual: 446 raise AssertionError(_build_err_msg()) 447 return 448 except (NotImplementedError, TypeError): 449 pass 450 if abs(desired - actual) >= np.float64(1.5 * 10.0 ** (-decimal)): 451 raise AssertionError(_build_err_msg()) 452 453 454def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True): 455 """ 456 Raises an AssertionError if two items are not equal up to significant 457 digits. 458 459 .. note:: It is recommended to use one of `assert_allclose`, 460 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 461 instead of this function for more consistent floating point 462 comparisons. 463 464 Given two numbers, check that they are approximately equal. 465 Approximately equal is defined as the number of significant digits 466 that agree. 467 468 Parameters 469 ---------- 470 actual : scalar 471 The object to check. 472 desired : scalar 473 The expected object. 474 significant : int, optional 475 Desired precision, default is 7. 476 err_msg : str, optional 477 The error message to be printed in case of failure. 478 verbose : bool, optional 479 If True, the conflicting values are appended to the error message. 480 481 Raises 482 ------ 483 AssertionError 484 If actual and desired are not equal up to specified precision. 485 486 See Also 487 -------- 488 assert_allclose: Compare two array_like objects for equality with desired 489 relative and/or absolute precision. 490 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 491 492 Examples 493 -------- 494 >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) # doctest: +SKIP 495 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, # doctest: +SKIP 496 ... significant=8) 497 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, # doctest: +SKIP 498 ... significant=8) 499 Traceback (most recent call last): 500 ... 501 AssertionError: 502 Items are not equal to 8 significant digits: 503 ACTUAL: 1.234567e-21 504 DESIRED: 1.2345672e-21 505 506 the evaluated condition that raises the exception is 507 508 >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1) 509 True 510 511 """ 512 __tracebackhide__ = True # Hide traceback for py.test 513 import numpy as np 514 515 (actual, desired) = map(float, (actual, desired)) 516 if desired == actual: 517 return 518 # Normalized the numbers to be in range (-10.0,10.0) 519 # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) 520 scale = 0.5 * (np.abs(desired) + np.abs(actual)) 521 scale = np.power(10, np.floor(np.log10(scale))) 522 try: 523 sc_desired = desired / scale 524 except ZeroDivisionError: 525 sc_desired = 0.0 526 try: 527 sc_actual = actual / scale 528 except ZeroDivisionError: 529 sc_actual = 0.0 530 msg = build_err_msg( 531 [actual, desired], 532 err_msg, 533 header="Items are not equal to %d significant digits:" % significant, 534 verbose=verbose, 535 ) 536 try: 537 # If one of desired/actual is not finite, handle it specially here: 538 # check that both are nan if any is a nan, and test for equality 539 # otherwise 540 if not (gisfinite(desired) and gisfinite(actual)): 541 if gisnan(desired) or gisnan(actual): 542 if not (gisnan(desired) and gisnan(actual)): 543 raise AssertionError(msg) 544 else: 545 if not desired == actual: 546 raise AssertionError(msg) 547 return 548 except (TypeError, NotImplementedError): 549 pass 550 if np.abs(sc_desired - sc_actual) >= np.power(10.0, -(significant - 1)): 551 raise AssertionError(msg) 552 553 554def assert_array_compare( 555 comparison, 556 x, 557 y, 558 err_msg="", 559 verbose=True, 560 header="", 561 precision=6, 562 equal_nan=True, 563 equal_inf=True, 564 *, 565 strict=False, 566): 567 __tracebackhide__ = True # Hide traceback for py.test 568 from torch._numpy import all, array, asarray, bool_, inf, isnan, max 569 570 x = asarray(x) 571 y = asarray(y) 572 573 def array2string(a): 574 return str(a) 575 576 # original array for output formatting 577 ox, oy = x, y 578 579 def func_assert_same_pos(x, y, func=isnan, hasval="nan"): 580 """Handling nan/inf. 581 582 Combine results of running func on x and y, checking that they are True 583 at the same locations. 584 585 """ 586 __tracebackhide__ = True # Hide traceback for py.test 587 x_id = func(x) 588 y_id = func(y) 589 # We include work-arounds here to handle three types of slightly 590 # pathological ndarray subclasses: 591 # (1) all() on `masked` array scalars can return masked arrays, so we 592 # use != True 593 # (2) __eq__ on some ndarray subclasses returns Python booleans 594 # instead of element-wise comparisons, so we cast to bool_() and 595 # use isinstance(..., bool) checks 596 # (3) subclasses with bare-bones __array_function__ implementations may 597 # not implement np.all(), so favor using the .all() method 598 # We are not committed to supporting such subclasses, but it's nice to 599 # support them if possible. 600 if (x_id == y_id).all().item() is not True: 601 msg = build_err_msg( 602 [x, y], 603 err_msg + f"\nx and y {hasval} location mismatch:", 604 verbose=verbose, 605 header=header, 606 names=("x", "y"), 607 precision=precision, 608 ) 609 raise AssertionError(msg) 610 # If there is a scalar, then here we know the array has the same 611 # flag as it everywhere, so we should return the scalar flag. 612 if isinstance(x_id, bool) or x_id.ndim == 0: 613 return bool_(x_id) 614 elif isinstance(y_id, bool) or y_id.ndim == 0: 615 return bool_(y_id) 616 else: 617 return y_id 618 619 try: 620 if strict: 621 cond = x.shape == y.shape and x.dtype == y.dtype 622 else: 623 cond = (x.shape == () or y.shape == ()) or x.shape == y.shape 624 if not cond: 625 if x.shape != y.shape: 626 reason = f"\n(shapes {x.shape}, {y.shape} mismatch)" 627 else: 628 reason = f"\n(dtypes {x.dtype}, {y.dtype} mismatch)" 629 msg = build_err_msg( 630 [x, y], 631 err_msg + reason, 632 verbose=verbose, 633 header=header, 634 names=("x", "y"), 635 precision=precision, 636 ) 637 raise AssertionError(msg) 638 639 flagged = bool_(False) 640 641 if equal_nan: 642 flagged = func_assert_same_pos(x, y, func=isnan, hasval="nan") 643 644 if equal_inf: 645 flagged |= func_assert_same_pos( 646 x, y, func=lambda xy: xy == +inf, hasval="+inf" 647 ) 648 flagged |= func_assert_same_pos( 649 x, y, func=lambda xy: xy == -inf, hasval="-inf" 650 ) 651 652 if flagged.ndim > 0: 653 x, y = x[~flagged], y[~flagged] 654 # Only do the comparison if actual values are left 655 if x.size == 0: 656 return 657 elif flagged: 658 # no sense doing comparison if everything is flagged. 659 return 660 661 val = comparison(x, y) 662 663 if isinstance(val, bool): 664 cond = val 665 reduced = array([val]) 666 else: 667 reduced = val.ravel() 668 cond = reduced.all() 669 670 # The below comparison is a hack to ensure that fully masked 671 # results, for which val.ravel().all() returns np.ma.masked, 672 # do not trigger a failure (np.ma.masked != True evaluates as 673 # np.ma.masked, which is falsy). 674 if not cond: 675 n_mismatch = reduced.size - int(reduced.sum(dtype=intp)) 676 n_elements = flagged.size if flagged.ndim != 0 else reduced.size 677 percent_mismatch = 100 * n_mismatch / n_elements 678 remarks = [ 679 f"Mismatched elements: {n_mismatch} / {n_elements} ({percent_mismatch:.3g}%)" 680 ] 681 682 # with errstate(all='ignore'): 683 # ignore errors for non-numeric types 684 with contextlib.suppress(TypeError, RuntimeError): 685 error = abs(x - y) 686 if np.issubdtype(x.dtype, np.unsignedinteger): 687 error2 = abs(y - x) 688 np.minimum(error, error2, out=error) 689 max_abs_error = max(error) 690 remarks.append( 691 "Max absolute difference: " + array2string(max_abs_error.item()) 692 ) 693 694 # note: this definition of relative error matches that one 695 # used by assert_allclose (found in np.isclose) 696 # Filter values where the divisor would be zero 697 nonzero = bool_(y != 0) 698 if all(~nonzero): 699 max_rel_error = array(inf) 700 else: 701 max_rel_error = max(error[nonzero] / abs(y[nonzero])) 702 remarks.append( 703 "Max relative difference: " + array2string(max_rel_error.item()) 704 ) 705 706 err_msg += "\n" + "\n".join(remarks) 707 msg = build_err_msg( 708 [ox, oy], 709 err_msg, 710 verbose=verbose, 711 header=header, 712 names=("x", "y"), 713 precision=precision, 714 ) 715 raise AssertionError(msg) 716 except ValueError: 717 import traceback 718 719 efmt = traceback.format_exc() 720 header = f"error during assertion:\n\n{efmt}\n\n{header}" 721 722 msg = build_err_msg( 723 [x, y], 724 err_msg, 725 verbose=verbose, 726 header=header, 727 names=("x", "y"), 728 precision=precision, 729 ) 730 raise ValueError(msg) # noqa: B904 731 732 733def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): 734 """ 735 Raises an AssertionError if two array_like objects are not equal. 736 737 Given two array_like objects, check that the shape is equal and all 738 elements of these objects are equal (but see the Notes for the special 739 handling of a scalar). An exception is raised at shape mismatch or 740 conflicting values. In contrast to the standard usage in numpy, NaNs 741 are compared like numbers, no assertion is raised if both objects have 742 NaNs in the same positions. 743 744 The usual caution for verifying equality with floating point numbers is 745 advised. 746 747 Parameters 748 ---------- 749 x : array_like 750 The actual object to check. 751 y : array_like 752 The desired, expected object. 753 err_msg : str, optional 754 The error message to be printed in case of failure. 755 verbose : bool, optional 756 If True, the conflicting values are appended to the error message. 757 strict : bool, optional 758 If True, raise an AssertionError when either the shape or the data 759 type of the array_like objects does not match. The special 760 handling for scalars mentioned in the Notes section is disabled. 761 762 Raises 763 ------ 764 AssertionError 765 If actual and desired objects are not equal. 766 767 See Also 768 -------- 769 assert_allclose: Compare two array_like objects for equality with desired 770 relative and/or absolute precision. 771 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 772 773 Notes 774 ----- 775 When one of `x` and `y` is a scalar and the other is array_like, the 776 function checks that each element of the array_like object is equal to 777 the scalar. This behaviour can be disabled with the `strict` parameter. 778 779 Examples 780 -------- 781 The first assert does not raise an exception: 782 783 >>> np.testing.assert_array_equal([1.0,2.33333,np.nan], 784 ... [np.exp(0),2.33333, np.nan]) 785 786 Use `assert_allclose` or one of the nulp (number of floating point values) 787 functions for these cases instead: 788 789 >>> np.testing.assert_allclose([1.0,np.pi,np.nan], 790 ... [1, np.sqrt(np.pi)**2, np.nan], 791 ... rtol=1e-10, atol=0) 792 793 As mentioned in the Notes section, `assert_array_equal` has special 794 handling for scalars. Here the test checks that each value in `x` is 3: 795 796 >>> x = np.full((2, 5), fill_value=3) 797 >>> np.testing.assert_array_equal(x, 3) 798 799 Use `strict` to raise an AssertionError when comparing a scalar with an 800 array: 801 802 >>> np.testing.assert_array_equal(x, 3, strict=True) 803 Traceback (most recent call last): 804 ... 805 AssertionError: 806 Arrays are not equal 807 <BLANKLINE> 808 (shapes (2, 5), () mismatch) 809 x: torch.ndarray([[3, 3, 3, 3, 3], 810 [3, 3, 3, 3, 3]]) 811 y: torch.ndarray(3) 812 813 The `strict` parameter also ensures that the array data types match: 814 815 >>> x = np.array([2, 2, 2]) 816 >>> y = np.array([2., 2., 2.], dtype=np.float32) 817 >>> np.testing.assert_array_equal(x, y, strict=True) 818 Traceback (most recent call last): 819 ... 820 AssertionError: 821 Arrays are not equal 822 <BLANKLINE> 823 (dtypes dtype("int64"), dtype("float32") mismatch) 824 x: torch.ndarray([2, 2, 2]) 825 y: torch.ndarray([2., 2., 2.]) 826 """ 827 __tracebackhide__ = True # Hide traceback for py.test 828 assert_array_compare( 829 operator.__eq__, 830 x, 831 y, 832 err_msg=err_msg, 833 verbose=verbose, 834 header="Arrays are not equal", 835 strict=strict, 836 ) 837 838 839def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True): 840 """ 841 Raises an AssertionError if two objects are not equal up to desired 842 precision. 843 844 .. note:: It is recommended to use one of `assert_allclose`, 845 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 846 instead of this function for more consistent floating point 847 comparisons. 848 849 The test verifies identical shapes and that the elements of ``actual`` and 850 ``desired`` satisfy. 851 852 ``abs(desired-actual) < 1.5 * 10**(-decimal)`` 853 854 That is a looser test than originally documented, but agrees with what the 855 actual implementation did up to rounding vagaries. An exception is raised 856 at shape mismatch or conflicting values. In contrast to the standard usage 857 in numpy, NaNs are compared like numbers, no assertion is raised if both 858 objects have NaNs in the same positions. 859 860 Parameters 861 ---------- 862 x : array_like 863 The actual object to check. 864 y : array_like 865 The desired, expected object. 866 decimal : int, optional 867 Desired precision, default is 6. 868 err_msg : str, optional 869 The error message to be printed in case of failure. 870 verbose : bool, optional 871 If True, the conflicting values are appended to the error message. 872 873 Raises 874 ------ 875 AssertionError 876 If actual and desired are not equal up to specified precision. 877 878 See Also 879 -------- 880 assert_allclose: Compare two array_like objects for equality with desired 881 relative and/or absolute precision. 882 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 883 884 Examples 885 -------- 886 the first assert does not raise an exception 887 888 >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], 889 ... [1.0,2.333,np.nan]) 890 891 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], 892 ... [1.0,2.33339,np.nan], decimal=5) 893 Traceback (most recent call last): 894 ... 895 AssertionError: 896 Arrays are not almost equal to 5 decimals 897 <BLANKLINE> 898 Mismatched elements: 1 / 3 (33.3%) 899 Max absolute difference: 5.999999999994898e-05 900 Max relative difference: 2.5713661239633743e-05 901 x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) 902 y: torch.ndarray([1.0000, 2.3334, nan], dtype=float64) 903 904 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], 905 ... [1.0,2.33333, 5], decimal=5) 906 Traceback (most recent call last): 907 ... 908 AssertionError: 909 Arrays are not almost equal to 5 decimals 910 <BLANKLINE> 911 x and y nan location mismatch: 912 x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) 913 y: torch.ndarray([1.0000, 2.3333, 5.0000], dtype=float64) 914 915 """ 916 __tracebackhide__ = True # Hide traceback for py.test 917 from torch._numpy import any as npany, float_, issubdtype, number, result_type 918 919 def compare(x, y): 920 try: 921 if npany(gisinf(x)) or npany(gisinf(y)): 922 xinfid = gisinf(x) 923 yinfid = gisinf(y) 924 if not (xinfid == yinfid).all(): 925 return False 926 # if one item, x and y is +- inf 927 if x.size == y.size == 1: 928 return x == y 929 x = x[~xinfid] 930 y = y[~yinfid] 931 except (TypeError, NotImplementedError): 932 pass 933 934 # make sure y is an inexact type to avoid abs(MIN_INT); will cause 935 # casting of x later. 936 dtype = result_type(y, 1.0) 937 y = asanyarray(y, dtype) 938 z = abs(x - y) 939 940 if not issubdtype(z.dtype, number): 941 z = z.astype(float_) # handle object arrays 942 943 return z < 1.5 * 10.0 ** (-decimal) 944 945 assert_array_compare( 946 compare, 947 x, 948 y, 949 err_msg=err_msg, 950 verbose=verbose, 951 header=("Arrays are not almost equal to %d decimals" % decimal), 952 precision=decimal, 953 ) 954 955 956def assert_array_less(x, y, err_msg="", verbose=True): 957 """ 958 Raises an AssertionError if two array_like objects are not ordered by less 959 than. 960 961 Given two array_like objects, check that the shape is equal and all 962 elements of the first object are strictly smaller than those of the 963 second object. An exception is raised at shape mismatch or incorrectly 964 ordered values. Shape mismatch does not raise if an object has zero 965 dimension. In contrast to the standard usage in numpy, NaNs are 966 compared, no assertion is raised if both objects have NaNs in the same 967 positions. 968 969 970 971 Parameters 972 ---------- 973 x : array_like 974 The smaller object to check. 975 y : array_like 976 The larger object to compare. 977 err_msg : string 978 The error message to be printed in case of failure. 979 verbose : bool 980 If True, the conflicting values are appended to the error message. 981 982 Raises 983 ------ 984 AssertionError 985 If actual and desired objects are not equal. 986 987 See Also 988 -------- 989 assert_array_equal: tests objects for equality 990 assert_array_almost_equal: test objects for equality up to precision 991 992 993 994 Examples 995 -------- 996 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) 997 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) 998 Traceback (most recent call last): 999 ... 1000 AssertionError: 1001 Arrays are not less-ordered 1002 <BLANKLINE> 1003 Mismatched elements: 1 / 3 (33.3%) 1004 Max absolute difference: 1.0 1005 Max relative difference: 0.5 1006 x: torch.ndarray([1., 1., nan], dtype=float64) 1007 y: torch.ndarray([1., 2., nan], dtype=float64) 1008 1009 >>> np.testing.assert_array_less([1.0, 4.0], 3) 1010 Traceback (most recent call last): 1011 ... 1012 AssertionError: 1013 Arrays are not less-ordered 1014 <BLANKLINE> 1015 Mismatched elements: 1 / 2 (50%) 1016 Max absolute difference: 2.0 1017 Max relative difference: 0.6666666666666666 1018 x: torch.ndarray([1., 4.], dtype=float64) 1019 y: torch.ndarray(3) 1020 1021 >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) 1022 Traceback (most recent call last): 1023 ... 1024 AssertionError: 1025 Arrays are not less-ordered 1026 <BLANKLINE> 1027 (shapes (3,), (1,) mismatch) 1028 x: torch.ndarray([1., 2., 3.], dtype=float64) 1029 y: torch.ndarray([4]) 1030 1031 """ 1032 __tracebackhide__ = True # Hide traceback for py.test 1033 assert_array_compare( 1034 operator.__lt__, 1035 x, 1036 y, 1037 err_msg=err_msg, 1038 verbose=verbose, 1039 header="Arrays are not less-ordered", 1040 equal_inf=False, 1041 ) 1042 1043 1044def assert_string_equal(actual, desired): 1045 """ 1046 Test if two strings are equal. 1047 1048 If the given strings are equal, `assert_string_equal` does nothing. 1049 If they are not equal, an AssertionError is raised, and the diff 1050 between the strings is shown. 1051 1052 Parameters 1053 ---------- 1054 actual : str 1055 The string to test for equality against the expected string. 1056 desired : str 1057 The expected string. 1058 1059 Examples 1060 -------- 1061 >>> np.testing.assert_string_equal('abc', 'abc') # doctest: +SKIP 1062 >>> np.testing.assert_string_equal('abc', 'abcd') # doctest: +SKIP 1063 Traceback (most recent call last): 1064 File "<stdin>", line 1, in <module> 1065 ... 1066 AssertionError: Differences in strings: 1067 - abc+ abcd? + 1068 1069 """ 1070 # delay import of difflib to reduce startup time 1071 __tracebackhide__ = True # Hide traceback for py.test 1072 import difflib 1073 1074 if not isinstance(actual, str): 1075 raise AssertionError(repr(type(actual))) 1076 if not isinstance(desired, str): 1077 raise AssertionError(repr(type(desired))) 1078 if desired == actual: 1079 return 1080 1081 diff = list( 1082 difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True)) 1083 ) 1084 diff_list = [] 1085 while diff: 1086 d1 = diff.pop(0) 1087 if d1.startswith(" "): 1088 continue 1089 if d1.startswith("- "): 1090 l = [d1] 1091 d2 = diff.pop(0) 1092 if d2.startswith("? "): 1093 l.append(d2) 1094 d2 = diff.pop(0) 1095 if not d2.startswith("+ "): 1096 raise AssertionError(repr(d2)) 1097 l.append(d2) 1098 if diff: 1099 d3 = diff.pop(0) 1100 if d3.startswith("? "): 1101 l.append(d3) 1102 else: 1103 diff.insert(0, d3) 1104 if d2[2:] == d1[2:]: 1105 continue 1106 diff_list.extend(l) 1107 continue 1108 raise AssertionError(repr(d1)) 1109 if not diff_list: 1110 return 1111 msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}" 1112 if actual != desired: 1113 raise AssertionError(msg) 1114 1115 1116import unittest 1117 1118 1119class _Dummy(unittest.TestCase): 1120 def nop(self): 1121 pass 1122 1123 1124_d = _Dummy("nop") 1125 1126 1127def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): 1128 """ 1129 assert_raises_regex(exception_class, expected_regexp, callable, *args, 1130 **kwargs) 1131 assert_raises_regex(exception_class, expected_regexp) 1132 1133 Fail unless an exception of class exception_class and with message that 1134 matches expected_regexp is thrown by callable when invoked with arguments 1135 args and keyword arguments kwargs. 1136 1137 Alternatively, can be used as a context manager like `assert_raises`. 1138 1139 Notes 1140 ----- 1141 .. versionadded:: 1.9.0 1142 1143 """ 1144 __tracebackhide__ = True # Hide traceback for py.test 1145 return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) 1146 1147 1148def decorate_methods(cls, decorator, testmatch=None): 1149 """ 1150 Apply a decorator to all methods in a class matching a regular expression. 1151 1152 The given decorator is applied to all public methods of `cls` that are 1153 matched by the regular expression `testmatch` 1154 (``testmatch.search(methodname)``). Methods that are private, i.e. start 1155 with an underscore, are ignored. 1156 1157 Parameters 1158 ---------- 1159 cls : class 1160 Class whose methods to decorate. 1161 decorator : function 1162 Decorator to apply to methods 1163 testmatch : compiled regexp or str, optional 1164 The regular expression. Default value is None, in which case the 1165 nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``) 1166 is used. 1167 If `testmatch` is a string, it is compiled to a regular expression 1168 first. 1169 1170 """ 1171 if testmatch is None: 1172 testmatch = re.compile(rf"(?:^|[\\b_\\.{os.sep}-])[Tt]est") 1173 else: 1174 testmatch = re.compile(testmatch) 1175 cls_attr = cls.__dict__ 1176 1177 # delayed import to reduce startup time 1178 from inspect import isfunction 1179 1180 methods = [_m for _m in cls_attr.values() if isfunction(_m)] 1181 for function in methods: 1182 try: 1183 if hasattr(function, "compat_func_name"): 1184 funcname = function.compat_func_name 1185 else: 1186 funcname = function.__name__ 1187 except AttributeError: 1188 # not a function 1189 continue 1190 if testmatch.search(funcname) and not funcname.startswith("_"): 1191 setattr(cls, funcname, decorator(function)) 1192 return 1193 1194 1195def _assert_valid_refcount(op): 1196 """ 1197 Check that ufuncs don't mishandle refcount of object `1`. 1198 Used in a few regression tests. 1199 """ 1200 if not HAS_REFCOUNT: 1201 return True 1202 1203 import gc 1204 1205 import numpy as np 1206 1207 b = np.arange(100 * 100).reshape(100, 100) 1208 c = b 1209 i = 1 1210 1211 gc.disable() 1212 try: 1213 rc = sys.getrefcount(i) 1214 for j in range(15): 1215 d = op(b, c) 1216 assert_(sys.getrefcount(i) >= rc) 1217 finally: 1218 gc.enable() 1219 del d # for pyflakes 1220 1221 1222def assert_allclose( 1223 actual, 1224 desired, 1225 rtol=1e-7, 1226 atol=0, 1227 equal_nan=True, 1228 err_msg="", 1229 verbose=True, 1230 check_dtype=False, 1231): 1232 """ 1233 Raises an AssertionError if two objects are not equal up to desired 1234 tolerance. 1235 1236 Given two array_like objects, check that their shapes and all elements 1237 are equal (but see the Notes for the special handling of a scalar). An 1238 exception is raised if the shapes mismatch or any values conflict. In 1239 contrast to the standard usage in numpy, NaNs are compared like numbers, 1240 no assertion is raised if both objects have NaNs in the same positions. 1241 1242 The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note 1243 that ``allclose`` has different default values). It compares the difference 1244 between `actual` and `desired` to ``atol + rtol * abs(desired)``. 1245 1246 .. versionadded:: 1.5.0 1247 1248 Parameters 1249 ---------- 1250 actual : array_like 1251 Array obtained. 1252 desired : array_like 1253 Array desired. 1254 rtol : float, optional 1255 Relative tolerance. 1256 atol : float, optional 1257 Absolute tolerance. 1258 equal_nan : bool, optional. 1259 If True, NaNs will compare equal. 1260 err_msg : str, optional 1261 The error message to be printed in case of failure. 1262 verbose : bool, optional 1263 If True, the conflicting values are appended to the error message. 1264 1265 Raises 1266 ------ 1267 AssertionError 1268 If actual and desired are not equal up to specified precision. 1269 1270 See Also 1271 -------- 1272 assert_array_almost_equal_nulp, assert_array_max_ulp 1273 1274 Notes 1275 ----- 1276 When one of `actual` and `desired` is a scalar and the other is 1277 array_like, the function checks that each element of the array_like 1278 object is equal to the scalar. 1279 1280 Examples 1281 -------- 1282 >>> x = [1e-5, 1e-3, 1e-1] 1283 >>> y = np.arccos(np.cos(x)) 1284 >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) 1285 1286 """ 1287 __tracebackhide__ = True # Hide traceback for py.test 1288 1289 def compare(x, y): 1290 return np.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) 1291 1292 actual, desired = asanyarray(actual), asanyarray(desired) 1293 header = f"Not equal to tolerance rtol={rtol:g}, atol={atol:g}" 1294 1295 if check_dtype: 1296 assert actual.dtype == desired.dtype 1297 1298 assert_array_compare( 1299 compare, 1300 actual, 1301 desired, 1302 err_msg=str(err_msg), 1303 verbose=verbose, 1304 header=header, 1305 equal_nan=equal_nan, 1306 ) 1307 1308 1309def assert_array_almost_equal_nulp(x, y, nulp=1): 1310 """ 1311 Compare two arrays relatively to their spacing. 1312 1313 This is a relatively robust method to compare two arrays whose amplitude 1314 is variable. 1315 1316 Parameters 1317 ---------- 1318 x, y : array_like 1319 Input arrays. 1320 nulp : int, optional 1321 The maximum number of unit in the last place for tolerance (see Notes). 1322 Default is 1. 1323 1324 Returns 1325 ------- 1326 None 1327 1328 Raises 1329 ------ 1330 AssertionError 1331 If the spacing between `x` and `y` for one or more elements is larger 1332 than `nulp`. 1333 1334 See Also 1335 -------- 1336 assert_array_max_ulp : Check that all items of arrays differ in at most 1337 N Units in the Last Place. 1338 spacing : Return the distance between x and the nearest adjacent number. 1339 1340 Notes 1341 ----- 1342 An assertion is raised if the following condition is not met:: 1343 1344 abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y))) 1345 1346 Examples 1347 -------- 1348 >>> x = np.array([1., 1e-10, 1e-20]) 1349 >>> eps = np.finfo(x.dtype).eps 1350 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) # doctest: +SKIP 1351 1352 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) # doctest: +SKIP 1353 Traceback (most recent call last): 1354 ... 1355 AssertionError: X and Y are not equal to 1 ULP (max is 2) 1356 1357 """ 1358 __tracebackhide__ = True # Hide traceback for py.test 1359 import numpy as np 1360 1361 ax = np.abs(x) 1362 ay = np.abs(y) 1363 ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) 1364 if not np.all(np.abs(x - y) <= ref): 1365 if np.iscomplexobj(x) or np.iscomplexobj(y): 1366 msg = "X and Y are not equal to %d ULP" % nulp 1367 else: 1368 max_nulp = np.max(nulp_diff(x, y)) 1369 msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp) 1370 raise AssertionError(msg) 1371 1372 1373def assert_array_max_ulp(a, b, maxulp=1, dtype=None): 1374 """ 1375 Check that all items of arrays differ in at most N Units in the Last Place. 1376 1377 Parameters 1378 ---------- 1379 a, b : array_like 1380 Input arrays to be compared. 1381 maxulp : int, optional 1382 The maximum number of units in the last place that elements of `a` and 1383 `b` can differ. Default is 1. 1384 dtype : dtype, optional 1385 Data-type to convert `a` and `b` to if given. Default is None. 1386 1387 Returns 1388 ------- 1389 ret : ndarray 1390 Array containing number of representable floating point numbers between 1391 items in `a` and `b`. 1392 1393 Raises 1394 ------ 1395 AssertionError 1396 If one or more elements differ by more than `maxulp`. 1397 1398 Notes 1399 ----- 1400 For computing the ULP difference, this API does not differentiate between 1401 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 1402 is zero). 1403 1404 See Also 1405 -------- 1406 assert_array_almost_equal_nulp : Compare two arrays relatively to their 1407 spacing. 1408 1409 Examples 1410 -------- 1411 >>> a = np.linspace(0., 1., 100) 1412 >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) # doctest: +SKIP 1413 1414 """ 1415 __tracebackhide__ = True # Hide traceback for py.test 1416 import numpy as np 1417 1418 ret = nulp_diff(a, b, dtype) 1419 if not np.all(ret <= maxulp): 1420 raise AssertionError( 1421 f"Arrays are not almost equal up to {maxulp:g} " 1422 f"ULP (max difference is {np.max(ret):g} ULP)" 1423 ) 1424 return ret 1425 1426 1427def nulp_diff(x, y, dtype=None): 1428 """For each item in x and y, return the number of representable floating 1429 points between them. 1430 1431 Parameters 1432 ---------- 1433 x : array_like 1434 first input array 1435 y : array_like 1436 second input array 1437 dtype : dtype, optional 1438 Data-type to convert `x` and `y` to if given. Default is None. 1439 1440 Returns 1441 ------- 1442 nulp : array_like 1443 number of representable floating point numbers between each item in x 1444 and y. 1445 1446 Notes 1447 ----- 1448 For computing the ULP difference, this API does not differentiate between 1449 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 1450 is zero). 1451 1452 Examples 1453 -------- 1454 # By definition, epsilon is the smallest number such as 1 + eps != 1, so 1455 # there should be exactly one ULP between 1 and 1 + eps 1456 >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps) # doctest: +SKIP 1457 1.0 1458 """ 1459 import numpy as np 1460 1461 if dtype: 1462 x = np.asarray(x, dtype=dtype) 1463 y = np.asarray(y, dtype=dtype) 1464 else: 1465 x = np.asarray(x) 1466 y = np.asarray(y) 1467 1468 t = np.common_type(x, y) 1469 if np.iscomplexobj(x) or np.iscomplexobj(y): 1470 raise NotImplementedError("_nulp not implemented for complex array") 1471 1472 x = np.array([x], dtype=t) 1473 y = np.array([y], dtype=t) 1474 1475 x[np.isnan(x)] = np.nan 1476 y[np.isnan(y)] = np.nan 1477 1478 if not x.shape == y.shape: 1479 raise ValueError(f"x and y do not have the same shape: {x.shape} - {y.shape}") 1480 1481 def _diff(rx, ry, vdt): 1482 diff = np.asarray(rx - ry, dtype=vdt) 1483 return np.abs(diff) 1484 1485 rx = integer_repr(x) 1486 ry = integer_repr(y) 1487 return _diff(rx, ry, t) 1488 1489 1490def _integer_repr(x, vdt, comp): 1491 # Reinterpret binary representation of the float as sign-magnitude: 1492 # take into account two-complement representation 1493 # See also 1494 # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ 1495 rx = x.view(vdt) 1496 if not (rx.size == 1): 1497 rx[rx < 0] = comp - rx[rx < 0] 1498 else: 1499 if rx < 0: 1500 rx = comp - rx 1501 1502 return rx 1503 1504 1505def integer_repr(x): 1506 """Return the signed-magnitude interpretation of the binary representation 1507 of x.""" 1508 import numpy as np 1509 1510 if x.dtype == np.float16: 1511 return _integer_repr(x, np.int16, np.int16(-(2**15))) 1512 elif x.dtype == np.float32: 1513 return _integer_repr(x, np.int32, np.int32(-(2**31))) 1514 elif x.dtype == np.float64: 1515 return _integer_repr(x, np.int64, np.int64(-(2**63))) 1516 else: 1517 raise ValueError(f"Unsupported dtype {x.dtype}") 1518 1519 1520@contextlib.contextmanager 1521def _assert_warns_context(warning_class, name=None): 1522 __tracebackhide__ = True # Hide traceback for py.test 1523 with suppress_warnings() as sup: 1524 l = sup.record(warning_class) 1525 yield 1526 if not len(l) > 0: 1527 name_str = f" when calling {name}" if name is not None else "" 1528 raise AssertionError("No warning raised" + name_str) 1529 1530 1531def assert_warns(warning_class, *args, **kwargs): 1532 """ 1533 Fail unless the given callable throws the specified warning. 1534 1535 A warning of class warning_class should be thrown by the callable when 1536 invoked with arguments args and keyword arguments kwargs. 1537 If a different type of warning is thrown, it will not be caught. 1538 1539 If called with all arguments other than the warning class omitted, may be 1540 used as a context manager: 1541 1542 with assert_warns(SomeWarning): 1543 do_something() 1544 1545 The ability to be used as a context manager is new in NumPy v1.11.0. 1546 1547 .. versionadded:: 1.4.0 1548 1549 Parameters 1550 ---------- 1551 warning_class : class 1552 The class defining the warning that `func` is expected to throw. 1553 func : callable, optional 1554 Callable to test 1555 *args : Arguments 1556 Arguments for `func`. 1557 **kwargs : Kwargs 1558 Keyword arguments for `func`. 1559 1560 Returns 1561 ------- 1562 The value returned by `func`. 1563 1564 Examples 1565 -------- 1566 >>> import warnings 1567 >>> def deprecated_func(num): 1568 ... warnings.warn("Please upgrade", DeprecationWarning) 1569 ... return num*num 1570 >>> with np.testing.assert_warns(DeprecationWarning): 1571 ... assert deprecated_func(4) == 16 1572 >>> # or passing a func 1573 >>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4) 1574 >>> assert ret == 16 1575 """ 1576 if not args: 1577 return _assert_warns_context(warning_class) 1578 1579 func = args[0] 1580 args = args[1:] 1581 with _assert_warns_context(warning_class, name=func.__name__): 1582 return func(*args, **kwargs) 1583 1584 1585@contextlib.contextmanager 1586def _assert_no_warnings_context(name=None): 1587 __tracebackhide__ = True # Hide traceback for py.test 1588 with warnings.catch_warnings(record=True) as l: 1589 warnings.simplefilter("always") 1590 yield 1591 if len(l) > 0: 1592 name_str = f" when calling {name}" if name is not None else "" 1593 raise AssertionError(f"Got warnings{name_str}: {l}") 1594 1595 1596def assert_no_warnings(*args, **kwargs): 1597 """ 1598 Fail if the given callable produces any warnings. 1599 1600 If called with all arguments omitted, may be used as a context manager: 1601 1602 with assert_no_warnings(): 1603 do_something() 1604 1605 The ability to be used as a context manager is new in NumPy v1.11.0. 1606 1607 .. versionadded:: 1.7.0 1608 1609 Parameters 1610 ---------- 1611 func : callable 1612 The callable to test. 1613 \\*args : Arguments 1614 Arguments passed to `func`. 1615 \\*\\*kwargs : Kwargs 1616 Keyword arguments passed to `func`. 1617 1618 Returns 1619 ------- 1620 The value returned by `func`. 1621 1622 """ 1623 if not args: 1624 return _assert_no_warnings_context() 1625 1626 func = args[0] 1627 args = args[1:] 1628 with _assert_no_warnings_context(name=func.__name__): 1629 return func(*args, **kwargs) 1630 1631 1632def _gen_alignment_data(dtype=float32, type="binary", max_size=24): 1633 """ 1634 generator producing data with different alignment and offsets 1635 to test simd vectorization 1636 1637 Parameters 1638 ---------- 1639 dtype : dtype 1640 data type to produce 1641 type : string 1642 'unary': create data for unary operations, creates one input 1643 and output array 1644 'binary': create data for unary operations, creates two input 1645 and output array 1646 max_size : integer 1647 maximum size of data to produce 1648 1649 Returns 1650 ------- 1651 if type is 'unary' yields one output, one input array and a message 1652 containing information on the data 1653 if type is 'binary' yields one output array, two input array and a message 1654 containing information on the data 1655 1656 """ 1657 ufmt = "unary offset=(%d, %d), size=%d, dtype=%r, %s" 1658 bfmt = "binary offset=(%d, %d, %d), size=%d, dtype=%r, %s" 1659 for o in range(3): 1660 for s in range(o + 2, max(o + 3, max_size)): 1661 if type == "unary": 1662 1663 def inp(): 1664 return arange(s, dtype=dtype)[o:] 1665 1666 out = empty((s,), dtype=dtype)[o:] 1667 yield out, inp(), ufmt % (o, o, s, dtype, "out of place") 1668 d = inp() 1669 yield d, d, ufmt % (o, o, s, dtype, "in place") 1670 yield out[1:], inp()[:-1], ufmt % ( 1671 o + 1, 1672 o, 1673 s - 1, 1674 dtype, 1675 "out of place", 1676 ) 1677 yield out[:-1], inp()[1:], ufmt % ( 1678 o, 1679 o + 1, 1680 s - 1, 1681 dtype, 1682 "out of place", 1683 ) 1684 yield inp()[:-1], inp()[1:], ufmt % (o, o + 1, s - 1, dtype, "aliased") 1685 yield inp()[1:], inp()[:-1], ufmt % (o + 1, o, s - 1, dtype, "aliased") 1686 if type == "binary": 1687 1688 def inp1(): 1689 return arange(s, dtype=dtype)[o:] 1690 1691 inp2 = inp1 1692 out = empty((s,), dtype=dtype)[o:] 1693 yield out, inp1(), inp2(), bfmt % (o, o, o, s, dtype, "out of place") 1694 d = inp1() 1695 yield d, d, inp2(), bfmt % (o, o, o, s, dtype, "in place1") 1696 d = inp2() 1697 yield d, inp1(), d, bfmt % (o, o, o, s, dtype, "in place2") 1698 yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % ( 1699 o + 1, 1700 o, 1701 o, 1702 s - 1, 1703 dtype, 1704 "out of place", 1705 ) 1706 yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % ( 1707 o, 1708 o + 1, 1709 o, 1710 s - 1, 1711 dtype, 1712 "out of place", 1713 ) 1714 yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % ( 1715 o, 1716 o, 1717 o + 1, 1718 s - 1, 1719 dtype, 1720 "out of place", 1721 ) 1722 yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % ( 1723 o + 1, 1724 o, 1725 o, 1726 s - 1, 1727 dtype, 1728 "aliased", 1729 ) 1730 yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % ( 1731 o, 1732 o + 1, 1733 o, 1734 s - 1, 1735 dtype, 1736 "aliased", 1737 ) 1738 yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % ( 1739 o, 1740 o, 1741 o + 1, 1742 s - 1, 1743 dtype, 1744 "aliased", 1745 ) 1746 1747 1748class IgnoreException(Exception): 1749 "Ignoring this exception due to disabled feature" 1750 1751 1752@contextlib.contextmanager 1753def tempdir(*args, **kwargs): 1754 """Context manager to provide a temporary test folder. 1755 1756 All arguments are passed as this to the underlying tempfile.mkdtemp 1757 function. 1758 1759 """ 1760 tmpdir = mkdtemp(*args, **kwargs) 1761 try: 1762 yield tmpdir 1763 finally: 1764 shutil.rmtree(tmpdir) 1765 1766 1767@contextlib.contextmanager 1768def temppath(*args, **kwargs): 1769 """Context manager for temporary files. 1770 1771 Context manager that returns the path to a closed temporary file. Its 1772 parameters are the same as for tempfile.mkstemp and are passed directly 1773 to that function. The underlying file is removed when the context is 1774 exited, so it should be closed at that time. 1775 1776 Windows does not allow a temporary file to be opened if it is already 1777 open, so the underlying file must be closed after opening before it 1778 can be opened again. 1779 1780 """ 1781 fd, path = mkstemp(*args, **kwargs) 1782 os.close(fd) 1783 try: 1784 yield path 1785 finally: 1786 os.remove(path) 1787 1788 1789class clear_and_catch_warnings(warnings.catch_warnings): 1790 """Context manager that resets warning registry for catching warnings 1791 1792 Warnings can be slippery, because, whenever a warning is triggered, Python 1793 adds a ``__warningregistry__`` member to the *calling* module. This makes 1794 it impossible to retrigger the warning in this module, whatever you put in 1795 the warnings filters. This context manager accepts a sequence of `modules` 1796 as a keyword argument to its constructor and: 1797 1798 * stores and removes any ``__warningregistry__`` entries in given `modules` 1799 on entry; 1800 * resets ``__warningregistry__`` to its previous state on exit. 1801 1802 This makes it possible to trigger any warning afresh inside the context 1803 manager without disturbing the state of warnings outside. 1804 1805 For compatibility with Python 3.0, please consider all arguments to be 1806 keyword-only. 1807 1808 Parameters 1809 ---------- 1810 record : bool, optional 1811 Specifies whether warnings should be captured by a custom 1812 implementation of ``warnings.showwarning()`` and be appended to a list 1813 returned by the context manager. Otherwise None is returned by the 1814 context manager. The objects appended to the list are arguments whose 1815 attributes mirror the arguments to ``showwarning()``. 1816 modules : sequence, optional 1817 Sequence of modules for which to reset warnings registry on entry and 1818 restore on exit. To work correctly, all 'ignore' filters should 1819 filter by one of these modules. 1820 1821 Examples 1822 -------- 1823 >>> import warnings 1824 >>> with np.testing.clear_and_catch_warnings( # doctest: +SKIP 1825 ... modules=[np.core.fromnumeric]): 1826 ... warnings.simplefilter('always') 1827 ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') 1828 ... # do something that raises a warning but ignore those in 1829 ... # np.core.fromnumeric 1830 """ 1831 1832 class_modules = () 1833 1834 def __init__(self, record=False, modules=()): 1835 self.modules = set(modules).union(self.class_modules) 1836 self._warnreg_copies = {} 1837 super().__init__(record=record) 1838 1839 def __enter__(self): 1840 for mod in self.modules: 1841 if hasattr(mod, "__warningregistry__"): 1842 mod_reg = mod.__warningregistry__ 1843 self._warnreg_copies[mod] = mod_reg.copy() 1844 mod_reg.clear() 1845 return super().__enter__() 1846 1847 def __exit__(self, *exc_info): 1848 super().__exit__(*exc_info) 1849 for mod in self.modules: 1850 if hasattr(mod, "__warningregistry__"): 1851 mod.__warningregistry__.clear() 1852 if mod in self._warnreg_copies: 1853 mod.__warningregistry__.update(self._warnreg_copies[mod]) 1854 1855 1856class suppress_warnings: 1857 """ 1858 Context manager and decorator doing much the same as 1859 ``warnings.catch_warnings``. 1860 1861 However, it also provides a filter mechanism to work around 1862 https://bugs.python.org/issue4180. 1863 1864 This bug causes Python before 3.4 to not reliably show warnings again 1865 after they have been ignored once (even within catch_warnings). It 1866 means that no "ignore" filter can be used easily, since following 1867 tests might need to see the warning. Additionally it allows easier 1868 specificity for testing warnings and can be nested. 1869 1870 Parameters 1871 ---------- 1872 forwarding_rule : str, optional 1873 One of "always", "once", "module", or "location". Analogous to 1874 the usual warnings module filter mode, it is useful to reduce 1875 noise mostly on the outmost level. Unsuppressed and unrecorded 1876 warnings will be forwarded based on this rule. Defaults to "always". 1877 "location" is equivalent to the warnings "default", match by exact 1878 location the warning warning originated from. 1879 1880 Notes 1881 ----- 1882 Filters added inside the context manager will be discarded again 1883 when leaving it. Upon entering all filters defined outside a 1884 context will be applied automatically. 1885 1886 When a recording filter is added, matching warnings are stored in the 1887 ``log`` attribute as well as in the list returned by ``record``. 1888 1889 If filters are added and the ``module`` keyword is given, the 1890 warning registry of this module will additionally be cleared when 1891 applying it, entering the context, or exiting it. This could cause 1892 warnings to appear a second time after leaving the context if they 1893 were configured to be printed once (default) and were already 1894 printed before the context was entered. 1895 1896 Nesting this context manager will work as expected when the 1897 forwarding rule is "always" (default). Unfiltered and unrecorded 1898 warnings will be passed out and be matched by the outer level. 1899 On the outmost level they will be printed (or caught by another 1900 warnings context). The forwarding rule argument can modify this 1901 behaviour. 1902 1903 Like ``catch_warnings`` this context manager is not threadsafe. 1904 1905 Examples 1906 -------- 1907 1908 With a context manager:: 1909 1910 with np.testing.suppress_warnings() as sup: 1911 sup.filter(DeprecationWarning, "Some text") 1912 sup.filter(module=np.ma.core) 1913 log = sup.record(FutureWarning, "Does this occur?") 1914 command_giving_warnings() 1915 # The FutureWarning was given once, the filtered warnings were 1916 # ignored. All other warnings abide outside settings (may be 1917 # printed/error) 1918 assert_(len(log) == 1) 1919 assert_(len(sup.log) == 1) # also stored in log attribute 1920 1921 Or as a decorator:: 1922 1923 sup = np.testing.suppress_warnings() 1924 sup.filter(module=np.ma.core) # module must match exactly 1925 @sup 1926 def some_function(): 1927 # do something which causes a warning in np.ma.core 1928 pass 1929 """ 1930 1931 def __init__(self, forwarding_rule="always"): 1932 self._entered = False 1933 1934 # Suppressions are either instance or defined inside one with block: 1935 self._suppressions = [] 1936 1937 if forwarding_rule not in {"always", "module", "once", "location"}: 1938 raise ValueError("unsupported forwarding rule.") 1939 self._forwarding_rule = forwarding_rule 1940 1941 def _clear_registries(self): 1942 if hasattr(warnings, "_filters_mutated"): 1943 # clearing the registry should not be necessary on new pythons, 1944 # instead the filters should be mutated. 1945 warnings._filters_mutated() 1946 return 1947 # Simply clear the registry, this should normally be harmless, 1948 # note that on new pythons it would be invalidated anyway. 1949 for module in self._tmp_modules: 1950 if hasattr(module, "__warningregistry__"): 1951 module.__warningregistry__.clear() 1952 1953 def _filter(self, category=Warning, message="", module=None, record=False): 1954 if record: 1955 record = [] # The log where to store warnings 1956 else: 1957 record = None 1958 if self._entered: 1959 if module is None: 1960 warnings.filterwarnings("always", category=category, message=message) 1961 else: 1962 module_regex = module.__name__.replace(".", r"\.") + "$" 1963 warnings.filterwarnings( 1964 "always", category=category, message=message, module=module_regex 1965 ) 1966 self._tmp_modules.add(module) 1967 self._clear_registries() 1968 1969 self._tmp_suppressions.append( 1970 (category, message, re.compile(message, re.IGNORECASE), module, record) 1971 ) 1972 else: 1973 self._suppressions.append( 1974 (category, message, re.compile(message, re.IGNORECASE), module, record) 1975 ) 1976 1977 return record 1978 1979 def filter(self, category=Warning, message="", module=None): 1980 """ 1981 Add a new suppressing filter or apply it if the state is entered. 1982 1983 Parameters 1984 ---------- 1985 category : class, optional 1986 Warning class to filter 1987 message : string, optional 1988 Regular expression matching the warning message. 1989 module : module, optional 1990 Module to filter for. Note that the module (and its file) 1991 must match exactly and cannot be a submodule. This may make 1992 it unreliable for external modules. 1993 1994 Notes 1995 ----- 1996 When added within a context, filters are only added inside 1997 the context and will be forgotten when the context is exited. 1998 """ 1999 self._filter(category=category, message=message, module=module, record=False) 2000 2001 def record(self, category=Warning, message="", module=None): 2002 """ 2003 Append a new recording filter or apply it if the state is entered. 2004 2005 All warnings matching will be appended to the ``log`` attribute. 2006 2007 Parameters 2008 ---------- 2009 category : class, optional 2010 Warning class to filter 2011 message : string, optional 2012 Regular expression matching the warning message. 2013 module : module, optional 2014 Module to filter for. Note that the module (and its file) 2015 must match exactly and cannot be a submodule. This may make 2016 it unreliable for external modules. 2017 2018 Returns 2019 ------- 2020 log : list 2021 A list which will be filled with all matched warnings. 2022 2023 Notes 2024 ----- 2025 When added within a context, filters are only added inside 2026 the context and will be forgotten when the context is exited. 2027 """ 2028 return self._filter( 2029 category=category, message=message, module=module, record=True 2030 ) 2031 2032 def __enter__(self): 2033 if self._entered: 2034 raise RuntimeError("cannot enter suppress_warnings twice.") 2035 2036 self._orig_show = warnings.showwarning 2037 self._filters = warnings.filters 2038 warnings.filters = self._filters[:] 2039 2040 self._entered = True 2041 self._tmp_suppressions = [] 2042 self._tmp_modules = set() 2043 self._forwarded = set() 2044 2045 self.log = [] # reset global log (no need to keep same list) 2046 2047 for cat, mess, _, mod, log in self._suppressions: 2048 if log is not None: 2049 del log[:] # clear the log 2050 if mod is None: 2051 warnings.filterwarnings("always", category=cat, message=mess) 2052 else: 2053 module_regex = mod.__name__.replace(".", r"\.") + "$" 2054 warnings.filterwarnings( 2055 "always", category=cat, message=mess, module=module_regex 2056 ) 2057 self._tmp_modules.add(mod) 2058 warnings.showwarning = self._showwarning 2059 self._clear_registries() 2060 2061 return self 2062 2063 def __exit__(self, *exc_info): 2064 warnings.showwarning = self._orig_show 2065 warnings.filters = self._filters 2066 self._clear_registries() 2067 self._entered = False 2068 del self._orig_show 2069 del self._filters 2070 2071 def _showwarning( 2072 self, message, category, filename, lineno, *args, use_warnmsg=None, **kwargs 2073 ): 2074 for cat, _, pattern, mod, rec in (self._suppressions + self._tmp_suppressions)[ 2075 ::-1 2076 ]: 2077 if issubclass(category, cat) and pattern.match(message.args[0]) is not None: 2078 if mod is None: 2079 # Message and category match, either recorded or ignored 2080 if rec is not None: 2081 msg = WarningMessage( 2082 message, category, filename, lineno, **kwargs 2083 ) 2084 self.log.append(msg) 2085 rec.append(msg) 2086 return 2087 # Use startswith, because warnings strips the c or o from 2088 # .pyc/.pyo files. 2089 elif mod.__file__.startswith(filename): 2090 # The message and module (filename) match 2091 if rec is not None: 2092 msg = WarningMessage( 2093 message, category, filename, lineno, **kwargs 2094 ) 2095 self.log.append(msg) 2096 rec.append(msg) 2097 return 2098 2099 # There is no filter in place, so pass to the outside handler 2100 # unless we should only pass it once 2101 if self._forwarding_rule == "always": 2102 if use_warnmsg is None: 2103 self._orig_show(message, category, filename, lineno, *args, **kwargs) 2104 else: 2105 self._orig_showmsg(use_warnmsg) 2106 return 2107 2108 if self._forwarding_rule == "once": 2109 signature = (message.args, category) 2110 elif self._forwarding_rule == "module": 2111 signature = (message.args, category, filename) 2112 elif self._forwarding_rule == "location": 2113 signature = (message.args, category, filename, lineno) 2114 2115 if signature in self._forwarded: 2116 return 2117 self._forwarded.add(signature) 2118 if use_warnmsg is None: 2119 self._orig_show(message, category, filename, lineno, *args, **kwargs) 2120 else: 2121 self._orig_showmsg(use_warnmsg) 2122 2123 def __call__(self, func): 2124 """ 2125 Function decorator to apply certain suppressions to a whole 2126 function. 2127 """ 2128 2129 @wraps(func) 2130 def new_func(*args, **kwargs): 2131 with self: 2132 return func(*args, **kwargs) 2133 2134 return new_func 2135 2136 2137@contextlib.contextmanager 2138def _assert_no_gc_cycles_context(name=None): 2139 __tracebackhide__ = True # Hide traceback for py.test 2140 2141 # not meaningful to test if there is no refcounting 2142 if not HAS_REFCOUNT: 2143 yield 2144 return 2145 2146 assert_(gc.isenabled()) 2147 gc.disable() 2148 gc_debug = gc.get_debug() 2149 try: 2150 for i in range(100): 2151 if gc.collect() == 0: 2152 break 2153 else: 2154 raise RuntimeError( 2155 "Unable to fully collect garbage - perhaps a __del__ method " 2156 "is creating more reference cycles?" 2157 ) 2158 2159 gc.set_debug(gc.DEBUG_SAVEALL) 2160 yield 2161 # gc.collect returns the number of unreachable objects in cycles that 2162 # were found -- we are checking that no cycles were created in the context 2163 n_objects_in_cycles = gc.collect() 2164 objects_in_cycles = gc.garbage[:] 2165 finally: 2166 del gc.garbage[:] 2167 gc.set_debug(gc_debug) 2168 gc.enable() 2169 2170 if n_objects_in_cycles: 2171 name_str = f" when calling {name}" if name is not None else "" 2172 raise AssertionError( 2173 "Reference cycles were found{}: {} objects were collected, " 2174 "of which {} are shown below:{}".format( 2175 name_str, 2176 n_objects_in_cycles, 2177 len(objects_in_cycles), 2178 "".join( 2179 "\n {} object with id={}:\n {}".format( 2180 type(o).__name__, 2181 id(o), 2182 pprint.pformat(o).replace("\n", "\n "), 2183 ) 2184 for o in objects_in_cycles 2185 ), 2186 ) 2187 ) 2188 2189 2190def assert_no_gc_cycles(*args, **kwargs): 2191 """ 2192 Fail if the given callable produces any reference cycles. 2193 2194 If called with all arguments omitted, may be used as a context manager: 2195 2196 with assert_no_gc_cycles(): 2197 do_something() 2198 2199 .. versionadded:: 1.15.0 2200 2201 Parameters 2202 ---------- 2203 func : callable 2204 The callable to test. 2205 \\*args : Arguments 2206 Arguments passed to `func`. 2207 \\*\\*kwargs : Kwargs 2208 Keyword arguments passed to `func`. 2209 2210 Returns 2211 ------- 2212 Nothing. The result is deliberately discarded to ensure that all cycles 2213 are found. 2214 2215 """ 2216 if not args: 2217 return _assert_no_gc_cycles_context() 2218 2219 func = args[0] 2220 args = args[1:] 2221 with _assert_no_gc_cycles_context(name=func.__name__): 2222 func(*args, **kwargs) 2223 2224 2225def break_cycles(): 2226 """ 2227 Break reference cycles by calling gc.collect 2228 Objects can call other objects' methods (for instance, another object's 2229 __del__) inside their own __del__. On PyPy, the interpreter only runs 2230 between calls to gc.collect, so multiple calls are needed to completely 2231 release all cycles. 2232 """ 2233 2234 gc.collect() 2235 if IS_PYPY: 2236 # a few more, just to make sure all the finalizers are called 2237 gc.collect() 2238 gc.collect() 2239 gc.collect() 2240 gc.collect() 2241 2242 2243def requires_memory(free_bytes): 2244 """Decorator to skip a test if not enough memory is available""" 2245 import pytest 2246 2247 def decorator(func): 2248 @wraps(func) 2249 def wrapper(*a, **kw): 2250 msg = check_free_memory(free_bytes) 2251 if msg is not None: 2252 pytest.skip(msg) 2253 2254 try: 2255 return func(*a, **kw) 2256 except MemoryError: 2257 # Probably ran out of memory regardless: don't regard as failure 2258 pytest.xfail("MemoryError raised") 2259 2260 return wrapper 2261 2262 return decorator 2263 2264 2265def check_free_memory(free_bytes): 2266 """ 2267 Check whether `free_bytes` amount of memory is currently free. 2268 Returns: None if enough memory available, otherwise error message 2269 """ 2270 env_var = "NPY_AVAILABLE_MEM" 2271 env_value = os.environ.get(env_var) 2272 if env_value is not None: 2273 try: 2274 mem_free = _parse_size(env_value) 2275 except ValueError as exc: 2276 raise ValueError( # noqa: B904 2277 f"Invalid environment variable {env_var}: {exc}" 2278 ) 2279 2280 msg = ( 2281 f"{free_bytes/1e9} GB memory required, but environment variable " 2282 f"NPY_AVAILABLE_MEM={env_value} set" 2283 ) 2284 else: 2285 mem_free = _get_mem_available() 2286 2287 if mem_free is None: 2288 msg = ( 2289 "Could not determine available memory; set NPY_AVAILABLE_MEM " 2290 "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " 2291 "the test." 2292 ) 2293 mem_free = -1 2294 else: 2295 msg = ( 2296 f"{free_bytes/1e9} GB memory required, but {mem_free/1e9} GB available" 2297 ) 2298 2299 return msg if mem_free < free_bytes else None 2300 2301 2302def _parse_size(size_str): 2303 """Convert memory size strings ('12 GB' etc.) to float""" 2304 suffixes = { 2305 "": 1, 2306 "b": 1, 2307 "k": 1000, 2308 "m": 1000**2, 2309 "g": 1000**3, 2310 "t": 1000**4, 2311 "kb": 1000, 2312 "mb": 1000**2, 2313 "gb": 1000**3, 2314 "tb": 1000**4, 2315 "kib": 1024, 2316 "mib": 1024**2, 2317 "gib": 1024**3, 2318 "tib": 1024**4, 2319 } 2320 2321 size_re = re.compile( 2322 r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())), 2323 re.IGNORECASE, 2324 ) 2325 2326 m = size_re.match(size_str.lower()) 2327 if not m or m.group(2) not in suffixes: 2328 raise ValueError(f"value {size_str!r} not a valid size") 2329 return int(float(m.group(1)) * suffixes[m.group(2)]) 2330 2331 2332def _get_mem_available(): 2333 """Return available memory in bytes, or None if unknown.""" 2334 try: 2335 import psutil 2336 2337 return psutil.virtual_memory().available 2338 except (ImportError, AttributeError): 2339 pass 2340 2341 if sys.platform.startswith("linux"): 2342 info = {} 2343 with open("/proc/meminfo") as f: 2344 for line in f: 2345 p = line.split() 2346 info[p[0].strip(":").lower()] = int(p[1]) * 1024 2347 2348 if "memavailable" in info: 2349 # Linux >= 3.14 2350 return info["memavailable"] 2351 else: 2352 return info["memfree"] + info["cached"] 2353 2354 return None 2355 2356 2357def _no_tracing(func): 2358 """ 2359 Decorator to temporarily turn off tracing for the duration of a test. 2360 Needed in tests that check refcounting, otherwise the tracing itself 2361 influences the refcounts 2362 """ 2363 if not hasattr(sys, "gettrace"): 2364 return func 2365 else: 2366 2367 @wraps(func) 2368 def wrapper(*args, **kwargs): 2369 original_trace = sys.gettrace() 2370 try: 2371 sys.settrace(None) 2372 return func(*args, **kwargs) 2373 finally: 2374 sys.settrace(original_trace) 2375 2376 return wrapper 2377 2378 2379def _get_glibc_version(): 2380 try: 2381 ver = os.confstr("CS_GNU_LIBC_VERSION").rsplit(" ")[1] 2382 except Exception as inst: 2383 ver = "0.0" 2384 2385 return ver 2386 2387 2388_glibcver = _get_glibc_version() 2389 2390 2391def _glibc_older_than(x): 2392 return _glibcver != "0.0" and _glibcver < x 2393