1# Owner(s): ["module: dynamo"] 2import torch 3import torch._dynamo.test_case 4import torch._dynamo.testing 5import torch.onnx.operators 6 7 8def fn(a, b): 9 return a + b * 0.67 10 11 12class InteropTests(torch._dynamo.test_case.TestCase): 13 def _common(self, fn): 14 inputs = [torch.randn(10), torch.randn(10)] 15 ref = fn(*inputs) 16 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 17 res = opt_fn(*inputs) 18 self.assertEqual(ref, res) 19 20 def test_fx_fn(self): 21 fx_fn = torch.fx.symbolic_trace(fn) 22 self._common(lambda a, b: fx_fn(a, b) + 1) 23 24 def test_script_fn(self): 25 script_fn = torch.jit.script(fn) 26 self._common(lambda a, b: script_fn(a, b) + 1) 27 28 def test_trace_fn(self): 29 trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)]) 30 self._common(lambda a, b: trace_fn(a, b) + 1) 31 32 def test_vmap_in_graph(self): 33 from functools import wraps 34 35 from torch._dynamo import allow_in_graph 36 37 def traceable(f): 38 f = allow_in_graph(f) 39 40 @wraps(f) 41 def wrapper(*args, **kwargs): 42 return f(*args, **kwargs) 43 44 return wrapper 45 46 cnts = torch._dynamo.testing.CompileCounter() 47 x = torch.randn(3, 5, 3) 48 49 def fn(x): 50 return torch.vmap(torch.Tensor.t)(x) 51 52 fn_opt = torch.compile(fn, backend=cnts, fullgraph=True) 53 fn_opt_traceable = torch.compile(traceable(fn), backend=cnts, fullgraph=True) 54 55 self.assertEqual(fn(x), fn_opt(x)) 56 self.assertEqual(cnts.frame_count, 1) 57 self.assertEqual(fn_opt(x), fn_opt_traceable(x)) 58 self.assertEqual(cnts.frame_count, 2) 59 60 61if __name__ == "__main__": 62 from torch._dynamo.test_case import run_tests 63 64 run_tests() 65