1# Owner(s): ["module: dynamo"] 2import torch 3import torch._dynamo.test_case 4import torch._dynamo.testing 5from torch._dynamo import eval_frame 6from torch._dynamo.hooks import Hooks 7 8 9c = 10 10 11 12def fn1(a, b): 13 return a + b - c 14 15 16def fn2(a, b): 17 x = 0 18 y = 1 19 20 def modify(): 21 nonlocal x 22 x += a + b + c 23 24 for _ in range(2): 25 modify() 26 27 return x + y 28 29 30def fn3(): 31 yield 1 32 yield 2 33 34 35with_debug_nops = eval_frame._optimize_catch_errors( 36 torch._dynamo.testing.debug_insert_nops, Hooks(None, None) 37) 38 39 40class NopTests(torch._dynamo.test_case.TestCase): 41 @with_debug_nops 42 def test1(self): 43 self.assertEqual(fn1(1, 2), -7) 44 self.assertEqual(fn1(1, 2), -7) 45 46 @with_debug_nops 47 def test2(self): 48 self.assertEqual(fn2(1, 2), 27) 49 self.assertEqual(fn2(1, 2), 27) 50 51 @with_debug_nops 52 def test3(self): 53 t = fn3() 54 self.assertEqual(next(t), 1) 55 self.assertEqual(next(t), 2) 56 self.assertRaises(StopIteration, lambda: next(t)) 57 58 def test_extended_args(self): 59 too_many_adds = "+".join(["a", "b"] * 256) 60 source = ( 61 f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)" 62 ) 63 fn = eval(source) 64 a = torch.ones(1) 65 b = torch.ones(1) 66 fn = with_debug_nops(fn) 67 self.assertEqual(fn(a, b).sum(), 513) 68 69 70if __name__ == "__main__": 71 from torch._dynamo.test_case import run_tests 72 73 run_tests() 74