xref: /aosp_15_r20/external/pytorch/torch/_functorch/partitioners.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import functools
4import heapq
5import itertools
6import logging
7import math
8import operator
9import os
10from collections import defaultdict
11from dataclasses import dataclass, replace
12from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
13
14import torch
15import torch._inductor.inductor_prims
16import torch.fx as fx
17import torch.utils._pytree as pytree
18from torch.fx.experimental._backward_state import BackwardState
19from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
20from torch.fx.experimental.sym_node import magic_methods, method_to_operator
21from torch.fx.experimental.symbolic_shapes import (
22    find_symbol_binding_fx_nodes,
23    free_symbols,
24    hint_int,
25    is_symbol_binding_fx_node,
26)
27from torch.fx.passes import graph_drawer
28from torch.utils.checkpoint import CheckpointPolicy
29
30from . import config
31from ._aot_autograd.logging_utils import get_aot_graph_name
32from ._aot_autograd.utils import is_with_effects
33from .compile_utils import fx_graph_cse, get_aten_target
34
35
36if TYPE_CHECKING:
37    import sympy
38
39
40AOT_PARTITIONER_DEBUG = config.debug_partitioner
41log = logging.getLogger(__name__)
42
43aten = torch.ops.aten
44prims = torch.ops.prims
45
46
47@dataclass
48class OpTypes:
49    """Class for keeping track of different operator categories"""
50
51    fusible_ops: Set[Callable]
52    compute_intensive_ops: Set[Callable]
53    random_ops: Set[Callable]
54    view_ops: Set[Callable]
55    recomputable_ops: Set[Callable]
56
57    def is_fusible(self, node: fx.Node):
58        return get_aten_target(node) in self.fusible_ops
59
60    def is_compute_intensive(self, node: fx.Node):
61        return get_aten_target(node) in self.compute_intensive_ops
62
63    def is_random(self, node: fx.Node):
64        return get_aten_target(node) in self.random_ops
65
66    def is_view(self, node: fx.Node):
67        return get_aten_target(node) in self.view_ops
68
69    def is_recomputable(self, node: fx.Node):
70        return get_aten_target(node) in self.recomputable_ops
71
72
73@dataclass
74class NodeInfo:
75    # Be careful about iterating over these explicitly, as their order may not
76    # be deterministic
77    inputs: List[fx.Node]
78    _required_fw_nodes: Set[fx.Node]
79    required_bw_nodes: Set[fx.Node]
80    unclaimed_nodes: Set[fx.Node]
81    fw_order: Dict[fx.Node, int]
82
83    @functools.cached_property
84    def required_fw_nodes(self) -> List[fx.Node]:
85        return sorted(
86            (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n]
87        )
88
89    def is_required_fw(self, n: fx.Node) -> bool:
90        return n in self._required_fw_nodes
91
92    def is_required_bw(self, n: fx.Node) -> bool:
93        return n in self.required_bw_nodes
94
95    def is_unclaimed(self, n: fx.Node) -> bool:
96        return n in self.unclaimed_nodes
97
98    def get_fw_order(self, n: fx.Node) -> int:
99        assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!"
100        return self.fw_order[n]
101
102
103@dataclass
104class MinCutOptions:
105    ban_if_used_far_apart: bool
106    ban_if_long_fusible_chains: bool
107    ban_if_materialized_backward: bool
108    ban_if_not_in_allowlist: bool
109    ban_if_reduction: bool
110
111
112def must_recompute(node: fx.Node) -> bool:
113    return node.meta.get("recompute", None) in [
114        CheckpointPolicy.MUST_RECOMPUTE,
115        CheckpointPolicy.PREFER_RECOMPUTE,
116    ]
117
118
119def has_recomputable_ops(fx_g: fx.GraphModule) -> bool:
120    found = False
121    for node in fx_g.graph.nodes:
122        if must_recompute(node):
123            return True
124    return False
125
126
127def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool:
128    for node in fx_g.graph.nodes:
129        if (
130            must_recompute(node)
131            and hasattr(node.target, "tags")
132            and torch.Tag.nondeterministic_seeded in node.target.tags
133        ):
134            return True
135    return False
136
137
138def sym_node_size(node: fx.Node) -> int:
139    if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)):
140        return 1
141    assert isinstance(node.meta["val"], torch.SymFloat)
142    return 4
143
144
145class InvalidNodeBase:
146    def __repr__(self):
147        return "Invalid Node"
148
149
150InvalidNode = InvalidNodeBase()
151
152
153def _extract_graph_with_inputs_outputs(
154    joint_graph: fx.Graph,
155    inputs: List[fx.Node],
156    outputs: List[fx.Node],
157    subgraph: Optional[str] = None,
158) -> fx.Graph:
159    """
160    Given a graph, extracts out a subgraph that takes the specified nodes as
161    inputs and returns the specified outputs.
162
163    This includes specifying non-placeholder nodes as inputs.
164
165    The general strategy is to initialize all inputs with proxies as we
166    encounter them, and trace through the graph, only keeping values which take
167    in valid proxies. Then, all dead code is eliminated.
168    """
169    new_graph = fx.Graph()
170    env = {}
171
172    # Add new placeholder nodes in the order specified by the inputs
173    for node in inputs:
174        new_node = new_graph.placeholder(node.name)
175        # Can't use node_copy here as we may be turning previous call_function into placeholders
176        new_node.meta = node.meta
177        env[node] = new_node
178
179    for node in joint_graph.nodes:
180        if _must_be_in_backward(node) and subgraph != "backward":
181            env[node] = InvalidNode  # type: ignore[assignment]
182            continue
183
184        if node in env:
185            # Node must be one of our inputs. (Any member of env which wasn't an
186            # input to start must have been created by this loop and won't be in
187            # joint_graph.nodes).
188            continue
189        elif node.op == "placeholder":
190            env[node] = InvalidNode  # type: ignore[assignment]
191        elif node.op == "call_function":
192            all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
193            all_args = [
194                isinstance(env[x], InvalidNodeBase)
195                for x in all_args
196                if isinstance(x, fx.Node)
197            ]
198            if any(all_args):
199                env[node] = InvalidNode  # type: ignore[assignment]
200                continue
201            env[node] = new_graph.node_copy(node, lambda x: env[x])
202        elif node.op == "get_attr":
203            env[node] = new_graph.node_copy(node, lambda x: env[x])
204        elif node.op == "output":
205            pass
206    output_values = []
207    for x in outputs:
208        if isinstance(x, fx.Node):
209            if x not in env:
210                raise RuntimeError(f"Node {x} couldn't be found in env")
211            assert not isinstance(
212                env[x], InvalidNodeBase
213            ), f"Node {x} was invalid, but is output"
214            output_values.append(env[x])
215        else:
216            output_values.append(x)
217    new_graph.output(tuple(output_values))
218
219    new_graph.eliminate_dead_code()
220    new_graph.lint()
221    return new_graph
222
223
224def _is_primal(node: fx.Node) -> bool:
225    return (
226        node.op == "placeholder"
227        and "tangents" not in str(node.target)
228        and not _is_bwd_seed_offset(node)
229        and not _is_fwd_seed_offset(node)
230    )
231
232
233def _is_tangent(node: fx.Node) -> bool:
234    return node.op == "placeholder" and "tangents" in str(node.target)
235
236
237def _is_bwd_seed_offset(node: fx.Node) -> bool:
238    return node.op == "placeholder" and (
239        "bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target)
240    )
241
242
243def _is_fwd_seed_offset(node: fx.Node) -> bool:
244    return node.op == "placeholder" and (
245        "fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target)
246    )
247
248
249def _is_backward_state(node: fx.Node) -> bool:
250    return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState)
251
252
253def _has_tag_is_backward(node: fx.Node) -> bool:
254    return node.meta.get("partitioner_tag", None) == "is_backward"
255
256
257def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
258    return node.meta.get("partitioner_tag", None) == "must_be_in_backward"
259
260
261def _must_be_in_backward(node: fx.Node) -> bool:
262    return _has_tag_must_be_in_backward(node) or (
263        _has_tag_is_backward(node) and is_with_effects(node)
264    )
265
266
267def _extract_fwd_bwd_outputs(
268    joint_module: fx.GraphModule, *, num_fwd_outputs
269) -> Tuple[List[fx.Node], List[fx.Node]]:
270    outputs = pytree.arg_tree_leaves(
271        *(node.args for node in joint_module.graph.find_nodes(op="output"))
272    )
273    fwd_outputs = outputs[:num_fwd_outputs]
274    bwd_outputs = outputs[num_fwd_outputs:]
275    return fwd_outputs, bwd_outputs
276
277
278def _remove_by_name(saved_values: List[fx.Node], name: str):
279    for saved_value in saved_values:
280        if saved_value.name == name:
281            saved_values.remove(saved_value)
282            break
283
284
285def _extract_fwd_bwd_modules(
286    joint_module: fx.GraphModule,
287    saved_values: List[fx.Node],
288    saved_sym_nodes: List[fx.Node],
289    *,
290    num_fwd_outputs: int,
291) -> Tuple[fx.GraphModule, fx.GraphModule]:
292    fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
293        joint_module, num_fwd_outputs=num_fwd_outputs
294    )
295    placeholders = joint_module.graph.find_nodes(op="placeholder")
296    primal_inputs = [*filter(_is_primal, placeholders)]
297    tangent_inputs = [*filter(_is_tangent, placeholders)]
298    fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)]
299    bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)]
300    backward_state_inputs = [*filter(_is_backward_state, placeholders)]
301
302    bwd_graph = _extract_graph_with_inputs_outputs(
303        joint_module.graph,
304        saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs,
305        bwd_outputs,
306        "backward",
307    )
308
309    for node in bwd_graph.find_nodes(op="placeholder"):
310        # This is to filter out saved values that don't actually end up being used by the backwards pass
311        if not node.users:
312            _remove_by_name(saved_values, node.name)
313            _remove_by_name(saved_sym_nodes, node.name)
314        elif _is_backward_state(node):
315            # BackwardState is saved directly
316            _remove_by_name(saved_values, node.name)
317            assert backward_state_inputs
318
319    # Now that we have the finalized list of saved values, we need to ensure
320    # we propagate all symbols which are referenced by backwards inputs.
321    # These are not directly used in the graph but are required for downstream
322    # sizevar assignment
323    saved_symbols: Set[sympy.Symbol] = set()
324    saved_sym_nodes_binding = []
325    saved_sym_nodes_derived = []
326
327    # Some symbols may already be bound in the directly saved_sym_nodes,
328    # keep track of them so we don't re-bind them
329    for node in saved_sym_nodes:
330        symbol = is_symbol_binding_fx_node(node)
331        if symbol:
332            saved_symbols.add(symbol)
333            saved_sym_nodes_binding.append(node)
334        else:
335            saved_sym_nodes_derived.append(node)
336
337    # Now go through all of the prospective backward inputs and track any
338    # other symbols we need to bind
339    symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph)
340    for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs):
341        if "val" not in node.meta:
342            continue
343        new_symbols = free_symbols(node.meta["val"]) - saved_symbols
344        # NB: Deterministic order please!
345        for s in sorted(new_symbols, key=lambda s: s.name):
346            # NB: For well formed graphs, the symbol should always be present,
347            # but we also have ways to produce ill-formed graphs, e.g., direct
348            # make_fx usages, so don't choke in this case
349            if s not in symbol_bindings:
350                continue
351            saved_sym_nodes_binding.append(symbol_bindings[s])
352        saved_symbols |= new_symbols
353
354    # Update saved_sym_nodes that are now reordered to have all bindings at
355    # front. This can also be used later on to figure out the position of saved
356    # sym nodes in the output of fwd graph.
357    saved_sym_nodes.clear()
358    saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived)
359
360    # Now, we re-generate the fwd/bwd graphs.
361    # NB: This might increase compilation time, but I doubt it matters
362    fwd_graph = _extract_graph_with_inputs_outputs(
363        joint_module.graph,
364        primal_inputs + fwd_seed_offset_inputs,
365        fwd_outputs + saved_values + saved_sym_nodes,
366        "forward",
367    )
368    bwd_graph = _extract_graph_with_inputs_outputs(
369        joint_module.graph,
370        saved_sym_nodes
371        + saved_values
372        + tangent_inputs
373        + bwd_seed_offset_inputs
374        + backward_state_inputs,
375        bwd_outputs,
376        "backward",
377    )
378
379    fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph)
380    bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph)
381    return fwd_module, bwd_module
382
383
384def default_partition(
385    joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs
386) -> Tuple[fx.GraphModule, fx.GraphModule]:
387    """
388    Partitions the :attr:`joint_module` in a manner that closely resembles the
389    behavior observed in the original ``.forward()`` and ``.backward()`` of the
390    callable, i.e., the resulting forward graph contains those operators that
391    are executed in the original ``.forward()`` callable passed to
392    :func:`aot_function`.
393
394    The default partitioner collects the operators that are between the forward
395    inputs and the forward outputs. This helps in finding the tensors which have
396    to be stashed for the backward pass. These stashed tensors become the output
397    of the generated forward graph. The remaining operators are then placed in
398    the backward graph.
399
400    .. warning::
401        This API is experimental and likely to change.
402
403    Args:
404        joint_module(fx.GraphModule): The joint forward and backward graph. This
405            is the result of AOT Autograd tracing.
406
407    Returns:
408        Returns the generated forward and backward Fx graph modules.
409    """
410    if has_recomputable_ops(joint_module):
411        return min_cut_rematerialization_partition(
412            joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs
413        )
414    primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
415    fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
416    inputs = primal_inputs + fwd_seed_offset_inputs
417    fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
418        joint_module, num_fwd_outputs=num_fwd_outputs
419    )
420    forward_only_graph = _extract_graph_with_inputs_outputs(
421        joint_module.graph, inputs, fwd_outputs, "forward"
422    )
423    forward_node_names = {
424        node.name for node in forward_only_graph.nodes if node.op != "output"
425    }
426    saved_values = []
427    saved_sym_nodes = []
428
429    for node in joint_module.graph.nodes:
430        if node.name not in forward_node_names:
431            continue
432        if is_sym_node(node):
433            # Symints must be kept separate from tensors so that PythonFunction only calls
434            # save_for_backward on tensors and stashes symints in autograd .ctx
435            saved_sym_nodes.append(node)
436        elif "tensor_meta" not in node.meta and node.op == "call_function":
437            # Since we can't save tuple of tensor values, we need to flatten out what we're saving
438            users = node.users
439            assert all(user.target == operator.getitem for user in users)
440            saved_values.extend(users)
441        else:
442            backward_usages = [
443                n for n in node.users if n.name not in forward_node_names
444            ]
445            if "tensor_meta" in node.meta and all(
446                is_sym_node(n) for n in backward_usages
447            ):
448                # If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
449                # and not the actual tensor data,
450                # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
451                #
452                # Note that saving the tensor could also cause compilation problems:
453                # If the user mutated an input in the forward and uses its sizes/strides in the backward,
454                # then we would be obligated to clone the input before saving it to appease autograd.
455                # (This is how we originally found this bug).
456                saved_sym_nodes.extend(backward_usages)
457            else:
458                saved_values.append(node)
459    saved_values = list(dict.fromkeys(saved_values).keys())
460    saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
461
462    return _extract_fwd_bwd_modules(
463        joint_module,
464        saved_values,
465        saved_sym_nodes=saved_sym_nodes,
466        num_fwd_outputs=num_fwd_outputs,
467    )
468
469
470INT_INF = int(1e6)
471
472
473def _tensor_nbytes(numel: int, dtype) -> int:
474    return numel * dtype.itemsize
475
476
477def _size_of(node: fx.Node) -> int:
478    def object_nbytes(x) -> int:
479        if not isinstance(x, torch.Tensor):
480            return 0
481        return _tensor_nbytes(hint_int(x.numel(), fallback=4096), x.dtype)
482
483    if "val" in node.meta:
484        val = node.meta["val"]
485        if isinstance(val, py_sym_types):
486            return 1
487        # NB: The fallback values here are meaningless, maybe we should respect
488        # torch._inductor.config.unbacked_symint_fallback (but this is a
489        # layering violation)
490        elif isinstance(val, (list, tuple)):
491            return sum(object_nbytes(n) for n in val)
492        elif isinstance(val, dict):
493            return sum(object_nbytes(n) for _, n in val.items())
494        elif isinstance(val, torch.Tensor):
495            return object_nbytes(val)
496
497        raise RuntimeError(f"Unknown metadata type {type(val)} on node {node}")
498    if node.op == "get_attr":
499        return 0
500    raise RuntimeError(
501        f"Node {node} didn't have `val` metadata; we should always have `val` metadata on the nodes."
502    )
503
504
505# Used for some investigative purposes
506def _count_ops(graph: fx.Graph):
507    from collections import defaultdict
508
509    cnt: Dict[str, int] = defaultdict(int)
510    for node in graph.nodes:
511        if node.op == "call_function":
512            cnt[node.target.__name__] += 1
513    print(sorted(cnt.items(), key=lambda x: x[1], reverse=True))
514
515
516@functools.lru_cache(None)
517def pointwise_ops():
518    ops = []
519    for attr_name in dir(torch.ops.aten):
520        opoverloadpacket = getattr(torch.ops.aten, attr_name)
521        if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket):
522            continue
523
524        for overload in opoverloadpacket.overloads():
525            op_overload = getattr(opoverloadpacket, overload)
526            if torch.Tag.pointwise in op_overload.tags:
527                # currently aot autograd uses packet not overload
528                ops.append(opoverloadpacket)
529                break
530
531    return ops
532
533
534def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]:
535    arg_depths = {
536        arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node)
537    }
538    return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True)
539
540
541def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
542    """
543    This pass finds the first bwd node in the graph (by looking at users of
544    tangents) and then reorders the graph by walking from this node to all the
545    way to the end of the graph. At each op in this traveral, we insert this op
546    in a new graph and try to bring only the relevant subgraph from the other
547    non-bwd edges relevant for this op. This closely mimics the behavior of
548    autograd engine.
549
550    Why is this pass required in the first place?
551
552    This is an artifact of how partitioners work today. The starting point of
553    partitioner is a joint graph, which is fwd and then bwd graph. In the case
554    of checkpointing, we keep portions of fwd graph in their original place in
555    the joint graph, while obtaining a bwd graph. As a result, the resulting bwd
556    graph has copies of recomputed fwd subgraphs followed by the original bwd
557    graph. If we run this naively, this leads to bad memory footprint, because
558    the fwd subgraphs are live for way longer duration than necessary. This pass
559    reorders the operations such that we prioritize the ops for the original bwd
560    graph while only realizing those ops from the fwd graph that are necessary
561    at any given point in the graph.
562    """
563
564    new_graph = fx.Graph()
565    env: Dict[fx.Node, fx.Node] = {}
566
567    # Add new placeholder nodes in the order specified by the inputs
568    for node in gm.graph.find_nodes(op="placeholder"):
569        env[node] = new_graph.node_copy(node, lambda x: env[x])
570
571    order = {}
572    for idx, node in enumerate(gm.graph.nodes):
573        order[node] = idx
574
575    def insert_node_in_graph(node):
576        cur_nodes = [node]
577        insertable_nodes = set()
578        while len(cur_nodes) > 0:
579            node = cur_nodes.pop()
580            if node in insertable_nodes or node in env:
581                continue
582            insertable_nodes.add(node)
583
584            # Bias traversal towards the nodes that have higher depth - prioritizes
585            # critical path first.
586            cur_nodes += node.all_input_nodes
587
588        insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n])
589        for node in insertable_nodes:
590            env[node] = new_graph.node_copy(node, lambda x: env[x])
591
592    # Find first bwd node in the graph
593    tangent_inputs = list(filter(_is_tangent, gm.graph.nodes))
594    first_node_in_bwd = None
595    minimum_order = math.inf
596    for tangent in tangent_inputs:
597        for user in tangent.users:
598            if order[user] < minimum_order:
599                minimum_order = order[user]
600                first_node_in_bwd = user
601
602    # If gradInp does not depend upon gradOut, we may not find any nodes in the "backwards pass"
603    if first_node_in_bwd is None:
604        return gm
605
606    # Build the graph op-by-op by starting from the node all the way to the end
607    for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]:
608        insert_node_in_graph(node)
609
610    # The output node is already built by the traversal.
611    new_gm = torch.fx.GraphModule(gm, new_graph)
612    return new_gm
613
614
615def functionalize_rng_ops(
616    joint_module: fx.GraphModule,
617    fw_module: fx.GraphModule,
618    bw_module: fx.GraphModule,
619    num_sym_nodes: int,
620) -> Tuple[fx.GraphModule, fx.GraphModule]:
621    # During user-driven activation checkpointing, we have to ensure that a rng
622    # op in fwd yields the same output as the recomputed rng op in the bwd.  To
623    # do this, we use functionalize wrappers to wrap the random ops and share
624    # rng state between the fwd and bwd graphs.
625
626    # There are 3 main steps to do this
627    # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.
628    # Step 2 - Modify the fwd pass such that
629    #   1) Replace rand with run_and_save_rng_state wrapper
630    #   2) Replace the users of the original op with the output[1] of this op.
631    #   3) Collect all the rng_state - output[0] of each op, and make them
632    #   output nodes. Special care needs to be taken here because fwd outputs
633    #   has symints at the very end.
634    # Step 3 - Modify the bwd pass such that
635    #   1) Add the input nodes just before the tangents for the stashed rng states
636    #   2) Replace rand with run_with_save_rng_state wrappers
637    #   3) Use the stashed states as inputs to these ops
638
639    # Unique id to generate name
640    uid = itertools.count()
641
642    def get_rng_ops(gmod):
643        random_nodes = {}
644        for node in gmod.graph.nodes:
645            if (
646                node.op == "call_function"
647                and hasattr(node.target, "tags")
648                and torch.Tag.nondeterministic_seeded in node.target.tags
649            ):
650                random_nodes[node.name] = node
651        return random_nodes
652
653    def get_device(node):
654        """
655        Check the example value of the node outputs to find the device type.
656        """
657        if "val" not in node.meta:
658            return None
659
660        candidates = node.meta["val"]
661        if not isinstance(candidates, tuple):
662            candidates = (candidates,)
663
664        for candidate in candidates:
665            if isinstance(candidate, torch.Tensor):
666                if candidate.device.type == "cuda":
667                    return "cuda"
668
669        return "cpu"
670
671    def get_sample_rng_state(device):
672        if device == "cuda":
673            return torch.cuda.get_rng_state()
674        return torch.get_rng_state()
675
676    # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.
677    joint_graph_rng_ops = get_rng_ops(joint_module)
678    fw_graph_rng_ops = get_rng_ops(fw_module)
679    bw_graph_rng_ops = get_rng_ops(bw_module)
680    recomputable_rng_ops_map = {}
681    for node in joint_module.graph.nodes:
682        if (
683            must_recompute(node)
684            and hasattr(node.target, "tags")
685            and torch.Tag.nondeterministic_seeded in node.target.tags
686        ):
687            base_node = joint_graph_rng_ops[node.name]
688            fw_node = fw_graph_rng_ops[node.name]
689            bw_node = bw_graph_rng_ops[node.name]
690            recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node}
691
692    run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state
693    run_with_rng_state = torch._prims.rng_prims.run_with_rng_state
694    bw_tangent_start_node = None
695    for node in bw_module.graph.find_nodes(op="placeholder"):
696        if "tangent" in node.name:
697            bw_tangent_start_node = node
698            break
699    if bw_tangent_start_node is None:
700        raise RuntimeError(
701            "Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this"
702        )
703
704    fw_rng_state_outputs = []
705    for base_node, node_pair in recomputable_rng_ops_map.items():
706        # Step 2 - Modify the fwd pass such that
707        fw_node = node_pair["fwd"]
708        bw_node = node_pair["bwd"]
709        fw_graph = fw_module.graph
710        with fw_graph.inserting_before(fw_node):
711            functional_fw_node = fw_graph.create_node(
712                "call_function",
713                run_and_save_rng,
714                args=(fw_node.target, *fw_node.args),
715                kwargs=fw_node.kwargs,
716            )
717            state = fw_graph.create_node(
718                "call_function",
719                operator.getitem,
720                args=(functional_fw_node, 0),
721                kwargs={},
722            )
723            rng_output = fw_graph.create_node(
724                "call_function",
725                operator.getitem,
726                args=(
727                    functional_fw_node,
728                    1,
729                ),
730                kwargs={},
731            )
732            fw_node.replace_all_uses_with(rng_output)
733            fw_graph.erase_node(fw_node)
734            fw_rng_state_outputs.append(state)
735
736        # Step 3 - Modify the bwd pass such that
737        bw_graph = bw_module.graph
738        with bw_graph.inserting_before(bw_tangent_start_node):
739            state_name = f"rng_state_output_{next(uid)}"
740            bw_rng_state_node = bw_graph.placeholder(state_name)
741            bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node))
742
743        with bw_graph.inserting_before(bw_node):
744            rng_output = bw_graph.create_node(
745                "call_function",
746                run_with_rng_state,
747                args=(bw_rng_state_node, bw_node.target, *bw_node.args),
748                kwargs=bw_node.kwargs,
749            )
750
751            bw_node.replace_all_uses_with(rng_output)
752            bw_graph.erase_node(bw_node)
753
754    # Add the rng states in the output of the fwd graph. AOT Autograd assumes
755    # that symints are at the end of forward graph outputs. So, insert the new
756    # rng states accordingly.
757    fw_output_node = next(iter(fw_module.graph.find_nodes(op="output")))
758    fw_outputs = fw_output_node.args[0]
759    sym_node_start_idx = len(fw_outputs) - num_sym_nodes
760    outputs = (
761        fw_outputs[:sym_node_start_idx]
762        + tuple(fw_rng_state_outputs)
763        + fw_outputs[sym_node_start_idx:]
764    )
765    fw_module.graph.output(outputs)
766    fw_module.graph.erase_node(fw_output_node)
767    fw_module.recompile()
768    bw_module.recompile()
769    return fw_module, bw_module
770
771
772def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
773    """
774    If there are two consecutive checkpointed blocks with no operator in
775    between, we would still want to stash the tensor at the boundary of
776    checkpointed blocks. The following pass makes the last output node
777    non-recomputable to allow for that.
778    """
779    for node in joint_module.graph.nodes:
780        if must_recompute(node):
781            for user in node.users:
782                if (
783                    must_recompute(user)
784                    and user.meta["ac_graph_id"] > node.meta["ac_graph_id"]
785                ):
786                    node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
787    return joint_module
788
789
790def solve_min_cut(
791    joint_graph: fx.Graph,
792    node_info: NodeInfo,
793    min_cut_options: MinCutOptions,
794    dont_ban=None,
795):
796    if dont_ban is None:
797        dont_ban = set()
798    op_types = get_default_op_list()
799
800    if AOT_PARTITIONER_DEBUG:
801        joint_module_ops = {
802            str(node.target._overloadpacket)
803            for node in joint_graph.nodes
804            if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
805        }
806        ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops}
807        print("Ops banned from re-materialization: ", ops_ignored)
808        print()
809
810    def can_fuse_into_auto_functionalized(a, b):
811        if b.target != torch.ops.higher_order.auto_functionalized:
812            return False
813        mutable_op = b.args[0]
814        (
815            mutable_arg_names,
816            _,
817        ) = torch._higher_order_ops.auto_functionalize.get_mutable_args(mutable_op)
818        for name in mutable_arg_names:
819            arg = b.kwargs[name]
820            if a is arg:
821                return True
822            if isinstance(arg, list):
823                if a in arg:
824                    return True
825        return False
826
827    def can_fuse_into_triton_kernel_wrapper_functional(a, b):
828        if b.target != torch.ops.higher_order.triton_kernel_wrapper_functional:
829            return False
830        mutable_arg_names = b.kwargs["tensors_to_clone"]
831        for name in mutable_arg_names:
832            arg = b.kwargs["kwargs"][name]
833            if a is arg:
834                return True
835        return False
836
837    def is_fusible(a, b):
838        # We can perform "memory fusion" into a cat, but cat cannot be a
839        # producer to a fusion
840        if get_aten_target(b) == aten.cat:
841            return True
842        if can_fuse_into_auto_functionalized(a, b):
843            return True
844        if can_fuse_into_triton_kernel_wrapper_functional(a, b):
845            return True
846        return op_types.is_fusible(a) and op_types.is_fusible(b)
847
848    try:
849        import networkx as nx
850    except ImportError as e:
851        raise RuntimeError(
852            "Need networkx installed to perform smart recomputation " "heuristics"
853        ) from e
854
855    def is_materialized_backwards(node):
856        if op_types.is_view(node):
857            return False
858        cur_nodes = {node}
859        while len(cur_nodes) > 0:
860            cur = cur_nodes.pop()
861            for user in cur.users:
862                if not node_info.is_required_fw(user) and not is_fusible(cur, user):
863                    return True
864                if op_types.is_view(user):
865                    cur_nodes.add(user)
866
867        return False
868
869    def should_ban_recomputation(node):
870        if node.op != "call_function":
871            return False
872        if node.target == operator.getitem:
873            return False
874        if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE:
875            return True
876        if config.recompute_views and op_types.is_view(node):
877            return False
878        if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]:
879            return False
880
881        if min_cut_options.ban_if_not_in_allowlist:
882            if not op_types.is_recomputable(node):
883                return True
884        else:
885            if op_types.is_random(node) or op_types.is_compute_intensive(node):
886                return True
887
888        # If a node *must* be materialized in the backwards pass, then we
889        # should never recompute it. This is a pretty subtle point.  In
890        # general, the assumption we make is that recomputing a node in the
891        # backwards pass is "free". However, if a node must be materialized
892        # in the backwards pass, then recomputing it is never free.
893        if min_cut_options.ban_if_materialized_backward and is_materialized_backwards(
894            node
895        ):
896            log.info("materialized backwards: %s %s", node, tuple(node.users))
897            return True
898
899        # Arbitrary hack that sometimes seems to help things. The above
900        # modification appears to have made this heuristic a lot less critical
901        # for performance.
902        # NB: As of PR #121692, this hack no longer seems necessary.
903        if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw:
904            return True
905
906        # If the output of an op is 4x smaller (arbitrary choice),
907        # then we don't allow recomputation. The idea here is that for
908        # things like reductions, saving the output of the reduction is very
909        # cheap/small, and it makes sure we don't do things like recompute
910        # normalizations in the backwards.
911        if min_cut_options.ban_if_reduction:
912            input_tensors_size = sum(
913                _size_of(i) for i in node.args if isinstance(i, fx.Node)
914            )
915            output_size = _size_of(node)
916            return output_size * 4 < input_tensors_size
917        return False
918
919    def is_materialized(node):
920        if node.op == "placeholder":
921            return True
922
923        return not all(is_fusible(node, user) for user in node.users)
924
925    def get_node_weight(node) -> float:
926        mem_sz = _size_of(node)
927        if config.recompute_views and op_types.is_view(node):
928            # If `config.recompute_views=True`, we don't save views. This is generally
929            # a good idea since views are free to recompute, and it makes it a bit simpler
930            # to analyze.
931            # NB: If they're not free to recompute (e.g. nested tensors)... I
932            # think we should modify checks for view_ops to `is_view` and check
933            # that. Basically, with nested tensors, `aten.view` is not a "view
934            # op".
935            return math.inf
936
937        if isinstance(node.meta["val"], py_sym_types):
938            # We never want to save symfloats
939            if not isinstance(node.meta["val"], torch.SymInt):
940                return INT_INF
941
942        # Heuristic to bias towards nodes closer to the backwards pass
943        # Complete guess about current value
944        mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1)))
945        if is_materialized(node):
946            return mem_sz
947        else:
948            return mem_sz * 2
949
950    nx_graph = nx.DiGraph()
951    banned_nodes = set()
952
953    def ban_recomputation_if_allowed(node):
954        if op_types.is_view(node):
955            return False
956        if node in dont_ban:
957            return False
958        # This bans recomputation of the node unless we've been forced not to by
959        # user annotation
960        if must_recompute(node):
961            return False
962
963        if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat):
964            return False
965
966        banned_nodes.add(node)
967        # A node will only ever be recomputed if there is a path from an
968        # ancestor of this node to the backwards path through this node that
969        # doesn't go through any saved value. If this node is saved, then that
970        # condition is not possible.
971        nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)
972        return True
973
974    for node in joint_graph.nodes:
975        if node.op == "output":
976            continue
977
978        if node in node_info.required_bw_nodes:
979            if node not in node_info.inputs:
980                nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf)
981                continue
982            # If someone saves a input for backward as-is and backward
983            # returns that tensor as-is as a grad input, then the node x would
984            # be both a required_bw_node and an input. In this case we
985            # (1) connect x_in to to the source, (2) x_out to the sink, and
986            # (3) assign the proper weight to the x_in-x_out edge, so that
987            # x would be part of cut nodes. A case where this happens is if
988            # NestedTensor saves a offset tensor as part of the singleton int
989            # in sizes.
990            nx_graph.add_edge(node.name + "_out", "sink", capacity=math.inf)
991
992        if must_recompute(node):
993            # If user explicitly says they want to recompute a node, we honor it
994            # by adding an inf-capacity edge from X_in to the sink.
995            # This way, X_in node is guaranteed to be part of the subgraph that contains "sink"
996            # after the cut, thus guaranteeing that X op will be recomputed.
997            nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf)
998            continue
999
1000        if _is_primal(node) or _is_fwd_seed_offset(node):
1001            ban_recomputation_if_allowed(node)
1002
1003        # If a node can't be recomputed (too expensive or involves randomness),
1004        # we prevent it from being recomputed by adding an inf edge to the source
1005        # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.
1006        if node_info.is_required_fw(node) and should_ban_recomputation(node):
1007            ban_recomputation_if_allowed(node)
1008
1009        # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors.
1010        is_non_tensor_node = (
1011            "val" not in node.meta and "tensor_meta" not in node.meta
1012        ) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor))
1013
1014        if is_sym_node(node):
1015            weight = float(sym_node_size(node))
1016        elif is_non_tensor_node:
1017            weight = (
1018                0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf
1019            )
1020        else:
1021            weight = get_node_weight(node)
1022        # Creates the weights on the "node" edge
1023        nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight)
1024        for user in node.users:
1025            nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf)
1026
1027    # todo(chilli): This is the most questionable of the 3 heuristics for banning recompute.
1028    # Some example models to look at where this helps perf: poolformer_m36,
1029    # mixer_b16_224, cait_m36_384
1030
1031    # The "rough" idea here is that if you have some node that is used by both a
1032    # node nearby downstream as well as a node far downstream, if we recompute
1033    # both of the downstream nodes, we're unlikely to be able to fuse both
1034    # downstream nodes together.
1035
1036    # Thus, we shouldn't aim to recompute far downstream nodes that depend on
1037    # this node. That intuition of "far downstream" is captured by whether
1038    # there's an unfusible op along the chain somewhere
1039
1040    # It could probably be improved by properly analyzing what's going on in the
1041    # backwards pass instead of only relying on whether it's unfusible in the
1042    # forwards.
1043
1044    def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int:
1045        """
1046        Finds the first unfusible node in the chain of nodes starting from
1047        `start_nodes` and returns its position.
1048        """
1049        sorted_nodes: List[Tuple[int, fx.Node, bool]] = []
1050        for n in start_nodes:
1051            heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True))
1052
1053        while len(sorted_nodes) > 0:
1054            _, node, node_is_fusible = heapq.heappop(sorted_nodes)
1055            if not node_is_fusible:
1056                return node_info.get_fw_order(node)
1057            for user in node.users:
1058                if node_info.is_required_fw(user):
1059                    if node_info.get_fw_order(user) > max_range:
1060                        continue
1061                    heapq.heappush(
1062                        sorted_nodes,
1063                        (node_info.get_fw_order(user), user, is_fusible(node, user)),
1064                    )
1065        return max_range
1066
1067    if min_cut_options.ban_if_used_far_apart:
1068        for used_node in node_info.required_fw_nodes:
1069            orders = [
1070                node_info.get_fw_order(user)
1071                for user in used_node.users
1072                if node_info.is_required_fw(user)
1073            ]
1074            fw_users = [
1075                user for user in used_node.users if node_info.is_required_fw(user)
1076            ]
1077            if len(orders) > 0:
1078                first_unfusible_use = find_first_unfusible(fw_users, max(orders))
1079                for user in tuple(used_node.users):
1080                    if (
1081                        node_info.is_required_fw(user)
1082                        and node_info.get_fw_order(user) > first_unfusible_use
1083                        and is_fusible(used_node, user)
1084                    ):
1085                        if user in banned_nodes:
1086                            continue
1087                        log.info(
1088                            "used above/below fusible %s:(%s) -> %s -> %s:(%s)",
1089                            used_node,
1090                            node_info.get_fw_order(used_node),
1091                            first_unfusible_use,
1092                            user,
1093                            node_info.get_fw_order(user),
1094                        )
1095                        ban_recomputation_if_allowed(user)
1096
1097    # This heuristic is fairly straightforward. The idea is that although it is
1098    # cheap to recompute bandwidth-bound ops, we don't want to end up in a situation
1099    # where we have a long chain of pointwise ops from the beginning to the end
1100    # of the model (like say, residual connections)
1101
1102    # todo: I'm not totally sure why this heuristic matters. It's possible that this is
1103    # working around Inductor fusion decisions, or that it's a patch over
1104    # suboptimal partitioning decisions
1105
1106    # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36
1107
1108    if min_cut_options.ban_if_long_fusible_chains:
1109        visited = set()
1110        for start_node in joint_graph.nodes:
1111            if not node_info.is_required_fw(start_node):
1112                continue
1113            fusible = [(node_info.get_fw_order(start_node), start_node)]
1114            start_order = node_info.get_fw_order(start_node)
1115            while len(fusible) > 0:
1116                _, cur = heapq.heappop(fusible)
1117                if cur in visited:
1118                    continue
1119                visited.add(cur)
1120                # 100 is arbitrary choice to try and prevent degenerate cases
1121                if (
1122                    node_info.get_fw_order(cur) > start_order + 100
1123                    and len(fusible) == 0
1124                ):
1125                    log.info(
1126                        "too long %s %s %s %s",
1127                        cur,
1128                        start_node,
1129                        node_info.get_fw_order(cur),
1130                        node_info.get_fw_order(start_node),
1131                    )
1132                    ban_recomputation_if_allowed(cur)
1133                    break
1134
1135                for user in cur.users:
1136                    if (
1137                        node_info.is_required_fw(user)
1138                        and is_fusible(cur, user)
1139                        and user not in banned_nodes
1140                    ):
1141                        heapq.heappush(fusible, (node_info.get_fw_order(user), user))
1142
1143    try:
1144        cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
1145    except Exception:
1146        print("Failed to compute min-cut on following graph:")
1147        print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph)))
1148        visualize_min_cut_graph(nx_graph)
1149        raise
1150
1151    reachable, non_reachable = partition
1152    cutset: Set[Tuple[str, str]] = set()
1153    for u, nbrs in ((n, nx_graph[n]) for n in reachable):
1154        cutset.update((u, v) for v in nbrs if v in non_reachable)
1155
1156    cut_nodes = set()
1157    for node_in, node_out in cutset:
1158        assert node_in[:-3] == node_out[:-4]
1159        node_name = node_in[:-3]
1160        cut_nodes.add(node_name)
1161
1162    name_to_node = get_name_to_node(joint_graph)
1163    # To make this stuff deterministic
1164    node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)}
1165    saved_values = sorted(
1166        (name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x]
1167    )
1168    return saved_values, banned_nodes
1169
1170
1171def visualize_min_cut_graph(nx_graph):
1172    import networkx as nx
1173    import pydot
1174
1175    dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string()
1176    dot_graph = pydot.graph_from_dot_data(dot_format)[0]
1177    for edge in dot_graph.get_edges():
1178        weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"]
1179        # Set edge label to weight
1180        edge.set_label(str(weight))
1181        # Color edges with weight 'inf' as red
1182        if weight == float("inf"):
1183            edge.set_color("red")
1184    print("Visualizing the failed graph to min_cut_failed.svg")
1185    dot_graph.write_svg("min_cut_failed.svg")
1186
1187
1188def get_default_op_list() -> OpTypes:
1189    default_recomputable_ops: List[Callable] = [
1190        aten.add,
1191        aten.sub,
1192        aten.div,
1193        aten.atan2,
1194        aten.mul,
1195        aten.max,
1196        aten.min,
1197        aten.pow,
1198        aten.remainder,
1199        aten.fmod,
1200        aten.__and__,
1201        aten.__or__,
1202        aten.__xor__,
1203        aten.__lshift__,
1204        aten.__rshift__,
1205        aten.eq,
1206        aten.ne,
1207        aten.ge,
1208        aten.gt,
1209        aten.le,
1210        aten.lt,
1211        aten.abs,
1212        aten.bitwise_not,
1213        aten.ceil,
1214        aten.floor,
1215        aten.frac,
1216        aten.neg,
1217        aten.relu,
1218        aten.round,
1219        aten.silu,
1220        aten.trunc,
1221        aten.log,
1222        aten.log10,
1223        aten.log1p,
1224        aten.log2,
1225        aten.lgamma,
1226        aten.exp,
1227        aten.expm1,
1228        aten.erf,
1229        aten.erfc,
1230        aten.cos,
1231        aten.acos,
1232        aten.cosh,
1233        aten.sin,
1234        aten.asin,
1235        aten.sinh,
1236        aten.tan,
1237        aten.atan,
1238        aten.tanh,
1239        aten.atanh,
1240        aten.sqrt,
1241        aten.rsqrt,
1242        aten.reciprocal,
1243        aten.sigmoid,
1244        aten.softplus,
1245        aten.threshold,
1246        aten.threshold_backward,
1247        aten.clamp,
1248        aten.where,
1249        aten.lerp,
1250        aten.addcmul,
1251        aten.gelu,
1252        aten.gelu_backward,
1253        aten.sum,
1254        aten.mean,
1255        aten._grad_sum_to_size,
1256        aten.sum_to_size,
1257        aten.amax,
1258        aten.to,
1259        aten.type_as,
1260        operator.getitem,
1261        aten.squeeze,
1262        aten.unsqueeze,
1263        aten.rsub,
1264        aten._to_copy,
1265    ]  # noqa: E501,B950
1266    recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias]
1267    recomputable_view_ops += [
1268        aten.view,
1269        aten.slice,
1270        aten.t,
1271        prims.broadcast_in_dim,
1272        aten.expand,
1273        aten.as_strided,
1274        aten.permute,
1275    ]
1276    view_ops = recomputable_view_ops
1277    default_recomputable_ops += [
1278        prims.div,
1279        prims.convert_element_type,
1280        aten.clone,
1281        aten._to_copy,
1282        aten.full_like,
1283        prims.var,
1284        prims.sum,
1285        aten.var,
1286        aten.std,
1287        prims.broadcast_in_dim,
1288        aten.select,
1289        aten._unsafe_view,
1290        aten.view,
1291        aten.expand,
1292        aten.slice,
1293        aten.reshape,
1294        aten.broadcast_tensors,
1295        aten.scalar_tensor,
1296        aten.ones,
1297        aten.new_zeros,
1298        aten.lift_fresh_copy,
1299        aten.arange,
1300        aten.triu,
1301        aten.var_mean,
1302        aten.isinf,
1303        aten.any,
1304        aten.full,
1305        aten.as_strided,
1306        aten.zeros,
1307        aten.empty,
1308        aten.empty_like,
1309        aten.argmax,
1310        aten.maximum,
1311        prims.iota,
1312        prims._low_memory_max_pool2d_offsets_to_indices,
1313    ]  # noqa: E501,B950
1314    # Natalia said that we should allow recomputing indexing :)
1315    default_recomputable_ops += [aten.index, aten.gather]
1316    default_recomputable_ops += view_ops
1317
1318    default_recomputable_ops += pointwise_ops()
1319
1320    default_recomputable_ops += [
1321        aten.zeros_like,
1322    ]
1323
1324    default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
1325    recomputable_ops = set(default_recomputable_ops)
1326
1327    random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
1328    compute_intensive_ops = [
1329        aten.mm,
1330        aten.convolution,
1331        aten.convolution_backward,
1332        aten.bmm,
1333        aten.addmm,
1334        aten._scaled_dot_product_flash_attention,
1335        aten._scaled_dot_product_efficient_attention,
1336        aten._flash_attention_forward,
1337        aten._efficient_attention_forward,
1338        aten.upsample_bilinear2d,
1339        aten._scaled_mm,
1340    ]  # noqa: E501,B950
1341
1342    fusible_ops = recomputable_ops | set(random_ops)
1343    return OpTypes(
1344        set(fusible_ops),
1345        set(compute_intensive_ops),
1346        set(random_ops),
1347        set(view_ops),
1348        set(recomputable_ops),
1349    )
1350
1351
1352def get_name_to_node(graph: fx.Graph):
1353    name_to_node = {}
1354    for node in graph.nodes:
1355        name_to_node[node.name] = node
1356    return name_to_node
1357
1358
1359def greedy_knapsack(
1360    memory: List[float], runtimes: List[float], max_memory: float
1361) -> Tuple[float, List[int], List[int]]:
1362    n = len(runtimes)
1363    items = list(range(n))
1364
1365    # Sort items based on the ratio of runtime to memory in descending order
1366    items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True)
1367
1368    total_memory = 0.0
1369    total_runtime = 0.0
1370    items_to_save = []
1371    items_to_allow_recomputing = []
1372
1373    for i in items:
1374        if total_memory + memory[i] <= max_memory:
1375            total_memory += memory[i]
1376            total_runtime += runtimes[i]
1377            items_to_save.append(i)
1378        else:
1379            items_to_allow_recomputing.append(i)
1380    return total_runtime, items_to_save, items_to_allow_recomputing
1381
1382
1383def ilp_knapsack(
1384    memory: List[float], runtimes: List[float], max_memory: float
1385) -> Tuple[float, List[int], List[int]]:
1386    import numpy as np
1387
1388    try:
1389        from scipy.optimize import Bounds, LinearConstraint, milp
1390    except ImportError:
1391        raise RuntimeError(
1392            "To use the ILP for memory budget checkpointing you need to install scipy"
1393        ) from None
1394
1395    np_memory = np.array(memory)
1396    np_runtimes = np.array(runtimes)
1397    c = -np_runtimes  # type: ignore[operator]
1398
1399    memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory))
1400    constraints = [memory_constraint]
1401
1402    integrality = np.ones_like(c)
1403    res = milp(
1404        c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1)
1405    )
1406    if not res.success:
1407        raise RuntimeError("Somehow scipy solving failed")
1408
1409    items_to_save = []
1410    items_to_allow_recomputing = []
1411    for idx, i in enumerate(res.x):
1412        if i == 1:
1413            items_to_save.append(idx)
1414        else:
1415            items_to_allow_recomputing.append(idx)
1416    return -res.fun, items_to_save, items_to_allow_recomputing
1417
1418
1419def dp_knapsack(
1420    memory: List[float], runtimes: List[float], max_memory: float
1421) -> Tuple[float, List[int], List[int]]:
1422    # Scaling factor to convert floating point weights to integers
1423    S = 10000
1424
1425    # Quantize the memory weights
1426    quantized_memory = torch.tensor(
1427        [int(round(m * S)) for m in memory], dtype=torch.long, device="cpu"
1428    )
1429    runtimes = torch.tensor(runtimes, dtype=torch.float32, device="cpu")
1430
1431    # Quantized pseudopolynomial DP for 0-1 Knapsack
1432    quantized_max_memory = int(round(max_memory * S))
1433
1434    n = len(memory)
1435
1436    # Initialize the DP table
1437    # TODO(chilli): I think if needed, this memory can be optimized with sliding
1438    # window trick + Hirschberg trick:
1439    # https://codeforces.com/blog/entry/47247?#comment-316200
1440    dp = torch.zeros(
1441        (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu"
1442    )
1443
1444    for i in range(1, n + 1):
1445        current_memory = quantized_memory[i - 1]
1446        current_runtime = runtimes[i - 1]
1447
1448        # Copy the previous row
1449        dp[i, :] = dp[i - 1, :]
1450
1451        # Update dp[i, j] for all j >= current_memory
1452        if current_memory == 0:
1453            dp[i, :] = dp[i - 1, :] + current_runtime
1454        else:
1455            dp[i, current_memory:] = torch.maximum(
1456                dp[i - 1, current_memory:],
1457                dp[i - 1, :-current_memory] + current_runtime,
1458            )
1459
1460    # Backtrack to find the items included in the knapsack
1461    saved_items = []
1462    recomputable_items = []
1463    j: int = quantized_max_memory
1464    for i in range(n, 0, -1):
1465        if dp[i][j] != dp[i - 1][j]:
1466            saved_items.append(i - 1)  # Include this item (indexing from 0)
1467            j -= int(quantized_memory[i - 1].item())
1468        else:
1469            recomputable_items.append(i - 1)
1470
1471    saved_items.reverse()  # To get items in the order they were added
1472
1473    # The maximum runtime that can be achieved within the max_memory constraint
1474    max_runtime = dp[n][quantized_max_memory].item()
1475
1476    return max_runtime, saved_items, recomputable_items
1477
1478
1479def _optimize_runtime_with_given_memory(
1480    memory: List[float],
1481    runtimes: List[float],
1482    max_memory: float,
1483) -> Tuple[float, List[int], List[int]]:
1484    SOLVER = config.activation_memory_budget_solver
1485    if SOLVER == "greedy":
1486        return greedy_knapsack(memory, runtimes, max_memory)
1487    elif SOLVER == "ilp":
1488        return ilp_knapsack(memory, runtimes, max_memory)
1489    elif SOLVER == "dp":
1490        return dp_knapsack(memory, runtimes, max_memory)
1491    else:
1492        raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}")
1493
1494
1495from torch.utils._mode_utils import no_dispatch
1496
1497
1498def estimate_runtime(node):
1499    RUNTIME_MODE = config.activation_memory_budget_runtime_estimator
1500
1501    def materialize_arg(x):
1502        if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor):
1503            shape = list(x.meta["val"].shape)
1504
1505            def realize_symbol(d):
1506                return hint_int(d, fallback=4096)
1507
1508            shape = [realize_symbol(s) for s in shape]
1509            return x.meta["val"].new_empty_strided(
1510                shape, stride=x.meta["tensor_meta"].stride
1511            )
1512        elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt):
1513            return hint_int(x.meta["val"], fallback=4096)
1514        elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat):
1515            return 1.0
1516        elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool):
1517            return True
1518        else:
1519            return x
1520
1521    if RUNTIME_MODE == "testing":
1522        return 1
1523
1524    elif RUNTIME_MODE == "profile":
1525        with no_dispatch():
1526            from torch._inductor.runtime.benchmarking import benchmarker
1527
1528            args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs))
1529            ms = benchmarker.benchmark_gpu(lambda: node.target(*args, **kwargs))
1530            return ms
1531
1532    elif RUNTIME_MODE == "flops":
1533        # todo(chilli): Normalize this to also return ms
1534        from torch.utils.flop_counter import FlopCounterMode
1535
1536        args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs))
1537        with FlopCounterMode(display=False) as mode:
1538            node.target(*args, **kwargs)
1539        counted_flops = mode.get_total_flops()
1540        return max(counted_flops, 1)
1541    else:
1542        raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}")
1543
1544
1545def choose_saved_values_set(
1546    joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1
1547) -> List[fx.Node]:
1548    if memory_budget > 1 or memory_budget < 0:
1549        raise RuntimeError(
1550            f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}"
1551        )
1552    min_cut_options = MinCutOptions(
1553        ban_if_used_far_apart=config.ban_recompute_used_far_apart,
1554        ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains,
1555        ban_if_materialized_backward=config.ban_recompute_materialized_backward,
1556        ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist,
1557        ban_if_reduction=config.ban_recompute_reductions,
1558    )
1559
1560    if config.aggressive_recomputation:
1561        min_cut_options = replace(
1562            min_cut_options,
1563            ban_if_used_far_apart=False,
1564            ban_if_long_fusible_chains=False,
1565            ban_if_materialized_backward=False,
1566            ban_if_not_in_allowlist=False,
1567        )
1568    if memory_budget == 0:
1569        return node_info.inputs
1570
1571    runtime_optimized_saved_values, _ = solve_min_cut(
1572        joint_graph,
1573        node_info,
1574        min_cut_options,
1575    )
1576    # return runtime_optimized_saved_values
1577    if memory_budget == 1:
1578        return runtime_optimized_saved_values
1579
1580    def estimate_activations_size(saved_values: List[fx.Node]) -> float:
1581        return sum(map(_size_of, saved_values)) / 1e9
1582
1583    min_act_size = estimate_activations_size(node_info.inputs)
1584    max_act_size = estimate_activations_size(runtime_optimized_saved_values)
1585    # The optimized choice is smaller than the inputs anyways
1586    if max_act_size <= min_act_size:
1587        return runtime_optimized_saved_values
1588
1589    def get_normalized_size(sz):
1590        return (sz / 1e9) / (max_act_size - min_act_size)
1591
1592    def get_mem_ratio(activations: List[fx.Node]):
1593        return (estimate_activations_size(activations) - min_act_size) / (
1594            max_act_size - min_act_size
1595        )
1596
1597    more_aggressive_options = replace(
1598        min_cut_options,
1599        ban_if_used_far_apart=False,
1600        ban_if_long_fusible_chains=False,
1601        ban_if_materialized_backward=False,
1602    )
1603    more_aggressive_saved_values, _ = solve_min_cut(
1604        joint_graph, node_info, more_aggressive_options
1605    )
1606    if get_mem_ratio(more_aggressive_saved_values) < memory_budget:
1607        return more_aggressive_saved_values
1608
1609    aggressive_options = replace(
1610        more_aggressive_options,
1611        ban_if_not_in_allowlist=False,
1612    )
1613    aggressive_recomputation_saved_values, banned_nodes = solve_min_cut(
1614        joint_graph, node_info, aggressive_options
1615    )
1616
1617    if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget:
1618        return aggressive_recomputation_saved_values
1619
1620    from torch._inductor.fx_utils import get_node_storage
1621
1622    input_storages = {get_node_storage(node) for node in node_info.inputs}
1623
1624    def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]:
1625        return [
1626            i
1627            for i in banned_nodes
1628            if (
1629                # Only allow recomputing nodes that are actually required for BW
1630                i.dist_from_bw < int(1e9)  # type: ignore[attr-defined]
1631                and get_node_storage(i) not in input_storages
1632            )
1633        ]
1634
1635    recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes)
1636
1637    # default: runtime_optimized_saved_values
1638    # more aggressive: more_aggressive_saved_values
1639    # full aggressive: aggressive_recomputation_saved_values
1640
1641    all_recomputable_banned_nodes = sorted(
1642        recomputable_banned_nodes, key=_size_of, reverse=True
1643    )
1644    if len(all_recomputable_banned_nodes) == 0:
1645        return node_info.inputs
1646    memories_banned_nodes = [
1647        get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes
1648    ]
1649    runtimes_banned_nodes = [
1650        estimate_runtime(node) for node in all_recomputable_banned_nodes
1651    ]
1652    from torch.utils._mode_utils import no_dispatch
1653
1654    def get_saved_values_knapsack(memory_budget):
1655        with no_dispatch():
1656            (
1657                expected_runtime,
1658                saved_node_idxs,
1659                recomputable_node_idxs,
1660            ) = _optimize_runtime_with_given_memory(
1661                memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0)
1662            )
1663        dont_ban = set()
1664        for idx in recomputable_node_idxs:
1665            dont_ban.add(all_recomputable_banned_nodes[idx])
1666        assert dont_ban.issubset(all_recomputable_banned_nodes)
1667
1668        saved_values, _ = solve_min_cut(
1669            joint_graph,
1670            node_info,
1671            aggressive_options,
1672            dont_ban,
1673        )
1674        return saved_values, expected_runtime
1675
1676    if config.visualize_memory_budget_pareto:
1677        options = []
1678        for sweep_memory_budget in range(100, -1, -5):
1679            saved_values, expected_runtime = get_saved_values_knapsack(
1680                sweep_memory_budget / 100
1681            )
1682            options.append(
1683                (
1684                    sweep_memory_budget,
1685                    sum(runtimes_banned_nodes) - expected_runtime,
1686                    get_mem_ratio(saved_values),
1687                )
1688            )
1689
1690        import matplotlib.pyplot as plt
1691
1692        x_values = [item[2] for item in options]
1693        y_values = [item[1] for item in options]
1694
1695        # Plotting the values with updated axis labels and chart title
1696        plt.figure(figsize=(10, 6))
1697        plt.plot(x_values, y_values, marker="o")
1698
1699        # Adding labels for each point
1700        for i, txt in enumerate(x_values):
1701            plt.annotate(
1702                f"{txt:.2f}",
1703                (txt, y_values[i]),
1704                textcoords="offset points",
1705                xytext=(0, 10),
1706                ha="center",
1707            )
1708
1709        plt.xlabel("Memory Budget")
1710        plt.ylabel("Runtime of Recomputed Components")
1711        plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime")
1712        plt.grid(True)
1713        fig = plt.gcf()
1714        plt.show()
1715        fig_name = f"memory_budget_pareto_{get_aot_graph_name()}.png"
1716        fig.savefig(fig_name)
1717        log.warning("Generated Pareto frontier curve at %s", fig_name)
1718
1719    # todo(chilli): Estimated doesn't align exactly with actual - actual is
1720    # usually less memory than estimated. i'm guessing (actually quite
1721    # unsure about this) that's because estimated is just only including
1722    # tensors we actually banned from recompute, but there may be other
1723    # tensors that we choose to save.
1724
1725    return get_saved_values_knapsack(memory_budget=memory_budget)[0]
1726
1727
1728def min_cut_rematerialization_partition(
1729    joint_module: fx.GraphModule,
1730    _joint_inputs,
1731    compiler="inductor",
1732    *,
1733    num_fwd_outputs,
1734) -> Tuple[fx.GraphModule, fx.GraphModule]:
1735    """
1736    Partitions the joint graph such that the backward recomputes the forward.
1737    Recomputing helps in trading off memory bandwidth with computation.
1738
1739    To create the fwd and bwd graph, we copy the joint graph, manually set the
1740    outputs to just original forward or backward outputs. And then we run the
1741    resulting graphs through dead code elimination.
1742
1743    .. warning::
1744        This API is experimental and likely to change.
1745
1746    Args:
1747        joint_module(fx.GraphModule): The joint forward and backward graph. This
1748            is the result of AOT Autograd tracing.
1749        _joint_inputs: The inputs to the joint graph. This is unused.
1750        compiler: This option determines the default set of recomputable ops.
1751            Currently, there are two options: ``nvfuser`` and ``inductor``.
1752        recomputable_ops: This is an optional set of recomputable ops. If this
1753            is not None, then this set of ops will be used instead of the
1754            default set of ops.
1755        num_fwd_outputs: The number of outputs from the forward graph.
1756
1757    Returns:
1758        Returns the generated forward and backward Fx graph modules.
1759    """
1760
1761    joint_module.graph.eliminate_dead_code()
1762    joint_module.recompile()
1763
1764    fx_g = joint_module.graph
1765
1766    #  add the CSE pass
1767    if config.cse:
1768        cse_graph = fx_graph_cse(fx_g)
1769        joint_module.graph = cse_graph
1770    joint_graph = joint_module.graph
1771
1772    graph_has_recomputable_ops = has_recomputable_ops(joint_module)
1773    graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
1774    if graph_has_recomputable_ops:
1775        joint_module = cleanup_recompute_tags(joint_module)
1776
1777    def classify_nodes(joint_module):
1778        name_to_node = get_name_to_node(joint_module.graph)
1779        required_bw_nodes = set()
1780        for node in joint_module.graph.nodes:
1781            if node.op == "placeholder" and "tangents" in node.target:
1782                required_bw_nodes.add(node)
1783            elif _must_be_in_backward(node):
1784                required_bw_nodes.add(node)
1785
1786            if node in required_bw_nodes:
1787                for user in node.users:
1788                    required_bw_nodes.add(user)
1789
1790        primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
1791        fwd_seed_offset_inputs = list(
1792            filter(_is_fwd_seed_offset, joint_module.graph.nodes)
1793        )
1794        inputs = primal_inputs + fwd_seed_offset_inputs
1795        fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
1796            joint_module, num_fwd_outputs=num_fwd_outputs
1797        )
1798        required_bw_nodes.update(
1799            o for o in bwd_outputs if o is not None and o.op != "output"
1800        )
1801        forward_only_graph = _extract_graph_with_inputs_outputs(
1802            joint_module.graph, inputs, fwd_outputs, "forward"
1803        )
1804        required_fw_nodes: Set[fx.Node] = {
1805            name_to_node[node.name]
1806            for node in forward_only_graph.nodes
1807            if node.op != "output"
1808        }
1809        unclaimed_nodes = {
1810            node
1811            for node in joint_module.graph.nodes
1812            if node not in required_fw_nodes and node not in required_bw_nodes
1813        }
1814        fw_cnt = 0
1815        fw_order = {}
1816        for node in joint_module.graph.nodes:
1817            if node in required_fw_nodes:
1818                fw_order[node] = fw_cnt
1819                fw_cnt += 1
1820        return NodeInfo(
1821            inputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, fw_order
1822        )
1823
1824    node_info = classify_nodes(joint_module)
1825
1826    # networkx blows up on graphs with no required backward nodes
1827    # Since there's nothing to partition anyway, and the default partitioner can "handle"
1828    # this case, send our graph over to the default partitioner.
1829    if len(node_info.required_bw_nodes) == 0:
1830        return default_partition(
1831            joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs
1832        )
1833
1834    for node in reversed(joint_module.graph.nodes):
1835        if node.op == "output":
1836            node.dist_from_bw = int(1e9)
1837        elif not node_info.is_required_fw(node):
1838            node.dist_from_bw = 0
1839        else:
1840            node.dist_from_bw = int(1e9)
1841            for user in node.users:
1842                node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1)
1843
1844    memory_budget = config.activation_memory_budget
1845    for node in joint_graph.nodes:
1846        if isinstance(node.meta.get("memory_budget", None), float):
1847            memory_budget = node.meta["memory_budget"]
1848            break
1849    # print("Memory Budget: ", memory_budget)
1850    saved_values = choose_saved_values_set(
1851        joint_graph, node_info, memory_budget=memory_budget
1852    )
1853    # save_for_backward on tensors and stashes symints in autograd .ctx
1854    saved_sym_nodes = list(filter(is_sym_node, saved_values))
1855    saved_values = list(filter(lambda n: not is_sym_node(n), saved_values))
1856
1857    # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols
1858    fw_module, bw_module = _extract_fwd_bwd_modules(
1859        joint_module,
1860        saved_values,
1861        saved_sym_nodes=saved_sym_nodes,
1862        num_fwd_outputs=num_fwd_outputs,
1863    )
1864
1865    if graph_has_recomputable_ops:
1866        if graph_has_recomputable_rng_ops:
1867            fw_module, bw_module = functionalize_rng_ops(
1868                joint_module, fw_module, bw_module, len(saved_sym_nodes)
1869            )
1870    bw_module = reordering_to_mimic_autograd_engine(bw_module)
1871
1872    if AOT_PARTITIONER_DEBUG:
1873        from torch._inductor.fx_utils import get_node_storage
1874
1875        storages = {get_node_storage(node) for node in saved_values}
1876        print(
1877            "Theoretical Activations Stored: ",
1878            sum(_size_of(i) for i in saved_values) / 1e9,
1879        )
1880        sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values])
1881        fw_module_nodes = {
1882            node.name for node in fw_module.graph.nodes if node.op == "call_function"
1883        }
1884        bw_module_nodes = {
1885            node.name for node in bw_module.graph.nodes if node.op == "call_function"
1886        }
1887        remat_nodes = fw_module_nodes & bw_module_nodes
1888
1889        counts: Dict[str, int] = defaultdict(int)
1890        for node in fw_module.graph.nodes:
1891            if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"):
1892                counts[str(node.target._overloadpacket)] += 1
1893        print(
1894            f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}"
1895        )
1896        print(
1897            "Count of Ops Rematerialized: ",
1898            sorted(counts.items(), key=lambda x: x[1], reverse=True),
1899        )
1900    return fw_module, bw_module
1901
1902
1903def draw_graph(
1904    traced: torch.fx.GraphModule,
1905    fname: str,
1906    figname: str = "fx_graph",
1907    clear_meta: bool = True,
1908    prog: Optional[Union[str, List[str]]] = None,
1909    parse_stack_trace: bool = False,
1910    dot_graph_shape: Optional[str] = None,
1911) -> None:
1912    if clear_meta:
1913        new_graph = copy.deepcopy(traced.graph)
1914        traced = fx.GraphModule(traced, new_graph)
1915        for node in traced.graph.nodes:
1916            node.meta = {}
1917    base, ext = os.path.splitext(fname)
1918    if not ext:
1919        ext = "." + config.torch_compile_graph_format
1920    print(f"Writing FX graph to file: {base}{ext}")
1921    g = graph_drawer.FxGraphDrawer(
1922        traced,
1923        figname,
1924        parse_stack_trace=parse_stack_trace,
1925        dot_graph_shape=dot_graph_shape,
1926    )
1927    x = g.get_main_dot_graph()
1928    write_method = getattr(x, "write_" + ext.lstrip("."))
1929    fname = f"{base}{ext}"
1930    if prog is None:
1931        write_method(fname)
1932    else:
1933        write_method(fname, prog=prog)
1934