xref: /aosp_15_r20/external/pytorch/test/dynamo/test_minifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import unittest
3
4import torch._dynamo
5from torch._dynamo.test_minifier_common import MinifierTestBase
6from torch.testing._internal.common_utils import skipIfNNModuleInlined
7
8
9requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
10
11
12class MinifierTests(MinifierTestBase):
13    # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA)
14    def _test_after_dynamo(self, device, backend, expected_error):
15        run_code = f"""\
16@torch._dynamo.optimize({backend!r})
17def inner(x):
18    for _ in range(10):
19        x = torch.sin(x)
20    x = torch.relu(x)
21    for _ in range(10):
22        x = torch.cos(x)
23    return x
24
25inner(torch.randn(20, 20).to("{device}"))
26"""
27        self._run_full_test(run_code, "dynamo", expected_error, isolate=False)
28
29    def test_after_dynamo_cpu_compile_error(self):
30        self._test_after_dynamo(
31            "cpu", "relu_compile_error_TESTING_ONLY", "ReluCompileError"
32        )
33
34    def test_after_dynamo_cpu_runtime_error(self):
35        self._test_after_dynamo(
36            "cpu", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
37        )
38
39    def test_after_dynamo_cpu_accuracy_error(self):
40        self._test_after_dynamo(
41            "cpu", "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
42        )
43
44    @requires_cuda
45    def test_after_dynamo_cuda_compile_error(self):
46        self._test_after_dynamo(
47            "cuda", "relu_compile_error_TESTING_ONLY", "ReluCompileError"
48        )
49
50    @requires_cuda
51    def test_after_dynamo_cuda_runtime_error(self):
52        self._test_after_dynamo(
53            "cuda", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
54        )
55
56    @requires_cuda
57    def test_after_dynamo_cuda_accuracy_error(self):
58        self._test_after_dynamo(
59            "cuda", "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
60        )
61
62    def test_after_dynamo_non_leaf_compile_error(self):
63        run_code = """\
64@torch._dynamo.optimize("non_leaf_compile_error_TESTING_ONLY")
65def inner(x):
66    return x + 1
67
68inner(torch.randn(20, 20, requires_grad=True) + 1)
69"""
70        self._run_full_test(
71            run_code, "dynamo", "TestingOnlyCompileError", isolate=False
72        )
73
74    # Ensure that the testing backends pass when relu is not present.
75    def _test_after_dynamo_backend_passes(self, device, backend):
76        @torch._dynamo.optimize(backend)
77        def inner(x):
78            for _ in range(10):
79                x = torch.sin(x)
80            for _ in range(10):
81                x = torch.cos(x)
82            return x
83
84        inner(torch.randn(20, 20).to(device))
85
86    def test_after_dynamo_cpu_compile_backend_passes(self):
87        self._test_after_dynamo_backend_passes("cpu", "relu_compile_error_TESTING_ONLY")
88
89    def test_after_dynamo_cpu_runtime_backend_passes(self):
90        self._test_after_dynamo_backend_passes("cpu", "relu_runtime_error_TESTING_ONLY")
91
92    def test_after_dynamo_cpu_accuracy_backend_passes(self):
93        self._test_after_dynamo_backend_passes(
94            "cpu", "relu_accuracy_error_TESTING_ONLY"
95        )
96
97    @requires_cuda
98    def test_after_dynamo_cuda_compile_backend_passes(self):
99        self._test_after_dynamo_backend_passes(
100            "cuda", "relu_compile_error_TESTING_ONLY"
101        )
102
103    @requires_cuda
104    def test_after_dynamo_cuda_runtime_backend_passes(self):
105        self._test_after_dynamo_backend_passes(
106            "cuda", "relu_runtime_error_TESTING_ONLY"
107        )
108
109    @requires_cuda
110    def test_after_dynamo_cuda_accuracy_backend_passes(self):
111        self._test_after_dynamo_backend_passes(
112            "cuda", "relu_accuracy_error_TESTING_ONLY"
113        )
114
115    # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd
116    @skipIfNNModuleInlined()
117    @requires_cuda
118    def test_cpu_cuda_module_after_dynamo(self):
119        backend_name = "relu_compile_error_TESTING_ONLY"
120        run_code = f"""\
121class CpuCudaModule(torch.nn.Module):
122    def __init__(self) -> None:
123        super().__init__()
124        self.m_x = torch.nn.Linear(20, 20).cuda()
125        self.m_y = torch.nn.Linear(20, 20)
126        self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda())
127        self.p_y = torch.nn.Parameter(torch.randn(20, 20))
128        self.b_x = torch.nn.Buffer(torch.ones(20, 20).cuda())
129        self.b_y = torch.nn.Buffer(torch.ones(20, 20))
130
131    def forward(self, x, y):
132        return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y
133
134mod = CpuCudaModule()
135
136@torch._dynamo.optimize({backend_name!r})
137def inner(x1, y1):
138    x2 = torch.randn(20, 20).cuda()
139    y2 = torch.randn(20, 20)
140    x3, y3 = mod(x1 + x2, y1 + y2)
141    return torch.relu(x3.cpu() + y3)
142
143inner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
144"""
145
146        res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False)
147
148        self.assertExpectedInline(
149            res.minifier_module(),
150            """\
151class Repro(torch.nn.Module):
152    def __init__(self) -> None:
153        super().__init__()
154        self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda()
155        self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True)
156        self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).cuda())
157        self.register_buffer('G__mod___b_y', torch.randn([20, 20], dtype=torch.float32))
158        self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32, device="cuda"))
159        self.G__mod___p_y = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32))
160
161    def forward(self, L_x1_ : torch.Tensor, L_y1_ : torch.Tensor):
162        l_x1_ = L_x1_
163        l_y1_ = L_y1_
164        randn = torch.randn(20, 20)
165        x2 = randn.cuda();  randn = None
166        y2 = torch.randn(20, 20)
167        add = l_x1_ + x2;  l_x1_ = x2 = None
168        add_1 = l_y1_ + y2;  l_y1_ = y2 = None
169        g__mod___m_x = self.G__mod___m_x(add);  add = None
170        g__mod___p_x = self.G__mod___p_x
171        add_2 = g__mod___m_x + g__mod___p_x;  g__mod___m_x = g__mod___p_x = None
172        g__mod___b_x = self.G__mod___b_x
173        x3 = add_2 + g__mod___b_x;  add_2 = g__mod___b_x = None
174        g__mod___m_y = self.G__mod___m_y(add_1);  add_1 = None
175        g__mod___p_y = self.G__mod___p_y
176        add_4 = g__mod___m_y + g__mod___p_y;  g__mod___m_y = g__mod___p_y = None
177        g__mod___b_y = self.G__mod___b_y
178        y3 = add_4 + g__mod___b_y;  add_4 = g__mod___b_y = None
179        cpu = x3.cpu();  x3 = None
180        add_6 = cpu + y3;  cpu = y3 = None
181        relu = torch.relu(add_6);  add_6 = None
182        return (relu,)""",
183        )
184
185    # Test if we can actually get a minified graph
186    def test_if_graph_minified(self):
187        backend_name = "relu_compile_error_TESTING_ONLY"
188        run_code = f"""\
189@torch._dynamo.optimize({backend_name!r})
190def inner(x):
191    for _ in range(20):
192        x = torch.sin(x)
193    x = torch.relu(x)
194    for _ in range(20):
195        x = torch.cos(x)
196    return x
197
198inner(torch.randn(20, 20))
199"""
200
201        res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False)
202
203        self.assertExpectedInline(
204            res.repro_module(),
205            """\
206class Repro(torch.nn.Module):
207    def __init__(self) -> None:
208        super().__init__()
209
210    def forward(self, x_19):
211        x_20 = torch.relu(x_19);  x_19 = None
212        return (x_20,)""",
213        )
214
215
216if __name__ == "__main__":
217    from torch._dynamo.test_case import run_tests
218
219    run_tests()
220