1# Owner(s): ["module: dynamo"] 2import contextlib 3import functools 4import logging 5import os 6import unittest.mock 7 8import torch 9import torch._dynamo.test_case 10import torch._dynamo.testing 11import torch.distributed as dist 12from torch._dynamo.testing import empty_line_normalizer, skipIfNotPy311 13from torch._dynamo.trace_rules import _as_posix_path 14from torch.nn.parallel import DistributedDataParallel as DDP 15from torch.testing._internal.common_utils import ( 16 find_free_port, 17 munge_exc, 18 skipIfTorchDynamo, 19) 20from torch.testing._internal.inductor_utils import HAS_CUDA 21from torch.testing._internal.logging_utils import ( 22 LoggingTestCase, 23 make_logging_test, 24 make_settings_test, 25) 26 27 28requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 29requires_distributed = functools.partial( 30 unittest.skipIf, not dist.is_available(), "requires distributed" 31) 32 33 34def example_fn(a): 35 output = a.mul(torch.ones(1000, 1000)) 36 output = output.add(torch.ones(1000, 1000)) 37 return output 38 39 40def dynamo_error_fn(a): 41 output = a.mul(torch.ones(1000, 1000)) 42 output = output.add(torch.ones(10, 10)) 43 return output 44 45 46def inductor_error_fn(a): 47 output = torch.round(a) 48 return output 49 50 51def inductor_schedule_fn(a): 52 output = a.add(torch.ones(1000, 1000, device="cuda")) 53 return output 54 55 56ARGS = (torch.ones(1000, 1000, requires_grad=True),) 57 58 59def multi_record_test(num_records, **kwargs): 60 @make_logging_test(**kwargs) 61 def fn(self, records): 62 fn_opt = torch._dynamo.optimize("inductor")(example_fn) 63 fn_opt(*ARGS) 64 self.assertEqual(len(records), num_records) 65 66 return fn 67 68 69def within_range_record_test(num_records_lower, num_records_higher, **kwargs): 70 @make_logging_test(**kwargs) 71 def fn(self, records): 72 fn_opt = torch._dynamo.optimize("inductor")(example_fn) 73 fn_opt(*ARGS) 74 self.assertGreaterEqual(len(records), num_records_lower) 75 self.assertLessEqual(len(records), num_records_higher) 76 77 return fn 78 79 80def single_record_test(**kwargs): 81 return multi_record_test(1, **kwargs) 82 83 84class LoggingTests(LoggingTestCase): 85 test_bytecode = multi_record_test(2, bytecode=True) 86 test_output_code = multi_record_test(2, output_code=True) 87 test_aot_graphs = multi_record_test(3, aot_graphs=True) 88 89 @requires_cuda 90 @make_logging_test(schedule=True) 91 def test_schedule(self, records): 92 fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn) 93 fn_opt(torch.ones(1000, 1000, device="cuda")) 94 self.assertGreater(len(records), 0) 95 self.assertLess(len(records), 5) 96 97 @requires_cuda 98 @make_logging_test(fusion=True) 99 def test_fusion(self, records): 100 fn_opt = torch._dynamo.optimize("inductor")(inductor_schedule_fn) 101 fn_opt(torch.ones(1000, 1000, device="cuda")) 102 self.assertGreater(len(records), 0) 103 self.assertLess(len(records), 8) 104 105 @requires_cuda 106 @make_logging_test(cudagraphs=True) 107 def test_cudagraphs(self, records): 108 fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) 109 fn_opt(torch.ones(1000, 1000, device="cuda")) 110 self.assertGreater(len(records), 0) 111 self.assertLess(len(records), 8) 112 113 @make_logging_test(recompiles=True) 114 def test_recompiles(self, records): 115 def fn(x, y): 116 return torch.add(x, y) 117 118 fn_opt = torch._dynamo.optimize("inductor")(fn) 119 fn_opt(torch.ones(1000, 1000), torch.ones(1000, 1000)) 120 fn_opt(torch.ones(1000, 1000), 1) 121 self.assertGreater(len(records), 0) 122 123 test_dynamo_debug = within_range_record_test(30, 90, dynamo=logging.DEBUG) 124 test_dynamo_info = within_range_record_test(2, 10, dynamo=logging.INFO) 125 126 @skipIfTorchDynamo("too slow") 127 @make_logging_test(dynamo=logging.DEBUG) 128 def test_dynamo_debug_default_off_artifacts(self, records): 129 fn_opt = torch._dynamo.optimize("inductor")(example_fn) 130 fn_opt(torch.ones(1000, 1000)) 131 self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0) 132 self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0) 133 134 @make_logging_test() 135 def test_dynamo_error(self, records): 136 try: 137 fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn) 138 fn_opt(*ARGS) 139 except Exception: 140 pass 141 record = self.getRecord(records, "WON'T CONVERT") 142 self.assertExpectedInline( 143 munge_exc(record.getMessage()), 144 """\ 145WON'T CONVERT dynamo_error_fn test_logging.py line N 146due to: 147Traceback (most recent call last): 148torch._dynamo.exc.TorchRuntimeError: Failed running call_method add(*(FakeTensor(..., size=(1000, 1000), grad_fn=<MulBackward0>), FakeTensor(..., size=(10, 10))), **{}): 149Attempting to broadcast a dimension of length 10 at -1! Mismatching argument at index 1 had torch.Size([10, 10]); but expected shape should be broadcastable to [1000, 1000] 150 151from user code: 152 File "test_logging.py", line N, in dynamo_error_fn 153 output = output.add(torch.ones(10, 10))""", # noqa: B950 154 ) 155 156 test_aot = within_range_record_test(2, 6, aot=logging.INFO) 157 test_inductor_debug = within_range_record_test(3, 17, inductor=logging.DEBUG) 158 test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO) 159 160 @make_logging_test() 161 def test_inductor_error(self, records): 162 exitstack = contextlib.ExitStack() 163 import torch._inductor.lowering 164 165 def throw(x): 166 raise AssertionError 167 168 # inject an error in the lowerings 169 dict_entries = {} 170 for x in list(torch._inductor.lowering.lowerings.keys()): 171 if "round" in x.__name__: 172 dict_entries[x] = throw 173 174 exitstack.enter_context( 175 unittest.mock.patch.dict(torch._inductor.lowering.lowerings, dict_entries) 176 ) 177 178 try: 179 fn_opt = torch._dynamo.optimize("inductor")(inductor_error_fn) 180 fn_opt(*ARGS) 181 except Exception: 182 pass 183 record = self.getRecord(records, "WON'T CONVERT") 184 self.assertExpectedInline( 185 munge_exc(record.getMessage()), 186 """\ 187WON'T CONVERT inductor_error_fn test_logging.py line N 188due to: 189Traceback (most recent call last): 190 File "test_logging.py", line N, in throw 191 raise AssertionError 192torch._inductor.exc.LoweringException: AssertionError: 193 target: aten.round.default 194 args[0]: TensorBox(StorageBox( 195 InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1])) 196 )) 197 198The above exception was the direct cause of the following exception: 199 200Traceback (most recent call last): 201torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: 202LoweringException: AssertionError: 203 target: aten.round.default 204 args[0]: TensorBox(StorageBox( 205 InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1])) 206 ))""", 207 ) 208 209 exitstack.close() 210 211 @requires_distributed() 212 @requires_cuda 213 @make_logging_test(ddp_graphs=True) 214 def test_ddp_graphs(self, records): 215 class ToyModel(torch.nn.Module): 216 def __init__(self) -> None: 217 super().__init__() 218 self.layers = torch.nn.Sequential( 219 torch.nn.Linear(1024, 1024), 220 torch.nn.Linear(1024, 1024), 221 ) 222 223 def forward(self, x): 224 return self.layers(x) 225 226 os.environ["MASTER_ADDR"] = "localhost" 227 os.environ["MASTER_PORT"] = str(find_free_port()) 228 dist.init_process_group("gloo", rank=0, world_size=1) 229 230 ddp_model = torch._dynamo.optimize("inductor")( 231 DDP(ToyModel().to("cuda:0"), device_ids=[0], bucket_cap_mb=4) 232 ) 233 234 ddp_model(torch.randn(1024, 1024, device="cuda:0")) 235 236 dist.destroy_process_group() 237 self.assertEqual(len([r for r in records if "__ddp_graphs" in r.name]), 4) 238 239 # check that logging to a child log of a registered logger 240 # does not register it and result in duplicated records 241 @make_settings_test("torch._dynamo.output_graph") 242 def test_open_registration_with_registered_parent(self, records): 243 logger = logging.getLogger("torch._dynamo.output_graph") 244 logger.info("hi") 245 self.assertEqual(len(records), 1) 246 247 # check logging to a random log that is not a child log of a registered 248 # logger registers it and sets handlers properly 249 @make_settings_test("torch.utils") 250 def test_open_registration(self, records): 251 logger = logging.getLogger("torch.utils") 252 logger.info("hi") 253 self.assertEqual(len(records), 1) 254 255 # check logging to a random log that is not a child log of a registered 256 # logger registers it and sets handlers properly 257 @make_logging_test(modules={"torch.utils": logging.INFO}) 258 def test_open_registration_python_api(self, records): 259 logger = logging.getLogger("torch.utils") 260 logger.info("hi") 261 self.assertEqual(len(records), 1) 262 263 @make_logging_test(all=logging.DEBUG, dynamo=logging.INFO) 264 def test_all(self, _): 265 registry = torch._logging._internal.log_registry 266 267 dynamo_qnames = registry.log_alias_to_log_qnames["dynamo"] 268 for logger_qname in torch._logging._internal.log_registry.get_log_qnames(): 269 logger = logging.getLogger(logger_qname) 270 271 # if logger_qname is a.b.c and dynamo_qnames contains a.b, it still matches dynamo's INFO setting 272 if any(logger_qname.find(d) == 0 for d in dynamo_qnames): 273 self.assertEqual( 274 logger.getEffectiveLevel(), 275 logging.INFO, 276 msg=f"expected {logger_qname} is INFO, got {logging.getLevelName(logger.getEffectiveLevel())}", 277 ) 278 else: 279 self.assertEqual( 280 logger.getEffectiveLevel(), 281 logging.DEBUG, 282 msg=f"expected {logger_qname} is DEBUG, got {logging.getLevelName(logger.getEffectiveLevel())}", 283 ) 284 285 @make_logging_test(graph_breaks=True) 286 def test_graph_breaks(self, records): 287 @torch._dynamo.optimize("inductor") 288 def fn(x): 289 torch._dynamo.graph_break() 290 return x + 1 291 292 fn(torch.ones(1)) 293 294 self.assertEqual(len(records), 1) 295 296 @make_settings_test("torch._dynamo.utils") 297 def test_dump_compile_times(self, records): 298 fn_opt = torch._dynamo.optimize("inductor")(example_fn) 299 fn_opt(torch.ones(1000, 1000)) 300 # This function runs during exit via atexit.register. 301 # We're not actually going to run atexit._run_exit_funcs() here, 302 # because it'll destroy state necessary for other tests. 303 torch._dynamo.utils.dump_compile_times() 304 self.assertEqual( 305 len( 306 [r for r in records if "TorchDynamo compilation metrics" in str(r.msg)] 307 ), 308 1, 309 ) 310 311 @make_logging_test(dynamo=logging.INFO) 312 def test_custom_format_exc(self, records): 313 dynamo_log = logging.getLogger(torch._dynamo.__name__) 314 try: 315 raise RuntimeError("foo") 316 except RuntimeError: 317 dynamo_log.exception("test dynamo") 318 dynamo_log.info("with exc", exc_info=True) 319 dynamo_log.info("with stack", stack_info=True) 320 self.assertEqual(len(records), 3) 321 # unfortunately there's no easy way to test the final formatted log other than 322 # to ask the dynamo logger's handler to format it. 323 for handler in dynamo_log.handlers: 324 if torch._logging._internal._is_torch_handler(handler): 325 break 326 self.assertIsNotNone(handler) 327 self.assertIn("Traceback", handler.format(records[0])) 328 self.assertIn("Traceback", handler.format(records[1])) 329 self.assertIn("Stack", handler.format(records[2])) 330 331 @make_logging_test(dynamo=logging.INFO) 332 def test_custom_format(self, records): 333 dynamo_log = logging.getLogger(torch._dynamo.__name__) 334 test_log = torch._logging.getArtifactLogger( 335 torch._dynamo.__name__, "custom_format_test_artifact" 336 ) 337 dynamo_log.info("test dynamo") 338 test_log.info("custom format") 339 self.assertEqual(len(records), 2) 340 # unfortunately there's no easy way to test the final formatted log other than 341 # to ask the dynamo logger's handler to format it. 342 for handler in dynamo_log.handlers: 343 if torch._logging._internal._is_torch_handler(handler): 344 break 345 self.assertIsNotNone(handler) 346 self.assertIn("I", handler.format(records[0])) 347 self.assertEqual("custom format", handler.format(records[1])) 348 349 @make_logging_test(dynamo=logging.INFO) 350 def test_multiline_format(self, records): 351 dynamo_log = logging.getLogger(torch._dynamo.__name__) 352 dynamo_log.info("test\ndynamo") 353 dynamo_log.info("%s", "test\ndynamo") 354 dynamo_log.info("test\n%s", "test\ndynamo") 355 self.assertEqual(len(records), 3) 356 # unfortunately there's no easy way to test the final formatted log other than 357 # to ask the dynamo logger's handler to format it. 358 for handler in dynamo_log.handlers: 359 if torch._logging._internal._is_torch_handler(handler): 360 break 361 self.assertIsNotNone(handler) 362 for record in records: 363 r = handler.format(record) 364 for l in r.splitlines(): 365 self.assertIn("I", l) 366 367 test_trace_source_simple = within_range_record_test(1, 100, trace_source=True) 368 369 @make_logging_test(trace_source=True) 370 def test_trace_source_if_stmt(self, records): 371 def fn(x): 372 if x.sum() > 0: 373 return x * 2 374 return x * 3 375 376 fn_opt = torch._dynamo.optimize("eager")(fn) 377 fn_opt(torch.ones(3, 3)) 378 379 found_x2 = False 380 found_x3 = False 381 for record in records: 382 msg = record.getMessage() 383 if "return x * 2" in msg: 384 found_x2 = True 385 if "return x * 3" in msg: 386 found_x3 = True 387 388 self.assertTrue(found_x2) 389 self.assertFalse(found_x3) 390 391 @make_logging_test(trace_source=True) 392 def test_trace_source_nested(self, records): 393 def fn1(x): 394 x = fn2(x) 395 return x * 2 396 397 def fn2(x): 398 x = fn3(x) 399 return x * 3 400 401 def fn3(x): 402 return x * 4 403 404 fn_opt = torch._dynamo.optimize("eager")(fn1) 405 fn_opt(torch.ones(3, 3)) 406 407 found_x2 = False 408 found_x3 = False 409 found_x4 = False 410 for record in records: 411 msg = record.getMessage() 412 if "return x * 2" in msg: 413 found_x2 = True 414 self.assertNotIn("inline depth", msg) 415 elif "return x * 3" in msg: 416 found_x3 = True 417 self.assertIn("inline depth: 1", msg) 418 elif "return x * 4" in msg: 419 found_x4 = True 420 self.assertIn("inline depth: 2", msg) 421 self.assertTrue(found_x2) 422 self.assertTrue(found_x3) 423 self.assertTrue(found_x4) 424 425 @make_logging_test(trace_source=True) 426 def test_trace_source_cond(self, records): 427 from functorch.experimental.control_flow import cond 428 429 def true_fn(x): 430 return x * 2 431 432 def false_fn(x): 433 return x * 3 434 435 def inner(pred, x): 436 return cond(pred, true_fn, false_fn, [x]) 437 438 def outer(pred, x): 439 return inner(pred, x) 440 441 fn_opt = torch._dynamo.optimize("eager")(outer) 442 fn_opt(torch.tensor(True), torch.ones(3, 3)) 443 444 found_x2 = False 445 found_x3 = False 446 for record in records: 447 msg = record.getMessage() 448 if "return x * 2" in msg: 449 found_x2 = True 450 self.assertIn("inline depth: 3", msg) 451 if "return x * 3" in msg: 452 found_x3 = True 453 self.assertIn("inline depth: 3", msg) 454 455 self.assertTrue(found_x2) 456 self.assertTrue(found_x3) 457 458 @make_logging_test(trace_source=True) 459 def test_trace_source_funcname(self, records): 460 # NOTE: list comprehensions are inlined in 3.12, so test with tuples 461 def fn1(): 462 def fn2(): 463 if True: 464 return tuple(torch.ones(3, 3) for _ in range(5)) 465 return None 466 467 return fn2() 468 469 fn_opt = torch._dynamo.optimize("eager")(fn1) 470 fn_opt() 471 472 found_funcname = False 473 for record in records: 474 msg = record.getMessage() 475 if "<genexpr>" in msg and "fn1.fn2" in msg: 476 found_funcname = True 477 478 self.assertTrue(found_funcname) 479 480 def test_invalid_artifact_flag(self): 481 with self.assertRaises(ValueError): 482 torch._logging.set_logs(aot_graphs=5) 483 484 @requires_distributed() 485 def test_distributed_rank_logging(self): 486 env = dict(os.environ) 487 env["TORCH_LOGS"] = "dynamo" 488 stdout, stderr = self.run_process_no_exception( 489 """\ 490import torch.distributed as dist 491import logging 492from torch.testing._internal.distributed.fake_pg import FakeStore 493store = FakeStore() 494dist.init_process_group("fake", rank=0, world_size=2, store=store) 495dynamo_log = logging.getLogger("torch._dynamo") 496dynamo_log.info("woof") 497print("arf") 498""", 499 env=env, 500 ) 501 self.assertIn("[rank0]:", stderr.decode("utf-8")) 502 503 @skipIfNotPy311 504 @make_logging_test(trace_call=True) 505 def test_trace_call(self, records): 506 def fn(x, y): 507 return (x * 2) @ (y * 3) 508 509 fn_opt = torch._dynamo.optimize("eager")(fn) 510 fn_opt(torch.randn(10, 20), torch.randn(20, 30)) 511 512 self.assertEqual(len(records), 3) 513 # only get last 2 lines 514 messages = [ 515 "\n".join(record.getMessage().split("\n")[-2:]) for record in records 516 ] 517 self.assertExpectedInline( 518 messages[0], 519 """\ 520 return (x * 2) @ (y * 3) 521 ~~^~~""", 522 ) 523 self.assertExpectedInline( 524 messages[1], 525 """\ 526 return (x * 2) @ (y * 3) 527 ~~^~~""", 528 ) 529 self.assertExpectedInline( 530 messages[2], 531 """\ 532 return (x * 2) @ (y * 3) 533 ~~~~~~~~^~~~~~~~~""", 534 ) 535 536 @skipIfNotPy311 537 @make_logging_test(trace_call=True) 538 def test_trace_call_inline_call(self, records): 539 def g(x): 540 return x * 2 541 542 def f(x): 543 return g(g(x)) 544 545 fn_opt = torch._dynamo.optimize("eager")(f) 546 fn_opt(torch.randn(3, 3)) 547 548 self.assertEqual(len(records), 4) 549 messages = [ 550 "\n".join(record.getMessage().split("\n")[-2:]) for record in records 551 ] 552 self.assertExpectedInline( 553 messages[0], 554 """\ 555 return g(g(x)) 556 ~^^^""", 557 ) 558 self.assertExpectedInline( 559 messages[1], 560 """\ 561 return x * 2 562 ~~^~~""", 563 ) 564 self.assertExpectedInline( 565 messages[2], 566 """\ 567 return g(g(x)) 568 ~^^^^^^""", 569 ) 570 self.assertExpectedInline( 571 messages[3], 572 """\ 573 return x * 2 574 ~~^~~""", 575 ) 576 577 @skipIfNotPy311 578 @make_logging_test(trace_call=True) 579 def test_trace_call_graph_break(self, records): 580 def fn(x): 581 x = x * 2 582 torch._dynamo.graph_break() 583 return x * 3 584 585 fn_opt = torch._dynamo.optimize("eager")(fn) 586 fn_opt(torch.randn(3, 3)) 587 588 self.assertEqual(len(records), 3) 589 messages = [ 590 "\n".join(record.getMessage().split("\n")[-2:]) for record in records 591 ] 592 self.assertExpectedInline( 593 messages[0], 594 """\ 595 x = x * 2 596 ~~^~~""", 597 ) 598 self.assertExpectedInline( 599 messages[-1], 600 """\ 601 return x * 3 602 ~~^~~""", 603 ) 604 605 @make_logging_test(guards=True, recompiles=True) 606 def test_guards_recompiles(self, records): 607 def fn(x, ys, zs): 608 return inner(x, ys, zs) 609 610 def inner(x, ys, zs): 611 for y, z in zip(ys, zs): 612 x += y * z 613 return x 614 615 ys = [1.0, 2.0] 616 zs = [3.0] 617 x = torch.tensor([1.0]) 618 619 fn_opt = torch._dynamo.optimize("eager")(fn) 620 fn_opt(x, ys, zs) 621 fn_opt(x, ys[:1], zs) 622 623 record_str = "\n".join(r.getMessage() for r in records) 624 625 self.assertIn( 626 """L['zs'][0] == 3.0""", 627 record_str, 628 ) 629 self.assertIn( 630 "len(L['ys']) == 2", 631 record_str, 632 ) 633 634 @make_logging_test(cudagraph_static_inputs=True) 635 def test_cudagraph_static_inputs(self, records): 636 @torch.compile(mode="reduce-overhead") 637 def fn(x): 638 return x + 1 639 640 x = torch.ones(2, 2) 641 torch._dynamo.mark_static_address(x) 642 fn(x) 643 self.assertGreater(len(records), 0) 644 self.assertLess(len(records), 4) 645 646 @skipIfTorchDynamo("too slow") 647 @make_logging_test(**torch._logging.DEFAULT_LOGGING) 648 def test_default_logging(self, records): 649 def fn(a): 650 if a.sum() < 0: 651 a = torch.sin(a) 652 else: 653 a = torch.cos(a) 654 print("hello") 655 return a + 1 656 657 fn_opt = torch._dynamo.optimize("eager")(fn) 658 fn_opt(torch.ones(10, 10)) 659 fn_opt(-torch.ones(10, 5)) 660 661 self.assertGreater(len([r for r in records if ".__graph_breaks" in r.name]), 0) 662 self.assertGreater(len([r for r in records if ".__recompiles" in r.name]), 0) 663 self.assertGreater(len([r for r in records if ".symbolic_shapes" in r.name]), 0) 664 self.assertGreater(len([r for r in records if ".__guards" in r.name]), 0) 665 self.assertGreater( 666 len([r for r in records if "return a + 1" in r.getMessage()]), 0 667 ) 668 669 def test_logs_out(self): 670 import tempfile 671 672 with tempfile.NamedTemporaryFile(delete=False) as tmp: 673 file_path = _as_posix_path(tmp.name) 674 """ 675 NamedTemporaryFile will include a file open operation. 676 On Windowsm the file is opened by NamedTemporaryFile, the 677 following run_process_no_exception can't access a opened file. 678 And then, raise a PermissionError: [Errno 13] Permission denied: [file_path] 679 """ 680 tmp.close() 681 env = dict(os.environ) 682 env["TORCH_LOGS"] = "dynamo" 683 env["TORCH_LOGS_OUT"] = file_path 684 stdout, stderr = self.run_process_no_exception( 685 """\ 686import torch 687@torch.compile(backend="eager") 688def fn(a): 689 return a.sum() 690 691fn(torch.randn(5)) 692 """, 693 env=env, 694 ) 695 with open( 696 file_path, encoding="utf-8" 697 ) as fd: # encoding file to UTF-8 for Windows. 698 lines = fd.read() 699 fd.close() 700 os.remove( 701 file_path 702 ) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False. 703 self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix. 704 empty_line_normalizer(lines), 705 empty_line_normalizer(stderr.decode("utf-8")), 706 ) 707 708 709# single record tests 710exclusions = { 711 "bytecode", 712 "cudagraphs", 713 "output_code", 714 "schedule", 715 "fusion", 716 "overlap", 717 "aot_graphs", 718 "aot_graphs_effects", 719 "post_grad_graphs", 720 "compiled_autograd", 721 "compiled_autograd_verbose", 722 "recompiles", 723 "recompiles_verbose", 724 "graph_breaks", 725 "graph", 726 "graph_code", 727 "graph_sizes", 728 "ddp_graphs", 729 "perf_hints", 730 "not_implemented", 731 "trace_source", 732 "trace_call", 733 "trace_bytecode", 734 "custom_format_test_artifact", 735 "onnx", 736 "onnx_diagnostics", 737 "guards", 738 "verbose_guards", 739 "sym_node", 740 "export", 741 "trace_shape_events", 742 "cudagraph_static_inputs", 743 "benchmarking", 744 "loop_ordering", 745} 746for name in torch._logging._internal.log_registry.artifact_names: 747 if name not in exclusions: 748 setattr(LoggingTests, f"test_{name}", single_record_test(**{name: True})) 749 750if __name__ == "__main__": 751 from torch._dynamo.test_case import run_tests 752 753 run_tests() 754