xref: /aosp_15_r20/external/pytorch/test/dynamo/test_exc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import logging
4import unittest
5
6import torch
7import torch._dynamo
8import torch._dynamo.config
9import torch._dynamo.test_case
10from torch._dynamo.comptime import comptime
11from torch._dynamo.exc import Unsupported
12from torch.testing._internal.common_device_type import skipIf
13from torch.testing._internal.common_utils import (
14    IS_FBCODE,
15    munge_exc,
16    skipIfWindows,
17    TEST_Z3,
18)
19from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
20
21
22class ExcTests(LoggingTestCase):
23    maxDiff = None
24
25    def test_unsupported_real_stack(self):
26        # exercise Unsupported constructor and augment_exc_message
27        def fn002(x):
28            torch._dynamo.graph_break()
29
30        def fn001(x):
31            x = x + 1
32            fn002(x)
33
34        self.assertExpectedInlineMunged(
35            Unsupported,
36            lambda: torch.compile(fn001, backend="eager", fullgraph=True)(
37                torch.randn(1)
38            ),
39            """\
40'skip function graph_break in file _dynamo/decorators.py'
41
42from user code:
43   File "test_exc.py", line N, in fn001
44    fn002(x)
45  File "test_exc.py", line N, in fn002
46    torch._dynamo.graph_break()""",
47        )
48
49    @torch._dynamo.config.patch(verbose=True, suppress_errors=True)
50    @make_logging_test()
51    @unittest.skipIf(IS_FBCODE, "stack trace slightly different in fbcode")
52    def test_internal_error_suppress_errors(self, records):
53        def fn001(x):
54            def f(ctx):
55                raise AssertionError
56
57            comptime(f)
58
59        torch.compile(fn001, backend="eager")(torch.randn(1))
60
61        record = self.getRecord(records, "WON'T CONVERT")
62
63        self.assertExpectedInline(
64            munge_exc(record.getMessage()),
65            """\
66WON'T CONVERT fn001 test_exc.py line N
67========== TorchDynamo Stack Trace ==========
68Traceback (most recent call last):
69  File "test_exc.py", line N, in f
70    raise AssertionError
71AssertionError:
72
73from user code:
74   File "test_exc.py", line N, in fn001
75    comptime(f)
76
77
78========== The above exception occurred while processing the following code ==========
79
80  File "test_exc.py", line N, in test_internal_error_suppress_errors
81    torch.compile(fn001, backend="eager")(torch.randn(1))
82  File "test_exc.py", line N, in fn001
83    comptime(f)
84
85==========""",
86        )
87
88    @make_logging_test()
89    def test_not_implemented_error(self, records):
90        def fn001(x):
91            def f(ctx):
92                raise NotImplementedError
93
94            # Ensure graph break is not possible
95            for i in range(3):
96                comptime(f)
97
98        torch.compile(fn001, backend="eager")(torch.randn(1))
99
100        record = self.getRecord(records, "WON'T CONVERT")
101
102        self.assertExpectedInline(
103            munge_exc(record.getMessage()),
104            """\
105WON'T CONVERT fn001 test_exc.py line N
106due to:
107Traceback (most recent call last):
108  File "test_exc.py", line N, in f
109    raise NotImplementedError
110torch._dynamo.exc.InternalTorchDynamoError: NotImplementedError:
111
112from user code:
113   File "test_exc.py", line N, in fn001
114    comptime(f)""",
115        )
116
117    @torch._dynamo.config.patch(inject_BUILD_SET_unimplemented_TESTING_ONLY=True)
118    @make_logging_test(dynamo=logging.DEBUG)
119    def test_unsupported_error(self, records):
120        def fn001(x):
121            return {1, 2}
122
123        torch.compile(fn001, backend="eager")(torch.randn(1))
124
125        # TODO: There is no graph break log!  This is because the graph break
126        # logging is not in a centralized location; unsupported
127        # instruction bypasses it
128        self.getRecord(records, "Graph break:")
129
130    @torch._dynamo.config.patch(suppress_errors=False)
131    def test_internal_error_no_suppress(self):
132        def fn001(x):
133            # NB: avoid decorator, as 3.11 changed the line number attributed
134            # in this situation
135            def f(ctx):
136                raise AssertionError
137
138            comptime(f)
139
140        # NB: OK for user code to be truncated here, because the regular
141        # exception backtrace has the rest of the crumbs
142        self.assertExpectedInlineMunged(
143            AssertionError,
144            lambda: torch.compile(fn001, backend="eager")(torch.randn(1)),
145            """\
146
147
148from user code:
149   File "test_exc.py", line N, in fn001
150    comptime(f)""",
151        )
152
153    @make_logging_test(graph_breaks=True)
154    def test_graph_break_log(self, records):
155        def fn002(x):
156            x = x + 1
157            torch._dynamo.graph_break()
158            x = x + 1
159            return x
160
161        def fn001(x):
162            return fn002(x)
163
164        torch.compile(fn001, backend="eager")(torch.randn(1))
165
166        record = self.getRecord(records, "Graph break:")
167
168        # TODO: This should also report the enclosing frames; need to plumb
169        # frame object to it
170        self.assertExpectedInline(
171            munge_exc(record.getMessage()),
172            """\
173Graph break: from user code at:
174  File "test_exc.py", line N, in fn001
175    return fn002(x)
176  File "test_exc.py", line N, in fn002
177    torch._dynamo.graph_break()
178""",  # noqa: B950
179        )
180
181    @torch._dynamo.config.patch(suppress_errors=False)
182    def test_backend_suppress_line(self):
183        def fn001(x):
184            x = torch.relu(x)
185            return x + 1
186
187        # Do NOT let this get attributed to x + 1
188        self.assertExpectedInlineMunged(
189            torch._dynamo.exc.BackendCompilerFailed,
190            lambda: torch.compile(fn001, backend="relu_compile_error_TESTING_ONLY")(
191                torch.randn(1)
192            ),
193            """\
194backend='relu_compile_error_TESTING_ONLY' raised:
195ReluCompileError:""",
196        )
197
198    @skipIf(not TEST_Z3, "z3 not installed")
199    @torch._dynamo.config.patch(
200        assume_static_by_default=False,
201        suppress_errors=False,
202    )
203    @torch.fx.experimental._config.patch(
204        inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
205        translation_validation=True,
206        translation_validation_no_bisect=True,
207    )
208    @skipIfWindows(
209        msg='AssertionError: "tran[551 chars]s1 s2 s3) s0)\n  ==> (<= (+ s1 s2) (+ s0 (* -1[511 chars][0])'  # noqa: PLR0133
210        != 'tran[551 chars]s1 s2) (+ s0 (* -1 s3)))\n  ==> (<= (+ s1 s2) [483 chars][0])"'
211    )
212    def test_trigger_on_error(self):
213        from torch.fx.experimental.validator import ValidationException
214
215        @torch.compile
216        def fn(x, shape):
217            return x.split(shape)
218
219        self.assertExpectedInlineMunged(
220            ValidationException,
221            lambda: fn(torch.randn(20), (5, 10, 5)),
222            """\
223translation validation failed.
224
225Model:
226  ==> L['shape'][0]: 0
227  ==> L['shape'][1]: 1
228  ==> L['shape'][2]: 1
229  ==> L['x'].size()[0]: 3
230  ==> L['x'].storage_offset(): 0
231  ==> L['x'].stride()[0]: 1
232  ==> s0: 3
233  ==> s1: 0
234  ==> s2: 1
235  ==> s3: 1
236
237Assertions:
238  ==> (== 0 L['x'].storage_offset())
239  ==> (== 1 L['x'].stride()[0])
240  ==> (== L['shape'][0] s1)
241  ==> (== L['shape'][1] s2)
242  ==> (== L['shape'][2] s3)
243  ==> (== L['x'].size()[0] s0)
244  ==> (> s0 1)
245  ==> (True)
246
247Target Expressions:
248  ==> (!= (+ s1 s2 s3) s0)
249  ==> (<= (+ s1 s2 s3) s0)
250  ==> (<= (+ s1 s2) (+ s0 (* -1 s3)))
251  ==> (<= (+ s1 s2) s0)
252  ==> (<= 0 s1)
253  ==> (<= 0 s2)
254  ==> (<= 0 s3)
255  ==> (<= 2 s0)
256  ==> (<= s1 (+ s0 (* -1 s2)))
257  ==> (== 0 L['x'].storage_offset())
258  ==> (== 1 L['x'].stride()[0])
259  ==> (== L['shape'][0] s1)
260  ==> (== L['shape'][1] s2)
261  ==> (== L['shape'][2] s3)
262  ==> (== L['x'].size()[0] s0)
263  ==> (> s0 0)
264  ==> (>= 0 s1)
265  ==> (And (<= (+ s1 s2) s0) (<= (* -1 s0) (+ s1 s2)))
266
267Failed Source Expressions:
268  ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
269        )
270
271    @skipIf(not TEST_Z3, "z3 not installed")
272    @torch._dynamo.config.patch(
273        assume_static_by_default=False,
274        suppress_errors=False,
275    )
276    @torch.fx.experimental._config.patch(
277        inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
278        translation_validation=True,
279    )
280    def test_trigger_bisect_on_error(self):
281        from torch.fx.experimental.validator import BisectValidationException
282
283        @torch.compile
284        def fn(x, shape):
285            return x.split(shape)
286
287        self.assertExpectedInlineMunged(
288            BisectValidationException,
289            lambda: fn(torch.randn(20), (5, 10, 5)),
290            """\
291translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)
292
293Failure occurred while running node:
294    %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
295
296Model:
297  ==> L['shape'][0]: 1
298  ==> L['shape'][1]: 1
299  ==> L['shape'][2]: 0
300  ==> L['x'].size()[0]: 3
301  ==> L['x'].storage_offset(): 0
302  ==> L['x'].stride()[0]: 1
303  ==> s0: 3
304  ==> s1: 1
305  ==> s2: 1
306  ==> s3: 0
307
308Assertions:
309  ==> (== 0 L['x'].storage_offset())
310  ==> (== 1 L['x'].stride()[0])
311  ==> (== L['shape'][0] s1)
312  ==> (== L['shape'][1] s2)
313  ==> (== L['shape'][2] s3)
314  ==> (== L['x'].size()[0] s0)
315  ==> (> s0 1)
316
317Target Expressions:
318  ==> (!= (+ s1 s2 s3) s0)
319  ==> (<= 0 s1)
320  ==> (<= 0 s2)
321  ==> (<= 0 s3)
322  ==> (<= 2 s0)
323  ==> (== 0 L['x'].storage_offset())
324  ==> (== 1 L['x'].stride()[0])
325  ==> (== L['shape'][0] s1)
326  ==> (== L['shape'][1] s2)
327  ==> (== L['shape'][2] s3)
328  ==> (== L['x'].size()[0] s0)
329  ==> (> s0 0)
330
331Failed Source Expressions:
332  ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
333        )
334
335
336if __name__ == "__main__":
337    from torch._dynamo.test_case import run_tests
338
339    run_tests()
340