xref: /aosp_15_r20/external/pytorch/test/dynamo/test_minifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport unittest
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
5*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.test_minifier_common import MinifierTestBase
6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfNNModuleInlined
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerrequires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerclass MinifierTests(MinifierTestBase):
13*da0073e9SAndroid Build Coastguard Worker    # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA)
14*da0073e9SAndroid Build Coastguard Worker    def _test_after_dynamo(self, device, backend, expected_error):
15*da0073e9SAndroid Build Coastguard Worker        run_code = f"""\
16*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.optimize({backend!r})
17*da0073e9SAndroid Build Coastguard Workerdef inner(x):
18*da0073e9SAndroid Build Coastguard Worker    for _ in range(10):
19*da0073e9SAndroid Build Coastguard Worker        x = torch.sin(x)
20*da0073e9SAndroid Build Coastguard Worker    x = torch.relu(x)
21*da0073e9SAndroid Build Coastguard Worker    for _ in range(10):
22*da0073e9SAndroid Build Coastguard Worker        x = torch.cos(x)
23*da0073e9SAndroid Build Coastguard Worker    return x
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerinner(torch.randn(20, 20).to("{device}"))
26*da0073e9SAndroid Build Coastguard Worker"""
27*da0073e9SAndroid Build Coastguard Worker        self._run_full_test(run_code, "dynamo", expected_error, isolate=False)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cpu_compile_error(self):
30*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo(
31*da0073e9SAndroid Build Coastguard Worker            "cpu", "relu_compile_error_TESTING_ONLY", "ReluCompileError"
32*da0073e9SAndroid Build Coastguard Worker        )
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cpu_runtime_error(self):
35*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo(
36*da0073e9SAndroid Build Coastguard Worker            "cpu", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
37*da0073e9SAndroid Build Coastguard Worker        )
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cpu_accuracy_error(self):
40*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo(
41*da0073e9SAndroid Build Coastguard Worker            "cpu", "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
42*da0073e9SAndroid Build Coastguard Worker        )
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
45*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cuda_compile_error(self):
46*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo(
47*da0073e9SAndroid Build Coastguard Worker            "cuda", "relu_compile_error_TESTING_ONLY", "ReluCompileError"
48*da0073e9SAndroid Build Coastguard Worker        )
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
51*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cuda_runtime_error(self):
52*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo(
53*da0073e9SAndroid Build Coastguard Worker            "cuda", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
54*da0073e9SAndroid Build Coastguard Worker        )
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
57*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cuda_accuracy_error(self):
58*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo(
59*da0073e9SAndroid Build Coastguard Worker            "cuda", "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
60*da0073e9SAndroid Build Coastguard Worker        )
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_non_leaf_compile_error(self):
63*da0073e9SAndroid Build Coastguard Worker        run_code = """\
64*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.optimize("non_leaf_compile_error_TESTING_ONLY")
65*da0073e9SAndroid Build Coastguard Workerdef inner(x):
66*da0073e9SAndroid Build Coastguard Worker    return x + 1
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Workerinner(torch.randn(20, 20, requires_grad=True) + 1)
69*da0073e9SAndroid Build Coastguard Worker"""
70*da0073e9SAndroid Build Coastguard Worker        self._run_full_test(
71*da0073e9SAndroid Build Coastguard Worker            run_code, "dynamo", "TestingOnlyCompileError", isolate=False
72*da0073e9SAndroid Build Coastguard Worker        )
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    # Ensure that the testing backends pass when relu is not present.
75*da0073e9SAndroid Build Coastguard Worker    def _test_after_dynamo_backend_passes(self, device, backend):
76*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(backend)
77*da0073e9SAndroid Build Coastguard Worker        def inner(x):
78*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
79*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
80*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
81*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(x)
82*da0073e9SAndroid Build Coastguard Worker            return x
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        inner(torch.randn(20, 20).to(device))
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cpu_compile_backend_passes(self):
87*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo_backend_passes("cpu", "relu_compile_error_TESTING_ONLY")
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cpu_runtime_backend_passes(self):
90*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo_backend_passes("cpu", "relu_runtime_error_TESTING_ONLY")
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cpu_accuracy_backend_passes(self):
93*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo_backend_passes(
94*da0073e9SAndroid Build Coastguard Worker            "cpu", "relu_accuracy_error_TESTING_ONLY"
95*da0073e9SAndroid Build Coastguard Worker        )
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
98*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cuda_compile_backend_passes(self):
99*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo_backend_passes(
100*da0073e9SAndroid Build Coastguard Worker            "cuda", "relu_compile_error_TESTING_ONLY"
101*da0073e9SAndroid Build Coastguard Worker        )
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
104*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cuda_runtime_backend_passes(self):
105*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo_backend_passes(
106*da0073e9SAndroid Build Coastguard Worker            "cuda", "relu_runtime_error_TESTING_ONLY"
107*da0073e9SAndroid Build Coastguard Worker        )
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
110*da0073e9SAndroid Build Coastguard Worker    def test_after_dynamo_cuda_accuracy_backend_passes(self):
111*da0073e9SAndroid Build Coastguard Worker        self._test_after_dynamo_backend_passes(
112*da0073e9SAndroid Build Coastguard Worker            "cuda", "relu_accuracy_error_TESTING_ONLY"
113*da0073e9SAndroid Build Coastguard Worker        )
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker    # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd
116*da0073e9SAndroid Build Coastguard Worker    @skipIfNNModuleInlined()
117*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
118*da0073e9SAndroid Build Coastguard Worker    def test_cpu_cuda_module_after_dynamo(self):
119*da0073e9SAndroid Build Coastguard Worker        backend_name = "relu_compile_error_TESTING_ONLY"
120*da0073e9SAndroid Build Coastguard Worker        run_code = f"""\
121*da0073e9SAndroid Build Coastguard Workerclass CpuCudaModule(torch.nn.Module):
122*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
123*da0073e9SAndroid Build Coastguard Worker        super().__init__()
124*da0073e9SAndroid Build Coastguard Worker        self.m_x = torch.nn.Linear(20, 20).cuda()
125*da0073e9SAndroid Build Coastguard Worker        self.m_y = torch.nn.Linear(20, 20)
126*da0073e9SAndroid Build Coastguard Worker        self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda())
127*da0073e9SAndroid Build Coastguard Worker        self.p_y = torch.nn.Parameter(torch.randn(20, 20))
128*da0073e9SAndroid Build Coastguard Worker        self.b_x = torch.nn.Buffer(torch.ones(20, 20).cuda())
129*da0073e9SAndroid Build Coastguard Worker        self.b_y = torch.nn.Buffer(torch.ones(20, 20))
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    def forward(self, x, y):
132*da0073e9SAndroid Build Coastguard Worker        return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Workermod = CpuCudaModule()
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.optimize({backend_name!r})
137*da0073e9SAndroid Build Coastguard Workerdef inner(x1, y1):
138*da0073e9SAndroid Build Coastguard Worker    x2 = torch.randn(20, 20).cuda()
139*da0073e9SAndroid Build Coastguard Worker    y2 = torch.randn(20, 20)
140*da0073e9SAndroid Build Coastguard Worker    x3, y3 = mod(x1 + x2, y1 + y2)
141*da0073e9SAndroid Build Coastguard Worker    return torch.relu(x3.cpu() + y3)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Workerinner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
144*da0073e9SAndroid Build Coastguard Worker"""
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
149*da0073e9SAndroid Build Coastguard Worker            res.minifier_module(),
150*da0073e9SAndroid Build Coastguard Worker            """\
151*da0073e9SAndroid Build Coastguard Workerclass Repro(torch.nn.Module):
152*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
153*da0073e9SAndroid Build Coastguard Worker        super().__init__()
154*da0073e9SAndroid Build Coastguard Worker        self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda()
155*da0073e9SAndroid Build Coastguard Worker        self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True)
156*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).cuda())
157*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('G__mod___b_y', torch.randn([20, 20], dtype=torch.float32))
158*da0073e9SAndroid Build Coastguard Worker        self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32, device="cuda"))
159*da0073e9SAndroid Build Coastguard Worker        self.G__mod___p_y = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32))
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_x1_ : torch.Tensor, L_y1_ : torch.Tensor):
162*da0073e9SAndroid Build Coastguard Worker        l_x1_ = L_x1_
163*da0073e9SAndroid Build Coastguard Worker        l_y1_ = L_y1_
164*da0073e9SAndroid Build Coastguard Worker        randn = torch.randn(20, 20)
165*da0073e9SAndroid Build Coastguard Worker        x2 = randn.cuda();  randn = None
166*da0073e9SAndroid Build Coastguard Worker        y2 = torch.randn(20, 20)
167*da0073e9SAndroid Build Coastguard Worker        add = l_x1_ + x2;  l_x1_ = x2 = None
168*da0073e9SAndroid Build Coastguard Worker        add_1 = l_y1_ + y2;  l_y1_ = y2 = None
169*da0073e9SAndroid Build Coastguard Worker        g__mod___m_x = self.G__mod___m_x(add);  add = None
170*da0073e9SAndroid Build Coastguard Worker        g__mod___p_x = self.G__mod___p_x
171*da0073e9SAndroid Build Coastguard Worker        add_2 = g__mod___m_x + g__mod___p_x;  g__mod___m_x = g__mod___p_x = None
172*da0073e9SAndroid Build Coastguard Worker        g__mod___b_x = self.G__mod___b_x
173*da0073e9SAndroid Build Coastguard Worker        x3 = add_2 + g__mod___b_x;  add_2 = g__mod___b_x = None
174*da0073e9SAndroid Build Coastguard Worker        g__mod___m_y = self.G__mod___m_y(add_1);  add_1 = None
175*da0073e9SAndroid Build Coastguard Worker        g__mod___p_y = self.G__mod___p_y
176*da0073e9SAndroid Build Coastguard Worker        add_4 = g__mod___m_y + g__mod___p_y;  g__mod___m_y = g__mod___p_y = None
177*da0073e9SAndroid Build Coastguard Worker        g__mod___b_y = self.G__mod___b_y
178*da0073e9SAndroid Build Coastguard Worker        y3 = add_4 + g__mod___b_y;  add_4 = g__mod___b_y = None
179*da0073e9SAndroid Build Coastguard Worker        cpu = x3.cpu();  x3 = None
180*da0073e9SAndroid Build Coastguard Worker        add_6 = cpu + y3;  cpu = y3 = None
181*da0073e9SAndroid Build Coastguard Worker        relu = torch.relu(add_6);  add_6 = None
182*da0073e9SAndroid Build Coastguard Worker        return (relu,)""",
183*da0073e9SAndroid Build Coastguard Worker        )
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    # Test if we can actually get a minified graph
186*da0073e9SAndroid Build Coastguard Worker    def test_if_graph_minified(self):
187*da0073e9SAndroid Build Coastguard Worker        backend_name = "relu_compile_error_TESTING_ONLY"
188*da0073e9SAndroid Build Coastguard Worker        run_code = f"""\
189*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.optimize({backend_name!r})
190*da0073e9SAndroid Build Coastguard Workerdef inner(x):
191*da0073e9SAndroid Build Coastguard Worker    for _ in range(20):
192*da0073e9SAndroid Build Coastguard Worker        x = torch.sin(x)
193*da0073e9SAndroid Build Coastguard Worker    x = torch.relu(x)
194*da0073e9SAndroid Build Coastguard Worker    for _ in range(20):
195*da0073e9SAndroid Build Coastguard Worker        x = torch.cos(x)
196*da0073e9SAndroid Build Coastguard Worker    return x
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Workerinner(torch.randn(20, 20))
199*da0073e9SAndroid Build Coastguard Worker"""
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker        res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False)
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
204*da0073e9SAndroid Build Coastguard Worker            res.repro_module(),
205*da0073e9SAndroid Build Coastguard Worker            """\
206*da0073e9SAndroid Build Coastguard Workerclass Repro(torch.nn.Module):
207*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
208*da0073e9SAndroid Build Coastguard Worker        super().__init__()
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker    def forward(self, x_19):
211*da0073e9SAndroid Build Coastguard Worker        x_20 = torch.relu(x_19);  x_19 = None
212*da0073e9SAndroid Build Coastguard Worker        return (x_20,)""",
213*da0073e9SAndroid Build Coastguard Worker        )
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
217*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    run_tests()
220