xref: /aosp_15_r20/external/pytorch/torch/_functorch/config.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""
8Global flags for aot autograd
9"""
10import os
11import sys
12from typing import TYPE_CHECKING
13
14
15# Converts torch rng ops to their functional philox rng equivalents. Note that
16# we functionalize only CUDA rng ops today.
17functionalize_rng_ops = False
18
19# can be useful for debugging if we are incorrectly creating meta fake tensors
20fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"
21
22# Enables optional asserts in hotpath code to check for errors.  If
23# you are seeing weird accuracy problems, try turning this on.
24# This is currently off by default as it will harm tracing time,
25# but it is on by default for aot_eager.
26debug_assert = False
27
28debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"
29
30# Today, if you are in a situation where there is "false aliasing"
31# (e.g. you have a bunch of model parameters that all alias the same underlying buffer),
32# our checks for this situation are very slow if these inputs have dynamic shapes.
33# This config is set to ensure that there aren't too many aliased inputs in this situation,
34# so that we error loudly instead of compiling forever.
35# Eventually, we should make these checks faster.
36# For now, however, you can simply turn off dynamic shapes by marking your inputs static
37# when you run into this situation.
38_max_aliased_inputs_with_dynamic_shapes_enabled = 5
39
40static_weight_shapes = True
41
42# Applies CSE to the graph before partitioning
43cse = True
44
45
46enable_autograd_cache = os.environ.get("ENABLE_AOT_AUTOGRAD_CACHE", "0") == "1"
47
48# When AOTAutograd regenerates aliased graph outputs,
49# attempt to use functionalization's view-replay logic
50# before falling back to the autograd engine's view replay or as_strided.
51# This can have some perf implications
52# (although for many models this will not matter).
53# (1) If you have many view ops chained together, replaying all of them
54#     at runtime can have more overhead compared to a single as_strided call
55# (2) If you are doing training, AsStridedBackward is quite slow,
56#     and the individual view op backward formulas will likely be faster.
57# (3) Some backends like XLA do not support as_strided
58
59# Temporary hack: disable this flag for internal
60# (needed to fix an internal issue while avoiding bumping XLA pin)
61# eventually: either default this config to false completely
62# once XLA pin update works,
63# or default config to true and fix relevant bugs
64from torch._inductor.config import is_fbcode
65
66
67# View replay is currently not compatible with AOTAutogradCache, since
68# FunctionalTensors are not serializable. We'll need to make them
69# serializable before enabling warm cache with this config turned on.
70view_replay_for_aliased_outputs = (not is_fbcode()) and (not enable_autograd_cache)
71
72# Restricts the amount of computation AOTAutograd can do.
73# NB: We have essentially disabled this heuristic now. However, this is kept
74# here for now in case it's useful. Setting it low can artificially reduce the
75# amount of recomputation AOTAutograd performs, although not in any kind of
76# principled way.
77max_dist_from_bw = 1000
78
79
80# Bans recomputation of nodes that are reading from nodes that is far before
81# the current node
82ban_recompute_used_far_apart = True
83# Breaks up long chain of fusible ops, as otherwise we can have an arbitrarily
84# long chain of recomputation in the backwards pass.
85ban_recompute_long_fusible_chains = True
86# Bans recomputation of nodes that must be materialized in the backwards pass
87# (used by a non-fusible node)
88ban_recompute_materialized_backward = True
89# Chooses to ban recomputation of nodes based off an allowlist. Setting it to
90# False changes it to use a denylist. Main change is on operators like
91# sort/pool/stuff that isn't cheap enough to be fusible for free but also isn't
92# that expensive
93ban_recompute_not_in_allowlist = True
94# Chooses to ban recomputation of reductions. This is generally a good idea, as
95# the result of reductions is generally very small but recomputing reductions in
96# a fusion can be expensive.
97ban_recompute_reductions = True
98# Prevents the partitioner from ever saving views (i.e. always recompute them).
99# Generally a good idea since views are free to recompute.
100recompute_views = False
101
102# By default, the partitioner is purely trying to optimize for runtime (although
103# it should always use less memory than eager)
104# This knob controls the partitioner to make that tradeoff for you, choosing the
105# fastest option that saves less activations than the memory budget.
106# Specifically, 0.0 corresponds to the activation memory from applying
107# activation checkpointing to the full compiled region, and 1.0 corresponds to
108# the activation memory from the default runtime-optimized strategy.  So, 0.4
109# would result in a strategy that saves 40% of the activations compared to the
110# default strategy.
111# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below
112# the activation memory budget.
113# NOTE: This *cannot* be treated as
114activation_memory_budget = 1.0
115
116# This controls how we estimate the runtime when deciding what the cheapest
117# operators to recompute are. The 3 options are
118# "flops": Bases it off of the flop count provided by torch.utils.flop_counter
119# "profile": Benchmarks each operator to come up with a runtime
120# "testing": Returns 1 for everything
121activation_memory_budget_runtime_estimator = "flops"
122
123# This controls the solver used for the 0-1 knapsack. By default we use a
124# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp"
125# (which has a scipy dependency).
126activation_memory_budget_solver = "dp"
127
128# This dumps out a png visualization of the expected runtime vs. activation
129# memory tradeoffs for all memory budget values from 0 to 1 in increments of
130# 0.5. See an example here:
131# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
132visualize_memory_budget_pareto = (
133    os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1"
134)
135
136# Sets all of the ban_recompute heuristics to False except ban_recompute_reductions
137# Generally, this will probably result in some memory improvement, but at the
138# cost of some performance
139aggressive_recomputation = False
140
141# If FakeTensor.data_ptr() should error.
142# This option is independent of AOTAutograd and torch.compile, but our policy
143# is to turn it off during torch.compile.
144fake_tensor_allow_unsafe_data_ptr_access = True
145
146# Unlifts effect tokens from the inputs/outputs in the traced graph and instead
147# inserts make_token/sink_token calls in the graph to create tokens and then
148# sink them at the end. Note that this means the graph is no longer functional
149# which may lead to silent errors unless the backend knows how to handle the
150# tokens.
151unlift_effect_tokens = False
152
153# This mode specifies that we should also keep track of the real
154# tensor along with the fake tensor, and do real compute.  While
155# seemingly this eliminates the whole point of fake tensors, there are
156# two obvious use cases for it:
157#
158#   1. When users call item()/other data dependent operations,
159#      if we propagate_real_tensors we are able to determine what
160#      the true value is and keep going.
161#
162#   2. It can be useful for testing, when you want to see if the fake
163#      and real tensors agree with each other.  (Note that there are
164#      currently known inaccuracies in how we clone real tensors, that
165#      would have to be tightened up for this to be useful in this
166#      case.)
167#
168# Note that fake tensors are typically understood to be cheap to store
169# indefinitely, so we tend to hold on to them longer than we would
170# hold onto the real tensors.  So we also support you explicitly
171# deallocating the real tensor associated with a fake tensor, at which
172# point we will stop propagating real tensors.
173#
174# One more thing: when you provide a real tensor to fakeify, we will
175# clone it, so that we can safely perform mutations on it if necessary.
176# This will increase live memory usage.  This could potentially be
177# optimized by using COW.  We also currently do not faithfully
178# maintain autograd metadata on the real tensor; this is fine because
179# AOTAutograd will only use the fake tensor to determine leafness/etc
180# of tensors in question.
181fake_tensor_propagate_real_tensors = False
182
183# This controls whether we collect donated buffer. This flag must be set
184# False if a user wants to retain_graph=True for backward.
185donated_buffer = False
186
187# Controls the default graph output format used by draw_graph
188# Supported formats are defined here https://graphviz.org/docs/outputs/
189torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")
190
191
192# Error on BypassAOTAutogradCache instead of just a warning
193# Used for tests
194strict_autograd_cache = False
195
196if TYPE_CHECKING:
197    from torch.utils._config_typing import *  # noqa: F401, F403
198
199from torch.utils._config_module import install_config_module
200
201
202# adds patch, save_config, invalid config checks, etc
203install_config_module(sys.modules[__name__])
204