xref: /aosp_15_r20/external/pytorch/test/dynamo/test_interop.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import torch
3import torch._dynamo.test_case
4import torch._dynamo.testing
5import torch.onnx.operators
6
7
8def fn(a, b):
9    return a + b * 0.67
10
11
12class InteropTests(torch._dynamo.test_case.TestCase):
13    def _common(self, fn):
14        inputs = [torch.randn(10), torch.randn(10)]
15        ref = fn(*inputs)
16        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
17        res = opt_fn(*inputs)
18        self.assertEqual(ref, res)
19
20    def test_fx_fn(self):
21        fx_fn = torch.fx.symbolic_trace(fn)
22        self._common(lambda a, b: fx_fn(a, b) + 1)
23
24    def test_script_fn(self):
25        script_fn = torch.jit.script(fn)
26        self._common(lambda a, b: script_fn(a, b) + 1)
27
28    def test_trace_fn(self):
29        trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)])
30        self._common(lambda a, b: trace_fn(a, b) + 1)
31
32    def test_vmap_in_graph(self):
33        from functools import wraps
34
35        from torch._dynamo import allow_in_graph
36
37        def traceable(f):
38            f = allow_in_graph(f)
39
40            @wraps(f)
41            def wrapper(*args, **kwargs):
42                return f(*args, **kwargs)
43
44            return wrapper
45
46        cnts = torch._dynamo.testing.CompileCounter()
47        x = torch.randn(3, 5, 3)
48
49        def fn(x):
50            return torch.vmap(torch.Tensor.t)(x)
51
52        fn_opt = torch.compile(fn, backend=cnts, fullgraph=True)
53        fn_opt_traceable = torch.compile(traceable(fn), backend=cnts, fullgraph=True)
54
55        self.assertEqual(fn(x), fn_opt(x))
56        self.assertEqual(cnts.frame_count, 1)
57        self.assertEqual(fn_opt(x), fn_opt_traceable(x))
58        self.assertEqual(cnts.frame_count, 2)
59
60
61if __name__ == "__main__":
62    from torch._dynamo.test_case import run_tests
63
64    run_tests()
65