xref: /aosp_15_r20/external/pytorch/test/dynamo/test_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo import utils
4*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.test_case import TestCase
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerclass TestUtils(TestCase):
8*da0073e9SAndroid Build Coastguard Worker    def test_nan(self):
9*da0073e9SAndroid Build Coastguard Worker        a = torch.Tensor([float("nan")])
10*da0073e9SAndroid Build Coastguard Worker        b = torch.Tensor([float("nan")])
11*da0073e9SAndroid Build Coastguard Worker        fp64_ref = torch.DoubleTensor([5.0])
12*da0073e9SAndroid Build Coastguard Worker        res = utils.same(a, b, fp64_ref=fp64_ref, equal_nan=True)
13*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(res)
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker    def test_larger_multiplier_for_smaller_tensor(self):
16*da0073e9SAndroid Build Coastguard Worker        """
17*da0073e9SAndroid Build Coastguard Worker        Tensor numel between (10, 500]
18*da0073e9SAndroid Build Coastguard Worker        """
19*da0073e9SAndroid Build Coastguard Worker        N = 100
20*da0073e9SAndroid Build Coastguard Worker        fp64_ref = torch.full([N], 0.0, dtype=torch.double)
21*da0073e9SAndroid Build Coastguard Worker        a = torch.full([N], 1.0)
22*da0073e9SAndroid Build Coastguard Worker        tol = 4 * 1e-2
23*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(utils.same(a, a * 2, fp64_ref=fp64_ref, tol=tol))
24*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(utils.same(a, a * 4, fp64_ref=fp64_ref, tol=tol))
25*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
26*da0073e9SAndroid Build Coastguard Worker            utils.same(
27*da0073e9SAndroid Build Coastguard Worker                a,
28*da0073e9SAndroid Build Coastguard Worker                a * 4,
29*da0073e9SAndroid Build Coastguard Worker                fp64_ref=fp64_ref,
30*da0073e9SAndroid Build Coastguard Worker                use_larger_multiplier_for_smaller_tensor=True,
31*da0073e9SAndroid Build Coastguard Worker                tol=tol,
32*da0073e9SAndroid Build Coastguard Worker            )
33*da0073e9SAndroid Build Coastguard Worker        )
34*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(
35*da0073e9SAndroid Build Coastguard Worker            utils.same(
36*da0073e9SAndroid Build Coastguard Worker                a,
37*da0073e9SAndroid Build Coastguard Worker                a * 6,
38*da0073e9SAndroid Build Coastguard Worker                fp64_ref=fp64_ref,
39*da0073e9SAndroid Build Coastguard Worker                use_larger_multiplier_for_smaller_tensor=True,
40*da0073e9SAndroid Build Coastguard Worker                tol=tol,
41*da0073e9SAndroid Build Coastguard Worker            )
42*da0073e9SAndroid Build Coastguard Worker        )
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker    def test_larger_multiplier_for_even_smaller_tensor(self):
45*da0073e9SAndroid Build Coastguard Worker        """
46*da0073e9SAndroid Build Coastguard Worker        Tesnor numel <=10
47*da0073e9SAndroid Build Coastguard Worker        """
48*da0073e9SAndroid Build Coastguard Worker        fp64_ref = torch.DoubleTensor([0.0])
49*da0073e9SAndroid Build Coastguard Worker        a = torch.Tensor([1.0])
50*da0073e9SAndroid Build Coastguard Worker        tol = 4 * 1e-2
51*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(utils.same(a, a * 2, fp64_ref=fp64_ref, tol=tol))
52*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(utils.same(a, a * 7, fp64_ref=fp64_ref, tol=tol))
53*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
54*da0073e9SAndroid Build Coastguard Worker            utils.same(
55*da0073e9SAndroid Build Coastguard Worker                a,
56*da0073e9SAndroid Build Coastguard Worker                a * 7,
57*da0073e9SAndroid Build Coastguard Worker                fp64_ref=fp64_ref,
58*da0073e9SAndroid Build Coastguard Worker                use_larger_multiplier_for_smaller_tensor=True,
59*da0073e9SAndroid Build Coastguard Worker                tol=tol,
60*da0073e9SAndroid Build Coastguard Worker            )
61*da0073e9SAndroid Build Coastguard Worker        )
62*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(
63*da0073e9SAndroid Build Coastguard Worker            utils.same(
64*da0073e9SAndroid Build Coastguard Worker                a,
65*da0073e9SAndroid Build Coastguard Worker                a * 20,
66*da0073e9SAndroid Build Coastguard Worker                fp64_ref=fp64_ref,
67*da0073e9SAndroid Build Coastguard Worker                use_larger_multiplier_for_smaller_tensor=True,
68*da0073e9SAndroid Build Coastguard Worker                tol=tol,
69*da0073e9SAndroid Build Coastguard Worker            )
70*da0073e9SAndroid Build Coastguard Worker        )
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
74*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    run_tests()
77