1# Owner(s): ["module: dynamo"] 2 3import logging 4import unittest 5 6import torch 7import torch._dynamo 8import torch._dynamo.config 9import torch._dynamo.test_case 10from torch._dynamo.comptime import comptime 11from torch._dynamo.exc import Unsupported 12from torch.testing._internal.common_device_type import skipIf 13from torch.testing._internal.common_utils import ( 14 IS_FBCODE, 15 munge_exc, 16 skipIfWindows, 17 TEST_Z3, 18) 19from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test 20 21 22class ExcTests(LoggingTestCase): 23 maxDiff = None 24 25 def test_unsupported_real_stack(self): 26 # exercise Unsupported constructor and augment_exc_message 27 def fn002(x): 28 torch._dynamo.graph_break() 29 30 def fn001(x): 31 x = x + 1 32 fn002(x) 33 34 self.assertExpectedInlineMunged( 35 Unsupported, 36 lambda: torch.compile(fn001, backend="eager", fullgraph=True)( 37 torch.randn(1) 38 ), 39 """\ 40'skip function graph_break in file _dynamo/decorators.py' 41 42from user code: 43 File "test_exc.py", line N, in fn001 44 fn002(x) 45 File "test_exc.py", line N, in fn002 46 torch._dynamo.graph_break()""", 47 ) 48 49 @torch._dynamo.config.patch(verbose=True, suppress_errors=True) 50 @make_logging_test() 51 @unittest.skipIf(IS_FBCODE, "stack trace slightly different in fbcode") 52 def test_internal_error_suppress_errors(self, records): 53 def fn001(x): 54 def f(ctx): 55 raise AssertionError 56 57 comptime(f) 58 59 torch.compile(fn001, backend="eager")(torch.randn(1)) 60 61 record = self.getRecord(records, "WON'T CONVERT") 62 63 self.assertExpectedInline( 64 munge_exc(record.getMessage()), 65 """\ 66WON'T CONVERT fn001 test_exc.py line N 67========== TorchDynamo Stack Trace ========== 68Traceback (most recent call last): 69 File "test_exc.py", line N, in f 70 raise AssertionError 71AssertionError: 72 73from user code: 74 File "test_exc.py", line N, in fn001 75 comptime(f) 76 77 78========== The above exception occurred while processing the following code ========== 79 80 File "test_exc.py", line N, in test_internal_error_suppress_errors 81 torch.compile(fn001, backend="eager")(torch.randn(1)) 82 File "test_exc.py", line N, in fn001 83 comptime(f) 84 85==========""", 86 ) 87 88 @make_logging_test() 89 def test_not_implemented_error(self, records): 90 def fn001(x): 91 def f(ctx): 92 raise NotImplementedError 93 94 # Ensure graph break is not possible 95 for i in range(3): 96 comptime(f) 97 98 torch.compile(fn001, backend="eager")(torch.randn(1)) 99 100 record = self.getRecord(records, "WON'T CONVERT") 101 102 self.assertExpectedInline( 103 munge_exc(record.getMessage()), 104 """\ 105WON'T CONVERT fn001 test_exc.py line N 106due to: 107Traceback (most recent call last): 108 File "test_exc.py", line N, in f 109 raise NotImplementedError 110torch._dynamo.exc.InternalTorchDynamoError: NotImplementedError: 111 112from user code: 113 File "test_exc.py", line N, in fn001 114 comptime(f)""", 115 ) 116 117 @torch._dynamo.config.patch(inject_BUILD_SET_unimplemented_TESTING_ONLY=True) 118 @make_logging_test(dynamo=logging.DEBUG) 119 def test_unsupported_error(self, records): 120 def fn001(x): 121 return {1, 2} 122 123 torch.compile(fn001, backend="eager")(torch.randn(1)) 124 125 # TODO: There is no graph break log! This is because the graph break 126 # logging is not in a centralized location; unsupported 127 # instruction bypasses it 128 self.getRecord(records, "Graph break:") 129 130 @torch._dynamo.config.patch(suppress_errors=False) 131 def test_internal_error_no_suppress(self): 132 def fn001(x): 133 # NB: avoid decorator, as 3.11 changed the line number attributed 134 # in this situation 135 def f(ctx): 136 raise AssertionError 137 138 comptime(f) 139 140 # NB: OK for user code to be truncated here, because the regular 141 # exception backtrace has the rest of the crumbs 142 self.assertExpectedInlineMunged( 143 AssertionError, 144 lambda: torch.compile(fn001, backend="eager")(torch.randn(1)), 145 """\ 146 147 148from user code: 149 File "test_exc.py", line N, in fn001 150 comptime(f)""", 151 ) 152 153 @make_logging_test(graph_breaks=True) 154 def test_graph_break_log(self, records): 155 def fn002(x): 156 x = x + 1 157 torch._dynamo.graph_break() 158 x = x + 1 159 return x 160 161 def fn001(x): 162 return fn002(x) 163 164 torch.compile(fn001, backend="eager")(torch.randn(1)) 165 166 record = self.getRecord(records, "Graph break:") 167 168 # TODO: This should also report the enclosing frames; need to plumb 169 # frame object to it 170 self.assertExpectedInline( 171 munge_exc(record.getMessage()), 172 """\ 173Graph break: from user code at: 174 File "test_exc.py", line N, in fn001 175 return fn002(x) 176 File "test_exc.py", line N, in fn002 177 torch._dynamo.graph_break() 178""", # noqa: B950 179 ) 180 181 @torch._dynamo.config.patch(suppress_errors=False) 182 def test_backend_suppress_line(self): 183 def fn001(x): 184 x = torch.relu(x) 185 return x + 1 186 187 # Do NOT let this get attributed to x + 1 188 self.assertExpectedInlineMunged( 189 torch._dynamo.exc.BackendCompilerFailed, 190 lambda: torch.compile(fn001, backend="relu_compile_error_TESTING_ONLY")( 191 torch.randn(1) 192 ), 193 """\ 194backend='relu_compile_error_TESTING_ONLY' raised: 195ReluCompileError:""", 196 ) 197 198 @skipIf(not TEST_Z3, "z3 not installed") 199 @torch._dynamo.config.patch( 200 assume_static_by_default=False, 201 suppress_errors=False, 202 ) 203 @torch.fx.experimental._config.patch( 204 inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True, 205 translation_validation=True, 206 translation_validation_no_bisect=True, 207 ) 208 @skipIfWindows( 209 msg='AssertionError: "tran[551 chars]s1 s2 s3) s0)\n ==> (<= (+ s1 s2) (+ s0 (* -1[511 chars][0])' # noqa: PLR0133 210 != 'tran[551 chars]s1 s2) (+ s0 (* -1 s3)))\n ==> (<= (+ s1 s2) [483 chars][0])"' 211 ) 212 def test_trigger_on_error(self): 213 from torch.fx.experimental.validator import ValidationException 214 215 @torch.compile 216 def fn(x, shape): 217 return x.split(shape) 218 219 self.assertExpectedInlineMunged( 220 ValidationException, 221 lambda: fn(torch.randn(20), (5, 10, 5)), 222 """\ 223translation validation failed. 224 225Model: 226 ==> L['shape'][0]: 0 227 ==> L['shape'][1]: 1 228 ==> L['shape'][2]: 1 229 ==> L['x'].size()[0]: 3 230 ==> L['x'].storage_offset(): 0 231 ==> L['x'].stride()[0]: 1 232 ==> s0: 3 233 ==> s1: 0 234 ==> s2: 1 235 ==> s3: 1 236 237Assertions: 238 ==> (== 0 L['x'].storage_offset()) 239 ==> (== 1 L['x'].stride()[0]) 240 ==> (== L['shape'][0] s1) 241 ==> (== L['shape'][1] s2) 242 ==> (== L['shape'][2] s3) 243 ==> (== L['x'].size()[0] s0) 244 ==> (> s0 1) 245 ==> (True) 246 247Target Expressions: 248 ==> (!= (+ s1 s2 s3) s0) 249 ==> (<= (+ s1 s2 s3) s0) 250 ==> (<= (+ s1 s2) (+ s0 (* -1 s3))) 251 ==> (<= (+ s1 s2) s0) 252 ==> (<= 0 s1) 253 ==> (<= 0 s2) 254 ==> (<= 0 s3) 255 ==> (<= 2 s0) 256 ==> (<= s1 (+ s0 (* -1 s2))) 257 ==> (== 0 L['x'].storage_offset()) 258 ==> (== 1 L['x'].stride()[0]) 259 ==> (== L['shape'][0] s1) 260 ==> (== L['shape'][1] s2) 261 ==> (== L['shape'][2] s3) 262 ==> (== L['x'].size()[0] s0) 263 ==> (> s0 0) 264 ==> (>= 0 s1) 265 ==> (And (<= (+ s1 s2) s0) (<= (* -1 s0) (+ s1 s2))) 266 267Failed Source Expressions: 268 ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", 269 ) 270 271 @skipIf(not TEST_Z3, "z3 not installed") 272 @torch._dynamo.config.patch( 273 assume_static_by_default=False, 274 suppress_errors=False, 275 ) 276 @torch.fx.experimental._config.patch( 277 inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True, 278 translation_validation=True, 279 ) 280 def test_trigger_bisect_on_error(self): 281 from torch.fx.experimental.validator import BisectValidationException 282 283 @torch.compile 284 def fn(x, shape): 285 return x.split(shape) 286 287 self.assertExpectedInlineMunged( 288 BisectValidationException, 289 lambda: fn(torch.randn(20), (5, 10, 5)), 290 """\ 291translation validation failed when evaluating: Eq(s1 + s2 + s3, s0) 292 293Failure occurred while running node: 294 %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {}) 295 296Model: 297 ==> L['shape'][0]: 1 298 ==> L['shape'][1]: 1 299 ==> L['shape'][2]: 0 300 ==> L['x'].size()[0]: 3 301 ==> L['x'].storage_offset(): 0 302 ==> L['x'].stride()[0]: 1 303 ==> s0: 3 304 ==> s1: 1 305 ==> s2: 1 306 ==> s3: 0 307 308Assertions: 309 ==> (== 0 L['x'].storage_offset()) 310 ==> (== 1 L['x'].stride()[0]) 311 ==> (== L['shape'][0] s1) 312 ==> (== L['shape'][1] s2) 313 ==> (== L['shape'][2] s3) 314 ==> (== L['x'].size()[0] s0) 315 ==> (> s0 1) 316 317Target Expressions: 318 ==> (!= (+ s1 s2 s3) s0) 319 ==> (<= 0 s1) 320 ==> (<= 0 s2) 321 ==> (<= 0 s3) 322 ==> (<= 2 s0) 323 ==> (== 0 L['x'].storage_offset()) 324 ==> (== 1 L['x'].stride()[0]) 325 ==> (== L['shape'][0] s1) 326 ==> (== L['shape'][1] s2) 327 ==> (== L['shape'][2] s3) 328 ==> (== L['x'].size()[0] s0) 329 ==> (> s0 0) 330 331Failed Source Expressions: 332 ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", 333 ) 334 335 336if __name__ == "__main__": 337 from torch._dynamo.test_case import run_tests 338 339 run_tests() 340