xref: /aosp_15_r20/external/pytorch/test/dynamo/test_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import torch
3from torch._dynamo import utils
4from torch._inductor.test_case import TestCase
5
6
7class TestUtils(TestCase):
8    def test_nan(self):
9        a = torch.Tensor([float("nan")])
10        b = torch.Tensor([float("nan")])
11        fp64_ref = torch.DoubleTensor([5.0])
12        res = utils.same(a, b, fp64_ref=fp64_ref, equal_nan=True)
13        self.assertTrue(res)
14
15    def test_larger_multiplier_for_smaller_tensor(self):
16        """
17        Tensor numel between (10, 500]
18        """
19        N = 100
20        fp64_ref = torch.full([N], 0.0, dtype=torch.double)
21        a = torch.full([N], 1.0)
22        tol = 4 * 1e-2
23        self.assertTrue(utils.same(a, a * 2, fp64_ref=fp64_ref, tol=tol))
24        self.assertFalse(utils.same(a, a * 4, fp64_ref=fp64_ref, tol=tol))
25        self.assertTrue(
26            utils.same(
27                a,
28                a * 4,
29                fp64_ref=fp64_ref,
30                use_larger_multiplier_for_smaller_tensor=True,
31                tol=tol,
32            )
33        )
34        self.assertFalse(
35            utils.same(
36                a,
37                a * 6,
38                fp64_ref=fp64_ref,
39                use_larger_multiplier_for_smaller_tensor=True,
40                tol=tol,
41            )
42        )
43
44    def test_larger_multiplier_for_even_smaller_tensor(self):
45        """
46        Tesnor numel <=10
47        """
48        fp64_ref = torch.DoubleTensor([0.0])
49        a = torch.Tensor([1.0])
50        tol = 4 * 1e-2
51        self.assertTrue(utils.same(a, a * 2, fp64_ref=fp64_ref, tol=tol))
52        self.assertFalse(utils.same(a, a * 7, fp64_ref=fp64_ref, tol=tol))
53        self.assertTrue(
54            utils.same(
55                a,
56                a * 7,
57                fp64_ref=fp64_ref,
58                use_larger_multiplier_for_smaller_tensor=True,
59                tol=tol,
60            )
61        )
62        self.assertFalse(
63            utils.same(
64                a,
65                a * 20,
66                fp64_ref=fp64_ref,
67                use_larger_multiplier_for_smaller_tensor=True,
68                tol=tol,
69            )
70        )
71
72
73if __name__ == "__main__":
74    from torch._dynamo.test_case import run_tests
75
76    run_tests()
77