1# Owner(s): ["module: dynamo"] 2 3import torch 4import torch._dynamo.test_case 5 6 7def fn_creator(): 8 var1 = 1 9 10 def fn(x): 11 x = x + 1 12 var2 = 1 13 torch._dynamo.graph_break() 14 x = x + var1 15 16 def inner_fn(): 17 return var2 18 19 return x 20 21 return fn 22 23 24class ResumeFunctionTests(torch._dynamo.test_case.TestCase): 25 def test_freevars(self): 26 fn = fn_creator() 27 opt_fn = torch.compile(fn, backend="eager") 28 opt_fn(torch.randn(10)) 29 codes = [v for k, v in list(globals().items()) if k.startswith("__resume_at")] 30 self.assertEqual(len(codes), 1) 31 # co_freevars of resume functions, are sorted concatenation of the original function's co_freevars and co_cellvars 32 self.assertEqual(codes[0].co_freevars, ("var1", "var2")) 33 34 35if __name__ == "__main__": 36 from torch._dynamo.test_case import run_tests 37 38 run_tests() 39