xref: /aosp_15_r20/external/pytorch/torch/_dynamo/config.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import getpass
3import inspect
4import os
5import re
6import sys
7import tempfile
8from os.path import abspath, dirname
9from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union
10
11import torch
12
13
14def is_fbcode():
15    return not hasattr(torch.version, "git_version")
16
17
18# to configure logging for dynamo, aot, and inductor
19# use the following API in the torch._logging module
20# torch._logging.set_logs(dynamo=<level>, aot=<level>, inductor<level>)
21# or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity)
22# see this design doc for more detailed info
23# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
24# the name of a file to write the logs to
25# [@compile_ignored: debug]
26log_file_name: Optional[str] = None
27
28# [@compile_ignored: debug] Verbose will print full stack traces on warnings and errors
29verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
30
31# [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend
32verify_correctness = False
33
34# need this many ops to create an FX graph
35minimum_call_count = 1
36
37# turn on/off DCE pass
38dead_code_elimination = True
39
40# disable (for a function) when cache reaches this size
41
42# controls the maximum number of cache entries with a guard on same ID_MATCH'd
43# object. It also controls the maximum size of cache entries if they don't have
44# any ID_MATCH'd guards.
45# [@compile_ignored: runtime_behaviour]
46cache_size_limit = 8
47
48# [@compile_ignored: runtime_behaviour] safeguarding to prevent horrible recomps
49accumulated_cache_size_limit = 256
50
51# [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit
52skip_code_recursive_on_cache_limit_hit = True
53
54# whether or not to specialize on int inputs.  This only has an effect with
55# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
56# inputs.  Note that assume_static_by_default will also cause ints to get
57# specialized, so this is mostly useful for export, where we want inputs
58# to be dynamic, but accesses to ints should NOT get promoted into inputs.
59specialize_int = False
60
61# Whether or not to specialize on float inputs.  Dynamo will always promote
62# float inputs into Tensor inputs, but at the moment, backends inconsistently
63# support codegen on float (this is to be fixed).
64specialize_float = True
65
66# legacy config, does nothing now!
67dynamic_shapes = True
68
69use_lazy_graph_module = (
70    os.environ.get("TORCH_COMPILE_USE_LAZY_GRAPH_MODULE", "1") == "1"
71)
72
73# This is a temporarily flag, which changes the behavior of dynamic_shapes=True.
74# When assume_static_by_default is True, we only allocate symbols for shapes marked dynamic via mark_dynamic.
75# NOTE - this flag can be removed once we can run dynamic_shapes=False w/ the mark_dynamic API
76# see [Note - on the state of mark_dynamic]
77assume_static_by_default = True
78
79# This flag changes how dynamic_shapes=True works, and is meant to be used in conjunction
80# with assume_static_by_default=True.
81# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail
82# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic.
83automatic_dynamic_shapes = True
84
85# This flag changes how the shapes of parameters are treated.
86# If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic
87# If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static,
88# while the shapes of torch.Tensor are assumed to be dynamic.
89force_parameter_static_shapes = True
90
91# This flag ensures that the shapes of a nn module are always assumed to be static
92# If the flag is set to True, then the shapes of a nn.module are assumed to be static
93# If the flag is set to False, then the shapes of a nn.module can be dynamic
94force_nn_module_property_static_shapes = True
95
96# Typically, if you mark_dynamic a dimension, we will error if the dimension
97# actually ended up getting specialized.  This knob changes the behavior so
98# that we don't error at all.  This is helpful for our CI where I'm using a
99# heuristic to mark batch dimensions as dynamic and the heuristic may get it
100# wrong.
101allow_ignore_mark_dynamic = False
102
103# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing)
104guard_nn_modules = True
105
106# Uses CPython internal dictionary tags to detect mutation. There is some
107# overlap between guard_nn_modules_using_dict_tags and guard_nn_modules flag.
108# guard_nn_modules unspecializes the nn module instance and adds guard for each
109# relevant member of the nn modules. On the other hand,
110# guard_nn_modules_using_dict_tags specializes on each nn module instance but
111# uses low overhead dict version matching to detect mutations, obviating the
112# need to guard on members of the nn modules. With
113# guard_nn_modules_using_dict_tags, the guard_nn_modules is not really required
114# but kept around for debugging and discussing unspecializing nn module
115# variables.
116# TODO(janimesh, voz): Remove both of these flags (or atleast guard_nn_modules)
117# once we have reached stability for the guard_nn_modules_using_dict_tags.
118guard_nn_modules_using_dict_tags = True
119
120# This feature doesn't really work.  We offer this flag for experimental
121# purposes / if you want to help us build out support.
122#
123# torchdynamo has limited support for tensor subclasses that implement
124# __torch_function__ see [Note: __torch_function__] in torch_function.py.
125# Our current support is limited to tensor subclasses
126# that DO NOT store metadata on the tensor (in general, dynamo does not
127# support Python code that stores extra attributes on tensors at present).
128# If your tensor subclass purely changes function call behavior via
129# __torch_function__, you can allow torchdynamo to trace into it by
130# adding it to traceable_tensor_subclasses.  We don't do any safety checks,
131# so it is up to you to ensure that your subclass is well behaved.  See also
132# https://github.com/pytorch/torchdynamo/issues/1948
133#
134# We do NOT currently support __torch_dispatch__.  The implementation is
135# currently buggy, the main show stopper for nontrivial use is
136# https://github.com/pytorch/torchdynamo/issues/1952
137traceable_tensor_subclasses: Set[Type[Any]] = set()
138
139# Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager.
140# This is a good way to get your model to work one way or another, but you may
141# lose optimization opportunities this way.  Devs, if your benchmark model is failing
142# this way, you should figure out why instead of suppressing it.
143suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
144
145# Record and write an execution record of the current frame to a file
146# if an exception is encountered
147# @compile_ignored[debug]
148replay_record_enabled = os.environ.get("TORCH_COMPILE_REPLAY_RECORD", "0") == "1"
149
150# Rewrite assert statement in python with torch._assert
151rewrite_assert_with_torch_assert = True
152
153# Disable dynamo
154disable = os.environ.get("TORCH_COMPILE_DISABLE", False)
155
156# [@compile_ignored: runtime_behaviour] Get a cprofile trace of Dynamo
157cprofile = os.environ.get("TORCH_COMPILE_CPROFILE", False)
158
159# legacy config, does nothing now!
160skipfiles_inline_module_allowlist: Dict[Any, Any] = {}
161
162# If a string representing a PyTorch module is in this ignorelist,
163# the `allowed_functions.is_allowed` function will not consider it
164# when creating a list of PyTorch functions that will appear in
165# FX IR.
166allowed_functions_module_string_ignorelist = {
167    "torch.distributions",
168    "torch.testing",
169    "torch._refs",
170    "torch._prims",
171    "torch._decomp",
172}
173
174# Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"}
175# None - Minifier is switched off
176# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails
177# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails
178# [@compile_ignored: debug]
179repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
180
181# Compiler compilation debug info
182# 1: Dumps the original graph out to repro.py if compilation fails
183# 2: Dumps a minifier_launcher.py if compilation fails.
184# 3: Always dumps a minifier_launcher.py. Good for segfaults.
185# 4: Dumps a minifier_launcher.py if the accuracy fails.
186# [@compile_ignored: debug]
187repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
188
189# By default, we try to detect accuracy failure by running both forward
190# and backward of a torchdynamo produced graph (if you are using repro_after
191# 'dynamo').  This setting forces us to only test the forward graph and
192# not the backward graph.  This can be helpful if you're trying to debug
193# an inference only problem, but the minifier seems to be choking on the
194# backwards step
195# TODO: Detect this situation automatically so the user doesn't need
196# to manually configure this
197# [@compile_ignored: debug]
198repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1"
199
200# The tolerance we should use when testing if a compiled graph
201# has diverged so that we should treat it as an accuracy failure
202# [@compile_ignored: debug]
203repro_tolerance = 1e-3
204
205
206# Whether to ignore non-floating point values when checking accuracy.
207# Checking accuracy of non-floating point values such as boolean tensors
208# can lead to false positives.
209# [@compile_ignored: debug]
210repro_ignore_non_fp = os.environ.get("TORCHDYNAMO_REPRO_IGNORE_NON_FP") == "1"
211
212# If True, when testing if two models are the same, we will test them against
213# a third fp64 reference and only report a problem if the RMSE relative to the
214# fp64 is greater.  However, this will use more memory; you may disable this
215# if memory usage is too high.
216# [@compile_ignored: runtime_behaviour]
217same_two_models_use_fp64 = True
218
219# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
220# When this flag is set to False, we introduce a graph break instead of capturing.
221# This requires dynamic_shapes to be True.
222capture_scalar_outputs = os.environ.get("TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS") == "1"
223
224# Not all backends support operators that have dynamic output shape (e.g.,
225# nonzero, unique).  When this flag is set to False, we introduce a graph
226# break instead of capturing.  This requires dynamic_shapes to be True.
227# If you set this to True, you probably also want capture_scalar_outputs
228# (these are separated for historical reasons).
229capture_dynamic_output_shape_ops = (
230    os.environ.get("TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS", "0") == "1"
231)
232
233# hybrid backed unbacked symints
234prefer_deferred_runtime_asserts_over_guards = False
235
236# For complex dynamic shapes guards that we're unable to specify with dynamo/export's
237# range constraints + dims + derived dims language, we raise constraint violation
238# errors or specialize by default. If set to True, this flag avoids crashing/specialization,
239# and allows complex guards as runtime assertions in the graph.
240allow_complex_guards_as_runtime_asserts = False
241
242# By default, dynamo will treat all ints as backed SymInts, which means (1) it
243# will wait to see the int change over multiple runs before generalizing and
244# (2) it will still always 0/1 specialize an int.  When true, this knob
245# forces dynamo to treat _length_per_key and _offset_per_key on
246# KeyedJaggedTensor from torchrec as size-like unbacked SymInts, so that
247# they (1) generalize immediately and (2) unsoundly never compare equal to
248# 0/1.  This is not on by default as AOTAutograd/Inductor cannot currently
249# compile this code; however, this can be useful for export.
250force_unspec_int_unbacked_size_like_on_torchrec_kjt = False
251
252# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
253# false_fn produces code with identical guards.
254enforce_cond_guards_match = True
255
256# Specify how to optimize a compiled DDP module. The flag accepts a boolean
257# value or a string. There are 4 modes.
258# 1. "ddp_optimizer" (or True): with "ddp_ptimizer", Dynamo will automatically
259# split model graph into pieces to match DDP bucket sizes to allow DDP
260# comm/compute overlap.
261# 2. "python_reducer" (experimental): this optimization requires the usage
262# of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer
263# and use the Python reducer to allow compiled_autograd to trace the
264# communication and allow comm/compute overlap without graph-breaks.
265# 3. "python_reducer_without_compiled_forward" (experimental): this mode is
266# similar to "python_reducer". One should only use this optimization mode
267# when compiled_autograd is used but the DDP module is not compiled.
268# 4. "no_optimization" (or False): Dynamo won't split the model graph, nor
269# will Python reducer be used. With this mode, there will be no graph-breaks
270# and the original DDP C++ reducer will be used. There will no comm/compute
271# overlap. This mode CANNOT be used with compiled_autograd.
272# Note that to avoid breaking the existing usage, mode 1 and mode 4 can be
273# specified with a boolean value. True is using ddp_optimizer and False is
274# no optimization.
275optimize_ddp: Union[bool, str] = True
276
277# By default, Dynamo emits runtime asserts (e.g. torch._check, torch._check_is_size) in the graph.
278# In some cases those asserts could be performance costly
279# E.g. torch._check(tensor[0].item() > 2) for tensor on cuda will require cuda sync.
280# Setting this to True keeps them hinting to symbolic shapes engine,
281# but not be emitted in the graph.
282do_not_emit_runtime_asserts: bool = (
283    os.environ.get("TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS", "0") == "1"
284)
285
286_ddp_optimization_mode = [
287    "ddp_optimizer",
288    "python_reducer",  # experimental mode
289    "python_reducer_without_compiled_forward",  # experimental mode
290    "no_optimization",
291]
292
293
294def _get_optimize_ddp_mode():
295    m = sys.modules[__name__]
296    if isinstance(m.optimize_ddp, bool):
297        if m.optimize_ddp:
298            mode = "ddp_optimizer"
299        else:
300            mode = "no_optimization"
301    elif isinstance(m.optimize_ddp, str):
302        mode = m.optimize_ddp
303    else:
304        raise ValueError(f"Invalid type, {type(optimize_ddp)=}")
305
306    assert mode in m._ddp_optimization_mode, f"Invalid mode {mode=}"
307    return mode
308
309
310# Skip tracing the torchrec files added to trace_rules.FBCODE_SKIP_DIRS
311skip_torchrec = True
312
313
314# No longer used
315optimize_ddp_lazy_compile = False
316
317# Whether to skip guarding on FSDP-managed modules
318skip_fsdp_guards = True
319# Whether to apply torch._dynamo.disable() to FSDP2 hooks.
320# Defaults to True. If Traceable FSDP2 is used, set this to False.
321skip_fsdp_hooks = True
322
323# Make dynamo skip guarding on hooks on nn modules
324# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and  you add them,
325# dynamo will not notice and will execute whichever version you first compiled.
326skip_nnmodule_hook_guards = True
327
328# If True, raises exception if TorchDynamo is called with a context manager
329raise_on_ctx_manager_usage = True
330
331# If True, raise when aot autograd is unsafe to use
332raise_on_unsafe_aot_autograd = False
333
334# If true, error if you torch.jit.trace over a dynamo-optimized function.
335# If false, silently suppress dynamo
336error_on_nested_jit_trace = True
337
338# If true, error with a better message if we symbolically trace over a
339# dynamo-optimized function. If false, silently suppress dynamo.
340error_on_nested_fx_trace = True
341
342# Disables graph breaking on rnn. YMMV with backends.
343allow_rnn = False
344
345# If true, enables feature that captures PyTorch sparsity in the
346# exported FX graph. This flag should become the default eventually
347# and be removed, but currently provides a way to fall back to old
348# graph breaking behavior.
349capture_sparse_compute = False if is_fbcode() else True
350
351# If true, error if we try to compile a function that has
352# been seen before.
353# [@compile_ignored: runtime_behaviour]
354error_on_recompile = False
355
356# [@compile_ignored: debug] Whether to report any guard failures (deprecated: does not do anything)
357report_guard_failures = True
358
359# [@compile_ignored: debug] root folder of the project
360base_dir = dirname(dirname(dirname(abspath(__file__))))
361
362# Trace through NumPy or graphbreak
363trace_numpy = True
364
365# Default NumPy dtypes when tracing with torch.compile
366# We default to 64bits. For efficiency, one may want to change these to float32
367numpy_default_float = "float64"
368numpy_default_complex = "complex128"
369numpy_default_int = "int64"
370
371# use numpy's PRNG if True, pytorch otherwise
372use_numpy_random_stream = False
373
374# Use C++ guard manager
375enable_cpp_guard_manager = os.environ.get("TORCHDYNAMO_CPP_GUARD_MANAGER", "1") == "1"
376
377# Inline inbuilt nn modules
378inline_inbuilt_nn_modules = not is_fbcode()
379
380# When set, total compile time instruction count is recorded using
381# torch._dynamo.utilsCompileTimeInstructionCounter.
382record_compile_time_instruction_count = False
383
384
385def default_debug_dir_root():
386    # [@compile_ignored: debug]
387    DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR"
388    if DEBUG_DIR_VAR_NAME in os.environ:
389        return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug")
390    elif is_fbcode():
391        return os.path.join(
392            tempfile.gettempdir(), getpass.getuser(), "torch_compile_debug"
393        )
394    else:
395        return os.path.join(os.getcwd(), "torch_compile_debug")
396
397
398# [@compile_ignored: debug]
399debug_dir_root = default_debug_dir_root()
400
401# [@compile_ignored: debug]
402_save_config_ignore = {
403    "repro_after",
404    "repro_level",
405    # workaround: "cannot pickle PyCapsule"
406    "constant_functions",
407    # workaround: "cannot pickle module"
408    "skipfiles_inline_module_allowlist",
409}
410
411# for backend="cudagraphs", mutations on input be sent to the cudagraph backend
412# or replayed in aot_autograd epilogue. default is False because mutation on inputs
413# can prevent cudagraphing.
414cudagraph_backend_keep_input_mutation = False
415
416# enable cudagraph support for mutated inputs from prior cudagraph pool
417cudagraph_backend_support_input_mutation = False
418
419# When True, only ops that have the torch.Tag.pt2_compliant tag
420# will be allowed into the graph; all other ops will be disallowed
421# and will fall back to eager-mode PyTorch. Useful to ensure
422# correctness of custom ops.
423only_allow_pt2_compliant_ops = False
424
425capture_autograd_function = True
426
427# enable/disable dynamo tracing for `torch.func` transforms
428capture_func_transforms = True
429
430# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
431log_compilation_metrics = True
432
433# A set of logging functions which will be reordered to the end of graph breaks,
434# allowing dynamo to construct larget graph. Note that there are some
435# limitations to this, such as how it does not correctly print objects that were
436# mutated after the print statement.
437reorderable_logging_functions: Set[Callable[[Any], None]] = set()
438
439# simulates what would happen if we didn't have support for BUILD_SET opcode,
440# used for testing
441inject_BUILD_SET_unimplemented_TESTING_ONLY = False
442
443_autograd_backward_strict_mode_banned_ops = [
444    "stride",
445    "requires_grad",
446    "storage_offset",
447    "layout",
448    "data",
449]
450
451_autograd_backward_strict_mode_banned_ops.extend(
452    [name for name, _ in inspect.getmembers(torch.Tensor) if re.match(r"^is_.*", name)]
453)
454
455# Enables caching of dispatches to fake tensors.
456fake_tensor_cache_enabled = (
457    os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "1") == "1"
458)
459
460# Enables cross checking between the fake tensor cache and dispatch.
461fake_tensor_cache_crosscheck_enabled = (
462    os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE_CROSSCHECK", "0") == "1"
463)
464
465# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile().
466# Note: AOT Autograd will still trace joint graphs.
467compiled_autograd = False
468
469# Enables use of collectives *during* compilation to synchronize behavior
470# across ranks.  Today, this is used solely to modify automatic_dynamic_shapes
471# behavior, making it so that we infer that if an input is dynamic by
472# inspecting whether or not its input size varies across ranks.  Because
473# this synchronization uses collectives, all ranks must run compilation at
474# the same time; ranks must not diverge with graph breaks.  This can be most
475# reliably achieved by ensuring PT2 only is run on SPMD programs.  If this
476# invariant is inviolated, you will likely deadlock NCCL and encounter a
477# NCCL timeout.
478enable_compiler_collectives = os.environ.get("TORCH_COMPILER_COLLECTIVES", "0") == "1"
479
480if TYPE_CHECKING:
481    from torch.utils._config_typing import *  # noqa: F401, F403
482
483    def _make_closure_patcher(**changes):
484        ...
485
486
487from torch.utils._config_module import install_config_module
488
489
490install_config_module(sys.modules[__name__])
491