1# mypy: allow-untyped-defs 2import copy 3import functools 4import heapq 5import itertools 6import logging 7import math 8import operator 9import os 10from collections import defaultdict 11from dataclasses import dataclass, replace 12from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union 13 14import torch 15import torch._inductor.inductor_prims 16import torch.fx as fx 17import torch.utils._pytree as pytree 18from torch.fx.experimental._backward_state import BackwardState 19from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types 20from torch.fx.experimental.sym_node import magic_methods, method_to_operator 21from torch.fx.experimental.symbolic_shapes import ( 22 find_symbol_binding_fx_nodes, 23 free_symbols, 24 hint_int, 25 is_symbol_binding_fx_node, 26) 27from torch.fx.passes import graph_drawer 28from torch.utils.checkpoint import CheckpointPolicy 29 30from . import config 31from ._aot_autograd.logging_utils import get_aot_graph_name 32from ._aot_autograd.utils import is_with_effects 33from .compile_utils import fx_graph_cse, get_aten_target 34 35 36if TYPE_CHECKING: 37 import sympy 38 39 40AOT_PARTITIONER_DEBUG = config.debug_partitioner 41log = logging.getLogger(__name__) 42 43aten = torch.ops.aten 44prims = torch.ops.prims 45 46 47@dataclass 48class OpTypes: 49 """Class for keeping track of different operator categories""" 50 51 fusible_ops: Set[Callable] 52 compute_intensive_ops: Set[Callable] 53 random_ops: Set[Callable] 54 view_ops: Set[Callable] 55 recomputable_ops: Set[Callable] 56 57 def is_fusible(self, node: fx.Node): 58 return get_aten_target(node) in self.fusible_ops 59 60 def is_compute_intensive(self, node: fx.Node): 61 return get_aten_target(node) in self.compute_intensive_ops 62 63 def is_random(self, node: fx.Node): 64 return get_aten_target(node) in self.random_ops 65 66 def is_view(self, node: fx.Node): 67 return get_aten_target(node) in self.view_ops 68 69 def is_recomputable(self, node: fx.Node): 70 return get_aten_target(node) in self.recomputable_ops 71 72 73@dataclass 74class NodeInfo: 75 # Be careful about iterating over these explicitly, as their order may not 76 # be deterministic 77 inputs: List[fx.Node] 78 _required_fw_nodes: Set[fx.Node] 79 required_bw_nodes: Set[fx.Node] 80 unclaimed_nodes: Set[fx.Node] 81 fw_order: Dict[fx.Node, int] 82 83 @functools.cached_property 84 def required_fw_nodes(self) -> List[fx.Node]: 85 return sorted( 86 (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n] 87 ) 88 89 def is_required_fw(self, n: fx.Node) -> bool: 90 return n in self._required_fw_nodes 91 92 def is_required_bw(self, n: fx.Node) -> bool: 93 return n in self.required_bw_nodes 94 95 def is_unclaimed(self, n: fx.Node) -> bool: 96 return n in self.unclaimed_nodes 97 98 def get_fw_order(self, n: fx.Node) -> int: 99 assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!" 100 return self.fw_order[n] 101 102 103@dataclass 104class MinCutOptions: 105 ban_if_used_far_apart: bool 106 ban_if_long_fusible_chains: bool 107 ban_if_materialized_backward: bool 108 ban_if_not_in_allowlist: bool 109 ban_if_reduction: bool 110 111 112def must_recompute(node: fx.Node) -> bool: 113 return node.meta.get("recompute", None) in [ 114 CheckpointPolicy.MUST_RECOMPUTE, 115 CheckpointPolicy.PREFER_RECOMPUTE, 116 ] 117 118 119def has_recomputable_ops(fx_g: fx.GraphModule) -> bool: 120 found = False 121 for node in fx_g.graph.nodes: 122 if must_recompute(node): 123 return True 124 return False 125 126 127def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool: 128 for node in fx_g.graph.nodes: 129 if ( 130 must_recompute(node) 131 and hasattr(node.target, "tags") 132 and torch.Tag.nondeterministic_seeded in node.target.tags 133 ): 134 return True 135 return False 136 137 138def sym_node_size(node: fx.Node) -> int: 139 if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)): 140 return 1 141 assert isinstance(node.meta["val"], torch.SymFloat) 142 return 4 143 144 145class InvalidNodeBase: 146 def __repr__(self): 147 return "Invalid Node" 148 149 150InvalidNode = InvalidNodeBase() 151 152 153def _extract_graph_with_inputs_outputs( 154 joint_graph: fx.Graph, 155 inputs: List[fx.Node], 156 outputs: List[fx.Node], 157 subgraph: Optional[str] = None, 158) -> fx.Graph: 159 """ 160 Given a graph, extracts out a subgraph that takes the specified nodes as 161 inputs and returns the specified outputs. 162 163 This includes specifying non-placeholder nodes as inputs. 164 165 The general strategy is to initialize all inputs with proxies as we 166 encounter them, and trace through the graph, only keeping values which take 167 in valid proxies. Then, all dead code is eliminated. 168 """ 169 new_graph = fx.Graph() 170 env = {} 171 172 # Add new placeholder nodes in the order specified by the inputs 173 for node in inputs: 174 new_node = new_graph.placeholder(node.name) 175 # Can't use node_copy here as we may be turning previous call_function into placeholders 176 new_node.meta = node.meta 177 env[node] = new_node 178 179 for node in joint_graph.nodes: 180 if _must_be_in_backward(node) and subgraph != "backward": 181 env[node] = InvalidNode # type: ignore[assignment] 182 continue 183 184 if node in env: 185 # Node must be one of our inputs. (Any member of env which wasn't an 186 # input to start must have been created by this loop and won't be in 187 # joint_graph.nodes). 188 continue 189 elif node.op == "placeholder": 190 env[node] = InvalidNode # type: ignore[assignment] 191 elif node.op == "call_function": 192 all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs) 193 all_args = [ 194 isinstance(env[x], InvalidNodeBase) 195 for x in all_args 196 if isinstance(x, fx.Node) 197 ] 198 if any(all_args): 199 env[node] = InvalidNode # type: ignore[assignment] 200 continue 201 env[node] = new_graph.node_copy(node, lambda x: env[x]) 202 elif node.op == "get_attr": 203 env[node] = new_graph.node_copy(node, lambda x: env[x]) 204 elif node.op == "output": 205 pass 206 output_values = [] 207 for x in outputs: 208 if isinstance(x, fx.Node): 209 if x not in env: 210 raise RuntimeError(f"Node {x} couldn't be found in env") 211 assert not isinstance( 212 env[x], InvalidNodeBase 213 ), f"Node {x} was invalid, but is output" 214 output_values.append(env[x]) 215 else: 216 output_values.append(x) 217 new_graph.output(tuple(output_values)) 218 219 new_graph.eliminate_dead_code() 220 new_graph.lint() 221 return new_graph 222 223 224def _is_primal(node: fx.Node) -> bool: 225 return ( 226 node.op == "placeholder" 227 and "tangents" not in str(node.target) 228 and not _is_bwd_seed_offset(node) 229 and not _is_fwd_seed_offset(node) 230 ) 231 232 233def _is_tangent(node: fx.Node) -> bool: 234 return node.op == "placeholder" and "tangents" in str(node.target) 235 236 237def _is_bwd_seed_offset(node: fx.Node) -> bool: 238 return node.op == "placeholder" and ( 239 "bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target) 240 ) 241 242 243def _is_fwd_seed_offset(node: fx.Node) -> bool: 244 return node.op == "placeholder" and ( 245 "fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target) 246 ) 247 248 249def _is_backward_state(node: fx.Node) -> bool: 250 return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState) 251 252 253def _has_tag_is_backward(node: fx.Node) -> bool: 254 return node.meta.get("partitioner_tag", None) == "is_backward" 255 256 257def _has_tag_must_be_in_backward(node: fx.Node) -> bool: 258 return node.meta.get("partitioner_tag", None) == "must_be_in_backward" 259 260 261def _must_be_in_backward(node: fx.Node) -> bool: 262 return _has_tag_must_be_in_backward(node) or ( 263 _has_tag_is_backward(node) and is_with_effects(node) 264 ) 265 266 267def _extract_fwd_bwd_outputs( 268 joint_module: fx.GraphModule, *, num_fwd_outputs 269) -> Tuple[List[fx.Node], List[fx.Node]]: 270 outputs = pytree.arg_tree_leaves( 271 *(node.args for node in joint_module.graph.find_nodes(op="output")) 272 ) 273 fwd_outputs = outputs[:num_fwd_outputs] 274 bwd_outputs = outputs[num_fwd_outputs:] 275 return fwd_outputs, bwd_outputs 276 277 278def _remove_by_name(saved_values: List[fx.Node], name: str): 279 for saved_value in saved_values: 280 if saved_value.name == name: 281 saved_values.remove(saved_value) 282 break 283 284 285def _extract_fwd_bwd_modules( 286 joint_module: fx.GraphModule, 287 saved_values: List[fx.Node], 288 saved_sym_nodes: List[fx.Node], 289 *, 290 num_fwd_outputs: int, 291) -> Tuple[fx.GraphModule, fx.GraphModule]: 292 fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( 293 joint_module, num_fwd_outputs=num_fwd_outputs 294 ) 295 placeholders = joint_module.graph.find_nodes(op="placeholder") 296 primal_inputs = [*filter(_is_primal, placeholders)] 297 tangent_inputs = [*filter(_is_tangent, placeholders)] 298 fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)] 299 bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] 300 backward_state_inputs = [*filter(_is_backward_state, placeholders)] 301 302 bwd_graph = _extract_graph_with_inputs_outputs( 303 joint_module.graph, 304 saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs, 305 bwd_outputs, 306 "backward", 307 ) 308 309 for node in bwd_graph.find_nodes(op="placeholder"): 310 # This is to filter out saved values that don't actually end up being used by the backwards pass 311 if not node.users: 312 _remove_by_name(saved_values, node.name) 313 _remove_by_name(saved_sym_nodes, node.name) 314 elif _is_backward_state(node): 315 # BackwardState is saved directly 316 _remove_by_name(saved_values, node.name) 317 assert backward_state_inputs 318 319 # Now that we have the finalized list of saved values, we need to ensure 320 # we propagate all symbols which are referenced by backwards inputs. 321 # These are not directly used in the graph but are required for downstream 322 # sizevar assignment 323 saved_symbols: Set[sympy.Symbol] = set() 324 saved_sym_nodes_binding = [] 325 saved_sym_nodes_derived = [] 326 327 # Some symbols may already be bound in the directly saved_sym_nodes, 328 # keep track of them so we don't re-bind them 329 for node in saved_sym_nodes: 330 symbol = is_symbol_binding_fx_node(node) 331 if symbol: 332 saved_symbols.add(symbol) 333 saved_sym_nodes_binding.append(node) 334 else: 335 saved_sym_nodes_derived.append(node) 336 337 # Now go through all of the prospective backward inputs and track any 338 # other symbols we need to bind 339 symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) 340 for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs): 341 if "val" not in node.meta: 342 continue 343 new_symbols = free_symbols(node.meta["val"]) - saved_symbols 344 # NB: Deterministic order please! 345 for s in sorted(new_symbols, key=lambda s: s.name): 346 # NB: For well formed graphs, the symbol should always be present, 347 # but we also have ways to produce ill-formed graphs, e.g., direct 348 # make_fx usages, so don't choke in this case 349 if s not in symbol_bindings: 350 continue 351 saved_sym_nodes_binding.append(symbol_bindings[s]) 352 saved_symbols |= new_symbols 353 354 # Update saved_sym_nodes that are now reordered to have all bindings at 355 # front. This can also be used later on to figure out the position of saved 356 # sym nodes in the output of fwd graph. 357 saved_sym_nodes.clear() 358 saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) 359 360 # Now, we re-generate the fwd/bwd graphs. 361 # NB: This might increase compilation time, but I doubt it matters 362 fwd_graph = _extract_graph_with_inputs_outputs( 363 joint_module.graph, 364 primal_inputs + fwd_seed_offset_inputs, 365 fwd_outputs + saved_values + saved_sym_nodes, 366 "forward", 367 ) 368 bwd_graph = _extract_graph_with_inputs_outputs( 369 joint_module.graph, 370 saved_sym_nodes 371 + saved_values 372 + tangent_inputs 373 + bwd_seed_offset_inputs 374 + backward_state_inputs, 375 bwd_outputs, 376 "backward", 377 ) 378 379 fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph) 380 bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph) 381 return fwd_module, bwd_module 382 383 384def default_partition( 385 joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs 386) -> Tuple[fx.GraphModule, fx.GraphModule]: 387 """ 388 Partitions the :attr:`joint_module` in a manner that closely resembles the 389 behavior observed in the original ``.forward()`` and ``.backward()`` of the 390 callable, i.e., the resulting forward graph contains those operators that 391 are executed in the original ``.forward()`` callable passed to 392 :func:`aot_function`. 393 394 The default partitioner collects the operators that are between the forward 395 inputs and the forward outputs. This helps in finding the tensors which have 396 to be stashed for the backward pass. These stashed tensors become the output 397 of the generated forward graph. The remaining operators are then placed in 398 the backward graph. 399 400 .. warning:: 401 This API is experimental and likely to change. 402 403 Args: 404 joint_module(fx.GraphModule): The joint forward and backward graph. This 405 is the result of AOT Autograd tracing. 406 407 Returns: 408 Returns the generated forward and backward Fx graph modules. 409 """ 410 if has_recomputable_ops(joint_module): 411 return min_cut_rematerialization_partition( 412 joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs 413 ) 414 primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) 415 fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) 416 inputs = primal_inputs + fwd_seed_offset_inputs 417 fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( 418 joint_module, num_fwd_outputs=num_fwd_outputs 419 ) 420 forward_only_graph = _extract_graph_with_inputs_outputs( 421 joint_module.graph, inputs, fwd_outputs, "forward" 422 ) 423 forward_node_names = { 424 node.name for node in forward_only_graph.nodes if node.op != "output" 425 } 426 saved_values = [] 427 saved_sym_nodes = [] 428 429 for node in joint_module.graph.nodes: 430 if node.name not in forward_node_names: 431 continue 432 if is_sym_node(node): 433 # Symints must be kept separate from tensors so that PythonFunction only calls 434 # save_for_backward on tensors and stashes symints in autograd .ctx 435 saved_sym_nodes.append(node) 436 elif "tensor_meta" not in node.meta and node.op == "call_function": 437 # Since we can't save tuple of tensor values, we need to flatten out what we're saving 438 users = node.users 439 assert all(user.target == operator.getitem for user in users) 440 saved_values.extend(users) 441 else: 442 backward_usages = [ 443 n for n in node.users if n.name not in forward_node_names 444 ] 445 if "tensor_meta" in node.meta and all( 446 is_sym_node(n) for n in backward_usages 447 ): 448 # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, 449 # and not the actual tensor data, 450 # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. 451 # 452 # Note that saving the tensor could also cause compilation problems: 453 # If the user mutated an input in the forward and uses its sizes/strides in the backward, 454 # then we would be obligated to clone the input before saving it to appease autograd. 455 # (This is how we originally found this bug). 456 saved_sym_nodes.extend(backward_usages) 457 else: 458 saved_values.append(node) 459 saved_values = list(dict.fromkeys(saved_values).keys()) 460 saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) 461 462 return _extract_fwd_bwd_modules( 463 joint_module, 464 saved_values, 465 saved_sym_nodes=saved_sym_nodes, 466 num_fwd_outputs=num_fwd_outputs, 467 ) 468 469 470INT_INF = int(1e6) 471 472 473def _tensor_nbytes(numel: int, dtype) -> int: 474 return numel * dtype.itemsize 475 476 477def _size_of(node: fx.Node) -> int: 478 def object_nbytes(x) -> int: 479 if not isinstance(x, torch.Tensor): 480 return 0 481 return _tensor_nbytes(hint_int(x.numel(), fallback=4096), x.dtype) 482 483 if "val" in node.meta: 484 val = node.meta["val"] 485 if isinstance(val, py_sym_types): 486 return 1 487 # NB: The fallback values here are meaningless, maybe we should respect 488 # torch._inductor.config.unbacked_symint_fallback (but this is a 489 # layering violation) 490 elif isinstance(val, (list, tuple)): 491 return sum(object_nbytes(n) for n in val) 492 elif isinstance(val, dict): 493 return sum(object_nbytes(n) for _, n in val.items()) 494 elif isinstance(val, torch.Tensor): 495 return object_nbytes(val) 496 497 raise RuntimeError(f"Unknown metadata type {type(val)} on node {node}") 498 if node.op == "get_attr": 499 return 0 500 raise RuntimeError( 501 f"Node {node} didn't have `val` metadata; we should always have `val` metadata on the nodes." 502 ) 503 504 505# Used for some investigative purposes 506def _count_ops(graph: fx.Graph): 507 from collections import defaultdict 508 509 cnt: Dict[str, int] = defaultdict(int) 510 for node in graph.nodes: 511 if node.op == "call_function": 512 cnt[node.target.__name__] += 1 513 print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) 514 515 516@functools.lru_cache(None) 517def pointwise_ops(): 518 ops = [] 519 for attr_name in dir(torch.ops.aten): 520 opoverloadpacket = getattr(torch.ops.aten, attr_name) 521 if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket): 522 continue 523 524 for overload in opoverloadpacket.overloads(): 525 op_overload = getattr(opoverloadpacket, overload) 526 if torch.Tag.pointwise in op_overload.tags: 527 # currently aot autograd uses packet not overload 528 ops.append(opoverloadpacket) 529 break 530 531 return ops 532 533 534def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]: 535 arg_depths = { 536 arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node) 537 } 538 return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) 539 540 541def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: 542 """ 543 This pass finds the first bwd node in the graph (by looking at users of 544 tangents) and then reorders the graph by walking from this node to all the 545 way to the end of the graph. At each op in this traveral, we insert this op 546 in a new graph and try to bring only the relevant subgraph from the other 547 non-bwd edges relevant for this op. This closely mimics the behavior of 548 autograd engine. 549 550 Why is this pass required in the first place? 551 552 This is an artifact of how partitioners work today. The starting point of 553 partitioner is a joint graph, which is fwd and then bwd graph. In the case 554 of checkpointing, we keep portions of fwd graph in their original place in 555 the joint graph, while obtaining a bwd graph. As a result, the resulting bwd 556 graph has copies of recomputed fwd subgraphs followed by the original bwd 557 graph. If we run this naively, this leads to bad memory footprint, because 558 the fwd subgraphs are live for way longer duration than necessary. This pass 559 reorders the operations such that we prioritize the ops for the original bwd 560 graph while only realizing those ops from the fwd graph that are necessary 561 at any given point in the graph. 562 """ 563 564 new_graph = fx.Graph() 565 env: Dict[fx.Node, fx.Node] = {} 566 567 # Add new placeholder nodes in the order specified by the inputs 568 for node in gm.graph.find_nodes(op="placeholder"): 569 env[node] = new_graph.node_copy(node, lambda x: env[x]) 570 571 order = {} 572 for idx, node in enumerate(gm.graph.nodes): 573 order[node] = idx 574 575 def insert_node_in_graph(node): 576 cur_nodes = [node] 577 insertable_nodes = set() 578 while len(cur_nodes) > 0: 579 node = cur_nodes.pop() 580 if node in insertable_nodes or node in env: 581 continue 582 insertable_nodes.add(node) 583 584 # Bias traversal towards the nodes that have higher depth - prioritizes 585 # critical path first. 586 cur_nodes += node.all_input_nodes 587 588 insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n]) 589 for node in insertable_nodes: 590 env[node] = new_graph.node_copy(node, lambda x: env[x]) 591 592 # Find first bwd node in the graph 593 tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) 594 first_node_in_bwd = None 595 minimum_order = math.inf 596 for tangent in tangent_inputs: 597 for user in tangent.users: 598 if order[user] < minimum_order: 599 minimum_order = order[user] 600 first_node_in_bwd = user 601 602 # If gradInp does not depend upon gradOut, we may not find any nodes in the "backwards pass" 603 if first_node_in_bwd is None: 604 return gm 605 606 # Build the graph op-by-op by starting from the node all the way to the end 607 for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]: 608 insert_node_in_graph(node) 609 610 # The output node is already built by the traversal. 611 new_gm = torch.fx.GraphModule(gm, new_graph) 612 return new_gm 613 614 615def functionalize_rng_ops( 616 joint_module: fx.GraphModule, 617 fw_module: fx.GraphModule, 618 bw_module: fx.GraphModule, 619 num_sym_nodes: int, 620) -> Tuple[fx.GraphModule, fx.GraphModule]: 621 # During user-driven activation checkpointing, we have to ensure that a rng 622 # op in fwd yields the same output as the recomputed rng op in the bwd. To 623 # do this, we use functionalize wrappers to wrap the random ops and share 624 # rng state between the fwd and bwd graphs. 625 626 # There are 3 main steps to do this 627 # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. 628 # Step 2 - Modify the fwd pass such that 629 # 1) Replace rand with run_and_save_rng_state wrapper 630 # 2) Replace the users of the original op with the output[1] of this op. 631 # 3) Collect all the rng_state - output[0] of each op, and make them 632 # output nodes. Special care needs to be taken here because fwd outputs 633 # has symints at the very end. 634 # Step 3 - Modify the bwd pass such that 635 # 1) Add the input nodes just before the tangents for the stashed rng states 636 # 2) Replace rand with run_with_save_rng_state wrappers 637 # 3) Use the stashed states as inputs to these ops 638 639 # Unique id to generate name 640 uid = itertools.count() 641 642 def get_rng_ops(gmod): 643 random_nodes = {} 644 for node in gmod.graph.nodes: 645 if ( 646 node.op == "call_function" 647 and hasattr(node.target, "tags") 648 and torch.Tag.nondeterministic_seeded in node.target.tags 649 ): 650 random_nodes[node.name] = node 651 return random_nodes 652 653 def get_device(node): 654 """ 655 Check the example value of the node outputs to find the device type. 656 """ 657 if "val" not in node.meta: 658 return None 659 660 candidates = node.meta["val"] 661 if not isinstance(candidates, tuple): 662 candidates = (candidates,) 663 664 for candidate in candidates: 665 if isinstance(candidate, torch.Tensor): 666 if candidate.device.type == "cuda": 667 return "cuda" 668 669 return "cpu" 670 671 def get_sample_rng_state(device): 672 if device == "cuda": 673 return torch.cuda.get_rng_state() 674 return torch.get_rng_state() 675 676 # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. 677 joint_graph_rng_ops = get_rng_ops(joint_module) 678 fw_graph_rng_ops = get_rng_ops(fw_module) 679 bw_graph_rng_ops = get_rng_ops(bw_module) 680 recomputable_rng_ops_map = {} 681 for node in joint_module.graph.nodes: 682 if ( 683 must_recompute(node) 684 and hasattr(node.target, "tags") 685 and torch.Tag.nondeterministic_seeded in node.target.tags 686 ): 687 base_node = joint_graph_rng_ops[node.name] 688 fw_node = fw_graph_rng_ops[node.name] 689 bw_node = bw_graph_rng_ops[node.name] 690 recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node} 691 692 run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state 693 run_with_rng_state = torch._prims.rng_prims.run_with_rng_state 694 bw_tangent_start_node = None 695 for node in bw_module.graph.find_nodes(op="placeholder"): 696 if "tangent" in node.name: 697 bw_tangent_start_node = node 698 break 699 if bw_tangent_start_node is None: 700 raise RuntimeError( 701 "Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this" 702 ) 703 704 fw_rng_state_outputs = [] 705 for base_node, node_pair in recomputable_rng_ops_map.items(): 706 # Step 2 - Modify the fwd pass such that 707 fw_node = node_pair["fwd"] 708 bw_node = node_pair["bwd"] 709 fw_graph = fw_module.graph 710 with fw_graph.inserting_before(fw_node): 711 functional_fw_node = fw_graph.create_node( 712 "call_function", 713 run_and_save_rng, 714 args=(fw_node.target, *fw_node.args), 715 kwargs=fw_node.kwargs, 716 ) 717 state = fw_graph.create_node( 718 "call_function", 719 operator.getitem, 720 args=(functional_fw_node, 0), 721 kwargs={}, 722 ) 723 rng_output = fw_graph.create_node( 724 "call_function", 725 operator.getitem, 726 args=( 727 functional_fw_node, 728 1, 729 ), 730 kwargs={}, 731 ) 732 fw_node.replace_all_uses_with(rng_output) 733 fw_graph.erase_node(fw_node) 734 fw_rng_state_outputs.append(state) 735 736 # Step 3 - Modify the bwd pass such that 737 bw_graph = bw_module.graph 738 with bw_graph.inserting_before(bw_tangent_start_node): 739 state_name = f"rng_state_output_{next(uid)}" 740 bw_rng_state_node = bw_graph.placeholder(state_name) 741 bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node)) 742 743 with bw_graph.inserting_before(bw_node): 744 rng_output = bw_graph.create_node( 745 "call_function", 746 run_with_rng_state, 747 args=(bw_rng_state_node, bw_node.target, *bw_node.args), 748 kwargs=bw_node.kwargs, 749 ) 750 751 bw_node.replace_all_uses_with(rng_output) 752 bw_graph.erase_node(bw_node) 753 754 # Add the rng states in the output of the fwd graph. AOT Autograd assumes 755 # that symints are at the end of forward graph outputs. So, insert the new 756 # rng states accordingly. 757 fw_output_node = next(iter(fw_module.graph.find_nodes(op="output"))) 758 fw_outputs = fw_output_node.args[0] 759 sym_node_start_idx = len(fw_outputs) - num_sym_nodes 760 outputs = ( 761 fw_outputs[:sym_node_start_idx] 762 + tuple(fw_rng_state_outputs) 763 + fw_outputs[sym_node_start_idx:] 764 ) 765 fw_module.graph.output(outputs) 766 fw_module.graph.erase_node(fw_output_node) 767 fw_module.recompile() 768 bw_module.recompile() 769 return fw_module, bw_module 770 771 772def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: 773 """ 774 If there are two consecutive checkpointed blocks with no operator in 775 between, we would still want to stash the tensor at the boundary of 776 checkpointed blocks. The following pass makes the last output node 777 non-recomputable to allow for that. 778 """ 779 for node in joint_module.graph.nodes: 780 if must_recompute(node): 781 for user in node.users: 782 if ( 783 must_recompute(user) 784 and user.meta["ac_graph_id"] > node.meta["ac_graph_id"] 785 ): 786 node.meta["recompute"] = CheckpointPolicy.MUST_SAVE 787 return joint_module 788 789 790def solve_min_cut( 791 joint_graph: fx.Graph, 792 node_info: NodeInfo, 793 min_cut_options: MinCutOptions, 794 dont_ban=None, 795): 796 if dont_ban is None: 797 dont_ban = set() 798 op_types = get_default_op_list() 799 800 if AOT_PARTITIONER_DEBUG: 801 joint_module_ops = { 802 str(node.target._overloadpacket) 803 for node in joint_graph.nodes 804 if node.op == "call_function" and hasattr(node.target, "_overloadpacket") 805 } 806 ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} 807 print("Ops banned from re-materialization: ", ops_ignored) 808 print() 809 810 def can_fuse_into_auto_functionalized(a, b): 811 if b.target != torch.ops.higher_order.auto_functionalized: 812 return False 813 mutable_op = b.args[0] 814 ( 815 mutable_arg_names, 816 _, 817 ) = torch._higher_order_ops.auto_functionalize.get_mutable_args(mutable_op) 818 for name in mutable_arg_names: 819 arg = b.kwargs[name] 820 if a is arg: 821 return True 822 if isinstance(arg, list): 823 if a in arg: 824 return True 825 return False 826 827 def can_fuse_into_triton_kernel_wrapper_functional(a, b): 828 if b.target != torch.ops.higher_order.triton_kernel_wrapper_functional: 829 return False 830 mutable_arg_names = b.kwargs["tensors_to_clone"] 831 for name in mutable_arg_names: 832 arg = b.kwargs["kwargs"][name] 833 if a is arg: 834 return True 835 return False 836 837 def is_fusible(a, b): 838 # We can perform "memory fusion" into a cat, but cat cannot be a 839 # producer to a fusion 840 if get_aten_target(b) == aten.cat: 841 return True 842 if can_fuse_into_auto_functionalized(a, b): 843 return True 844 if can_fuse_into_triton_kernel_wrapper_functional(a, b): 845 return True 846 return op_types.is_fusible(a) and op_types.is_fusible(b) 847 848 try: 849 import networkx as nx 850 except ImportError as e: 851 raise RuntimeError( 852 "Need networkx installed to perform smart recomputation " "heuristics" 853 ) from e 854 855 def is_materialized_backwards(node): 856 if op_types.is_view(node): 857 return False 858 cur_nodes = {node} 859 while len(cur_nodes) > 0: 860 cur = cur_nodes.pop() 861 for user in cur.users: 862 if not node_info.is_required_fw(user) and not is_fusible(cur, user): 863 return True 864 if op_types.is_view(user): 865 cur_nodes.add(user) 866 867 return False 868 869 def should_ban_recomputation(node): 870 if node.op != "call_function": 871 return False 872 if node.target == operator.getitem: 873 return False 874 if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE: 875 return True 876 if config.recompute_views and op_types.is_view(node): 877 return False 878 if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: 879 return False 880 881 if min_cut_options.ban_if_not_in_allowlist: 882 if not op_types.is_recomputable(node): 883 return True 884 else: 885 if op_types.is_random(node) or op_types.is_compute_intensive(node): 886 return True 887 888 # If a node *must* be materialized in the backwards pass, then we 889 # should never recompute it. This is a pretty subtle point. In 890 # general, the assumption we make is that recomputing a node in the 891 # backwards pass is "free". However, if a node must be materialized 892 # in the backwards pass, then recomputing it is never free. 893 if min_cut_options.ban_if_materialized_backward and is_materialized_backwards( 894 node 895 ): 896 log.info("materialized backwards: %s %s", node, tuple(node.users)) 897 return True 898 899 # Arbitrary hack that sometimes seems to help things. The above 900 # modification appears to have made this heuristic a lot less critical 901 # for performance. 902 # NB: As of PR #121692, this hack no longer seems necessary. 903 if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: 904 return True 905 906 # If the output of an op is 4x smaller (arbitrary choice), 907 # then we don't allow recomputation. The idea here is that for 908 # things like reductions, saving the output of the reduction is very 909 # cheap/small, and it makes sure we don't do things like recompute 910 # normalizations in the backwards. 911 if min_cut_options.ban_if_reduction: 912 input_tensors_size = sum( 913 _size_of(i) for i in node.args if isinstance(i, fx.Node) 914 ) 915 output_size = _size_of(node) 916 return output_size * 4 < input_tensors_size 917 return False 918 919 def is_materialized(node): 920 if node.op == "placeholder": 921 return True 922 923 return not all(is_fusible(node, user) for user in node.users) 924 925 def get_node_weight(node) -> float: 926 mem_sz = _size_of(node) 927 if config.recompute_views and op_types.is_view(node): 928 # If `config.recompute_views=True`, we don't save views. This is generally 929 # a good idea since views are free to recompute, and it makes it a bit simpler 930 # to analyze. 931 # NB: If they're not free to recompute (e.g. nested tensors)... I 932 # think we should modify checks for view_ops to `is_view` and check 933 # that. Basically, with nested tensors, `aten.view` is not a "view 934 # op". 935 return math.inf 936 937 if isinstance(node.meta["val"], py_sym_types): 938 # We never want to save symfloats 939 if not isinstance(node.meta["val"], torch.SymInt): 940 return INT_INF 941 942 # Heuristic to bias towards nodes closer to the backwards pass 943 # Complete guess about current value 944 mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))) 945 if is_materialized(node): 946 return mem_sz 947 else: 948 return mem_sz * 2 949 950 nx_graph = nx.DiGraph() 951 banned_nodes = set() 952 953 def ban_recomputation_if_allowed(node): 954 if op_types.is_view(node): 955 return False 956 if node in dont_ban: 957 return False 958 # This bans recomputation of the node unless we've been forced not to by 959 # user annotation 960 if must_recompute(node): 961 return False 962 963 if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat): 964 return False 965 966 banned_nodes.add(node) 967 # A node will only ever be recomputed if there is a path from an 968 # ancestor of this node to the backwards path through this node that 969 # doesn't go through any saved value. If this node is saved, then that 970 # condition is not possible. 971 nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) 972 return True 973 974 for node in joint_graph.nodes: 975 if node.op == "output": 976 continue 977 978 if node in node_info.required_bw_nodes: 979 if node not in node_info.inputs: 980 nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) 981 continue 982 # If someone saves a input for backward as-is and backward 983 # returns that tensor as-is as a grad input, then the node x would 984 # be both a required_bw_node and an input. In this case we 985 # (1) connect x_in to to the source, (2) x_out to the sink, and 986 # (3) assign the proper weight to the x_in-x_out edge, so that 987 # x would be part of cut nodes. A case where this happens is if 988 # NestedTensor saves a offset tensor as part of the singleton int 989 # in sizes. 990 nx_graph.add_edge(node.name + "_out", "sink", capacity=math.inf) 991 992 if must_recompute(node): 993 # If user explicitly says they want to recompute a node, we honor it 994 # by adding an inf-capacity edge from X_in to the sink. 995 # This way, X_in node is guaranteed to be part of the subgraph that contains "sink" 996 # after the cut, thus guaranteeing that X op will be recomputed. 997 nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) 998 continue 999 1000 if _is_primal(node) or _is_fwd_seed_offset(node): 1001 ban_recomputation_if_allowed(node) 1002 1003 # If a node can't be recomputed (too expensive or involves randomness), 1004 # we prevent it from being recomputed by adding an inf edge to the source 1005 # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. 1006 if node_info.is_required_fw(node) and should_ban_recomputation(node): 1007 ban_recomputation_if_allowed(node) 1008 1009 # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors. 1010 is_non_tensor_node = ( 1011 "val" not in node.meta and "tensor_meta" not in node.meta 1012 ) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor)) 1013 1014 if is_sym_node(node): 1015 weight = float(sym_node_size(node)) 1016 elif is_non_tensor_node: 1017 weight = ( 1018 0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf 1019 ) 1020 else: 1021 weight = get_node_weight(node) 1022 # Creates the weights on the "node" edge 1023 nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) 1024 for user in node.users: 1025 nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf) 1026 1027 # todo(chilli): This is the most questionable of the 3 heuristics for banning recompute. 1028 # Some example models to look at where this helps perf: poolformer_m36, 1029 # mixer_b16_224, cait_m36_384 1030 1031 # The "rough" idea here is that if you have some node that is used by both a 1032 # node nearby downstream as well as a node far downstream, if we recompute 1033 # both of the downstream nodes, we're unlikely to be able to fuse both 1034 # downstream nodes together. 1035 1036 # Thus, we shouldn't aim to recompute far downstream nodes that depend on 1037 # this node. That intuition of "far downstream" is captured by whether 1038 # there's an unfusible op along the chain somewhere 1039 1040 # It could probably be improved by properly analyzing what's going on in the 1041 # backwards pass instead of only relying on whether it's unfusible in the 1042 # forwards. 1043 1044 def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: 1045 """ 1046 Finds the first unfusible node in the chain of nodes starting from 1047 `start_nodes` and returns its position. 1048 """ 1049 sorted_nodes: List[Tuple[int, fx.Node, bool]] = [] 1050 for n in start_nodes: 1051 heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True)) 1052 1053 while len(sorted_nodes) > 0: 1054 _, node, node_is_fusible = heapq.heappop(sorted_nodes) 1055 if not node_is_fusible: 1056 return node_info.get_fw_order(node) 1057 for user in node.users: 1058 if node_info.is_required_fw(user): 1059 if node_info.get_fw_order(user) > max_range: 1060 continue 1061 heapq.heappush( 1062 sorted_nodes, 1063 (node_info.get_fw_order(user), user, is_fusible(node, user)), 1064 ) 1065 return max_range 1066 1067 if min_cut_options.ban_if_used_far_apart: 1068 for used_node in node_info.required_fw_nodes: 1069 orders = [ 1070 node_info.get_fw_order(user) 1071 for user in used_node.users 1072 if node_info.is_required_fw(user) 1073 ] 1074 fw_users = [ 1075 user for user in used_node.users if node_info.is_required_fw(user) 1076 ] 1077 if len(orders) > 0: 1078 first_unfusible_use = find_first_unfusible(fw_users, max(orders)) 1079 for user in tuple(used_node.users): 1080 if ( 1081 node_info.is_required_fw(user) 1082 and node_info.get_fw_order(user) > first_unfusible_use 1083 and is_fusible(used_node, user) 1084 ): 1085 if user in banned_nodes: 1086 continue 1087 log.info( 1088 "used above/below fusible %s:(%s) -> %s -> %s:(%s)", 1089 used_node, 1090 node_info.get_fw_order(used_node), 1091 first_unfusible_use, 1092 user, 1093 node_info.get_fw_order(user), 1094 ) 1095 ban_recomputation_if_allowed(user) 1096 1097 # This heuristic is fairly straightforward. The idea is that although it is 1098 # cheap to recompute bandwidth-bound ops, we don't want to end up in a situation 1099 # where we have a long chain of pointwise ops from the beginning to the end 1100 # of the model (like say, residual connections) 1101 1102 # todo: I'm not totally sure why this heuristic matters. It's possible that this is 1103 # working around Inductor fusion decisions, or that it's a patch over 1104 # suboptimal partitioning decisions 1105 1106 # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36 1107 1108 if min_cut_options.ban_if_long_fusible_chains: 1109 visited = set() 1110 for start_node in joint_graph.nodes: 1111 if not node_info.is_required_fw(start_node): 1112 continue 1113 fusible = [(node_info.get_fw_order(start_node), start_node)] 1114 start_order = node_info.get_fw_order(start_node) 1115 while len(fusible) > 0: 1116 _, cur = heapq.heappop(fusible) 1117 if cur in visited: 1118 continue 1119 visited.add(cur) 1120 # 100 is arbitrary choice to try and prevent degenerate cases 1121 if ( 1122 node_info.get_fw_order(cur) > start_order + 100 1123 and len(fusible) == 0 1124 ): 1125 log.info( 1126 "too long %s %s %s %s", 1127 cur, 1128 start_node, 1129 node_info.get_fw_order(cur), 1130 node_info.get_fw_order(start_node), 1131 ) 1132 ban_recomputation_if_allowed(cur) 1133 break 1134 1135 for user in cur.users: 1136 if ( 1137 node_info.is_required_fw(user) 1138 and is_fusible(cur, user) 1139 and user not in banned_nodes 1140 ): 1141 heapq.heappush(fusible, (node_info.get_fw_order(user), user)) 1142 1143 try: 1144 cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") 1145 except Exception: 1146 print("Failed to compute min-cut on following graph:") 1147 print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) 1148 visualize_min_cut_graph(nx_graph) 1149 raise 1150 1151 reachable, non_reachable = partition 1152 cutset: Set[Tuple[str, str]] = set() 1153 for u, nbrs in ((n, nx_graph[n]) for n in reachable): 1154 cutset.update((u, v) for v in nbrs if v in non_reachable) 1155 1156 cut_nodes = set() 1157 for node_in, node_out in cutset: 1158 assert node_in[:-3] == node_out[:-4] 1159 node_name = node_in[:-3] 1160 cut_nodes.add(node_name) 1161 1162 name_to_node = get_name_to_node(joint_graph) 1163 # To make this stuff deterministic 1164 node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)} 1165 saved_values = sorted( 1166 (name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x] 1167 ) 1168 return saved_values, banned_nodes 1169 1170 1171def visualize_min_cut_graph(nx_graph): 1172 import networkx as nx 1173 import pydot 1174 1175 dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string() 1176 dot_graph = pydot.graph_from_dot_data(dot_format)[0] 1177 for edge in dot_graph.get_edges(): 1178 weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"] 1179 # Set edge label to weight 1180 edge.set_label(str(weight)) 1181 # Color edges with weight 'inf' as red 1182 if weight == float("inf"): 1183 edge.set_color("red") 1184 print("Visualizing the failed graph to min_cut_failed.svg") 1185 dot_graph.write_svg("min_cut_failed.svg") 1186 1187 1188def get_default_op_list() -> OpTypes: 1189 default_recomputable_ops: List[Callable] = [ 1190 aten.add, 1191 aten.sub, 1192 aten.div, 1193 aten.atan2, 1194 aten.mul, 1195 aten.max, 1196 aten.min, 1197 aten.pow, 1198 aten.remainder, 1199 aten.fmod, 1200 aten.__and__, 1201 aten.__or__, 1202 aten.__xor__, 1203 aten.__lshift__, 1204 aten.__rshift__, 1205 aten.eq, 1206 aten.ne, 1207 aten.ge, 1208 aten.gt, 1209 aten.le, 1210 aten.lt, 1211 aten.abs, 1212 aten.bitwise_not, 1213 aten.ceil, 1214 aten.floor, 1215 aten.frac, 1216 aten.neg, 1217 aten.relu, 1218 aten.round, 1219 aten.silu, 1220 aten.trunc, 1221 aten.log, 1222 aten.log10, 1223 aten.log1p, 1224 aten.log2, 1225 aten.lgamma, 1226 aten.exp, 1227 aten.expm1, 1228 aten.erf, 1229 aten.erfc, 1230 aten.cos, 1231 aten.acos, 1232 aten.cosh, 1233 aten.sin, 1234 aten.asin, 1235 aten.sinh, 1236 aten.tan, 1237 aten.atan, 1238 aten.tanh, 1239 aten.atanh, 1240 aten.sqrt, 1241 aten.rsqrt, 1242 aten.reciprocal, 1243 aten.sigmoid, 1244 aten.softplus, 1245 aten.threshold, 1246 aten.threshold_backward, 1247 aten.clamp, 1248 aten.where, 1249 aten.lerp, 1250 aten.addcmul, 1251 aten.gelu, 1252 aten.gelu_backward, 1253 aten.sum, 1254 aten.mean, 1255 aten._grad_sum_to_size, 1256 aten.sum_to_size, 1257 aten.amax, 1258 aten.to, 1259 aten.type_as, 1260 operator.getitem, 1261 aten.squeeze, 1262 aten.unsqueeze, 1263 aten.rsub, 1264 aten._to_copy, 1265 ] # noqa: E501,B950 1266 recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] 1267 recomputable_view_ops += [ 1268 aten.view, 1269 aten.slice, 1270 aten.t, 1271 prims.broadcast_in_dim, 1272 aten.expand, 1273 aten.as_strided, 1274 aten.permute, 1275 ] 1276 view_ops = recomputable_view_ops 1277 default_recomputable_ops += [ 1278 prims.div, 1279 prims.convert_element_type, 1280 aten.clone, 1281 aten._to_copy, 1282 aten.full_like, 1283 prims.var, 1284 prims.sum, 1285 aten.var, 1286 aten.std, 1287 prims.broadcast_in_dim, 1288 aten.select, 1289 aten._unsafe_view, 1290 aten.view, 1291 aten.expand, 1292 aten.slice, 1293 aten.reshape, 1294 aten.broadcast_tensors, 1295 aten.scalar_tensor, 1296 aten.ones, 1297 aten.new_zeros, 1298 aten.lift_fresh_copy, 1299 aten.arange, 1300 aten.triu, 1301 aten.var_mean, 1302 aten.isinf, 1303 aten.any, 1304 aten.full, 1305 aten.as_strided, 1306 aten.zeros, 1307 aten.empty, 1308 aten.empty_like, 1309 aten.argmax, 1310 aten.maximum, 1311 prims.iota, 1312 prims._low_memory_max_pool2d_offsets_to_indices, 1313 ] # noqa: E501,B950 1314 # Natalia said that we should allow recomputing indexing :) 1315 default_recomputable_ops += [aten.index, aten.gather] 1316 default_recomputable_ops += view_ops 1317 1318 default_recomputable_ops += pointwise_ops() 1319 1320 default_recomputable_ops += [ 1321 aten.zeros_like, 1322 ] 1323 1324 default_recomputable_ops += [method_to_operator(m) for m in magic_methods] 1325 recomputable_ops = set(default_recomputable_ops) 1326 1327 random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] 1328 compute_intensive_ops = [ 1329 aten.mm, 1330 aten.convolution, 1331 aten.convolution_backward, 1332 aten.bmm, 1333 aten.addmm, 1334 aten._scaled_dot_product_flash_attention, 1335 aten._scaled_dot_product_efficient_attention, 1336 aten._flash_attention_forward, 1337 aten._efficient_attention_forward, 1338 aten.upsample_bilinear2d, 1339 aten._scaled_mm, 1340 ] # noqa: E501,B950 1341 1342 fusible_ops = recomputable_ops | set(random_ops) 1343 return OpTypes( 1344 set(fusible_ops), 1345 set(compute_intensive_ops), 1346 set(random_ops), 1347 set(view_ops), 1348 set(recomputable_ops), 1349 ) 1350 1351 1352def get_name_to_node(graph: fx.Graph): 1353 name_to_node = {} 1354 for node in graph.nodes: 1355 name_to_node[node.name] = node 1356 return name_to_node 1357 1358 1359def greedy_knapsack( 1360 memory: List[float], runtimes: List[float], max_memory: float 1361) -> Tuple[float, List[int], List[int]]: 1362 n = len(runtimes) 1363 items = list(range(n)) 1364 1365 # Sort items based on the ratio of runtime to memory in descending order 1366 items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) 1367 1368 total_memory = 0.0 1369 total_runtime = 0.0 1370 items_to_save = [] 1371 items_to_allow_recomputing = [] 1372 1373 for i in items: 1374 if total_memory + memory[i] <= max_memory: 1375 total_memory += memory[i] 1376 total_runtime += runtimes[i] 1377 items_to_save.append(i) 1378 else: 1379 items_to_allow_recomputing.append(i) 1380 return total_runtime, items_to_save, items_to_allow_recomputing 1381 1382 1383def ilp_knapsack( 1384 memory: List[float], runtimes: List[float], max_memory: float 1385) -> Tuple[float, List[int], List[int]]: 1386 import numpy as np 1387 1388 try: 1389 from scipy.optimize import Bounds, LinearConstraint, milp 1390 except ImportError: 1391 raise RuntimeError( 1392 "To use the ILP for memory budget checkpointing you need to install scipy" 1393 ) from None 1394 1395 np_memory = np.array(memory) 1396 np_runtimes = np.array(runtimes) 1397 c = -np_runtimes # type: ignore[operator] 1398 1399 memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) 1400 constraints = [memory_constraint] 1401 1402 integrality = np.ones_like(c) 1403 res = milp( 1404 c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) 1405 ) 1406 if not res.success: 1407 raise RuntimeError("Somehow scipy solving failed") 1408 1409 items_to_save = [] 1410 items_to_allow_recomputing = [] 1411 for idx, i in enumerate(res.x): 1412 if i == 1: 1413 items_to_save.append(idx) 1414 else: 1415 items_to_allow_recomputing.append(idx) 1416 return -res.fun, items_to_save, items_to_allow_recomputing 1417 1418 1419def dp_knapsack( 1420 memory: List[float], runtimes: List[float], max_memory: float 1421) -> Tuple[float, List[int], List[int]]: 1422 # Scaling factor to convert floating point weights to integers 1423 S = 10000 1424 1425 # Quantize the memory weights 1426 quantized_memory = torch.tensor( 1427 [int(round(m * S)) for m in memory], dtype=torch.long, device="cpu" 1428 ) 1429 runtimes = torch.tensor(runtimes, dtype=torch.float32, device="cpu") 1430 1431 # Quantized pseudopolynomial DP for 0-1 Knapsack 1432 quantized_max_memory = int(round(max_memory * S)) 1433 1434 n = len(memory) 1435 1436 # Initialize the DP table 1437 # TODO(chilli): I think if needed, this memory can be optimized with sliding 1438 # window trick + Hirschberg trick: 1439 # https://codeforces.com/blog/entry/47247?#comment-316200 1440 dp = torch.zeros( 1441 (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" 1442 ) 1443 1444 for i in range(1, n + 1): 1445 current_memory = quantized_memory[i - 1] 1446 current_runtime = runtimes[i - 1] 1447 1448 # Copy the previous row 1449 dp[i, :] = dp[i - 1, :] 1450 1451 # Update dp[i, j] for all j >= current_memory 1452 if current_memory == 0: 1453 dp[i, :] = dp[i - 1, :] + current_runtime 1454 else: 1455 dp[i, current_memory:] = torch.maximum( 1456 dp[i - 1, current_memory:], 1457 dp[i - 1, :-current_memory] + current_runtime, 1458 ) 1459 1460 # Backtrack to find the items included in the knapsack 1461 saved_items = [] 1462 recomputable_items = [] 1463 j: int = quantized_max_memory 1464 for i in range(n, 0, -1): 1465 if dp[i][j] != dp[i - 1][j]: 1466 saved_items.append(i - 1) # Include this item (indexing from 0) 1467 j -= int(quantized_memory[i - 1].item()) 1468 else: 1469 recomputable_items.append(i - 1) 1470 1471 saved_items.reverse() # To get items in the order they were added 1472 1473 # The maximum runtime that can be achieved within the max_memory constraint 1474 max_runtime = dp[n][quantized_max_memory].item() 1475 1476 return max_runtime, saved_items, recomputable_items 1477 1478 1479def _optimize_runtime_with_given_memory( 1480 memory: List[float], 1481 runtimes: List[float], 1482 max_memory: float, 1483) -> Tuple[float, List[int], List[int]]: 1484 SOLVER = config.activation_memory_budget_solver 1485 if SOLVER == "greedy": 1486 return greedy_knapsack(memory, runtimes, max_memory) 1487 elif SOLVER == "ilp": 1488 return ilp_knapsack(memory, runtimes, max_memory) 1489 elif SOLVER == "dp": 1490 return dp_knapsack(memory, runtimes, max_memory) 1491 else: 1492 raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}") 1493 1494 1495from torch.utils._mode_utils import no_dispatch 1496 1497 1498def estimate_runtime(node): 1499 RUNTIME_MODE = config.activation_memory_budget_runtime_estimator 1500 1501 def materialize_arg(x): 1502 if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor): 1503 shape = list(x.meta["val"].shape) 1504 1505 def realize_symbol(d): 1506 return hint_int(d, fallback=4096) 1507 1508 shape = [realize_symbol(s) for s in shape] 1509 return x.meta["val"].new_empty_strided( 1510 shape, stride=x.meta["tensor_meta"].stride 1511 ) 1512 elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): 1513 return hint_int(x.meta["val"], fallback=4096) 1514 elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat): 1515 return 1.0 1516 elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool): 1517 return True 1518 else: 1519 return x 1520 1521 if RUNTIME_MODE == "testing": 1522 return 1 1523 1524 elif RUNTIME_MODE == "profile": 1525 with no_dispatch(): 1526 from torch._inductor.runtime.benchmarking import benchmarker 1527 1528 args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) 1529 ms = benchmarker.benchmark_gpu(lambda: node.target(*args, **kwargs)) 1530 return ms 1531 1532 elif RUNTIME_MODE == "flops": 1533 # todo(chilli): Normalize this to also return ms 1534 from torch.utils.flop_counter import FlopCounterMode 1535 1536 args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) 1537 with FlopCounterMode(display=False) as mode: 1538 node.target(*args, **kwargs) 1539 counted_flops = mode.get_total_flops() 1540 return max(counted_flops, 1) 1541 else: 1542 raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}") 1543 1544 1545def choose_saved_values_set( 1546 joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 1547) -> List[fx.Node]: 1548 if memory_budget > 1 or memory_budget < 0: 1549 raise RuntimeError( 1550 f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}" 1551 ) 1552 min_cut_options = MinCutOptions( 1553 ban_if_used_far_apart=config.ban_recompute_used_far_apart, 1554 ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, 1555 ban_if_materialized_backward=config.ban_recompute_materialized_backward, 1556 ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist, 1557 ban_if_reduction=config.ban_recompute_reductions, 1558 ) 1559 1560 if config.aggressive_recomputation: 1561 min_cut_options = replace( 1562 min_cut_options, 1563 ban_if_used_far_apart=False, 1564 ban_if_long_fusible_chains=False, 1565 ban_if_materialized_backward=False, 1566 ban_if_not_in_allowlist=False, 1567 ) 1568 if memory_budget == 0: 1569 return node_info.inputs 1570 1571 runtime_optimized_saved_values, _ = solve_min_cut( 1572 joint_graph, 1573 node_info, 1574 min_cut_options, 1575 ) 1576 # return runtime_optimized_saved_values 1577 if memory_budget == 1: 1578 return runtime_optimized_saved_values 1579 1580 def estimate_activations_size(saved_values: List[fx.Node]) -> float: 1581 return sum(map(_size_of, saved_values)) / 1e9 1582 1583 min_act_size = estimate_activations_size(node_info.inputs) 1584 max_act_size = estimate_activations_size(runtime_optimized_saved_values) 1585 # The optimized choice is smaller than the inputs anyways 1586 if max_act_size <= min_act_size: 1587 return runtime_optimized_saved_values 1588 1589 def get_normalized_size(sz): 1590 return (sz / 1e9) / (max_act_size - min_act_size) 1591 1592 def get_mem_ratio(activations: List[fx.Node]): 1593 return (estimate_activations_size(activations) - min_act_size) / ( 1594 max_act_size - min_act_size 1595 ) 1596 1597 more_aggressive_options = replace( 1598 min_cut_options, 1599 ban_if_used_far_apart=False, 1600 ban_if_long_fusible_chains=False, 1601 ban_if_materialized_backward=False, 1602 ) 1603 more_aggressive_saved_values, _ = solve_min_cut( 1604 joint_graph, node_info, more_aggressive_options 1605 ) 1606 if get_mem_ratio(more_aggressive_saved_values) < memory_budget: 1607 return more_aggressive_saved_values 1608 1609 aggressive_options = replace( 1610 more_aggressive_options, 1611 ban_if_not_in_allowlist=False, 1612 ) 1613 aggressive_recomputation_saved_values, banned_nodes = solve_min_cut( 1614 joint_graph, node_info, aggressive_options 1615 ) 1616 1617 if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget: 1618 return aggressive_recomputation_saved_values 1619 1620 from torch._inductor.fx_utils import get_node_storage 1621 1622 input_storages = {get_node_storage(node) for node in node_info.inputs} 1623 1624 def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]: 1625 return [ 1626 i 1627 for i in banned_nodes 1628 if ( 1629 # Only allow recomputing nodes that are actually required for BW 1630 i.dist_from_bw < int(1e9) # type: ignore[attr-defined] 1631 and get_node_storage(i) not in input_storages 1632 ) 1633 ] 1634 1635 recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes) 1636 1637 # default: runtime_optimized_saved_values 1638 # more aggressive: more_aggressive_saved_values 1639 # full aggressive: aggressive_recomputation_saved_values 1640 1641 all_recomputable_banned_nodes = sorted( 1642 recomputable_banned_nodes, key=_size_of, reverse=True 1643 ) 1644 if len(all_recomputable_banned_nodes) == 0: 1645 return node_info.inputs 1646 memories_banned_nodes = [ 1647 get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes 1648 ] 1649 runtimes_banned_nodes = [ 1650 estimate_runtime(node) for node in all_recomputable_banned_nodes 1651 ] 1652 from torch.utils._mode_utils import no_dispatch 1653 1654 def get_saved_values_knapsack(memory_budget): 1655 with no_dispatch(): 1656 ( 1657 expected_runtime, 1658 saved_node_idxs, 1659 recomputable_node_idxs, 1660 ) = _optimize_runtime_with_given_memory( 1661 memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0) 1662 ) 1663 dont_ban = set() 1664 for idx in recomputable_node_idxs: 1665 dont_ban.add(all_recomputable_banned_nodes[idx]) 1666 assert dont_ban.issubset(all_recomputable_banned_nodes) 1667 1668 saved_values, _ = solve_min_cut( 1669 joint_graph, 1670 node_info, 1671 aggressive_options, 1672 dont_ban, 1673 ) 1674 return saved_values, expected_runtime 1675 1676 if config.visualize_memory_budget_pareto: 1677 options = [] 1678 for sweep_memory_budget in range(100, -1, -5): 1679 saved_values, expected_runtime = get_saved_values_knapsack( 1680 sweep_memory_budget / 100 1681 ) 1682 options.append( 1683 ( 1684 sweep_memory_budget, 1685 sum(runtimes_banned_nodes) - expected_runtime, 1686 get_mem_ratio(saved_values), 1687 ) 1688 ) 1689 1690 import matplotlib.pyplot as plt 1691 1692 x_values = [item[2] for item in options] 1693 y_values = [item[1] for item in options] 1694 1695 # Plotting the values with updated axis labels and chart title 1696 plt.figure(figsize=(10, 6)) 1697 plt.plot(x_values, y_values, marker="o") 1698 1699 # Adding labels for each point 1700 for i, txt in enumerate(x_values): 1701 plt.annotate( 1702 f"{txt:.2f}", 1703 (txt, y_values[i]), 1704 textcoords="offset points", 1705 xytext=(0, 10), 1706 ha="center", 1707 ) 1708 1709 plt.xlabel("Memory Budget") 1710 plt.ylabel("Runtime of Recomputed Components") 1711 plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime") 1712 plt.grid(True) 1713 fig = plt.gcf() 1714 plt.show() 1715 fig_name = f"memory_budget_pareto_{get_aot_graph_name()}.png" 1716 fig.savefig(fig_name) 1717 log.warning("Generated Pareto frontier curve at %s", fig_name) 1718 1719 # todo(chilli): Estimated doesn't align exactly with actual - actual is 1720 # usually less memory than estimated. i'm guessing (actually quite 1721 # unsure about this) that's because estimated is just only including 1722 # tensors we actually banned from recompute, but there may be other 1723 # tensors that we choose to save. 1724 1725 return get_saved_values_knapsack(memory_budget=memory_budget)[0] 1726 1727 1728def min_cut_rematerialization_partition( 1729 joint_module: fx.GraphModule, 1730 _joint_inputs, 1731 compiler="inductor", 1732 *, 1733 num_fwd_outputs, 1734) -> Tuple[fx.GraphModule, fx.GraphModule]: 1735 """ 1736 Partitions the joint graph such that the backward recomputes the forward. 1737 Recomputing helps in trading off memory bandwidth with computation. 1738 1739 To create the fwd and bwd graph, we copy the joint graph, manually set the 1740 outputs to just original forward or backward outputs. And then we run the 1741 resulting graphs through dead code elimination. 1742 1743 .. warning:: 1744 This API is experimental and likely to change. 1745 1746 Args: 1747 joint_module(fx.GraphModule): The joint forward and backward graph. This 1748 is the result of AOT Autograd tracing. 1749 _joint_inputs: The inputs to the joint graph. This is unused. 1750 compiler: This option determines the default set of recomputable ops. 1751 Currently, there are two options: ``nvfuser`` and ``inductor``. 1752 recomputable_ops: This is an optional set of recomputable ops. If this 1753 is not None, then this set of ops will be used instead of the 1754 default set of ops. 1755 num_fwd_outputs: The number of outputs from the forward graph. 1756 1757 Returns: 1758 Returns the generated forward and backward Fx graph modules. 1759 """ 1760 1761 joint_module.graph.eliminate_dead_code() 1762 joint_module.recompile() 1763 1764 fx_g = joint_module.graph 1765 1766 # add the CSE pass 1767 if config.cse: 1768 cse_graph = fx_graph_cse(fx_g) 1769 joint_module.graph = cse_graph 1770 joint_graph = joint_module.graph 1771 1772 graph_has_recomputable_ops = has_recomputable_ops(joint_module) 1773 graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) 1774 if graph_has_recomputable_ops: 1775 joint_module = cleanup_recompute_tags(joint_module) 1776 1777 def classify_nodes(joint_module): 1778 name_to_node = get_name_to_node(joint_module.graph) 1779 required_bw_nodes = set() 1780 for node in joint_module.graph.nodes: 1781 if node.op == "placeholder" and "tangents" in node.target: 1782 required_bw_nodes.add(node) 1783 elif _must_be_in_backward(node): 1784 required_bw_nodes.add(node) 1785 1786 if node in required_bw_nodes: 1787 for user in node.users: 1788 required_bw_nodes.add(user) 1789 1790 primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) 1791 fwd_seed_offset_inputs = list( 1792 filter(_is_fwd_seed_offset, joint_module.graph.nodes) 1793 ) 1794 inputs = primal_inputs + fwd_seed_offset_inputs 1795 fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( 1796 joint_module, num_fwd_outputs=num_fwd_outputs 1797 ) 1798 required_bw_nodes.update( 1799 o for o in bwd_outputs if o is not None and o.op != "output" 1800 ) 1801 forward_only_graph = _extract_graph_with_inputs_outputs( 1802 joint_module.graph, inputs, fwd_outputs, "forward" 1803 ) 1804 required_fw_nodes: Set[fx.Node] = { 1805 name_to_node[node.name] 1806 for node in forward_only_graph.nodes 1807 if node.op != "output" 1808 } 1809 unclaimed_nodes = { 1810 node 1811 for node in joint_module.graph.nodes 1812 if node not in required_fw_nodes and node not in required_bw_nodes 1813 } 1814 fw_cnt = 0 1815 fw_order = {} 1816 for node in joint_module.graph.nodes: 1817 if node in required_fw_nodes: 1818 fw_order[node] = fw_cnt 1819 fw_cnt += 1 1820 return NodeInfo( 1821 inputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, fw_order 1822 ) 1823 1824 node_info = classify_nodes(joint_module) 1825 1826 # networkx blows up on graphs with no required backward nodes 1827 # Since there's nothing to partition anyway, and the default partitioner can "handle" 1828 # this case, send our graph over to the default partitioner. 1829 if len(node_info.required_bw_nodes) == 0: 1830 return default_partition( 1831 joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs 1832 ) 1833 1834 for node in reversed(joint_module.graph.nodes): 1835 if node.op == "output": 1836 node.dist_from_bw = int(1e9) 1837 elif not node_info.is_required_fw(node): 1838 node.dist_from_bw = 0 1839 else: 1840 node.dist_from_bw = int(1e9) 1841 for user in node.users: 1842 node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) 1843 1844 memory_budget = config.activation_memory_budget 1845 for node in joint_graph.nodes: 1846 if isinstance(node.meta.get("memory_budget", None), float): 1847 memory_budget = node.meta["memory_budget"] 1848 break 1849 # print("Memory Budget: ", memory_budget) 1850 saved_values = choose_saved_values_set( 1851 joint_graph, node_info, memory_budget=memory_budget 1852 ) 1853 # save_for_backward on tensors and stashes symints in autograd .ctx 1854 saved_sym_nodes = list(filter(is_sym_node, saved_values)) 1855 saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) 1856 1857 # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols 1858 fw_module, bw_module = _extract_fwd_bwd_modules( 1859 joint_module, 1860 saved_values, 1861 saved_sym_nodes=saved_sym_nodes, 1862 num_fwd_outputs=num_fwd_outputs, 1863 ) 1864 1865 if graph_has_recomputable_ops: 1866 if graph_has_recomputable_rng_ops: 1867 fw_module, bw_module = functionalize_rng_ops( 1868 joint_module, fw_module, bw_module, len(saved_sym_nodes) 1869 ) 1870 bw_module = reordering_to_mimic_autograd_engine(bw_module) 1871 1872 if AOT_PARTITIONER_DEBUG: 1873 from torch._inductor.fx_utils import get_node_storage 1874 1875 storages = {get_node_storage(node) for node in saved_values} 1876 print( 1877 "Theoretical Activations Stored: ", 1878 sum(_size_of(i) for i in saved_values) / 1e9, 1879 ) 1880 sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values]) 1881 fw_module_nodes = { 1882 node.name for node in fw_module.graph.nodes if node.op == "call_function" 1883 } 1884 bw_module_nodes = { 1885 node.name for node in bw_module.graph.nodes if node.op == "call_function" 1886 } 1887 remat_nodes = fw_module_nodes & bw_module_nodes 1888 1889 counts: Dict[str, int] = defaultdict(int) 1890 for node in fw_module.graph.nodes: 1891 if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"): 1892 counts[str(node.target._overloadpacket)] += 1 1893 print( 1894 f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}" 1895 ) 1896 print( 1897 "Count of Ops Rematerialized: ", 1898 sorted(counts.items(), key=lambda x: x[1], reverse=True), 1899 ) 1900 return fw_module, bw_module 1901 1902 1903def draw_graph( 1904 traced: torch.fx.GraphModule, 1905 fname: str, 1906 figname: str = "fx_graph", 1907 clear_meta: bool = True, 1908 prog: Optional[Union[str, List[str]]] = None, 1909 parse_stack_trace: bool = False, 1910 dot_graph_shape: Optional[str] = None, 1911) -> None: 1912 if clear_meta: 1913 new_graph = copy.deepcopy(traced.graph) 1914 traced = fx.GraphModule(traced, new_graph) 1915 for node in traced.graph.nodes: 1916 node.meta = {} 1917 base, ext = os.path.splitext(fname) 1918 if not ext: 1919 ext = "." + config.torch_compile_graph_format 1920 print(f"Writing FX graph to file: {base}{ext}") 1921 g = graph_drawer.FxGraphDrawer( 1922 traced, 1923 figname, 1924 parse_stack_trace=parse_stack_trace, 1925 dot_graph_shape=dot_graph_shape, 1926 ) 1927 x = g.get_main_dot_graph() 1928 write_method = getattr(x, "write_" + ext.lstrip(".")) 1929 fname = f"{base}{ext}" 1930 if prog is None: 1931 write_method(fname) 1932 else: 1933 write_method(fname, prog=prog) 1934