xref: /aosp_15_r20/external/pytorch/test/inductor/test_standalone_compile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import torch
3from torch import _dynamo as dynamo, _inductor as inductor
4from torch._inductor.test_case import run_tests, TestCase
5from torch._inductor.utils import gen_gm_and_inputs
6from torch.fx import symbolic_trace
7from torch.fx.experimental.proxy_tensor import make_fx
8from torch.testing._internal.inductor_utils import HAS_CPU
9
10
11class MyModule(torch.nn.Module):
12    def __init__(self) -> None:
13        super().__init__()
14        self.a = torch.nn.Linear(10, 10)
15        self.b = torch.nn.Linear(10, 10)
16        self.relu = torch.nn.ReLU()
17
18    def forward(self, x):
19        x = self.relu(self.a(x))
20        x = torch.sigmoid(self.b(x))
21        return x
22
23
24class MyModule2(MyModule):
25    def forward(self, x):  # takes a dict of list
26        a, b = x["key"]
27        return {"result": super().forward(a) + b}
28
29
30class MyModule3(MyModule):
31    def forward(self, x):
32        return (super().forward(x),)
33
34
35class TestStandaloneInductor(TestCase):
36    """
37    These test check that you can call TorchInductor directly without
38    going through TorchDynamo.
39    """
40
41    def test_inductor_via_fx(self):
42        mod = MyModule3().eval()
43        inp = torch.randn(10)
44        correct = mod(inp)
45        mod_opt = inductor.compile(symbolic_trace(mod), [inp])
46        actual = mod_opt(inp)
47        self.assertEqual(actual, correct)
48
49    def test_inductor_via_fx_tensor_return(self):
50        mod = MyModule().eval()
51        inp = torch.randn(10)
52        correct = mod(inp)
53        mod_opt = inductor.compile(symbolic_trace(mod), [inp])
54        actual = mod_opt(inp)
55        self.assertEqual(actual, correct)
56
57    def test_inductor_via_fx_dict_input(self):
58        mod = MyModule2().eval()
59        inp = {"key": [torch.randn(10), torch.randn(10)]}
60        correct = mod(inp)
61        mod_opt = inductor.compile(symbolic_trace(mod), [inp])
62        actual = mod_opt(inp)
63        self.assertEqual(actual, correct)
64
65    def test_inductor_via_make_fx(self):
66        mod = MyModule().eval()
67        inp = torch.randn(10)
68        correct = mod(inp)
69        mod_opt = inductor.compile(make_fx(mod)(inp), [inp])
70        actual = mod_opt(inp)
71        self.assertEqual(actual, correct)
72
73    def test_inductor_via_bare_module(self):
74        mod = MyModule3().eval()
75        inp = torch.randn(10)
76        correct = mod(inp)
77        # no FX graph at all (mod must return list/tuple in this case)
78        mod_opt = inductor.compile(mod, [inp])
79        actual = mod_opt(inp)
80        self.assertEqual(actual, correct)
81
82    def test_inductor_via_export1(self):
83        mod = MyModule3().eval()
84        inp = torch.randn(10)
85        correct = mod(inp)
86        gm, guards = dynamo.export(mod, inp, aten_graph=True)
87        mod_opt = inductor.compile(gm, [inp])
88        actual = mod_opt(inp)
89        self.assertEqual(actual, correct)
90
91    def test_inductor_via_export2(self):
92        mod = MyModule2().eval()
93        inp = {"key": [torch.randn(10), torch.randn(10)]}
94        correct = mod(inp)
95        gm, guards = dynamo.export(mod, inp)
96        mod_opt = inductor.compile(gm, [inp])
97        actual = mod_opt(inp)
98        self.assertEqual(actual, correct)
99
100    def test_inductor_via_op_with_multiple_outputs(self):
101        x1 = torch.randn((2, 512, 128))
102        x2 = [128]
103        x3 = torch.randn(128)
104        x4 = torch.randn((128,))
105        x5 = 1e-6
106        mod, inp = gen_gm_and_inputs(
107            torch.ops.aten.native_layer_norm.default, (x1, x2, x3, x4, x5), {}
108        )
109        mod_opt = inductor.compile(mod, inp)
110        self.assertEqual(mod(*inp), mod_opt(*inp))
111
112
113if __name__ == "__main__":
114    if HAS_CPU:
115        run_tests()
116