1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 5*da0073e9SAndroid Build Coastguard Workerimport torch.onnx.operators 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerdef fn(a, b): 9*da0073e9SAndroid Build Coastguard Worker return a + b * 0.67 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass InteropTests(torch._dynamo.test_case.TestCase): 13*da0073e9SAndroid Build Coastguard Worker def _common(self, fn): 14*da0073e9SAndroid Build Coastguard Worker inputs = [torch.randn(10), torch.randn(10)] 15*da0073e9SAndroid Build Coastguard Worker ref = fn(*inputs) 16*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 17*da0073e9SAndroid Build Coastguard Worker res = opt_fn(*inputs) 18*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker def test_fx_fn(self): 21*da0073e9SAndroid Build Coastguard Worker fx_fn = torch.fx.symbolic_trace(fn) 22*da0073e9SAndroid Build Coastguard Worker self._common(lambda a, b: fx_fn(a, b) + 1) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker def test_script_fn(self): 25*da0073e9SAndroid Build Coastguard Worker script_fn = torch.jit.script(fn) 26*da0073e9SAndroid Build Coastguard Worker self._common(lambda a, b: script_fn(a, b) + 1) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker def test_trace_fn(self): 29*da0073e9SAndroid Build Coastguard Worker trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)]) 30*da0073e9SAndroid Build Coastguard Worker self._common(lambda a, b: trace_fn(a, b) + 1) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker def test_vmap_in_graph(self): 33*da0073e9SAndroid Build Coastguard Worker from functools import wraps 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker from torch._dynamo import allow_in_graph 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker def traceable(f): 38*da0073e9SAndroid Build Coastguard Worker f = allow_in_graph(f) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker @wraps(f) 41*da0073e9SAndroid Build Coastguard Worker def wrapper(*args, **kwargs): 42*da0073e9SAndroid Build Coastguard Worker return f(*args, **kwargs) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker return wrapper 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 47*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 5, 3) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def fn(x): 50*da0073e9SAndroid Build Coastguard Worker return torch.vmap(torch.Tensor.t)(x) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(fn, backend=cnts, fullgraph=True) 53*da0073e9SAndroid Build Coastguard Worker fn_opt_traceable = torch.compile(traceable(fn), backend=cnts, fullgraph=True) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), fn_opt(x)) 56*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 57*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn_opt(x), fn_opt_traceable(x)) 58*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 62*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker run_tests() 65