1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7""" 8Global flags for aot autograd 9""" 10import os 11import sys 12from typing import TYPE_CHECKING 13 14 15# Converts torch rng ops to their functional philox rng equivalents. Note that 16# we functionalize only CUDA rng ops today. 17functionalize_rng_ops = False 18 19# can be useful for debugging if we are incorrectly creating meta fake tensors 20fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0" 21 22# Enables optional asserts in hotpath code to check for errors. If 23# you are seeing weird accuracy problems, try turning this on. 24# This is currently off by default as it will harm tracing time, 25# but it is on by default for aot_eager. 26debug_assert = False 27 28debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0" 29 30# Today, if you are in a situation where there is "false aliasing" 31# (e.g. you have a bunch of model parameters that all alias the same underlying buffer), 32# our checks for this situation are very slow if these inputs have dynamic shapes. 33# This config is set to ensure that there aren't too many aliased inputs in this situation, 34# so that we error loudly instead of compiling forever. 35# Eventually, we should make these checks faster. 36# For now, however, you can simply turn off dynamic shapes by marking your inputs static 37# when you run into this situation. 38_max_aliased_inputs_with_dynamic_shapes_enabled = 5 39 40static_weight_shapes = True 41 42# Applies CSE to the graph before partitioning 43cse = True 44 45 46enable_autograd_cache = os.environ.get("ENABLE_AOT_AUTOGRAD_CACHE", "0") == "1" 47 48# When AOTAutograd regenerates aliased graph outputs, 49# attempt to use functionalization's view-replay logic 50# before falling back to the autograd engine's view replay or as_strided. 51# This can have some perf implications 52# (although for many models this will not matter). 53# (1) If you have many view ops chained together, replaying all of them 54# at runtime can have more overhead compared to a single as_strided call 55# (2) If you are doing training, AsStridedBackward is quite slow, 56# and the individual view op backward formulas will likely be faster. 57# (3) Some backends like XLA do not support as_strided 58 59# Temporary hack: disable this flag for internal 60# (needed to fix an internal issue while avoiding bumping XLA pin) 61# eventually: either default this config to false completely 62# once XLA pin update works, 63# or default config to true and fix relevant bugs 64from torch._inductor.config import is_fbcode 65 66 67# View replay is currently not compatible with AOTAutogradCache, since 68# FunctionalTensors are not serializable. We'll need to make them 69# serializable before enabling warm cache with this config turned on. 70view_replay_for_aliased_outputs = (not is_fbcode()) and (not enable_autograd_cache) 71 72# Restricts the amount of computation AOTAutograd can do. 73# NB: We have essentially disabled this heuristic now. However, this is kept 74# here for now in case it's useful. Setting it low can artificially reduce the 75# amount of recomputation AOTAutograd performs, although not in any kind of 76# principled way. 77max_dist_from_bw = 1000 78 79 80# Bans recomputation of nodes that are reading from nodes that is far before 81# the current node 82ban_recompute_used_far_apart = True 83# Breaks up long chain of fusible ops, as otherwise we can have an arbitrarily 84# long chain of recomputation in the backwards pass. 85ban_recompute_long_fusible_chains = True 86# Bans recomputation of nodes that must be materialized in the backwards pass 87# (used by a non-fusible node) 88ban_recompute_materialized_backward = True 89# Chooses to ban recomputation of nodes based off an allowlist. Setting it to 90# False changes it to use a denylist. Main change is on operators like 91# sort/pool/stuff that isn't cheap enough to be fusible for free but also isn't 92# that expensive 93ban_recompute_not_in_allowlist = True 94# Chooses to ban recomputation of reductions. This is generally a good idea, as 95# the result of reductions is generally very small but recomputing reductions in 96# a fusion can be expensive. 97ban_recompute_reductions = True 98# Prevents the partitioner from ever saving views (i.e. always recompute them). 99# Generally a good idea since views are free to recompute. 100recompute_views = False 101 102# By default, the partitioner is purely trying to optimize for runtime (although 103# it should always use less memory than eager) 104# This knob controls the partitioner to make that tradeoff for you, choosing the 105# fastest option that saves less activations than the memory budget. 106# Specifically, 0.0 corresponds to the activation memory from applying 107# activation checkpointing to the full compiled region, and 1.0 corresponds to 108# the activation memory from the default runtime-optimized strategy. So, 0.4 109# would result in a strategy that saves 40% of the activations compared to the 110# default strategy. 111# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below 112# the activation memory budget. 113# NOTE: This *cannot* be treated as 114activation_memory_budget = 1.0 115 116# This controls how we estimate the runtime when deciding what the cheapest 117# operators to recompute are. The 3 options are 118# "flops": Bases it off of the flop count provided by torch.utils.flop_counter 119# "profile": Benchmarks each operator to come up with a runtime 120# "testing": Returns 1 for everything 121activation_memory_budget_runtime_estimator = "flops" 122 123# This controls the solver used for the 0-1 knapsack. By default we use a 124# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" 125# (which has a scipy dependency). 126activation_memory_budget_solver = "dp" 127 128# This dumps out a png visualization of the expected runtime vs. activation 129# memory tradeoffs for all memory budget values from 0 to 1 in increments of 130# 0.5. See an example here: 131# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 132visualize_memory_budget_pareto = ( 133 os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1" 134) 135 136# Sets all of the ban_recompute heuristics to False except ban_recompute_reductions 137# Generally, this will probably result in some memory improvement, but at the 138# cost of some performance 139aggressive_recomputation = False 140 141# If FakeTensor.data_ptr() should error. 142# This option is independent of AOTAutograd and torch.compile, but our policy 143# is to turn it off during torch.compile. 144fake_tensor_allow_unsafe_data_ptr_access = True 145 146# Unlifts effect tokens from the inputs/outputs in the traced graph and instead 147# inserts make_token/sink_token calls in the graph to create tokens and then 148# sink them at the end. Note that this means the graph is no longer functional 149# which may lead to silent errors unless the backend knows how to handle the 150# tokens. 151unlift_effect_tokens = False 152 153# This mode specifies that we should also keep track of the real 154# tensor along with the fake tensor, and do real compute. While 155# seemingly this eliminates the whole point of fake tensors, there are 156# two obvious use cases for it: 157# 158# 1. When users call item()/other data dependent operations, 159# if we propagate_real_tensors we are able to determine what 160# the true value is and keep going. 161# 162# 2. It can be useful for testing, when you want to see if the fake 163# and real tensors agree with each other. (Note that there are 164# currently known inaccuracies in how we clone real tensors, that 165# would have to be tightened up for this to be useful in this 166# case.) 167# 168# Note that fake tensors are typically understood to be cheap to store 169# indefinitely, so we tend to hold on to them longer than we would 170# hold onto the real tensors. So we also support you explicitly 171# deallocating the real tensor associated with a fake tensor, at which 172# point we will stop propagating real tensors. 173# 174# One more thing: when you provide a real tensor to fakeify, we will 175# clone it, so that we can safely perform mutations on it if necessary. 176# This will increase live memory usage. This could potentially be 177# optimized by using COW. We also currently do not faithfully 178# maintain autograd metadata on the real tensor; this is fine because 179# AOTAutograd will only use the fake tensor to determine leafness/etc 180# of tensors in question. 181fake_tensor_propagate_real_tensors = False 182 183# This controls whether we collect donated buffer. This flag must be set 184# False if a user wants to retain_graph=True for backward. 185donated_buffer = False 186 187# Controls the default graph output format used by draw_graph 188# Supported formats are defined here https://graphviz.org/docs/outputs/ 189torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg") 190 191 192# Error on BypassAOTAutogradCache instead of just a warning 193# Used for tests 194strict_autograd_cache = False 195 196if TYPE_CHECKING: 197 from torch.utils._config_typing import * # noqa: F401, F403 198 199from torch.utils._config_module import install_config_module 200 201 202# adds patch, save_config, invalid config checks, etc 203install_config_module(sys.modules[__name__]) 204