1# mypy: allow-untyped-decorators 2from __future__ import annotations 3 4import collections 5import contextlib 6import cProfile 7import dis 8import functools 9import itertools 10import logging 11import os 12import pstats 13import random 14import subprocess 15import sys 16import threading 17import time 18import traceback 19import typing 20import weakref 21from pathlib import Path 22from types import CodeType, FrameType, FunctionType, ModuleType 23from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union 24from typing_extensions import ParamSpec 25from weakref import ReferenceType 26 27import torch 28import torch._logging 29from torch._C._dynamo.guards import GlobalStateGuard 30from torch._dynamo.distributed import get_compile_pg 31from torch._dynamo.utils import CompileTimeInstructionCounter 32from torch._guards import compile_context, CompileContext, CompileId, tracing 33from torch._logging import structured 34from torch._utils_internal import ( 35 compile_time_strobelight_meta, 36 justknobs_check, 37 maybe_upload_prof_stats_to_manifold, 38 signpost_event, 39) 40from torch.fx._lazy_graph_module import _use_lazy_graph_module 41from torch.fx.experimental.symbolic_shapes import ( 42 ConstraintViolationError, 43 GuardOnDataDependentSymNode, 44) 45from torch.fx.graph_module import _forward_from_src as original_forward_from_src 46from torch.nn.parallel.distributed import DistributedDataParallel 47from torch.utils._python_dispatch import ( 48 _disable_current_modes, 49 is_in_torch_dispatch_mode, 50) 51from torch.utils._traceback import CapturedTraceback, format_traceback_short 52 53from . import config, exc, trace_rules 54from .bytecode_analysis import remove_dead_code, remove_pointless_jumps 55from .bytecode_transformation import ( 56 check_inst_exn_tab_entries_valid, 57 Instruction, 58 is_generator, 59 propagate_inst_exn_table_entries, 60 transform_code_object, 61) 62from .cache_size import ( 63 CacheSizeRelevantForFrame, 64 compute_cache_size, 65 exceeds_cache_size_limit, 66 is_recompilation, 67) 68from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher 69from .exc import ( 70 augment_exc_message, 71 BackendCompilerFailed, 72 CacheLimitExceeded, 73 format_error_msg, 74 InternalTorchDynamoError, 75 SkipCodeRecursiveException, 76 TorchRuntimeError, 77 UncapturedHigherOrderOpError, 78 unimplemented, 79 Unsupported, 80) 81from .guards import ( 82 CheckFunctionManager, 83 get_and_maybe_log_recompilation_reason, 84 GuardedCode, 85) 86from .hooks import Hooks 87from .replay_record import ExecutionRecord 88from .symbolic_convert import ( 89 DistributedState, 90 InstructionTranslator, 91 LocalState, 92 SpeculationLog, 93) 94from .trace_rules import is_numpy 95from .utils import ( 96 CleanupManager, 97 CompilationMetrics, 98 counters, 99 dynamo_timed, 100 format_bytecode, 101 frame_phase_timing, 102 gen_record_file_name, 103 get_chromium_event_logger, 104 increment_frame, 105 is_namedtuple, 106 istype, 107 LazyString, 108 orig_code_map, 109 record_compilation_metrics, 110 reset_graph_break_dup_checker, 111 setup_compile_debug, 112 troubleshooting_url, 113 write_record_to_file, 114) 115 116 117np: Optional[ModuleType] 118try: 119 import numpy as np 120except ModuleNotFoundError: 121 np = None 122 123 124if typing.TYPE_CHECKING: 125 from .backends.registry import CompilerFn 126 from .repro.after_dynamo import WrapBackendDebug 127 from .types import BytecodeHook, CacheEntry 128 from .variables.builder import FrameStateSizeEntry 129 130 131log = logging.getLogger(__name__) 132bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode") 133graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") 134 135 136compile_lock = threading.RLock() 137 138_T = TypeVar("_T") 139_P = ParamSpec("_P") 140 141 142class TODO_UNKNOWN: 143 pass 144 145 146class Tracker: 147 def __init__(self) -> None: 148 self.seen: List[ReferenceType[CodeType]] = [] 149 self.seen_ids: Set[int] = set() 150 151 def add(self, strong_obj: CodeType) -> None: 152 idx = id(strong_obj) 153 if idx not in self.seen_ids: 154 obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx)) 155 self.seen.append(obj) 156 self.seen_ids.add(idx) 157 158 def __contains__(self, item: CodeType) -> bool: 159 return id(item) in self.seen_ids 160 161 def clear(self) -> None: 162 self.seen.clear() 163 self.seen_ids.clear() 164 165 166input_codes = Tracker() 167output_codes = Tracker() 168 169initial_global_state: Optional[GlobalStateGuard] = None 170 171 172@functools.wraps(original_forward_from_src) 173def fx_forward_from_src_skip_result( 174 src: str, globals: Dict[str, Any], co_fields: Optional[Dict[str, str]] = None 175) -> FunctionType: 176 # we monkey patch FX to prevent infinite loop of trying to convert 177 # our generated code 178 result = original_forward_from_src(src, globals, co_fields) 179 skip_code(result.__code__) 180 return result 181 182 183def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: 184 """ 185 Context manager to: 186 1) Save/restore torch.is_grad_enabled() state 187 2) Save/restore python random state 188 3) Save/restore torch random state 189 4) Monkey patch torch.fx.graph_module._forward_from_src 190 """ 191 192 @functools.wraps(fn) 193 def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: 194 guards = GlobalStateGuard() 195 prior_grad_mode = torch.is_grad_enabled() 196 # Just in case we get left in a bad dispatch state we want to restore 197 # it. This can happen because the dispatch bits aren't a true 198 # stack/counter - so we can't just increment/decrement them as we enter 199 # and leave. 200 with torch._C._PreserveDispatchKeyGuard(): 201 prior_inference_mode = torch.is_inference_mode_enabled() 202 prior_deterministic = torch.are_deterministic_algorithms_enabled() 203 prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled() 204 py_rng_state = random.getstate() 205 torch_rng_state = torch.random.get_rng_state() 206 cuda_rng_state = None 207 if torch.cuda.is_available(): 208 cuda_rng_state = torch.cuda.get_rng_state() 209 allow_tf32 = torch._C._get_cublas_allow_tf32() 210 prior_fwd_from_src = torch.fx.graph_module._forward_from_src 211 torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result 212 cleanup = setup_compile_debug() 213 214 exit_stack = contextlib.ExitStack() 215 exit_stack.enter_context( 216 torch.fx._symbolic_trace._maybe_revert_all_patches() 217 ) 218 try: 219 return fn(*args, **kwargs) 220 finally: 221 cleanup.close() 222 exit_stack.close() 223 torch._C._set_grad_enabled(prior_grad_mode) 224 torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) 225 torch.use_deterministic_algorithms( 226 prior_deterministic, warn_only=prior_warn_only 227 ) 228 random.setstate(py_rng_state) 229 torch.random.set_rng_state(torch_rng_state) 230 if cuda_rng_state is not None: 231 torch.cuda.set_rng_state(cuda_rng_state) 232 torch._C._set_cublas_allow_tf32(allow_tf32) 233 torch.fx.graph_module._forward_from_src = prior_fwd_from_src 234 assert ( 235 guards.check() 236 ), f"Global {guards.reason()}state changed while dynamo tracing, please report a bug" 237 238 _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] 239 return _fn 240 241 242@TorchPatcher.suppress_torch_distributed_warnings 243def has_tensor_in_frame(frame: FrameType) -> bool: 244 """Check if the frame has torch.* related bits""" 245 # Check if the function was decorated using torch._dynamo.optimize 246 if frame.f_code in always_optimize_code_objects: 247 return True 248 249 # Check if there is global import of torch.* 250 for co_name in frame.f_code.co_names: 251 if co_name in frame.f_globals: 252 obj = frame.f_globals[co_name] 253 if isinstance(obj, ModuleType) and ( 254 obj.__name__.startswith("torch.") or obj is torch 255 ): 256 return True 257 # ... or a global import of numpy.* 258 if np and config.trace_numpy and (obj is np or is_numpy(obj)): 259 return True 260 261 seen_ids: Dict[int, bool] = {} 262 263 def has_tensor(obj: object) -> bool: 264 """Recursively check if the obj has a tensor""" 265 obj_id = id(obj) 266 if obj_id in seen_ids: 267 return seen_ids[obj_id] 268 seen_ids[obj_id] = False 269 270 if isinstance(obj, (torch.Tensor, torch.nn.Module)) or ( 271 istype(obj, type) and issubclass(obj, torch.nn.Module) 272 ): 273 seen_ids[obj_id] = True 274 return seen_ids[obj_id] 275 elif ( 276 config.trace_numpy 277 and np 278 and (istype(obj, np.ndarray) or isinstance(obj, np.generic)) 279 ): 280 seen_ids[obj_id] = True 281 return seen_ids[obj_id] 282 elif istype(obj, (list, tuple)): 283 seen_ids[obj_id] = any(has_tensor(v) for v in obj) 284 return seen_ids[obj_id] 285 elif istype(obj, dict): 286 # Some packages like pytest can be updated during runtime. So, make a 287 # copy of values to avoid issues like "RuntimeError: dictionary 288 # changed size during iteration" 289 values = list(obj.values()) 290 seen_ids[obj_id] = any(has_tensor(v) for v in values) 291 return seen_ids[obj_id] 292 elif istype(obj, (str, int, float, type(None), bool)): 293 seen_ids[obj_id] = False 294 return seen_ids[obj_id] 295 elif is_namedtuple(obj) and hasattr(obj, "_fields"): 296 seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields) 297 return seen_ids[obj_id] 298 else: 299 # if config.debug: 300 # print( 301 # f"Assuming that object of type {type(obj)} does not have a tensor" 302 # ) 303 return False 304 305 # Check if the passed arguments are of type Tensor 306 for value in frame.f_locals.values(): 307 if has_tensor(value): 308 return True 309 310 log.debug( 311 "skipping because no torch.* %s \ 312 %s %s", 313 frame.f_code.co_name, 314 frame.f_code.co_filename, 315 frame.f_code.co_firstlineno, 316 ) 317 318 return False 319 320 321def exception_handler( 322 e: Exception, 323 code: CodeType, 324 frame: Optional[FrameType] = None, 325 export: bool = False, 326) -> None: 327 record_filename = None 328 if hasattr(e, "exec_record"): 329 record_filename = gen_record_file_name(e, code) 330 write_record_to_file(record_filename, e.exec_record) 331 e.record_filename = record_filename # type: ignore[attr-defined] 332 333 augment_exc_message(e, export=export) 334 335 336FRAME_COUNTER = 0 337FRAME_COMPILE_COUNTER: typing.Counter[ 338 Union[int, FrameStateSizeEntry] 339] = collections.Counter() 340 341 342def maybe_cprofile(func: Callable[_P, _T]) -> Callable[_P, _T]: 343 if config.cprofile: 344 return cprofile_wrapper(func) 345 return func 346 347 348def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]: 349 @functools.wraps(func) 350 def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: 351 trace_id = CompileContext.current_trace_id() 352 assert trace_id, "Trace id is None" 353 profile_path = Path( 354 f"/tmp/{func.__name__}_{str(trace_id).replace('/', '_')}.profile" 355 ) 356 prof = cProfile.Profile() 357 prof.enable() 358 start_ts = time.time() 359 retval = prof.runcall(func, *args, **kwargs) 360 profile_latency = time.time() - start_ts 361 prof.disable() 362 log.warning( 363 "### Cprofile for %s trace id [%s] took %.3f seconds ###", 364 func.__name__, 365 trace_id, 366 profile_latency, 367 ) 368 ps = pstats.Stats(prof) 369 try: 370 prof.dump_stats(profile_path) 371 except PermissionError: 372 log.exception("Cannot write to %s", profile_path) 373 log.warning("Raw profile at %s", profile_path) 374 svg_path = profile_path.with_suffix(".svg") 375 try: 376 gprof2dot_process = subprocess.Popen( 377 [ 378 "gprof2dot", 379 "-f", 380 "pstats", 381 "--node-label=total-time-percentage", 382 "--node-label=self-time-percentage", 383 "--node-label=total-time", 384 str(profile_path), 385 ], 386 stdout=subprocess.PIPE, 387 ) 388 subprocess.check_call( 389 ["dot", "-Tsvg", "-o", str(svg_path)], 390 stdin=gprof2dot_process.stdout, 391 ) 392 log.warning("Generated SVG from profile at %s", svg_path) 393 except FileNotFoundError: 394 log.warning( 395 "Failed to generate SVG from profile -- dumping stats instead." 396 "Try installing gprof2dot and dot for a better visualization" 397 ) 398 ps.sort_stats(pstats.SortKey.TIME).print_stats(20) 399 ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20) 400 401 if manifold_link := maybe_upload_prof_stats_to_manifold( 402 str(profile_path) 403 ): # fb-only 404 torch._logging.trace_structured( 405 "link", 406 lambda: {"name": "cprofile_manifold_url", "url": manifold_link}, 407 ) 408 return retval 409 410 return profile_wrapper 411 412 413class ConvertFrameAssert: 414 def __init__( 415 self, 416 compiler_fn: CompilerFn, 417 one_graph: bool = True, 418 export: bool = False, 419 export_constraints: Optional[typing.Never] = None, 420 ) -> None: 421 # assert export_constraints is None 422 reset_graph_break_dup_checker() 423 self._torchdynamo_orig_callable = compiler_fn 424 self._one_graph = one_graph 425 self._export = export 426 self._export_constraints = export_constraints 427 428 @property 429 def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]: 430 return lambda backend: convert_frame_assert( 431 backend, self._one_graph, self._export, self._export_constraints 432 ) 433 434 def __call__( 435 self, 436 frame: FrameType, 437 cache_entry: Optional[CacheEntry], 438 hooks: Hooks, 439 frame_state: Dict[str, Union[int, FrameStateSizeEntry]], 440 *, 441 skip: int = 0, 442 ) -> Optional[GuardedCode]: 443 increment_frame() 444 445 code = frame.f_code 446 447 cache_size = compute_cache_size(frame, cache_entry) 448 input_codes.add(code) 449 if code in output_codes: 450 return None 451 if ( 452 os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") 453 and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name 454 ): 455 return None 456 if code.co_name == "<genexpr>" and code.co_filename.endswith( 457 ( 458 "transformers/file_utils.py", 459 "transformers/utils/generic.py", 460 "diffusers/utils/outputs.py", 461 ) 462 ): 463 # not needed, but cleans up torchbench error stats 464 return None 465 if code.co_name == "__setattr__": 466 # setattr could be tricky to handle generally, 467 # but also not likely useful to compile- skip the whole frame 468 return None 469 if code.co_name == "__init__" and code.co_filename.startswith( 470 os.path.dirname(torch.optim.__file__) 471 ): 472 # optimizer support is still incomplete see 473 # test_state_dict in test/dynamo/test_optimizers.py 474 return None 475 476 # Check if the frame is generated by an exec builtin call 477 # TODO - Running exec generated frame seems propagates f_globals to the 478 # next frames. 479 if code.co_name == "<module>" and code.co_filename == "<string>": 480 return None 481 482 if ( 483 code.co_name == "<lambda>" 484 and code.co_filename == "<string>" 485 and not bool(frame.f_builtins) 486 ): 487 # namedtuple subclass constructor. Empty builtins cause issue with 488 # len keyword in LIST_LEN guard. 489 return None 490 491 if is_generator(code): 492 unimplemented("generator") 493 494 if not has_tensor_in_frame(frame): 495 return None 496 497 global initial_global_state 498 initial_global_state = GlobalStateGuard() 499 500 global FRAME_COUNTER 501 if "_id" not in frame_state: 502 frame_state["_id"] = FRAME_COUNTER 503 FRAME_COUNTER += 1 504 frame_id = frame_state["_id"] 505 assert isinstance(frame_id, int) 506 507 frame_compile_id = FRAME_COMPILE_COUNTER[frame_id] 508 FRAME_COMPILE_COUNTER[frame_id] += 1 509 510 compile_id = CompileId(frame_id, frame_compile_id) 511 512 signpost_event( 513 "dynamo", 514 "_convert_frame_assert._compile", 515 { 516 "co_name": code.co_name, 517 "frame_id": frame_id, 518 "compile_id": str(compile_id), 519 "co_filename": code.co_filename, 520 "co_firstlineno": code.co_firstlineno, 521 "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs, 522 "accumulated_cache_size": cache_size.num_cache_entries, 523 }, 524 ) 525 526 return _compile( 527 frame.f_code, 528 frame.f_globals, 529 frame.f_locals, 530 frame.f_builtins, 531 self._torchdynamo_orig_callable, 532 self._one_graph, 533 self._export, 534 self._export_constraints, 535 hooks, 536 cache_entry, 537 cache_size, 538 frame, 539 frame_state=frame_state, 540 compile_id=compile_id, 541 skip=skip + 1, 542 ) 543 544 545def convert_frame_assert( 546 compiler_fn: CompilerFn, 547 one_graph: bool = True, 548 export: bool = False, 549 export_constraints: Optional[typing.Never] = None, 550) -> ConvertFrameAssert: 551 """Fully convert a frame into an FX graph""" 552 return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints) 553 554 555from collections import OrderedDict 556 557from torch.utils.hooks import RemovableHandle 558 559 560if typing.TYPE_CHECKING: 561 from .output_graph import OutputGraph 562 563# we have to use `OrderedDict` to make `RemovableHandle` work. 564_bytecode_hooks: Dict[int, BytecodeHook] = OrderedDict() 565 566 567def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle: 568 """Register hooks for bytecode generated by Dynamo. The hook can do some 569 logging, as well as return a new code object to be used. Please refer 570 to `BytecodeHook` for the hook signature. 571 """ 572 handle = RemovableHandle(_bytecode_hooks) 573 _bytecode_hooks[handle.id] = hook 574 return handle 575 576 577def _compile( 578 code: CodeType, 579 globals: Dict[str, object], 580 locals: Dict[str, object], 581 builtins: Dict[str, object], 582 compiler_fn: CompilerFn, 583 one_graph: bool, 584 export: bool, 585 export_constraints: Optional[typing.Never], 586 hooks: Hooks, 587 cache_entry: Optional[CacheEntry], 588 cache_size: CacheSizeRelevantForFrame, 589 frame: Optional[FrameType] = None, 590 frame_state: Optional[Dict[str, Union[int, FrameStateSizeEntry]]] = None, 591 *, 592 compile_id: CompileId, 593 skip: int = 0, 594) -> Optional[GuardedCode]: 595 from torch.fx.experimental.validator import ( 596 bisect, 597 BisectValidationException, 598 translation_validation_enabled, 599 ValidationException, 600 ) 601 602 # Only nonlocal defs here please! 603 # Time spent compiling this frame before restarting or failing analysis 604 dynamo_time_before_restart: float = 0.0 605 output: Optional[OutputGraph] = None 606 tracer: Optional[InstructionTranslator] = None 607 608 @preserve_global_state 609 def transform( 610 instructions: List[Instruction], code_options: Dict[str, object] 611 ) -> None: 612 nonlocal output 613 nonlocal tracer 614 speculation_log.restart() 615 tracer = InstructionTranslator( 616 instructions, 617 code, 618 locals, 619 globals, 620 builtins, 621 code_options, 622 compiler_fn, 623 one_graph, 624 export, 625 export_constraints, 626 mutated_closure_cell_contents, 627 frame_state=frame_state, 628 speculation_log=speculation_log, 629 distributed_state=distributed_state, 630 ) 631 632 try: 633 with tracing(tracer.output.tracing_context), tracer.set_current_tx(): 634 tracer.run() 635 except exc.UnspecializeRestartAnalysis: 636 speculation_log.clear() 637 raise 638 except (exc.SpeculationRestartAnalysis, exc.SkipFrame): 639 raise 640 except Exception: 641 if translation_validation_enabled(): 642 bisect(tracer.output.shape_env) 643 raise 644 finally: 645 tracer.output.call_cleanup_hooks() 646 647 output = tracer.output 648 assert output is not None 649 assert output.output_instructions 650 instructions[:] = output.output_instructions 651 code_options.update(output.code_options) 652 653 if config.dead_code_elimination: 654 propagate_inst_exn_table_entries(instructions) 655 check_inst_exn_tab_entries_valid(instructions) 656 instructions[:] = remove_pointless_jumps(remove_dead_code(instructions)) 657 658 def compile_inner( 659 code: CodeType, 660 one_graph: bool, 661 hooks: Hooks, 662 transform: Callable[[List[Instruction], Dict[str, Any]], Any], 663 ) -> Optional[GuardedCode]: 664 with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"): 665 with CompileTimeInstructionCounter.record(): 666 return _compile_inner(code, one_graph, hooks, transform) 667 668 @compile_time_strobelight_meta(phase_name="compile_inner") 669 @maybe_cprofile 670 def _compile_inner( 671 code: CodeType, 672 one_graph: bool, 673 hooks: Hooks, 674 transform: Callable[[List[Instruction], Dict[str, Any]], Any], 675 ) -> Optional[GuardedCode]: 676 nonlocal dynamo_time_before_restart 677 last_attempt_start_time = start_time = time.time() 678 679 def log_bytecode( 680 prefix: str, name: str, filename: str, line_no: int, code: CodeType 681 ) -> None: 682 if bytecode_log.isEnabledFor(logging.DEBUG): 683 bytecode_log.debug( 684 format_bytecode(prefix, name, filename, line_no, code) 685 ) 686 687 log_bytecode( 688 "ORIGINAL BYTECODE", 689 code.co_name, 690 code.co_filename, 691 code.co_firstlineno, 692 code, 693 ) 694 695 out_code = None 696 for attempt in itertools.count(): 697 CompileContext.get().attempt = attempt 698 try: 699 out_code = transform_code_object(code, transform) 700 break 701 except exc.RestartAnalysis as e: 702 log.info( 703 "Restarting analysis due to %s", 704 LazyString(format_traceback_short, e.__traceback__), 705 ) 706 # If restart reason is None just log the type of the exception 707 restart_reasons.add(e.restart_reason or str(type(e))) 708 # We now have a new "last attempt", reset the clock 709 last_attempt_start_time = time.time() 710 if attempt > 100: 711 unimplemented("100+ RestartAnalysis() calls") 712 except exc.SkipFrame as e: 713 log.debug( 714 "Skipping frame %s %s \ 715 %s %s", 716 e, 717 code.co_name, 718 code.co_filename, 719 code.co_firstlineno, 720 ) 721 if one_graph: 722 log.debug("No graph captured with one_graph=True") 723 return None 724 725 assert ( 726 distributed_state is None or distributed_state.all_states is not None 727 ), "compiler collective wasn't run before compilation completed" 728 729 assert out_code is not None 730 log_bytecode( 731 "MODIFIED BYTECODE", 732 code.co_name, 733 code.co_filename, 734 code.co_firstlineno, 735 out_code, 736 ) 737 738 for hook in _bytecode_hooks.values(): 739 hook_output = hook(code, out_code) 740 if hook_output is not None: 741 out_code = hook_output 742 743 orig_code_map[out_code] = code 744 output_codes.add(out_code) 745 dynamo_time_before_restart = last_attempt_start_time - start_time 746 assert output is not None 747 748 # Tests for new code objects. 749 # The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c 750 # Only test once the code object is created. 751 # They are not tested during runtime. 752 753 def count_args(code: CodeType) -> int: 754 import inspect 755 756 return ( 757 code.co_argcount 758 + code.co_kwonlyargcount 759 + bool(code.co_flags & inspect.CO_VARARGS) 760 + bool(code.co_flags & inspect.CO_VARKEYWORDS) 761 ) 762 763 assert out_code is not None 764 765 total_argcount_old = count_args(code) 766 total_argcount_new = count_args(out_code) 767 msg = "arg mismatch: " 768 msg += f"old code object has args {code.co_varnames[:total_argcount_old]}, " 769 msg += f"new code object has args {out_code.co_varnames[:total_argcount_new]}" 770 assert ( 771 code.co_varnames[:total_argcount_old] 772 == out_code.co_varnames[:total_argcount_new] 773 ), msg 774 775 msg = "free var mismatch: " 776 msg += f"old code object has free var {code.co_freevars}, " 777 msg += f"new code object has free var {out_code.co_freevars}" 778 assert code.co_freevars == out_code.co_freevars, msg 779 780 msg = "cell var mismatch: " 781 msg += f"old code object has cell var {code.co_cellvars}, " 782 msg += f"new code object has cell var {out_code.co_cellvars}" 783 assert code.co_cellvars == out_code.co_cellvars, msg 784 785 # Skipping Dynamo on a frame without any extracted graph. 786 # This does not affect eager functionality. But this is necessary 787 # for export for cases where Dynamo-reconstructed bytecode can create 788 # new function frames, confusing export in thinking that there 789 # are extra graphs now. 790 791 if output.export and output.is_empty_graph(): 792 return None 793 794 assert output.guards is not None 795 CleanupManager.instance[out_code] = output.cleanups 796 check_fn = CheckFunctionManager( 797 output, 798 hooks.guard_fail_fn if hooks else None, 799 ) 800 801 guarded_code = GuardedCode(out_code, check_fn.check_fn, compile_id) 802 803 if not output.is_empty_graph() and hooks.guard_export_fn is not None: 804 # We should not run the guard_export_fn when Dynamo does not 805 # generate any graph. This can happen in export when TorchDynamo 806 # generated bytecode has some reconstruction logic for mutated 807 # variables which can trigger TorchDynamo on the children frames but 808 # they are benign and do not generate any new graphs. 809 hooks.guard_export_fn(output.guards) 810 811 return guarded_code 812 813 with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context( 814 CompileContext(compile_id) 815 ): 816 restart_reasons: set[str] = set() 817 # This is shared across restarts 818 mutated_closure_cell_contents: Set[str] = set() 819 speculation_log = SpeculationLog() 820 if compile_pg := get_compile_pg(): 821 distributed_state = DistributedState(compile_pg, LocalState()) 822 else: 823 distributed_state = None 824 torch._dynamo.callback_handler.run_start_callbacks() 825 826 # Check recompilations 827 recompile_reasons = None 828 if is_recompilation(cache_size) and frame: 829 recompile_reasons = get_and_maybe_log_recompilation_reason( 830 cache_entry, frame 831 ) 832 833 exceeded, limit_type = exceeds_cache_size_limit(cache_size, compile_id) 834 if exceeded: 835 836 def format_func_info(code: CodeType) -> str: 837 return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})" 838 839 def format_guard_failures() -> str: 840 if not recompile_reasons: 841 return "Unable to find recompilation reasons" 842 return recompile_reasons[-1] 843 844 log.warning( 845 "torch._dynamo hit config.%s (%s)\n" 846 " function: %s\n" 847 " last reason: %s\n" 848 'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n' 849 "To diagnose recompilation issues, see %s.", 850 limit_type, 851 getattr(config, limit_type), 852 format_func_info(code), 853 format_guard_failures(), 854 troubleshooting_url, 855 ) 856 if config.skip_code_recursive_on_cache_limit_hit and justknobs_check( 857 "pytorch/compiler:skip_code_recursive_on_cache_limit_hit" 858 ): 859 raise CacheLimitExceeded(f"{limit_type} reached") 860 else: 861 # do not recursively skip frames 862 unimplemented(f"{limit_type} reached") 863 864 log.debug( 865 "torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s", 866 code.co_name, 867 code.co_filename, 868 code.co_firstlineno, 869 skip + 2, 870 # -2: omit current frame, omit contextlib decorator 871 "".join(CapturedTraceback.extract(skip=2 + skip).format()), 872 ) 873 # -4: -2 as above, plus trace_structured frames 874 # 875 # NB: the frame looks like this: 876 # 877 # # handled by skip argument 878 # torch/_dynamo/convert_frame.py:1069 in catch_errors 879 # torch/_dynamo/convert_frame.py:910 in _convert_frame 880 # torch/_dynamo/convert_frame.py:464 in _convert_frame_assert 881 # torch/_utils_internal.py:70 in wrapper_function 882 # 883 # # 2 current frame and context lib 884 # env/lib/python3.10/contextlib.py:79 in inner 885 # torch/_dynamo/convert_frame.py:776 in _compile 886 # 887 # # 2 extra here 888 # torch/_logging/_internal.py:1064 in trace_structured 889 # torch/_dynamo/convert_frame.py:780 in <lambda> 890 convert_frame_intern = structured.intern_string(__file__) 891 # Initialize the ChromiumEventLogger on start 892 chromium_event_log = get_chromium_event_logger() 893 chromium_event_log.reset() 894 torch._logging.trace_structured( 895 "dynamo_start", 896 lambda: { 897 "stack": list( 898 itertools.takewhile( 899 lambda f: f["filename"] != convert_frame_intern, 900 structured.from_traceback( 901 CapturedTraceback.extract(skip=4 + skip).summary() 902 ), 903 ) 904 ) 905 + [ 906 { 907 "line": code.co_firstlineno, 908 "name": code.co_name, 909 "filename": structured.intern_string(code.co_filename), 910 } 911 ] 912 }, 913 ) 914 start_time = time.time() 915 fail_type: Optional[str] = None 916 fail_reason: Optional[str] = None 917 fail_user_frame_filename: Optional[str] = None 918 fail_user_frame_lineno: Optional[int] = None 919 start_possibly_missed_reinplacing_opportunities = torch._dynamo.utils.counters[ 920 "inductor" 921 ]["possibly_missed_reinplacing_opportunities"] 922 guarded_code = None 923 try: 924 guarded_code = compile_inner(code, one_graph, hooks, transform) 925 return guarded_code 926 except Exception as e: 927 fail_type = type(e).__qualname__ 928 fail_reason = str(e) 929 # NB: e's msg is mutated here to add user stack, but we DON'T want 930 # that stack in the Scuba logged fail_reason 931 exception_handler(e, code, frame, export=export) 932 fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( 933 e, compile_id 934 ) 935 if isinstance( 936 e, 937 ( 938 Unsupported, 939 TorchRuntimeError, 940 BackendCompilerFailed, 941 AssertionError, 942 ConstraintViolationError, 943 GuardOnDataDependentSymNode, 944 ValidationException, 945 UncapturedHigherOrderOpError, 946 BisectValidationException, 947 ), 948 ): 949 raise 950 else: 951 # Rewrap for clarity 952 raise InternalTorchDynamoError( 953 f"{type(e).__qualname__}: {str(e)}" 954 ).with_traceback(e.__traceback__) from None 955 finally: 956 if tracer: 957 tracer.output.local_scope = {} 958 959 from .utils import curr_frame 960 961 frame_key = str(curr_frame) 962 if ( 963 fail_reason is None 964 and output is not None 965 and frame_key in frame_phase_timing 966 ): 967 guard_count = len(output.guards) 968 shape_env_guard_count = len(output.shape_env.guards) 969 graph_op_count = output.count_calls() 970 graph_node_count = len(output.graph.nodes) 971 graph_input_count = len(output.placeholders) 972 entire_frame_compile_time = frame_phase_timing[frame_key].get( 973 "entire_frame_compile", None 974 ) 975 backend_compile_time = frame_phase_timing[frame_key].get( 976 "backend_compile", None 977 ) 978 inductor_compile_time = frame_phase_timing[frame_key].get( 979 "inductor_compile", None 980 ) 981 code_gen_time = frame_phase_timing[frame_key].get("code_gen", None) 982 non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops} 983 compliant_custom_ops = { 984 op.__qualname__ for op in output.compliant_custom_ops 985 } 986 possibly_missed_reinplacing_opportunities = ( 987 torch._dynamo.utils.counters["inductor"][ 988 "possibly_missed_reinplacing_opportunities" 989 ] 990 - start_possibly_missed_reinplacing_opportunities 991 ) 992 else: 993 guard_count = None 994 shape_env_guard_count = None 995 graph_op_count = None 996 graph_node_count = None 997 graph_input_count = None 998 entire_frame_compile_time = None 999 backend_compile_time = None 1000 inductor_compile_time = None 1001 code_gen_time = None 1002 non_compliant_ops = set({}) 1003 compliant_custom_ops = set({}) 1004 restart_reasons = set() 1005 # If compilation failed, the entire time is wasted 1006 dynamo_time_before_restart = time.time() - start_time 1007 possibly_missed_reinplacing_opportunities = None 1008 1009 metrics = CompilationMetrics( 1010 str(compile_id), 1011 frame_key, 1012 code.co_name, 1013 code.co_filename, 1014 code.co_firstlineno, 1015 cache_size.num_cache_entries_with_same_id_matched_objs, 1016 cache_size.num_cache_entries, 1017 guard_count, 1018 shape_env_guard_count, 1019 graph_op_count, 1020 graph_node_count, 1021 graph_input_count, 1022 start_time, 1023 entire_frame_compile_time, 1024 backend_compile_time, 1025 inductor_compile_time, 1026 code_gen_time, 1027 fail_type, 1028 fail_reason, 1029 fail_user_frame_filename, 1030 fail_user_frame_lineno, 1031 non_compliant_ops, 1032 compliant_custom_ops, 1033 restart_reasons, 1034 dynamo_time_before_restart, 1035 guarded_code is not None, 1036 possibly_missed_reinplacing_opportunities, 1037 ) 1038 record_compilation_metrics(metrics) 1039 torch._dynamo.callback_handler.run_end_callbacks() 1040 1041 1042class ConvertFrame: 1043 def __init__(self, compiler_fn: CompilerFn, hooks: Hooks) -> None: 1044 self._torchdynamo_orig_callable = compiler_fn 1045 self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False) 1046 self._hooks = hooks 1047 1048 @property 1049 def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]: 1050 return lambda backend: convert_frame(backend, self._hooks) 1051 1052 def __call__( 1053 self, 1054 frame: FrameType, 1055 cache_entry: Optional[CacheEntry], 1056 hooks: Hooks, 1057 frame_state: Dict[str, Union[int, FrameStateSizeEntry]], 1058 skip: int = 0, 1059 ) -> Optional[ 1060 Union[GuardedCode, torch._C._dynamo.eval_frame.SkipCodeRecursiveFlag] 1061 ]: 1062 counters["frames"]["total"] += 1 1063 try: 1064 result = self._inner_convert( 1065 frame, cache_entry, hooks, frame_state, skip=skip + 1 1066 ) 1067 counters["frames"]["ok"] += 1 1068 return result 1069 except Exception as e: 1070 # These two exception types are "soft" failure, in the sense that 1071 # we know this is due to something we didn't implement all the 1072 # way, scare the user less about it. That being said, if you 1073 # are trying to understand why a graph break happened, it's still 1074 # important to have this information, so offer it. 1075 # 1076 # NB: NotImplementedError used to be on this list, but actually 1077 # it is impossible for it to reach here, as it is converted into 1078 # InternalTorchDynamoError. This behavior seemed reasonable 1079 # to me (ezyang, Aug 2023) so I kept it, but maybe at some point 1080 # someone wanted these to also get suppressed. If so, you'll 1081 # need to make these exceptions not get wrapped 1082 1083 # We intentionally don't want to suppress error here. 1084 if isinstance(e, UncapturedHigherOrderOpError): 1085 raise 1086 1087 soft_fail = isinstance(e, Unsupported) 1088 1089 # This is a soft failure. In the sense, the code path reaches here 1090 # when we do not support graph breaks on bytecodes like LOAD_ATTR, 1091 # BUILD_SET etc. In such case, we can fallback to eager without 1092 # scaring users. 1093 if isinstance(e, Unsupported) and graph_break_log.isEnabledFor( 1094 logging.DEBUG 1095 ): 1096 # Log this message in the graph break. Also use the string 1097 # "skip: " to tell that the whole frame is falling back to 1098 # eager. 1099 if hasattr(e, "compile_id"): 1100 with compile_context(CompileContext(e.compile_id)): # type: ignore[attr-defined] 1101 user_stack = e.real_stack 1102 user_stack_formatted = "".join( 1103 traceback.format_list(user_stack) 1104 ) 1105 graph_break_log.debug( 1106 "Graph break: skip: from user code at:\n%s", 1107 user_stack_formatted, 1108 exc_info=True, 1109 ) 1110 1111 if not config.suppress_errors and not soft_fail: 1112 raise 1113 1114 # Suppress the error. NB: It's very important to do the 1115 # suppression logging HERE, where the actual suppression 1116 # happens. Previously it was somewhere else and so it was 1117 # possible to accidentally not log at all. 1118 record_filename = getattr(e, "record_filename", None) 1119 code = frame.f_code 1120 error_msg = format_error_msg(e, code, record_filename, frame) 1121 1122 if soft_fail: 1123 log.info(error_msg, exc_info=True) 1124 else: 1125 log.warning(error_msg, exc_info=True) 1126 1127 # If we encounter SkipCodeRecursiveException, return skip_code_recursive_flag 1128 # to signal to Dynamo eval frame to skip the current frame and any recursive calls. 1129 if isinstance(e, SkipCodeRecursiveException): 1130 return torch._C._dynamo.eval_frame.skip_code_recursive_flag 1131 1132 return None 1133 1134 1135def convert_frame(compiler_fn: CompilerFn, hooks: Hooks) -> ConvertFrame: 1136 """Try to convert a frame into an FX graph, if error leave frame unmodified""" 1137 return ConvertFrame(compiler_fn, hooks) 1138 1139 1140# TODO mlazos: add support for same args, or record them 1141def replay(filename: str) -> None: 1142 from .backends.debugging import eager 1143 1144 original_replay_val = config.replay_record_enabled 1145 config.replay_record_enabled = False 1146 with open(filename, "rb") as in_file: 1147 record = ExecutionRecord.load(in_file) 1148 record.globals = dict(itertools.chain(record.globals.items(), globals().items())) 1149 1150 try: 1151 _compile( 1152 record.code, 1153 record.globals, 1154 record.locals, 1155 record.builtins, 1156 compiler_fn=eager, 1157 one_graph=False, 1158 export=False, 1159 export_constraints=None, 1160 hooks=Hooks(), 1161 cache_size=CacheSizeRelevantForFrame(0, 0), 1162 cache_entry=None, 1163 frame=None, 1164 frame_state={}, 1165 compile_id=CompileId(42, 999), 1166 ) 1167 finally: 1168 config.replay_record_enabled = original_replay_val 1169 1170 1171def first_real_inst_idx(code: CodeType) -> int: 1172 if sys.version_info < (3, 11): 1173 return 0 1174 for inst in dis.get_instructions(code): 1175 if inst.opname == "RESUME": 1176 return inst.offset // 2 1177 raise RuntimeError("RESUME instruction not found in code") 1178 1179 1180class ConvertFrameProtocol(typing.Protocol): 1181 def __call__( 1182 self, 1183 frame: FrameType, 1184 cache_entry: Optional[CacheEntry], 1185 hooks: Hooks, 1186 frame_state: Dict[str, Union[int, FrameStateSizeEntry]], 1187 *, 1188 skip: int = 0, 1189 ) -> Optional[GuardedCode]: 1190 ... 1191 1192 1193class CatchErrorsWrapper: 1194 def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None: 1195 functools.wraps(callback)(self) 1196 self._torchdynamo_orig_callable = callback 1197 self.hooks = hooks 1198 1199 def __call__( 1200 self, 1201 frame: FrameType, 1202 cache_entry: Optional[CacheEntry], 1203 frame_state: Dict[str, Union[int, FrameStateSizeEntry]], 1204 ) -> Optional[GuardedCode]: 1205 assert frame_state is not None 1206 1207 is_skipfile = trace_rules.check(frame.f_code) 1208 if sys.version_info >= (3, 13): 1209 has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code) 1210 else: 1211 has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code) 1212 if ( 1213 # TODO: the first condition is not covered by any test 1214 has_started_execution 1215 or is_skipfile 1216 or config.disable 1217 or ( 1218 is_in_torch_dispatch_mode(include_infra_modes=False) 1219 and not getattr(self._torchdynamo_orig_callable, "_export", False) 1220 ) 1221 ): 1222 if log.isEnabledFor(logging.DEBUG): 1223 print(frame.f_lasti, first_real_inst_idx(frame.f_code)) 1224 1225 if has_started_execution: 1226 skip_reason = "traced frame already" 1227 elif trace_rules.check(frame.f_code): 1228 skip_reason = "in skipfiles" 1229 elif is_in_torch_dispatch_mode(include_infra_modes=False): 1230 skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile" 1231 else: 1232 skip_reason = "dynamo tracing is disabled" 1233 1234 log.debug( 1235 "skipping: %s (reason: %s, file: %s)", 1236 frame.f_code.co_name, 1237 skip_reason, 1238 frame.f_code.co_filename, 1239 ) 1240 return None 1241 1242 if frame.f_code.co_filename == "<string>" and frame.f_code.co_name == "__new__": 1243 # nametuple constructor 1244 return None 1245 if config._get_optimize_ddp_mode() == "ddp_optimizer": 1246 ddp_module = DistributedDataParallel._get_active_ddp_module() 1247 if ddp_module: 1248 with compile_lock: 1249 from torch._dynamo.backends.distributed import DDPOptimizer 1250 1251 ddp_optimizer = DDPOptimizer( 1252 bucket_bytes_cap=ddp_module.bucket_bytes_cap, 1253 backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable, # type: ignore[attr-defined] 1254 ) 1255 assert hasattr( 1256 self._torchdynamo_orig_callable, "_clone_with_backend" 1257 ), "DDPOptimizer only supports callback fns that know how to clone themselves." 1258 hijacked_callback = ( 1259 self._torchdynamo_orig_callable._clone_with_backend( 1260 ddp_optimizer.compile_fn, 1261 ) 1262 ) 1263 return hijacked_callback( 1264 frame, cache_entry, self.hooks, frame_state 1265 ) 1266 1267 with compile_lock, _disable_current_modes(): 1268 # skip=1: skip this frame 1269 return self._torchdynamo_orig_callable( 1270 frame, cache_entry, self.hooks, frame_state, skip=1 1271 ) 1272 1273 1274def catch_errors_wrapper( 1275 callback: ConvertFrameProtocol, hooks: Hooks 1276) -> CatchErrorsWrapper: 1277 return CatchErrorsWrapper(callback, hooks) 1278