1# mypy: allow-untyped-defs 2import inspect 3import itertools 4import logging 5from typing import Optional 6 7from torch._logging import warning_once 8from torch._ops import HigherOrderOperator 9from torch.types import _dtype 10from torch.utils.checkpoint import checkpoint, CheckpointPolicy 11 12 13log = logging.getLogger(__name__) 14 15uid = itertools.count(1) 16 17 18# Used for testing the HigherOrderOperator mechanism 19class Wrap(HigherOrderOperator): 20 def __init__(self) -> None: 21 super().__init__("wrap") 22 23 def __call__(self, func, *args, **kwargs): 24 # Dynamo already traces the body of HigherOrderOp beforehand when it 25 # so no need to trace into it. 26 import torch._dynamo # noqa: F401 27 from torch._dynamo import disable 28 29 @disable 30 def wrapper(): 31 result = func(*args, **kwargs) 32 return result 33 34 return wrapper() 35 36 37wrap = Wrap() 38 39 40class WrapWithSetGradEnabled(HigherOrderOperator): 41 def __init__(self) -> None: 42 super().__init__("wrap_with_set_grad_enabled") 43 44 def __call__(self, enable_grad, wrapped_func, *args, **kwargs): 45 # Dynamo already traces the body of HigherOrderOp beforehand when it 46 # so no need to trace into it. 47 import torch._dynamo # noqa: F401 48 from torch._dynamo import disable 49 50 @disable 51 def wrapper(): 52 with torch.set_grad_enabled(enable_grad): 53 return wrapped_func(*args, **kwargs) 54 55 return wrapper() 56 57 58wrap_with_set_grad_enabled = WrapWithSetGradEnabled() 59 60 61class WrapWithAutocast(HigherOrderOperator): 62 def __init__(self): 63 super().__init__("wrap_with_autocast") 64 65 def __call__( 66 self, 67 device_type: str, 68 dtype: Optional[_dtype], 69 enabled: bool, 70 cache_enabled: Optional[bool], 71 wrapped_func, 72 *args, 73 **kwargs, 74 ): 75 # Dynamo already traces the body of HigherOrderOp beforehand when it 76 # so no need to trace into it. 77 import torch._dynamo # noqa: F401 78 from torch._dynamo import disable 79 80 @disable 81 def wrapper(): 82 with torch.autocast(device_type, dtype, enabled, cache_enabled): 83 return wrapped_func(*args, **kwargs) 84 85 return wrapper() 86 87 88wrap_with_autocast = WrapWithAutocast() 89 90 91class WrapActivationCheckpoint(HigherOrderOperator): 92 """ 93 This operator is used to wrap torch.utils.checkpoint. This avoids 94 TorchDynamo to look into saved tensor hooks and directly passes the control 95 to AOT Autograd, which is ok with tracing saved tensor hooks. As a result of 96 AOT tracing torch.utils.checkpoint code, we have a backward graph with 97 recomputed forward nodes. 98 99 However, we might deprecate this operator soon. The difficulty arises in the 100 functionalization of rng ops. Today, there are two different 101 functionalization of rng ops - one at AOT autograd and other at Inductor. 102 And they are difficult to map to each other. The rng states also complicate 103 pattern matching in Inductor. Due to the ease of implementation, we are 104 currently inclined towards functionalization at Inductor level, which means 105 that duplication/recomputation is done as a compiler pass in the 106 partitioners. See TagActivationCheckpoint for more information. 107 """ 108 109 def __init__(self) -> None: 110 super().__init__("wrap_activation_checkpoint") 111 112 def __call__(self, function, *args, **kwargs): 113 # use_reentrant is set to False because this op is going to be traced. 114 # And we ensure that AOT Autograd traces through the non reentrant 115 # version of checkpointing. 116 import torch.fx.traceback as fx_traceback 117 from torch.fx import Interpreter 118 119 kwargs["use_reentrant"] = False 120 kwargs["preserve_rng_state"] = False 121 # Using interpreter allows preservation of metadata through torch.compile stack. 122 with fx_traceback.preserve_node_meta(): 123 return checkpoint(Interpreter(function).run, *args, **kwargs) 124 125 126wrap_activation_checkpoint = WrapActivationCheckpoint() 127 128 129class TagActivationCheckpoint(HigherOrderOperator): 130 """ 131 This operator is supposed to be used only with torch.compile stack. This 132 accepts a Fx graph module which needs to be checkpointed. This operator adds 133 "recomputable" tag to the nodes of the Fx graph that should be recomputed. 134 135 The goal is to: 136 1. Avoid using Dynamo to trace through saved tensor hooks. 137 2. For selective checkpointing case, let AOTAutograd trace through 138 saved tensor hooks but has special logic with TorchDispatchMode to override 139 the usual saved_tensor_hooks fn logic in order to tag the nodes. 140 3. Rely on the partitioners to actually duplicate the nodes. 141 This sits well in the torch.compile stack, because by the time graph 142 reaches partitioner, inductor has already run its functionalization of rng 143 ops (by setting fixed seed for each random op, see `replace_random_passes`). 144 Therefore, the duplication of nodes, by design, respects the rng states in 145 the forward and recomputed forward in backward. 146 """ 147 148 def __init__(self) -> None: 149 super().__init__("tag_activation_checkpoint") 150 151 @staticmethod 152 def divide_kwargs(kwargs): 153 """ 154 checkpoint fn can have mixed kwargs between checkpointed fn and 155 checkpoint fn itself. For example 156 >> def gn(x, y, z=None): 157 >> a = torch.matmul(x, y) 158 >> if z is not None: 159 >> return torch.matmul(a, z) 160 >> return a 161 >> def fn(x, y, z): 162 >> return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z)) 163 In the above case, z belongs to checkpointed function gn, but 164 use_reentrant belongs to the checkpoint function. This function splits 165 the kwargs into checkpoint_kwargs and gmod_kwargs (or 166 checkpointed_fn_kwargs). 167 We do sorting to ensure same graph from run to run for better 168 debuggability. It is not required for correctness. 169 """ 170 ckpt_signature = inspect.signature(checkpoint) 171 checkpoint_keys = set() 172 for name in ckpt_signature.parameters: 173 if name in ("function", "args", "kwargs"): 174 continue 175 checkpoint_keys.add(name) 176 177 # `preserve_rng_state` is not a regular kwarg 178 checkpoint_keys.add("preserve_rng_state") 179 180 checkpoint_kwargs = { 181 name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys 182 } 183 gmod_kwargs = { 184 name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys 185 } 186 return checkpoint_kwargs, gmod_kwargs 187 188 def tag_nodes(self, gmod, is_sac): 189 unique_graph_id = next(uid) 190 for node in gmod.graph.nodes: 191 if node.op in ("call_function", "call_method", "call_module"): 192 node.meta["ac_graph_id"] = unique_graph_id 193 if is_sac: 194 # For selective checkpointing, we will populate this tag later in _CachingTorchDispatchMode. 195 node.meta["recompute"] = None 196 else: 197 # Under vanilla activation checkpointing, all nodes should be recomputed. 198 node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE 199 return gmod 200 201 def __call__(self, gmod, *args, **kwargs): 202 import torch.fx.traceback as fx_traceback 203 from torch.fx import Interpreter 204 205 if "_checkpoint_context_fn" in gmod.meta: 206 warning_once( 207 log, 208 """ 209Detected that context_fn is passed to torch.utils.checkpoint under torch.compile. 210Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_). 211""", 212 ) 213 # use_reentrant is set to False because this op is going to be traced. 214 # And we ensure that AOT Autograd traces through the non reentrant 215 # version of checkpointing. 216 kwargs["use_reentrant"] = False 217 # preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through 218 # `torch.random.fork_rng` op (which is not supported yet under CUDA). 219 # This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state 220 # regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor 221 # instead of in AOTAutograd). 222 kwargs["preserve_rng_state"] = False 223 kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"] 224 # We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag 225 # for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py. 226 gmod = self.tag_nodes(gmod, is_sac=True) 227 # Using interpreter allows preservation of metadata through torch.compile stack. 228 with fx_traceback.preserve_node_meta(): 229 return checkpoint(Interpreter(gmod).run, *args, **kwargs) 230 else: 231 gmod = self.tag_nodes(gmod, is_sac=False) 232 # Using interpreter allows preservation of metadata through torch.compile stack. 233 # TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here 234 # as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile. 235 # (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test) 236 with fx_traceback.preserve_node_meta(): 237 return Interpreter(gmod).run(*args) 238 239 240tag_activation_checkpoint = TagActivationCheckpoint() 241