xref: /aosp_15_r20/external/pytorch/test/dynamo/test_interop.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
5*da0073e9SAndroid Build Coastguard Workerimport torch.onnx.operators
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerdef fn(a, b):
9*da0073e9SAndroid Build Coastguard Worker    return a + b * 0.67
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerclass InteropTests(torch._dynamo.test_case.TestCase):
13*da0073e9SAndroid Build Coastguard Worker    def _common(self, fn):
14*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.randn(10), torch.randn(10)]
15*da0073e9SAndroid Build Coastguard Worker        ref = fn(*inputs)
16*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
17*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(*inputs)
18*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker    def test_fx_fn(self):
21*da0073e9SAndroid Build Coastguard Worker        fx_fn = torch.fx.symbolic_trace(fn)
22*da0073e9SAndroid Build Coastguard Worker        self._common(lambda a, b: fx_fn(a, b) + 1)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    def test_script_fn(self):
25*da0073e9SAndroid Build Coastguard Worker        script_fn = torch.jit.script(fn)
26*da0073e9SAndroid Build Coastguard Worker        self._common(lambda a, b: script_fn(a, b) + 1)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    def test_trace_fn(self):
29*da0073e9SAndroid Build Coastguard Worker        trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)])
30*da0073e9SAndroid Build Coastguard Worker        self._common(lambda a, b: trace_fn(a, b) + 1)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    def test_vmap_in_graph(self):
33*da0073e9SAndroid Build Coastguard Worker        from functools import wraps
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo import allow_in_graph
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker        def traceable(f):
38*da0073e9SAndroid Build Coastguard Worker            f = allow_in_graph(f)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker            @wraps(f)
41*da0073e9SAndroid Build Coastguard Worker            def wrapper(*args, **kwargs):
42*da0073e9SAndroid Build Coastguard Worker                return f(*args, **kwargs)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker            return wrapper
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
47*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 5, 3)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        def fn(x):
50*da0073e9SAndroid Build Coastguard Worker            return torch.vmap(torch.Tensor.t)(x)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(fn, backend=cnts, fullgraph=True)
53*da0073e9SAndroid Build Coastguard Worker        fn_opt_traceable = torch.compile(traceable(fn), backend=cnts, fullgraph=True)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), fn_opt(x))
56*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
57*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn_opt(x), fn_opt_traceable(x))
58*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
62*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    run_tests()
65