1# Owner(s): ["module: dynamo"] 2import torch 3import torch._dynamo 4import torch._dynamo.test_case 5 6 7class PreDispatchTests(torch._dynamo.test_case.TestCase): 8 def test_no_grad_simple(self): 9 def f(a): 10 b = a.sin() 11 with torch.no_grad(): 12 c = b.cos() 13 return b * c.sin() 14 15 f_compiled = torch.compile(f, backend="pre_dispatch_eager") 16 17 a_ref = torch.randn(4, requires_grad=True) 18 a_test = a_ref.clone().detach().requires_grad_(True) 19 20 out_ref = f(a_ref) 21 out_test = f_compiled(a_test) 22 self.assertEqual(out_ref, out_test) 23 24 out_ref.sum().backward() 25 out_test.sum().backward() 26 self.assertEqual(a_ref.grad, a_test.grad) 27 28 def test_enable_grad_and_no_grad(self): 29 def f(a): 30 b = a * 2 31 with torch.no_grad(): 32 c = b * 3 33 with torch.enable_grad(): 34 d = c * 4 35 e = d * 5 36 return b + c + d + e 37 38 f_compiled = torch.compile(f, backend="pre_dispatch_eager") 39 40 a_ref = torch.randn(4, requires_grad=True) 41 a_test = a_ref.clone().detach().requires_grad_(True) 42 43 out_ref = f(a_ref) 44 out_test = f_compiled(a_test) 45 self.assertEqual(out_ref, out_test) 46 47 out_ref.sum().backward() 48 out_test.sum().backward() 49 self.assertEqual(a_ref.grad, a_test.grad) 50 51 def test_autocast_simple(self): 52 def f(a): 53 b = a * 2 54 with torch.amp.autocast(device_type="cpu"): 55 c = torch.matmul(b, b) 56 return b + c 57 58 f_compiled = torch.compile(f, backend="pre_dispatch_eager") 59 60 a_ref = torch.randn(4, device="cpu", requires_grad=True) 61 a_test = a_ref.clone().detach().requires_grad_(True) 62 63 out_ref = f(a_ref) 64 out_test = f_compiled(a_test) 65 self.assertEqual(out_ref, out_test) 66 67 out_ref.sum().backward() 68 out_test.sum().backward() 69 self.assertEqual(a_ref.grad, a_test.grad) 70 71 72if __name__ == "__main__": 73 from torch._dynamo.test_case import run_tests 74 75 run_tests() 76