xref: /aosp_15_r20/external/pytorch/test/dynamo/test_nops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import torch
3import torch._dynamo.test_case
4import torch._dynamo.testing
5from torch._dynamo import eval_frame
6from torch._dynamo.hooks import Hooks
7
8
9c = 10
10
11
12def fn1(a, b):
13    return a + b - c
14
15
16def fn2(a, b):
17    x = 0
18    y = 1
19
20    def modify():
21        nonlocal x
22        x += a + b + c
23
24    for _ in range(2):
25        modify()
26
27    return x + y
28
29
30def fn3():
31    yield 1
32    yield 2
33
34
35with_debug_nops = eval_frame._optimize_catch_errors(
36    torch._dynamo.testing.debug_insert_nops, Hooks(None, None)
37)
38
39
40class NopTests(torch._dynamo.test_case.TestCase):
41    @with_debug_nops
42    def test1(self):
43        self.assertEqual(fn1(1, 2), -7)
44        self.assertEqual(fn1(1, 2), -7)
45
46    @with_debug_nops
47    def test2(self):
48        self.assertEqual(fn2(1, 2), 27)
49        self.assertEqual(fn2(1, 2), 27)
50
51    @with_debug_nops
52    def test3(self):
53        t = fn3()
54        self.assertEqual(next(t), 1)
55        self.assertEqual(next(t), 2)
56        self.assertRaises(StopIteration, lambda: next(t))
57
58    def test_extended_args(self):
59        too_many_adds = "+".join(["a", "b"] * 256)
60        source = (
61            f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)"
62        )
63        fn = eval(source)
64        a = torch.ones(1)
65        b = torch.ones(1)
66        fn = with_debug_nops(fn)
67        self.assertEqual(fn(a, b).sum(), 513)
68
69
70if __name__ == "__main__":
71    from torch._dynamo.test_case import run_tests
72
73    run_tests()
74