1# mypy: allow-untyped-defs 2import getpass 3import inspect 4import os 5import re 6import sys 7import tempfile 8from os.path import abspath, dirname 9from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union 10 11import torch 12 13 14def is_fbcode(): 15 return not hasattr(torch.version, "git_version") 16 17 18# to configure logging for dynamo, aot, and inductor 19# use the following API in the torch._logging module 20# torch._logging.set_logs(dynamo=<level>, aot=<level>, inductor<level>) 21# or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity) 22# see this design doc for more detailed info 23# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit# 24# the name of a file to write the logs to 25# [@compile_ignored: debug] 26log_file_name: Optional[str] = None 27 28# [@compile_ignored: debug] Verbose will print full stack traces on warnings and errors 29verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1" 30 31# [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend 32verify_correctness = False 33 34# need this many ops to create an FX graph 35minimum_call_count = 1 36 37# turn on/off DCE pass 38dead_code_elimination = True 39 40# disable (for a function) when cache reaches this size 41 42# controls the maximum number of cache entries with a guard on same ID_MATCH'd 43# object. It also controls the maximum size of cache entries if they don't have 44# any ID_MATCH'd guards. 45# [@compile_ignored: runtime_behaviour] 46cache_size_limit = 8 47 48# [@compile_ignored: runtime_behaviour] safeguarding to prevent horrible recomps 49accumulated_cache_size_limit = 256 50 51# [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit 52skip_code_recursive_on_cache_limit_hit = True 53 54# whether or not to specialize on int inputs. This only has an effect with 55# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int 56# inputs. Note that assume_static_by_default will also cause ints to get 57# specialized, so this is mostly useful for export, where we want inputs 58# to be dynamic, but accesses to ints should NOT get promoted into inputs. 59specialize_int = False 60 61# Whether or not to specialize on float inputs. Dynamo will always promote 62# float inputs into Tensor inputs, but at the moment, backends inconsistently 63# support codegen on float (this is to be fixed). 64specialize_float = True 65 66# legacy config, does nothing now! 67dynamic_shapes = True 68 69use_lazy_graph_module = ( 70 os.environ.get("TORCH_COMPILE_USE_LAZY_GRAPH_MODULE", "1") == "1" 71) 72 73# This is a temporarily flag, which changes the behavior of dynamic_shapes=True. 74# When assume_static_by_default is True, we only allocate symbols for shapes marked dynamic via mark_dynamic. 75# NOTE - this flag can be removed once we can run dynamic_shapes=False w/ the mark_dynamic API 76# see [Note - on the state of mark_dynamic] 77assume_static_by_default = True 78 79# This flag changes how dynamic_shapes=True works, and is meant to be used in conjunction 80# with assume_static_by_default=True. 81# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail 82# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic. 83automatic_dynamic_shapes = True 84 85# This flag changes how the shapes of parameters are treated. 86# If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic 87# If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static, 88# while the shapes of torch.Tensor are assumed to be dynamic. 89force_parameter_static_shapes = True 90 91# This flag ensures that the shapes of a nn module are always assumed to be static 92# If the flag is set to True, then the shapes of a nn.module are assumed to be static 93# If the flag is set to False, then the shapes of a nn.module can be dynamic 94force_nn_module_property_static_shapes = True 95 96# Typically, if you mark_dynamic a dimension, we will error if the dimension 97# actually ended up getting specialized. This knob changes the behavior so 98# that we don't error at all. This is helpful for our CI where I'm using a 99# heuristic to mark batch dimensions as dynamic and the heuristic may get it 100# wrong. 101allow_ignore_mark_dynamic = False 102 103# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing) 104guard_nn_modules = True 105 106# Uses CPython internal dictionary tags to detect mutation. There is some 107# overlap between guard_nn_modules_using_dict_tags and guard_nn_modules flag. 108# guard_nn_modules unspecializes the nn module instance and adds guard for each 109# relevant member of the nn modules. On the other hand, 110# guard_nn_modules_using_dict_tags specializes on each nn module instance but 111# uses low overhead dict version matching to detect mutations, obviating the 112# need to guard on members of the nn modules. With 113# guard_nn_modules_using_dict_tags, the guard_nn_modules is not really required 114# but kept around for debugging and discussing unspecializing nn module 115# variables. 116# TODO(janimesh, voz): Remove both of these flags (or atleast guard_nn_modules) 117# once we have reached stability for the guard_nn_modules_using_dict_tags. 118guard_nn_modules_using_dict_tags = True 119 120# This feature doesn't really work. We offer this flag for experimental 121# purposes / if you want to help us build out support. 122# 123# torchdynamo has limited support for tensor subclasses that implement 124# __torch_function__ see [Note: __torch_function__] in torch_function.py. 125# Our current support is limited to tensor subclasses 126# that DO NOT store metadata on the tensor (in general, dynamo does not 127# support Python code that stores extra attributes on tensors at present). 128# If your tensor subclass purely changes function call behavior via 129# __torch_function__, you can allow torchdynamo to trace into it by 130# adding it to traceable_tensor_subclasses. We don't do any safety checks, 131# so it is up to you to ensure that your subclass is well behaved. See also 132# https://github.com/pytorch/torchdynamo/issues/1948 133# 134# We do NOT currently support __torch_dispatch__. The implementation is 135# currently buggy, the main show stopper for nontrivial use is 136# https://github.com/pytorch/torchdynamo/issues/1952 137traceable_tensor_subclasses: Set[Type[Any]] = set() 138 139# Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager. 140# This is a good way to get your model to work one way or another, but you may 141# lose optimization opportunities this way. Devs, if your benchmark model is failing 142# this way, you should figure out why instead of suppressing it. 143suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False)) 144 145# Record and write an execution record of the current frame to a file 146# if an exception is encountered 147# @compile_ignored[debug] 148replay_record_enabled = os.environ.get("TORCH_COMPILE_REPLAY_RECORD", "0") == "1" 149 150# Rewrite assert statement in python with torch._assert 151rewrite_assert_with_torch_assert = True 152 153# Disable dynamo 154disable = os.environ.get("TORCH_COMPILE_DISABLE", False) 155 156# [@compile_ignored: runtime_behaviour] Get a cprofile trace of Dynamo 157cprofile = os.environ.get("TORCH_COMPILE_CPROFILE", False) 158 159# legacy config, does nothing now! 160skipfiles_inline_module_allowlist: Dict[Any, Any] = {} 161 162# If a string representing a PyTorch module is in this ignorelist, 163# the `allowed_functions.is_allowed` function will not consider it 164# when creating a list of PyTorch functions that will appear in 165# FX IR. 166allowed_functions_module_string_ignorelist = { 167 "torch.distributions", 168 "torch.testing", 169 "torch._refs", 170 "torch._prims", 171 "torch._decomp", 172} 173 174# Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"} 175# None - Minifier is switched off 176# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails 177# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails 178# [@compile_ignored: debug] 179repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None) 180 181# Compiler compilation debug info 182# 1: Dumps the original graph out to repro.py if compilation fails 183# 2: Dumps a minifier_launcher.py if compilation fails. 184# 3: Always dumps a minifier_launcher.py. Good for segfaults. 185# 4: Dumps a minifier_launcher.py if the accuracy fails. 186# [@compile_ignored: debug] 187repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2)) 188 189# By default, we try to detect accuracy failure by running both forward 190# and backward of a torchdynamo produced graph (if you are using repro_after 191# 'dynamo'). This setting forces us to only test the forward graph and 192# not the backward graph. This can be helpful if you're trying to debug 193# an inference only problem, but the minifier seems to be choking on the 194# backwards step 195# TODO: Detect this situation automatically so the user doesn't need 196# to manually configure this 197# [@compile_ignored: debug] 198repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1" 199 200# The tolerance we should use when testing if a compiled graph 201# has diverged so that we should treat it as an accuracy failure 202# [@compile_ignored: debug] 203repro_tolerance = 1e-3 204 205 206# Whether to ignore non-floating point values when checking accuracy. 207# Checking accuracy of non-floating point values such as boolean tensors 208# can lead to false positives. 209# [@compile_ignored: debug] 210repro_ignore_non_fp = os.environ.get("TORCHDYNAMO_REPRO_IGNORE_NON_FP") == "1" 211 212# If True, when testing if two models are the same, we will test them against 213# a third fp64 reference and only report a problem if the RMSE relative to the 214# fp64 is greater. However, this will use more memory; you may disable this 215# if memory usage is too high. 216# [@compile_ignored: runtime_behaviour] 217same_two_models_use_fp64 = True 218 219# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type. 220# When this flag is set to False, we introduce a graph break instead of capturing. 221# This requires dynamic_shapes to be True. 222capture_scalar_outputs = os.environ.get("TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS") == "1" 223 224# Not all backends support operators that have dynamic output shape (e.g., 225# nonzero, unique). When this flag is set to False, we introduce a graph 226# break instead of capturing. This requires dynamic_shapes to be True. 227# If you set this to True, you probably also want capture_scalar_outputs 228# (these are separated for historical reasons). 229capture_dynamic_output_shape_ops = ( 230 os.environ.get("TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS", "0") == "1" 231) 232 233# hybrid backed unbacked symints 234prefer_deferred_runtime_asserts_over_guards = False 235 236# For complex dynamic shapes guards that we're unable to specify with dynamo/export's 237# range constraints + dims + derived dims language, we raise constraint violation 238# errors or specialize by default. If set to True, this flag avoids crashing/specialization, 239# and allows complex guards as runtime assertions in the graph. 240allow_complex_guards_as_runtime_asserts = False 241 242# By default, dynamo will treat all ints as backed SymInts, which means (1) it 243# will wait to see the int change over multiple runs before generalizing and 244# (2) it will still always 0/1 specialize an int. When true, this knob 245# forces dynamo to treat _length_per_key and _offset_per_key on 246# KeyedJaggedTensor from torchrec as size-like unbacked SymInts, so that 247# they (1) generalize immediately and (2) unsoundly never compare equal to 248# 0/1. This is not on by default as AOTAutograd/Inductor cannot currently 249# compile this code; however, this can be useful for export. 250force_unspec_int_unbacked_size_like_on_torchrec_kjt = False 251 252# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and 253# false_fn produces code with identical guards. 254enforce_cond_guards_match = True 255 256# Specify how to optimize a compiled DDP module. The flag accepts a boolean 257# value or a string. There are 4 modes. 258# 1. "ddp_optimizer" (or True): with "ddp_ptimizer", Dynamo will automatically 259# split model graph into pieces to match DDP bucket sizes to allow DDP 260# comm/compute overlap. 261# 2. "python_reducer" (experimental): this optimization requires the usage 262# of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer 263# and use the Python reducer to allow compiled_autograd to trace the 264# communication and allow comm/compute overlap without graph-breaks. 265# 3. "python_reducer_without_compiled_forward" (experimental): this mode is 266# similar to "python_reducer". One should only use this optimization mode 267# when compiled_autograd is used but the DDP module is not compiled. 268# 4. "no_optimization" (or False): Dynamo won't split the model graph, nor 269# will Python reducer be used. With this mode, there will be no graph-breaks 270# and the original DDP C++ reducer will be used. There will no comm/compute 271# overlap. This mode CANNOT be used with compiled_autograd. 272# Note that to avoid breaking the existing usage, mode 1 and mode 4 can be 273# specified with a boolean value. True is using ddp_optimizer and False is 274# no optimization. 275optimize_ddp: Union[bool, str] = True 276 277# By default, Dynamo emits runtime asserts (e.g. torch._check, torch._check_is_size) in the graph. 278# In some cases those asserts could be performance costly 279# E.g. torch._check(tensor[0].item() > 2) for tensor on cuda will require cuda sync. 280# Setting this to True keeps them hinting to symbolic shapes engine, 281# but not be emitted in the graph. 282do_not_emit_runtime_asserts: bool = ( 283 os.environ.get("TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS", "0") == "1" 284) 285 286_ddp_optimization_mode = [ 287 "ddp_optimizer", 288 "python_reducer", # experimental mode 289 "python_reducer_without_compiled_forward", # experimental mode 290 "no_optimization", 291] 292 293 294def _get_optimize_ddp_mode(): 295 m = sys.modules[__name__] 296 if isinstance(m.optimize_ddp, bool): 297 if m.optimize_ddp: 298 mode = "ddp_optimizer" 299 else: 300 mode = "no_optimization" 301 elif isinstance(m.optimize_ddp, str): 302 mode = m.optimize_ddp 303 else: 304 raise ValueError(f"Invalid type, {type(optimize_ddp)=}") 305 306 assert mode in m._ddp_optimization_mode, f"Invalid mode {mode=}" 307 return mode 308 309 310# Skip tracing the torchrec files added to trace_rules.FBCODE_SKIP_DIRS 311skip_torchrec = True 312 313 314# No longer used 315optimize_ddp_lazy_compile = False 316 317# Whether to skip guarding on FSDP-managed modules 318skip_fsdp_guards = True 319# Whether to apply torch._dynamo.disable() to FSDP2 hooks. 320# Defaults to True. If Traceable FSDP2 is used, set this to False. 321skip_fsdp_hooks = True 322 323# Make dynamo skip guarding on hooks on nn modules 324# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them, 325# dynamo will not notice and will execute whichever version you first compiled. 326skip_nnmodule_hook_guards = True 327 328# If True, raises exception if TorchDynamo is called with a context manager 329raise_on_ctx_manager_usage = True 330 331# If True, raise when aot autograd is unsafe to use 332raise_on_unsafe_aot_autograd = False 333 334# If true, error if you torch.jit.trace over a dynamo-optimized function. 335# If false, silently suppress dynamo 336error_on_nested_jit_trace = True 337 338# If true, error with a better message if we symbolically trace over a 339# dynamo-optimized function. If false, silently suppress dynamo. 340error_on_nested_fx_trace = True 341 342# Disables graph breaking on rnn. YMMV with backends. 343allow_rnn = False 344 345# If true, enables feature that captures PyTorch sparsity in the 346# exported FX graph. This flag should become the default eventually 347# and be removed, but currently provides a way to fall back to old 348# graph breaking behavior. 349capture_sparse_compute = False if is_fbcode() else True 350 351# If true, error if we try to compile a function that has 352# been seen before. 353# [@compile_ignored: runtime_behaviour] 354error_on_recompile = False 355 356# [@compile_ignored: debug] Whether to report any guard failures (deprecated: does not do anything) 357report_guard_failures = True 358 359# [@compile_ignored: debug] root folder of the project 360base_dir = dirname(dirname(dirname(abspath(__file__)))) 361 362# Trace through NumPy or graphbreak 363trace_numpy = True 364 365# Default NumPy dtypes when tracing with torch.compile 366# We default to 64bits. For efficiency, one may want to change these to float32 367numpy_default_float = "float64" 368numpy_default_complex = "complex128" 369numpy_default_int = "int64" 370 371# use numpy's PRNG if True, pytorch otherwise 372use_numpy_random_stream = False 373 374# Use C++ guard manager 375enable_cpp_guard_manager = os.environ.get("TORCHDYNAMO_CPP_GUARD_MANAGER", "1") == "1" 376 377# Inline inbuilt nn modules 378inline_inbuilt_nn_modules = not is_fbcode() 379 380# When set, total compile time instruction count is recorded using 381# torch._dynamo.utilsCompileTimeInstructionCounter. 382record_compile_time_instruction_count = False 383 384 385def default_debug_dir_root(): 386 # [@compile_ignored: debug] 387 DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR" 388 if DEBUG_DIR_VAR_NAME in os.environ: 389 return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug") 390 elif is_fbcode(): 391 return os.path.join( 392 tempfile.gettempdir(), getpass.getuser(), "torch_compile_debug" 393 ) 394 else: 395 return os.path.join(os.getcwd(), "torch_compile_debug") 396 397 398# [@compile_ignored: debug] 399debug_dir_root = default_debug_dir_root() 400 401# [@compile_ignored: debug] 402_save_config_ignore = { 403 "repro_after", 404 "repro_level", 405 # workaround: "cannot pickle PyCapsule" 406 "constant_functions", 407 # workaround: "cannot pickle module" 408 "skipfiles_inline_module_allowlist", 409} 410 411# for backend="cudagraphs", mutations on input be sent to the cudagraph backend 412# or replayed in aot_autograd epilogue. default is False because mutation on inputs 413# can prevent cudagraphing. 414cudagraph_backend_keep_input_mutation = False 415 416# enable cudagraph support for mutated inputs from prior cudagraph pool 417cudagraph_backend_support_input_mutation = False 418 419# When True, only ops that have the torch.Tag.pt2_compliant tag 420# will be allowed into the graph; all other ops will be disallowed 421# and will fall back to eager-mode PyTorch. Useful to ensure 422# correctness of custom ops. 423only_allow_pt2_compliant_ops = False 424 425capture_autograd_function = True 426 427# enable/disable dynamo tracing for `torch.func` transforms 428capture_func_transforms = True 429 430# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode). 431log_compilation_metrics = True 432 433# A set of logging functions which will be reordered to the end of graph breaks, 434# allowing dynamo to construct larget graph. Note that there are some 435# limitations to this, such as how it does not correctly print objects that were 436# mutated after the print statement. 437reorderable_logging_functions: Set[Callable[[Any], None]] = set() 438 439# simulates what would happen if we didn't have support for BUILD_SET opcode, 440# used for testing 441inject_BUILD_SET_unimplemented_TESTING_ONLY = False 442 443_autograd_backward_strict_mode_banned_ops = [ 444 "stride", 445 "requires_grad", 446 "storage_offset", 447 "layout", 448 "data", 449] 450 451_autograd_backward_strict_mode_banned_ops.extend( 452 [name for name, _ in inspect.getmembers(torch.Tensor) if re.match(r"^is_.*", name)] 453) 454 455# Enables caching of dispatches to fake tensors. 456fake_tensor_cache_enabled = ( 457 os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "1") == "1" 458) 459 460# Enables cross checking between the fake tensor cache and dispatch. 461fake_tensor_cache_crosscheck_enabled = ( 462 os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE_CROSSCHECK", "0") == "1" 463) 464 465# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile(). 466# Note: AOT Autograd will still trace joint graphs. 467compiled_autograd = False 468 469# Enables use of collectives *during* compilation to synchronize behavior 470# across ranks. Today, this is used solely to modify automatic_dynamic_shapes 471# behavior, making it so that we infer that if an input is dynamic by 472# inspecting whether or not its input size varies across ranks. Because 473# this synchronization uses collectives, all ranks must run compilation at 474# the same time; ranks must not diverge with graph breaks. This can be most 475# reliably achieved by ensuring PT2 only is run on SPMD programs. If this 476# invariant is inviolated, you will likely deadlock NCCL and encounter a 477# NCCL timeout. 478enable_compiler_collectives = os.environ.get("TORCH_COMPILER_COLLECTIVES", "0") == "1" 479 480if TYPE_CHECKING: 481 from torch.utils._config_typing import * # noqa: F401, F403 482 483 def _make_closure_patcher(**changes): 484 ... 485 486 487from torch.utils._config_module import install_config_module 488 489 490install_config_module(sys.modules[__name__]) 491