1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo import utils 4*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.test_case import TestCase 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerclass TestUtils(TestCase): 8*da0073e9SAndroid Build Coastguard Worker def test_nan(self): 9*da0073e9SAndroid Build Coastguard Worker a = torch.Tensor([float("nan")]) 10*da0073e9SAndroid Build Coastguard Worker b = torch.Tensor([float("nan")]) 11*da0073e9SAndroid Build Coastguard Worker fp64_ref = torch.DoubleTensor([5.0]) 12*da0073e9SAndroid Build Coastguard Worker res = utils.same(a, b, fp64_ref=fp64_ref, equal_nan=True) 13*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res) 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker def test_larger_multiplier_for_smaller_tensor(self): 16*da0073e9SAndroid Build Coastguard Worker """ 17*da0073e9SAndroid Build Coastguard Worker Tensor numel between (10, 500] 18*da0073e9SAndroid Build Coastguard Worker """ 19*da0073e9SAndroid Build Coastguard Worker N = 100 20*da0073e9SAndroid Build Coastguard Worker fp64_ref = torch.full([N], 0.0, dtype=torch.double) 21*da0073e9SAndroid Build Coastguard Worker a = torch.full([N], 1.0) 22*da0073e9SAndroid Build Coastguard Worker tol = 4 * 1e-2 23*da0073e9SAndroid Build Coastguard Worker self.assertTrue(utils.same(a, a * 2, fp64_ref=fp64_ref, tol=tol)) 24*da0073e9SAndroid Build Coastguard Worker self.assertFalse(utils.same(a, a * 4, fp64_ref=fp64_ref, tol=tol)) 25*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 26*da0073e9SAndroid Build Coastguard Worker utils.same( 27*da0073e9SAndroid Build Coastguard Worker a, 28*da0073e9SAndroid Build Coastguard Worker a * 4, 29*da0073e9SAndroid Build Coastguard Worker fp64_ref=fp64_ref, 30*da0073e9SAndroid Build Coastguard Worker use_larger_multiplier_for_smaller_tensor=True, 31*da0073e9SAndroid Build Coastguard Worker tol=tol, 32*da0073e9SAndroid Build Coastguard Worker ) 33*da0073e9SAndroid Build Coastguard Worker ) 34*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 35*da0073e9SAndroid Build Coastguard Worker utils.same( 36*da0073e9SAndroid Build Coastguard Worker a, 37*da0073e9SAndroid Build Coastguard Worker a * 6, 38*da0073e9SAndroid Build Coastguard Worker fp64_ref=fp64_ref, 39*da0073e9SAndroid Build Coastguard Worker use_larger_multiplier_for_smaller_tensor=True, 40*da0073e9SAndroid Build Coastguard Worker tol=tol, 41*da0073e9SAndroid Build Coastguard Worker ) 42*da0073e9SAndroid Build Coastguard Worker ) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker def test_larger_multiplier_for_even_smaller_tensor(self): 45*da0073e9SAndroid Build Coastguard Worker """ 46*da0073e9SAndroid Build Coastguard Worker Tesnor numel <=10 47*da0073e9SAndroid Build Coastguard Worker """ 48*da0073e9SAndroid Build Coastguard Worker fp64_ref = torch.DoubleTensor([0.0]) 49*da0073e9SAndroid Build Coastguard Worker a = torch.Tensor([1.0]) 50*da0073e9SAndroid Build Coastguard Worker tol = 4 * 1e-2 51*da0073e9SAndroid Build Coastguard Worker self.assertTrue(utils.same(a, a * 2, fp64_ref=fp64_ref, tol=tol)) 52*da0073e9SAndroid Build Coastguard Worker self.assertFalse(utils.same(a, a * 7, fp64_ref=fp64_ref, tol=tol)) 53*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 54*da0073e9SAndroid Build Coastguard Worker utils.same( 55*da0073e9SAndroid Build Coastguard Worker a, 56*da0073e9SAndroid Build Coastguard Worker a * 7, 57*da0073e9SAndroid Build Coastguard Worker fp64_ref=fp64_ref, 58*da0073e9SAndroid Build Coastguard Worker use_larger_multiplier_for_smaller_tensor=True, 59*da0073e9SAndroid Build Coastguard Worker tol=tol, 60*da0073e9SAndroid Build Coastguard Worker ) 61*da0073e9SAndroid Build Coastguard Worker ) 62*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 63*da0073e9SAndroid Build Coastguard Worker utils.same( 64*da0073e9SAndroid Build Coastguard Worker a, 65*da0073e9SAndroid Build Coastguard Worker a * 20, 66*da0073e9SAndroid Build Coastguard Worker fp64_ref=fp64_ref, 67*da0073e9SAndroid Build Coastguard Worker use_larger_multiplier_for_smaller_tensor=True, 68*da0073e9SAndroid Build Coastguard Worker tol=tol, 69*da0073e9SAndroid Build Coastguard Worker ) 70*da0073e9SAndroid Build Coastguard Worker ) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 74*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker run_tests() 77