1# mypy: allow-untyped-defs 2import gc 3import typing 4 5import torch 6 7from .._utils import _dummy_type 8 9 10if not hasattr(torch._C, "_CudaStreamBase"): 11 # Define dummy base classes 12 torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph") 13 torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle") 14 torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type( 15 "_cuda_isCurrentStreamCapturing" 16 ) 17 18from torch._C import ( # noqa: F401 19 _cuda_isCurrentStreamCapturing, 20 _CUDAGraph, 21 _graph_pool_handle, 22) 23 24 25def is_current_stream_capturing(): 26 r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise. 27 28 If a CUDA context does not exist on the current device, returns False without initializing the context. 29 """ 30 return _cuda_isCurrentStreamCapturing() 31 32 33# Python shim helps Sphinx process docstrings more reliably. 34def graph_pool_handle(): 35 r"""Return an opaque token representing the id of a graph memory pool. 36 37 See :ref:`Graph memory management<graph-memory-management>`. 38 39 .. warning:: 40 This API is in beta and may change in future releases. 41 """ 42 return _graph_pool_handle() 43 44 45# Python shim helps Sphinx process docstrings more reliably. 46class CUDAGraph(torch._C._CUDAGraph): 47 r"""Wrapper around a CUDA graph. 48 49 .. warning:: 50 This API is in beta and may change in future releases. 51 """ 52 53 def __new__(cls): 54 return super().__new__(cls) 55 56 def capture_begin(self, pool=None, capture_error_mode="global"): 57 r"""Begin capturing CUDA work on the current stream. 58 59 Typically, you shouldn't call ``capture_begin`` yourself. 60 Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`, 61 which call ``capture_begin`` internally. 62 63 Arguments: 64 pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or 65 :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory 66 with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`. 67 capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. 68 Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, 69 may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for 70 actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting 71 unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_ 72 """ # noqa: B950 73 super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) 74 75 def capture_end(self): 76 r"""End CUDA graph capture on the current stream. 77 78 After ``capture_end``, ``replay`` may be called on this instance. 79 80 Typically, you shouldn't call ``capture_end`` yourself. 81 Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`, 82 which call ``capture_end`` internally. 83 """ 84 super().capture_end() 85 86 def replay(self): 87 r"""Replay the CUDA work captured by this graph.""" 88 super().replay() 89 90 def reset(self): 91 r"""Delete the graph currently held by this instance.""" 92 super().reset() 93 94 def pool(self): 95 r"""Return an opaque token representing the id of this graph's memory pool. 96 97 This id can optionally be passed to another graph's ``capture_begin``, 98 which hints the other graph may share the same memory pool. 99 """ 100 return super().pool() 101 102 def enable_debug_mode(self): 103 r"""Enable debugging mode for CUDAGraph.debug_dump.""" 104 return super().enable_debug_mode() 105 106 def debug_dump(self, debug_path): 107 r""" 108 Arguments: 109 debug_path (required): Path to dump the graph to. 110 111 Calls a debugging function to dump the graph if the debugging is 112 enabled via CUDAGraph.enable_debug_mode() 113 """ 114 return super().debug_dump(debug_path) 115 116 117class graph: 118 r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay. 119 120 See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction, 121 detailed use, and constraints. 122 123 Arguments: 124 cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture. 125 pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or 126 :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture 127 may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`. 128 stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context. 129 If not supplied, ``graph`` sets its own internal side stream as the current stream in the context. 130 capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream. 131 Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc, 132 may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for 133 actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting 134 unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_ 135 136 .. note:: 137 For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture 138 used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture. 139 140 .. warning:: 141 This API is in beta and may change in future releases. 142 143 .. _cudaStreamCaptureMode: 144 https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 145 """ # noqa: B950 146 147 default_capture_stream: typing.Optional["torch.cuda.Stream"] = None 148 149 def __init__( 150 self, 151 cuda_graph, 152 pool=None, 153 stream=None, 154 capture_error_mode: str = "global", 155 ): 156 # Lazy-init of default_capture_stream helps avoid circular-import errors. 157 # Not thread safe, but graphs already have the general (explicitly documented) 158 # restriction that only one capture may be underway at a time in the process. 159 if self.__class__.default_capture_stream is None: 160 self.__class__.default_capture_stream = torch.cuda.Stream() 161 162 self.pool = () if pool is None else (pool,) 163 self.capture_stream = ( 164 stream if stream is not None else self.__class__.default_capture_stream 165 ) 166 assert self.capture_stream is not None 167 self.stream_ctx = torch.cuda.stream(self.capture_stream) 168 self.cuda_graph = cuda_graph 169 self.capture_error_mode = capture_error_mode 170 171 def __enter__(self): 172 # Free as much memory as we can for the graph 173 torch.cuda.synchronize() 174 gc.collect() 175 torch.cuda.empty_cache() 176 177 # Stackoverflow seems comfortable with this pattern 178 # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487 179 self.stream_ctx.__enter__() 180 181 self.cuda_graph.capture_begin( 182 *self.pool, capture_error_mode=self.capture_error_mode 183 ) 184 185 def __exit__(self, exc_type, exc_value, traceback): 186 self.cuda_graph.capture_end() 187 self.stream_ctx.__exit__(exc_type, exc_value, traceback) 188 # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() 189 190 191def make_graphed_callables( 192 callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None 193): 194 r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions. 195 196 Each graphed callable's forward pass runs its source callable's 197 forward CUDA work as a CUDA graph inside a single autograd node. 198 199 The graphed callable's forward pass also appends 200 a backward node to the autograd graph. During backward, this node runs the 201 callable's backward work as a CUDA graph. 202 203 Therefore, each graphed callable should be a drop-in replacement for its source callable 204 in an autograd-enabled training loop. 205 206 See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints. 207 208 If you pass a tuple of several callables, their captures will use the same memory pool. 209 See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate. 210 211 Arguments: 212 callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph. 213 See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables 214 is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order 215 they'll run in the live workload. 216 sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable. 217 If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors. 218 If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors. 219 num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs 220 11 iterations for warm up. Default: ``3``. 221 allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs 222 (and therefore their grad is always zero) is an error. Defaults to False. 223 pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or 224 :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory 225 with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`. 226 .. note:: 227 The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state 228 that's expected for the corresponding real input in the training loop. 229 230 .. warning:: 231 This API is in beta and may change in future releases. 232 233 .. warning:: 234 ``sample_args`` for each callable must contain only Tensors. Other types are not allowed. 235 236 .. warning:: 237 Returned callables do not support higher order differentiation (e.g., double backward). 238 239 .. warning:: 240 In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters 241 may be trainable. Buffers must have ``requires_grad=False``. 242 243 .. warning:: 244 After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`, 245 you may not add or remove any of that Module's parameters or buffers. 246 247 .. warning:: 248 :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks 249 registered on them at the time they are passed. However, registering hooks on modules *after* passing them 250 through :func:`~torch.cuda.make_graphed_callables` is allowed. 251 252 .. warning:: 253 When running a graphed callable, you must pass its arguments in the same order and format 254 they appeared in that callable's ``sample_args``. 255 256 .. warning:: 257 The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled 258 caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`. 259 """ 260 if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): 261 raise RuntimeError( 262 "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`." 263 ) 264 265 just_one_callable = False 266 267 if not isinstance(callables, tuple): 268 just_one_callable = True 269 callables = (callables,) 270 sample_args = (sample_args,) 271 272 flatten_sample_args = [] 273 274 for c, args in zip(callables, sample_args): 275 if isinstance(c, torch.nn.Module): 276 assert ( 277 len(c._backward_hooks) == 0 278 and len(c._forward_hooks) == 0 279 and len(c._forward_pre_hooks) == 0 280 ), ( 281 "Modules must not have hooks registered at the time they are passed. However, registering hooks " 282 + "on modules after passing them through make_graphed_callables is allowed." 283 ) 284 assert all(b.requires_grad is False for b in c.buffers()), ( 285 "In any :class:`~torch.nn.Module` passed to " 286 + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " 287 + "``requires_grad=False``." 288 ) 289 flatten_arg = torch.utils._pytree.arg_tree_leaves(*args) 290 flatten_sample_args.append(tuple(flatten_arg)) 291 assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( 292 "In the beta API, sample_args " 293 + "for each callable must contain only Tensors. Other types are not allowed." 294 ) 295 296 # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly 297 # passes to forward (ie, its sample_args) AND the module's parameter attributes. 298 per_callable_len_user_args = [len(args) for args in flatten_sample_args] 299 per_callable_module_params = [ 300 tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () 301 for c in callables 302 ] 303 per_callable_static_input_surfaces = [ 304 flatten_sample_args[i] + per_callable_module_params[i] 305 for i in range(len(callables)) 306 ] 307 308 fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] 309 bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] 310 311 mempool = graph_pool_handle() if pool is None else pool 312 313 # Warmup 314 # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work 315 # from ending up in any captures. 316 torch.cuda.synchronize() 317 with torch.cuda.stream(torch.cuda.Stream()): 318 for func, args, static_input_surface in zip( 319 callables, sample_args, per_callable_static_input_surfaces 320 ): 321 grad_inputs, outputs, outputs_grad = None, None, None 322 for _ in range(num_warmup_iters): 323 outputs = torch.utils._pytree.tree_leaves(func(*args)) 324 outputs_grad = tuple(o for o in outputs if o.requires_grad) 325 if len(outputs_grad) > 0: 326 grad_inputs = torch.autograd.grad( 327 outputs=outputs_grad, 328 inputs=tuple( 329 i for i in static_input_surface if i.requires_grad 330 ), 331 grad_outputs=tuple( 332 torch.empty_like(o) for o in outputs if o.requires_grad 333 ), 334 only_inputs=True, 335 allow_unused=allow_unused_input, 336 ) 337 for v in [outputs, outputs_grad, grad_inputs]: 338 del v 339 340 torch.cuda.synchronize() 341 342 # All captures here share a mempool. To avoid replays corrupting each other's memory, 343 # the safest approach is to capture all passes in the same order they'll run: 344 # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. 345 346 # Capture forward graphs 347 per_callable_static_outputs = [] 348 per_callable_output_unflatten_spec = [] 349 for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): 350 with torch.cuda.graph(fwd_graph, pool=mempool): 351 outputs = func(*args) 352 353 flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs) 354 per_callable_static_outputs.append(tuple(flatten_outputs)) 355 per_callable_output_unflatten_spec.append(spec) 356 357 # Capture backward graphs in reverse order 358 per_callable_static_grad_outputs = [] 359 per_callable_static_grad_inputs = [] 360 for static_input_surface, static_outputs, bwd_graph, module_params in zip( 361 reversed(per_callable_static_input_surfaces), 362 reversed(per_callable_static_outputs), 363 reversed(bwd_graphs), 364 reversed(per_callable_module_params), 365 ): 366 # For now, assumes all static_outputs require grad 367 # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad." 368 static_grad_outputs = tuple( 369 torch.empty_like(o) if o.requires_grad else None for o in static_outputs 370 ) 371 372 outputs_grad = tuple(o for o in static_outputs if o.requires_grad) 373 grad_inputs = None 374 if len(outputs_grad) > 0: 375 with torch.cuda.graph(bwd_graph, pool=mempool): 376 grad_inputs = torch.autograd.grad( 377 outputs=outputs_grad, 378 inputs=tuple(i for i in static_input_surface if i.requires_grad), 379 grad_outputs=tuple(o for o in static_grad_outputs if o is not None), 380 only_inputs=True, 381 allow_unused=allow_unused_input, 382 ) 383 384 # Constructs a tuple suitable for returning from Graphed.backward: 385 # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad. 386 # I couldn't think of a slick one-liner for this pattern. 387 static_grad_inputs = [] 388 grad_idx = 0 389 for arg in static_input_surface: 390 if arg.requires_grad and grad_inputs is not None: 391 static_grad_inputs.append(grad_inputs[grad_idx]) 392 grad_idx += 1 393 else: 394 static_grad_inputs.append(None) # type: ignore[arg-type] 395 static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] 396 397 per_callable_static_grad_outputs.append(static_grad_outputs) 398 per_callable_static_grad_inputs.append(static_grad_inputs) 399 400 # Reverses the most recent two lists 401 per_callable_static_grad_outputs.reverse() 402 per_callable_static_grad_inputs.reverse() 403 # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. 404 405 def make_graphed_autograd_function( 406 fwd_graph, 407 bwd_graph, 408 module_params, 409 len_user_args, 410 output_unflatten_spec, 411 static_input_surface, 412 static_outputs, 413 static_grad_outputs, 414 static_grad_inputs, 415 ): 416 class Graphed(torch.autograd.Function): 417 @staticmethod 418 def forward(ctx, *inputs): 419 # At this stage, only the user args may (potentially) be new tensors. 420 for i in range(len_user_args): 421 if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): 422 static_input_surface[i].copy_(inputs[i]) 423 fwd_graph.replay() 424 assert isinstance(static_outputs, tuple) 425 return tuple(o.detach() for o in static_outputs) 426 427 @staticmethod 428 @torch.autograd.function.once_differentiable 429 def backward(ctx, *grads): 430 assert len(grads) == len(static_grad_outputs) 431 for g, grad in zip(static_grad_outputs, grads): 432 if g is not None: 433 # don't copy if autograd gods have been kind and the 434 # incoming grad is already in the right place 435 if g.data_ptr() != grad.data_ptr(): 436 g.copy_(grad) 437 bwd_graph.replay() 438 439 # Input args that didn't require grad expect a None gradient. 440 assert isinstance(static_grad_inputs, tuple) 441 return tuple( 442 b.detach() if b is not None else b for b in static_grad_inputs 443 ) 444 445 def functionalized(*user_args): 446 # Runs the autograd function with inputs == all inputs to the graph that might require grad 447 # (explicit user args + module parameters) 448 # Assumes module params didn't change since capture. 449 flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args) 450 out = Graphed.apply(*(tuple(flatten_user_args) + module_params)) 451 return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec) 452 453 return functionalized 454 455 # Put together the final graphed callables 456 ret = [] 457 for i, func in enumerate(callables): 458 graphed = make_graphed_autograd_function( 459 fwd_graphs[i], 460 bwd_graphs[i], 461 per_callable_module_params[i], 462 per_callable_len_user_args[i], 463 per_callable_output_unflatten_spec[i], 464 per_callable_static_input_surfaces[i], 465 per_callable_static_outputs[i], 466 per_callable_static_grad_outputs[i], 467 per_callable_static_grad_inputs[i], 468 ) 469 470 if isinstance(func, torch.nn.Module): 471 472 def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): 473 def new_fwd(*user_args): 474 # If the module's training-or-eval state matches what we graphed, 475 # run the graph, otherwise run the original forward method 476 if func.training == graph_training_state: 477 return graphed(*user_args) 478 else: 479 return orig_fwd(*user_args) 480 481 return new_fwd 482 483 func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment] 484 ret.append(func) 485 else: 486 ret.append(graphed) 487 488 if just_one_callable: 489 return ret[0] 490 491 return tuple(ret) 492