xref: /aosp_15_r20/external/pytorch/test/dynamo/test_resume.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import torch
4import torch._dynamo.test_case
5
6
7def fn_creator():
8    var1 = 1
9
10    def fn(x):
11        x = x + 1
12        var2 = 1
13        torch._dynamo.graph_break()
14        x = x + var1
15
16        def inner_fn():
17            return var2
18
19        return x
20
21    return fn
22
23
24class ResumeFunctionTests(torch._dynamo.test_case.TestCase):
25    def test_freevars(self):
26        fn = fn_creator()
27        opt_fn = torch.compile(fn, backend="eager")
28        opt_fn(torch.randn(10))
29        codes = [v for k, v in list(globals().items()) if k.startswith("__resume_at")]
30        self.assertEqual(len(codes), 1)
31        # co_freevars of resume functions, are sorted concatenation of the original function's co_freevars and co_cellvars
32        self.assertEqual(codes[0].co_freevars, ("var1", "var2"))
33
34
35if __name__ == "__main__":
36    from torch._dynamo.test_case import run_tests
37
38    run_tests()
39