1# mypy: allow-untyped-defs 2# mypy: disable-error-code="method-assign" 3 4""" 5Functions in this file are responsible for modifying the eval frame 6handler at RUNTIME. Therefore, all functions in this file are hot. 7Functions that only execute at compile time should be placed 8in torch._dynamo.convert_frame. 9""" 10 11from __future__ import annotations 12 13import contextlib 14import functools 15import inspect 16import logging 17import os 18import sys 19import textwrap 20import traceback 21import types 22import warnings 23import weakref 24from enum import Enum 25from os.path import dirname, join 26from typing import ( 27 Any, 28 Callable, 29 Dict, 30 List, 31 NamedTuple, 32 Optional, 33 Set, 34 Tuple, 35 TYPE_CHECKING, 36 Union, 37) 38from unittest.mock import patch 39 40import sympy 41 42import torch 43import torch.fx 44import torch.utils._pytree as pytree 45import torch.utils.checkpoint 46from torch import _guards 47 48# see discussion at https://github.com/pytorch/pytorch/issues/120699 49from torch._C._dynamo.eval_frame import ( # noqa: F401 50 reset_code, 51 set_guard_error_hook, 52 skip_code, 53 unsupported, 54) 55from torch._dispatch.python import enable_python_dispatcher 56from torch._subclasses.fake_tensor import unset_fake_temporarily 57from torch._utils_internal import justknobs_check, log_export_usage 58from torch.export.dynamic_shapes import _combine_args, _process_dynamic_shapes 59from torch.fx import GraphModule 60from torch.fx.experimental.proxy_tensor import make_fx 61from torch.fx.experimental.symbolic_shapes import ( 62 ConstraintViolationError, 63 DimDynamic, 64 ShapeEnv, 65 StatelessSymbolicContext, 66) 67from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo 68 69from . import config, convert_frame, external_utils, trace_rules, utils 70from .backends.registry import CompilerFn, lookup_backend 71from .code_context import code_context 72from .exc import CondOpArgsMismatchError, UserError, UserErrorType 73from .hooks import Hooks 74from .mutation_guard import install_generation_tagging_init 75from .utils import common_constant_types, compile_times 76 77 78if TYPE_CHECKING: 79 from torch._subclasses import fake_tensor 80 81 from .types import CacheEntry, DynamoCallback 82 83 84log = logging.getLogger(__name__) 85 86 87always_optimize_code_objects = utils.ExactWeakKeyDictionary() 88null_context = contextlib.nullcontext 89 90 91# See https://github.com/python/typing/pull/240 92class Unset(Enum): 93 token = 0 94 95 96cached_backends: Dict[int, CompilerFn] = {} 97 98unset = Unset.token 99 100 101def _maybe_set_eval_frame(callback: DynamoCallback): 102 # A wrapper on set_eval_frame that is guarded by a Justknob. 103 # Users can disable torchDynamo by setting the JK to False. 104 from torch._C._dynamo.eval_frame import set_eval_frame 105 106 if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"): 107 torch._dynamo.utils.warn_once( 108 "Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame" 109 ) 110 return callback 111 else: 112 return set_eval_frame(callback) 113 114 115def _reset_guarded_backend_cache(): 116 global cached_backends 117 for backend in cached_backends.values(): 118 if hasattr(backend, "reset"): 119 backend.reset() 120 cached_backends.clear() 121 122 123DONT_WRAP_FILES = { 124 # For tracing into fx modules 125 inspect.getsourcefile(GraphModule), 126 join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"), 127} 128 129 130def _debug_get_cache_entry_list( 131 code: Union[types.CodeType, Callable[..., Any]] 132) -> List[CacheEntry]: 133 """ 134 Given a code object or a callable object, retrieve the cache entries 135 stored in this code. 136 """ 137 if callable(code): 138 code = code.__code__ 139 return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code) 140 141 142class OptimizedModule(torch.nn.Module): 143 """ 144 Wraps the original nn.Module object and later patches its 145 forward method to optimized self.forward method. 146 """ 147 148 _torchdynamo_orig_callable: Callable[..., Any] 149 get_compiler_config: Callable[[], Any] 150 151 _opt_mod_attributes = { 152 "_orig_mod", 153 "dynamo_ctx", 154 "_torchdynamo_orig_callable", 155 "get_compiler_config", 156 "forward", 157 "_forward", 158 "__dict__", 159 "named_children_walk", 160 } 161 162 def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None: 163 super().__init__() 164 # Installs the params/buffer 165 self._orig_mod = mod 166 self.dynamo_ctx = dynamo_ctx 167 self._initialize() 168 self.training = self._orig_mod.training 169 170 def _initialize(self): 171 # Do this stuff in constructor to lower overhead slightly 172 if isinstance(self.dynamo_ctx, DisableContext): 173 # No need to check trace rules 174 self.forward = self.dynamo_ctx(self._orig_mod.__call__) 175 elif isinstance(self._orig_mod.forward, types.MethodType) and ( 176 trace_rules.check(self._orig_mod.forward) 177 or getattr(self._orig_mod, "_is_fsdp_managed_module", False) 178 ): 179 # This may be a torch.nn.* instance in trace_rules.py which 180 # won't trigger a frame evaluation workaround to add an extra 181 # frame we can capture 182 self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod)) 183 else: 184 # Invoke hooks outside of dynamo then pickup the inner frame 185 self.forward = self.dynamo_ctx(self._orig_mod.__call__) 186 187 if hasattr(self._orig_mod, "_initialize_hook"): 188 self._forward = self.forward 189 self.forward = self._call_lazy_check 190 191 def __reduce__(self): 192 return (self.__class__, (self._orig_mod, self.dynamo_ctx)) 193 194 def __getstate__(self): 195 state = dict(self.__dict__) 196 state.pop("forward", None) 197 state.pop("__call__", None) 198 return state 199 200 def __setstate__(self, state): 201 self.__dict__ = state 202 self._initialize() 203 204 @property 205 def training(self): 206 return self._orig_mod.training 207 208 @training.setter 209 def training(self, value): 210 try: 211 super().__getattr__("_orig_mod") 212 self._orig_mod.training = value 213 except AttributeError: 214 # still initializing 215 pass 216 217 def __getattr__(self, name): 218 if name == "_orig_mod": 219 return self._modules["_orig_mod"] 220 return getattr(self._orig_mod, name) 221 222 def __setattr__(self, name, val) -> None: 223 # Allow patching over class attributes 224 if hasattr(type(self), name): 225 return super().__setattr__(name, val) 226 227 if name in OptimizedModule._opt_mod_attributes: 228 return super().__setattr__(name, val) 229 return setattr(self._orig_mod, name, val) 230 231 def _call_lazy_check(self, *args, **kwargs): 232 if hasattr(self._orig_mod, "_initialize_hook"): 233 # In the case of a lazy module, we want to run 234 # the pre-hooks which initialize it. 235 # Afterwards, lazy module deletes its pre-hooks 236 # to avoid treating it as lazy on subsequent recompile. 237 self._orig_mod._infer_parameters(self._orig_mod, args, kwargs) 238 return self._forward(*args, **kwargs) 239 240 def __dir__(self): 241 orig_mod_attrs = self._orig_mod.__dir__() 242 return orig_mod_attrs + [ 243 attr for attr in super().__dir__() if attr not in orig_mod_attrs 244 ] 245 246 247def remove_from_cache(f): 248 """ 249 Make sure f.__code__ is not cached to force a recompile 250 """ 251 if isinstance(f, types.CodeType): 252 reset_code(f) 253 elif hasattr(f, "__code__"): 254 reset_code(f.__code__) 255 elif hasattr(getattr(f, "forward", None), "__code__"): 256 reset_code(f.forward.__code__) 257 else: 258 from . import reset # type: ignore[attr-defined] 259 260 reset() 261 log.warning("could not determine __code__ for %s", f) 262 263 264def nothing(): 265 pass 266 267 268def always_false(): 269 return False 270 271 272def innermost_fn(fn): 273 """ 274 In case of nesting of _TorchDynamoContext calls, find the innermost 275 function. TorchDynamo caches on fn.__code__ object, so its necessary to find 276 the innermost function to pass on the optimize, run, disable etc. 277 """ 278 unaltered_fn = fn 279 while hasattr(unaltered_fn, "_torchdynamo_orig_callable"): 280 unaltered_fn = unaltered_fn._torchdynamo_orig_callable 281 assert callable(unaltered_fn) 282 return unaltered_fn 283 284 285def make_set_enable_dynamic(enable: bool): 286 assert isinstance(enable, bool) 287 if enable: 288 # Assume everything is dynamic by default 289 return config._make_closure_patcher(assume_static_by_default=False) 290 else: 291 return config._make_closure_patcher( 292 automatic_dynamic_shapes=False, assume_static_by_default=True 293 ) 294 295 296class _TorchDynamoContext: 297 def __init__( 298 self, 299 callback: DynamoCallback, 300 on_enter=nothing, 301 backend_ctx_ctor=null_context, 302 patch_fn=nothing, 303 first_ctx=False, 304 *, 305 export=False, 306 dynamic=None, 307 compiler_config=None, 308 ) -> None: 309 super().__init__() 310 assert callable(callback) or callback is False or callback is None 311 self.callback: DynamoCallback = callback 312 self._backend_ctx_ctor = backend_ctx_ctor 313 self.prior: Union[Unset, DynamoCallback] = unset 314 self.first_ctx = first_ctx 315 self.export = export 316 self._dynamic = dynamic 317 self.compiler_config = compiler_config 318 self.cleanup_fns: List[Callable[[], Any]] = [] 319 self.enter_exit_hooks = [] 320 patch_fn() 321 322 # Save the backends so that we can reset them during torch._dynamo.reset 323 backend = innermost_fn(callback) 324 cached_backends.setdefault(id(backend), backend) 325 326 if dynamic is not None: 327 self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic)) 328 329 if on_enter is not nothing: 330 # this case is not common 331 def call_on_enter(): 332 on_enter() 333 return nothing 334 335 self.enter_exit_hooks.append(call_on_enter) 336 337 if backend_ctx_ctor is not contextlib.nullcontext: 338 # this case is not common 339 def call_backend_ctx(): 340 ctx = backend_ctx_ctor() 341 ctx.__enter__() 342 return functools.partial(ctx.__exit__, None, None, None) 343 344 self.enter_exit_hooks.append(call_backend_ctx) 345 346 def __enter__(self): 347 if config.raise_on_ctx_manager_usage: 348 raise RuntimeError( 349 "torch._dynamo.optimize(...) is used with a context manager. " 350 "Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html " 351 "to use torch._dynamo.optimize(...) as an annotation/decorator. " 352 ) 353 self.cleanup_fns = [enter() for enter in self.enter_exit_hooks] 354 self.prior = _maybe_set_eval_frame(self.callback) 355 356 def __exit__(self, exc_type, exc_val, exc_tb): 357 assert self.prior is not unset 358 _maybe_set_eval_frame(self.prior) 359 self.prior = unset 360 for cleanup in self.cleanup_fns: 361 cleanup() 362 self.cleanup_fns.clear() 363 364 def __call__(self, fn): 365 # public api for compiler config/options 366 def get_compiler_config(): 367 return self.compiler_config 368 369 fn = innermost_fn(fn) 370 371 # add context containing GraphModule to any GraphModule forward functions 372 if isinstance(fn, GraphModule): 373 # add context containing GraphModule to any GraphModule forward functions 374 code_context.get_context(fn.forward.__code__)[ 375 "orig_graphmodule" 376 ] = weakref.ref(fn) 377 378 # Optimize the forward method of torch.nn.Module object 379 if isinstance(fn, torch.nn.Module): 380 mod = fn 381 new_mod = OptimizedModule(mod, self) 382 # Save the function pointer to find the original callable while nesting 383 # of decorators. 384 new_mod._torchdynamo_orig_callable = mod.forward 385 386 # when compiling torch.nn.Module, 387 # provide public api OptimizedModule.get_compiler_config() 388 assert not hasattr(new_mod, "get_compiler_config") 389 new_mod.get_compiler_config = get_compiler_config 390 391 return new_mod 392 393 if inspect.isclass(fn): 394 # User has wrapped the class with compile/disable decorator. Apply 395 # disable to init/call method. 396 cls_obj = fn 397 cls_obj.__call__ = self(cls_obj.__call__) 398 if issubclass(cls_obj, torch.nn.Module): 399 # NN module variable tracker directly inlines the _call_impl. 400 cls_obj._call_impl = self(cls_obj._call_impl) 401 return cls_obj 402 403 assert callable(fn) 404 405 try: 406 filename = inspect.getsourcefile(fn) 407 except TypeError: 408 filename = None 409 if ( 410 (filename is None or trace_rules.check(fn)) 411 and ( 412 getattr(fn, "__name__", "") 413 not in ["_call_impl", "_wrapped_call_impl", "_lazy_forward"] 414 ) 415 and filename not in DONT_WRAP_FILES 416 ): 417 # call to a builtin without a frame for us to capture 418 fn = external_utils.wrap_inline(fn) 419 420 def do_nothing(*arg, **kwargs): 421 pass 422 423 if hasattr(self, "callback"): 424 callback = self.callback 425 else: 426 callback = do_nothing 427 428 is_jit_tracing = torch._C._is_tracing 429 is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing 430 431 @functools.wraps(fn) 432 def _fn(*args, **kwargs): 433 if is_fx_tracing(): 434 if config.error_on_nested_fx_trace: 435 raise RuntimeError( 436 "Detected that you are using FX to symbolically trace " 437 "a dynamo-optimized function. This is not supported at the moment." 438 ) 439 else: 440 return fn(*args, **kwargs) 441 442 if is_jit_tracing(): 443 if config.error_on_nested_jit_trace: 444 raise RuntimeError( 445 "Detected that you are using FX to torch.jit.trace " 446 "a dynamo-optimized function. This is not supported at the moment." 447 ) 448 else: 449 return fn(*args, **kwargs) 450 451 cleanups = [enter() for enter in self.enter_exit_hooks] 452 prior = _maybe_set_eval_frame(callback) 453 454 # Ensure that if an assertion occurs after graph pushes 455 # something onto the DynamicLayerStack then we pop it off (the 456 # constructed graph code isn't guarded with try/finally). 457 # 458 # This used to be a context but putting a `with` here is a noticible 459 # perf regression (#126293) 460 saved_dynamic_layer_stack_depth = ( 461 torch._C._functorch.get_dynamic_layer_stack_depth() 462 ) 463 464 try: 465 return fn(*args, **kwargs) 466 finally: 467 # Restore the dynamic layer stack depth if necessary. 468 torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( 469 saved_dynamic_layer_stack_depth 470 ) 471 472 _maybe_set_eval_frame(prior) 473 for cleanup in cleanups: 474 cleanup() 475 476 # hooks to properly handle inlining 477 _fn._torchdynamo_inline = fn # type: ignore[attr-defined] 478 479 # Save the function pointer to find the original callable while nesting 480 # of decorators. 481 _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] 482 483 # when compiling user function instead of nn.Module 484 # provide public api _fn.get_compiler_config() 485 assert not hasattr(_fn, "get_compiler_config") 486 _fn.get_compiler_config = get_compiler_config # type: ignore[attr-defined] 487 488 # If the function is called using torch._dynamo.optimize decorator, we 489 # should prevent any type of skipping. 490 if callback not in (None, False): 491 if not hasattr(fn, "__code__"): 492 raise RuntimeError( 493 textwrap.dedent( 494 """ 495 496 torch._dynamo.optimize is called on a non function object. 497 If this is a callable class, please wrap the relevant code into a function and optimize the 498 wrapper function. 499 500 >> class CallableClass: 501 >> def __init__(self) -> None: 502 >> super().__init__() 503 >> self.relu = torch.nn.ReLU() 504 >> 505 >> def __call__(self, x): 506 >> return self.relu(torch.sin(x)) 507 >> 508 >> def print_hello(self): 509 >> print("Hello world") 510 >> 511 >> mod = CallableClass() 512 513 If you want to optimize the __call__ function and other code, wrap that up in a function 514 515 >> def wrapper_fn(x): 516 >> y = mod(x) 517 >> return y.sum() 518 519 and then optimize the wrapper_fn 520 521 >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn) 522 """ 523 ) 524 ) 525 always_optimize_code_objects[fn.__code__] = True 526 527 return _fn 528 529 530class OptimizeContext(_TorchDynamoContext): 531 def __init__( 532 self, 533 callback, 534 backend_ctx_ctor, 535 first_ctx=False, 536 *, 537 export=False, 538 dynamic=None, 539 compiler_config=None, 540 rebuild_ctx: Optional[ 541 Callable[[], Union[OptimizeContext, _NullDecorator]] 542 ] = None, 543 ) -> None: 544 def on_enter(): 545 install_generation_tagging_init() 546 547 super().__init__( 548 callback=callback, 549 on_enter=on_enter, 550 backend_ctx_ctor=backend_ctx_ctor, 551 patch_fn=TorchPatcher.patch, 552 first_ctx=first_ctx, 553 export=export, 554 dynamic=dynamic, 555 compiler_config=compiler_config, 556 ) 557 558 if config.compiled_autograd: 559 560 def call_compiled_autograd(): 561 assert rebuild_ctx is not None 562 compiler_fn = rebuild_ctx() 563 ctx = torch._dynamo.compiled_autograd.enable(compiler_fn) 564 ctx.__enter__() 565 return functools.partial(ctx.__exit__, None, None, None) 566 567 self.enter_exit_hooks.append(call_compiled_autograd) 568 569 def __reduce__(self): 570 return ( 571 self.__class__, 572 (self.callback, self._backend_ctx_ctor, self.first_ctx), 573 { 574 "export": self.export, 575 "dynamic": self._dynamic, 576 "compiler_config": self.compiler_config, 577 }, 578 ) 579 580 581class RunOnlyContext(_TorchDynamoContext): 582 def __init__(self) -> None: 583 # cudagraph trees relies on generation increment 584 def on_enter(): 585 torch._dynamo.mutation_guard.GenerationTracker.generation += 1 586 587 super().__init__(callback=False, on_enter=on_enter) 588 589 def __reduce__(self): 590 return (self.__class__, ()) 591 592 593class DisableContext(_TorchDynamoContext): 594 def __init__(self) -> None: 595 super().__init__(callback=None) 596 597 def __call__(self, fn): 598 # Earlier this code was in the base class _TorchDynamoContext. But we 599 # moved it here to have better code organization. For disable, we just 600 # want the callback to be None. We don't have to check trace_rules or 601 # create any wrapper. 602 fn = innermost_fn(fn) 603 604 if isinstance(fn, torch.nn.Module): 605 mod = fn 606 new_mod = OptimizedModule(mod, self) 607 new_mod._torchdynamo_orig_callable = mod.forward 608 return new_mod 609 610 if inspect.isclass(fn): 611 # User has wrapped the class with compile/disable decorator. Apply 612 # disable to init/call method. 613 cls_obj = fn 614 # Disable on init is useful for reconstruction of bytecodes where we 615 # want to prevent Dynamo from tracing into the init function. Check 616 # test_reconstruction in test_model_output.py. 617 cls_obj.__init__ = self(cls_obj.__init__) 618 cls_obj.__call__ = self(cls_obj.__call__) 619 if issubclass(cls_obj, torch.nn.Module): 620 # NN module variable tracker directly inlines the _call_impl. Disable it. 621 cls_obj._call_impl = self(cls_obj._call_impl) 622 return cls_obj 623 624 assert callable(fn) 625 626 callback = self.callback 627 628 @functools.wraps(fn) 629 def _fn(*args, **kwargs): 630 prior = _maybe_set_eval_frame(callback) 631 try: 632 return fn(*args, **kwargs) 633 finally: 634 _maybe_set_eval_frame(prior) 635 636 _fn._torchdynamo_disable = True # type: ignore[attr-defined] 637 638 # Save the function pointer to find the original callable while nesting 639 # of decorators. 640 _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] 641 642 return _fn 643 644 def __reduce__(self): 645 return (self.__class__, ()) 646 647 648def _optimize_catch_errors( 649 compile_fn, 650 hooks: Hooks, 651 backend_ctx_ctor=null_context, 652 export=False, 653 dynamic=None, 654 compiler_config=None, 655 rebuild_ctx=None, 656): 657 return OptimizeContext( 658 convert_frame.catch_errors_wrapper(compile_fn, hooks), 659 backend_ctx_ctor=backend_ctx_ctor, 660 first_ctx=True, 661 export=export, 662 dynamic=dynamic, 663 compiler_config=compiler_config, 664 rebuild_ctx=rebuild_ctx, 665 ) 666 667 668def get_compiler_fn(compiler_fn): 669 from .repro.after_dynamo import wrap_backend_debug 670 671 if hasattr(compiler_fn, "compiler_name"): 672 compiler_str = compiler_fn.compiler_name 673 elif isinstance(compiler_fn, str): 674 compiler_str = compiler_fn 675 else: 676 compiler_str = None 677 compiler_fn = lookup_backend(compiler_fn) 678 return wrap_backend_debug(compiler_fn, compiler_str) 679 680 681class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] 682 def __call__(self, fn): 683 assert callable(fn) 684 return fn 685 686 687def check_if_dynamo_supported(): 688 if sys.version_info >= (3, 13): 689 raise RuntimeError("Python 3.13+ not yet supported for torch.compile") 690 691 692def is_dynamo_supported(): 693 try: 694 check_if_dynamo_supported() 695 return True 696 except Exception: 697 return False 698 699 700def check_if_inductor_supported(): 701 check_if_dynamo_supported() 702 703 704def is_inductor_supported(): 705 try: 706 check_if_inductor_supported() 707 return True 708 except Exception: 709 return False 710 711 712def optimize(*args, **kwargs): 713 def rebuild_ctx(): 714 return optimize(*args, **kwargs) 715 716 return _optimize(rebuild_ctx, *args, **kwargs) 717 718 719def _optimize( 720 rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], 721 backend="inductor", 722 *, 723 nopython=False, 724 guard_export_fn=None, 725 guard_fail_fn=None, 726 disable=False, 727 dynamic=None, 728) -> Union[OptimizeContext, _NullDecorator]: 729 """ 730 The main entrypoint of TorchDynamo. Do graph capture and call 731 backend() to optimize extracted graphs. 732 733 Args: 734 backend: One of the two things: 735 - Either, a function/callable taking a torch.fx.GraphModule and 736 example_inputs and returning a python callable that runs the 737 graph faster. 738 One can also provide additional context for the backend, like 739 torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute. 740 See AOTAutogradMemoryEfficientFusionWithContext for the usage. 741 - Or, a string backend name in `torch._dynamo.list_backends()` 742 nopython: If True, graph breaks will be errors and there will 743 be a single whole-program graph. 744 disable: If True, turn this decorator into a no-op 745 dynamic: If True, upfront compile as dynamic a kernel as possible. If False, 746 disable all dynamic shapes support (always specialize). If None, automatically 747 detect when sizes vary and generate dynamic kernels upon recompile. 748 749 Example Usage:: 750 751 @torch._dynamo.optimize() 752 def toy_example(a, b): 753 ... 754 """ 755 check_if_dynamo_supported() 756 # Note: The hooks object could be global instead of passed around, *however* that would make 757 # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls. 758 # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same 759 # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an 760 # easier to understand UX at the cost of a little more plumbing on our end. 761 hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn) 762 torch._C._log_api_usage_once("torch._dynamo.optimize") 763 if ( 764 disable 765 or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1" 766 or (not justknobs_check("pytorch/compiler:enable_dynamo")) 767 ): 768 return _NullDecorator() 769 770 backend = get_compiler_fn(backend) 771 772 # Find if backend has any extra context manager 773 backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) 774 775 if nopython: 776 return optimize_assert( 777 backend, 778 dynamic=dynamic, 779 hooks=hooks, 780 rebuild_ctx=rebuild_ctx, 781 ) 782 # The backend function is stashed in the callable returned by 783 # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can 784 # be used by eval_frame.c to insert a guard on the backend. 785 return _optimize_catch_errors( 786 convert_frame.convert_frame(backend, hooks=hooks), 787 hooks, 788 backend_ctx_ctor, 789 dynamic=dynamic, 790 compiler_config=backend.get_compiler_config() 791 if hasattr(backend, "get_compiler_config") 792 else None, 793 rebuild_ctx=rebuild_ctx, 794 ) 795 796 797# TODO(voz): Consider making "explain" output alongside a run / part of a run 798@patch("torch._dynamo.symbolic_convert.explain", True) 799def explain(f, *extra_args, **extra_kwargs): 800 def inner(*args, **kwargs): 801 # TODO(voz): Do we want a decorator for this? 802 from . import reset # type: ignore[attr-defined] 803 804 reset() 805 806 graphs: List[torch.fx.GraphModule] = [] 807 break_reasons: List[Any] = [] 808 op_count: int = 0 809 ops_per_graph: List[torch.fx.Node] = [] 810 out_guards: List[_guards.Guard] = [] 811 812 def dynamo_graph_accumulating_compiler( 813 gm: torch.fx.GraphModule, example_inputs 814 ): 815 from .backends.debugging import _explain_graph_detail 816 817 nonlocal graphs 818 nonlocal op_count 819 nonlocal ops_per_graph 820 nonlocal break_reasons 821 822 gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail( 823 gm, graphs, op_count, ops_per_graph, break_reasons 824 ) 825 826 return gm.forward 827 828 def guard_export_print(guards): 829 nonlocal out_guards 830 out_guards.extend(guards) 831 832 opt_f = optimize( 833 dynamo_graph_accumulating_compiler, 834 nopython=False, 835 guard_export_fn=guard_export_print, 836 )(f) 837 # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject. 838 opt_f(*args, **kwargs) 839 840 graph_count = len(graphs) 841 graph_break_count = graph_count - 1 842 compile_time = compile_times(repr="str") 843 844 # TODO(voz): Do we want a decorator for this? 845 reset() 846 from .backends.debugging import ExplainOutput 847 848 return ExplainOutput( 849 graphs, 850 graph_count, 851 graph_break_count, 852 break_reasons, 853 op_count, 854 ops_per_graph, 855 out_guards, 856 compile_time, 857 ) 858 859 if extra_args or extra_kwargs: 860 warnings.warn( 861 "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. " 862 "If you don't migrate, we may break your explain call in the future if your user defined kwargs " 863 "conflict with future kwargs added to explain(f).", 864 FutureWarning, 865 stacklevel=2, 866 ) 867 return inner(*extra_args, **extra_kwargs) 868 else: 869 return inner 870 871 872class FlattenInputOutputSignature(torch.fx.interpreter.Transformer): 873 def __init__( 874 self, 875 m: torch.fx.GraphModule, 876 flat_args: Tuple[Any], 877 matched_input_elements_positions: List[int], 878 flat_results: List[Any], 879 matched_output_elements_positions: List[int], 880 example_fake_inputs: List[torch.Tensor], 881 flat_args_dynamic_dims: List[Set[int]], 882 fake_mode: Optional[fake_tensor.FakeTensorMode] = None, 883 ) -> None: 884 super().__init__(m) 885 886 assert len(flat_args_dynamic_dims) == len(flat_args) 887 matched_input_elements_to_fake = { 888 val: example_fake_inputs[ix] 889 for ix, val in enumerate(matched_input_elements_positions) 890 } 891 892 self.new_args = [] 893 for i in range(0, len(flat_args)): 894 arg = super().placeholder(f"arg{i}", (), {}) 895 if i in matched_input_elements_to_fake: 896 arg.node.meta["val"] = matched_input_elements_to_fake[i] 897 else: 898 # Fill node.mata["val"] with faketensor from the input, 899 # if it's not found in matched_input_elements_positions 900 if fake_mode is not None and isinstance(flat_args[i], torch.Tensor): 901 # TODO(zhxchen17) Also preserve all the user constraints here. 902 arg.node.meta["val"] = fake_mode.from_tensor( 903 flat_args[i], 904 symbolic_context=StatelessSymbolicContext( 905 dynamic_sizes=[ 906 DimDynamic.DYNAMIC 907 if d in flat_args_dynamic_dims[i] 908 else DimDynamic.STATIC 909 for d in range(len(flat_args[i].shape)) 910 ], 911 constraint_sizes=[None] * len(flat_args[i].shape), 912 ), 913 ) 914 self.new_args.append(arg) 915 self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions) 916 self.matched_output_elements_positions = matched_output_elements_positions 917 self.flat_results = flat_results 918 919 def placeholder(self, target, args, kwargs): 920 arg = next(self.old_args_gen) 921 if "val" in self.current_node.meta: 922 arg.node.meta["val"] = self.current_node.meta["val"] 923 if "tensor_dict" in self.current_node.meta: 924 arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"] 925 if "example_value" in self.current_node.meta: 926 # NB: intentionally do not use set_example_value 927 arg.node.meta["example_value"] = self.current_node.meta["example_value"] 928 if "unbacked_bindings" in self.current_node.meta: 929 arg.node.meta["unbacked_bindings"] = self.current_node.meta[ 930 "unbacked_bindings" 931 ] 932 return arg 933 934 def output(self, target, args, kwargs): 935 dynamo_result_flat = args[0] 936 lookup = [*dynamo_result_flat, *self.new_args] 937 new_results_flat = [] 938 for i in range(len(self.flat_results)): 939 if self.matched_output_elements_positions[i] is not None: 940 new_results_flat.append( 941 lookup[self.matched_output_elements_positions[i]] 942 ) 943 else: 944 const_val = self.flat_results[i] 945 assert isinstance(const_val, tuple(common_constant_types)) 946 new_results_flat.append(const_val) 947 return super().output(target, (new_results_flat,), {}) 948 949 def run_node(self, n): 950 self.current_node = n 951 result_proxy = super().run_node(n) 952 if "val" in self.current_node.meta: 953 result_proxy.node.meta["val"] = self.current_node.meta["val"] 954 if "example_value" in self.current_node.meta: 955 # NB: intentionally do not use set_example_value 956 result_proxy.node.meta["example_value"] = self.current_node.meta[ 957 "example_value" 958 ] 959 if "unbacked_bindings" in self.current_node.meta: 960 result_proxy.node.meta["unbacked_bindings"] = self.current_node.meta[ 961 "unbacked_bindings" 962 ] 963 if self.current_node.op != "output": 964 result_proxy.node._rename( 965 getattr(self.current_node, "name", result_proxy.node.name) 966 ) 967 return result_proxy 968 969 def transform(self): 970 result_gm = super().transform() 971 if "dynamo_flat_name_to_original_fqn" in self.module.meta: 972 result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ 973 "dynamo_flat_name_to_original_fqn" 974 ] 975 return result_gm 976 977 978class ExportResult(NamedTuple): 979 graph_module: torch.fx.GraphModule 980 guards: _guards.GuardsSet 981 # NB: Do not add new fields without overriding __iter__; people are 982 # destructuring so it is BC-breaking 983 984 985def check_signature_rewritable(graph): 986 input_errors = [] 987 for node in graph.graph.find_nodes(op="placeholder"): 988 assert hasattr(node, "_dynamo_source") 989 source = node._dynamo_source 990 user_stacks = graph._source_to_user_stacks.get(source) 991 if user_stacks is None: 992 continue 993 assert len(user_stacks) > 0 994 # In some cases we may not have a useful stack. Look for a 995 # useful stack 996 stack = None 997 for s in user_stacks: 998 if len(s) == 0: 999 continue 1000 stack = s 1001 break 1002 if stack is None: 1003 msg = f"{source.name()}, a closed over free variable" 1004 else: 1005 tb = "".join(traceback.format_list(stack)) 1006 extra = "" 1007 if len(user_stacks) > 1: 1008 extra = f"(elided {len(user_stacks) - 1} more accesses)" 1009 msg = f"{source.name()}, accessed at:\n{tb}{extra}" 1010 # TODO: option to print ALL of the stack traces at once 1011 input_errors.append(msg) 1012 1013 if input_errors: 1014 raise UserError( 1015 UserErrorType.INVALID_INPUT, 1016 "Cannot export model which references tensors that are neither " 1017 "buffers/parameters/constants nor are direct inputs. For each tensor, if you'd " 1018 "like this tensor to be an explicit input, add it as a dummy argument " 1019 "to the top-level model definition you are exporting; if you would " 1020 "like its value to be embedded as an exported constant, wrap its access " 1021 "in a function marked with @assume_constant_result.\n\n" 1022 + "\n\n".join(input_errors), 1023 ) 1024 1025 1026def rewrite_signature( 1027 f_sig, 1028 graph, 1029 fake_mode, 1030 flat_args, 1031 in_spec, 1032 example_fake_inputs, 1033 graph_captured_input, 1034 graph_captured_output, 1035 dynamo_traced_result, 1036 flat_args_dynamic_dims, 1037): 1038 orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec) 1039 1040 def check_user_input_output(flat_values, error_type): 1041 supported_types = [ 1042 torch.Tensor, 1043 torch.SymInt, 1044 torch.SymFloat, 1045 torch.SymBool, 1046 torch._C.ScriptObject, 1047 ] + list(common_constant_types) 1048 1049 def is_supported_type(val): 1050 return isinstance(val, tuple(supported_types)) 1051 1052 value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output" 1053 # We only check that the outputs are not None. Inputs can be None. 1054 for v in flat_values: 1055 if not is_supported_type(v): 1056 if error_type == UserErrorType.INVALID_INPUT and v is None: 1057 continue 1058 1059 raise UserError( 1060 error_type, 1061 f"It looks like one of the {value_type}s with type `{type(v)}` " 1062 "is not supported or pytree-flattenable. \n" 1063 f"Exported graphs {value_type}s can only contain the " 1064 f"following supported types: {supported_types}. \n" 1065 "If you are using a custom class object, " 1066 "please register a pytree_flatten/unflatten function " 1067 "using `torch.utils._pytree.register_pytree_node` or " 1068 "`torch.export.register_dataclass`.", 1069 ) 1070 1071 check_user_input_output(flat_args, UserErrorType.INVALID_INPUT) 1072 flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result) 1073 check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT) 1074 1075 def check_optional_input_and_error(f_sig: inspect.Signature): 1076 # Check if function has optional input. 1077 for name, param in f_sig.parameters.items(): 1078 if param.default is not inspect.Parameter.empty: 1079 from torch._dynamo.exc import Unsupported 1080 1081 log.error( 1082 "Parameter %s is optional with a default value of %s", 1083 name, 1084 param.default, 1085 ) 1086 raise Unsupported( 1087 "Tracing through optional input is not supported yet", 1088 case_name="optional_input", 1089 ) 1090 1091 def produce_matching(debug_type, sources, candidates): 1092 matched_elements_positions: List[Optional[int]] = [] 1093 dict_of_source_vals = {} 1094 for i, val in enumerate(sources): 1095 dict_of_source_vals[id(val)] = i 1096 1097 for i, val in enumerate(candidates): 1098 if isinstance(val, tuple(common_constant_types)): 1099 matched_elements_positions.append(None) 1100 elif id(val) not in dict_of_source_vals: 1101 if debug_type == "inputs": 1102 check_optional_input_and_error(f_sig) 1103 raise AssertionError( 1104 f"Unexpectedly found a {type(val)} in the {debug_type}.\n" 1105 'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"', 1106 ) 1107 else: 1108 matched_elements_positions.append(dict_of_source_vals[id(val)]) 1109 1110 return matched_elements_positions 1111 1112 matched_input_elements_positions = produce_matching( 1113 "inputs", flat_args, graph_captured_input 1114 ) 1115 1116 assert graph_captured_output is not None 1117 matched_output_elements_positions = produce_matching( 1118 "outputs", list(graph_captured_output) + flat_args, flat_results_traced 1119 ) 1120 1121 new_graph = FlattenInputOutputSignature( 1122 graph, 1123 flat_args, 1124 matched_input_elements_positions, 1125 flat_results_traced, 1126 matched_output_elements_positions, 1127 example_fake_inputs, 1128 flat_args_dynamic_dims, 1129 fake_mode, 1130 ).transform() 1131 1132 # Make dynamo graph to have same input/output spec as user code 1133 def argument_names(f_sig, args, kwargs) -> List[str]: 1134 def signature_to_fullargspec(sig: inspect.Signature): 1135 # Get a list of Parameter objects from the Signature object 1136 params = list(sig.parameters.values()) 1137 # Separate positional arguments, keyword-only arguments and varargs/varkw 1138 args = [ 1139 p.name 1140 for p in params 1141 if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD 1142 ] 1143 kwonlyargs = [ 1144 p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY 1145 ] 1146 varargs = next( 1147 (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL), 1148 None, 1149 ) 1150 varkw = next( 1151 (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD), 1152 None, 1153 ) 1154 # Get default values for positional arguments and keyword-only arguments 1155 defaults = tuple( 1156 p.default 1157 for p in params 1158 if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD 1159 and p.default is not inspect.Parameter.empty 1160 ) 1161 kwonlydefaults = { 1162 p.name: p.default 1163 for p in params 1164 if p.kind == inspect.Parameter.KEYWORD_ONLY 1165 and p.default is not inspect.Parameter.empty 1166 } 1167 # Get annotations for parameters and return value 1168 annotations = {} 1169 if sig.return_annotation: 1170 annotations = {"return": sig.return_annotation} 1171 for parameter in params: 1172 annotations[parameter.name] = parameter.annotation 1173 # Return a FullArgSpec object with the extracted attributes 1174 return inspect.FullArgSpec( 1175 args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations 1176 ) 1177 1178 fullargspec = signature_to_fullargspec(f_sig) 1179 1180 # 1. Map `args` 1-to-1 to positional arguments in original signature. 1181 input_strs = fullargspec.args[: len(args)] 1182 1183 if len(args) > len(fullargspec.args): 1184 # 2. If there are more arguments left in `args`, they map to varargs in original 1185 # signature. Assign names as {varargs}_0, {varargs}_1, ... 1186 assert fullargspec.varargs is not None, "More arguments than expected" 1187 input_strs += [ 1188 f"{fullargspec.varargs}_{i}" 1189 for i in range(0, len(args) - len(input_strs)) 1190 ] 1191 elif len(args) < len(fullargspec.args): 1192 # 3. If there are fewer arguments in `args` than `fullargspec.args`, 1193 # it implies these are arguments either with default values, or provided in 1194 # `kwargs`. The former can be safely ignored. Because Dynamo.export does not 1195 # export them as part of the function signature. The latter will be handled 1196 # in the next step. 1197 for unprovided_arg in fullargspec.args[ 1198 len(args) : -len(fullargspec.defaults or []) 1199 ]: 1200 assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}" 1201 1202 # 4. Keyword arguments provided in `kwargs`. 1203 input_strs += list(kwargs.keys()) 1204 1205 # 5. Keyword-only arguments with default values if not provided are not exported 1206 # as part of the function signature. 1207 for kwonly_arg in fullargspec.kwonlyargs: 1208 kwonlydefaults = fullargspec.kwonlydefaults or {} 1209 assert ( 1210 kwonly_arg in kwargs or kwonly_arg in kwonlydefaults 1211 ), f"Missing keyword only argument {kwonly_arg}" 1212 1213 return input_strs 1214 1215 new_graph.graph._codegen = _PyTreeCodeGen( 1216 _PyTreeInfo( 1217 argument_names(f_sig, orig_args, orig_kwargs), 1218 in_spec, 1219 out_spec_traced, 1220 ) 1221 ) 1222 new_graph.recompile() 1223 return new_graph 1224 1225 1226def export( 1227 f: Callable[..., Any], 1228 *extra_args, 1229 aten_graph: bool = False, 1230 pre_dispatch: bool = False, 1231 decomposition_table: Optional[ 1232 Dict[torch._ops.OpOverload, Callable[..., Any]] 1233 ] = None, 1234 tracing_mode: str = "symbolic", 1235 dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, 1236 assume_static_by_default: bool = False, 1237 same_signature: bool = True, 1238 disable_constraint_solver: bool = False, 1239 prefer_deferred_runtime_asserts_over_guards: bool = False, 1240 allow_complex_guards_as_runtime_asserts: bool = False, 1241 _log_export_usage: bool = True, 1242 **extra_kwargs, 1243) -> Callable[..., ExportResult]: 1244 """ 1245 Export an input function f to a format that can be executed outside of PyTorch using the FX graph. 1246 1247 Args: 1248 f (callable): A PyTorch function to be exported. 1249 1250 aten_graph (bool): If True, exports a graph with ATen operators. 1251 If False, exports a graph with Python operators. Default is False. 1252 1253 pre_dispatch (bool): If True, exports a graph with ATen operators, 1254 but before any logic in the PyTorch dispatcher has run. 1255 This can be useful if you want to apply further transformations on a graph before running it 1256 through autograd, autocast, or any other functionalities that are integrated into the dispatcher. 1257 This flag is only valid if aten_graph=True is set. 1258 Default is False. 1259 1260 decomposition_table (dict): A dictionary that maps operators to their decomposition functions. 1261 Required if aten_graph or tracing_mode is specified. Default is None. 1262 1263 tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic". 1264 1265 dynamic_shapes: 1266 An optional argument where the type should either be: 1267 1) a dict from argument names of ``f`` to their dynamic shape specifications, 1268 2) a tuple that specifies dynamic shape specifications for each input in original order. 1269 If you are specifying dynamism on keyword args, you will need to pass them in the order that 1270 is defined in the original function signature. 1271 1272 The dynamic shape of a tensor argument can be specified as either 1273 (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is 1274 not required to include static dimension indices in this dict, but when they are, 1275 they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, 1276 where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions 1277 are denoted by None. Arguments that are dicts or tuples / lists of tensors are 1278 recursively specified by using mappings or sequences of contained specifications. 1279 1280 same_signature (bool): If True, rewrite the returned graph's signature to be the same as f. 1281 1282 disable_constraint_solver (bool): Whether the dim constraint solver must be disabled. 1283 1284 Returns: 1285 A function that given args and kwargs, returns a tuple of (graph, guards) 1286 Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options. 1287 Guards: The guards we accumulated during tracing f above 1288 1289 Raises: 1290 AssertionError: If decomposition_table is specified without setting aten_graph=True, 1291 or if graph breaks during tracing in export. 1292 1293 AssertionError: If Dynamo input and output is not consistent with traced input/output. 1294 1295 Note - this headerdoc was authored by ChatGPT, with slight modifications by the author. 1296 """ 1297 if _log_export_usage: 1298 log_export_usage(event="export.private_api", flags={"_dynamo"}) 1299 1300 # Deal with "local variable referenced before assignment" 1301 _f = f 1302 _assume_static_by_default = assume_static_by_default 1303 1304 def inner(*args, **kwargs): 1305 combined_args = _combine_args(_f, args, kwargs) 1306 constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) 1307 f = _f 1308 assume_static_by_default = _assume_static_by_default 1309 check_if_dynamo_supported() 1310 torch._C._log_api_usage_once("torch._dynamo.export") 1311 if decomposition_table is not None: 1312 assert ( 1313 aten_graph 1314 ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True" 1315 if pre_dispatch: 1316 assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" 1317 f = innermost_fn(f) 1318 call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f 1319 original_signature = inspect.signature(call_to_inspect) 1320 graph = None 1321 out_guards = None 1322 graph_captured_input = None 1323 graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None 1324 fake_mode = None 1325 result_traced = None 1326 1327 def guard_export_print(guards: _guards.GuardsSet): 1328 nonlocal out_guards 1329 assert ( 1330 out_guards is None 1331 ), "whole graph export entails exactly one guard export" 1332 out_guards = guards 1333 1334 example_inputs = [] 1335 1336 def dynamo_normalization_capturing_compiler( 1337 gm: torch.fx.GraphModule, inner_example_inputs 1338 ): 1339 nonlocal graph 1340 assert ( 1341 graph is None 1342 ), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." 1343 graph = gm 1344 1345 nonlocal fake_mode, example_inputs 1346 # NB: do NOT pass inner_example_inputs here, we are detecting the 1347 # Dynamo allocated fake mode, which should be DISTINCT from a 1348 # potential outer ambient fake mode which the user provided. 1349 # example_inputs is always the user specified inputs, so they 1350 # would have the wrong fake mode attached to them 1351 fake_mode = _guards.detect_fake_mode() 1352 example_inputs = inner_example_inputs 1353 1354 def result_capturing_wrapper(*graph_inputs): 1355 nonlocal graph_captured_result 1356 nonlocal graph_captured_input 1357 1358 graph_captured_input = graph_inputs 1359 assert graph is not None 1360 1361 named_parameters = dict(graph.named_parameters(remove_duplicate=False)) 1362 named_buffers = dict(graph.named_buffers(remove_duplicate=False)) 1363 1364 ambient_fake_mode = ( 1365 _guards.detect_fake_mode(graph_inputs) 1366 if _guards.detect_fake_mode(graph_inputs) is not None 1367 else fake_mode 1368 ) 1369 1370 # We reran fake tensor propagation, but we didn't do 1371 # anything with the resulting unbacked SymInts. Drop them 1372 # from the pending list. 1373 # NB: this is wrong if graph_captured_result has 1374 # data-dependent output size! 1375 ignore_fresh_unbacked = null_context() 1376 if shape_env := ambient_fake_mode.shape_env: 1377 ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() 1378 1379 with ( 1380 ambient_fake_mode 1381 ), enable_python_dispatcher(), ignore_fresh_unbacked: 1382 params_and_buffers = { 1383 **named_parameters, 1384 **named_buffers, 1385 } 1386 fake_params_buffers = {} 1387 1388 for name, value in params_and_buffers.items(): 1389 fake_params_buffers[name] = ambient_fake_mode.from_tensor( 1390 value, static_shapes=True 1391 ) 1392 1393 fake_graph_inputs = pytree.tree_map( 1394 ambient_fake_mode.from_tensor, graph_inputs 1395 ) 1396 graph_captured_result = torch.func.functional_call( 1397 graph, fake_params_buffers, fake_graph_inputs 1398 ) 1399 1400 return graph_captured_result 1401 1402 return result_capturing_wrapper 1403 1404 # Note: This is needed by rewrite_signature. We need to put it before 1405 # optimize_assert since user program may mutate the inputs. 1406 flat_args, in_spec = pytree.tree_flatten((args, kwargs)) 1407 1408 remove_from_cache(f) 1409 constraint_violation_error = None 1410 if tracing_mode != "symbolic": 1411 assume_static_by_default = True 1412 with config.patch( 1413 specialize_int=True, 1414 assume_static_by_default=assume_static_by_default, 1415 automatic_dynamic_shapes=False, 1416 capture_dynamic_output_shape_ops=True, 1417 capture_scalar_outputs=True, 1418 prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, 1419 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, 1420 ): 1421 opt_f = optimize_assert( 1422 dynamo_normalization_capturing_compiler, 1423 hooks=Hooks( 1424 guard_export_fn=guard_export_print, 1425 guard_fail_fn=None, 1426 ), 1427 export=True, 1428 export_constraints=constraints, 1429 )(f) 1430 # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject. 1431 try: 1432 result_traced = opt_f(*args, **kwargs) 1433 except ConstraintViolationError as e: 1434 constraint_violation_error = e 1435 remove_from_cache(f) 1436 1437 if ( 1438 not disable_constraint_solver 1439 and (shape_env := getattr(fake_mode, "shape_env", None)) is not None 1440 and (dim_constraints := shape_env.dim_constraints) is not None 1441 and not isinstance( 1442 call_to_inspect, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) 1443 ) 1444 and not trace_rules.check(call_to_inspect) 1445 ): 1446 dim_constraints.solve() 1447 forced_specializations = dim_constraints.forced_specializations() 1448 msg = dim_constraints.prettify_results( 1449 original_signature, 1450 dynamic_shapes, 1451 constraint_violation_error, 1452 forced_specializations, 1453 ) 1454 if constraint_violation_error: 1455 constraint_violation_error.args = ( 1456 constraint_violation_error.args[0] + msg, 1457 ) 1458 else: 1459 if forced_specializations: 1460 constraint_violation_error = ConstraintViolationError(msg) 1461 else: 1462 log.info( 1463 "Summary of dimension constraints:%s", 1464 msg, 1465 ) 1466 1467 # Error if we have any constraints on static values 1468 for k in shape_env.var_to_range.keys(): 1469 if isinstance(k, sympy.Integer): 1470 constraint_violation_error = ConstraintViolationError( 1471 f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" 1472 "It appears that you're trying to set a constraint on a " 1473 f"value which we evaluated to have a static value of {k}. " 1474 'Set TORCH_LOGS="+export" for more information.' 1475 ) 1476 if constraint_violation_error: 1477 raise constraint_violation_error 1478 1479 if graph is None: 1480 assert ( 1481 same_signature 1482 ), "Failed to produce a graph during tracing as no tensor operations were found and same_signature is False." 1483 # If the module does not contain any tensor computation, we would create a graph with inputs and outputs. 1484 # To be consitant with the graph traced by dynano, `graph` will have only tensor inputs as placeholders 1485 # and tensor outputs as output nodes. non-tensor inputs and outputs will be added when rewriting signature. 1486 # We will also construct the `example_inputs`, `graph_captured_input`, and `graph_captured_result` corresponding 1487 # to `graph`. 1488 example_inputs = [] 1489 graph_captured_input = () 1490 graph_captured_result = () 1491 fake_mode = torch._subclasses.FakeTensorMode( 1492 shape_env=ShapeEnv(), export=True 1493 ) 1494 if out_guards is None: 1495 out_guards = _guards.GuardsSet() 1496 assert out_guards is not None # suppress mypy error 1497 parameter_names = list(original_signature.parameters.keys()) 1498 fx_graph = torch.fx.Graph() 1499 for i, name in enumerate(parameter_names): 1500 if torch.is_tensor(flat_args[i]): 1501 node = fx_graph.placeholder(name) 1502 node.meta["val"] = fake_mode.from_tensor( 1503 flat_args[i], static_shapes=True 1504 ) 1505 graph_captured_input = graph_captured_input + (flat_args[i],) 1506 example_inputs.append(flat_args[i]) 1507 fx_graph.output(graph_captured_result) 1508 module = torch.nn.Module() 1509 graph = torch.fx.GraphModule(module, fx_graph) 1510 log.info( 1511 "Failed to capture a graph during tracing as no tensor operations were found.:\n\n%s", 1512 graph.print_readable(print_output=False, colored=True), 1513 ) 1514 else: 1515 assert hasattr(graph, "_source_to_user_stacks") 1516 assert out_guards is not None, "Failed to produce guards during tracing" 1517 assert fake_mode is not None 1518 1519 log.info( 1520 "Dynamo captured graph:\n\n%s", 1521 graph.print_readable(print_output=False, colored=True), 1522 ) 1523 1524 # This check need to happened before aten_graph 1525 # because placeholder's _source_node attribute is not preserved by make_fx 1526 if same_signature: 1527 check_signature_rewritable(graph) 1528 1529 # NB: This is mostly hitting the cache; Dynamo already converted these 1530 example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs] 1531 1532 if aten_graph: 1533 # Running graph with interpreter is needed for propagating the stack_trace 1534 def graph_with_interpreter(*args): 1535 with torch.fx.traceback.preserve_node_meta(): 1536 return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type] 1537 1538 with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode: 1539 try: 1540 graph = make_fx( 1541 graph_with_interpreter, 1542 decomposition_table=decomposition_table, 1543 tracing_mode="real", 1544 _allow_non_fake_inputs=True, 1545 pre_dispatch=pre_dispatch, 1546 _allow_fake_constant=False, 1547 )(*example_fake_inputs) 1548 except CondOpArgsMismatchError as e: 1549 # Wrap the internal error to the user-facing error 1550 raise UserError( # noqa: B904 1551 UserErrorType.DYNAMIC_CONTROL_FLOW, 1552 str(e), 1553 case_name="cond_operands", 1554 ) 1555 1556 assert graph is not None 1557 for node in graph.graph.find_nodes(op="get_attr"): 1558 if isinstance(getattr(graph, node.target), torch.Tensor): # type: ignore[arg-type] 1559 node.meta["val"] = fake_mode.from_tensor( 1560 getattr(graph, node.target), static_shapes=True # type: ignore[arg-type] 1561 ) 1562 1563 if same_signature: 1564 flat_args_dynamic_dims = [ 1565 { 1566 c.dim 1567 for c in (constraints or ()) 1568 if ( 1569 c.t_id == id(x) 1570 and c.constraint_range.vr.lower != c.constraint_range.vr.upper 1571 ) 1572 } 1573 for x in flat_args 1574 ] 1575 graph = rewrite_signature( 1576 original_signature, 1577 graph, 1578 fake_mode, 1579 flat_args, 1580 in_spec, 1581 example_fake_inputs, 1582 graph_captured_input, 1583 graph_captured_result, 1584 result_traced, # type: ignore[possibly-undefined] 1585 flat_args_dynamic_dims, 1586 ) 1587 return ExportResult(graph, out_guards) # type: ignore[arg-type] 1588 1589 if extra_args or extra_kwargs: 1590 warnings.warn( 1591 "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. " 1592 "If you don't migrate, we may break your export call in the future if your user defined kwargs " 1593 "conflict with future kwargs added to export(f).", 1594 FutureWarning, 1595 stacklevel=2, 1596 ) 1597 return inner(*extra_args, **extra_kwargs) 1598 else: 1599 return inner 1600 1601 1602def optimize_assert( 1603 backend, 1604 *, 1605 hooks=Hooks(None, None), 1606 export=False, 1607 export_constraints=None, 1608 dynamic=None, 1609 rebuild_ctx=None, 1610): 1611 """ 1612 The same as `torch._dynamo.optimize(backend, nopython=True)` 1613 """ 1614 backend = get_compiler_fn(backend) 1615 1616 # Find if backend has any extra context manager 1617 backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) 1618 1619 return _optimize_catch_errors( 1620 convert_frame.convert_frame_assert( 1621 backend, export=export, export_constraints=export_constraints 1622 ), 1623 hooks, 1624 backend_ctx_ctor, 1625 export=export, 1626 dynamic=dynamic, 1627 rebuild_ctx=rebuild_ctx, 1628 ) 1629 1630 1631class TorchPatcher: 1632 @staticmethod 1633 @functools.lru_cache(None) 1634 def patch(): 1635 # A better way to disable the following would be decorate the source 1636 # functions with @torch._disable_dynamo. However, this causes issues 1637 # with torch.deploy internally. 1638 from .decorators import disable 1639 1640 torch.jit.trace = disable(torch.jit.trace) 1641 torch.jit.trace_module = disable(torch.jit.trace_module) 1642 torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph) 1643 torch.fx._symbolic_trace.Tracer.trace = disable( 1644 torch.fx._symbolic_trace.Tracer.trace 1645 ) 1646 torch.distributions.Distribution.set_default_validate_args(False) 1647 1648 from torch.optim import ( 1649 adadelta, 1650 adagrad, 1651 adam, 1652 adamax, 1653 adamw, 1654 asgd, 1655 lbfgs, 1656 nadam, 1657 radam, 1658 rmsprop, 1659 rprop, 1660 sgd, 1661 sparse_adam, 1662 ) 1663 1664 optimizer_modules = { 1665 adadelta, 1666 adagrad, 1667 adam, 1668 adamax, 1669 adamw, 1670 asgd, 1671 lbfgs, 1672 nadam, 1673 radam, 1674 rmsprop, 1675 rprop, 1676 sgd, 1677 sparse_adam, 1678 } 1679 1680 for opt_mod in optimizer_modules: 1681 opt_name = opt_mod.__name__.split(".")[-1] 1682 fused_fn_name = f"_fused_{opt_name}" 1683 single_tensor_fn_name = f"_single_tensor_{opt_name}" 1684 1685 if hasattr(opt_mod, fused_fn_name): 1686 setattr( 1687 opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name)) 1688 ) 1689 1690 optimizer_classes = [ 1691 opt 1692 for opt in torch.optim.__dict__.values() 1693 if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer) 1694 ] 1695 1696 # Note: we don't support sparsity or tracing through backwards 1697 excluded_optimizer_classes = { 1698 torch.optim.SparseAdam, 1699 torch.optim.LBFGS, 1700 } 1701 1702 for opt in optimizer_classes: 1703 if opt in excluded_optimizer_classes: 1704 opt.step = disable(opt.step) 1705 1706 if hasattr(opt, "_init_group"): 1707 opt._init_group = disable(opt._init_group) 1708 1709 @staticmethod 1710 def suppress_torch_distributed_warnings(fn): 1711 def inner_fn(*args, **kwargs): 1712 warnings.filterwarnings( 1713 "ignore", category=UserWarning, module="torch.distributed" 1714 ) 1715 return fn(*args, **kwargs) 1716 1717 return inner_fn 1718