xref: /aosp_15_r20/external/pytorch/test/test_testing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport doctest
5*da0073e9SAndroid Build Coastguard Workerimport functools
6*da0073e9SAndroid Build Coastguard Workerimport importlib
7*da0073e9SAndroid Build Coastguard Workerimport inspect
8*da0073e9SAndroid Build Coastguard Workerimport itertools
9*da0073e9SAndroid Build Coastguard Workerimport math
10*da0073e9SAndroid Build Coastguard Workerimport os
11*da0073e9SAndroid Build Coastguard Workerimport re
12*da0073e9SAndroid Build Coastguard Workerimport subprocess
13*da0073e9SAndroid Build Coastguard Workerimport sys
14*da0073e9SAndroid Build Coastguard Workerimport unittest.mock
15*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Iterator, List, Tuple
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerimport torch
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import \
21*da0073e9SAndroid Build Coastguard Worker    (IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest,
22*da0073e9SAndroid Build Coastguard Worker     parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf)
23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import \
24*da0073e9SAndroid Build Coastguard Worker    (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
25*da0073e9SAndroid Build Coastguard Worker     get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes,
26*da0073e9SAndroid Build Coastguard Worker     deviceCountAtLeast, ops, expectedFailureMeta, OpDTypes)
27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db
28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import opinfo
29*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import all_types_and_complex_and, floating_types
30*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_modules import modules, module_db, ModuleInfo
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.opinfo.core import SampleInput, DecorateInfo, OpInfo
32*da0073e9SAndroid Build Coastguard Workerimport operator
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker# For testing TestCase methods and torch.testing functions
35*da0073e9SAndroid Build Coastguard Workerclass TestTesting(TestCase):
36*da0073e9SAndroid Build Coastguard Worker    # Ensure that assertEqual handles numpy arrays properly
37*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.bool, torch.half))
38*da0073e9SAndroid Build Coastguard Worker    def test_assertEqual_numpy(self, device, dtype):
39*da0073e9SAndroid Build Coastguard Worker        S = 10
40*da0073e9SAndroid Build Coastguard Worker        test_sizes = [
41*da0073e9SAndroid Build Coastguard Worker            (),
42*da0073e9SAndroid Build Coastguard Worker            (0,),
43*da0073e9SAndroid Build Coastguard Worker            (S,),
44*da0073e9SAndroid Build Coastguard Worker            (S, S),
45*da0073e9SAndroid Build Coastguard Worker            (0, S),
46*da0073e9SAndroid Build Coastguard Worker            (S, 0)]
47*da0073e9SAndroid Build Coastguard Worker        for test_size in test_sizes:
48*da0073e9SAndroid Build Coastguard Worker            a = make_tensor(test_size, dtype=dtype, device=device, low=-5, high=5)
49*da0073e9SAndroid Build Coastguard Worker            a_n = a.cpu().numpy()
50*da0073e9SAndroid Build Coastguard Worker            msg = f'size: {test_size}'
51*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg)
52*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg)
53*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def test_assertEqual_longMessage(self):
56*da0073e9SAndroid Build Coastguard Worker        actual = "actual"
57*da0073e9SAndroid Build Coastguard Worker        expected = "expected"
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        long_message = self.longMessage
60*da0073e9SAndroid Build Coastguard Worker        try:
61*da0073e9SAndroid Build Coastguard Worker            # Capture the default error message by forcing TestCase.longMessage = False
62*da0073e9SAndroid Build Coastguard Worker            self.longMessage = False
63*da0073e9SAndroid Build Coastguard Worker            try:
64*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected)
65*da0073e9SAndroid Build Coastguard Worker            except AssertionError as error:
66*da0073e9SAndroid Build Coastguard Worker                default_msg = str(error)
67*da0073e9SAndroid Build Coastguard Worker            else:
68*da0073e9SAndroid Build Coastguard Worker                raise AssertionError("AssertionError not raised")
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker            self.longMessage = True
71*da0073e9SAndroid Build Coastguard Worker            extra_msg = "sentinel"
72*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape(f"{default_msg}\n{extra_msg}")):
73*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected, msg=extra_msg)
74*da0073e9SAndroid Build Coastguard Worker        finally:
75*da0073e9SAndroid Build Coastguard Worker            self.longMessage = long_message
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05):
78*da0073e9SAndroid Build Coastguard Worker        for test in tests:
79*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor((test[0],), device=device, dtype=dtype)
80*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor((test[1],), device=device, dtype=dtype)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker            actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol)
83*da0073e9SAndroid Build Coastguard Worker            expected = test[2]
84*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual.item(), expected)
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def test_isclose_bool(self, device):
87*da0073e9SAndroid Build Coastguard Worker        tests = (
88*da0073e9SAndroid Build Coastguard Worker            (True, True, True),
89*da0073e9SAndroid Build Coastguard Worker            (False, False, True),
90*da0073e9SAndroid Build Coastguard Worker            (True, False, False),
91*da0073e9SAndroid Build Coastguard Worker            (False, True, False),
92*da0073e9SAndroid Build Coastguard Worker        )
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, torch.bool, False)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.uint8,
97*da0073e9SAndroid Build Coastguard Worker            torch.int8, torch.int16, torch.int32, torch.int64)
98*da0073e9SAndroid Build Coastguard Worker    def test_isclose_integer(self, device, dtype):
99*da0073e9SAndroid Build Coastguard Worker        tests = (
100*da0073e9SAndroid Build Coastguard Worker            (0, 0, True),
101*da0073e9SAndroid Build Coastguard Worker            (0, 1, False),
102*da0073e9SAndroid Build Coastguard Worker            (1, 0, False),
103*da0073e9SAndroid Build Coastguard Worker        )
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, False)
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        # atol and rtol tests
108*da0073e9SAndroid Build Coastguard Worker        tests = [
109*da0073e9SAndroid Build Coastguard Worker            (0, 1, True),
110*da0073e9SAndroid Build Coastguard Worker            (1, 0, False),
111*da0073e9SAndroid Build Coastguard Worker            (1, 3, True),
112*da0073e9SAndroid Build Coastguard Worker        ]
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.uint8:
117*da0073e9SAndroid Build Coastguard Worker            tests = [
118*da0073e9SAndroid Build Coastguard Worker                (-1, 1, False),
119*da0073e9SAndroid Build Coastguard Worker                (1, -1, False)
120*da0073e9SAndroid Build Coastguard Worker            ]
121*da0073e9SAndroid Build Coastguard Worker        else:
122*da0073e9SAndroid Build Coastguard Worker            tests = [
123*da0073e9SAndroid Build Coastguard Worker                (-1, 1, True),
124*da0073e9SAndroid Build Coastguard Worker                (1, -1, True)
125*da0073e9SAndroid Build Coastguard Worker            ]
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
130*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.float32, torch.float64)
131*da0073e9SAndroid Build Coastguard Worker    def test_isclose_float(self, device, dtype):
132*da0073e9SAndroid Build Coastguard Worker        tests = (
133*da0073e9SAndroid Build Coastguard Worker            (0, 0, True),
134*da0073e9SAndroid Build Coastguard Worker            (0, -1, False),
135*da0073e9SAndroid Build Coastguard Worker            (float('inf'), float('inf'), True),
136*da0073e9SAndroid Build Coastguard Worker            (-float('inf'), float('inf'), False),
137*da0073e9SAndroid Build Coastguard Worker            (float('inf'), float('nan'), False),
138*da0073e9SAndroid Build Coastguard Worker            (float('nan'), float('nan'), False),
139*da0073e9SAndroid Build Coastguard Worker            (0, float('nan'), False),
140*da0073e9SAndroid Build Coastguard Worker            (1, 1, True),
141*da0073e9SAndroid Build Coastguard Worker        )
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, False)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker        # atol and rtol tests
146*da0073e9SAndroid Build Coastguard Worker        eps = 1e-2 if dtype is torch.half else 1e-6
147*da0073e9SAndroid Build Coastguard Worker        tests = (
148*da0073e9SAndroid Build Coastguard Worker            (0, 1, True),
149*da0073e9SAndroid Build Coastguard Worker            (0, 1 + eps, False),
150*da0073e9SAndroid Build Coastguard Worker            (1, 0, False),
151*da0073e9SAndroid Build Coastguard Worker            (1, 3, True),
152*da0073e9SAndroid Build Coastguard Worker            (1 - eps, 3, False),
153*da0073e9SAndroid Build Coastguard Worker            (-.25, .5, True),
154*da0073e9SAndroid Build Coastguard Worker            (-.25 - eps, .5, False),
155*da0073e9SAndroid Build Coastguard Worker            (.25, -.5, True),
156*da0073e9SAndroid Build Coastguard Worker            (.25 + eps, -.5, False),
157*da0073e9SAndroid Build Coastguard Worker        )
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        # equal_nan = True tests
162*da0073e9SAndroid Build Coastguard Worker        tests = (
163*da0073e9SAndroid Build Coastguard Worker            (0, float('nan'), False),
164*da0073e9SAndroid Build Coastguard Worker            (float('inf'), float('nan'), False),
165*da0073e9SAndroid Build Coastguard Worker            (float('nan'), float('nan'), True),
166*da0073e9SAndroid Build Coastguard Worker        )
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, True)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
171*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.complex64, torch.complex128)
172*da0073e9SAndroid Build Coastguard Worker    def test_isclose_complex(self, device, dtype):
173*da0073e9SAndroid Build Coastguard Worker        tests = (
174*da0073e9SAndroid Build Coastguard Worker            (complex(1, 1), complex(1, 1 + 1e-8), True),
175*da0073e9SAndroid Build Coastguard Worker            (complex(0, 1), complex(1, 1), False),
176*da0073e9SAndroid Build Coastguard Worker            (complex(1, 1), complex(1, 0), False),
177*da0073e9SAndroid Build Coastguard Worker            (complex(1, 1), complex(1, float('nan')), False),
178*da0073e9SAndroid Build Coastguard Worker            (complex(1, float('nan')), complex(1, float('nan')), False),
179*da0073e9SAndroid Build Coastguard Worker            (complex(1, 1), complex(1, float('inf')), False),
180*da0073e9SAndroid Build Coastguard Worker            (complex(float('inf'), 1), complex(1, float('inf')), False),
181*da0073e9SAndroid Build Coastguard Worker            (complex(-float('inf'), 1), complex(1, float('inf')), False),
182*da0073e9SAndroid Build Coastguard Worker            (complex(-float('inf'), 1), complex(float('inf'), 1), False),
183*da0073e9SAndroid Build Coastguard Worker            (complex(float('inf'), 1), complex(float('inf'), 1), True),
184*da0073e9SAndroid Build Coastguard Worker            (complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False),
185*da0073e9SAndroid Build Coastguard Worker        )
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, False)
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        # atol and rtol tests
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        # atol and rtol tests
192*da0073e9SAndroid Build Coastguard Worker        eps = 1e-6
193*da0073e9SAndroid Build Coastguard Worker        tests = (
194*da0073e9SAndroid Build Coastguard Worker            # Complex versions of float tests (real part)
195*da0073e9SAndroid Build Coastguard Worker            (complex(0, 0), complex(1, 0), True),
196*da0073e9SAndroid Build Coastguard Worker            (complex(0, 0), complex(1 + eps, 0), False),
197*da0073e9SAndroid Build Coastguard Worker            (complex(1, 0), complex(0, 0), False),
198*da0073e9SAndroid Build Coastguard Worker            (complex(1, 0), complex(3, 0), True),
199*da0073e9SAndroid Build Coastguard Worker            (complex(1 - eps, 0), complex(3, 0), False),
200*da0073e9SAndroid Build Coastguard Worker            (complex(-.25, 0), complex(.5, 0), True),
201*da0073e9SAndroid Build Coastguard Worker            (complex(-.25 - eps, 0), complex(.5, 0), False),
202*da0073e9SAndroid Build Coastguard Worker            (complex(.25, 0), complex(-.5, 0), True),
203*da0073e9SAndroid Build Coastguard Worker            (complex(.25 + eps, 0), complex(-.5, 0), False),
204*da0073e9SAndroid Build Coastguard Worker            # Complex versions of float tests (imaginary part)
205*da0073e9SAndroid Build Coastguard Worker            (complex(0, 0), complex(0, 1), True),
206*da0073e9SAndroid Build Coastguard Worker            (complex(0, 0), complex(0, 1 + eps), False),
207*da0073e9SAndroid Build Coastguard Worker            (complex(0, 1), complex(0, 0), False),
208*da0073e9SAndroid Build Coastguard Worker            (complex(0, 1), complex(0, 3), True),
209*da0073e9SAndroid Build Coastguard Worker            (complex(0, 1 - eps), complex(0, 3), False),
210*da0073e9SAndroid Build Coastguard Worker            (complex(0, -.25), complex(0, .5), True),
211*da0073e9SAndroid Build Coastguard Worker            (complex(0, -.25 - eps), complex(0, .5), False),
212*da0073e9SAndroid Build Coastguard Worker            (complex(0, .25), complex(0, -.5), True),
213*da0073e9SAndroid Build Coastguard Worker            (complex(0, .25 + eps), complex(0, -.5), False),
214*da0073e9SAndroid Build Coastguard Worker        )
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        # atol and rtol tests for isclose
219*da0073e9SAndroid Build Coastguard Worker        tests = (
220*da0073e9SAndroid Build Coastguard Worker            # Complex-specific tests
221*da0073e9SAndroid Build Coastguard Worker            (complex(1, -1), complex(-1, 1), False),
222*da0073e9SAndroid Build Coastguard Worker            (complex(1, -1), complex(2, -2), True),
223*da0073e9SAndroid Build Coastguard Worker            (complex(-math.sqrt(2), math.sqrt(2)),
224*da0073e9SAndroid Build Coastguard Worker             complex(-math.sqrt(.5), math.sqrt(.5)), True),
225*da0073e9SAndroid Build Coastguard Worker            (complex(-math.sqrt(2), math.sqrt(2)),
226*da0073e9SAndroid Build Coastguard Worker             complex(-math.sqrt(.501), math.sqrt(.499)), False),
227*da0073e9SAndroid Build Coastguard Worker            (complex(2, 4), complex(1., 8.8523607), True),
228*da0073e9SAndroid Build Coastguard Worker            (complex(2, 4), complex(1., 8.8523607 + eps), False),
229*da0073e9SAndroid Build Coastguard Worker            (complex(1, 99), complex(4, 100), True),
230*da0073e9SAndroid Build Coastguard Worker        )
231*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        # equal_nan = True tests
234*da0073e9SAndroid Build Coastguard Worker        tests = (
235*da0073e9SAndroid Build Coastguard Worker            (complex(1, 1), complex(1, float('nan')), False),
236*da0073e9SAndroid Build Coastguard Worker            (complex(1, 1), complex(float('nan'), 1), False),
237*da0073e9SAndroid Build Coastguard Worker            (complex(float('nan'), 1), complex(float('nan'), 1), True),
238*da0073e9SAndroid Build Coastguard Worker            (complex(float('nan'), 1), complex(1, float('nan')), True),
239*da0073e9SAndroid Build Coastguard Worker            (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
240*da0073e9SAndroid Build Coastguard Worker        )
241*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, True)
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    # Tests that isclose with rtol or atol values less than zero throws a
244*da0073e9SAndroid Build Coastguard Worker    #   RuntimeError
245*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bool, torch.uint8,
246*da0073e9SAndroid Build Coastguard Worker            torch.int8, torch.int16, torch.int32, torch.int64,
247*da0073e9SAndroid Build Coastguard Worker            torch.float16, torch.float32, torch.float64)
248*da0073e9SAndroid Build Coastguard Worker    def test_isclose_atol_rtol_greater_than_zero(self, device, dtype):
249*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor((1,), device=device, dtype=dtype)
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
252*da0073e9SAndroid Build Coastguard Worker            torch.isclose(t, t, atol=-1, rtol=1)
253*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
254*da0073e9SAndroid Build Coastguard Worker            torch.isclose(t, t, atol=1, rtol=-1)
255*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
256*da0073e9SAndroid Build Coastguard Worker            torch.isclose(t, t, atol=-1, rtol=-1)
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker    def test_isclose_equality_shortcut(self):
259*da0073e9SAndroid Build Coastguard Worker        # For values >= 2**53, integers differing by 1 can no longer differentiated by torch.float64 or lower precision
260*da0073e9SAndroid Build Coastguard Worker        # floating point dtypes. Thus, even with rtol == 0 and atol == 0, these tensors would be considered close if
261*da0073e9SAndroid Build Coastguard Worker        # they were not compared as integers.
262*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(2 ** 53, dtype=torch.int64)
263*da0073e9SAndroid Build Coastguard Worker        b = a + 1
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.isclose(a, b, rtol=0, atol=0))
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128)
268*da0073e9SAndroid Build Coastguard Worker    def test_isclose_nan_equality_shortcut(self, device, dtype):
269*da0073e9SAndroid Build Coastguard Worker        if dtype.is_floating_point:
270*da0073e9SAndroid Build Coastguard Worker            a = b = torch.nan
271*da0073e9SAndroid Build Coastguard Worker        else:
272*da0073e9SAndroid Build Coastguard Worker            a = complex(torch.nan, 0)
273*da0073e9SAndroid Build Coastguard Worker            b = complex(0, torch.nan)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        expected = True
276*da0073e9SAndroid Build Coastguard Worker        tests = [(a, b, expected)]
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0)
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker    # The following tests (test_cuda_assert_*) are added to ensure test suite terminates early
281*da0073e9SAndroid Build Coastguard Worker    # when CUDA assert was thrown. Because all subsequent test will fail if that happens.
282*da0073e9SAndroid Build Coastguard Worker    # These tests are slow because it spawn another process to run test suite.
283*da0073e9SAndroid Build Coastguard Worker    # See: https://github.com/pytorch/pytorch/issues/49019
284*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
285*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
286*da0073e9SAndroid Build Coastguard Worker    @slowTest
287*da0073e9SAndroid Build Coastguard Worker    def test_cuda_assert_should_stop_common_utils_test_suite(self, device):
288*da0073e9SAndroid Build Coastguard Worker        # test to ensure common_utils.py override has early termination for CUDA.
289*da0073e9SAndroid Build Coastguard Worker        stderr = TestCase.runWithPytorchAPIUsageStderr("""\
290*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Workerimport torch
293*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Workerclass TestThatContainsCUDAAssertFailure(TestCase):
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    @slowTest
298*da0073e9SAndroid Build Coastguard Worker    def test_throw_unrecoverable_cuda_exception(self):
299*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10, device='cuda')
300*da0073e9SAndroid Build Coastguard Worker        # cause unrecoverable CUDA exception, recoverable on CPU
301*da0073e9SAndroid Build Coastguard Worker        y = x[torch.tensor([25])].cpu()
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker    @slowTest
304*da0073e9SAndroid Build Coastguard Worker    def test_trivial_passing_test_case_on_cpu_cuda(self):
305*da0073e9SAndroid Build Coastguard Worker        x1 = torch.tensor([0., 1.], device='cuda')
306*da0073e9SAndroid Build Coastguard Worker        x2 = torch.tensor([0., 1.], device='cpu')
307*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1, x2)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
310*da0073e9SAndroid Build Coastguard Worker    run_tests()
311*da0073e9SAndroid Build Coastguard Worker""")
312*da0073e9SAndroid Build Coastguard Worker        # should capture CUDA error
313*da0073e9SAndroid Build Coastguard Worker        self.assertIn('CUDA error: device-side assert triggered', stderr)
314*da0073e9SAndroid Build Coastguard Worker        # should run only 1 test because it throws unrecoverable error.
315*da0073e9SAndroid Build Coastguard Worker        self.assertIn('errors=1', stderr)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
319*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
320*da0073e9SAndroid Build Coastguard Worker    @slowTest
321*da0073e9SAndroid Build Coastguard Worker    def test_cuda_assert_should_stop_common_device_type_test_suite(self, device):
322*da0073e9SAndroid Build Coastguard Worker        # test to ensure common_device_type.py override has early termination for CUDA.
323*da0073e9SAndroid Build Coastguard Worker        stderr = TestCase.runWithPytorchAPIUsageStderr("""\
324*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Workerimport torch
327*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
328*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Workerclass TestThatContainsCUDAAssertFailure(TestCase):
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker    @slowTest
333*da0073e9SAndroid Build Coastguard Worker    def test_throw_unrecoverable_cuda_exception(self, device):
334*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10, device=device)
335*da0073e9SAndroid Build Coastguard Worker        # cause unrecoverable CUDA exception, recoverable on CPU
336*da0073e9SAndroid Build Coastguard Worker        y = x[torch.tensor([25])].cpu()
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker    @slowTest
339*da0073e9SAndroid Build Coastguard Worker    def test_trivial_passing_test_case_on_cpu_cuda(self, device):
340*da0073e9SAndroid Build Coastguard Worker        x1 = torch.tensor([0., 1.], device=device)
341*da0073e9SAndroid Build Coastguard Worker        x2 = torch.tensor([0., 1.], device='cpu')
342*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1, x2)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(
345*da0073e9SAndroid Build Coastguard Worker    TestThatContainsCUDAAssertFailure,
346*da0073e9SAndroid Build Coastguard Worker    globals(),
347*da0073e9SAndroid Build Coastguard Worker    only_for='cuda'
348*da0073e9SAndroid Build Coastguard Worker)
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
351*da0073e9SAndroid Build Coastguard Worker    run_tests()
352*da0073e9SAndroid Build Coastguard Worker""")
353*da0073e9SAndroid Build Coastguard Worker        # should capture CUDA error
354*da0073e9SAndroid Build Coastguard Worker        self.assertIn('CUDA error: device-side assert triggered', stderr)
355*da0073e9SAndroid Build Coastguard Worker        # should run only 1 test because it throws unrecoverable error.
356*da0073e9SAndroid Build Coastguard Worker        self.assertIn('errors=1', stderr)
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
360*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
361*da0073e9SAndroid Build Coastguard Worker    @slowTest
362*da0073e9SAndroid Build Coastguard Worker    def test_cuda_assert_should_not_stop_common_distributed_test_suite(self, device):
363*da0073e9SAndroid Build Coastguard Worker        # test to ensure common_distributed.py override should not early terminate CUDA.
364*da0073e9SAndroid Build Coastguard Worker        stderr = TestCase.runWithPytorchAPIUsageStderr("""\
365*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Workerimport torch
368*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (run_tests, slowTest)
369*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests
370*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_distributed import MultiProcessTestCase
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Workerclass TestThatContainsCUDAAssertFailure(MultiProcessTestCase):
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker    @slowTest
375*da0073e9SAndroid Build Coastguard Worker    def test_throw_unrecoverable_cuda_exception(self, device):
376*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10, device=device)
377*da0073e9SAndroid Build Coastguard Worker        # cause unrecoverable CUDA exception, recoverable on CPU
378*da0073e9SAndroid Build Coastguard Worker        y = x[torch.tensor([25])].cpu()
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker    @slowTest
381*da0073e9SAndroid Build Coastguard Worker    def test_trivial_passing_test_case_on_cpu_cuda(self, device):
382*da0073e9SAndroid Build Coastguard Worker        x1 = torch.tensor([0., 1.], device=device)
383*da0073e9SAndroid Build Coastguard Worker        x2 = torch.tensor([0., 1.], device='cpu')
384*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1, x2)
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(
387*da0073e9SAndroid Build Coastguard Worker    TestThatContainsCUDAAssertFailure,
388*da0073e9SAndroid Build Coastguard Worker    globals(),
389*da0073e9SAndroid Build Coastguard Worker    only_for='cuda'
390*da0073e9SAndroid Build Coastguard Worker)
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
393*da0073e9SAndroid Build Coastguard Worker    run_tests()
394*da0073e9SAndroid Build Coastguard Worker""")
395*da0073e9SAndroid Build Coastguard Worker        # we are currently disabling CUDA early termination for distributed tests.
396*da0073e9SAndroid Build Coastguard Worker        self.assertIn('errors=2', stderr)
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # This is only supported for CPU and CUDA
399*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
400*da0073e9SAndroid Build Coastguard Worker    def test_get_supported_dtypes(self, device):
401*da0073e9SAndroid Build Coastguard Worker        # Test the `get_supported_dtypes` helper function.
402*da0073e9SAndroid Build Coastguard Worker        # We acquire the dtypes for few Ops dynamically and verify them against
403*da0073e9SAndroid Build Coastguard Worker        # the correct statically described values.
404*da0073e9SAndroid Build Coastguard Worker        ops_to_test = list(filter(lambda op: op.name in ['atan2', 'topk', 'xlogy'], op_db))
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker        for op in ops_to_test:
407*da0073e9SAndroid Build Coastguard Worker            dynamic_dtypes = opinfo.utils.get_supported_dtypes(op, op.sample_inputs_func, self.device_type)
408*da0073e9SAndroid Build Coastguard Worker            dynamic_dispatch = opinfo.utils.dtypes_dispatch_hint(dynamic_dtypes)
409*da0073e9SAndroid Build Coastguard Worker            if self.device_type == 'cpu':
410*da0073e9SAndroid Build Coastguard Worker                dtypes = op.dtypes
411*da0073e9SAndroid Build Coastguard Worker            else:  # device_type ='cuda'
412*da0073e9SAndroid Build Coastguard Worker                dtypes = op.dtypesIfCUDA
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(set(dtypes) == set(dynamic_dtypes))
415*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(set(dtypes) == set(dynamic_dispatch.dispatch_fn()))
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
418*da0073e9SAndroid Build Coastguard Worker    @ops(
419*da0073e9SAndroid Build Coastguard Worker        [
420*da0073e9SAndroid Build Coastguard Worker            op
421*da0073e9SAndroid Build Coastguard Worker            for op in op_db
422*da0073e9SAndroid Build Coastguard Worker            if len(
423*da0073e9SAndroid Build Coastguard Worker                op.supported_dtypes("cpu").symmetric_difference(
424*da0073e9SAndroid Build Coastguard Worker                    op.supported_dtypes("cuda")
425*da0073e9SAndroid Build Coastguard Worker                )
426*da0073e9SAndroid Build Coastguard Worker            )
427*da0073e9SAndroid Build Coastguard Worker            > 0
428*da0073e9SAndroid Build Coastguard Worker        ][:1],
429*da0073e9SAndroid Build Coastguard Worker        dtypes=OpDTypes.none,
430*da0073e9SAndroid Build Coastguard Worker    )
431*da0073e9SAndroid Build Coastguard Worker    def test_supported_dtypes(self, device, op):
432*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(op.supported_dtypes("cpu"), op.supported_dtypes("cuda"))
433*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(op.supported_dtypes("cuda"), op.supported_dtypes("cuda:0"))
434*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
435*da0073e9SAndroid Build Coastguard Worker            op.supported_dtypes(torch.device("cuda")),
436*da0073e9SAndroid Build Coastguard Worker            op.supported_dtypes(torch.device("cuda", index=1)),
437*da0073e9SAndroid Build Coastguard Worker        )
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTesting, globals())
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Workerclass TestFrameworkUtils(TestCase):
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
445*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
446*da0073e9SAndroid Build Coastguard Worker    def test_filtering_env_var(self):
447*da0073e9SAndroid Build Coastguard Worker        # Test environment variable selected device type test generator.
448*da0073e9SAndroid Build Coastguard Worker        test_filter_file_template = """\
449*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Workerimport torch
452*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (TestCase, run_tests)
453*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Workerclass TestEnvironmentVariable(TestCase):
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker    def test_trivial_passing_test(self, device):
458*da0073e9SAndroid Build Coastguard Worker        x1 = torch.tensor([0., 1.], device=device)
459*da0073e9SAndroid Build Coastguard Worker        x2 = torch.tensor([0., 1.], device='cpu')
460*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1, x2)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(
463*da0073e9SAndroid Build Coastguard Worker    TestEnvironmentVariable,
464*da0073e9SAndroid Build Coastguard Worker    globals(),
465*da0073e9SAndroid Build Coastguard Worker)
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
468*da0073e9SAndroid Build Coastguard Worker    run_tests()
469*da0073e9SAndroid Build Coastguard Worker"""
470*da0073e9SAndroid Build Coastguard Worker        test_bases_count = len(get_device_type_test_bases())
471*da0073e9SAndroid Build Coastguard Worker        # Test without setting env var should run everything.
472*da0073e9SAndroid Build Coastguard Worker        env = dict(os.environ)
473*da0073e9SAndroid Build Coastguard Worker        for k in ['CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]:
474*da0073e9SAndroid Build Coastguard Worker            if k in env.keys():
475*da0073e9SAndroid Build Coastguard Worker                del env[k]
476*da0073e9SAndroid Build Coastguard Worker        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
477*da0073e9SAndroid Build Coastguard Worker        self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii'))
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker        # Test with setting only_for should only run 1 test.
480*da0073e9SAndroid Build Coastguard Worker        env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
481*da0073e9SAndroid Build Coastguard Worker        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
482*da0073e9SAndroid Build Coastguard Worker        self.assertIn('Ran 1 test', stderr.decode('ascii'))
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker        # Test with setting except_for should run 1 less device type from default.
485*da0073e9SAndroid Build Coastguard Worker        del env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY]
486*da0073e9SAndroid Build Coastguard Worker        env[PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY] = 'cpu'
487*da0073e9SAndroid Build Coastguard Worker        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
488*da0073e9SAndroid Build Coastguard Worker        self.assertIn(f'Ran {test_bases_count-1} test', stderr.decode('ascii'))
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker        # Test with setting both should throw exception
491*da0073e9SAndroid Build Coastguard Worker        env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
492*da0073e9SAndroid Build Coastguard Worker        _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
493*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn('OK', stderr.decode('ascii'))
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Workerdef make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]:
497*da0073e9SAndroid Build Coastguard Worker    """Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker    Args:
500*da0073e9SAndroid Build Coastguard Worker        actual (Any): Actual input.
501*da0073e9SAndroid Build Coastguard Worker        expected (Any): Expected input.
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker    Returns:
504*da0073e9SAndroid Build Coastguard Worker        List[Tuple[Any, Any]]: Pair of example inputs, as well as the example inputs wrapped in sequences
505*da0073e9SAndroid Build Coastguard Worker        (:class:`tuple`, :class:`list`), and mappings (:class:`dict`, :class:`~collections.OrderedDict`).
506*da0073e9SAndroid Build Coastguard Worker    """
507*da0073e9SAndroid Build Coastguard Worker    return [
508*da0073e9SAndroid Build Coastguard Worker        (actual, expected),
509*da0073e9SAndroid Build Coastguard Worker        # tuple vs. tuple
510*da0073e9SAndroid Build Coastguard Worker        ((actual,), (expected,)),
511*da0073e9SAndroid Build Coastguard Worker        # list vs. list
512*da0073e9SAndroid Build Coastguard Worker        ([actual], [expected]),
513*da0073e9SAndroid Build Coastguard Worker        # tuple vs. list
514*da0073e9SAndroid Build Coastguard Worker        ((actual,), [expected]),
515*da0073e9SAndroid Build Coastguard Worker        # dict vs. dict
516*da0073e9SAndroid Build Coastguard Worker        ({"t": actual}, {"t": expected}),
517*da0073e9SAndroid Build Coastguard Worker        # OrderedDict vs. OrderedDict
518*da0073e9SAndroid Build Coastguard Worker        (collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
519*da0073e9SAndroid Build Coastguard Worker        # dict vs. OrderedDict
520*da0073e9SAndroid Build Coastguard Worker        ({"t": actual}, collections.OrderedDict([("t", expected)])),
521*da0073e9SAndroid Build Coastguard Worker        # list of tuples vs. tuple of lists
522*da0073e9SAndroid Build Coastguard Worker        ([(actual,)], ([expected],)),
523*da0073e9SAndroid Build Coastguard Worker        # list of dicts vs. tuple of OrderedDicts
524*da0073e9SAndroid Build Coastguard Worker        ([{"t": actual}], (collections.OrderedDict([("t", expected)]),)),
525*da0073e9SAndroid Build Coastguard Worker        # dict of lists vs. OrderedDict of tuples
526*da0073e9SAndroid Build Coastguard Worker        ({"t": [actual]}, collections.OrderedDict([("t", (expected,))])),
527*da0073e9SAndroid Build Coastguard Worker    ]
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Workerdef assert_close_with_inputs(actual: Any, expected: Any) -> Iterator[Callable]:
531*da0073e9SAndroid Build Coastguard Worker    """Yields :func:`torch.testing.assert_close` with predefined positional inputs based on two examples.
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker    .. note::
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker        Every test that does not test for a specific input should iterate over this to maximize the coverage.
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker    Args:
538*da0073e9SAndroid Build Coastguard Worker        actual (Any): Actual input.
539*da0073e9SAndroid Build Coastguard Worker        expected (Any): Expected input.
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker    Yields:
542*da0073e9SAndroid Build Coastguard Worker        Callable: :func:`torch.testing.assert_close` with predefined positional inputs.
543*da0073e9SAndroid Build Coastguard Worker    """
544*da0073e9SAndroid Build Coastguard Worker    for inputs in make_assert_close_inputs(actual, expected):
545*da0073e9SAndroid Build Coastguard Worker        yield functools.partial(torch.testing.assert_close, *inputs)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Workerclass TestAssertClose(TestCase):
549*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_types_subclasses(self):
550*da0073e9SAndroid Build Coastguard Worker        actual = torch.rand(())
551*da0073e9SAndroid Build Coastguard Worker        expected = torch.nn.Parameter(actual)
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
554*da0073e9SAndroid Build Coastguard Worker            fn()
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_types_type_equality(self):
557*da0073e9SAndroid Build Coastguard Worker        actual = torch.empty(())
558*da0073e9SAndroid Build Coastguard Worker        expected = torch.nn.Parameter(actual)
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
561*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, str(type(expected))):
562*da0073e9SAndroid Build Coastguard Worker                fn(allow_subclasses=False)
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_types(self):
565*da0073e9SAndroid Build Coastguard Worker        actual = torch.empty(2)
566*da0073e9SAndroid Build Coastguard Worker        expected = actual.numpy()
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker        for fn, allow_subclasses in itertools.product(assert_close_with_inputs(actual, expected), (True, False)):
569*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, str(type(expected))):
570*da0073e9SAndroid Build Coastguard Worker                fn(allow_subclasses=allow_subclasses)
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker    def test_unknown_type(self):
573*da0073e9SAndroid Build Coastguard Worker        actual = "0"
574*da0073e9SAndroid Build Coastguard Worker        expected = "0"
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
577*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, str(type(actual))):
578*da0073e9SAndroid Build Coastguard Worker                fn()
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_shape(self):
581*da0073e9SAndroid Build Coastguard Worker        actual = torch.empty(())
582*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone().reshape((1,))
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
585*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, "shape"):
586*da0073e9SAndroid Build Coastguard Worker                fn()
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.")
589*da0073e9SAndroid Build Coastguard Worker    def test_unknown_layout(self):
590*da0073e9SAndroid Build Coastguard Worker        actual = torch.empty((2, 2))
591*da0073e9SAndroid Build Coastguard Worker        expected = actual.to_mkldnn()
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
594*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "layout"):
595*da0073e9SAndroid Build Coastguard Worker                fn()
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker    def test_meta(self):
598*da0073e9SAndroid Build Coastguard Worker        actual = torch.empty((2, 2), device="meta")
599*da0073e9SAndroid Build Coastguard Worker        expected = torch.empty((2, 2), device="meta")
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
602*da0073e9SAndroid Build Coastguard Worker            fn()
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_layout(self):
605*da0073e9SAndroid Build Coastguard Worker        strided = torch.empty((2, 2))
606*da0073e9SAndroid Build Coastguard Worker        sparse_coo = strided.to_sparse()
607*da0073e9SAndroid Build Coastguard Worker        sparse_csr = strided.to_sparse_csr()
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
610*da0073e9SAndroid Build Coastguard Worker            for fn in assert_close_with_inputs(actual, expected):
611*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(AssertionError, "layout"):
612*da0073e9SAndroid Build Coastguard Worker                    fn()
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_layout_no_check(self):
615*da0073e9SAndroid Build Coastguard Worker        strided = torch.randn((2, 2))
616*da0073e9SAndroid Build Coastguard Worker        sparse_coo = strided.to_sparse()
617*da0073e9SAndroid Build Coastguard Worker        sparse_csr = strided.to_sparse_csr()
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
620*da0073e9SAndroid Build Coastguard Worker            for fn in assert_close_with_inputs(actual, expected):
621*da0073e9SAndroid Build Coastguard Worker                fn(check_layout=False)
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_dtype(self):
624*da0073e9SAndroid Build Coastguard Worker        actual = torch.empty((), dtype=torch.float)
625*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone().to(torch.int)
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
628*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, "dtype"):
629*da0073e9SAndroid Build Coastguard Worker                fn()
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_dtype_no_check(self):
632*da0073e9SAndroid Build Coastguard Worker        actual = torch.ones((), dtype=torch.float)
633*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone().to(torch.int)
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
636*da0073e9SAndroid Build Coastguard Worker            fn(check_dtype=False)
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_stride(self):
639*da0073e9SAndroid Build Coastguard Worker        actual = torch.empty((2, 2))
640*da0073e9SAndroid Build Coastguard Worker        expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
643*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, "stride"):
644*da0073e9SAndroid Build Coastguard Worker                fn(check_stride=True)
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_stride_no_check(self):
647*da0073e9SAndroid Build Coastguard Worker        actual = torch.rand((2, 2))
648*da0073e9SAndroid Build Coastguard Worker        expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
649*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
650*da0073e9SAndroid Build Coastguard Worker            fn()
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker    def test_only_rtol(self):
653*da0073e9SAndroid Build Coastguard Worker        actual = torch.empty(())
654*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
657*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
658*da0073e9SAndroid Build Coastguard Worker                fn(rtol=0.0)
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker    def test_only_atol(self):
661*da0073e9SAndroid Build Coastguard Worker        actual = torch.empty(())
662*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
665*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
666*da0073e9SAndroid Build Coastguard Worker                fn(atol=0.0)
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_values(self):
669*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(1)
670*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(2)
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
673*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
674*da0073e9SAndroid Build Coastguard Worker                fn()
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_values_rtol(self):
677*da0073e9SAndroid Build Coastguard Worker        eps = 1e-3
678*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(1.0)
679*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(1.0 + eps)
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
682*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
683*da0073e9SAndroid Build Coastguard Worker                fn(rtol=eps / 2, atol=0.0)
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_values_atol(self):
686*da0073e9SAndroid Build Coastguard Worker        eps = 1e-3
687*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(0.0)
688*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(eps)
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
691*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
692*da0073e9SAndroid Build Coastguard Worker                fn(rtol=0.0, atol=eps / 2)
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker    def test_matching(self):
695*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(1.0)
696*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(actual, expected)
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker    def test_matching_rtol(self):
701*da0073e9SAndroid Build Coastguard Worker        eps = 1e-3
702*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(1.0)
703*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(1.0 + eps)
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
706*da0073e9SAndroid Build Coastguard Worker            fn(rtol=eps * 2, atol=0.0)
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker    def test_matching_atol(self):
709*da0073e9SAndroid Build Coastguard Worker        eps = 1e-3
710*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(0.0)
711*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(eps)
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
714*da0073e9SAndroid Build Coastguard Worker            fn(rtol=0.0, atol=eps * 2)
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker    # TODO: the code that this test was designed for was removed in https://github.com/pytorch/pytorch/pull/56058
717*da0073e9SAndroid Build Coastguard Worker    #  We need to check if this test is still needed or if this behavior is now enabled by default.
718*da0073e9SAndroid Build Coastguard Worker    def test_matching_conjugate_bit(self):
719*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(complex(1, 1)).conj()
720*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(complex(1, -1))
721*da0073e9SAndroid Build Coastguard Worker
722*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
723*da0073e9SAndroid Build Coastguard Worker            fn()
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker    def test_matching_nan(self):
726*da0073e9SAndroid Build Coastguard Worker        nan = float("NaN")
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker        tests = (
729*da0073e9SAndroid Build Coastguard Worker            (nan, nan),
730*da0073e9SAndroid Build Coastguard Worker            (complex(nan, 0), complex(0, nan)),
731*da0073e9SAndroid Build Coastguard Worker            (complex(nan, nan), complex(nan, 0)),
732*da0073e9SAndroid Build Coastguard Worker            (complex(nan, nan), complex(nan, nan)),
733*da0073e9SAndroid Build Coastguard Worker        )
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker        for actual, expected in tests:
736*da0073e9SAndroid Build Coastguard Worker            for fn in assert_close_with_inputs(actual, expected):
737*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(AssertionError):
738*da0073e9SAndroid Build Coastguard Worker                    fn()
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker    def test_matching_nan_with_equal_nan(self):
741*da0073e9SAndroid Build Coastguard Worker        nan = float("NaN")
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker        tests = (
744*da0073e9SAndroid Build Coastguard Worker            (nan, nan),
745*da0073e9SAndroid Build Coastguard Worker            (complex(nan, 0), complex(0, nan)),
746*da0073e9SAndroid Build Coastguard Worker            (complex(nan, nan), complex(nan, 0)),
747*da0073e9SAndroid Build Coastguard Worker            (complex(nan, nan), complex(nan, nan)),
748*da0073e9SAndroid Build Coastguard Worker        )
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker        for actual, expected in tests:
751*da0073e9SAndroid Build Coastguard Worker            for fn in assert_close_with_inputs(actual, expected):
752*da0073e9SAndroid Build Coastguard Worker                fn(equal_nan=True)
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker    def test_numpy(self):
755*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(2, 2, dtype=torch.float32)
756*da0073e9SAndroid Build Coastguard Worker        actual = tensor.numpy()
757*da0073e9SAndroid Build Coastguard Worker        expected = actual.copy()
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
760*da0073e9SAndroid Build Coastguard Worker            fn()
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker    def test_scalar(self):
763*da0073e9SAndroid Build Coastguard Worker        number = torch.randint(10, size=()).item()
764*da0073e9SAndroid Build Coastguard Worker        for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2):
765*da0073e9SAndroid Build Coastguard Worker            check_dtype = type(actual) is type(expected)
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker            for fn in assert_close_with_inputs(actual, expected):
768*da0073e9SAndroid Build Coastguard Worker                fn(check_dtype=check_dtype)
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker    def test_bool(self):
771*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor([True, False])
772*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
775*da0073e9SAndroid Build Coastguard Worker            fn()
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker    def test_none(self):
778*da0073e9SAndroid Build Coastguard Worker        actual = expected = None
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
781*da0073e9SAndroid Build Coastguard Worker            fn()
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker    def test_none_mismatch(self):
784*da0073e9SAndroid Build Coastguard Worker        expected = None
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker        for actual in (False, 0, torch.nan, torch.tensor(torch.nan)):
787*da0073e9SAndroid Build Coastguard Worker            for fn in assert_close_with_inputs(actual, expected):
788*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(AssertionError):
789*da0073e9SAndroid Build Coastguard Worker                    fn()
790*da0073e9SAndroid Build Coastguard Worker
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker    def test_docstring_examples(self):
793*da0073e9SAndroid Build Coastguard Worker        finder = doctest.DocTestFinder(verbose=False)
794*da0073e9SAndroid Build Coastguard Worker        runner = doctest.DocTestRunner(verbose=False, optionflags=doctest.NORMALIZE_WHITESPACE)
795*da0073e9SAndroid Build Coastguard Worker        globs = dict(torch=torch)
796*da0073e9SAndroid Build Coastguard Worker        doctests = finder.find(torch.testing.assert_close, globs=globs)[0]
797*da0073e9SAndroid Build Coastguard Worker        failures = []
798*da0073e9SAndroid Build Coastguard Worker        runner.run(doctests, out=lambda report: failures.append(report))
799*da0073e9SAndroid Build Coastguard Worker        if failures:
800*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(f"Doctest found {len(failures)} failures:\n\n" + "\n".join(failures))
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker    def test_default_tolerance_selection_mismatching_dtypes(self):
803*da0073e9SAndroid Build Coastguard Worker        # If the default tolerances where selected based on the promoted dtype, i.e. float64,
804*da0073e9SAndroid Build Coastguard Worker        # these tensors wouldn't be considered close.
805*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(0.99, dtype=torch.bfloat16)
806*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(1.0, dtype=torch.float64)
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
809*da0073e9SAndroid Build Coastguard Worker            fn(check_dtype=False)
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker    class UnexpectedException(Exception):
812*da0073e9SAndroid Build Coastguard Worker        """The only purpose of this exception is to test ``assert_close``'s handling of unexpected exceptions. Thus,
813*da0073e9SAndroid Build Coastguard Worker        the test should mock a component to raise this instead of the regular behavior. We avoid using a builtin
814*da0073e9SAndroid Build Coastguard Worker        exception here to avoid triggering possible handling of them.
815*da0073e9SAndroid Build Coastguard Worker        """
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker    @unittest.mock.patch("torch.testing._comparison.TensorLikePair.__init__", side_effect=UnexpectedException)
818*da0073e9SAndroid Build Coastguard Worker    def test_unexpected_error_originate(self, _):
819*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(1.0)
820*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "unexpected exception"):
823*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(actual, expected)
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker    @unittest.mock.patch("torch.testing._comparison.TensorLikePair.compare", side_effect=UnexpectedException)
826*da0073e9SAndroid Build Coastguard Worker    def test_unexpected_error_compare(self, _):
827*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(1.0)
828*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "unexpected exception"):
831*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(actual, expected)
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseMultiDevice(TestCase):
837*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(1)
838*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_device(self, devices):
839*da0073e9SAndroid Build Coastguard Worker        for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
840*da0073e9SAndroid Build Coastguard Worker            actual = torch.empty((), device=actual_device)
841*da0073e9SAndroid Build Coastguard Worker            expected = actual.clone().to(expected_device)
842*da0073e9SAndroid Build Coastguard Worker            for fn in assert_close_with_inputs(actual, expected):
843*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(AssertionError, "device"):
844*da0073e9SAndroid Build Coastguard Worker                    fn()
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(1)
847*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_device_no_check(self, devices):
848*da0073e9SAndroid Build Coastguard Worker        for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
849*da0073e9SAndroid Build Coastguard Worker            actual = torch.rand((), device=actual_device)
850*da0073e9SAndroid Build Coastguard Worker            expected = actual.clone().to(expected_device)
851*da0073e9SAndroid Build Coastguard Worker            for fn in assert_close_with_inputs(actual, expected):
852*da0073e9SAndroid Build Coastguard Worker                fn(check_device=False)
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestAssertCloseMultiDevice, globals(), only_for="cuda")
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseErrorMessage(TestCase):
859*da0073e9SAndroid Build Coastguard Worker    def test_identifier_tensor_likes(self):
860*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor([1, 2, 3, 4])
861*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([1, 2, 5, 6])
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
864*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Tensor-likes")):
865*da0073e9SAndroid Build Coastguard Worker                fn()
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker    def test_identifier_scalars(self):
868*da0073e9SAndroid Build Coastguard Worker        actual = 3
869*da0073e9SAndroid Build Coastguard Worker        expected = 5
870*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
871*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Scalars")):
872*da0073e9SAndroid Build Coastguard Worker                fn()
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker    def test_not_equal(self):
875*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
876*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32)
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
879*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("not equal")):
880*da0073e9SAndroid Build Coastguard Worker                fn(rtol=0.0, atol=0.0)
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker    def test_not_close(self):
883*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
884*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32)
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker        for fn, (rtol, atol) in itertools.product(
887*da0073e9SAndroid Build Coastguard Worker            assert_close_with_inputs(actual, expected), ((1.3e-6, 0.0), (0.0, 1e-5), (1.3e-6, 1e-5))
888*da0073e9SAndroid Build Coastguard Worker        ):
889*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("not close")):
890*da0073e9SAndroid Build Coastguard Worker                fn(rtol=rtol, atol=atol)
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker    def test_mismatched_elements(self):
893*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor([1, 2, 3, 4])
894*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([1, 2, 5, 6])
895*da0073e9SAndroid Build Coastguard Worker
896*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
897*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
898*da0073e9SAndroid Build Coastguard Worker                fn()
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker    def test_abs_diff(self):
901*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor([[1, 2], [3, 4]])
902*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([[1, 2], [5, 4]])
903*da0073e9SAndroid Build Coastguard Worker
904*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
905*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at index (1, 0)")):
906*da0073e9SAndroid Build Coastguard Worker                fn()
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker    def test_abs_diff_scalar(self):
909*da0073e9SAndroid Build Coastguard Worker        actual = 3
910*da0073e9SAndroid Build Coastguard Worker        expected = 5
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
913*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Absolute difference: 2")):
914*da0073e9SAndroid Build Coastguard Worker                fn()
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker    def test_rel_diff(self):
917*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor([[1, 2], [3, 4]])
918*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([[1, 4], [3, 4]])
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
921*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at index (0, 1)")):
922*da0073e9SAndroid Build Coastguard Worker                fn()
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker    def test_rel_diff_scalar(self):
925*da0073e9SAndroid Build Coastguard Worker        actual = 2
926*da0073e9SAndroid Build Coastguard Worker        expected = 4
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
929*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Relative difference: 0.5")):
930*da0073e9SAndroid Build Coastguard Worker                fn()
931*da0073e9SAndroid Build Coastguard Worker
932*da0073e9SAndroid Build Coastguard Worker    def test_zero_div_zero(self):
933*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor([1.0, 0.0])
934*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([2.0, 0.0])
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
937*da0073e9SAndroid Build Coastguard Worker            # Although it looks complicated, this regex just makes sure that the word 'nan' is not part of the error
938*da0073e9SAndroid Build Coastguard Worker            # message. That would happen if the 0 / 0 is used for the mismatch computation although it matches.
939*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, "((?!nan).)*"):
940*da0073e9SAndroid Build Coastguard Worker                fn()
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker    def test_rtol(self):
943*da0073e9SAndroid Build Coastguard Worker        rtol = 1e-3
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor((1, 2))
946*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor((2, 2))
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
949*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {rtol} allowed)")):
950*da0073e9SAndroid Build Coastguard Worker                fn(rtol=rtol, atol=0.0)
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker    def test_atol(self):
953*da0073e9SAndroid Build Coastguard Worker        atol = 1e-3
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor((1, 2))
956*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor((2, 2))
957*da0073e9SAndroid Build Coastguard Worker
958*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
959*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {atol} allowed)")):
960*da0073e9SAndroid Build Coastguard Worker                fn(rtol=0.0, atol=atol)
961*da0073e9SAndroid Build Coastguard Worker
962*da0073e9SAndroid Build Coastguard Worker    def test_msg_str(self):
963*da0073e9SAndroid Build Coastguard Worker        msg = "Custom error message!"
964*da0073e9SAndroid Build Coastguard Worker
965*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(1)
966*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(2)
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
969*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, msg):
970*da0073e9SAndroid Build Coastguard Worker                fn(msg=msg)
971*da0073e9SAndroid Build Coastguard Worker
972*da0073e9SAndroid Build Coastguard Worker    def test_msg_callable(self):
973*da0073e9SAndroid Build Coastguard Worker        msg = "Custom error message"
974*da0073e9SAndroid Build Coastguard Worker
975*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(1)
976*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(2)
977*da0073e9SAndroid Build Coastguard Worker
978*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
979*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, msg):
980*da0073e9SAndroid Build Coastguard Worker                fn(msg=lambda _: msg)
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseContainer(TestCase):
984*da0073e9SAndroid Build Coastguard Worker    def test_sequence_mismatching_len(self):
985*da0073e9SAndroid Build Coastguard Worker        actual = (torch.empty(()),)
986*da0073e9SAndroid Build Coastguard Worker        expected = ()
987*da0073e9SAndroid Build Coastguard Worker
988*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
989*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(actual, expected)
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker    def test_sequence_mismatching_values_msg(self):
992*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor(1)
993*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor(2)
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker        actual = (t1, t1)
996*da0073e9SAndroid Build Coastguard Worker        expected = (t1, t2)
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, re.escape("item [1]")):
999*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(actual, expected)
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker    def test_mapping_mismatching_keys(self):
1002*da0073e9SAndroid Build Coastguard Worker        actual = {"a": torch.empty(())}
1003*da0073e9SAndroid Build Coastguard Worker        expected = {}
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
1006*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(actual, expected)
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Worker    def test_mapping_mismatching_values_msg(self):
1009*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor(1)
1010*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor(2)
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker        actual = {"a": t1, "b": t1}
1013*da0073e9SAndroid Build Coastguard Worker        expected = {"a": t1, "b": t2}
1014*da0073e9SAndroid Build Coastguard Worker
1015*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, re.escape("item ['b']")):
1016*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(actual, expected)
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseSparseCOO(TestCase):
1020*da0073e9SAndroid Build Coastguard Worker    def test_matching_coalesced(self):
1021*da0073e9SAndroid Build Coastguard Worker        indices = (
1022*da0073e9SAndroid Build Coastguard Worker            (0, 1),
1023*da0073e9SAndroid Build Coastguard Worker            (1, 0),
1024*da0073e9SAndroid Build Coastguard Worker        )
1025*da0073e9SAndroid Build Coastguard Worker        values = (1, 2)
1026*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce()
1027*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1030*da0073e9SAndroid Build Coastguard Worker            fn()
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker    def test_matching_uncoalesced(self):
1033*da0073e9SAndroid Build Coastguard Worker        indices = (
1034*da0073e9SAndroid Build Coastguard Worker            (0, 1),
1035*da0073e9SAndroid Build Coastguard Worker            (1, 0),
1036*da0073e9SAndroid Build Coastguard Worker        )
1037*da0073e9SAndroid Build Coastguard Worker        values = (1, 2)
1038*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
1039*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1042*da0073e9SAndroid Build Coastguard Worker            fn()
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_sparse_dims(self):
1045*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(2, 3, 4)
1046*da0073e9SAndroid Build Coastguard Worker        actual = t.to_sparse()
1047*da0073e9SAndroid Build Coastguard Worker        expected = t.to_sparse(2)
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1050*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")):
1051*da0073e9SAndroid Build Coastguard Worker                fn()
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_nnz(self):
1054*da0073e9SAndroid Build Coastguard Worker        actual_indices = (
1055*da0073e9SAndroid Build Coastguard Worker            (0, 1),
1056*da0073e9SAndroid Build Coastguard Worker            (1, 0),
1057*da0073e9SAndroid Build Coastguard Worker        )
1058*da0073e9SAndroid Build Coastguard Worker        actual_values = (1, 2)
1059*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        expected_indices = (
1062*da0073e9SAndroid Build Coastguard Worker            (0, 1, 1,),
1063*da0073e9SAndroid Build Coastguard Worker            (1, 0, 0,),
1064*da0073e9SAndroid Build Coastguard Worker        )
1065*da0073e9SAndroid Build Coastguard Worker        expected_values = (1, 1, 1)
1066*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1069*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("number of specified values in sparse COO tensors")):
1070*da0073e9SAndroid Build Coastguard Worker                fn()
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_indices_msg(self):
1073*da0073e9SAndroid Build Coastguard Worker        actual_indices = (
1074*da0073e9SAndroid Build Coastguard Worker            (0, 1),
1075*da0073e9SAndroid Build Coastguard Worker            (1, 0),
1076*da0073e9SAndroid Build Coastguard Worker        )
1077*da0073e9SAndroid Build Coastguard Worker        actual_values = (1, 2)
1078*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker        expected_indices = (
1081*da0073e9SAndroid Build Coastguard Worker            (0, 1),
1082*da0073e9SAndroid Build Coastguard Worker            (1, 1),
1083*da0073e9SAndroid Build Coastguard Worker        )
1084*da0073e9SAndroid Build Coastguard Worker        expected_values = (1, 2)
1085*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1086*da0073e9SAndroid Build Coastguard Worker
1087*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1088*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO indices")):
1089*da0073e9SAndroid Build Coastguard Worker                fn()
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_values_msg(self):
1092*da0073e9SAndroid Build Coastguard Worker        actual_indices = (
1093*da0073e9SAndroid Build Coastguard Worker            (0, 1),
1094*da0073e9SAndroid Build Coastguard Worker            (1, 0),
1095*da0073e9SAndroid Build Coastguard Worker        )
1096*da0073e9SAndroid Build Coastguard Worker        actual_values = (1, 2)
1097*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker        expected_indices = (
1100*da0073e9SAndroid Build Coastguard Worker            (0, 1),
1101*da0073e9SAndroid Build Coastguard Worker            (1, 0),
1102*da0073e9SAndroid Build Coastguard Worker        )
1103*da0073e9SAndroid Build Coastguard Worker        expected_values = (1, 3)
1104*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1107*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO values")):
1108*da0073e9SAndroid Build Coastguard Worker                fn()
1109*da0073e9SAndroid Build Coastguard Worker
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing")
1112*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseSparseCSR(TestCase):
1113*da0073e9SAndroid Build Coastguard Worker    def test_matching(self):
1114*da0073e9SAndroid Build Coastguard Worker        crow_indices = (0, 1, 2)
1115*da0073e9SAndroid Build Coastguard Worker        col_indices = (1, 0)
1116*da0073e9SAndroid Build Coastguard Worker        values = (1, 2)
1117*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
1118*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1121*da0073e9SAndroid Build Coastguard Worker            fn()
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_crow_indices_msg(self):
1124*da0073e9SAndroid Build Coastguard Worker        actual_crow_indices = (0, 1, 2)
1125*da0073e9SAndroid Build Coastguard Worker        actual_col_indices = (0, 1)
1126*da0073e9SAndroid Build Coastguard Worker        actual_values = (1, 2)
1127*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker        expected_crow_indices = (0, 2, 2)
1130*da0073e9SAndroid Build Coastguard Worker        expected_col_indices = actual_col_indices
1131*da0073e9SAndroid Build Coastguard Worker        expected_values = actual_values
1132*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1135*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR crow_indices")):
1136*da0073e9SAndroid Build Coastguard Worker                fn()
1137*da0073e9SAndroid Build Coastguard Worker
1138*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_col_indices_msg(self):
1139*da0073e9SAndroid Build Coastguard Worker        actual_crow_indices = (0, 1, 2)
1140*da0073e9SAndroid Build Coastguard Worker        actual_col_indices = (1, 0)
1141*da0073e9SAndroid Build Coastguard Worker        actual_values = (1, 2)
1142*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1143*da0073e9SAndroid Build Coastguard Worker
1144*da0073e9SAndroid Build Coastguard Worker        expected_crow_indices = actual_crow_indices
1145*da0073e9SAndroid Build Coastguard Worker        expected_col_indices = (1, 1)
1146*da0073e9SAndroid Build Coastguard Worker        expected_values = actual_values
1147*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1150*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR col_indices")):
1151*da0073e9SAndroid Build Coastguard Worker                fn()
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_values_msg(self):
1154*da0073e9SAndroid Build Coastguard Worker        actual_crow_indices = (0, 1, 2)
1155*da0073e9SAndroid Build Coastguard Worker        actual_col_indices = (1, 0)
1156*da0073e9SAndroid Build Coastguard Worker        actual_values = (1, 2)
1157*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Worker        expected_crow_indices = actual_crow_indices
1160*da0073e9SAndroid Build Coastguard Worker        expected_col_indices = actual_col_indices
1161*da0073e9SAndroid Build Coastguard Worker        expected_values = (1, 3)
1162*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1165*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR values")):
1166*da0073e9SAndroid Build Coastguard Worker                fn()
1167*da0073e9SAndroid Build Coastguard Worker
1168*da0073e9SAndroid Build Coastguard Worker
1169*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSC testing")
1170*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseSparseCSC(TestCase):
1171*da0073e9SAndroid Build Coastguard Worker    def test_matching(self):
1172*da0073e9SAndroid Build Coastguard Worker        ccol_indices = (0, 1, 2)
1173*da0073e9SAndroid Build Coastguard Worker        row_indices = (1, 0)
1174*da0073e9SAndroid Build Coastguard Worker        values = (1, 2)
1175*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_csc_tensor(ccol_indices, row_indices, values, size=(2, 2))
1176*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
1177*da0073e9SAndroid Build Coastguard Worker
1178*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1179*da0073e9SAndroid Build Coastguard Worker            fn()
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_ccol_indices_msg(self):
1182*da0073e9SAndroid Build Coastguard Worker        actual_ccol_indices = (0, 1, 2)
1183*da0073e9SAndroid Build Coastguard Worker        actual_row_indices = (0, 1)
1184*da0073e9SAndroid Build Coastguard Worker        actual_values = (1, 2)
1185*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker        expected_ccol_indices = (0, 2, 2)
1188*da0073e9SAndroid Build Coastguard Worker        expected_row_indices = actual_row_indices
1189*da0073e9SAndroid Build Coastguard Worker        expected_values = actual_values
1190*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1191*da0073e9SAndroid Build Coastguard Worker
1192*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1193*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC ccol_indices")):
1194*da0073e9SAndroid Build Coastguard Worker                fn()
1195*da0073e9SAndroid Build Coastguard Worker
1196*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_row_indices_msg(self):
1197*da0073e9SAndroid Build Coastguard Worker        actual_ccol_indices = (0, 1, 2)
1198*da0073e9SAndroid Build Coastguard Worker        actual_row_indices = (1, 0)
1199*da0073e9SAndroid Build Coastguard Worker        actual_values = (1, 2)
1200*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1201*da0073e9SAndroid Build Coastguard Worker
1202*da0073e9SAndroid Build Coastguard Worker        expected_ccol_indices = actual_ccol_indices
1203*da0073e9SAndroid Build Coastguard Worker        expected_row_indices = (1, 1)
1204*da0073e9SAndroid Build Coastguard Worker        expected_values = actual_values
1205*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1206*da0073e9SAndroid Build Coastguard Worker
1207*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1208*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC row_indices")):
1209*da0073e9SAndroid Build Coastguard Worker                fn()
1210*da0073e9SAndroid Build Coastguard Worker
1211*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_values_msg(self):
1212*da0073e9SAndroid Build Coastguard Worker        actual_ccol_indices = (0, 1, 2)
1213*da0073e9SAndroid Build Coastguard Worker        actual_row_indices = (1, 0)
1214*da0073e9SAndroid Build Coastguard Worker        actual_values = (1, 2)
1215*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1216*da0073e9SAndroid Build Coastguard Worker
1217*da0073e9SAndroid Build Coastguard Worker        expected_ccol_indices = actual_ccol_indices
1218*da0073e9SAndroid Build Coastguard Worker        expected_row_indices = actual_row_indices
1219*da0073e9SAndroid Build Coastguard Worker        expected_values = (1, 3)
1220*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1221*da0073e9SAndroid Build Coastguard Worker
1222*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1223*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC values")):
1224*da0073e9SAndroid Build Coastguard Worker                fn()
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Worker
1227*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSR testing")
1228*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseSparseBSR(TestCase):
1229*da0073e9SAndroid Build Coastguard Worker    def test_matching(self):
1230*da0073e9SAndroid Build Coastguard Worker        crow_indices = (0, 1, 2)
1231*da0073e9SAndroid Build Coastguard Worker        col_indices = (1, 0)
1232*da0073e9SAndroid Build Coastguard Worker        values = ([[1]], [[2]])
1233*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(2, 2))
1234*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1237*da0073e9SAndroid Build Coastguard Worker            fn()
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_crow_indices_msg(self):
1240*da0073e9SAndroid Build Coastguard Worker        actual_crow_indices = (0, 1, 2)
1241*da0073e9SAndroid Build Coastguard Worker        actual_col_indices = (0, 1)
1242*da0073e9SAndroid Build Coastguard Worker        actual_values = ([[1]], [[2]])
1243*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker        expected_crow_indices = (0, 2, 2)
1246*da0073e9SAndroid Build Coastguard Worker        expected_col_indices = actual_col_indices
1247*da0073e9SAndroid Build Coastguard Worker        expected_values = actual_values
1248*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1249*da0073e9SAndroid Build Coastguard Worker
1250*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1251*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR crow_indices")):
1252*da0073e9SAndroid Build Coastguard Worker                fn()
1253*da0073e9SAndroid Build Coastguard Worker
1254*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_col_indices_msg(self):
1255*da0073e9SAndroid Build Coastguard Worker        actual_crow_indices = (0, 1, 2)
1256*da0073e9SAndroid Build Coastguard Worker        actual_col_indices = (1, 0)
1257*da0073e9SAndroid Build Coastguard Worker        actual_values = ([[1]], [[2]])
1258*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker        expected_crow_indices = actual_crow_indices
1261*da0073e9SAndroid Build Coastguard Worker        expected_col_indices = (1, 1)
1262*da0073e9SAndroid Build Coastguard Worker        expected_values = actual_values
1263*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1264*da0073e9SAndroid Build Coastguard Worker
1265*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1266*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR col_indices")):
1267*da0073e9SAndroid Build Coastguard Worker                fn()
1268*da0073e9SAndroid Build Coastguard Worker
1269*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_values_msg(self):
1270*da0073e9SAndroid Build Coastguard Worker        actual_crow_indices = (0, 1, 2)
1271*da0073e9SAndroid Build Coastguard Worker        actual_col_indices = (1, 0)
1272*da0073e9SAndroid Build Coastguard Worker        actual_values = ([[1]], [[2]])
1273*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker        expected_crow_indices = actual_crow_indices
1276*da0073e9SAndroid Build Coastguard Worker        expected_col_indices = actual_col_indices
1277*da0073e9SAndroid Build Coastguard Worker        expected_values = ([[1]], [[3]])
1278*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1281*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR values")):
1282*da0073e9SAndroid Build Coastguard Worker                fn()
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker
1285*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSC testing")
1286*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseSparseBSC(TestCase):
1287*da0073e9SAndroid Build Coastguard Worker    def test_matching(self):
1288*da0073e9SAndroid Build Coastguard Worker        ccol_indices = (0, 1, 2)
1289*da0073e9SAndroid Build Coastguard Worker        row_indices = (1, 0)
1290*da0073e9SAndroid Build Coastguard Worker        values = ([[1]], [[2]])
1291*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_bsc_tensor(ccol_indices, row_indices, values, size=(2, 2))
1292*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1295*da0073e9SAndroid Build Coastguard Worker            fn()
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_ccol_indices_msg(self):
1298*da0073e9SAndroid Build Coastguard Worker        actual_ccol_indices = (0, 1, 2)
1299*da0073e9SAndroid Build Coastguard Worker        actual_row_indices = (0, 1)
1300*da0073e9SAndroid Build Coastguard Worker        actual_values = ([[1]], [[2]])
1301*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker        expected_ccol_indices = (0, 2, 2)
1304*da0073e9SAndroid Build Coastguard Worker        expected_row_indices = actual_row_indices
1305*da0073e9SAndroid Build Coastguard Worker        expected_values = actual_values
1306*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1307*da0073e9SAndroid Build Coastguard Worker
1308*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1309*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC ccol_indices")):
1310*da0073e9SAndroid Build Coastguard Worker                fn()
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_row_indices_msg(self):
1313*da0073e9SAndroid Build Coastguard Worker        actual_ccol_indices = (0, 1, 2)
1314*da0073e9SAndroid Build Coastguard Worker        actual_row_indices = (1, 0)
1315*da0073e9SAndroid Build Coastguard Worker        actual_values = ([[1]], [[2]])
1316*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1317*da0073e9SAndroid Build Coastguard Worker
1318*da0073e9SAndroid Build Coastguard Worker        expected_ccol_indices = actual_ccol_indices
1319*da0073e9SAndroid Build Coastguard Worker        expected_row_indices = (1, 1)
1320*da0073e9SAndroid Build Coastguard Worker        expected_values = actual_values
1321*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1324*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC row_indices")):
1325*da0073e9SAndroid Build Coastguard Worker                fn()
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_values_msg(self):
1328*da0073e9SAndroid Build Coastguard Worker        actual_ccol_indices = (0, 1, 2)
1329*da0073e9SAndroid Build Coastguard Worker        actual_row_indices = (1, 0)
1330*da0073e9SAndroid Build Coastguard Worker        actual_values = ([[1]], [[2]])
1331*da0073e9SAndroid Build Coastguard Worker        actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
1332*da0073e9SAndroid Build Coastguard Worker
1333*da0073e9SAndroid Build Coastguard Worker        expected_ccol_indices = actual_ccol_indices
1334*da0073e9SAndroid Build Coastguard Worker        expected_row_indices = actual_row_indices
1335*da0073e9SAndroid Build Coastguard Worker        expected_values = ([[1]], [[3]])
1336*da0073e9SAndroid Build Coastguard Worker        expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1339*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC values")):
1340*da0073e9SAndroid Build Coastguard Worker                fn()
1341*da0073e9SAndroid Build Coastguard Worker
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Workerclass TestAssertCloseQuantized(TestCase):
1344*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_is_quantized(self):
1345*da0073e9SAndroid Build Coastguard Worker        actual = torch.tensor(1.0)
1346*da0073e9SAndroid Build Coastguard Worker        expected = torch.quantize_per_tensor(actual, scale=1.0, zero_point=0, dtype=torch.qint32)
1347*da0073e9SAndroid Build Coastguard Worker
1348*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1349*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, "is_quantized"):
1350*da0073e9SAndroid Build Coastguard Worker                fn()
1351*da0073e9SAndroid Build Coastguard Worker
1352*da0073e9SAndroid Build Coastguard Worker    def test_mismatching_qscheme(self):
1353*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor((1.0,))
1354*da0073e9SAndroid Build Coastguard Worker        actual = torch.quantize_per_tensor(t, scale=1.0, zero_point=0, dtype=torch.qint32)
1355*da0073e9SAndroid Build Coastguard Worker        expected = torch.quantize_per_channel(
1356*da0073e9SAndroid Build Coastguard Worker            t,
1357*da0073e9SAndroid Build Coastguard Worker            scales=torch.tensor((1.0,)),
1358*da0073e9SAndroid Build Coastguard Worker            zero_points=torch.tensor((0,)),
1359*da0073e9SAndroid Build Coastguard Worker            axis=0,
1360*da0073e9SAndroid Build Coastguard Worker            dtype=torch.qint32,
1361*da0073e9SAndroid Build Coastguard Worker        )
1362*da0073e9SAndroid Build Coastguard Worker
1363*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1364*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, "qscheme"):
1365*da0073e9SAndroid Build Coastguard Worker                fn()
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker    def test_matching_per_tensor(self):
1368*da0073e9SAndroid Build Coastguard Worker        actual = torch.quantize_per_tensor(torch.tensor(1.0), scale=1.0, zero_point=0, dtype=torch.qint32)
1369*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
1370*da0073e9SAndroid Build Coastguard Worker
1371*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1372*da0073e9SAndroid Build Coastguard Worker            fn()
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker    def test_matching_per_channel(self):
1375*da0073e9SAndroid Build Coastguard Worker        actual = torch.quantize_per_channel(
1376*da0073e9SAndroid Build Coastguard Worker            torch.tensor((1.0,)),
1377*da0073e9SAndroid Build Coastguard Worker            scales=torch.tensor((1.0,)),
1378*da0073e9SAndroid Build Coastguard Worker            zero_points=torch.tensor((0,)),
1379*da0073e9SAndroid Build Coastguard Worker            axis=0,
1380*da0073e9SAndroid Build Coastguard Worker            dtype=torch.qint32,
1381*da0073e9SAndroid Build Coastguard Worker        )
1382*da0073e9SAndroid Build Coastguard Worker        expected = actual.clone()
1383*da0073e9SAndroid Build Coastguard Worker
1384*da0073e9SAndroid Build Coastguard Worker        for fn in assert_close_with_inputs(actual, expected):
1385*da0073e9SAndroid Build Coastguard Worker            fn()
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Workerclass TestMakeTensor(TestCase):
1389*da0073e9SAndroid Build Coastguard Worker    supported_dtypes = dtypes(
1390*da0073e9SAndroid Build Coastguard Worker        torch.bool,
1391*da0073e9SAndroid Build Coastguard Worker        torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
1392*da0073e9SAndroid Build Coastguard Worker        torch.float16, torch.bfloat16, torch.float32, torch.float64,
1393*da0073e9SAndroid Build Coastguard Worker        torch.complex32, torch.complex64, torch.complex128,
1394*da0073e9SAndroid Build Coastguard Worker    )
1395*da0073e9SAndroid Build Coastguard Worker
1396*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1397*da0073e9SAndroid Build Coastguard Worker    @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
1398*da0073e9SAndroid Build Coastguard Worker    @parametrize("splat_shape", [False, True])
1399*da0073e9SAndroid Build Coastguard Worker    def test_smoke(self, dtype, device, shape, splat_shape):
1400*da0073e9SAndroid Build Coastguard Worker        t = torch.testing.make_tensor(*shape if splat_shape else shape, dtype=dtype, device=device)
1401*da0073e9SAndroid Build Coastguard Worker
1402*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(t, torch.Tensor)
1403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.shape, shape)
1404*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.dtype, dtype)
1405*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.device, torch.device(device))
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1408*da0073e9SAndroid Build Coastguard Worker    @parametrize("requires_grad", [False, True])
1409*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad(self, dtype, device, requires_grad):
1410*da0073e9SAndroid Build Coastguard Worker        make_tensor = functools.partial(
1411*da0073e9SAndroid Build Coastguard Worker            torch.testing.make_tensor,
1412*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
1413*da0073e9SAndroid Build Coastguard Worker            device=device,
1414*da0073e9SAndroid Build Coastguard Worker            requires_grad=requires_grad,
1415*da0073e9SAndroid Build Coastguard Worker        )
1416*da0073e9SAndroid Build Coastguard Worker
1417*da0073e9SAndroid Build Coastguard Worker        if not requires_grad or dtype.is_floating_point or dtype.is_complex:
1418*da0073e9SAndroid Build Coastguard Worker            t = make_tensor()
1419*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.requires_grad, requires_grad)
1420*da0073e9SAndroid Build Coastguard Worker        else:
1421*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1422*da0073e9SAndroid Build Coastguard Worker                    ValueError, "`requires_grad=True` is not supported for boolean and integral dtypes"
1423*da0073e9SAndroid Build Coastguard Worker            ):
1424*da0073e9SAndroid Build Coastguard Worker                make_tensor()
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1427*da0073e9SAndroid Build Coastguard Worker    @parametrize("noncontiguous", [False, True])
1428*da0073e9SAndroid Build Coastguard Worker    @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
1429*da0073e9SAndroid Build Coastguard Worker    def test_noncontiguous(self, dtype, device, noncontiguous, shape):
1430*da0073e9SAndroid Build Coastguard Worker        numel = functools.reduce(operator.mul, shape, 1)
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker        t = torch.testing.make_tensor(shape, dtype=dtype, device=device, noncontiguous=noncontiguous)
1433*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.is_contiguous(), not noncontiguous or numel < 2)
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1436*da0073e9SAndroid Build Coastguard Worker    @parametrize(
1437*da0073e9SAndroid Build Coastguard Worker        "memory_format_and_shape",
1438*da0073e9SAndroid Build Coastguard Worker        [
1439*da0073e9SAndroid Build Coastguard Worker            (None, (2, 3, 4)),
1440*da0073e9SAndroid Build Coastguard Worker            (torch.contiguous_format, (2, 3, 4)),
1441*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last, (2, 3, 4, 5)),
1442*da0073e9SAndroid Build Coastguard Worker            (torch.channels_last_3d, (2, 3, 4, 5, 6)),
1443*da0073e9SAndroid Build Coastguard Worker            (torch.preserve_format, (2, 3, 4)),
1444*da0073e9SAndroid Build Coastguard Worker        ],
1445*da0073e9SAndroid Build Coastguard Worker    )
1446*da0073e9SAndroid Build Coastguard Worker    def test_memory_format(self, dtype, device, memory_format_and_shape):
1447*da0073e9SAndroid Build Coastguard Worker        memory_format, shape = memory_format_and_shape
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker        t = torch.testing.make_tensor(shape, dtype=dtype, device=device, memory_format=memory_format)
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
1452*da0073e9SAndroid Build Coastguard Worker            t.is_contiguous(memory_format=torch.contiguous_format if memory_format is None else memory_format)
1453*da0073e9SAndroid Build Coastguard Worker        )
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1456*da0073e9SAndroid Build Coastguard Worker    def test_noncontiguous_memory_format(self, dtype, device):
1457*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "`noncontiguous` and `memory_format` are mutually exclusive"):
1458*da0073e9SAndroid Build Coastguard Worker            torch.testing.make_tensor(
1459*da0073e9SAndroid Build Coastguard Worker                (2, 3, 4, 5),
1460*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
1461*da0073e9SAndroid Build Coastguard Worker                device=device,
1462*da0073e9SAndroid Build Coastguard Worker                noncontiguous=True,
1463*da0073e9SAndroid Build Coastguard Worker                memory_format=torch.channels_last,
1464*da0073e9SAndroid Build Coastguard Worker            )
1465*da0073e9SAndroid Build Coastguard Worker
1466*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1467*da0073e9SAndroid Build Coastguard Worker    def test_exclude_zero(self, dtype, device):
1468*da0073e9SAndroid Build Coastguard Worker        t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, exclude_zero=True, low=-1, high=2)
1469*da0073e9SAndroid Build Coastguard Worker
1470*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((t != 0).all())
1471*da0073e9SAndroid Build Coastguard Worker
1472*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1473*da0073e9SAndroid Build Coastguard Worker    def test_low_high_smoke(self, dtype, device):
1474*da0073e9SAndroid Build Coastguard Worker        low_inclusive, high_exclusive = 0, 2
1475*da0073e9SAndroid Build Coastguard Worker
1476*da0073e9SAndroid Build Coastguard Worker        t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
1477*da0073e9SAndroid Build Coastguard Worker        if dtype.is_complex:
1478*da0073e9SAndroid Build Coastguard Worker            t = torch.view_as_real(t)
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all())
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1483*da0073e9SAndroid Build Coastguard Worker    def test_low_high_default_smoke(self, dtype, device):
1484*da0073e9SAndroid Build Coastguard Worker        low_inclusive, high_exclusive = {
1485*da0073e9SAndroid Build Coastguard Worker            torch.bool: (0, 2),
1486*da0073e9SAndroid Build Coastguard Worker            torch.uint8: (0, 10),
1487*da0073e9SAndroid Build Coastguard Worker            **dict.fromkeys([torch.int8, torch.int16, torch.int32, torch.int64], (-9, 10)),
1488*da0073e9SAndroid Build Coastguard Worker        }.get(dtype, (-9, 9))
1489*da0073e9SAndroid Build Coastguard Worker
1490*da0073e9SAndroid Build Coastguard Worker        t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
1491*da0073e9SAndroid Build Coastguard Worker        if dtype.is_complex:
1492*da0073e9SAndroid Build Coastguard Worker            t = torch.view_as_real(t)
1493*da0073e9SAndroid Build Coastguard Worker
1494*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all())
1495*da0073e9SAndroid Build Coastguard Worker
1496*da0073e9SAndroid Build Coastguard Worker    @parametrize("low_high", [(0, 0), (1, 0), (0, -1)])
1497*da0073e9SAndroid Build Coastguard Worker    @parametrize("value_types", list(itertools.product([int, float], repeat=2)))
1498*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1499*da0073e9SAndroid Build Coastguard Worker    def test_low_ge_high(self, dtype, device, low_high, value_types):
1500*da0073e9SAndroid Build Coastguard Worker        low, high = (value_type(value) for value, value_type in zip(low_high, value_types))
1501*da0073e9SAndroid Build Coastguard Worker
1502*da0073e9SAndroid Build Coastguard Worker        if low == high and (dtype.is_floating_point or dtype.is_complex):
1503*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(
1504*da0073e9SAndroid Build Coastguard Worker                    FutureWarning,
1505*da0073e9SAndroid Build Coastguard Worker                    "Passing `low==high` to `torch.testing.make_tensor` for floating or complex types is deprecated",
1506*da0073e9SAndroid Build Coastguard Worker            ):
1507*da0073e9SAndroid Build Coastguard Worker                t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low, high=high)
1508*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t, torch.full_like(t, complex(low, low) if dtype.is_complex else low))
1509*da0073e9SAndroid Build Coastguard Worker        else:
1510*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "`low` must be less than `high`"):
1511*da0073e9SAndroid Build Coastguard Worker                torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1514*da0073e9SAndroid Build Coastguard Worker    @parametrize("low_high", [(None, torch.nan), (torch.nan, None), (torch.nan, torch.nan)])
1515*da0073e9SAndroid Build Coastguard Worker    def test_low_high_nan(self, dtype, device, low_high):
1516*da0073e9SAndroid Build Coastguard Worker        low, high = low_high
1517*da0073e9SAndroid Build Coastguard Worker
1518*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "`low` and `high` cannot be NaN"):
1519*da0073e9SAndroid Build Coastguard Worker            torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)
1520*da0073e9SAndroid Build Coastguard Worker
1521*da0073e9SAndroid Build Coastguard Worker    @supported_dtypes
1522*da0073e9SAndroid Build Coastguard Worker    def test_low_high_outside_valid_range(self, dtype, device):
1523*da0073e9SAndroid Build Coastguard Worker        make_tensor = functools.partial(torch.testing.make_tensor, dtype=dtype, device=device)
1524*da0073e9SAndroid Build Coastguard Worker
1525*da0073e9SAndroid Build Coastguard Worker        def get_dtype_limits(dtype):
1526*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.bool:
1527*da0073e9SAndroid Build Coastguard Worker                return 0, 1
1528*da0073e9SAndroid Build Coastguard Worker
1529*da0073e9SAndroid Build Coastguard Worker            info = (torch.finfo if dtype.is_floating_point or dtype.is_complex else torch.iinfo)(dtype)
1530*da0073e9SAndroid Build Coastguard Worker            # We are using integer bounds here, because otherwise it would be impossible to pass `low` and `high`
1531*da0073e9SAndroid Build Coastguard Worker            # outside their valid range. Python uses 64bit floating point numbers and thus trying to do something like
1532*da0073e9SAndroid Build Coastguard Worker            # `torch.ffinfo(torch.float64)max * 2` will always result in `inf`. On the flipside, Pythons `int` is
1533*da0073e9SAndroid Build Coastguard Worker            # unbounded.
1534*da0073e9SAndroid Build Coastguard Worker            return int(info.min), int(info.max)
1535*da0073e9SAndroid Build Coastguard Worker
1536*da0073e9SAndroid Build Coastguard Worker        lowest_inclusive, highest_inclusive = get_dtype_limits(dtype)
1537*da0073e9SAndroid Build Coastguard Worker
1538*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, ""):
1539*da0073e9SAndroid Build Coastguard Worker            low, high = (-2, -1) if lowest_inclusive == 0 else (lowest_inclusive * 4, lowest_inclusive * 2)
1540*da0073e9SAndroid Build Coastguard Worker            make_tensor(low=low, high=high)
1541*da0073e9SAndroid Build Coastguard Worker
1542*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, ""):
1543*da0073e9SAndroid Build Coastguard Worker            make_tensor(low=highest_inclusive * 2, high=highest_inclusive * 4)
1544*da0073e9SAndroid Build Coastguard Worker
1545*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
1546*da0073e9SAndroid Build Coastguard Worker    def test_low_high_boolean_integral1(self, dtype, device):
1547*da0073e9SAndroid Build Coastguard Worker        shape = (10_000,)
1548*da0073e9SAndroid Build Coastguard Worker        eps = 1e-4
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker        actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=-(1 - eps), high=1 - eps)
1551*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(shape, dtype=dtype, device=device)
1552*da0073e9SAndroid Build Coastguard Worker
1553*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(actual, expected)
1554*da0073e9SAndroid Build Coastguard Worker
1555*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
1556*da0073e9SAndroid Build Coastguard Worker    def test_low_high_boolean_integral2(self, dtype, device):
1557*da0073e9SAndroid Build Coastguard Worker        shape = (10_000,)
1558*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.bool:
1559*da0073e9SAndroid Build Coastguard Worker            low = 1
1560*da0073e9SAndroid Build Coastguard Worker        elif dtype is torch.int64:
1561*da0073e9SAndroid Build Coastguard Worker            # Due to its internals, `make_tensor` is not able to sample `torch.iinfo(torch.int64).max`
1562*da0073e9SAndroid Build Coastguard Worker            low = torch.iinfo(dtype).max - 1
1563*da0073e9SAndroid Build Coastguard Worker        else:
1564*da0073e9SAndroid Build Coastguard Worker            low = torch.iinfo(dtype).max
1565*da0073e9SAndroid Build Coastguard Worker        high = low + 1
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker        actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
1568*da0073e9SAndroid Build Coastguard Worker        expected = torch.full(shape, low, dtype=dtype, device=device)
1569*da0073e9SAndroid Build Coastguard Worker
1570*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(actual, expected)
1571*da0073e9SAndroid Build Coastguard Worker
1572*da0073e9SAndroid Build Coastguard Worker
1573*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMakeTensor, globals())
1574*da0073e9SAndroid Build Coastguard Worker
1575*da0073e9SAndroid Build Coastguard Worker
1576*da0073e9SAndroid Build Coastguard Workerdef _get_test_names_for_test_class(test_cls):
1577*da0073e9SAndroid Build Coastguard Worker    """ Convenience function to get all test names for a given test class. """
1578*da0073e9SAndroid Build Coastguard Worker    test_names = [f'{test_cls.__name__}.{key}' for key in test_cls.__dict__
1579*da0073e9SAndroid Build Coastguard Worker                  if key.startswith('test_')]
1580*da0073e9SAndroid Build Coastguard Worker    return sorted(test_names)
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker
1583*da0073e9SAndroid Build Coastguard Workerdef _get_test_funcs_for_test_class(test_cls):
1584*da0073e9SAndroid Build Coastguard Worker    """ Convenience function to get all (test function, parametrized_name) pairs for a given test class. """
1585*da0073e9SAndroid Build Coastguard Worker    test_funcs = [(getattr(test_cls, key), key) for key in test_cls.__dict__ if key.startswith('test_')]
1586*da0073e9SAndroid Build Coastguard Worker    return test_funcs
1587*da0073e9SAndroid Build Coastguard Worker
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Workerclass TestTestParametrization(TestCase):
1590*da0073e9SAndroid Build Coastguard Worker    def test_default_names(self):
1591*da0073e9SAndroid Build Coastguard Worker
1592*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1593*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", range(5))
1594*da0073e9SAndroid Build Coastguard Worker            def test_default_names(self, x):
1595*da0073e9SAndroid Build Coastguard Worker                pass
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker            @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
1598*da0073e9SAndroid Build Coastguard Worker            def test_two_things_default_names(self, x, y):
1599*da0073e9SAndroid Build Coastguard Worker                pass
1600*da0073e9SAndroid Build Coastguard Worker
1601*da0073e9SAndroid Build Coastguard Worker        instantiate_parametrized_tests(TestParametrized)
1602*da0073e9SAndroid Build Coastguard Worker
1603*da0073e9SAndroid Build Coastguard Worker        expected_test_names = [
1604*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_default_names_x_0',
1605*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_default_names_x_1',
1606*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_default_names_x_2',
1607*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_default_names_x_3',
1608*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_default_names_x_4',
1609*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_two_things_default_names_x_1_y_2',
1610*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_two_things_default_names_x_2_y_3',
1611*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_two_things_default_names_x_3_y_4',
1612*da0073e9SAndroid Build Coastguard Worker        ]
1613*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(TestParametrized)
1614*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_test_names, test_names)
1615*da0073e9SAndroid Build Coastguard Worker
1616*da0073e9SAndroid Build Coastguard Worker    def test_name_fn(self):
1617*da0073e9SAndroid Build Coastguard Worker
1618*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1619*da0073e9SAndroid Build Coastguard Worker            @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
1620*da0073e9SAndroid Build Coastguard Worker            def test_custom_names(self, bias):
1621*da0073e9SAndroid Build Coastguard Worker                pass
1622*da0073e9SAndroid Build Coastguard Worker
1623*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", [1, 2], name_fn=str)
1624*da0073e9SAndroid Build Coastguard Worker            @parametrize("y", [3, 4], name_fn=str)
1625*da0073e9SAndroid Build Coastguard Worker            @parametrize("z", [5, 6], name_fn=str)
1626*da0073e9SAndroid Build Coastguard Worker            def test_three_things_composition_custom_names(self, x, y, z):
1627*da0073e9SAndroid Build Coastguard Worker                pass
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker            @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}')
1630*da0073e9SAndroid Build Coastguard Worker            def test_two_things_custom_names_alternate(self, x, y):
1631*da0073e9SAndroid Build Coastguard Worker                pass
1632*da0073e9SAndroid Build Coastguard Worker
1633*da0073e9SAndroid Build Coastguard Worker        instantiate_parametrized_tests(TestParametrized)
1634*da0073e9SAndroid Build Coastguard Worker
1635*da0073e9SAndroid Build Coastguard Worker        expected_test_names = [
1636*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_custom_names_bias',
1637*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_custom_names_no_bias',
1638*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_three_things_composition_custom_names_1_3_5',
1639*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_three_things_composition_custom_names_1_3_6',
1640*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_three_things_composition_custom_names_1_4_5',
1641*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_three_things_composition_custom_names_1_4_6',
1642*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_three_things_composition_custom_names_2_3_5',
1643*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_three_things_composition_custom_names_2_3_6',
1644*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_three_things_composition_custom_names_2_4_5',
1645*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_three_things_composition_custom_names_2_4_6',
1646*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_two_things_custom_names_alternate_1__2',
1647*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_two_things_custom_names_alternate_1__3',
1648*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_two_things_custom_names_alternate_1__4',
1649*da0073e9SAndroid Build Coastguard Worker        ]
1650*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(TestParametrized)
1651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_test_names, test_names)
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker    def test_subtest_names(self):
1654*da0073e9SAndroid Build Coastguard Worker
1655*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1656*da0073e9SAndroid Build Coastguard Worker            @parametrize("bias", [subtest(True, name='bias'),
1657*da0073e9SAndroid Build Coastguard Worker                                  subtest(False, name='no_bias')])
1658*da0073e9SAndroid Build Coastguard Worker            def test_custom_names(self, bias):
1659*da0073e9SAndroid Build Coastguard Worker                pass
1660*da0073e9SAndroid Build Coastguard Worker
1661*da0073e9SAndroid Build Coastguard Worker            @parametrize("x,y", [subtest((1, 2), name='double'),
1662*da0073e9SAndroid Build Coastguard Worker                                 subtest((1, 3), name='triple'),
1663*da0073e9SAndroid Build Coastguard Worker                                 subtest((1, 4), name='quadruple')])
1664*da0073e9SAndroid Build Coastguard Worker            def test_two_things_custom_names(self, x, y):
1665*da0073e9SAndroid Build Coastguard Worker                pass
1666*da0073e9SAndroid Build Coastguard Worker
1667*da0073e9SAndroid Build Coastguard Worker        instantiate_parametrized_tests(TestParametrized)
1668*da0073e9SAndroid Build Coastguard Worker
1669*da0073e9SAndroid Build Coastguard Worker        expected_test_names = [
1670*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_custom_names_bias',
1671*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_custom_names_no_bias',
1672*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_two_things_custom_names_double',
1673*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_two_things_custom_names_quadruple',
1674*da0073e9SAndroid Build Coastguard Worker            'TestParametrized.test_two_things_custom_names_triple',
1675*da0073e9SAndroid Build Coastguard Worker        ]
1676*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(TestParametrized)
1677*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_test_names, test_names)
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker    def test_apply_param_specific_decorators(self):
1680*da0073e9SAndroid Build Coastguard Worker        # Test that decorators can be applied on a per-param basis.
1681*da0073e9SAndroid Build Coastguard Worker
1682*da0073e9SAndroid Build Coastguard Worker        def test_dec(func):
1683*da0073e9SAndroid Build Coastguard Worker            func._decorator_applied = True
1684*da0073e9SAndroid Build Coastguard Worker            return func
1685*da0073e9SAndroid Build Coastguard Worker
1686*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1687*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", [subtest(1, name='one'),
1688*da0073e9SAndroid Build Coastguard Worker                               subtest(2, name='two', decorators=[test_dec]),
1689*da0073e9SAndroid Build Coastguard Worker                               subtest(3, name='three')])
1690*da0073e9SAndroid Build Coastguard Worker            def test_param(self, x):
1691*da0073e9SAndroid Build Coastguard Worker                pass
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker        instantiate_parametrized_tests(TestParametrized)
1694*da0073e9SAndroid Build Coastguard Worker
1695*da0073e9SAndroid Build Coastguard Worker        for test_func, name in _get_test_funcs_for_test_class(TestParametrized):
1696*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(hasattr(test_func, '_decorator_applied'), name == 'test_param_two')
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker    def test_compose_param_specific_decorators(self):
1699*da0073e9SAndroid Build Coastguard Worker        # Test that multiple per-param decorators compose correctly.
1700*da0073e9SAndroid Build Coastguard Worker
1701*da0073e9SAndroid Build Coastguard Worker        def test_dec(func):
1702*da0073e9SAndroid Build Coastguard Worker            func._decorator_applied = True
1703*da0073e9SAndroid Build Coastguard Worker            return func
1704*da0073e9SAndroid Build Coastguard Worker
1705*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1706*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", [subtest(1),
1707*da0073e9SAndroid Build Coastguard Worker                               subtest(2, decorators=[test_dec]),
1708*da0073e9SAndroid Build Coastguard Worker                               subtest(3)])
1709*da0073e9SAndroid Build Coastguard Worker            @parametrize("y", [subtest(False, decorators=[test_dec]),
1710*da0073e9SAndroid Build Coastguard Worker                               subtest(True)])
1711*da0073e9SAndroid Build Coastguard Worker            def test_param(self, x, y):
1712*da0073e9SAndroid Build Coastguard Worker                pass
1713*da0073e9SAndroid Build Coastguard Worker
1714*da0073e9SAndroid Build Coastguard Worker        instantiate_parametrized_tests(TestParametrized)
1715*da0073e9SAndroid Build Coastguard Worker
1716*da0073e9SAndroid Build Coastguard Worker        for test_func, name in _get_test_funcs_for_test_class(TestParametrized):
1717*da0073e9SAndroid Build Coastguard Worker            # Decorator should be applied whenever either x == 2 or y == False.
1718*da0073e9SAndroid Build Coastguard Worker            should_apply = ('x_2' in name) or ('y_False' in name)
1719*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
1720*da0073e9SAndroid Build Coastguard Worker
1721*da0073e9SAndroid Build Coastguard Worker    def test_modules_decorator_misuse_error(self):
1722*da0073e9SAndroid Build Coastguard Worker        # Test that @modules errors out when used with instantiate_parametrized_tests().
1723*da0073e9SAndroid Build Coastguard Worker
1724*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1725*da0073e9SAndroid Build Coastguard Worker            @modules(module_db)
1726*da0073e9SAndroid Build Coastguard Worker            def test_modules(self, module_info):
1727*da0073e9SAndroid Build Coastguard Worker                pass
1728*da0073e9SAndroid Build Coastguard Worker
1729*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'):
1730*da0073e9SAndroid Build Coastguard Worker            instantiate_parametrized_tests(TestParametrized)
1731*da0073e9SAndroid Build Coastguard Worker
1732*da0073e9SAndroid Build Coastguard Worker    def test_ops_decorator_misuse_error(self):
1733*da0073e9SAndroid Build Coastguard Worker        # Test that @ops errors out when used with instantiate_parametrized_tests().
1734*da0073e9SAndroid Build Coastguard Worker
1735*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1736*da0073e9SAndroid Build Coastguard Worker            @ops(op_db)
1737*da0073e9SAndroid Build Coastguard Worker            def test_ops(self, module_info):
1738*da0073e9SAndroid Build Coastguard Worker                pass
1739*da0073e9SAndroid Build Coastguard Worker
1740*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'):
1741*da0073e9SAndroid Build Coastguard Worker            instantiate_parametrized_tests(TestParametrized)
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker    def test_multiple_handling_of_same_param_error(self):
1744*da0073e9SAndroid Build Coastguard Worker        # Test that multiple decorators handling the same param errors out.
1745*da0073e9SAndroid Build Coastguard Worker
1746*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1747*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", range(3))
1748*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", range(5))
1749*da0073e9SAndroid Build Coastguard Worker            def test_param(self, x):
1750*da0073e9SAndroid Build Coastguard Worker                pass
1751*da0073e9SAndroid Build Coastguard Worker
1752*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'multiple parametrization decorators'):
1753*da0073e9SAndroid Build Coastguard Worker            instantiate_parametrized_tests(TestParametrized)
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker    @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
1756*da0073e9SAndroid Build Coastguard Worker    def test_subtest_expected_failure(self, x):
1757*da0073e9SAndroid Build Coastguard Worker        if x == 2:
1758*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError('Boom')
1759*da0073e9SAndroid Build Coastguard Worker
1760*da0073e9SAndroid Build Coastguard Worker    @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
1761*da0073e9SAndroid Build Coastguard Worker    @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
1762*da0073e9SAndroid Build Coastguard Worker    def test_two_things_subtest_expected_failure(self, x, y):
1763*da0073e9SAndroid Build Coastguard Worker        if x == 1 or y == 6:
1764*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError('Boom')
1765*da0073e9SAndroid Build Coastguard Worker
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Workerclass TestTestParametrizationDeviceType(TestCase):
1768*da0073e9SAndroid Build Coastguard Worker    def test_unparametrized_names(self, device):
1769*da0073e9SAndroid Build Coastguard Worker        # This test exists to protect against regressions in device / dtype test naming
1770*da0073e9SAndroid Build Coastguard Worker        # due to parametrization logic.
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
1773*da0073e9SAndroid Build Coastguard Worker
1774*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1775*da0073e9SAndroid Build Coastguard Worker            def test_device_specific(self, device):
1776*da0073e9SAndroid Build Coastguard Worker                pass
1777*da0073e9SAndroid Build Coastguard Worker
1778*da0073e9SAndroid Build Coastguard Worker            @dtypes(torch.float32, torch.float64)
1779*da0073e9SAndroid Build Coastguard Worker            def test_device_dtype_specific(self, device, dtype):
1780*da0073e9SAndroid Build Coastguard Worker                pass
1781*da0073e9SAndroid Build Coastguard Worker
1782*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1783*da0073e9SAndroid Build Coastguard Worker
1784*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
1785*da0073e9SAndroid Build Coastguard Worker        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1786*da0073e9SAndroid Build Coastguard Worker            '{}.test_device_dtype_specific_{}_float32',
1787*da0073e9SAndroid Build Coastguard Worker            '{}.test_device_dtype_specific_{}_float64',
1788*da0073e9SAndroid Build Coastguard Worker            '{}.test_device_specific_{}')
1789*da0073e9SAndroid Build Coastguard Worker        ]
1790*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(device_cls)
1791*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_test_names, test_names)
1792*da0073e9SAndroid Build Coastguard Worker
1793*da0073e9SAndroid Build Coastguard Worker    def test_empty_param_names(self, device):
1794*da0073e9SAndroid Build Coastguard Worker        # If no param names are passed, ensure things still work without parametrization.
1795*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
1796*da0073e9SAndroid Build Coastguard Worker
1797*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1798*da0073e9SAndroid Build Coastguard Worker            @parametrize("", [])
1799*da0073e9SAndroid Build Coastguard Worker            def test_foo(self, device):
1800*da0073e9SAndroid Build Coastguard Worker                pass
1801*da0073e9SAndroid Build Coastguard Worker
1802*da0073e9SAndroid Build Coastguard Worker            @parametrize("", range(5))
1803*da0073e9SAndroid Build Coastguard Worker            def test_bar(self, device):
1804*da0073e9SAndroid Build Coastguard Worker                pass
1805*da0073e9SAndroid Build Coastguard Worker
1806*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
1809*da0073e9SAndroid Build Coastguard Worker        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1810*da0073e9SAndroid Build Coastguard Worker            '{}.test_bar_{}',
1811*da0073e9SAndroid Build Coastguard Worker            '{}.test_foo_{}')
1812*da0073e9SAndroid Build Coastguard Worker        ]
1813*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(device_cls)
1814*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_test_names, test_names)
1815*da0073e9SAndroid Build Coastguard Worker
1816*da0073e9SAndroid Build Coastguard Worker    def test_empty_param_list(self, device):
1817*da0073e9SAndroid Build Coastguard Worker        # If no param values are passed, ensure a helpful error message is thrown.
1818*da0073e9SAndroid Build Coastguard Worker        # In the wild, this could indicate reuse of an exhausted generator.
1819*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
1820*da0073e9SAndroid Build Coastguard Worker
1821*da0073e9SAndroid Build Coastguard Worker        generator = (a for a in range(5))
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1824*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", generator)
1825*da0073e9SAndroid Build Coastguard Worker            def test_foo(self, device, x):
1826*da0073e9SAndroid Build Coastguard Worker                pass
1827*da0073e9SAndroid Build Coastguard Worker
1828*da0073e9SAndroid Build Coastguard Worker            # Reuse generator from first test function.
1829*da0073e9SAndroid Build Coastguard Worker            @parametrize("y", generator)
1830*da0073e9SAndroid Build Coastguard Worker            def test_bar(self, device, y):
1831*da0073e9SAndroid Build Coastguard Worker                pass
1832*da0073e9SAndroid Build Coastguard Worker
1833*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, 'An empty arg_values was passed'):
1834*da0073e9SAndroid Build Coastguard Worker            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1835*da0073e9SAndroid Build Coastguard Worker
1836*da0073e9SAndroid Build Coastguard Worker    def test_default_names(self, device):
1837*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
1838*da0073e9SAndroid Build Coastguard Worker
1839*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1840*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", range(5))
1841*da0073e9SAndroid Build Coastguard Worker            def test_default_names(self, device, x):
1842*da0073e9SAndroid Build Coastguard Worker                pass
1843*da0073e9SAndroid Build Coastguard Worker
1844*da0073e9SAndroid Build Coastguard Worker            @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
1845*da0073e9SAndroid Build Coastguard Worker            def test_two_things_default_names(self, device, x, y):
1846*da0073e9SAndroid Build Coastguard Worker                pass
1847*da0073e9SAndroid Build Coastguard Worker
1848*da0073e9SAndroid Build Coastguard Worker
1849*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1850*da0073e9SAndroid Build Coastguard Worker
1851*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
1852*da0073e9SAndroid Build Coastguard Worker        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1853*da0073e9SAndroid Build Coastguard Worker            '{}.test_default_names_x_0_{}',
1854*da0073e9SAndroid Build Coastguard Worker            '{}.test_default_names_x_1_{}',
1855*da0073e9SAndroid Build Coastguard Worker            '{}.test_default_names_x_2_{}',
1856*da0073e9SAndroid Build Coastguard Worker            '{}.test_default_names_x_3_{}',
1857*da0073e9SAndroid Build Coastguard Worker            '{}.test_default_names_x_4_{}',
1858*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_default_names_x_1_y_2_{}',
1859*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_default_names_x_2_y_3_{}',
1860*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_default_names_x_3_y_4_{}')
1861*da0073e9SAndroid Build Coastguard Worker        ]
1862*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(device_cls)
1863*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_test_names, test_names)
1864*da0073e9SAndroid Build Coastguard Worker
1865*da0073e9SAndroid Build Coastguard Worker    def test_default_name_non_primitive(self, device):
1866*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
1867*da0073e9SAndroid Build Coastguard Worker
1868*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1869*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", [1, .5, "foo", object()])
1870*da0073e9SAndroid Build Coastguard Worker            def test_default_names(self, device, x):
1871*da0073e9SAndroid Build Coastguard Worker                pass
1872*da0073e9SAndroid Build Coastguard Worker
1873*da0073e9SAndroid Build Coastguard Worker            @parametrize("x,y", [(1, object()), (object(), .5), (object(), object())])
1874*da0073e9SAndroid Build Coastguard Worker            def test_two_things_default_names(self, device, x, y):
1875*da0073e9SAndroid Build Coastguard Worker                pass
1876*da0073e9SAndroid Build Coastguard Worker
1877*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1878*da0073e9SAndroid Build Coastguard Worker
1879*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
1880*da0073e9SAndroid Build Coastguard Worker        expected_test_names = sorted(name.format(device_cls.__name__, device) for name in (
1881*da0073e9SAndroid Build Coastguard Worker            '{}.test_default_names_x_1_{}',
1882*da0073e9SAndroid Build Coastguard Worker            '{}.test_default_names_x_0_5_{}',
1883*da0073e9SAndroid Build Coastguard Worker            '{}.test_default_names_x_foo_{}',
1884*da0073e9SAndroid Build Coastguard Worker            '{}.test_default_names_x3_{}',
1885*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_default_names_x_1_y0_{}',
1886*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_default_names_x1_y_0_5_{}',
1887*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_default_names_x2_y2_{}')
1888*da0073e9SAndroid Build Coastguard Worker        )
1889*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(device_cls)
1890*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_test_names, test_names)
1891*da0073e9SAndroid Build Coastguard Worker
1892*da0073e9SAndroid Build Coastguard Worker    def test_name_fn(self, device):
1893*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
1894*da0073e9SAndroid Build Coastguard Worker
1895*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1896*da0073e9SAndroid Build Coastguard Worker            @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
1897*da0073e9SAndroid Build Coastguard Worker            def test_custom_names(self, device, bias):
1898*da0073e9SAndroid Build Coastguard Worker                pass
1899*da0073e9SAndroid Build Coastguard Worker
1900*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", [1, 2], name_fn=str)
1901*da0073e9SAndroid Build Coastguard Worker            @parametrize("y", [3, 4], name_fn=str)
1902*da0073e9SAndroid Build Coastguard Worker            @parametrize("z", [5, 6], name_fn=str)
1903*da0073e9SAndroid Build Coastguard Worker            def test_three_things_composition_custom_names(self, device, x, y, z):
1904*da0073e9SAndroid Build Coastguard Worker                pass
1905*da0073e9SAndroid Build Coastguard Worker
1906*da0073e9SAndroid Build Coastguard Worker            @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}')
1907*da0073e9SAndroid Build Coastguard Worker            def test_two_things_custom_names_alternate(self, device, x, y):
1908*da0073e9SAndroid Build Coastguard Worker                pass
1909*da0073e9SAndroid Build Coastguard Worker
1910*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1911*da0073e9SAndroid Build Coastguard Worker
1912*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
1913*da0073e9SAndroid Build Coastguard Worker        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1914*da0073e9SAndroid Build Coastguard Worker            '{}.test_custom_names_bias_{}',
1915*da0073e9SAndroid Build Coastguard Worker            '{}.test_custom_names_no_bias_{}',
1916*da0073e9SAndroid Build Coastguard Worker            '{}.test_three_things_composition_custom_names_1_3_5_{}',
1917*da0073e9SAndroid Build Coastguard Worker            '{}.test_three_things_composition_custom_names_1_3_6_{}',
1918*da0073e9SAndroid Build Coastguard Worker            '{}.test_three_things_composition_custom_names_1_4_5_{}',
1919*da0073e9SAndroid Build Coastguard Worker            '{}.test_three_things_composition_custom_names_1_4_6_{}',
1920*da0073e9SAndroid Build Coastguard Worker            '{}.test_three_things_composition_custom_names_2_3_5_{}',
1921*da0073e9SAndroid Build Coastguard Worker            '{}.test_three_things_composition_custom_names_2_3_6_{}',
1922*da0073e9SAndroid Build Coastguard Worker            '{}.test_three_things_composition_custom_names_2_4_5_{}',
1923*da0073e9SAndroid Build Coastguard Worker            '{}.test_three_things_composition_custom_names_2_4_6_{}',
1924*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_custom_names_alternate_1__2_{}',
1925*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_custom_names_alternate_1__3_{}',
1926*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_custom_names_alternate_1__4_{}')
1927*da0073e9SAndroid Build Coastguard Worker        ]
1928*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(device_cls)
1929*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_test_names, test_names)
1930*da0073e9SAndroid Build Coastguard Worker
1931*da0073e9SAndroid Build Coastguard Worker    def test_subtest_names(self, device):
1932*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
1933*da0073e9SAndroid Build Coastguard Worker
1934*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1935*da0073e9SAndroid Build Coastguard Worker            @parametrize("bias", [subtest(True, name='bias'),
1936*da0073e9SAndroid Build Coastguard Worker                                  subtest(False, name='no_bias')])
1937*da0073e9SAndroid Build Coastguard Worker            def test_custom_names(self, device, bias):
1938*da0073e9SAndroid Build Coastguard Worker                pass
1939*da0073e9SAndroid Build Coastguard Worker
1940*da0073e9SAndroid Build Coastguard Worker            @parametrize("x,y", [subtest((1, 2), name='double'),
1941*da0073e9SAndroid Build Coastguard Worker                                 subtest((1, 3), name='triple'),
1942*da0073e9SAndroid Build Coastguard Worker                                 subtest((1, 4), name='quadruple')])
1943*da0073e9SAndroid Build Coastguard Worker            def test_two_things_custom_names(self, device, x, y):
1944*da0073e9SAndroid Build Coastguard Worker                pass
1945*da0073e9SAndroid Build Coastguard Worker
1946*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
1949*da0073e9SAndroid Build Coastguard Worker        expected_test_names = [name.format(device_cls.__name__, device) for name in (
1950*da0073e9SAndroid Build Coastguard Worker            '{}.test_custom_names_bias_{}',
1951*da0073e9SAndroid Build Coastguard Worker            '{}.test_custom_names_no_bias_{}',
1952*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_custom_names_double_{}',
1953*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_custom_names_quadruple_{}',
1954*da0073e9SAndroid Build Coastguard Worker            '{}.test_two_things_custom_names_triple_{}')
1955*da0073e9SAndroid Build Coastguard Worker        ]
1956*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(device_cls)
1957*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_test_names, test_names)
1958*da0073e9SAndroid Build Coastguard Worker
1959*da0073e9SAndroid Build Coastguard Worker    def test_ops_composition_names(self, device):
1960*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1963*da0073e9SAndroid Build Coastguard Worker            @ops(op_db)
1964*da0073e9SAndroid Build Coastguard Worker            @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
1965*da0073e9SAndroid Build Coastguard Worker            def test_op_parametrized(self, device, dtype, op, flag):
1966*da0073e9SAndroid Build Coastguard Worker                pass
1967*da0073e9SAndroid Build Coastguard Worker
1968*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1969*da0073e9SAndroid Build Coastguard Worker
1970*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
1971*da0073e9SAndroid Build Coastguard Worker        expected_test_names = []
1972*da0073e9SAndroid Build Coastguard Worker        for op in op_db:
1973*da0073e9SAndroid Build Coastguard Worker            for dtype in op.supported_dtypes(torch.device(device).type):
1974*da0073e9SAndroid Build Coastguard Worker                for flag_part in ('flag_disabled', 'flag_enabled'):
1975*da0073e9SAndroid Build Coastguard Worker                    expected_name = f'{device_cls.__name__}.test_op_parametrized_{op.formatted_name}_{flag_part}_{device}_{dtype_name(dtype)}'  # noqa: B950
1976*da0073e9SAndroid Build Coastguard Worker                    expected_test_names.append(expected_name)
1977*da0073e9SAndroid Build Coastguard Worker
1978*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(device_cls)
1979*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sorted(expected_test_names), sorted(test_names))
1980*da0073e9SAndroid Build Coastguard Worker
1981*da0073e9SAndroid Build Coastguard Worker    def test_modules_composition_names(self, device):
1982*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
1983*da0073e9SAndroid Build Coastguard Worker
1984*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
1985*da0073e9SAndroid Build Coastguard Worker            @modules(module_db)
1986*da0073e9SAndroid Build Coastguard Worker            @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
1987*da0073e9SAndroid Build Coastguard Worker            def test_module_parametrized(self, device, dtype, module_info, training, flag):
1988*da0073e9SAndroid Build Coastguard Worker                pass
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
1991*da0073e9SAndroid Build Coastguard Worker
1992*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
1993*da0073e9SAndroid Build Coastguard Worker        expected_test_names = []
1994*da0073e9SAndroid Build Coastguard Worker        for module_info in module_db:
1995*da0073e9SAndroid Build Coastguard Worker            for dtype in module_info.dtypes:
1996*da0073e9SAndroid Build Coastguard Worker                for flag_part in ('flag_disabled', 'flag_enabled'):
1997*da0073e9SAndroid Build Coastguard Worker                    expected_train_modes = (
1998*da0073e9SAndroid Build Coastguard Worker                        ['train_mode', 'eval_mode'] if module_info.train_and_eval_differ else [''])
1999*da0073e9SAndroid Build Coastguard Worker                    for training_part in expected_train_modes:
2000*da0073e9SAndroid Build Coastguard Worker                        expected_name = '{}.test_module_parametrized_{}{}_{}_{}_{}'.format(
2001*da0073e9SAndroid Build Coastguard Worker                            device_cls.__name__, module_info.formatted_name,
2002*da0073e9SAndroid Build Coastguard Worker                            '_' + training_part if len(training_part) > 0 else '',
2003*da0073e9SAndroid Build Coastguard Worker                            flag_part, device, dtype_name(dtype))
2004*da0073e9SAndroid Build Coastguard Worker                        expected_test_names.append(expected_name)
2005*da0073e9SAndroid Build Coastguard Worker
2006*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(device_cls)
2007*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sorted(expected_test_names), sorted(test_names))
2008*da0073e9SAndroid Build Coastguard Worker
2009*da0073e9SAndroid Build Coastguard Worker    def test_ops_decorator_applies_op_and_param_specific_decorators(self, device):
2010*da0073e9SAndroid Build Coastguard Worker        # Test that decorators can be applied on a per-op / per-param basis.
2011*da0073e9SAndroid Build Coastguard Worker
2012*da0073e9SAndroid Build Coastguard Worker        # Create a test op, OpInfo entry, and decorator to apply.
2013*da0073e9SAndroid Build Coastguard Worker        def test_op(x):
2014*da0073e9SAndroid Build Coastguard Worker            return -x
2015*da0073e9SAndroid Build Coastguard Worker
2016*da0073e9SAndroid Build Coastguard Worker        def test_dec(func):
2017*da0073e9SAndroid Build Coastguard Worker            func._decorator_applied = True
2018*da0073e9SAndroid Build Coastguard Worker            return func
2019*da0073e9SAndroid Build Coastguard Worker
2020*da0073e9SAndroid Build Coastguard Worker        test_op_info = OpInfo(
2021*da0073e9SAndroid Build Coastguard Worker            'test_op',
2022*da0073e9SAndroid Build Coastguard Worker            op=test_op,
2023*da0073e9SAndroid Build Coastguard Worker            dtypes=floating_types(),
2024*da0073e9SAndroid Build Coastguard Worker            sample_inputs_func=lambda _: [],
2025*da0073e9SAndroid Build Coastguard Worker            decorators=[
2026*da0073e9SAndroid Build Coastguard Worker                DecorateInfo(test_dec, 'TestParametrized', 'test_op_param',
2027*da0073e9SAndroid Build Coastguard Worker                             device_type='cpu', dtypes=[torch.float64],
2028*da0073e9SAndroid Build Coastguard Worker                             active_if=lambda p: p['x'] == 2)
2029*da0073e9SAndroid Build Coastguard Worker            ])
2030*da0073e9SAndroid Build Coastguard Worker
2031*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
2032*da0073e9SAndroid Build Coastguard Worker            @ops(op_db + [test_op_info])
2033*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", [2, 3])
2034*da0073e9SAndroid Build Coastguard Worker            def test_op_param(self, device, dtype, op, x):
2035*da0073e9SAndroid Build Coastguard Worker                pass
2036*da0073e9SAndroid Build Coastguard Worker
2037*da0073e9SAndroid Build Coastguard Worker            @ops(op_db + [test_op_info])
2038*da0073e9SAndroid Build Coastguard Worker            @parametrize("y", [
2039*da0073e9SAndroid Build Coastguard Worker                subtest(4),
2040*da0073e9SAndroid Build Coastguard Worker                subtest(5, decorators=[test_dec])])
2041*da0073e9SAndroid Build Coastguard Worker            def test_other(self, device, dtype, op, y):
2042*da0073e9SAndroid Build Coastguard Worker                pass
2043*da0073e9SAndroid Build Coastguard Worker
2044*da0073e9SAndroid Build Coastguard Worker            @decorateIf(test_dec, lambda p: p['dtype'] == torch.int16)
2045*da0073e9SAndroid Build Coastguard Worker            @ops(op_db)
2046*da0073e9SAndroid Build Coastguard Worker            def test_three(self, device, dtype, op):
2047*da0073e9SAndroid Build Coastguard Worker                pass
2048*da0073e9SAndroid Build Coastguard Worker
2049*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
2050*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2051*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
2052*da0073e9SAndroid Build Coastguard Worker
2053*da0073e9SAndroid Build Coastguard Worker        for test_func, name in _get_test_funcs_for_test_class(device_cls):
2054*da0073e9SAndroid Build Coastguard Worker            should_apply = (name == 'test_op_param_test_op_x_2_cpu_float64' or
2055*da0073e9SAndroid Build Coastguard Worker                            ('test_other' in name and 'y_5' in name) or
2056*da0073e9SAndroid Build Coastguard Worker                            ('test_three' in name and name.endswith('_int16')))
2057*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2058*da0073e9SAndroid Build Coastguard Worker
2059*da0073e9SAndroid Build Coastguard Worker    def test_modules_decorator_applies_module_and_param_specific_decorators(self, device):
2060*da0073e9SAndroid Build Coastguard Worker        # Test that decorators can be applied on a per-module / per-param basis.
2061*da0073e9SAndroid Build Coastguard Worker
2062*da0073e9SAndroid Build Coastguard Worker        # Create a test module, ModuleInfo entry, and decorator to apply.
2063*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
2064*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2065*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2066*da0073e9SAndroid Build Coastguard Worker                self.x = torch.nn.Parameter(torch.randn(3))
2067*da0073e9SAndroid Build Coastguard Worker
2068*da0073e9SAndroid Build Coastguard Worker            def forward(self, y):
2069*da0073e9SAndroid Build Coastguard Worker                return self.x + y
2070*da0073e9SAndroid Build Coastguard Worker
2071*da0073e9SAndroid Build Coastguard Worker        def test_dec(func):
2072*da0073e9SAndroid Build Coastguard Worker            func._decorator_applied = True
2073*da0073e9SAndroid Build Coastguard Worker            return func
2074*da0073e9SAndroid Build Coastguard Worker
2075*da0073e9SAndroid Build Coastguard Worker        test_module_info = ModuleInfo(
2076*da0073e9SAndroid Build Coastguard Worker            TestModule,
2077*da0073e9SAndroid Build Coastguard Worker            module_inputs_func=lambda _: [],
2078*da0073e9SAndroid Build Coastguard Worker            decorators=[
2079*da0073e9SAndroid Build Coastguard Worker                DecorateInfo(test_dec, 'TestParametrized', 'test_module_param',
2080*da0073e9SAndroid Build Coastguard Worker                             device_type='cpu', dtypes=[torch.float64],
2081*da0073e9SAndroid Build Coastguard Worker                             active_if=lambda p: p['x'] == 2)
2082*da0073e9SAndroid Build Coastguard Worker            ])
2083*da0073e9SAndroid Build Coastguard Worker
2084*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
2085*da0073e9SAndroid Build Coastguard Worker            @modules(module_db + [test_module_info])
2086*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", [2, 3])
2087*da0073e9SAndroid Build Coastguard Worker            def test_module_param(self, device, dtype, module_info, training, x):
2088*da0073e9SAndroid Build Coastguard Worker                pass
2089*da0073e9SAndroid Build Coastguard Worker
2090*da0073e9SAndroid Build Coastguard Worker            @modules(module_db + [test_module_info])
2091*da0073e9SAndroid Build Coastguard Worker            @parametrize("y", [
2092*da0073e9SAndroid Build Coastguard Worker                subtest(4),
2093*da0073e9SAndroid Build Coastguard Worker                subtest(5, decorators=[test_dec])])
2094*da0073e9SAndroid Build Coastguard Worker            def test_other(self, device, dtype, module_info, training, y):
2095*da0073e9SAndroid Build Coastguard Worker                pass
2096*da0073e9SAndroid Build Coastguard Worker
2097*da0073e9SAndroid Build Coastguard Worker            @decorateIf(test_dec, lambda p: p['dtype'] == torch.float64)
2098*da0073e9SAndroid Build Coastguard Worker            @modules(module_db)
2099*da0073e9SAndroid Build Coastguard Worker            def test_three(self, device, dtype, module_info):
2100*da0073e9SAndroid Build Coastguard Worker                pass
2101*da0073e9SAndroid Build Coastguard Worker
2102*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
2103*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2104*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
2105*da0073e9SAndroid Build Coastguard Worker
2106*da0073e9SAndroid Build Coastguard Worker        for test_func, name in _get_test_funcs_for_test_class(device_cls):
2107*da0073e9SAndroid Build Coastguard Worker            should_apply = (name == 'test_module_param_TestModule_x_2_cpu_float64' or
2108*da0073e9SAndroid Build Coastguard Worker                            ('test_other' in name and 'y_5' in name) or
2109*da0073e9SAndroid Build Coastguard Worker                            ('test_three' in name and name.endswith('float64')))
2110*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2111*da0073e9SAndroid Build Coastguard Worker
2112*da0073e9SAndroid Build Coastguard Worker    def test_param_specific_decoration(self, device):
2113*da0073e9SAndroid Build Coastguard Worker
2114*da0073e9SAndroid Build Coastguard Worker        def test_dec(func):
2115*da0073e9SAndroid Build Coastguard Worker            func._decorator_applied = True
2116*da0073e9SAndroid Build Coastguard Worker            return func
2117*da0073e9SAndroid Build Coastguard Worker
2118*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
2119*da0073e9SAndroid Build Coastguard Worker            @decorateIf(test_dec, lambda params: params["x"] == 1 and params["y"])
2120*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", range(5))
2121*da0073e9SAndroid Build Coastguard Worker            @parametrize("y", [False, True])
2122*da0073e9SAndroid Build Coastguard Worker            def test_param(self, x, y):
2123*da0073e9SAndroid Build Coastguard Worker                pass
2124*da0073e9SAndroid Build Coastguard Worker
2125*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
2126*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2127*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
2128*da0073e9SAndroid Build Coastguard Worker
2129*da0073e9SAndroid Build Coastguard Worker        for test_func, name in _get_test_funcs_for_test_class(device_cls):
2130*da0073e9SAndroid Build Coastguard Worker            should_apply = ('test_param_x_1_y_True' in name)
2131*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
2132*da0073e9SAndroid Build Coastguard Worker
2133*da0073e9SAndroid Build Coastguard Worker    def test_dtypes_composition_valid(self, device):
2134*da0073e9SAndroid Build Coastguard Worker        # Test checks that @parametrize and @dtypes compose as expected when @parametrize
2135*da0073e9SAndroid Build Coastguard Worker        # doesn't set dtype.
2136*da0073e9SAndroid Build Coastguard Worker
2137*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
2138*da0073e9SAndroid Build Coastguard Worker
2139*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
2140*da0073e9SAndroid Build Coastguard Worker            @dtypes(torch.float32, torch.float64)
2141*da0073e9SAndroid Build Coastguard Worker            @parametrize("x", range(3))
2142*da0073e9SAndroid Build Coastguard Worker            def test_parametrized(self, x, dtype):
2143*da0073e9SAndroid Build Coastguard Worker                pass
2144*da0073e9SAndroid Build Coastguard Worker
2145*da0073e9SAndroid Build Coastguard Worker        instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2146*da0073e9SAndroid Build Coastguard Worker
2147*da0073e9SAndroid Build Coastguard Worker        device_cls = locals()[f'TestParametrized{device.upper()}']
2148*da0073e9SAndroid Build Coastguard Worker        expected_test_names = [name.format(device_cls.__name__, device) for name in (
2149*da0073e9SAndroid Build Coastguard Worker            '{}.test_parametrized_x_0_{}_float32',
2150*da0073e9SAndroid Build Coastguard Worker            '{}.test_parametrized_x_0_{}_float64',
2151*da0073e9SAndroid Build Coastguard Worker            '{}.test_parametrized_x_1_{}_float32',
2152*da0073e9SAndroid Build Coastguard Worker            '{}.test_parametrized_x_1_{}_float64',
2153*da0073e9SAndroid Build Coastguard Worker            '{}.test_parametrized_x_2_{}_float32',
2154*da0073e9SAndroid Build Coastguard Worker            '{}.test_parametrized_x_2_{}_float64')
2155*da0073e9SAndroid Build Coastguard Worker        ]
2156*da0073e9SAndroid Build Coastguard Worker        test_names = _get_test_names_for_test_class(device_cls)
2157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sorted(expected_test_names), sorted(test_names))
2158*da0073e9SAndroid Build Coastguard Worker
2159*da0073e9SAndroid Build Coastguard Worker    def test_dtypes_composition_invalid(self, device):
2160*da0073e9SAndroid Build Coastguard Worker        # Test checks that @dtypes cannot be composed with parametrization decorators when they
2161*da0073e9SAndroid Build Coastguard Worker        # also try to set dtype.
2162*da0073e9SAndroid Build Coastguard Worker
2163*da0073e9SAndroid Build Coastguard Worker        device = self.device_type
2164*da0073e9SAndroid Build Coastguard Worker
2165*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
2166*da0073e9SAndroid Build Coastguard Worker            @dtypes(torch.float32, torch.float64)
2167*da0073e9SAndroid Build Coastguard Worker            @parametrize("dtype", [torch.int32, torch.int64])
2168*da0073e9SAndroid Build Coastguard Worker            def test_parametrized(self, dtype):
2169*da0073e9SAndroid Build Coastguard Worker                pass
2170*da0073e9SAndroid Build Coastguard Worker
2171*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2172*da0073e9SAndroid Build Coastguard Worker            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2173*da0073e9SAndroid Build Coastguard Worker
2174*da0073e9SAndroid Build Coastguard Worker        # Verify proper error behavior with @ops + @dtypes, as both try to set dtype.
2175*da0073e9SAndroid Build Coastguard Worker
2176*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
2177*da0073e9SAndroid Build Coastguard Worker            @dtypes(torch.float32, torch.float64)
2178*da0073e9SAndroid Build Coastguard Worker            @ops(op_db)
2179*da0073e9SAndroid Build Coastguard Worker            def test_parametrized(self, op, dtype):
2180*da0073e9SAndroid Build Coastguard Worker                pass
2181*da0073e9SAndroid Build Coastguard Worker
2182*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2183*da0073e9SAndroid Build Coastguard Worker            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2184*da0073e9SAndroid Build Coastguard Worker
2185*da0073e9SAndroid Build Coastguard Worker    def test_multiple_handling_of_same_param_error(self, device):
2186*da0073e9SAndroid Build Coastguard Worker        # Test that multiple decorators handling the same param errors out.
2187*da0073e9SAndroid Build Coastguard Worker        # Both @modules and @ops handle the dtype param.
2188*da0073e9SAndroid Build Coastguard Worker
2189*da0073e9SAndroid Build Coastguard Worker        class TestParametrized(TestCase):
2190*da0073e9SAndroid Build Coastguard Worker            @ops(op_db)
2191*da0073e9SAndroid Build Coastguard Worker            @modules(module_db)
2192*da0073e9SAndroid Build Coastguard Worker            def test_param(self, device, dtype, op, module_info, training):
2193*da0073e9SAndroid Build Coastguard Worker                pass
2194*da0073e9SAndroid Build Coastguard Worker
2195*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
2196*da0073e9SAndroid Build Coastguard Worker            instantiate_device_type_tests(TestParametrized, locals(), only_for=device)
2197*da0073e9SAndroid Build Coastguard Worker
2198*da0073e9SAndroid Build Coastguard Worker    @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
2199*da0073e9SAndroid Build Coastguard Worker    def test_subtest_expected_failure(self, device, x):
2200*da0073e9SAndroid Build Coastguard Worker        if x == 2:
2201*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError('Boom')
2202*da0073e9SAndroid Build Coastguard Worker
2203*da0073e9SAndroid Build Coastguard Worker    @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
2204*da0073e9SAndroid Build Coastguard Worker    @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
2205*da0073e9SAndroid Build Coastguard Worker    def test_two_things_subtest_expected_failure(self, device, x, y):
2206*da0073e9SAndroid Build Coastguard Worker        if x == 1 or y == 6:
2207*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError('Boom')
2208*da0073e9SAndroid Build Coastguard Worker
2209*da0073e9SAndroid Build Coastguard Worker
2210*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestTestParametrization)
2211*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTestParametrizationDeviceType, globals())
2212*da0073e9SAndroid Build Coastguard Worker
2213*da0073e9SAndroid Build Coastguard Worker
2214*da0073e9SAndroid Build Coastguard Workerclass TestImports(TestCase):
2215*da0073e9SAndroid Build Coastguard Worker    @classmethod
2216*da0073e9SAndroid Build Coastguard Worker    def _check_python_output(cls, program) -> str:
2217*da0073e9SAndroid Build Coastguard Worker        return subprocess.check_output(
2218*da0073e9SAndroid Build Coastguard Worker            [sys.executable, "-W", "always", "-c", program],
2219*da0073e9SAndroid Build Coastguard Worker            stderr=subprocess.STDOUT,
2220*da0073e9SAndroid Build Coastguard Worker            # On Windows, opening the subprocess with the default CWD makes `import torch`
2221*da0073e9SAndroid Build Coastguard Worker            # fail, so just set CWD to this script's directory
2222*da0073e9SAndroid Build Coastguard Worker            cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
2223*da0073e9SAndroid Build Coastguard Worker
2224*da0073e9SAndroid Build Coastguard Worker    def test_circular_dependencies(self) -> None:
2225*da0073e9SAndroid Build Coastguard Worker        """ Checks that all modules inside torch can be imported
2226*da0073e9SAndroid Build Coastguard Worker        Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """
2227*da0073e9SAndroid Build Coastguard Worker        ignored_modules = ["torch.utils.tensorboard",  # deps on tensorboard
2228*da0073e9SAndroid Build Coastguard Worker                           "torch.distributed.elastic.rendezvous",  # depps on etcd
2229*da0073e9SAndroid Build Coastguard Worker                           "torch.backends._coreml",  # depends on pycoreml
2230*da0073e9SAndroid Build Coastguard Worker                           "torch.contrib.",  # something weird
2231*da0073e9SAndroid Build Coastguard Worker                           "torch.testing._internal.distributed.",  # just fails
2232*da0073e9SAndroid Build Coastguard Worker                           "torch.ao.pruning._experimental.",  # depends on pytorch_lightning, not user-facing
2233*da0073e9SAndroid Build Coastguard Worker                           "torch.onnx._internal",  # depends on onnx-script
2234*da0073e9SAndroid Build Coastguard Worker                           "torch._inductor.runtime.triton_helpers",  # depends on triton
2235*da0073e9SAndroid Build Coastguard Worker                           "torch._inductor.codegen.cuda",  # depends on cutlass
2236*da0073e9SAndroid Build Coastguard Worker                           ]
2237*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/77801
2238*da0073e9SAndroid Build Coastguard Worker        if not sys.version_info >= (3, 9):
2239*da0073e9SAndroid Build Coastguard Worker            ignored_modules.append("torch.utils.benchmark")
2240*da0073e9SAndroid Build Coastguard Worker        if IS_WINDOWS or IS_MACOS or IS_JETSON:
2241*da0073e9SAndroid Build Coastguard Worker            # Distributed should be importable on Windows(except nn.api.), but not on Mac
2242*da0073e9SAndroid Build Coastguard Worker            if IS_MACOS or IS_JETSON:
2243*da0073e9SAndroid Build Coastguard Worker                ignored_modules.append("torch.distributed.")
2244*da0073e9SAndroid Build Coastguard Worker            else:
2245*da0073e9SAndroid Build Coastguard Worker                ignored_modules.append("torch.distributed.nn.api.")
2246*da0073e9SAndroid Build Coastguard Worker                ignored_modules.append("torch.distributed.optim.")
2247*da0073e9SAndroid Build Coastguard Worker                ignored_modules.append("torch.distributed.rpc.")
2248*da0073e9SAndroid Build Coastguard Worker            ignored_modules.append("torch.testing._internal.dist_utils")
2249*da0073e9SAndroid Build Coastguard Worker            # And these both end up with transitive dependencies on distributed
2250*da0073e9SAndroid Build Coastguard Worker            ignored_modules.append("torch.nn.parallel._replicated_tensor_ddp_interop")
2251*da0073e9SAndroid Build Coastguard Worker            ignored_modules.append("torch.testing._internal.common_fsdp")
2252*da0073e9SAndroid Build Coastguard Worker            ignored_modules.append("torch.testing._internal.common_distributed")
2253*da0073e9SAndroid Build Coastguard Worker
2254*da0073e9SAndroid Build Coastguard Worker        torch_dir = os.path.dirname(torch.__file__)
2255*da0073e9SAndroid Build Coastguard Worker        for base, folders, files in os.walk(torch_dir):
2256*da0073e9SAndroid Build Coastguard Worker            prefix = os.path.relpath(base, os.path.dirname(torch_dir)).replace(os.path.sep, ".")
2257*da0073e9SAndroid Build Coastguard Worker            for f in files:
2258*da0073e9SAndroid Build Coastguard Worker                if not f.endswith(".py"):
2259*da0073e9SAndroid Build Coastguard Worker                    continue
2260*da0073e9SAndroid Build Coastguard Worker                mod_name = f"{prefix}.{f[:-3]}" if f != "__init__.py" else prefix
2261*da0073e9SAndroid Build Coastguard Worker                # Do not attempt to import executable modules
2262*da0073e9SAndroid Build Coastguard Worker                if f == "__main__.py":
2263*da0073e9SAndroid Build Coastguard Worker                    continue
2264*da0073e9SAndroid Build Coastguard Worker                if any(mod_name.startswith(x) for x in ignored_modules):
2265*da0073e9SAndroid Build Coastguard Worker                    continue
2266*da0073e9SAndroid Build Coastguard Worker                try:
2267*da0073e9SAndroid Build Coastguard Worker                    mod = importlib.import_module(mod_name)
2268*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
2269*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(f"Failed to import {mod_name}: {e}") from e
2270*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(inspect.ismodule(mod))
2271*da0073e9SAndroid Build Coastguard Worker
2272*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "TODO enable on Windows")
2273*da0073e9SAndroid Build Coastguard Worker    def test_lazy_imports_are_lazy(self) -> None:
2274*da0073e9SAndroid Build Coastguard Worker        out = self._check_python_output("import sys;import torch;print(all(x not in sys.modules for x in torch._lazy_modules))")
2275*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.strip(), "True")
2276*da0073e9SAndroid Build Coastguard Worker
2277*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
2278*da0073e9SAndroid Build Coastguard Worker    def test_no_warning_on_import(self) -> None:
2279*da0073e9SAndroid Build Coastguard Worker        out = self._check_python_output("import torch")
2280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, "")
2281*da0073e9SAndroid Build Coastguard Worker
2282*da0073e9SAndroid Build Coastguard Worker    def test_not_import_sympy(self) -> None:
2283*da0073e9SAndroid Build Coastguard Worker        out = self._check_python_output("import torch;import sys;print('sympy' not in sys.modules)")
2284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.strip(), "True",
2285*da0073e9SAndroid Build Coastguard Worker                         "PyTorch should not depend on SymPy at import time as importing SymPy is *very* slow.\n"
2286*da0073e9SAndroid Build Coastguard Worker                         "See the beginning of the following blog post for how to profile and find which file is importing sympy:\n"
2287*da0073e9SAndroid Build Coastguard Worker                         "https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589\n\n"
2288*da0073e9SAndroid Build Coastguard Worker                         "If you hit this error, you may want to:\n"
2289*da0073e9SAndroid Build Coastguard Worker                         "  - Refactor your code to avoid depending on sympy files you may not need to depend\n"
2290*da0073e9SAndroid Build Coastguard Worker                         "  - Use TYPE_CHECKING if you are using sympy + strings if you are using sympy on type annotations\n"
2291*da0073e9SAndroid Build Coastguard Worker                         "  - Import things that depend on SymPy locally")
2292*da0073e9SAndroid Build Coastguard Worker
2293*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
2294*da0073e9SAndroid Build Coastguard Worker    @parametrize('path', ['torch', 'functorch'])
2295*da0073e9SAndroid Build Coastguard Worker    def test_no_mutate_global_logging_on_import(self, path) -> None:
2296*da0073e9SAndroid Build Coastguard Worker        # Calling logging.basicConfig, among other things, modifies the global
2297*da0073e9SAndroid Build Coastguard Worker        # logging state. It is not OK to modify the global logging state on
2298*da0073e9SAndroid Build Coastguard Worker        # `import torch` (or other submodules we own) because users do not expect it.
2299*da0073e9SAndroid Build Coastguard Worker        expected = 'abcdefghijklmnopqrstuvwxyz'
2300*da0073e9SAndroid Build Coastguard Worker        commands = [
2301*da0073e9SAndroid Build Coastguard Worker            'import logging',
2302*da0073e9SAndroid Build Coastguard Worker            f'import {path}',
2303*da0073e9SAndroid Build Coastguard Worker            '_logger = logging.getLogger("torch_test_testing")',
2304*da0073e9SAndroid Build Coastguard Worker            'logging.root.addHandler(logging.StreamHandler())',
2305*da0073e9SAndroid Build Coastguard Worker            'logging.root.setLevel(logging.INFO)',
2306*da0073e9SAndroid Build Coastguard Worker            f'_logger.info("{expected}")'
2307*da0073e9SAndroid Build Coastguard Worker        ]
2308*da0073e9SAndroid Build Coastguard Worker        out = self._check_python_output("; ".join(commands))
2309*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.strip(), expected)
2310*da0073e9SAndroid Build Coastguard Worker
2311*da0073e9SAndroid Build Coastguard Workerclass TestOpInfos(TestCase):
2312*da0073e9SAndroid Build Coastguard Worker    def test_sample_input(self) -> None:
2313*da0073e9SAndroid Build Coastguard Worker        a, b, c, d, e = (object() for _ in range(5))
2314*da0073e9SAndroid Build Coastguard Worker
2315*da0073e9SAndroid Build Coastguard Worker        # Construction with natural syntax
2316*da0073e9SAndroid Build Coastguard Worker        s = SampleInput(a, b, c, d=d, e=e)
2317*da0073e9SAndroid Build Coastguard Worker        assert s.input is a
2318*da0073e9SAndroid Build Coastguard Worker        assert s.args == (b, c)
2319*da0073e9SAndroid Build Coastguard Worker        assert s.kwargs == dict(d=d, e=e)
2320*da0073e9SAndroid Build Coastguard Worker
2321*da0073e9SAndroid Build Coastguard Worker        # Construction with explicit args and kwargs
2322*da0073e9SAndroid Build Coastguard Worker        s = SampleInput(a, args=(b,), kwargs=dict(c=c, d=d, e=e))
2323*da0073e9SAndroid Build Coastguard Worker        assert s.input is a
2324*da0073e9SAndroid Build Coastguard Worker        assert s.args == (b,)
2325*da0073e9SAndroid Build Coastguard Worker        assert s.kwargs == dict(c=c, d=d, e=e)
2326*da0073e9SAndroid Build Coastguard Worker
2327*da0073e9SAndroid Build Coastguard Worker        # Construction with a mixed form will error
2328*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
2329*da0073e9SAndroid Build Coastguard Worker            s = SampleInput(a, b, c, args=(d, e))
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
2332*da0073e9SAndroid Build Coastguard Worker            s = SampleInput(a, b, c, kwargs=dict(d=d, e=e))
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
2335*da0073e9SAndroid Build Coastguard Worker            s = SampleInput(a, args=(b, c), d=d, e=e)
2336*da0073e9SAndroid Build Coastguard Worker
2337*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
2338*da0073e9SAndroid Build Coastguard Worker            s = SampleInput(a, b, c=c, kwargs=dict(d=d, e=e))
2339*da0073e9SAndroid Build Coastguard Worker
2340*da0073e9SAndroid Build Coastguard Worker        # Mixing metadata into "natural" construction will error
2341*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
2342*da0073e9SAndroid Build Coastguard Worker            s = SampleInput(a, b, name="foo")
2343*da0073e9SAndroid Build Coastguard Worker
2344*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
2345*da0073e9SAndroid Build Coastguard Worker            s = SampleInput(a, b, output_process_fn_grad=lambda x: x)
2346*da0073e9SAndroid Build Coastguard Worker
2347*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
2348*da0073e9SAndroid Build Coastguard Worker            s = SampleInput(a, b, broadcasts_input=True)
2349*da0073e9SAndroid Build Coastguard Worker
2350*da0073e9SAndroid Build Coastguard Worker        # But when only input is given, metadata is allowed for backward
2351*da0073e9SAndroid Build Coastguard Worker        # compatibility
2352*da0073e9SAndroid Build Coastguard Worker        s = SampleInput(a, broadcasts_input=True)
2353*da0073e9SAndroid Build Coastguard Worker        assert s.input is a
2354*da0073e9SAndroid Build Coastguard Worker        assert s.broadcasts_input
2355*da0073e9SAndroid Build Coastguard Worker
2356*da0073e9SAndroid Build Coastguard Worker    def test_sample_input_metadata(self) -> None:
2357*da0073e9SAndroid Build Coastguard Worker        a, b = (object() for _ in range(2))
2358*da0073e9SAndroid Build Coastguard Worker        s1 = SampleInput(a, b=b)
2359*da0073e9SAndroid Build Coastguard Worker        self.assertIs(s1.output_process_fn_grad(None), None)
2360*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(s1.broadcasts_input)
2361*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s1.name, "")
2362*da0073e9SAndroid Build Coastguard Worker
2363*da0073e9SAndroid Build Coastguard Worker        s2 = s1.with_metadata(
2364*da0073e9SAndroid Build Coastguard Worker            output_process_fn_grad=lambda x: a,
2365*da0073e9SAndroid Build Coastguard Worker            broadcasts_input=True,
2366*da0073e9SAndroid Build Coastguard Worker            name="foo",
2367*da0073e9SAndroid Build Coastguard Worker        )
2368*da0073e9SAndroid Build Coastguard Worker        self.assertIs(s1, s2)
2369*da0073e9SAndroid Build Coastguard Worker        self.assertIs(s2.output_process_fn_grad(None), a)
2370*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s2.broadcasts_input)
2371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s2.name, "foo")
2372*da0073e9SAndroid Build Coastguard Worker
2373*da0073e9SAndroid Build Coastguard Worker
2374*da0073e9SAndroid Build Coastguard Worker# Tests that validate the various sample generating functions on each OpInfo.
2375*da0073e9SAndroid Build Coastguard Workerclass TestOpInfoSampleFunctions(TestCase):
2376*da0073e9SAndroid Build Coastguard Worker
2377*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, dtypes=OpDTypes.any_one)
2378*da0073e9SAndroid Build Coastguard Worker    def test_opinfo_sample_generators(self, device, dtype, op):
2379*da0073e9SAndroid Build Coastguard Worker        # Test op.sample_inputs doesn't generate multiple samples when called
2380*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype)
2381*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(samples, Iterator)
2382*da0073e9SAndroid Build Coastguard Worker
2383*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one)
2384*da0073e9SAndroid Build Coastguard Worker    def test_opinfo_reference_generators(self, device, dtype, op):
2385*da0073e9SAndroid Build Coastguard Worker        # Test op.reference_inputs doesn't generate multiple samples when called
2386*da0073e9SAndroid Build Coastguard Worker        samples = op.reference_inputs(device, dtype)
2387*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(samples, Iterator)
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker    @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
2390*da0073e9SAndroid Build Coastguard Worker    def test_opinfo_error_generators(self, device, op):
2391*da0073e9SAndroid Build Coastguard Worker        # Test op.error_inputs doesn't generate multiple inputs when called
2392*da0073e9SAndroid Build Coastguard Worker        samples = op.error_inputs(device)
2393*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(samples, Iterator)
2394*da0073e9SAndroid Build Coastguard Worker
2395*da0073e9SAndroid Build Coastguard Worker
2396*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestOpInfoSampleFunctions, globals())
2397*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestImports)
2398*da0073e9SAndroid Build Coastguard Worker
2399*da0073e9SAndroid Build Coastguard Worker
2400*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
2401*da0073e9SAndroid Build Coastguard Worker    run_tests()
2402