xref: /aosp_15_r20/external/pytorch/test/dynamo/test_compile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import inspect
4import io
5import os
6import tempfile
7from unittest.mock import patch
8
9import torch
10from torch._dynamo.test_case import run_tests, TestCase
11from torch._dynamo.testing import CompileCounter
12
13
14class ToyModel(torch.nn.Module):
15    def __init__(self) -> None:
16        super().__init__()
17        self.linear = torch.nn.Linear(10, 10)
18        self.relu = torch.nn.ReLU()
19
20    def forward(self, x):
21        return self.relu(self.linear(x))
22
23
24class InPlaceCompilationTests(TestCase):
25    def test_compilation(self):
26        torch._dynamo.reset()
27        model = ToyModel()
28        cnt = CompileCounter()
29        model.compile(backend=cnt)
30        x = torch.randn(10, 10)
31        model(x)
32        self.assertEqual(cnt.frame_count, 1)
33
34    def test_overwrite_call_impl(self):
35        torch._dynamo.reset()
36        model = ToyModel()
37        self.assertTrue(model._compiled_call_impl is None)
38        model.compile()
39        self.assertTrue(model._compiled_call_impl is not None)
40
41    def test_save(self):
42        torch._dynamo.reset()
43        model = ToyModel()
44        model.compile()
45        model(torch.randn(1, 10))
46
47        with tempfile.TemporaryDirectory() as tmpdirname:
48            torch.save(model, os.path.join(tmpdirname, "model.pt"))
49            loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
50            loaded_model(torch.randn(1, 10))
51
52    def test_state_dict_save(self):
53        torch._dynamo.reset()
54        model = ToyModel()
55        model.compile()
56        model(torch.randn(1, 10))
57        with tempfile.TemporaryDirectory() as tmpdirname:
58            torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt"))
59            loaded_model = ToyModel()
60            loaded_model.load_state_dict(
61                torch.load(os.path.join(tmpdirname, "model.pt"))
62            )
63            loaded_model(torch.randn(1, 10))
64
65    def test_jit_save(self):
66        torch._dynamo.reset()
67        model = ToyModel()
68        model.compile()
69        model(torch.randn(1, 10))
70        scripted_model = torch.jit.script(model)
71        with tempfile.TemporaryDirectory() as tmpdirname:
72            torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
73            loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
74            loaded_model(torch.randn(1, 10))
75
76    def test_compilation_callback(self):
77        torch._dynamo.reset()
78
79        @torch._dynamo.on_compile_start
80        def start_callback():
81            print("Compilation started.")
82
83        @torch._dynamo.on_compile_end
84        def end_callback():
85            print("Compilation ended.")
86
87        mod = ToyModel()
88        x = torch.randn(10, 10)
89
90        with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
91            opt_mod = torch.compile(backend="eager", fullgraph=True)(mod)
92            opt_mod(x)
93            printed_output = mock_stdout.getvalue().strip()
94
95        self.assertEqual(printed_output, "Compilation started.\nCompilation ended.")
96
97    def test_compile_eager_options(self):
98        @torch.compile(backend="eager", options={"foo": 2})
99        def f(x):
100            return x + x
101
102        f(torch.randn(3))
103
104        @torch.compile(backend="aot_eager", options={"foo": 2})
105        def g(x):
106            return x + x
107
108        g(torch.randn(3))
109
110    def test_compilation_callback_with_graph_break(self):
111        torch._dynamo.reset()
112        counter = 0
113
114        @torch._dynamo.on_compile_start
115        def start_callback():
116            nonlocal counter
117            counter += 1
118            print(f"Counter = {counter}")
119
120        @torch._dynamo.on_compile_end
121        def end_callback():
122            nonlocal counter
123            counter += 1
124            print(f"Counter = {counter}")
125
126        @torch.compile(backend="eager")
127        def fn(x):
128            x = x + 1
129            torch._dynamo.graph_break()
130            return torch.sin(x)
131
132        x = torch.randn(10, 10)
133
134        with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
135            fn(x)
136            printed_output = mock_stdout.getvalue().strip()
137
138        self.assertEqual(
139            printed_output, "Counter = 1\nCounter = 2\nCounter = 3\nCounter = 4"
140        )
141
142
143# The private variants of the below functions are extensively tested
144# So as long as the signatures match we're good
145class PublicTorchCompilerTests(TestCase):
146    def check_signature(self, public_fn_name, private_fn_name, private_namespace):
147        public_fn = getattr(torch.compiler, public_fn_name)
148        private_fn = getattr(private_namespace, private_fn_name)
149
150        public_sig = inspect.signature(public_fn)
151        private_sig = inspect.signature(private_fn)
152
153        self.assertEqual(
154            public_sig,
155            private_sig,
156            f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}",
157        )
158
159    def test_dynamo_signatures(self):
160        function_names = [
161            "reset",
162            "allow_in_graph",
163            "list_backends",
164            "assume_constant_result",
165            "disable",
166        ]
167
168        for fn_name in function_names:
169            self.check_signature(fn_name, fn_name, torch._dynamo)
170
171
172if __name__ == "__main__":
173    run_tests()
174