1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport doctest 5*da0073e9SAndroid Build Coastguard Workerimport functools 6*da0073e9SAndroid Build Coastguard Workerimport importlib 7*da0073e9SAndroid Build Coastguard Workerimport inspect 8*da0073e9SAndroid Build Coastguard Workerimport itertools 9*da0073e9SAndroid Build Coastguard Workerimport math 10*da0073e9SAndroid Build Coastguard Workerimport os 11*da0073e9SAndroid Build Coastguard Workerimport re 12*da0073e9SAndroid Build Coastguard Workerimport subprocess 13*da0073e9SAndroid Build Coastguard Workerimport sys 14*da0073e9SAndroid Build Coastguard Workerimport unittest.mock 15*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Iterator, List, Tuple 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerimport torch 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import \ 21*da0073e9SAndroid Build Coastguard Worker (IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest, 22*da0073e9SAndroid Build Coastguard Worker parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf) 23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import \ 24*da0073e9SAndroid Build Coastguard Worker (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes, 25*da0073e9SAndroid Build Coastguard Worker get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes, 26*da0073e9SAndroid Build Coastguard Worker deviceCountAtLeast, ops, expectedFailureMeta, OpDTypes) 27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db 28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import opinfo 29*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import all_types_and_complex_and, floating_types 30*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_modules import modules, module_db, ModuleInfo 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.opinfo.core import SampleInput, DecorateInfo, OpInfo 32*da0073e9SAndroid Build Coastguard Workerimport operator 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker# For testing TestCase methods and torch.testing functions 35*da0073e9SAndroid Build Coastguard Workerclass TestTesting(TestCase): 36*da0073e9SAndroid Build Coastguard Worker # Ensure that assertEqual handles numpy arrays properly 37*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bool, torch.half)) 38*da0073e9SAndroid Build Coastguard Worker def test_assertEqual_numpy(self, device, dtype): 39*da0073e9SAndroid Build Coastguard Worker S = 10 40*da0073e9SAndroid Build Coastguard Worker test_sizes = [ 41*da0073e9SAndroid Build Coastguard Worker (), 42*da0073e9SAndroid Build Coastguard Worker (0,), 43*da0073e9SAndroid Build Coastguard Worker (S,), 44*da0073e9SAndroid Build Coastguard Worker (S, S), 45*da0073e9SAndroid Build Coastguard Worker (0, S), 46*da0073e9SAndroid Build Coastguard Worker (S, 0)] 47*da0073e9SAndroid Build Coastguard Worker for test_size in test_sizes: 48*da0073e9SAndroid Build Coastguard Worker a = make_tensor(test_size, dtype=dtype, device=device, low=-5, high=5) 49*da0073e9SAndroid Build Coastguard Worker a_n = a.cpu().numpy() 50*da0073e9SAndroid Build Coastguard Worker msg = f'size: {test_size}' 51*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg) 52*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg) 53*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def test_assertEqual_longMessage(self): 56*da0073e9SAndroid Build Coastguard Worker actual = "actual" 57*da0073e9SAndroid Build Coastguard Worker expected = "expected" 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker long_message = self.longMessage 60*da0073e9SAndroid Build Coastguard Worker try: 61*da0073e9SAndroid Build Coastguard Worker # Capture the default error message by forcing TestCase.longMessage = False 62*da0073e9SAndroid Build Coastguard Worker self.longMessage = False 63*da0073e9SAndroid Build Coastguard Worker try: 64*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 65*da0073e9SAndroid Build Coastguard Worker except AssertionError as error: 66*da0073e9SAndroid Build Coastguard Worker default_msg = str(error) 67*da0073e9SAndroid Build Coastguard Worker else: 68*da0073e9SAndroid Build Coastguard Worker raise AssertionError("AssertionError not raised") 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker self.longMessage = True 71*da0073e9SAndroid Build Coastguard Worker extra_msg = "sentinel" 72*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape(f"{default_msg}\n{extra_msg}")): 73*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, msg=extra_msg) 74*da0073e9SAndroid Build Coastguard Worker finally: 75*da0073e9SAndroid Build Coastguard Worker self.longMessage = long_message 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05): 78*da0073e9SAndroid Build Coastguard Worker for test in tests: 79*da0073e9SAndroid Build Coastguard Worker a = torch.tensor((test[0],), device=device, dtype=dtype) 80*da0073e9SAndroid Build Coastguard Worker b = torch.tensor((test[1],), device=device, dtype=dtype) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol) 83*da0073e9SAndroid Build Coastguard Worker expected = test[2] 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.item(), expected) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def test_isclose_bool(self, device): 87*da0073e9SAndroid Build Coastguard Worker tests = ( 88*da0073e9SAndroid Build Coastguard Worker (True, True, True), 89*da0073e9SAndroid Build Coastguard Worker (False, False, True), 90*da0073e9SAndroid Build Coastguard Worker (True, False, False), 91*da0073e9SAndroid Build Coastguard Worker (False, True, False), 92*da0073e9SAndroid Build Coastguard Worker ) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, torch.bool, False) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.uint8, 97*da0073e9SAndroid Build Coastguard Worker torch.int8, torch.int16, torch.int32, torch.int64) 98*da0073e9SAndroid Build Coastguard Worker def test_isclose_integer(self, device, dtype): 99*da0073e9SAndroid Build Coastguard Worker tests = ( 100*da0073e9SAndroid Build Coastguard Worker (0, 0, True), 101*da0073e9SAndroid Build Coastguard Worker (0, 1, False), 102*da0073e9SAndroid Build Coastguard Worker (1, 0, False), 103*da0073e9SAndroid Build Coastguard Worker ) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, False) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker # atol and rtol tests 108*da0073e9SAndroid Build Coastguard Worker tests = [ 109*da0073e9SAndroid Build Coastguard Worker (0, 1, True), 110*da0073e9SAndroid Build Coastguard Worker (1, 0, False), 111*da0073e9SAndroid Build Coastguard Worker (1, 3, True), 112*da0073e9SAndroid Build Coastguard Worker ] 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker if dtype is torch.uint8: 117*da0073e9SAndroid Build Coastguard Worker tests = [ 118*da0073e9SAndroid Build Coastguard Worker (-1, 1, False), 119*da0073e9SAndroid Build Coastguard Worker (1, -1, False) 120*da0073e9SAndroid Build Coastguard Worker ] 121*da0073e9SAndroid Build Coastguard Worker else: 122*da0073e9SAndroid Build Coastguard Worker tests = [ 123*da0073e9SAndroid Build Coastguard Worker (-1, 1, True), 124*da0073e9SAndroid Build Coastguard Worker (1, -1, True) 125*da0073e9SAndroid Build Coastguard Worker ] 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 130*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.float32, torch.float64) 131*da0073e9SAndroid Build Coastguard Worker def test_isclose_float(self, device, dtype): 132*da0073e9SAndroid Build Coastguard Worker tests = ( 133*da0073e9SAndroid Build Coastguard Worker (0, 0, True), 134*da0073e9SAndroid Build Coastguard Worker (0, -1, False), 135*da0073e9SAndroid Build Coastguard Worker (float('inf'), float('inf'), True), 136*da0073e9SAndroid Build Coastguard Worker (-float('inf'), float('inf'), False), 137*da0073e9SAndroid Build Coastguard Worker (float('inf'), float('nan'), False), 138*da0073e9SAndroid Build Coastguard Worker (float('nan'), float('nan'), False), 139*da0073e9SAndroid Build Coastguard Worker (0, float('nan'), False), 140*da0073e9SAndroid Build Coastguard Worker (1, 1, True), 141*da0073e9SAndroid Build Coastguard Worker ) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, False) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker # atol and rtol tests 146*da0073e9SAndroid Build Coastguard Worker eps = 1e-2 if dtype is torch.half else 1e-6 147*da0073e9SAndroid Build Coastguard Worker tests = ( 148*da0073e9SAndroid Build Coastguard Worker (0, 1, True), 149*da0073e9SAndroid Build Coastguard Worker (0, 1 + eps, False), 150*da0073e9SAndroid Build Coastguard Worker (1, 0, False), 151*da0073e9SAndroid Build Coastguard Worker (1, 3, True), 152*da0073e9SAndroid Build Coastguard Worker (1 - eps, 3, False), 153*da0073e9SAndroid Build Coastguard Worker (-.25, .5, True), 154*da0073e9SAndroid Build Coastguard Worker (-.25 - eps, .5, False), 155*da0073e9SAndroid Build Coastguard Worker (.25, -.5, True), 156*da0073e9SAndroid Build Coastguard Worker (.25 + eps, -.5, False), 157*da0073e9SAndroid Build Coastguard Worker ) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker # equal_nan = True tests 162*da0073e9SAndroid Build Coastguard Worker tests = ( 163*da0073e9SAndroid Build Coastguard Worker (0, float('nan'), False), 164*da0073e9SAndroid Build Coastguard Worker (float('inf'), float('nan'), False), 165*da0073e9SAndroid Build Coastguard Worker (float('nan'), float('nan'), True), 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, True) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle") 171*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.complex64, torch.complex128) 172*da0073e9SAndroid Build Coastguard Worker def test_isclose_complex(self, device, dtype): 173*da0073e9SAndroid Build Coastguard Worker tests = ( 174*da0073e9SAndroid Build Coastguard Worker (complex(1, 1), complex(1, 1 + 1e-8), True), 175*da0073e9SAndroid Build Coastguard Worker (complex(0, 1), complex(1, 1), False), 176*da0073e9SAndroid Build Coastguard Worker (complex(1, 1), complex(1, 0), False), 177*da0073e9SAndroid Build Coastguard Worker (complex(1, 1), complex(1, float('nan')), False), 178*da0073e9SAndroid Build Coastguard Worker (complex(1, float('nan')), complex(1, float('nan')), False), 179*da0073e9SAndroid Build Coastguard Worker (complex(1, 1), complex(1, float('inf')), False), 180*da0073e9SAndroid Build Coastguard Worker (complex(float('inf'), 1), complex(1, float('inf')), False), 181*da0073e9SAndroid Build Coastguard Worker (complex(-float('inf'), 1), complex(1, float('inf')), False), 182*da0073e9SAndroid Build Coastguard Worker (complex(-float('inf'), 1), complex(float('inf'), 1), False), 183*da0073e9SAndroid Build Coastguard Worker (complex(float('inf'), 1), complex(float('inf'), 1), True), 184*da0073e9SAndroid Build Coastguard Worker (complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False), 185*da0073e9SAndroid Build Coastguard Worker ) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, False) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker # atol and rtol tests 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker # atol and rtol tests 192*da0073e9SAndroid Build Coastguard Worker eps = 1e-6 193*da0073e9SAndroid Build Coastguard Worker tests = ( 194*da0073e9SAndroid Build Coastguard Worker # Complex versions of float tests (real part) 195*da0073e9SAndroid Build Coastguard Worker (complex(0, 0), complex(1, 0), True), 196*da0073e9SAndroid Build Coastguard Worker (complex(0, 0), complex(1 + eps, 0), False), 197*da0073e9SAndroid Build Coastguard Worker (complex(1, 0), complex(0, 0), False), 198*da0073e9SAndroid Build Coastguard Worker (complex(1, 0), complex(3, 0), True), 199*da0073e9SAndroid Build Coastguard Worker (complex(1 - eps, 0), complex(3, 0), False), 200*da0073e9SAndroid Build Coastguard Worker (complex(-.25, 0), complex(.5, 0), True), 201*da0073e9SAndroid Build Coastguard Worker (complex(-.25 - eps, 0), complex(.5, 0), False), 202*da0073e9SAndroid Build Coastguard Worker (complex(.25, 0), complex(-.5, 0), True), 203*da0073e9SAndroid Build Coastguard Worker (complex(.25 + eps, 0), complex(-.5, 0), False), 204*da0073e9SAndroid Build Coastguard Worker # Complex versions of float tests (imaginary part) 205*da0073e9SAndroid Build Coastguard Worker (complex(0, 0), complex(0, 1), True), 206*da0073e9SAndroid Build Coastguard Worker (complex(0, 0), complex(0, 1 + eps), False), 207*da0073e9SAndroid Build Coastguard Worker (complex(0, 1), complex(0, 0), False), 208*da0073e9SAndroid Build Coastguard Worker (complex(0, 1), complex(0, 3), True), 209*da0073e9SAndroid Build Coastguard Worker (complex(0, 1 - eps), complex(0, 3), False), 210*da0073e9SAndroid Build Coastguard Worker (complex(0, -.25), complex(0, .5), True), 211*da0073e9SAndroid Build Coastguard Worker (complex(0, -.25 - eps), complex(0, .5), False), 212*da0073e9SAndroid Build Coastguard Worker (complex(0, .25), complex(0, -.5), True), 213*da0073e9SAndroid Build Coastguard Worker (complex(0, .25 + eps), complex(0, -.5), False), 214*da0073e9SAndroid Build Coastguard Worker ) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker # atol and rtol tests for isclose 219*da0073e9SAndroid Build Coastguard Worker tests = ( 220*da0073e9SAndroid Build Coastguard Worker # Complex-specific tests 221*da0073e9SAndroid Build Coastguard Worker (complex(1, -1), complex(-1, 1), False), 222*da0073e9SAndroid Build Coastguard Worker (complex(1, -1), complex(2, -2), True), 223*da0073e9SAndroid Build Coastguard Worker (complex(-math.sqrt(2), math.sqrt(2)), 224*da0073e9SAndroid Build Coastguard Worker complex(-math.sqrt(.5), math.sqrt(.5)), True), 225*da0073e9SAndroid Build Coastguard Worker (complex(-math.sqrt(2), math.sqrt(2)), 226*da0073e9SAndroid Build Coastguard Worker complex(-math.sqrt(.501), math.sqrt(.499)), False), 227*da0073e9SAndroid Build Coastguard Worker (complex(2, 4), complex(1., 8.8523607), True), 228*da0073e9SAndroid Build Coastguard Worker (complex(2, 4), complex(1., 8.8523607 + eps), False), 229*da0073e9SAndroid Build Coastguard Worker (complex(1, 99), complex(4, 100), True), 230*da0073e9SAndroid Build Coastguard Worker ) 231*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker # equal_nan = True tests 234*da0073e9SAndroid Build Coastguard Worker tests = ( 235*da0073e9SAndroid Build Coastguard Worker (complex(1, 1), complex(1, float('nan')), False), 236*da0073e9SAndroid Build Coastguard Worker (complex(1, 1), complex(float('nan'), 1), False), 237*da0073e9SAndroid Build Coastguard Worker (complex(float('nan'), 1), complex(float('nan'), 1), True), 238*da0073e9SAndroid Build Coastguard Worker (complex(float('nan'), 1), complex(1, float('nan')), True), 239*da0073e9SAndroid Build Coastguard Worker (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True), 240*da0073e9SAndroid Build Coastguard Worker ) 241*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, True) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker # Tests that isclose with rtol or atol values less than zero throws a 244*da0073e9SAndroid Build Coastguard Worker # RuntimeError 245*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bool, torch.uint8, 246*da0073e9SAndroid Build Coastguard Worker torch.int8, torch.int16, torch.int32, torch.int64, 247*da0073e9SAndroid Build Coastguard Worker torch.float16, torch.float32, torch.float64) 248*da0073e9SAndroid Build Coastguard Worker def test_isclose_atol_rtol_greater_than_zero(self, device, dtype): 249*da0073e9SAndroid Build Coastguard Worker t = torch.tensor((1,), device=device, dtype=dtype) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 252*da0073e9SAndroid Build Coastguard Worker torch.isclose(t, t, atol=-1, rtol=1) 253*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 254*da0073e9SAndroid Build Coastguard Worker torch.isclose(t, t, atol=1, rtol=-1) 255*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 256*da0073e9SAndroid Build Coastguard Worker torch.isclose(t, t, atol=-1, rtol=-1) 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker def test_isclose_equality_shortcut(self): 259*da0073e9SAndroid Build Coastguard Worker # For values >= 2**53, integers differing by 1 can no longer differentiated by torch.float64 or lower precision 260*da0073e9SAndroid Build Coastguard Worker # floating point dtypes. Thus, even with rtol == 0 and atol == 0, these tensors would be considered close if 261*da0073e9SAndroid Build Coastguard Worker # they were not compared as integers. 262*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(2 ** 53, dtype=torch.int64) 263*da0073e9SAndroid Build Coastguard Worker b = a + 1 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.isclose(a, b, rtol=0, atol=0)) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128) 268*da0073e9SAndroid Build Coastguard Worker def test_isclose_nan_equality_shortcut(self, device, dtype): 269*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 270*da0073e9SAndroid Build Coastguard Worker a = b = torch.nan 271*da0073e9SAndroid Build Coastguard Worker else: 272*da0073e9SAndroid Build Coastguard Worker a = complex(torch.nan, 0) 273*da0073e9SAndroid Build Coastguard Worker b = complex(0, torch.nan) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker expected = True 276*da0073e9SAndroid Build Coastguard Worker tests = [(a, b, expected)] 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker # The following tests (test_cuda_assert_*) are added to ensure test suite terminates early 281*da0073e9SAndroid Build Coastguard Worker # when CUDA assert was thrown. Because all subsequent test will fail if that happens. 282*da0073e9SAndroid Build Coastguard Worker # These tests are slow because it spawn another process to run test suite. 283*da0073e9SAndroid Build Coastguard Worker # See: https://github.com/pytorch/pytorch/issues/49019 284*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts") 285*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 286*da0073e9SAndroid Build Coastguard Worker @slowTest 287*da0073e9SAndroid Build Coastguard Worker def test_cuda_assert_should_stop_common_utils_test_suite(self, device): 288*da0073e9SAndroid Build Coastguard Worker # test to ensure common_utils.py override has early termination for CUDA. 289*da0073e9SAndroid Build Coastguard Worker stderr = TestCase.runWithPytorchAPIUsageStderr("""\ 290*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Workerimport torch 293*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (TestCase, run_tests, slowTest) 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Workerclass TestThatContainsCUDAAssertFailure(TestCase): 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker @slowTest 298*da0073e9SAndroid Build Coastguard Worker def test_throw_unrecoverable_cuda_exception(self): 299*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, device='cuda') 300*da0073e9SAndroid Build Coastguard Worker # cause unrecoverable CUDA exception, recoverable on CPU 301*da0073e9SAndroid Build Coastguard Worker y = x[torch.tensor([25])].cpu() 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker @slowTest 304*da0073e9SAndroid Build Coastguard Worker def test_trivial_passing_test_case_on_cpu_cuda(self): 305*da0073e9SAndroid Build Coastguard Worker x1 = torch.tensor([0., 1.], device='cuda') 306*da0073e9SAndroid Build Coastguard Worker x2 = torch.tensor([0., 1.], device='cpu') 307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1, x2) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 310*da0073e9SAndroid Build Coastguard Worker run_tests() 311*da0073e9SAndroid Build Coastguard Worker""") 312*da0073e9SAndroid Build Coastguard Worker # should capture CUDA error 313*da0073e9SAndroid Build Coastguard Worker self.assertIn('CUDA error: device-side assert triggered', stderr) 314*da0073e9SAndroid Build Coastguard Worker # should run only 1 test because it throws unrecoverable error. 315*da0073e9SAndroid Build Coastguard Worker self.assertIn('errors=1', stderr) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts") 319*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 320*da0073e9SAndroid Build Coastguard Worker @slowTest 321*da0073e9SAndroid Build Coastguard Worker def test_cuda_assert_should_stop_common_device_type_test_suite(self, device): 322*da0073e9SAndroid Build Coastguard Worker # test to ensure common_device_type.py override has early termination for CUDA. 323*da0073e9SAndroid Build Coastguard Worker stderr = TestCase.runWithPytorchAPIUsageStderr("""\ 324*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Workerimport torch 327*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (TestCase, run_tests, slowTest) 328*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Workerclass TestThatContainsCUDAAssertFailure(TestCase): 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker @slowTest 333*da0073e9SAndroid Build Coastguard Worker def test_throw_unrecoverable_cuda_exception(self, device): 334*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, device=device) 335*da0073e9SAndroid Build Coastguard Worker # cause unrecoverable CUDA exception, recoverable on CPU 336*da0073e9SAndroid Build Coastguard Worker y = x[torch.tensor([25])].cpu() 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker @slowTest 339*da0073e9SAndroid Build Coastguard Worker def test_trivial_passing_test_case_on_cpu_cuda(self, device): 340*da0073e9SAndroid Build Coastguard Worker x1 = torch.tensor([0., 1.], device=device) 341*da0073e9SAndroid Build Coastguard Worker x2 = torch.tensor([0., 1.], device='cpu') 342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1, x2) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests( 345*da0073e9SAndroid Build Coastguard Worker TestThatContainsCUDAAssertFailure, 346*da0073e9SAndroid Build Coastguard Worker globals(), 347*da0073e9SAndroid Build Coastguard Worker only_for='cuda' 348*da0073e9SAndroid Build Coastguard Worker) 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 351*da0073e9SAndroid Build Coastguard Worker run_tests() 352*da0073e9SAndroid Build Coastguard Worker""") 353*da0073e9SAndroid Build Coastguard Worker # should capture CUDA error 354*da0073e9SAndroid Build Coastguard Worker self.assertIn('CUDA error: device-side assert triggered', stderr) 355*da0073e9SAndroid Build Coastguard Worker # should run only 1 test because it throws unrecoverable error. 356*da0073e9SAndroid Build Coastguard Worker self.assertIn('errors=1', stderr) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts") 360*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 361*da0073e9SAndroid Build Coastguard Worker @slowTest 362*da0073e9SAndroid Build Coastguard Worker def test_cuda_assert_should_not_stop_common_distributed_test_suite(self, device): 363*da0073e9SAndroid Build Coastguard Worker # test to ensure common_distributed.py override should not early terminate CUDA. 364*da0073e9SAndroid Build Coastguard Worker stderr = TestCase.runWithPytorchAPIUsageStderr("""\ 365*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Workerimport torch 368*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (run_tests, slowTest) 369*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests 370*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_distributed import MultiProcessTestCase 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Workerclass TestThatContainsCUDAAssertFailure(MultiProcessTestCase): 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker @slowTest 375*da0073e9SAndroid Build Coastguard Worker def test_throw_unrecoverable_cuda_exception(self, device): 376*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, device=device) 377*da0073e9SAndroid Build Coastguard Worker # cause unrecoverable CUDA exception, recoverable on CPU 378*da0073e9SAndroid Build Coastguard Worker y = x[torch.tensor([25])].cpu() 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker @slowTest 381*da0073e9SAndroid Build Coastguard Worker def test_trivial_passing_test_case_on_cpu_cuda(self, device): 382*da0073e9SAndroid Build Coastguard Worker x1 = torch.tensor([0., 1.], device=device) 383*da0073e9SAndroid Build Coastguard Worker x2 = torch.tensor([0., 1.], device='cpu') 384*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1, x2) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests( 387*da0073e9SAndroid Build Coastguard Worker TestThatContainsCUDAAssertFailure, 388*da0073e9SAndroid Build Coastguard Worker globals(), 389*da0073e9SAndroid Build Coastguard Worker only_for='cuda' 390*da0073e9SAndroid Build Coastguard Worker) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 393*da0073e9SAndroid Build Coastguard Worker run_tests() 394*da0073e9SAndroid Build Coastguard Worker""") 395*da0073e9SAndroid Build Coastguard Worker # we are currently disabling CUDA early termination for distributed tests. 396*da0073e9SAndroid Build Coastguard Worker self.assertIn('errors=2', stderr) 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # This is only supported for CPU and CUDA 399*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 400*da0073e9SAndroid Build Coastguard Worker def test_get_supported_dtypes(self, device): 401*da0073e9SAndroid Build Coastguard Worker # Test the `get_supported_dtypes` helper function. 402*da0073e9SAndroid Build Coastguard Worker # We acquire the dtypes for few Ops dynamically and verify them against 403*da0073e9SAndroid Build Coastguard Worker # the correct statically described values. 404*da0073e9SAndroid Build Coastguard Worker ops_to_test = list(filter(lambda op: op.name in ['atan2', 'topk', 'xlogy'], op_db)) 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker for op in ops_to_test: 407*da0073e9SAndroid Build Coastguard Worker dynamic_dtypes = opinfo.utils.get_supported_dtypes(op, op.sample_inputs_func, self.device_type) 408*da0073e9SAndroid Build Coastguard Worker dynamic_dispatch = opinfo.utils.dtypes_dispatch_hint(dynamic_dtypes) 409*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cpu': 410*da0073e9SAndroid Build Coastguard Worker dtypes = op.dtypes 411*da0073e9SAndroid Build Coastguard Worker else: # device_type ='cuda' 412*da0073e9SAndroid Build Coastguard Worker dtypes = op.dtypesIfCUDA 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker self.assertTrue(set(dtypes) == set(dynamic_dtypes)) 415*da0073e9SAndroid Build Coastguard Worker self.assertTrue(set(dtypes) == set(dynamic_dispatch.dispatch_fn())) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker @onlyCPU 418*da0073e9SAndroid Build Coastguard Worker @ops( 419*da0073e9SAndroid Build Coastguard Worker [ 420*da0073e9SAndroid Build Coastguard Worker op 421*da0073e9SAndroid Build Coastguard Worker for op in op_db 422*da0073e9SAndroid Build Coastguard Worker if len( 423*da0073e9SAndroid Build Coastguard Worker op.supported_dtypes("cpu").symmetric_difference( 424*da0073e9SAndroid Build Coastguard Worker op.supported_dtypes("cuda") 425*da0073e9SAndroid Build Coastguard Worker ) 426*da0073e9SAndroid Build Coastguard Worker ) 427*da0073e9SAndroid Build Coastguard Worker > 0 428*da0073e9SAndroid Build Coastguard Worker ][:1], 429*da0073e9SAndroid Build Coastguard Worker dtypes=OpDTypes.none, 430*da0073e9SAndroid Build Coastguard Worker ) 431*da0073e9SAndroid Build Coastguard Worker def test_supported_dtypes(self, device, op): 432*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(op.supported_dtypes("cpu"), op.supported_dtypes("cuda")) 433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op.supported_dtypes("cuda"), op.supported_dtypes("cuda:0")) 434*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 435*da0073e9SAndroid Build Coastguard Worker op.supported_dtypes(torch.device("cuda")), 436*da0073e9SAndroid Build Coastguard Worker op.supported_dtypes(torch.device("cuda", index=1)), 437*da0073e9SAndroid Build Coastguard Worker ) 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTesting, globals()) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Workerclass TestFrameworkUtils(TestCase): 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows") 445*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle") 446*da0073e9SAndroid Build Coastguard Worker def test_filtering_env_var(self): 447*da0073e9SAndroid Build Coastguard Worker # Test environment variable selected device type test generator. 448*da0073e9SAndroid Build Coastguard Worker test_filter_file_template = """\ 449*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Workerimport torch 452*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (TestCase, run_tests) 453*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Workerclass TestEnvironmentVariable(TestCase): 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker def test_trivial_passing_test(self, device): 458*da0073e9SAndroid Build Coastguard Worker x1 = torch.tensor([0., 1.], device=device) 459*da0073e9SAndroid Build Coastguard Worker x2 = torch.tensor([0., 1.], device='cpu') 460*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1, x2) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests( 463*da0073e9SAndroid Build Coastguard Worker TestEnvironmentVariable, 464*da0073e9SAndroid Build Coastguard Worker globals(), 465*da0073e9SAndroid Build Coastguard Worker) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 468*da0073e9SAndroid Build Coastguard Worker run_tests() 469*da0073e9SAndroid Build Coastguard Worker""" 470*da0073e9SAndroid Build Coastguard Worker test_bases_count = len(get_device_type_test_bases()) 471*da0073e9SAndroid Build Coastguard Worker # Test without setting env var should run everything. 472*da0073e9SAndroid Build Coastguard Worker env = dict(os.environ) 473*da0073e9SAndroid Build Coastguard Worker for k in ['CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]: 474*da0073e9SAndroid Build Coastguard Worker if k in env.keys(): 475*da0073e9SAndroid Build Coastguard Worker del env[k] 476*da0073e9SAndroid Build Coastguard Worker _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) 477*da0073e9SAndroid Build Coastguard Worker self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii')) 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker # Test with setting only_for should only run 1 test. 480*da0073e9SAndroid Build Coastguard Worker env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu' 481*da0073e9SAndroid Build Coastguard Worker _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) 482*da0073e9SAndroid Build Coastguard Worker self.assertIn('Ran 1 test', stderr.decode('ascii')) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker # Test with setting except_for should run 1 less device type from default. 485*da0073e9SAndroid Build Coastguard Worker del env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] 486*da0073e9SAndroid Build Coastguard Worker env[PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY] = 'cpu' 487*da0073e9SAndroid Build Coastguard Worker _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) 488*da0073e9SAndroid Build Coastguard Worker self.assertIn(f'Ran {test_bases_count-1} test', stderr.decode('ascii')) 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker # Test with setting both should throw exception 491*da0073e9SAndroid Build Coastguard Worker env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu' 492*da0073e9SAndroid Build Coastguard Worker _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) 493*da0073e9SAndroid Build Coastguard Worker self.assertNotIn('OK', stderr.decode('ascii')) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Workerdef make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]: 497*da0073e9SAndroid Build Coastguard Worker """Makes inputs for :func:`torch.testing.assert_close` functions based on two examples. 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker Args: 500*da0073e9SAndroid Build Coastguard Worker actual (Any): Actual input. 501*da0073e9SAndroid Build Coastguard Worker expected (Any): Expected input. 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker Returns: 504*da0073e9SAndroid Build Coastguard Worker List[Tuple[Any, Any]]: Pair of example inputs, as well as the example inputs wrapped in sequences 505*da0073e9SAndroid Build Coastguard Worker (:class:`tuple`, :class:`list`), and mappings (:class:`dict`, :class:`~collections.OrderedDict`). 506*da0073e9SAndroid Build Coastguard Worker """ 507*da0073e9SAndroid Build Coastguard Worker return [ 508*da0073e9SAndroid Build Coastguard Worker (actual, expected), 509*da0073e9SAndroid Build Coastguard Worker # tuple vs. tuple 510*da0073e9SAndroid Build Coastguard Worker ((actual,), (expected,)), 511*da0073e9SAndroid Build Coastguard Worker # list vs. list 512*da0073e9SAndroid Build Coastguard Worker ([actual], [expected]), 513*da0073e9SAndroid Build Coastguard Worker # tuple vs. list 514*da0073e9SAndroid Build Coastguard Worker ((actual,), [expected]), 515*da0073e9SAndroid Build Coastguard Worker # dict vs. dict 516*da0073e9SAndroid Build Coastguard Worker ({"t": actual}, {"t": expected}), 517*da0073e9SAndroid Build Coastguard Worker # OrderedDict vs. OrderedDict 518*da0073e9SAndroid Build Coastguard Worker (collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])), 519*da0073e9SAndroid Build Coastguard Worker # dict vs. OrderedDict 520*da0073e9SAndroid Build Coastguard Worker ({"t": actual}, collections.OrderedDict([("t", expected)])), 521*da0073e9SAndroid Build Coastguard Worker # list of tuples vs. tuple of lists 522*da0073e9SAndroid Build Coastguard Worker ([(actual,)], ([expected],)), 523*da0073e9SAndroid Build Coastguard Worker # list of dicts vs. tuple of OrderedDicts 524*da0073e9SAndroid Build Coastguard Worker ([{"t": actual}], (collections.OrderedDict([("t", expected)]),)), 525*da0073e9SAndroid Build Coastguard Worker # dict of lists vs. OrderedDict of tuples 526*da0073e9SAndroid Build Coastguard Worker ({"t": [actual]}, collections.OrderedDict([("t", (expected,))])), 527*da0073e9SAndroid Build Coastguard Worker ] 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Workerdef assert_close_with_inputs(actual: Any, expected: Any) -> Iterator[Callable]: 531*da0073e9SAndroid Build Coastguard Worker """Yields :func:`torch.testing.assert_close` with predefined positional inputs based on two examples. 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker .. note:: 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker Every test that does not test for a specific input should iterate over this to maximize the coverage. 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker Args: 538*da0073e9SAndroid Build Coastguard Worker actual (Any): Actual input. 539*da0073e9SAndroid Build Coastguard Worker expected (Any): Expected input. 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker Yields: 542*da0073e9SAndroid Build Coastguard Worker Callable: :func:`torch.testing.assert_close` with predefined positional inputs. 543*da0073e9SAndroid Build Coastguard Worker """ 544*da0073e9SAndroid Build Coastguard Worker for inputs in make_assert_close_inputs(actual, expected): 545*da0073e9SAndroid Build Coastguard Worker yield functools.partial(torch.testing.assert_close, *inputs) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Workerclass TestAssertClose(TestCase): 549*da0073e9SAndroid Build Coastguard Worker def test_mismatching_types_subclasses(self): 550*da0073e9SAndroid Build Coastguard Worker actual = torch.rand(()) 551*da0073e9SAndroid Build Coastguard Worker expected = torch.nn.Parameter(actual) 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 554*da0073e9SAndroid Build Coastguard Worker fn() 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker def test_mismatching_types_type_equality(self): 557*da0073e9SAndroid Build Coastguard Worker actual = torch.empty(()) 558*da0073e9SAndroid Build Coastguard Worker expected = torch.nn.Parameter(actual) 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 561*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, str(type(expected))): 562*da0073e9SAndroid Build Coastguard Worker fn(allow_subclasses=False) 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker def test_mismatching_types(self): 565*da0073e9SAndroid Build Coastguard Worker actual = torch.empty(2) 566*da0073e9SAndroid Build Coastguard Worker expected = actual.numpy() 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker for fn, allow_subclasses in itertools.product(assert_close_with_inputs(actual, expected), (True, False)): 569*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, str(type(expected))): 570*da0073e9SAndroid Build Coastguard Worker fn(allow_subclasses=allow_subclasses) 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker def test_unknown_type(self): 573*da0073e9SAndroid Build Coastguard Worker actual = "0" 574*da0073e9SAndroid Build Coastguard Worker expected = "0" 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 577*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, str(type(actual))): 578*da0073e9SAndroid Build Coastguard Worker fn() 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker def test_mismatching_shape(self): 581*da0073e9SAndroid Build Coastguard Worker actual = torch.empty(()) 582*da0073e9SAndroid Build Coastguard Worker expected = actual.clone().reshape((1,)) 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 585*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "shape"): 586*da0073e9SAndroid Build Coastguard Worker fn() 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.") 589*da0073e9SAndroid Build Coastguard Worker def test_unknown_layout(self): 590*da0073e9SAndroid Build Coastguard Worker actual = torch.empty((2, 2)) 591*da0073e9SAndroid Build Coastguard Worker expected = actual.to_mkldnn() 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 594*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "layout"): 595*da0073e9SAndroid Build Coastguard Worker fn() 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker def test_meta(self): 598*da0073e9SAndroid Build Coastguard Worker actual = torch.empty((2, 2), device="meta") 599*da0073e9SAndroid Build Coastguard Worker expected = torch.empty((2, 2), device="meta") 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 602*da0073e9SAndroid Build Coastguard Worker fn() 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker def test_mismatching_layout(self): 605*da0073e9SAndroid Build Coastguard Worker strided = torch.empty((2, 2)) 606*da0073e9SAndroid Build Coastguard Worker sparse_coo = strided.to_sparse() 607*da0073e9SAndroid Build Coastguard Worker sparse_csr = strided.to_sparse_csr() 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2): 610*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 611*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "layout"): 612*da0073e9SAndroid Build Coastguard Worker fn() 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker def test_mismatching_layout_no_check(self): 615*da0073e9SAndroid Build Coastguard Worker strided = torch.randn((2, 2)) 616*da0073e9SAndroid Build Coastguard Worker sparse_coo = strided.to_sparse() 617*da0073e9SAndroid Build Coastguard Worker sparse_csr = strided.to_sparse_csr() 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2): 620*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 621*da0073e9SAndroid Build Coastguard Worker fn(check_layout=False) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker def test_mismatching_dtype(self): 624*da0073e9SAndroid Build Coastguard Worker actual = torch.empty((), dtype=torch.float) 625*da0073e9SAndroid Build Coastguard Worker expected = actual.clone().to(torch.int) 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 628*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "dtype"): 629*da0073e9SAndroid Build Coastguard Worker fn() 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker def test_mismatching_dtype_no_check(self): 632*da0073e9SAndroid Build Coastguard Worker actual = torch.ones((), dtype=torch.float) 633*da0073e9SAndroid Build Coastguard Worker expected = actual.clone().to(torch.int) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 636*da0073e9SAndroid Build Coastguard Worker fn(check_dtype=False) 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker def test_mismatching_stride(self): 639*da0073e9SAndroid Build Coastguard Worker actual = torch.empty((2, 2)) 640*da0073e9SAndroid Build Coastguard Worker expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1]) 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 643*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "stride"): 644*da0073e9SAndroid Build Coastguard Worker fn(check_stride=True) 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker def test_mismatching_stride_no_check(self): 647*da0073e9SAndroid Build Coastguard Worker actual = torch.rand((2, 2)) 648*da0073e9SAndroid Build Coastguard Worker expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1]) 649*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 650*da0073e9SAndroid Build Coastguard Worker fn() 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker def test_only_rtol(self): 653*da0073e9SAndroid Build Coastguard Worker actual = torch.empty(()) 654*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 657*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 658*da0073e9SAndroid Build Coastguard Worker fn(rtol=0.0) 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker def test_only_atol(self): 661*da0073e9SAndroid Build Coastguard Worker actual = torch.empty(()) 662*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 665*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 666*da0073e9SAndroid Build Coastguard Worker fn(atol=0.0) 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker def test_mismatching_values(self): 669*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(1) 670*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(2) 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 673*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 674*da0073e9SAndroid Build Coastguard Worker fn() 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker def test_mismatching_values_rtol(self): 677*da0073e9SAndroid Build Coastguard Worker eps = 1e-3 678*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(1.0) 679*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(1.0 + eps) 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 682*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 683*da0073e9SAndroid Build Coastguard Worker fn(rtol=eps / 2, atol=0.0) 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker def test_mismatching_values_atol(self): 686*da0073e9SAndroid Build Coastguard Worker eps = 1e-3 687*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(0.0) 688*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(eps) 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 691*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 692*da0073e9SAndroid Build Coastguard Worker fn(rtol=0.0, atol=eps / 2) 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker def test_matching(self): 695*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(1.0) 696*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 697*da0073e9SAndroid Build Coastguard Worker 698*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(actual, expected) 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker def test_matching_rtol(self): 701*da0073e9SAndroid Build Coastguard Worker eps = 1e-3 702*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(1.0) 703*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(1.0 + eps) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 706*da0073e9SAndroid Build Coastguard Worker fn(rtol=eps * 2, atol=0.0) 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker def test_matching_atol(self): 709*da0073e9SAndroid Build Coastguard Worker eps = 1e-3 710*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(0.0) 711*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(eps) 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 714*da0073e9SAndroid Build Coastguard Worker fn(rtol=0.0, atol=eps * 2) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker # TODO: the code that this test was designed for was removed in https://github.com/pytorch/pytorch/pull/56058 717*da0073e9SAndroid Build Coastguard Worker # We need to check if this test is still needed or if this behavior is now enabled by default. 718*da0073e9SAndroid Build Coastguard Worker def test_matching_conjugate_bit(self): 719*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(complex(1, 1)).conj() 720*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(complex(1, -1)) 721*da0073e9SAndroid Build Coastguard Worker 722*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 723*da0073e9SAndroid Build Coastguard Worker fn() 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker def test_matching_nan(self): 726*da0073e9SAndroid Build Coastguard Worker nan = float("NaN") 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker tests = ( 729*da0073e9SAndroid Build Coastguard Worker (nan, nan), 730*da0073e9SAndroid Build Coastguard Worker (complex(nan, 0), complex(0, nan)), 731*da0073e9SAndroid Build Coastguard Worker (complex(nan, nan), complex(nan, 0)), 732*da0073e9SAndroid Build Coastguard Worker (complex(nan, nan), complex(nan, nan)), 733*da0073e9SAndroid Build Coastguard Worker ) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker for actual, expected in tests: 736*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 737*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 738*da0073e9SAndroid Build Coastguard Worker fn() 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker def test_matching_nan_with_equal_nan(self): 741*da0073e9SAndroid Build Coastguard Worker nan = float("NaN") 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker tests = ( 744*da0073e9SAndroid Build Coastguard Worker (nan, nan), 745*da0073e9SAndroid Build Coastguard Worker (complex(nan, 0), complex(0, nan)), 746*da0073e9SAndroid Build Coastguard Worker (complex(nan, nan), complex(nan, 0)), 747*da0073e9SAndroid Build Coastguard Worker (complex(nan, nan), complex(nan, nan)), 748*da0073e9SAndroid Build Coastguard Worker ) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker for actual, expected in tests: 751*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 752*da0073e9SAndroid Build Coastguard Worker fn(equal_nan=True) 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker def test_numpy(self): 755*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(2, 2, dtype=torch.float32) 756*da0073e9SAndroid Build Coastguard Worker actual = tensor.numpy() 757*da0073e9SAndroid Build Coastguard Worker expected = actual.copy() 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 760*da0073e9SAndroid Build Coastguard Worker fn() 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker def test_scalar(self): 763*da0073e9SAndroid Build Coastguard Worker number = torch.randint(10, size=()).item() 764*da0073e9SAndroid Build Coastguard Worker for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2): 765*da0073e9SAndroid Build Coastguard Worker check_dtype = type(actual) is type(expected) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 768*da0073e9SAndroid Build Coastguard Worker fn(check_dtype=check_dtype) 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker def test_bool(self): 771*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor([True, False]) 772*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 775*da0073e9SAndroid Build Coastguard Worker fn() 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker def test_none(self): 778*da0073e9SAndroid Build Coastguard Worker actual = expected = None 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 781*da0073e9SAndroid Build Coastguard Worker fn() 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker def test_none_mismatch(self): 784*da0073e9SAndroid Build Coastguard Worker expected = None 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker for actual in (False, 0, torch.nan, torch.tensor(torch.nan)): 787*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 788*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 789*da0073e9SAndroid Build Coastguard Worker fn() 790*da0073e9SAndroid Build Coastguard Worker 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker def test_docstring_examples(self): 793*da0073e9SAndroid Build Coastguard Worker finder = doctest.DocTestFinder(verbose=False) 794*da0073e9SAndroid Build Coastguard Worker runner = doctest.DocTestRunner(verbose=False, optionflags=doctest.NORMALIZE_WHITESPACE) 795*da0073e9SAndroid Build Coastguard Worker globs = dict(torch=torch) 796*da0073e9SAndroid Build Coastguard Worker doctests = finder.find(torch.testing.assert_close, globs=globs)[0] 797*da0073e9SAndroid Build Coastguard Worker failures = [] 798*da0073e9SAndroid Build Coastguard Worker runner.run(doctests, out=lambda report: failures.append(report)) 799*da0073e9SAndroid Build Coastguard Worker if failures: 800*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"Doctest found {len(failures)} failures:\n\n" + "\n".join(failures)) 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker def test_default_tolerance_selection_mismatching_dtypes(self): 803*da0073e9SAndroid Build Coastguard Worker # If the default tolerances where selected based on the promoted dtype, i.e. float64, 804*da0073e9SAndroid Build Coastguard Worker # these tensors wouldn't be considered close. 805*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(0.99, dtype=torch.bfloat16) 806*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(1.0, dtype=torch.float64) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 809*da0073e9SAndroid Build Coastguard Worker fn(check_dtype=False) 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker class UnexpectedException(Exception): 812*da0073e9SAndroid Build Coastguard Worker """The only purpose of this exception is to test ``assert_close``'s handling of unexpected exceptions. Thus, 813*da0073e9SAndroid Build Coastguard Worker the test should mock a component to raise this instead of the regular behavior. We avoid using a builtin 814*da0073e9SAndroid Build Coastguard Worker exception here to avoid triggering possible handling of them. 815*da0073e9SAndroid Build Coastguard Worker """ 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker @unittest.mock.patch("torch.testing._comparison.TensorLikePair.__init__", side_effect=UnexpectedException) 818*da0073e9SAndroid Build Coastguard Worker def test_unexpected_error_originate(self, _): 819*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(1.0) 820*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "unexpected exception"): 823*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(actual, expected) 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker @unittest.mock.patch("torch.testing._comparison.TensorLikePair.compare", side_effect=UnexpectedException) 826*da0073e9SAndroid Build Coastguard Worker def test_unexpected_error_compare(self, _): 827*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(1.0) 828*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "unexpected exception"): 831*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(actual, expected) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseMultiDevice(TestCase): 837*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(1) 838*da0073e9SAndroid Build Coastguard Worker def test_mismatching_device(self, devices): 839*da0073e9SAndroid Build Coastguard Worker for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2): 840*da0073e9SAndroid Build Coastguard Worker actual = torch.empty((), device=actual_device) 841*da0073e9SAndroid Build Coastguard Worker expected = actual.clone().to(expected_device) 842*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 843*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "device"): 844*da0073e9SAndroid Build Coastguard Worker fn() 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(1) 847*da0073e9SAndroid Build Coastguard Worker def test_mismatching_device_no_check(self, devices): 848*da0073e9SAndroid Build Coastguard Worker for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2): 849*da0073e9SAndroid Build Coastguard Worker actual = torch.rand((), device=actual_device) 850*da0073e9SAndroid Build Coastguard Worker expected = actual.clone().to(expected_device) 851*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 852*da0073e9SAndroid Build Coastguard Worker fn(check_device=False) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestAssertCloseMultiDevice, globals(), only_for="cuda") 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseErrorMessage(TestCase): 859*da0073e9SAndroid Build Coastguard Worker def test_identifier_tensor_likes(self): 860*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor([1, 2, 3, 4]) 861*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([1, 2, 5, 6]) 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 864*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Tensor-likes")): 865*da0073e9SAndroid Build Coastguard Worker fn() 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker def test_identifier_scalars(self): 868*da0073e9SAndroid Build Coastguard Worker actual = 3 869*da0073e9SAndroid Build Coastguard Worker expected = 5 870*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 871*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Scalars")): 872*da0073e9SAndroid Build Coastguard Worker fn() 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker def test_not_equal(self): 875*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32) 876*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 879*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("not equal")): 880*da0073e9SAndroid Build Coastguard Worker fn(rtol=0.0, atol=0.0) 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker def test_not_close(self): 883*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32) 884*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32) 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker for fn, (rtol, atol) in itertools.product( 887*da0073e9SAndroid Build Coastguard Worker assert_close_with_inputs(actual, expected), ((1.3e-6, 0.0), (0.0, 1e-5), (1.3e-6, 1e-5)) 888*da0073e9SAndroid Build Coastguard Worker ): 889*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("not close")): 890*da0073e9SAndroid Build Coastguard Worker fn(rtol=rtol, atol=atol) 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker def test_mismatched_elements(self): 893*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor([1, 2, 3, 4]) 894*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([1, 2, 5, 6]) 895*da0073e9SAndroid Build Coastguard Worker 896*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 897*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")): 898*da0073e9SAndroid Build Coastguard Worker fn() 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker def test_abs_diff(self): 901*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor([[1, 2], [3, 4]]) 902*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([[1, 2], [5, 4]]) 903*da0073e9SAndroid Build Coastguard Worker 904*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 905*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at index (1, 0)")): 906*da0073e9SAndroid Build Coastguard Worker fn() 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker def test_abs_diff_scalar(self): 909*da0073e9SAndroid Build Coastguard Worker actual = 3 910*da0073e9SAndroid Build Coastguard Worker expected = 5 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 913*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Absolute difference: 2")): 914*da0073e9SAndroid Build Coastguard Worker fn() 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker def test_rel_diff(self): 917*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor([[1, 2], [3, 4]]) 918*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([[1, 4], [3, 4]]) 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 921*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at index (0, 1)")): 922*da0073e9SAndroid Build Coastguard Worker fn() 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker def test_rel_diff_scalar(self): 925*da0073e9SAndroid Build Coastguard Worker actual = 2 926*da0073e9SAndroid Build Coastguard Worker expected = 4 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 929*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Relative difference: 0.5")): 930*da0073e9SAndroid Build Coastguard Worker fn() 931*da0073e9SAndroid Build Coastguard Worker 932*da0073e9SAndroid Build Coastguard Worker def test_zero_div_zero(self): 933*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor([1.0, 0.0]) 934*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([2.0, 0.0]) 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 937*da0073e9SAndroid Build Coastguard Worker # Although it looks complicated, this regex just makes sure that the word 'nan' is not part of the error 938*da0073e9SAndroid Build Coastguard Worker # message. That would happen if the 0 / 0 is used for the mismatch computation although it matches. 939*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "((?!nan).)*"): 940*da0073e9SAndroid Build Coastguard Worker fn() 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker def test_rtol(self): 943*da0073e9SAndroid Build Coastguard Worker rtol = 1e-3 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor((1, 2)) 946*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor((2, 2)) 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 949*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {rtol} allowed)")): 950*da0073e9SAndroid Build Coastguard Worker fn(rtol=rtol, atol=0.0) 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker def test_atol(self): 953*da0073e9SAndroid Build Coastguard Worker atol = 1e-3 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor((1, 2)) 956*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor((2, 2)) 957*da0073e9SAndroid Build Coastguard Worker 958*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 959*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {atol} allowed)")): 960*da0073e9SAndroid Build Coastguard Worker fn(rtol=0.0, atol=atol) 961*da0073e9SAndroid Build Coastguard Worker 962*da0073e9SAndroid Build Coastguard Worker def test_msg_str(self): 963*da0073e9SAndroid Build Coastguard Worker msg = "Custom error message!" 964*da0073e9SAndroid Build Coastguard Worker 965*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(1) 966*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(2) 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 969*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, msg): 970*da0073e9SAndroid Build Coastguard Worker fn(msg=msg) 971*da0073e9SAndroid Build Coastguard Worker 972*da0073e9SAndroid Build Coastguard Worker def test_msg_callable(self): 973*da0073e9SAndroid Build Coastguard Worker msg = "Custom error message" 974*da0073e9SAndroid Build Coastguard Worker 975*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(1) 976*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(2) 977*da0073e9SAndroid Build Coastguard Worker 978*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 979*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, msg): 980*da0073e9SAndroid Build Coastguard Worker fn(msg=lambda _: msg) 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker 983*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseContainer(TestCase): 984*da0073e9SAndroid Build Coastguard Worker def test_sequence_mismatching_len(self): 985*da0073e9SAndroid Build Coastguard Worker actual = (torch.empty(()),) 986*da0073e9SAndroid Build Coastguard Worker expected = () 987*da0073e9SAndroid Build Coastguard Worker 988*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 989*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(actual, expected) 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker def test_sequence_mismatching_values_msg(self): 992*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor(1) 993*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor(2) 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker actual = (t1, t1) 996*da0073e9SAndroid Build Coastguard Worker expected = (t1, t2) 997*da0073e9SAndroid Build Coastguard Worker 998*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("item [1]")): 999*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(actual, expected) 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker def test_mapping_mismatching_keys(self): 1002*da0073e9SAndroid Build Coastguard Worker actual = {"a": torch.empty(())} 1003*da0073e9SAndroid Build Coastguard Worker expected = {} 1004*da0073e9SAndroid Build Coastguard Worker 1005*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 1006*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(actual, expected) 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Worker def test_mapping_mismatching_values_msg(self): 1009*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor(1) 1010*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor(2) 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker actual = {"a": t1, "b": t1} 1013*da0073e9SAndroid Build Coastguard Worker expected = {"a": t1, "b": t2} 1014*da0073e9SAndroid Build Coastguard Worker 1015*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("item ['b']")): 1016*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(actual, expected) 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker 1019*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseSparseCOO(TestCase): 1020*da0073e9SAndroid Build Coastguard Worker def test_matching_coalesced(self): 1021*da0073e9SAndroid Build Coastguard Worker indices = ( 1022*da0073e9SAndroid Build Coastguard Worker (0, 1), 1023*da0073e9SAndroid Build Coastguard Worker (1, 0), 1024*da0073e9SAndroid Build Coastguard Worker ) 1025*da0073e9SAndroid Build Coastguard Worker values = (1, 2) 1026*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce() 1027*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1030*da0073e9SAndroid Build Coastguard Worker fn() 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker def test_matching_uncoalesced(self): 1033*da0073e9SAndroid Build Coastguard Worker indices = ( 1034*da0073e9SAndroid Build Coastguard Worker (0, 1), 1035*da0073e9SAndroid Build Coastguard Worker (1, 0), 1036*da0073e9SAndroid Build Coastguard Worker ) 1037*da0073e9SAndroid Build Coastguard Worker values = (1, 2) 1038*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)) 1039*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1042*da0073e9SAndroid Build Coastguard Worker fn() 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker def test_mismatching_sparse_dims(self): 1045*da0073e9SAndroid Build Coastguard Worker t = torch.randn(2, 3, 4) 1046*da0073e9SAndroid Build Coastguard Worker actual = t.to_sparse() 1047*da0073e9SAndroid Build Coastguard Worker expected = t.to_sparse(2) 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1050*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")): 1051*da0073e9SAndroid Build Coastguard Worker fn() 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker def test_mismatching_nnz(self): 1054*da0073e9SAndroid Build Coastguard Worker actual_indices = ( 1055*da0073e9SAndroid Build Coastguard Worker (0, 1), 1056*da0073e9SAndroid Build Coastguard Worker (1, 0), 1057*da0073e9SAndroid Build Coastguard Worker ) 1058*da0073e9SAndroid Build Coastguard Worker actual_values = (1, 2) 1059*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker expected_indices = ( 1062*da0073e9SAndroid Build Coastguard Worker (0, 1, 1,), 1063*da0073e9SAndroid Build Coastguard Worker (1, 0, 0,), 1064*da0073e9SAndroid Build Coastguard Worker ) 1065*da0073e9SAndroid Build Coastguard Worker expected_values = (1, 1, 1) 1066*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2)) 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1069*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("number of specified values in sparse COO tensors")): 1070*da0073e9SAndroid Build Coastguard Worker fn() 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker def test_mismatching_indices_msg(self): 1073*da0073e9SAndroid Build Coastguard Worker actual_indices = ( 1074*da0073e9SAndroid Build Coastguard Worker (0, 1), 1075*da0073e9SAndroid Build Coastguard Worker (1, 0), 1076*da0073e9SAndroid Build Coastguard Worker ) 1077*da0073e9SAndroid Build Coastguard Worker actual_values = (1, 2) 1078*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)) 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker expected_indices = ( 1081*da0073e9SAndroid Build Coastguard Worker (0, 1), 1082*da0073e9SAndroid Build Coastguard Worker (1, 1), 1083*da0073e9SAndroid Build Coastguard Worker ) 1084*da0073e9SAndroid Build Coastguard Worker expected_values = (1, 2) 1085*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2)) 1086*da0073e9SAndroid Build Coastguard Worker 1087*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1088*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO indices")): 1089*da0073e9SAndroid Build Coastguard Worker fn() 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker def test_mismatching_values_msg(self): 1092*da0073e9SAndroid Build Coastguard Worker actual_indices = ( 1093*da0073e9SAndroid Build Coastguard Worker (0, 1), 1094*da0073e9SAndroid Build Coastguard Worker (1, 0), 1095*da0073e9SAndroid Build Coastguard Worker ) 1096*da0073e9SAndroid Build Coastguard Worker actual_values = (1, 2) 1097*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)) 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Worker expected_indices = ( 1100*da0073e9SAndroid Build Coastguard Worker (0, 1), 1101*da0073e9SAndroid Build Coastguard Worker (1, 0), 1102*da0073e9SAndroid Build Coastguard Worker ) 1103*da0073e9SAndroid Build Coastguard Worker expected_values = (1, 3) 1104*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2)) 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1107*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO values")): 1108*da0073e9SAndroid Build Coastguard Worker fn() 1109*da0073e9SAndroid Build Coastguard Worker 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing") 1112*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseSparseCSR(TestCase): 1113*da0073e9SAndroid Build Coastguard Worker def test_matching(self): 1114*da0073e9SAndroid Build Coastguard Worker crow_indices = (0, 1, 2) 1115*da0073e9SAndroid Build Coastguard Worker col_indices = (1, 0) 1116*da0073e9SAndroid Build Coastguard Worker values = (1, 2) 1117*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2)) 1118*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1121*da0073e9SAndroid Build Coastguard Worker fn() 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker def test_mismatching_crow_indices_msg(self): 1124*da0073e9SAndroid Build Coastguard Worker actual_crow_indices = (0, 1, 2) 1125*da0073e9SAndroid Build Coastguard Worker actual_col_indices = (0, 1) 1126*da0073e9SAndroid Build Coastguard Worker actual_values = (1, 2) 1127*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker expected_crow_indices = (0, 2, 2) 1130*da0073e9SAndroid Build Coastguard Worker expected_col_indices = actual_col_indices 1131*da0073e9SAndroid Build Coastguard Worker expected_values = actual_values 1132*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1135*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR crow_indices")): 1136*da0073e9SAndroid Build Coastguard Worker fn() 1137*da0073e9SAndroid Build Coastguard Worker 1138*da0073e9SAndroid Build Coastguard Worker def test_mismatching_col_indices_msg(self): 1139*da0073e9SAndroid Build Coastguard Worker actual_crow_indices = (0, 1, 2) 1140*da0073e9SAndroid Build Coastguard Worker actual_col_indices = (1, 0) 1141*da0073e9SAndroid Build Coastguard Worker actual_values = (1, 2) 1142*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1143*da0073e9SAndroid Build Coastguard Worker 1144*da0073e9SAndroid Build Coastguard Worker expected_crow_indices = actual_crow_indices 1145*da0073e9SAndroid Build Coastguard Worker expected_col_indices = (1, 1) 1146*da0073e9SAndroid Build Coastguard Worker expected_values = actual_values 1147*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1150*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR col_indices")): 1151*da0073e9SAndroid Build Coastguard Worker fn() 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker def test_mismatching_values_msg(self): 1154*da0073e9SAndroid Build Coastguard Worker actual_crow_indices = (0, 1, 2) 1155*da0073e9SAndroid Build Coastguard Worker actual_col_indices = (1, 0) 1156*da0073e9SAndroid Build Coastguard Worker actual_values = (1, 2) 1157*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Worker expected_crow_indices = actual_crow_indices 1160*da0073e9SAndroid Build Coastguard Worker expected_col_indices = actual_col_indices 1161*da0073e9SAndroid Build Coastguard Worker expected_values = (1, 3) 1162*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1165*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR values")): 1166*da0073e9SAndroid Build Coastguard Worker fn() 1167*da0073e9SAndroid Build Coastguard Worker 1168*da0073e9SAndroid Build Coastguard Worker 1169*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSC testing") 1170*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseSparseCSC(TestCase): 1171*da0073e9SAndroid Build Coastguard Worker def test_matching(self): 1172*da0073e9SAndroid Build Coastguard Worker ccol_indices = (0, 1, 2) 1173*da0073e9SAndroid Build Coastguard Worker row_indices = (1, 0) 1174*da0073e9SAndroid Build Coastguard Worker values = (1, 2) 1175*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_csc_tensor(ccol_indices, row_indices, values, size=(2, 2)) 1176*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 1177*da0073e9SAndroid Build Coastguard Worker 1178*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1179*da0073e9SAndroid Build Coastguard Worker fn() 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Worker def test_mismatching_ccol_indices_msg(self): 1182*da0073e9SAndroid Build Coastguard Worker actual_ccol_indices = (0, 1, 2) 1183*da0073e9SAndroid Build Coastguard Worker actual_row_indices = (0, 1) 1184*da0073e9SAndroid Build Coastguard Worker actual_values = (1, 2) 1185*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1186*da0073e9SAndroid Build Coastguard Worker 1187*da0073e9SAndroid Build Coastguard Worker expected_ccol_indices = (0, 2, 2) 1188*da0073e9SAndroid Build Coastguard Worker expected_row_indices = actual_row_indices 1189*da0073e9SAndroid Build Coastguard Worker expected_values = actual_values 1190*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1191*da0073e9SAndroid Build Coastguard Worker 1192*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1193*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC ccol_indices")): 1194*da0073e9SAndroid Build Coastguard Worker fn() 1195*da0073e9SAndroid Build Coastguard Worker 1196*da0073e9SAndroid Build Coastguard Worker def test_mismatching_row_indices_msg(self): 1197*da0073e9SAndroid Build Coastguard Worker actual_ccol_indices = (0, 1, 2) 1198*da0073e9SAndroid Build Coastguard Worker actual_row_indices = (1, 0) 1199*da0073e9SAndroid Build Coastguard Worker actual_values = (1, 2) 1200*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1201*da0073e9SAndroid Build Coastguard Worker 1202*da0073e9SAndroid Build Coastguard Worker expected_ccol_indices = actual_ccol_indices 1203*da0073e9SAndroid Build Coastguard Worker expected_row_indices = (1, 1) 1204*da0073e9SAndroid Build Coastguard Worker expected_values = actual_values 1205*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1206*da0073e9SAndroid Build Coastguard Worker 1207*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1208*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC row_indices")): 1209*da0073e9SAndroid Build Coastguard Worker fn() 1210*da0073e9SAndroid Build Coastguard Worker 1211*da0073e9SAndroid Build Coastguard Worker def test_mismatching_values_msg(self): 1212*da0073e9SAndroid Build Coastguard Worker actual_ccol_indices = (0, 1, 2) 1213*da0073e9SAndroid Build Coastguard Worker actual_row_indices = (1, 0) 1214*da0073e9SAndroid Build Coastguard Worker actual_values = (1, 2) 1215*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1216*da0073e9SAndroid Build Coastguard Worker 1217*da0073e9SAndroid Build Coastguard Worker expected_ccol_indices = actual_ccol_indices 1218*da0073e9SAndroid Build Coastguard Worker expected_row_indices = actual_row_indices 1219*da0073e9SAndroid Build Coastguard Worker expected_values = (1, 3) 1220*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1221*da0073e9SAndroid Build Coastguard Worker 1222*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1223*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC values")): 1224*da0073e9SAndroid Build Coastguard Worker fn() 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Worker 1227*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSR testing") 1228*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseSparseBSR(TestCase): 1229*da0073e9SAndroid Build Coastguard Worker def test_matching(self): 1230*da0073e9SAndroid Build Coastguard Worker crow_indices = (0, 1, 2) 1231*da0073e9SAndroid Build Coastguard Worker col_indices = (1, 0) 1232*da0073e9SAndroid Build Coastguard Worker values = ([[1]], [[2]]) 1233*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(2, 2)) 1234*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1237*da0073e9SAndroid Build Coastguard Worker fn() 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker def test_mismatching_crow_indices_msg(self): 1240*da0073e9SAndroid Build Coastguard Worker actual_crow_indices = (0, 1, 2) 1241*da0073e9SAndroid Build Coastguard Worker actual_col_indices = (0, 1) 1242*da0073e9SAndroid Build Coastguard Worker actual_values = ([[1]], [[2]]) 1243*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker expected_crow_indices = (0, 2, 2) 1246*da0073e9SAndroid Build Coastguard Worker expected_col_indices = actual_col_indices 1247*da0073e9SAndroid Build Coastguard Worker expected_values = actual_values 1248*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1249*da0073e9SAndroid Build Coastguard Worker 1250*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1251*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR crow_indices")): 1252*da0073e9SAndroid Build Coastguard Worker fn() 1253*da0073e9SAndroid Build Coastguard Worker 1254*da0073e9SAndroid Build Coastguard Worker def test_mismatching_col_indices_msg(self): 1255*da0073e9SAndroid Build Coastguard Worker actual_crow_indices = (0, 1, 2) 1256*da0073e9SAndroid Build Coastguard Worker actual_col_indices = (1, 0) 1257*da0073e9SAndroid Build Coastguard Worker actual_values = ([[1]], [[2]]) 1258*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker expected_crow_indices = actual_crow_indices 1261*da0073e9SAndroid Build Coastguard Worker expected_col_indices = (1, 1) 1262*da0073e9SAndroid Build Coastguard Worker expected_values = actual_values 1263*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1264*da0073e9SAndroid Build Coastguard Worker 1265*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1266*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR col_indices")): 1267*da0073e9SAndroid Build Coastguard Worker fn() 1268*da0073e9SAndroid Build Coastguard Worker 1269*da0073e9SAndroid Build Coastguard Worker def test_mismatching_values_msg(self): 1270*da0073e9SAndroid Build Coastguard Worker actual_crow_indices = (0, 1, 2) 1271*da0073e9SAndroid Build Coastguard Worker actual_col_indices = (1, 0) 1272*da0073e9SAndroid Build Coastguard Worker actual_values = ([[1]], [[2]]) 1273*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2)) 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker expected_crow_indices = actual_crow_indices 1276*da0073e9SAndroid Build Coastguard Worker expected_col_indices = actual_col_indices 1277*da0073e9SAndroid Build Coastguard Worker expected_values = ([[1]], [[3]]) 1278*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2)) 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1281*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR values")): 1282*da0073e9SAndroid Build Coastguard Worker fn() 1283*da0073e9SAndroid Build Coastguard Worker 1284*da0073e9SAndroid Build Coastguard Worker 1285*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSC testing") 1286*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseSparseBSC(TestCase): 1287*da0073e9SAndroid Build Coastguard Worker def test_matching(self): 1288*da0073e9SAndroid Build Coastguard Worker ccol_indices = (0, 1, 2) 1289*da0073e9SAndroid Build Coastguard Worker row_indices = (1, 0) 1290*da0073e9SAndroid Build Coastguard Worker values = ([[1]], [[2]]) 1291*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_bsc_tensor(ccol_indices, row_indices, values, size=(2, 2)) 1292*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1295*da0073e9SAndroid Build Coastguard Worker fn() 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Worker def test_mismatching_ccol_indices_msg(self): 1298*da0073e9SAndroid Build Coastguard Worker actual_ccol_indices = (0, 1, 2) 1299*da0073e9SAndroid Build Coastguard Worker actual_row_indices = (0, 1) 1300*da0073e9SAndroid Build Coastguard Worker actual_values = ([[1]], [[2]]) 1301*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker expected_ccol_indices = (0, 2, 2) 1304*da0073e9SAndroid Build Coastguard Worker expected_row_indices = actual_row_indices 1305*da0073e9SAndroid Build Coastguard Worker expected_values = actual_values 1306*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1307*da0073e9SAndroid Build Coastguard Worker 1308*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1309*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC ccol_indices")): 1310*da0073e9SAndroid Build Coastguard Worker fn() 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker def test_mismatching_row_indices_msg(self): 1313*da0073e9SAndroid Build Coastguard Worker actual_ccol_indices = (0, 1, 2) 1314*da0073e9SAndroid Build Coastguard Worker actual_row_indices = (1, 0) 1315*da0073e9SAndroid Build Coastguard Worker actual_values = ([[1]], [[2]]) 1316*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1317*da0073e9SAndroid Build Coastguard Worker 1318*da0073e9SAndroid Build Coastguard Worker expected_ccol_indices = actual_ccol_indices 1319*da0073e9SAndroid Build Coastguard Worker expected_row_indices = (1, 1) 1320*da0073e9SAndroid Build Coastguard Worker expected_values = actual_values 1321*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1324*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC row_indices")): 1325*da0073e9SAndroid Build Coastguard Worker fn() 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker def test_mismatching_values_msg(self): 1328*da0073e9SAndroid Build Coastguard Worker actual_ccol_indices = (0, 1, 2) 1329*da0073e9SAndroid Build Coastguard Worker actual_row_indices = (1, 0) 1330*da0073e9SAndroid Build Coastguard Worker actual_values = ([[1]], [[2]]) 1331*da0073e9SAndroid Build Coastguard Worker actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2)) 1332*da0073e9SAndroid Build Coastguard Worker 1333*da0073e9SAndroid Build Coastguard Worker expected_ccol_indices = actual_ccol_indices 1334*da0073e9SAndroid Build Coastguard Worker expected_row_indices = actual_row_indices 1335*da0073e9SAndroid Build Coastguard Worker expected_values = ([[1]], [[3]]) 1336*da0073e9SAndroid Build Coastguard Worker expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2)) 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1339*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC values")): 1340*da0073e9SAndroid Build Coastguard Worker fn() 1341*da0073e9SAndroid Build Coastguard Worker 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseQuantized(TestCase): 1344*da0073e9SAndroid Build Coastguard Worker def test_mismatching_is_quantized(self): 1345*da0073e9SAndroid Build Coastguard Worker actual = torch.tensor(1.0) 1346*da0073e9SAndroid Build Coastguard Worker expected = torch.quantize_per_tensor(actual, scale=1.0, zero_point=0, dtype=torch.qint32) 1347*da0073e9SAndroid Build Coastguard Worker 1348*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1349*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "is_quantized"): 1350*da0073e9SAndroid Build Coastguard Worker fn() 1351*da0073e9SAndroid Build Coastguard Worker 1352*da0073e9SAndroid Build Coastguard Worker def test_mismatching_qscheme(self): 1353*da0073e9SAndroid Build Coastguard Worker t = torch.tensor((1.0,)) 1354*da0073e9SAndroid Build Coastguard Worker actual = torch.quantize_per_tensor(t, scale=1.0, zero_point=0, dtype=torch.qint32) 1355*da0073e9SAndroid Build Coastguard Worker expected = torch.quantize_per_channel( 1356*da0073e9SAndroid Build Coastguard Worker t, 1357*da0073e9SAndroid Build Coastguard Worker scales=torch.tensor((1.0,)), 1358*da0073e9SAndroid Build Coastguard Worker zero_points=torch.tensor((0,)), 1359*da0073e9SAndroid Build Coastguard Worker axis=0, 1360*da0073e9SAndroid Build Coastguard Worker dtype=torch.qint32, 1361*da0073e9SAndroid Build Coastguard Worker ) 1362*da0073e9SAndroid Build Coastguard Worker 1363*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1364*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "qscheme"): 1365*da0073e9SAndroid Build Coastguard Worker fn() 1366*da0073e9SAndroid Build Coastguard Worker 1367*da0073e9SAndroid Build Coastguard Worker def test_matching_per_tensor(self): 1368*da0073e9SAndroid Build Coastguard Worker actual = torch.quantize_per_tensor(torch.tensor(1.0), scale=1.0, zero_point=0, dtype=torch.qint32) 1369*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 1370*da0073e9SAndroid Build Coastguard Worker 1371*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1372*da0073e9SAndroid Build Coastguard Worker fn() 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker def test_matching_per_channel(self): 1375*da0073e9SAndroid Build Coastguard Worker actual = torch.quantize_per_channel( 1376*da0073e9SAndroid Build Coastguard Worker torch.tensor((1.0,)), 1377*da0073e9SAndroid Build Coastguard Worker scales=torch.tensor((1.0,)), 1378*da0073e9SAndroid Build Coastguard Worker zero_points=torch.tensor((0,)), 1379*da0073e9SAndroid Build Coastguard Worker axis=0, 1380*da0073e9SAndroid Build Coastguard Worker dtype=torch.qint32, 1381*da0073e9SAndroid Build Coastguard Worker ) 1382*da0073e9SAndroid Build Coastguard Worker expected = actual.clone() 1383*da0073e9SAndroid Build Coastguard Worker 1384*da0073e9SAndroid Build Coastguard Worker for fn in assert_close_with_inputs(actual, expected): 1385*da0073e9SAndroid Build Coastguard Worker fn() 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Workerclass TestMakeTensor(TestCase): 1389*da0073e9SAndroid Build Coastguard Worker supported_dtypes = dtypes( 1390*da0073e9SAndroid Build Coastguard Worker torch.bool, 1391*da0073e9SAndroid Build Coastguard Worker torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, 1392*da0073e9SAndroid Build Coastguard Worker torch.float16, torch.bfloat16, torch.float32, torch.float64, 1393*da0073e9SAndroid Build Coastguard Worker torch.complex32, torch.complex64, torch.complex128, 1394*da0073e9SAndroid Build Coastguard Worker ) 1395*da0073e9SAndroid Build Coastguard Worker 1396*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1397*da0073e9SAndroid Build Coastguard Worker @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)]) 1398*da0073e9SAndroid Build Coastguard Worker @parametrize("splat_shape", [False, True]) 1399*da0073e9SAndroid Build Coastguard Worker def test_smoke(self, dtype, device, shape, splat_shape): 1400*da0073e9SAndroid Build Coastguard Worker t = torch.testing.make_tensor(*shape if splat_shape else shape, dtype=dtype, device=device) 1401*da0073e9SAndroid Build Coastguard Worker 1402*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(t, torch.Tensor) 1403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.shape, shape) 1404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, dtype) 1405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.device, torch.device(device)) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1408*da0073e9SAndroid Build Coastguard Worker @parametrize("requires_grad", [False, True]) 1409*da0073e9SAndroid Build Coastguard Worker def test_requires_grad(self, dtype, device, requires_grad): 1410*da0073e9SAndroid Build Coastguard Worker make_tensor = functools.partial( 1411*da0073e9SAndroid Build Coastguard Worker torch.testing.make_tensor, 1412*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1413*da0073e9SAndroid Build Coastguard Worker device=device, 1414*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 1415*da0073e9SAndroid Build Coastguard Worker ) 1416*da0073e9SAndroid Build Coastguard Worker 1417*da0073e9SAndroid Build Coastguard Worker if not requires_grad or dtype.is_floating_point or dtype.is_complex: 1418*da0073e9SAndroid Build Coastguard Worker t = make_tensor() 1419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.requires_grad, requires_grad) 1420*da0073e9SAndroid Build Coastguard Worker else: 1421*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1422*da0073e9SAndroid Build Coastguard Worker ValueError, "`requires_grad=True` is not supported for boolean and integral dtypes" 1423*da0073e9SAndroid Build Coastguard Worker ): 1424*da0073e9SAndroid Build Coastguard Worker make_tensor() 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1427*da0073e9SAndroid Build Coastguard Worker @parametrize("noncontiguous", [False, True]) 1428*da0073e9SAndroid Build Coastguard Worker @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)]) 1429*da0073e9SAndroid Build Coastguard Worker def test_noncontiguous(self, dtype, device, noncontiguous, shape): 1430*da0073e9SAndroid Build Coastguard Worker numel = functools.reduce(operator.mul, shape, 1) 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker t = torch.testing.make_tensor(shape, dtype=dtype, device=device, noncontiguous=noncontiguous) 1433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.is_contiguous(), not noncontiguous or numel < 2) 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1436*da0073e9SAndroid Build Coastguard Worker @parametrize( 1437*da0073e9SAndroid Build Coastguard Worker "memory_format_and_shape", 1438*da0073e9SAndroid Build Coastguard Worker [ 1439*da0073e9SAndroid Build Coastguard Worker (None, (2, 3, 4)), 1440*da0073e9SAndroid Build Coastguard Worker (torch.contiguous_format, (2, 3, 4)), 1441*da0073e9SAndroid Build Coastguard Worker (torch.channels_last, (2, 3, 4, 5)), 1442*da0073e9SAndroid Build Coastguard Worker (torch.channels_last_3d, (2, 3, 4, 5, 6)), 1443*da0073e9SAndroid Build Coastguard Worker (torch.preserve_format, (2, 3, 4)), 1444*da0073e9SAndroid Build Coastguard Worker ], 1445*da0073e9SAndroid Build Coastguard Worker ) 1446*da0073e9SAndroid Build Coastguard Worker def test_memory_format(self, dtype, device, memory_format_and_shape): 1447*da0073e9SAndroid Build Coastguard Worker memory_format, shape = memory_format_and_shape 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker t = torch.testing.make_tensor(shape, dtype=dtype, device=device, memory_format=memory_format) 1450*da0073e9SAndroid Build Coastguard Worker 1451*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1452*da0073e9SAndroid Build Coastguard Worker t.is_contiguous(memory_format=torch.contiguous_format if memory_format is None else memory_format) 1453*da0073e9SAndroid Build Coastguard Worker ) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1456*da0073e9SAndroid Build Coastguard Worker def test_noncontiguous_memory_format(self, dtype, device): 1457*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "`noncontiguous` and `memory_format` are mutually exclusive"): 1458*da0073e9SAndroid Build Coastguard Worker torch.testing.make_tensor( 1459*da0073e9SAndroid Build Coastguard Worker (2, 3, 4, 5), 1460*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1461*da0073e9SAndroid Build Coastguard Worker device=device, 1462*da0073e9SAndroid Build Coastguard Worker noncontiguous=True, 1463*da0073e9SAndroid Build Coastguard Worker memory_format=torch.channels_last, 1464*da0073e9SAndroid Build Coastguard Worker ) 1465*da0073e9SAndroid Build Coastguard Worker 1466*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1467*da0073e9SAndroid Build Coastguard Worker def test_exclude_zero(self, dtype, device): 1468*da0073e9SAndroid Build Coastguard Worker t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, exclude_zero=True, low=-1, high=2) 1469*da0073e9SAndroid Build Coastguard Worker 1470*da0073e9SAndroid Build Coastguard Worker self.assertTrue((t != 0).all()) 1471*da0073e9SAndroid Build Coastguard Worker 1472*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1473*da0073e9SAndroid Build Coastguard Worker def test_low_high_smoke(self, dtype, device): 1474*da0073e9SAndroid Build Coastguard Worker low_inclusive, high_exclusive = 0, 2 1475*da0073e9SAndroid Build Coastguard Worker 1476*da0073e9SAndroid Build Coastguard Worker t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive) 1477*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 1478*da0073e9SAndroid Build Coastguard Worker t = torch.view_as_real(t) 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all()) 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1483*da0073e9SAndroid Build Coastguard Worker def test_low_high_default_smoke(self, dtype, device): 1484*da0073e9SAndroid Build Coastguard Worker low_inclusive, high_exclusive = { 1485*da0073e9SAndroid Build Coastguard Worker torch.bool: (0, 2), 1486*da0073e9SAndroid Build Coastguard Worker torch.uint8: (0, 10), 1487*da0073e9SAndroid Build Coastguard Worker **dict.fromkeys([torch.int8, torch.int16, torch.int32, torch.int64], (-9, 10)), 1488*da0073e9SAndroid Build Coastguard Worker }.get(dtype, (-9, 9)) 1489*da0073e9SAndroid Build Coastguard Worker 1490*da0073e9SAndroid Build Coastguard Worker t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive) 1491*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 1492*da0073e9SAndroid Build Coastguard Worker t = torch.view_as_real(t) 1493*da0073e9SAndroid Build Coastguard Worker 1494*da0073e9SAndroid Build Coastguard Worker self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all()) 1495*da0073e9SAndroid Build Coastguard Worker 1496*da0073e9SAndroid Build Coastguard Worker @parametrize("low_high", [(0, 0), (1, 0), (0, -1)]) 1497*da0073e9SAndroid Build Coastguard Worker @parametrize("value_types", list(itertools.product([int, float], repeat=2))) 1498*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1499*da0073e9SAndroid Build Coastguard Worker def test_low_ge_high(self, dtype, device, low_high, value_types): 1500*da0073e9SAndroid Build Coastguard Worker low, high = (value_type(value) for value, value_type in zip(low_high, value_types)) 1501*da0073e9SAndroid Build Coastguard Worker 1502*da0073e9SAndroid Build Coastguard Worker if low == high and (dtype.is_floating_point or dtype.is_complex): 1503*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 1504*da0073e9SAndroid Build Coastguard Worker FutureWarning, 1505*da0073e9SAndroid Build Coastguard Worker "Passing `low==high` to `torch.testing.make_tensor` for floating or complex types is deprecated", 1506*da0073e9SAndroid Build Coastguard Worker ): 1507*da0073e9SAndroid Build Coastguard Worker t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low, high=high) 1508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, torch.full_like(t, complex(low, low) if dtype.is_complex else low)) 1509*da0073e9SAndroid Build Coastguard Worker else: 1510*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "`low` must be less than `high`"): 1511*da0073e9SAndroid Build Coastguard Worker torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high) 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1514*da0073e9SAndroid Build Coastguard Worker @parametrize("low_high", [(None, torch.nan), (torch.nan, None), (torch.nan, torch.nan)]) 1515*da0073e9SAndroid Build Coastguard Worker def test_low_high_nan(self, dtype, device, low_high): 1516*da0073e9SAndroid Build Coastguard Worker low, high = low_high 1517*da0073e9SAndroid Build Coastguard Worker 1518*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "`low` and `high` cannot be NaN"): 1519*da0073e9SAndroid Build Coastguard Worker torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high) 1520*da0073e9SAndroid Build Coastguard Worker 1521*da0073e9SAndroid Build Coastguard Worker @supported_dtypes 1522*da0073e9SAndroid Build Coastguard Worker def test_low_high_outside_valid_range(self, dtype, device): 1523*da0073e9SAndroid Build Coastguard Worker make_tensor = functools.partial(torch.testing.make_tensor, dtype=dtype, device=device) 1524*da0073e9SAndroid Build Coastguard Worker 1525*da0073e9SAndroid Build Coastguard Worker def get_dtype_limits(dtype): 1526*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bool: 1527*da0073e9SAndroid Build Coastguard Worker return 0, 1 1528*da0073e9SAndroid Build Coastguard Worker 1529*da0073e9SAndroid Build Coastguard Worker info = (torch.finfo if dtype.is_floating_point or dtype.is_complex else torch.iinfo)(dtype) 1530*da0073e9SAndroid Build Coastguard Worker # We are using integer bounds here, because otherwise it would be impossible to pass `low` and `high` 1531*da0073e9SAndroid Build Coastguard Worker # outside their valid range. Python uses 64bit floating point numbers and thus trying to do something like 1532*da0073e9SAndroid Build Coastguard Worker # `torch.ffinfo(torch.float64)max * 2` will always result in `inf`. On the flipside, Pythons `int` is 1533*da0073e9SAndroid Build Coastguard Worker # unbounded. 1534*da0073e9SAndroid Build Coastguard Worker return int(info.min), int(info.max) 1535*da0073e9SAndroid Build Coastguard Worker 1536*da0073e9SAndroid Build Coastguard Worker lowest_inclusive, highest_inclusive = get_dtype_limits(dtype) 1537*da0073e9SAndroid Build Coastguard Worker 1538*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, ""): 1539*da0073e9SAndroid Build Coastguard Worker low, high = (-2, -1) if lowest_inclusive == 0 else (lowest_inclusive * 4, lowest_inclusive * 2) 1540*da0073e9SAndroid Build Coastguard Worker make_tensor(low=low, high=high) 1541*da0073e9SAndroid Build Coastguard Worker 1542*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, ""): 1543*da0073e9SAndroid Build Coastguard Worker make_tensor(low=highest_inclusive * 2, high=highest_inclusive * 4) 1544*da0073e9SAndroid Build Coastguard Worker 1545*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) 1546*da0073e9SAndroid Build Coastguard Worker def test_low_high_boolean_integral1(self, dtype, device): 1547*da0073e9SAndroid Build Coastguard Worker shape = (10_000,) 1548*da0073e9SAndroid Build Coastguard Worker eps = 1e-4 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=-(1 - eps), high=1 - eps) 1551*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(shape, dtype=dtype, device=device) 1552*da0073e9SAndroid Build Coastguard Worker 1553*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(actual, expected) 1554*da0073e9SAndroid Build Coastguard Worker 1555*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) 1556*da0073e9SAndroid Build Coastguard Worker def test_low_high_boolean_integral2(self, dtype, device): 1557*da0073e9SAndroid Build Coastguard Worker shape = (10_000,) 1558*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bool: 1559*da0073e9SAndroid Build Coastguard Worker low = 1 1560*da0073e9SAndroid Build Coastguard Worker elif dtype is torch.int64: 1561*da0073e9SAndroid Build Coastguard Worker # Due to its internals, `make_tensor` is not able to sample `torch.iinfo(torch.int64).max` 1562*da0073e9SAndroid Build Coastguard Worker low = torch.iinfo(dtype).max - 1 1563*da0073e9SAndroid Build Coastguard Worker else: 1564*da0073e9SAndroid Build Coastguard Worker low = torch.iinfo(dtype).max 1565*da0073e9SAndroid Build Coastguard Worker high = low + 1 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high) 1568*da0073e9SAndroid Build Coastguard Worker expected = torch.full(shape, low, dtype=dtype, device=device) 1569*da0073e9SAndroid Build Coastguard Worker 1570*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(actual, expected) 1571*da0073e9SAndroid Build Coastguard Worker 1572*da0073e9SAndroid Build Coastguard Worker 1573*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMakeTensor, globals()) 1574*da0073e9SAndroid Build Coastguard Worker 1575*da0073e9SAndroid Build Coastguard Worker 1576*da0073e9SAndroid Build Coastguard Workerdef _get_test_names_for_test_class(test_cls): 1577*da0073e9SAndroid Build Coastguard Worker """ Convenience function to get all test names for a given test class. """ 1578*da0073e9SAndroid Build Coastguard Worker test_names = [f'{test_cls.__name__}.{key}' for key in test_cls.__dict__ 1579*da0073e9SAndroid Build Coastguard Worker if key.startswith('test_')] 1580*da0073e9SAndroid Build Coastguard Worker return sorted(test_names) 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker 1583*da0073e9SAndroid Build Coastguard Workerdef _get_test_funcs_for_test_class(test_cls): 1584*da0073e9SAndroid Build Coastguard Worker """ Convenience function to get all (test function, parametrized_name) pairs for a given test class. """ 1585*da0073e9SAndroid Build Coastguard Worker test_funcs = [(getattr(test_cls, key), key) for key in test_cls.__dict__ if key.startswith('test_')] 1586*da0073e9SAndroid Build Coastguard Worker return test_funcs 1587*da0073e9SAndroid Build Coastguard Worker 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Workerclass TestTestParametrization(TestCase): 1590*da0073e9SAndroid Build Coastguard Worker def test_default_names(self): 1591*da0073e9SAndroid Build Coastguard Worker 1592*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1593*da0073e9SAndroid Build Coastguard Worker @parametrize("x", range(5)) 1594*da0073e9SAndroid Build Coastguard Worker def test_default_names(self, x): 1595*da0073e9SAndroid Build Coastguard Worker pass 1596*da0073e9SAndroid Build Coastguard Worker 1597*da0073e9SAndroid Build Coastguard Worker @parametrize("x,y", [(1, 2), (2, 3), (3, 4)]) 1598*da0073e9SAndroid Build Coastguard Worker def test_two_things_default_names(self, x, y): 1599*da0073e9SAndroid Build Coastguard Worker pass 1600*da0073e9SAndroid Build Coastguard Worker 1601*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests(TestParametrized) 1602*da0073e9SAndroid Build Coastguard Worker 1603*da0073e9SAndroid Build Coastguard Worker expected_test_names = [ 1604*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_default_names_x_0', 1605*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_default_names_x_1', 1606*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_default_names_x_2', 1607*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_default_names_x_3', 1608*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_default_names_x_4', 1609*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_two_things_default_names_x_1_y_2', 1610*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_two_things_default_names_x_2_y_3', 1611*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_two_things_default_names_x_3_y_4', 1612*da0073e9SAndroid Build Coastguard Worker ] 1613*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(TestParametrized) 1614*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_test_names, test_names) 1615*da0073e9SAndroid Build Coastguard Worker 1616*da0073e9SAndroid Build Coastguard Worker def test_name_fn(self): 1617*da0073e9SAndroid Build Coastguard Worker 1618*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1619*da0073e9SAndroid Build Coastguard Worker @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias') 1620*da0073e9SAndroid Build Coastguard Worker def test_custom_names(self, bias): 1621*da0073e9SAndroid Build Coastguard Worker pass 1622*da0073e9SAndroid Build Coastguard Worker 1623*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [1, 2], name_fn=str) 1624*da0073e9SAndroid Build Coastguard Worker @parametrize("y", [3, 4], name_fn=str) 1625*da0073e9SAndroid Build Coastguard Worker @parametrize("z", [5, 6], name_fn=str) 1626*da0073e9SAndroid Build Coastguard Worker def test_three_things_composition_custom_names(self, x, y, z): 1627*da0073e9SAndroid Build Coastguard Worker pass 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}') 1630*da0073e9SAndroid Build Coastguard Worker def test_two_things_custom_names_alternate(self, x, y): 1631*da0073e9SAndroid Build Coastguard Worker pass 1632*da0073e9SAndroid Build Coastguard Worker 1633*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests(TestParametrized) 1634*da0073e9SAndroid Build Coastguard Worker 1635*da0073e9SAndroid Build Coastguard Worker expected_test_names = [ 1636*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_custom_names_bias', 1637*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_custom_names_no_bias', 1638*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_three_things_composition_custom_names_1_3_5', 1639*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_three_things_composition_custom_names_1_3_6', 1640*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_three_things_composition_custom_names_1_4_5', 1641*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_three_things_composition_custom_names_1_4_6', 1642*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_three_things_composition_custom_names_2_3_5', 1643*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_three_things_composition_custom_names_2_3_6', 1644*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_three_things_composition_custom_names_2_4_5', 1645*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_three_things_composition_custom_names_2_4_6', 1646*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_two_things_custom_names_alternate_1__2', 1647*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_two_things_custom_names_alternate_1__3', 1648*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_two_things_custom_names_alternate_1__4', 1649*da0073e9SAndroid Build Coastguard Worker ] 1650*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(TestParametrized) 1651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_test_names, test_names) 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker def test_subtest_names(self): 1654*da0073e9SAndroid Build Coastguard Worker 1655*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1656*da0073e9SAndroid Build Coastguard Worker @parametrize("bias", [subtest(True, name='bias'), 1657*da0073e9SAndroid Build Coastguard Worker subtest(False, name='no_bias')]) 1658*da0073e9SAndroid Build Coastguard Worker def test_custom_names(self, bias): 1659*da0073e9SAndroid Build Coastguard Worker pass 1660*da0073e9SAndroid Build Coastguard Worker 1661*da0073e9SAndroid Build Coastguard Worker @parametrize("x,y", [subtest((1, 2), name='double'), 1662*da0073e9SAndroid Build Coastguard Worker subtest((1, 3), name='triple'), 1663*da0073e9SAndroid Build Coastguard Worker subtest((1, 4), name='quadruple')]) 1664*da0073e9SAndroid Build Coastguard Worker def test_two_things_custom_names(self, x, y): 1665*da0073e9SAndroid Build Coastguard Worker pass 1666*da0073e9SAndroid Build Coastguard Worker 1667*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests(TestParametrized) 1668*da0073e9SAndroid Build Coastguard Worker 1669*da0073e9SAndroid Build Coastguard Worker expected_test_names = [ 1670*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_custom_names_bias', 1671*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_custom_names_no_bias', 1672*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_two_things_custom_names_double', 1673*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_two_things_custom_names_quadruple', 1674*da0073e9SAndroid Build Coastguard Worker 'TestParametrized.test_two_things_custom_names_triple', 1675*da0073e9SAndroid Build Coastguard Worker ] 1676*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(TestParametrized) 1677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_test_names, test_names) 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker def test_apply_param_specific_decorators(self): 1680*da0073e9SAndroid Build Coastguard Worker # Test that decorators can be applied on a per-param basis. 1681*da0073e9SAndroid Build Coastguard Worker 1682*da0073e9SAndroid Build Coastguard Worker def test_dec(func): 1683*da0073e9SAndroid Build Coastguard Worker func._decorator_applied = True 1684*da0073e9SAndroid Build Coastguard Worker return func 1685*da0073e9SAndroid Build Coastguard Worker 1686*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1687*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [subtest(1, name='one'), 1688*da0073e9SAndroid Build Coastguard Worker subtest(2, name='two', decorators=[test_dec]), 1689*da0073e9SAndroid Build Coastguard Worker subtest(3, name='three')]) 1690*da0073e9SAndroid Build Coastguard Worker def test_param(self, x): 1691*da0073e9SAndroid Build Coastguard Worker pass 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests(TestParametrized) 1694*da0073e9SAndroid Build Coastguard Worker 1695*da0073e9SAndroid Build Coastguard Worker for test_func, name in _get_test_funcs_for_test_class(TestParametrized): 1696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hasattr(test_func, '_decorator_applied'), name == 'test_param_two') 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker def test_compose_param_specific_decorators(self): 1699*da0073e9SAndroid Build Coastguard Worker # Test that multiple per-param decorators compose correctly. 1700*da0073e9SAndroid Build Coastguard Worker 1701*da0073e9SAndroid Build Coastguard Worker def test_dec(func): 1702*da0073e9SAndroid Build Coastguard Worker func._decorator_applied = True 1703*da0073e9SAndroid Build Coastguard Worker return func 1704*da0073e9SAndroid Build Coastguard Worker 1705*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1706*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [subtest(1), 1707*da0073e9SAndroid Build Coastguard Worker subtest(2, decorators=[test_dec]), 1708*da0073e9SAndroid Build Coastguard Worker subtest(3)]) 1709*da0073e9SAndroid Build Coastguard Worker @parametrize("y", [subtest(False, decorators=[test_dec]), 1710*da0073e9SAndroid Build Coastguard Worker subtest(True)]) 1711*da0073e9SAndroid Build Coastguard Worker def test_param(self, x, y): 1712*da0073e9SAndroid Build Coastguard Worker pass 1713*da0073e9SAndroid Build Coastguard Worker 1714*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests(TestParametrized) 1715*da0073e9SAndroid Build Coastguard Worker 1716*da0073e9SAndroid Build Coastguard Worker for test_func, name in _get_test_funcs_for_test_class(TestParametrized): 1717*da0073e9SAndroid Build Coastguard Worker # Decorator should be applied whenever either x == 2 or y == False. 1718*da0073e9SAndroid Build Coastguard Worker should_apply = ('x_2' in name) or ('y_False' in name) 1719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply) 1720*da0073e9SAndroid Build Coastguard Worker 1721*da0073e9SAndroid Build Coastguard Worker def test_modules_decorator_misuse_error(self): 1722*da0073e9SAndroid Build Coastguard Worker # Test that @modules errors out when used with instantiate_parametrized_tests(). 1723*da0073e9SAndroid Build Coastguard Worker 1724*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1725*da0073e9SAndroid Build Coastguard Worker @modules(module_db) 1726*da0073e9SAndroid Build Coastguard Worker def test_modules(self, module_info): 1727*da0073e9SAndroid Build Coastguard Worker pass 1728*da0073e9SAndroid Build Coastguard Worker 1729*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'): 1730*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests(TestParametrized) 1731*da0073e9SAndroid Build Coastguard Worker 1732*da0073e9SAndroid Build Coastguard Worker def test_ops_decorator_misuse_error(self): 1733*da0073e9SAndroid Build Coastguard Worker # Test that @ops errors out when used with instantiate_parametrized_tests(). 1734*da0073e9SAndroid Build Coastguard Worker 1735*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1736*da0073e9SAndroid Build Coastguard Worker @ops(op_db) 1737*da0073e9SAndroid Build Coastguard Worker def test_ops(self, module_info): 1738*da0073e9SAndroid Build Coastguard Worker pass 1739*da0073e9SAndroid Build Coastguard Worker 1740*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'): 1741*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests(TestParametrized) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker def test_multiple_handling_of_same_param_error(self): 1744*da0073e9SAndroid Build Coastguard Worker # Test that multiple decorators handling the same param errors out. 1745*da0073e9SAndroid Build Coastguard Worker 1746*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1747*da0073e9SAndroid Build Coastguard Worker @parametrize("x", range(3)) 1748*da0073e9SAndroid Build Coastguard Worker @parametrize("x", range(5)) 1749*da0073e9SAndroid Build Coastguard Worker def test_param(self, x): 1750*da0073e9SAndroid Build Coastguard Worker pass 1751*da0073e9SAndroid Build Coastguard Worker 1752*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'multiple parametrization decorators'): 1753*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests(TestParametrized) 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3]) 1756*da0073e9SAndroid Build Coastguard Worker def test_subtest_expected_failure(self, x): 1757*da0073e9SAndroid Build Coastguard Worker if x == 2: 1758*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Boom') 1759*da0073e9SAndroid Build Coastguard Worker 1760*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3]) 1761*da0073e9SAndroid Build Coastguard Worker @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])]) 1762*da0073e9SAndroid Build Coastguard Worker def test_two_things_subtest_expected_failure(self, x, y): 1763*da0073e9SAndroid Build Coastguard Worker if x == 1 or y == 6: 1764*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Boom') 1765*da0073e9SAndroid Build Coastguard Worker 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Workerclass TestTestParametrizationDeviceType(TestCase): 1768*da0073e9SAndroid Build Coastguard Worker def test_unparametrized_names(self, device): 1769*da0073e9SAndroid Build Coastguard Worker # This test exists to protect against regressions in device / dtype test naming 1770*da0073e9SAndroid Build Coastguard Worker # due to parametrization logic. 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker device = self.device_type 1773*da0073e9SAndroid Build Coastguard Worker 1774*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1775*da0073e9SAndroid Build Coastguard Worker def test_device_specific(self, device): 1776*da0073e9SAndroid Build Coastguard Worker pass 1777*da0073e9SAndroid Build Coastguard Worker 1778*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 1779*da0073e9SAndroid Build Coastguard Worker def test_device_dtype_specific(self, device, dtype): 1780*da0073e9SAndroid Build Coastguard Worker pass 1781*da0073e9SAndroid Build Coastguard Worker 1782*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1783*da0073e9SAndroid Build Coastguard Worker 1784*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 1785*da0073e9SAndroid Build Coastguard Worker expected_test_names = [name.format(device_cls.__name__, device) for name in ( 1786*da0073e9SAndroid Build Coastguard Worker '{}.test_device_dtype_specific_{}_float32', 1787*da0073e9SAndroid Build Coastguard Worker '{}.test_device_dtype_specific_{}_float64', 1788*da0073e9SAndroid Build Coastguard Worker '{}.test_device_specific_{}') 1789*da0073e9SAndroid Build Coastguard Worker ] 1790*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(device_cls) 1791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_test_names, test_names) 1792*da0073e9SAndroid Build Coastguard Worker 1793*da0073e9SAndroid Build Coastguard Worker def test_empty_param_names(self, device): 1794*da0073e9SAndroid Build Coastguard Worker # If no param names are passed, ensure things still work without parametrization. 1795*da0073e9SAndroid Build Coastguard Worker device = self.device_type 1796*da0073e9SAndroid Build Coastguard Worker 1797*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1798*da0073e9SAndroid Build Coastguard Worker @parametrize("", []) 1799*da0073e9SAndroid Build Coastguard Worker def test_foo(self, device): 1800*da0073e9SAndroid Build Coastguard Worker pass 1801*da0073e9SAndroid Build Coastguard Worker 1802*da0073e9SAndroid Build Coastguard Worker @parametrize("", range(5)) 1803*da0073e9SAndroid Build Coastguard Worker def test_bar(self, device): 1804*da0073e9SAndroid Build Coastguard Worker pass 1805*da0073e9SAndroid Build Coastguard Worker 1806*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 1809*da0073e9SAndroid Build Coastguard Worker expected_test_names = [name.format(device_cls.__name__, device) for name in ( 1810*da0073e9SAndroid Build Coastguard Worker '{}.test_bar_{}', 1811*da0073e9SAndroid Build Coastguard Worker '{}.test_foo_{}') 1812*da0073e9SAndroid Build Coastguard Worker ] 1813*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(device_cls) 1814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_test_names, test_names) 1815*da0073e9SAndroid Build Coastguard Worker 1816*da0073e9SAndroid Build Coastguard Worker def test_empty_param_list(self, device): 1817*da0073e9SAndroid Build Coastguard Worker # If no param values are passed, ensure a helpful error message is thrown. 1818*da0073e9SAndroid Build Coastguard Worker # In the wild, this could indicate reuse of an exhausted generator. 1819*da0073e9SAndroid Build Coastguard Worker device = self.device_type 1820*da0073e9SAndroid Build Coastguard Worker 1821*da0073e9SAndroid Build Coastguard Worker generator = (a for a in range(5)) 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1824*da0073e9SAndroid Build Coastguard Worker @parametrize("x", generator) 1825*da0073e9SAndroid Build Coastguard Worker def test_foo(self, device, x): 1826*da0073e9SAndroid Build Coastguard Worker pass 1827*da0073e9SAndroid Build Coastguard Worker 1828*da0073e9SAndroid Build Coastguard Worker # Reuse generator from first test function. 1829*da0073e9SAndroid Build Coastguard Worker @parametrize("y", generator) 1830*da0073e9SAndroid Build Coastguard Worker def test_bar(self, device, y): 1831*da0073e9SAndroid Build Coastguard Worker pass 1832*da0073e9SAndroid Build Coastguard Worker 1833*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, 'An empty arg_values was passed'): 1834*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1835*da0073e9SAndroid Build Coastguard Worker 1836*da0073e9SAndroid Build Coastguard Worker def test_default_names(self, device): 1837*da0073e9SAndroid Build Coastguard Worker device = self.device_type 1838*da0073e9SAndroid Build Coastguard Worker 1839*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1840*da0073e9SAndroid Build Coastguard Worker @parametrize("x", range(5)) 1841*da0073e9SAndroid Build Coastguard Worker def test_default_names(self, device, x): 1842*da0073e9SAndroid Build Coastguard Worker pass 1843*da0073e9SAndroid Build Coastguard Worker 1844*da0073e9SAndroid Build Coastguard Worker @parametrize("x,y", [(1, 2), (2, 3), (3, 4)]) 1845*da0073e9SAndroid Build Coastguard Worker def test_two_things_default_names(self, device, x, y): 1846*da0073e9SAndroid Build Coastguard Worker pass 1847*da0073e9SAndroid Build Coastguard Worker 1848*da0073e9SAndroid Build Coastguard Worker 1849*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1850*da0073e9SAndroid Build Coastguard Worker 1851*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 1852*da0073e9SAndroid Build Coastguard Worker expected_test_names = [name.format(device_cls.__name__, device) for name in ( 1853*da0073e9SAndroid Build Coastguard Worker '{}.test_default_names_x_0_{}', 1854*da0073e9SAndroid Build Coastguard Worker '{}.test_default_names_x_1_{}', 1855*da0073e9SAndroid Build Coastguard Worker '{}.test_default_names_x_2_{}', 1856*da0073e9SAndroid Build Coastguard Worker '{}.test_default_names_x_3_{}', 1857*da0073e9SAndroid Build Coastguard Worker '{}.test_default_names_x_4_{}', 1858*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_default_names_x_1_y_2_{}', 1859*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_default_names_x_2_y_3_{}', 1860*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_default_names_x_3_y_4_{}') 1861*da0073e9SAndroid Build Coastguard Worker ] 1862*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(device_cls) 1863*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_test_names, test_names) 1864*da0073e9SAndroid Build Coastguard Worker 1865*da0073e9SAndroid Build Coastguard Worker def test_default_name_non_primitive(self, device): 1866*da0073e9SAndroid Build Coastguard Worker device = self.device_type 1867*da0073e9SAndroid Build Coastguard Worker 1868*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1869*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [1, .5, "foo", object()]) 1870*da0073e9SAndroid Build Coastguard Worker def test_default_names(self, device, x): 1871*da0073e9SAndroid Build Coastguard Worker pass 1872*da0073e9SAndroid Build Coastguard Worker 1873*da0073e9SAndroid Build Coastguard Worker @parametrize("x,y", [(1, object()), (object(), .5), (object(), object())]) 1874*da0073e9SAndroid Build Coastguard Worker def test_two_things_default_names(self, device, x, y): 1875*da0073e9SAndroid Build Coastguard Worker pass 1876*da0073e9SAndroid Build Coastguard Worker 1877*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1878*da0073e9SAndroid Build Coastguard Worker 1879*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 1880*da0073e9SAndroid Build Coastguard Worker expected_test_names = sorted(name.format(device_cls.__name__, device) for name in ( 1881*da0073e9SAndroid Build Coastguard Worker '{}.test_default_names_x_1_{}', 1882*da0073e9SAndroid Build Coastguard Worker '{}.test_default_names_x_0_5_{}', 1883*da0073e9SAndroid Build Coastguard Worker '{}.test_default_names_x_foo_{}', 1884*da0073e9SAndroid Build Coastguard Worker '{}.test_default_names_x3_{}', 1885*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_default_names_x_1_y0_{}', 1886*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_default_names_x1_y_0_5_{}', 1887*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_default_names_x2_y2_{}') 1888*da0073e9SAndroid Build Coastguard Worker ) 1889*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(device_cls) 1890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_test_names, test_names) 1891*da0073e9SAndroid Build Coastguard Worker 1892*da0073e9SAndroid Build Coastguard Worker def test_name_fn(self, device): 1893*da0073e9SAndroid Build Coastguard Worker device = self.device_type 1894*da0073e9SAndroid Build Coastguard Worker 1895*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1896*da0073e9SAndroid Build Coastguard Worker @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias') 1897*da0073e9SAndroid Build Coastguard Worker def test_custom_names(self, device, bias): 1898*da0073e9SAndroid Build Coastguard Worker pass 1899*da0073e9SAndroid Build Coastguard Worker 1900*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [1, 2], name_fn=str) 1901*da0073e9SAndroid Build Coastguard Worker @parametrize("y", [3, 4], name_fn=str) 1902*da0073e9SAndroid Build Coastguard Worker @parametrize("z", [5, 6], name_fn=str) 1903*da0073e9SAndroid Build Coastguard Worker def test_three_things_composition_custom_names(self, device, x, y, z): 1904*da0073e9SAndroid Build Coastguard Worker pass 1905*da0073e9SAndroid Build Coastguard Worker 1906*da0073e9SAndroid Build Coastguard Worker @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}') 1907*da0073e9SAndroid Build Coastguard Worker def test_two_things_custom_names_alternate(self, device, x, y): 1908*da0073e9SAndroid Build Coastguard Worker pass 1909*da0073e9SAndroid Build Coastguard Worker 1910*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1911*da0073e9SAndroid Build Coastguard Worker 1912*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 1913*da0073e9SAndroid Build Coastguard Worker expected_test_names = [name.format(device_cls.__name__, device) for name in ( 1914*da0073e9SAndroid Build Coastguard Worker '{}.test_custom_names_bias_{}', 1915*da0073e9SAndroid Build Coastguard Worker '{}.test_custom_names_no_bias_{}', 1916*da0073e9SAndroid Build Coastguard Worker '{}.test_three_things_composition_custom_names_1_3_5_{}', 1917*da0073e9SAndroid Build Coastguard Worker '{}.test_three_things_composition_custom_names_1_3_6_{}', 1918*da0073e9SAndroid Build Coastguard Worker '{}.test_three_things_composition_custom_names_1_4_5_{}', 1919*da0073e9SAndroid Build Coastguard Worker '{}.test_three_things_composition_custom_names_1_4_6_{}', 1920*da0073e9SAndroid Build Coastguard Worker '{}.test_three_things_composition_custom_names_2_3_5_{}', 1921*da0073e9SAndroid Build Coastguard Worker '{}.test_three_things_composition_custom_names_2_3_6_{}', 1922*da0073e9SAndroid Build Coastguard Worker '{}.test_three_things_composition_custom_names_2_4_5_{}', 1923*da0073e9SAndroid Build Coastguard Worker '{}.test_three_things_composition_custom_names_2_4_6_{}', 1924*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_custom_names_alternate_1__2_{}', 1925*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_custom_names_alternate_1__3_{}', 1926*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_custom_names_alternate_1__4_{}') 1927*da0073e9SAndroid Build Coastguard Worker ] 1928*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(device_cls) 1929*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_test_names, test_names) 1930*da0073e9SAndroid Build Coastguard Worker 1931*da0073e9SAndroid Build Coastguard Worker def test_subtest_names(self, device): 1932*da0073e9SAndroid Build Coastguard Worker device = self.device_type 1933*da0073e9SAndroid Build Coastguard Worker 1934*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1935*da0073e9SAndroid Build Coastguard Worker @parametrize("bias", [subtest(True, name='bias'), 1936*da0073e9SAndroid Build Coastguard Worker subtest(False, name='no_bias')]) 1937*da0073e9SAndroid Build Coastguard Worker def test_custom_names(self, device, bias): 1938*da0073e9SAndroid Build Coastguard Worker pass 1939*da0073e9SAndroid Build Coastguard Worker 1940*da0073e9SAndroid Build Coastguard Worker @parametrize("x,y", [subtest((1, 2), name='double'), 1941*da0073e9SAndroid Build Coastguard Worker subtest((1, 3), name='triple'), 1942*da0073e9SAndroid Build Coastguard Worker subtest((1, 4), name='quadruple')]) 1943*da0073e9SAndroid Build Coastguard Worker def test_two_things_custom_names(self, device, x, y): 1944*da0073e9SAndroid Build Coastguard Worker pass 1945*da0073e9SAndroid Build Coastguard Worker 1946*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 1949*da0073e9SAndroid Build Coastguard Worker expected_test_names = [name.format(device_cls.__name__, device) for name in ( 1950*da0073e9SAndroid Build Coastguard Worker '{}.test_custom_names_bias_{}', 1951*da0073e9SAndroid Build Coastguard Worker '{}.test_custom_names_no_bias_{}', 1952*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_custom_names_double_{}', 1953*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_custom_names_quadruple_{}', 1954*da0073e9SAndroid Build Coastguard Worker '{}.test_two_things_custom_names_triple_{}') 1955*da0073e9SAndroid Build Coastguard Worker ] 1956*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(device_cls) 1957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_test_names, test_names) 1958*da0073e9SAndroid Build Coastguard Worker 1959*da0073e9SAndroid Build Coastguard Worker def test_ops_composition_names(self, device): 1960*da0073e9SAndroid Build Coastguard Worker device = self.device_type 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1963*da0073e9SAndroid Build Coastguard Worker @ops(op_db) 1964*da0073e9SAndroid Build Coastguard Worker @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled') 1965*da0073e9SAndroid Build Coastguard Worker def test_op_parametrized(self, device, dtype, op, flag): 1966*da0073e9SAndroid Build Coastguard Worker pass 1967*da0073e9SAndroid Build Coastguard Worker 1968*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1969*da0073e9SAndroid Build Coastguard Worker 1970*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 1971*da0073e9SAndroid Build Coastguard Worker expected_test_names = [] 1972*da0073e9SAndroid Build Coastguard Worker for op in op_db: 1973*da0073e9SAndroid Build Coastguard Worker for dtype in op.supported_dtypes(torch.device(device).type): 1974*da0073e9SAndroid Build Coastguard Worker for flag_part in ('flag_disabled', 'flag_enabled'): 1975*da0073e9SAndroid Build Coastguard Worker expected_name = f'{device_cls.__name__}.test_op_parametrized_{op.formatted_name}_{flag_part}_{device}_{dtype_name(dtype)}' # noqa: B950 1976*da0073e9SAndroid Build Coastguard Worker expected_test_names.append(expected_name) 1977*da0073e9SAndroid Build Coastguard Worker 1978*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(device_cls) 1979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sorted(expected_test_names), sorted(test_names)) 1980*da0073e9SAndroid Build Coastguard Worker 1981*da0073e9SAndroid Build Coastguard Worker def test_modules_composition_names(self, device): 1982*da0073e9SAndroid Build Coastguard Worker device = self.device_type 1983*da0073e9SAndroid Build Coastguard Worker 1984*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 1985*da0073e9SAndroid Build Coastguard Worker @modules(module_db) 1986*da0073e9SAndroid Build Coastguard Worker @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled') 1987*da0073e9SAndroid Build Coastguard Worker def test_module_parametrized(self, device, dtype, module_info, training, flag): 1988*da0073e9SAndroid Build Coastguard Worker pass 1989*da0073e9SAndroid Build Coastguard Worker 1990*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 1991*da0073e9SAndroid Build Coastguard Worker 1992*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 1993*da0073e9SAndroid Build Coastguard Worker expected_test_names = [] 1994*da0073e9SAndroid Build Coastguard Worker for module_info in module_db: 1995*da0073e9SAndroid Build Coastguard Worker for dtype in module_info.dtypes: 1996*da0073e9SAndroid Build Coastguard Worker for flag_part in ('flag_disabled', 'flag_enabled'): 1997*da0073e9SAndroid Build Coastguard Worker expected_train_modes = ( 1998*da0073e9SAndroid Build Coastguard Worker ['train_mode', 'eval_mode'] if module_info.train_and_eval_differ else ['']) 1999*da0073e9SAndroid Build Coastguard Worker for training_part in expected_train_modes: 2000*da0073e9SAndroid Build Coastguard Worker expected_name = '{}.test_module_parametrized_{}{}_{}_{}_{}'.format( 2001*da0073e9SAndroid Build Coastguard Worker device_cls.__name__, module_info.formatted_name, 2002*da0073e9SAndroid Build Coastguard Worker '_' + training_part if len(training_part) > 0 else '', 2003*da0073e9SAndroid Build Coastguard Worker flag_part, device, dtype_name(dtype)) 2004*da0073e9SAndroid Build Coastguard Worker expected_test_names.append(expected_name) 2005*da0073e9SAndroid Build Coastguard Worker 2006*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(device_cls) 2007*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sorted(expected_test_names), sorted(test_names)) 2008*da0073e9SAndroid Build Coastguard Worker 2009*da0073e9SAndroid Build Coastguard Worker def test_ops_decorator_applies_op_and_param_specific_decorators(self, device): 2010*da0073e9SAndroid Build Coastguard Worker # Test that decorators can be applied on a per-op / per-param basis. 2011*da0073e9SAndroid Build Coastguard Worker 2012*da0073e9SAndroid Build Coastguard Worker # Create a test op, OpInfo entry, and decorator to apply. 2013*da0073e9SAndroid Build Coastguard Worker def test_op(x): 2014*da0073e9SAndroid Build Coastguard Worker return -x 2015*da0073e9SAndroid Build Coastguard Worker 2016*da0073e9SAndroid Build Coastguard Worker def test_dec(func): 2017*da0073e9SAndroid Build Coastguard Worker func._decorator_applied = True 2018*da0073e9SAndroid Build Coastguard Worker return func 2019*da0073e9SAndroid Build Coastguard Worker 2020*da0073e9SAndroid Build Coastguard Worker test_op_info = OpInfo( 2021*da0073e9SAndroid Build Coastguard Worker 'test_op', 2022*da0073e9SAndroid Build Coastguard Worker op=test_op, 2023*da0073e9SAndroid Build Coastguard Worker dtypes=floating_types(), 2024*da0073e9SAndroid Build Coastguard Worker sample_inputs_func=lambda _: [], 2025*da0073e9SAndroid Build Coastguard Worker decorators=[ 2026*da0073e9SAndroid Build Coastguard Worker DecorateInfo(test_dec, 'TestParametrized', 'test_op_param', 2027*da0073e9SAndroid Build Coastguard Worker device_type='cpu', dtypes=[torch.float64], 2028*da0073e9SAndroid Build Coastguard Worker active_if=lambda p: p['x'] == 2) 2029*da0073e9SAndroid Build Coastguard Worker ]) 2030*da0073e9SAndroid Build Coastguard Worker 2031*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 2032*da0073e9SAndroid Build Coastguard Worker @ops(op_db + [test_op_info]) 2033*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [2, 3]) 2034*da0073e9SAndroid Build Coastguard Worker def test_op_param(self, device, dtype, op, x): 2035*da0073e9SAndroid Build Coastguard Worker pass 2036*da0073e9SAndroid Build Coastguard Worker 2037*da0073e9SAndroid Build Coastguard Worker @ops(op_db + [test_op_info]) 2038*da0073e9SAndroid Build Coastguard Worker @parametrize("y", [ 2039*da0073e9SAndroid Build Coastguard Worker subtest(4), 2040*da0073e9SAndroid Build Coastguard Worker subtest(5, decorators=[test_dec])]) 2041*da0073e9SAndroid Build Coastguard Worker def test_other(self, device, dtype, op, y): 2042*da0073e9SAndroid Build Coastguard Worker pass 2043*da0073e9SAndroid Build Coastguard Worker 2044*da0073e9SAndroid Build Coastguard Worker @decorateIf(test_dec, lambda p: p['dtype'] == torch.int16) 2045*da0073e9SAndroid Build Coastguard Worker @ops(op_db) 2046*da0073e9SAndroid Build Coastguard Worker def test_three(self, device, dtype, op): 2047*da0073e9SAndroid Build Coastguard Worker pass 2048*da0073e9SAndroid Build Coastguard Worker 2049*da0073e9SAndroid Build Coastguard Worker device = self.device_type 2050*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2051*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 2052*da0073e9SAndroid Build Coastguard Worker 2053*da0073e9SAndroid Build Coastguard Worker for test_func, name in _get_test_funcs_for_test_class(device_cls): 2054*da0073e9SAndroid Build Coastguard Worker should_apply = (name == 'test_op_param_test_op_x_2_cpu_float64' or 2055*da0073e9SAndroid Build Coastguard Worker ('test_other' in name and 'y_5' in name) or 2056*da0073e9SAndroid Build Coastguard Worker ('test_three' in name and name.endswith('_int16'))) 2057*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply) 2058*da0073e9SAndroid Build Coastguard Worker 2059*da0073e9SAndroid Build Coastguard Worker def test_modules_decorator_applies_module_and_param_specific_decorators(self, device): 2060*da0073e9SAndroid Build Coastguard Worker # Test that decorators can be applied on a per-module / per-param basis. 2061*da0073e9SAndroid Build Coastguard Worker 2062*da0073e9SAndroid Build Coastguard Worker # Create a test module, ModuleInfo entry, and decorator to apply. 2063*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 2064*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2065*da0073e9SAndroid Build Coastguard Worker super().__init__() 2066*da0073e9SAndroid Build Coastguard Worker self.x = torch.nn.Parameter(torch.randn(3)) 2067*da0073e9SAndroid Build Coastguard Worker 2068*da0073e9SAndroid Build Coastguard Worker def forward(self, y): 2069*da0073e9SAndroid Build Coastguard Worker return self.x + y 2070*da0073e9SAndroid Build Coastguard Worker 2071*da0073e9SAndroid Build Coastguard Worker def test_dec(func): 2072*da0073e9SAndroid Build Coastguard Worker func._decorator_applied = True 2073*da0073e9SAndroid Build Coastguard Worker return func 2074*da0073e9SAndroid Build Coastguard Worker 2075*da0073e9SAndroid Build Coastguard Worker test_module_info = ModuleInfo( 2076*da0073e9SAndroid Build Coastguard Worker TestModule, 2077*da0073e9SAndroid Build Coastguard Worker module_inputs_func=lambda _: [], 2078*da0073e9SAndroid Build Coastguard Worker decorators=[ 2079*da0073e9SAndroid Build Coastguard Worker DecorateInfo(test_dec, 'TestParametrized', 'test_module_param', 2080*da0073e9SAndroid Build Coastguard Worker device_type='cpu', dtypes=[torch.float64], 2081*da0073e9SAndroid Build Coastguard Worker active_if=lambda p: p['x'] == 2) 2082*da0073e9SAndroid Build Coastguard Worker ]) 2083*da0073e9SAndroid Build Coastguard Worker 2084*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 2085*da0073e9SAndroid Build Coastguard Worker @modules(module_db + [test_module_info]) 2086*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [2, 3]) 2087*da0073e9SAndroid Build Coastguard Worker def test_module_param(self, device, dtype, module_info, training, x): 2088*da0073e9SAndroid Build Coastguard Worker pass 2089*da0073e9SAndroid Build Coastguard Worker 2090*da0073e9SAndroid Build Coastguard Worker @modules(module_db + [test_module_info]) 2091*da0073e9SAndroid Build Coastguard Worker @parametrize("y", [ 2092*da0073e9SAndroid Build Coastguard Worker subtest(4), 2093*da0073e9SAndroid Build Coastguard Worker subtest(5, decorators=[test_dec])]) 2094*da0073e9SAndroid Build Coastguard Worker def test_other(self, device, dtype, module_info, training, y): 2095*da0073e9SAndroid Build Coastguard Worker pass 2096*da0073e9SAndroid Build Coastguard Worker 2097*da0073e9SAndroid Build Coastguard Worker @decorateIf(test_dec, lambda p: p['dtype'] == torch.float64) 2098*da0073e9SAndroid Build Coastguard Worker @modules(module_db) 2099*da0073e9SAndroid Build Coastguard Worker def test_three(self, device, dtype, module_info): 2100*da0073e9SAndroid Build Coastguard Worker pass 2101*da0073e9SAndroid Build Coastguard Worker 2102*da0073e9SAndroid Build Coastguard Worker device = self.device_type 2103*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2104*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 2105*da0073e9SAndroid Build Coastguard Worker 2106*da0073e9SAndroid Build Coastguard Worker for test_func, name in _get_test_funcs_for_test_class(device_cls): 2107*da0073e9SAndroid Build Coastguard Worker should_apply = (name == 'test_module_param_TestModule_x_2_cpu_float64' or 2108*da0073e9SAndroid Build Coastguard Worker ('test_other' in name and 'y_5' in name) or 2109*da0073e9SAndroid Build Coastguard Worker ('test_three' in name and name.endswith('float64'))) 2110*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply) 2111*da0073e9SAndroid Build Coastguard Worker 2112*da0073e9SAndroid Build Coastguard Worker def test_param_specific_decoration(self, device): 2113*da0073e9SAndroid Build Coastguard Worker 2114*da0073e9SAndroid Build Coastguard Worker def test_dec(func): 2115*da0073e9SAndroid Build Coastguard Worker func._decorator_applied = True 2116*da0073e9SAndroid Build Coastguard Worker return func 2117*da0073e9SAndroid Build Coastguard Worker 2118*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 2119*da0073e9SAndroid Build Coastguard Worker @decorateIf(test_dec, lambda params: params["x"] == 1 and params["y"]) 2120*da0073e9SAndroid Build Coastguard Worker @parametrize("x", range(5)) 2121*da0073e9SAndroid Build Coastguard Worker @parametrize("y", [False, True]) 2122*da0073e9SAndroid Build Coastguard Worker def test_param(self, x, y): 2123*da0073e9SAndroid Build Coastguard Worker pass 2124*da0073e9SAndroid Build Coastguard Worker 2125*da0073e9SAndroid Build Coastguard Worker device = self.device_type 2126*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2127*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 2128*da0073e9SAndroid Build Coastguard Worker 2129*da0073e9SAndroid Build Coastguard Worker for test_func, name in _get_test_funcs_for_test_class(device_cls): 2130*da0073e9SAndroid Build Coastguard Worker should_apply = ('test_param_x_1_y_True' in name) 2131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply) 2132*da0073e9SAndroid Build Coastguard Worker 2133*da0073e9SAndroid Build Coastguard Worker def test_dtypes_composition_valid(self, device): 2134*da0073e9SAndroid Build Coastguard Worker # Test checks that @parametrize and @dtypes compose as expected when @parametrize 2135*da0073e9SAndroid Build Coastguard Worker # doesn't set dtype. 2136*da0073e9SAndroid Build Coastguard Worker 2137*da0073e9SAndroid Build Coastguard Worker device = self.device_type 2138*da0073e9SAndroid Build Coastguard Worker 2139*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 2140*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 2141*da0073e9SAndroid Build Coastguard Worker @parametrize("x", range(3)) 2142*da0073e9SAndroid Build Coastguard Worker def test_parametrized(self, x, dtype): 2143*da0073e9SAndroid Build Coastguard Worker pass 2144*da0073e9SAndroid Build Coastguard Worker 2145*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2146*da0073e9SAndroid Build Coastguard Worker 2147*da0073e9SAndroid Build Coastguard Worker device_cls = locals()[f'TestParametrized{device.upper()}'] 2148*da0073e9SAndroid Build Coastguard Worker expected_test_names = [name.format(device_cls.__name__, device) for name in ( 2149*da0073e9SAndroid Build Coastguard Worker '{}.test_parametrized_x_0_{}_float32', 2150*da0073e9SAndroid Build Coastguard Worker '{}.test_parametrized_x_0_{}_float64', 2151*da0073e9SAndroid Build Coastguard Worker '{}.test_parametrized_x_1_{}_float32', 2152*da0073e9SAndroid Build Coastguard Worker '{}.test_parametrized_x_1_{}_float64', 2153*da0073e9SAndroid Build Coastguard Worker '{}.test_parametrized_x_2_{}_float32', 2154*da0073e9SAndroid Build Coastguard Worker '{}.test_parametrized_x_2_{}_float64') 2155*da0073e9SAndroid Build Coastguard Worker ] 2156*da0073e9SAndroid Build Coastguard Worker test_names = _get_test_names_for_test_class(device_cls) 2157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sorted(expected_test_names), sorted(test_names)) 2158*da0073e9SAndroid Build Coastguard Worker 2159*da0073e9SAndroid Build Coastguard Worker def test_dtypes_composition_invalid(self, device): 2160*da0073e9SAndroid Build Coastguard Worker # Test checks that @dtypes cannot be composed with parametrization decorators when they 2161*da0073e9SAndroid Build Coastguard Worker # also try to set dtype. 2162*da0073e9SAndroid Build Coastguard Worker 2163*da0073e9SAndroid Build Coastguard Worker device = self.device_type 2164*da0073e9SAndroid Build Coastguard Worker 2165*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 2166*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 2167*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.int32, torch.int64]) 2168*da0073e9SAndroid Build Coastguard Worker def test_parametrized(self, dtype): 2169*da0073e9SAndroid Build Coastguard Worker pass 2170*da0073e9SAndroid Build Coastguard Worker 2171*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "handled multiple times"): 2172*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2173*da0073e9SAndroid Build Coastguard Worker 2174*da0073e9SAndroid Build Coastguard Worker # Verify proper error behavior with @ops + @dtypes, as both try to set dtype. 2175*da0073e9SAndroid Build Coastguard Worker 2176*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 2177*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 2178*da0073e9SAndroid Build Coastguard Worker @ops(op_db) 2179*da0073e9SAndroid Build Coastguard Worker def test_parametrized(self, op, dtype): 2180*da0073e9SAndroid Build Coastguard Worker pass 2181*da0073e9SAndroid Build Coastguard Worker 2182*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "handled multiple times"): 2183*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2184*da0073e9SAndroid Build Coastguard Worker 2185*da0073e9SAndroid Build Coastguard Worker def test_multiple_handling_of_same_param_error(self, device): 2186*da0073e9SAndroid Build Coastguard Worker # Test that multiple decorators handling the same param errors out. 2187*da0073e9SAndroid Build Coastguard Worker # Both @modules and @ops handle the dtype param. 2188*da0073e9SAndroid Build Coastguard Worker 2189*da0073e9SAndroid Build Coastguard Worker class TestParametrized(TestCase): 2190*da0073e9SAndroid Build Coastguard Worker @ops(op_db) 2191*da0073e9SAndroid Build Coastguard Worker @modules(module_db) 2192*da0073e9SAndroid Build Coastguard Worker def test_param(self, device, dtype, op, module_info, training): 2193*da0073e9SAndroid Build Coastguard Worker pass 2194*da0073e9SAndroid Build Coastguard Worker 2195*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "handled multiple times"): 2196*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests(TestParametrized, locals(), only_for=device) 2197*da0073e9SAndroid Build Coastguard Worker 2198*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3]) 2199*da0073e9SAndroid Build Coastguard Worker def test_subtest_expected_failure(self, device, x): 2200*da0073e9SAndroid Build Coastguard Worker if x == 2: 2201*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Boom') 2202*da0073e9SAndroid Build Coastguard Worker 2203*da0073e9SAndroid Build Coastguard Worker @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3]) 2204*da0073e9SAndroid Build Coastguard Worker @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])]) 2205*da0073e9SAndroid Build Coastguard Worker def test_two_things_subtest_expected_failure(self, device, x, y): 2206*da0073e9SAndroid Build Coastguard Worker if x == 1 or y == 6: 2207*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Boom') 2208*da0073e9SAndroid Build Coastguard Worker 2209*da0073e9SAndroid Build Coastguard Worker 2210*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestTestParametrization) 2211*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTestParametrizationDeviceType, globals()) 2212*da0073e9SAndroid Build Coastguard Worker 2213*da0073e9SAndroid Build Coastguard Worker 2214*da0073e9SAndroid Build Coastguard Workerclass TestImports(TestCase): 2215*da0073e9SAndroid Build Coastguard Worker @classmethod 2216*da0073e9SAndroid Build Coastguard Worker def _check_python_output(cls, program) -> str: 2217*da0073e9SAndroid Build Coastguard Worker return subprocess.check_output( 2218*da0073e9SAndroid Build Coastguard Worker [sys.executable, "-W", "always", "-c", program], 2219*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 2220*da0073e9SAndroid Build Coastguard Worker # On Windows, opening the subprocess with the default CWD makes `import torch` 2221*da0073e9SAndroid Build Coastguard Worker # fail, so just set CWD to this script's directory 2222*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") 2223*da0073e9SAndroid Build Coastguard Worker 2224*da0073e9SAndroid Build Coastguard Worker def test_circular_dependencies(self) -> None: 2225*da0073e9SAndroid Build Coastguard Worker """ Checks that all modules inside torch can be imported 2226*da0073e9SAndroid Build Coastguard Worker Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """ 2227*da0073e9SAndroid Build Coastguard Worker ignored_modules = ["torch.utils.tensorboard", # deps on tensorboard 2228*da0073e9SAndroid Build Coastguard Worker "torch.distributed.elastic.rendezvous", # depps on etcd 2229*da0073e9SAndroid Build Coastguard Worker "torch.backends._coreml", # depends on pycoreml 2230*da0073e9SAndroid Build Coastguard Worker "torch.contrib.", # something weird 2231*da0073e9SAndroid Build Coastguard Worker "torch.testing._internal.distributed.", # just fails 2232*da0073e9SAndroid Build Coastguard Worker "torch.ao.pruning._experimental.", # depends on pytorch_lightning, not user-facing 2233*da0073e9SAndroid Build Coastguard Worker "torch.onnx._internal", # depends on onnx-script 2234*da0073e9SAndroid Build Coastguard Worker "torch._inductor.runtime.triton_helpers", # depends on triton 2235*da0073e9SAndroid Build Coastguard Worker "torch._inductor.codegen.cuda", # depends on cutlass 2236*da0073e9SAndroid Build Coastguard Worker ] 2237*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/77801 2238*da0073e9SAndroid Build Coastguard Worker if not sys.version_info >= (3, 9): 2239*da0073e9SAndroid Build Coastguard Worker ignored_modules.append("torch.utils.benchmark") 2240*da0073e9SAndroid Build Coastguard Worker if IS_WINDOWS or IS_MACOS or IS_JETSON: 2241*da0073e9SAndroid Build Coastguard Worker # Distributed should be importable on Windows(except nn.api.), but not on Mac 2242*da0073e9SAndroid Build Coastguard Worker if IS_MACOS or IS_JETSON: 2243*da0073e9SAndroid Build Coastguard Worker ignored_modules.append("torch.distributed.") 2244*da0073e9SAndroid Build Coastguard Worker else: 2245*da0073e9SAndroid Build Coastguard Worker ignored_modules.append("torch.distributed.nn.api.") 2246*da0073e9SAndroid Build Coastguard Worker ignored_modules.append("torch.distributed.optim.") 2247*da0073e9SAndroid Build Coastguard Worker ignored_modules.append("torch.distributed.rpc.") 2248*da0073e9SAndroid Build Coastguard Worker ignored_modules.append("torch.testing._internal.dist_utils") 2249*da0073e9SAndroid Build Coastguard Worker # And these both end up with transitive dependencies on distributed 2250*da0073e9SAndroid Build Coastguard Worker ignored_modules.append("torch.nn.parallel._replicated_tensor_ddp_interop") 2251*da0073e9SAndroid Build Coastguard Worker ignored_modules.append("torch.testing._internal.common_fsdp") 2252*da0073e9SAndroid Build Coastguard Worker ignored_modules.append("torch.testing._internal.common_distributed") 2253*da0073e9SAndroid Build Coastguard Worker 2254*da0073e9SAndroid Build Coastguard Worker torch_dir = os.path.dirname(torch.__file__) 2255*da0073e9SAndroid Build Coastguard Worker for base, folders, files in os.walk(torch_dir): 2256*da0073e9SAndroid Build Coastguard Worker prefix = os.path.relpath(base, os.path.dirname(torch_dir)).replace(os.path.sep, ".") 2257*da0073e9SAndroid Build Coastguard Worker for f in files: 2258*da0073e9SAndroid Build Coastguard Worker if not f.endswith(".py"): 2259*da0073e9SAndroid Build Coastguard Worker continue 2260*da0073e9SAndroid Build Coastguard Worker mod_name = f"{prefix}.{f[:-3]}" if f != "__init__.py" else prefix 2261*da0073e9SAndroid Build Coastguard Worker # Do not attempt to import executable modules 2262*da0073e9SAndroid Build Coastguard Worker if f == "__main__.py": 2263*da0073e9SAndroid Build Coastguard Worker continue 2264*da0073e9SAndroid Build Coastguard Worker if any(mod_name.startswith(x) for x in ignored_modules): 2265*da0073e9SAndroid Build Coastguard Worker continue 2266*da0073e9SAndroid Build Coastguard Worker try: 2267*da0073e9SAndroid Build Coastguard Worker mod = importlib.import_module(mod_name) 2268*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2269*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Failed to import {mod_name}: {e}") from e 2270*da0073e9SAndroid Build Coastguard Worker self.assertTrue(inspect.ismodule(mod)) 2271*da0073e9SAndroid Build Coastguard Worker 2272*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "TODO enable on Windows") 2273*da0073e9SAndroid Build Coastguard Worker def test_lazy_imports_are_lazy(self) -> None: 2274*da0073e9SAndroid Build Coastguard Worker out = self._check_python_output("import sys;import torch;print(all(x not in sys.modules for x in torch._lazy_modules))") 2275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.strip(), "True") 2276*da0073e9SAndroid Build Coastguard Worker 2277*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning") 2278*da0073e9SAndroid Build Coastguard Worker def test_no_warning_on_import(self) -> None: 2279*da0073e9SAndroid Build Coastguard Worker out = self._check_python_output("import torch") 2280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, "") 2281*da0073e9SAndroid Build Coastguard Worker 2282*da0073e9SAndroid Build Coastguard Worker def test_not_import_sympy(self) -> None: 2283*da0073e9SAndroid Build Coastguard Worker out = self._check_python_output("import torch;import sys;print('sympy' not in sys.modules)") 2284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.strip(), "True", 2285*da0073e9SAndroid Build Coastguard Worker "PyTorch should not depend on SymPy at import time as importing SymPy is *very* slow.\n" 2286*da0073e9SAndroid Build Coastguard Worker "See the beginning of the following blog post for how to profile and find which file is importing sympy:\n" 2287*da0073e9SAndroid Build Coastguard Worker "https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589\n\n" 2288*da0073e9SAndroid Build Coastguard Worker "If you hit this error, you may want to:\n" 2289*da0073e9SAndroid Build Coastguard Worker " - Refactor your code to avoid depending on sympy files you may not need to depend\n" 2290*da0073e9SAndroid Build Coastguard Worker " - Use TYPE_CHECKING if you are using sympy + strings if you are using sympy on type annotations\n" 2291*da0073e9SAndroid Build Coastguard Worker " - Import things that depend on SymPy locally") 2292*da0073e9SAndroid Build Coastguard Worker 2293*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning") 2294*da0073e9SAndroid Build Coastguard Worker @parametrize('path', ['torch', 'functorch']) 2295*da0073e9SAndroid Build Coastguard Worker def test_no_mutate_global_logging_on_import(self, path) -> None: 2296*da0073e9SAndroid Build Coastguard Worker # Calling logging.basicConfig, among other things, modifies the global 2297*da0073e9SAndroid Build Coastguard Worker # logging state. It is not OK to modify the global logging state on 2298*da0073e9SAndroid Build Coastguard Worker # `import torch` (or other submodules we own) because users do not expect it. 2299*da0073e9SAndroid Build Coastguard Worker expected = 'abcdefghijklmnopqrstuvwxyz' 2300*da0073e9SAndroid Build Coastguard Worker commands = [ 2301*da0073e9SAndroid Build Coastguard Worker 'import logging', 2302*da0073e9SAndroid Build Coastguard Worker f'import {path}', 2303*da0073e9SAndroid Build Coastguard Worker '_logger = logging.getLogger("torch_test_testing")', 2304*da0073e9SAndroid Build Coastguard Worker 'logging.root.addHandler(logging.StreamHandler())', 2305*da0073e9SAndroid Build Coastguard Worker 'logging.root.setLevel(logging.INFO)', 2306*da0073e9SAndroid Build Coastguard Worker f'_logger.info("{expected}")' 2307*da0073e9SAndroid Build Coastguard Worker ] 2308*da0073e9SAndroid Build Coastguard Worker out = self._check_python_output("; ".join(commands)) 2309*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.strip(), expected) 2310*da0073e9SAndroid Build Coastguard Worker 2311*da0073e9SAndroid Build Coastguard Workerclass TestOpInfos(TestCase): 2312*da0073e9SAndroid Build Coastguard Worker def test_sample_input(self) -> None: 2313*da0073e9SAndroid Build Coastguard Worker a, b, c, d, e = (object() for _ in range(5)) 2314*da0073e9SAndroid Build Coastguard Worker 2315*da0073e9SAndroid Build Coastguard Worker # Construction with natural syntax 2316*da0073e9SAndroid Build Coastguard Worker s = SampleInput(a, b, c, d=d, e=e) 2317*da0073e9SAndroid Build Coastguard Worker assert s.input is a 2318*da0073e9SAndroid Build Coastguard Worker assert s.args == (b, c) 2319*da0073e9SAndroid Build Coastguard Worker assert s.kwargs == dict(d=d, e=e) 2320*da0073e9SAndroid Build Coastguard Worker 2321*da0073e9SAndroid Build Coastguard Worker # Construction with explicit args and kwargs 2322*da0073e9SAndroid Build Coastguard Worker s = SampleInput(a, args=(b,), kwargs=dict(c=c, d=d, e=e)) 2323*da0073e9SAndroid Build Coastguard Worker assert s.input is a 2324*da0073e9SAndroid Build Coastguard Worker assert s.args == (b,) 2325*da0073e9SAndroid Build Coastguard Worker assert s.kwargs == dict(c=c, d=d, e=e) 2326*da0073e9SAndroid Build Coastguard Worker 2327*da0073e9SAndroid Build Coastguard Worker # Construction with a mixed form will error 2328*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 2329*da0073e9SAndroid Build Coastguard Worker s = SampleInput(a, b, c, args=(d, e)) 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 2332*da0073e9SAndroid Build Coastguard Worker s = SampleInput(a, b, c, kwargs=dict(d=d, e=e)) 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 2335*da0073e9SAndroid Build Coastguard Worker s = SampleInput(a, args=(b, c), d=d, e=e) 2336*da0073e9SAndroid Build Coastguard Worker 2337*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 2338*da0073e9SAndroid Build Coastguard Worker s = SampleInput(a, b, c=c, kwargs=dict(d=d, e=e)) 2339*da0073e9SAndroid Build Coastguard Worker 2340*da0073e9SAndroid Build Coastguard Worker # Mixing metadata into "natural" construction will error 2341*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 2342*da0073e9SAndroid Build Coastguard Worker s = SampleInput(a, b, name="foo") 2343*da0073e9SAndroid Build Coastguard Worker 2344*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 2345*da0073e9SAndroid Build Coastguard Worker s = SampleInput(a, b, output_process_fn_grad=lambda x: x) 2346*da0073e9SAndroid Build Coastguard Worker 2347*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 2348*da0073e9SAndroid Build Coastguard Worker s = SampleInput(a, b, broadcasts_input=True) 2349*da0073e9SAndroid Build Coastguard Worker 2350*da0073e9SAndroid Build Coastguard Worker # But when only input is given, metadata is allowed for backward 2351*da0073e9SAndroid Build Coastguard Worker # compatibility 2352*da0073e9SAndroid Build Coastguard Worker s = SampleInput(a, broadcasts_input=True) 2353*da0073e9SAndroid Build Coastguard Worker assert s.input is a 2354*da0073e9SAndroid Build Coastguard Worker assert s.broadcasts_input 2355*da0073e9SAndroid Build Coastguard Worker 2356*da0073e9SAndroid Build Coastguard Worker def test_sample_input_metadata(self) -> None: 2357*da0073e9SAndroid Build Coastguard Worker a, b = (object() for _ in range(2)) 2358*da0073e9SAndroid Build Coastguard Worker s1 = SampleInput(a, b=b) 2359*da0073e9SAndroid Build Coastguard Worker self.assertIs(s1.output_process_fn_grad(None), None) 2360*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s1.broadcasts_input) 2361*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1.name, "") 2362*da0073e9SAndroid Build Coastguard Worker 2363*da0073e9SAndroid Build Coastguard Worker s2 = s1.with_metadata( 2364*da0073e9SAndroid Build Coastguard Worker output_process_fn_grad=lambda x: a, 2365*da0073e9SAndroid Build Coastguard Worker broadcasts_input=True, 2366*da0073e9SAndroid Build Coastguard Worker name="foo", 2367*da0073e9SAndroid Build Coastguard Worker ) 2368*da0073e9SAndroid Build Coastguard Worker self.assertIs(s1, s2) 2369*da0073e9SAndroid Build Coastguard Worker self.assertIs(s2.output_process_fn_grad(None), a) 2370*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s2.broadcasts_input) 2371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s2.name, "foo") 2372*da0073e9SAndroid Build Coastguard Worker 2373*da0073e9SAndroid Build Coastguard Worker 2374*da0073e9SAndroid Build Coastguard Worker# Tests that validate the various sample generating functions on each OpInfo. 2375*da0073e9SAndroid Build Coastguard Workerclass TestOpInfoSampleFunctions(TestCase): 2376*da0073e9SAndroid Build Coastguard Worker 2377*da0073e9SAndroid Build Coastguard Worker @ops(op_db, dtypes=OpDTypes.any_one) 2378*da0073e9SAndroid Build Coastguard Worker def test_opinfo_sample_generators(self, device, dtype, op): 2379*da0073e9SAndroid Build Coastguard Worker # Test op.sample_inputs doesn't generate multiple samples when called 2380*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype) 2381*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(samples, Iterator) 2382*da0073e9SAndroid Build Coastguard Worker 2383*da0073e9SAndroid Build Coastguard Worker @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one) 2384*da0073e9SAndroid Build Coastguard Worker def test_opinfo_reference_generators(self, device, dtype, op): 2385*da0073e9SAndroid Build Coastguard Worker # Test op.reference_inputs doesn't generate multiple samples when called 2386*da0073e9SAndroid Build Coastguard Worker samples = op.reference_inputs(device, dtype) 2387*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(samples, Iterator) 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) 2390*da0073e9SAndroid Build Coastguard Worker def test_opinfo_error_generators(self, device, op): 2391*da0073e9SAndroid Build Coastguard Worker # Test op.error_inputs doesn't generate multiple inputs when called 2392*da0073e9SAndroid Build Coastguard Worker samples = op.error_inputs(device) 2393*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(samples, Iterator) 2394*da0073e9SAndroid Build Coastguard Worker 2395*da0073e9SAndroid Build Coastguard Worker 2396*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestOpInfoSampleFunctions, globals()) 2397*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestImports) 2398*da0073e9SAndroid Build Coastguard Worker 2399*da0073e9SAndroid Build Coastguard Worker 2400*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 2401*da0073e9SAndroid Build Coastguard Worker run_tests() 2402