xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/post_grad.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import functools
4import itertools
5import logging
6import operator
7from collections import Counter, defaultdict
8from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
9
10import torch
11import torch._inductor as inductor
12import torch.utils._pytree as pytree
13from torch import fx
14from torch._decomp import register_decomposition
15from torch._dynamo.utils import counters, optimus_scuba_log
16from torch._inductor import comms
17from torch._inductor.virtualized import ops
18from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
19from torch._utils_internal import upload_graph
20from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
21from torch.fx.passes.graph_transform_observer import GraphTransformObserver
22
23from .. import config, ir, pattern_matcher
24from ..codegen.common import BackendFeature, has_backend_feature
25from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
26from ..lowering import lowerings as L
27from ..pattern_matcher import (
28    _return_true,
29    Arg,
30    CallFunction,
31    CallFunctionVarArgs,
32    filter_nodes,
33    get_arg_value,
34    get_mutation_region_id,
35    Ignored,
36    init_once_fakemode,
37    KeywordArg,
38    ListOf,
39    Match,
40    MULTIPLE,
41    PatternMatcherPass,
42    register_graph_pattern,
43    stable_topological_sort,
44)
45from ..utils import decode_device, get_gpu_type, is_pointwise_use
46from ..virtualized import V
47from .b2b_gemm import B2B_GEMM_PASS
48from .ddp_fusion import fuse_ddp_communication
49from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS
50from .micro_pipeline_tp import micro_pipeline_tp_pass
51from .pre_grad import is_same_dict, save_inductor_dict
52from .reinplace import reinplace_inplaceable_ops
53from .split_cat import POST_GRAD_PATTERNS
54
55
56if TYPE_CHECKING:
57    from sympy import Expr
58
59
60log = logging.getLogger(__name__)
61aten = torch.ops.aten
62prims = torch.ops.prims
63
64# First pass_patterns[0] are applied, then [1], then [2]
65pass_patterns = [
66    PatternMatcherPass(),
67    PatternMatcherPass(),
68    PatternMatcherPass(),
69]
70
71
72def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
73    """
74    Passes that run on after grad.  This is called once on the forwards
75    graph and once on the backwards graph.
76
77    The IR here has been normalized and functionalized.
78    """
79    if config.dce:
80        # has some issues with mutation in inference mode
81        gm.graph.eliminate_dead_code()
82
83    if is_inference and config.reorder_for_locality:
84        reorder_for_locality(gm.graph)
85
86    fake_tensor_updater = FakeTensorUpdater(gm.graph)
87
88    if config.post_grad_custom_pre_pass is not None:
89        with GraphTransformObserver(
90            gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform
91        ):
92            config.post_grad_custom_pre_pass(gm.graph)
93
94    if config.pattern_matcher:
95        lazy_init()
96        optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
97        group_batch_fusion_passes(gm.graph, pre_grad=False)
98        remove_noop_ops(gm.graph)
99        for patterns in pass_patterns:
100            patterns.apply(gm.graph)  # type: ignore[arg-type]
101        for pass_name in config.post_grad_fusion_options:
102            # skip all patterns for group batch fusions
103            if pass_name in POST_GRAD_FUSIONS:
104                continue
105            pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name]
106            inductor_before_change = save_inductor_dict(
107                [pattern_matcher_pass.pass_name]
108            )
109            pattern_matcher_pass.apply(gm.graph)  # type: ignore[arg-type]
110            if not is_same_dict(counters["inductor"], inductor_before_change):
111                optimus_scuba_log[
112                    f"{pattern_matcher_pass.pass_name}_post_grad"
113                ] = upload_graph(gm.graph)
114        if config.b2b_gemm_pass:
115            B2B_GEMM_PASS.apply(gm.graph)  # type: ignore[arg-type]
116
117    if config._micro_pipeline_tp:
118        micro_pipeline_tp_pass(gm.graph)
119
120    if config._fuse_ddp_communication:
121        fuse_ddp_communication(
122            gm.graph,
123            config._fuse_ddp_communication_passes,
124            config._fuse_ddp_bucket_size,
125        )
126
127    if config.post_grad_custom_post_pass is not None:
128        with GraphTransformObserver(
129            gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform
130        ):
131            config.post_grad_custom_post_pass(gm.graph)
132
133    stable_topological_sort(gm.graph)
134
135    move_constructors_to_gpu(gm.graph)
136
137    fake_tensor_updater.incremental_update()
138
139    # Keep these last, since they introduces mutation. Look at
140    # ./fx_passes/README.md for a discussion of mutation invariants.
141    reinplace_inplaceable_ops(gm.graph)
142    decompose_auto_functionalized(gm.graph)
143
144    comms.reinplace_fsdp_all_gather(gm.graph)
145
146    gm.recompile()
147    optimus_scuba_log["after_recompile_post_grad"] = upload_graph(gm.graph)
148    gm.graph.lint()
149
150
151@init_once_fakemode
152def lazy_init():
153    if torch._C._has_mkldnn:
154        from . import decompose_mem_bound_mm  # noqa: F401
155        from .mkldnn_fusion import _mkldnn_fusion_init
156
157        _mkldnn_fusion_init()
158
159
160def reorder_for_locality(graph: torch.fx.Graph):
161    def visit(other_node):
162        if (
163            other_node.op == "call_function"
164            and other_node.target != operator.getitem
165            and all((n in seen_nodes) for n in other_node.users)
166            and get_mutation_region_id(graph, node)
167            == get_mutation_region_id(graph, other_node)
168        ):
169            # move node's producers right before it
170            node.prepend(other_node)
171
172    seen_nodes = set()
173
174    # only reorder nodes before the first copy_ in the graph.
175    # copy_ will appear at the end of functionalized graphs when there is mutation on inputs,
176    # and this reordering doesnt work well with mutation
177    first_copy = next(
178        iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)),
179        None,
180    )
181    past_mutating_epilogue = True if first_copy is None else False
182
183    for node in reversed(graph.nodes):
184        seen_nodes.add(node)
185        if not past_mutating_epilogue:
186            past_mutating_epilogue = node is first_copy
187            continue
188
189        torch.fx.map_arg((node.args, node.kwargs), visit)
190
191
192def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1):
193    """
194    Register an aten to inductor IR replacement pattern
195    """
196    return pattern_matcher.register_lowering_pattern(
197        pattern, extra_check, pass_dict=pass_patterns[pass_number]
198    )
199
200
201################################################################################
202# Actual patterns below this point.
203# Priority of patterns is:
204#   - later output nodes first
205#   - order patterns are defined in
206################################################################################
207
208
209def is_valid_mm_plus_mm(match: Match):
210    *b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape
211    *b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape
212    if k1 != k2:
213        return False
214
215    *b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape
216    *b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape
217    if k3 != k4:
218        return False
219
220    if m1 != m2 or n1 != n2:
221        return False
222
223    return True
224
225
226def scatter_upon_const_tensor_extra_check(m):
227    if not config.optimize_scatter_upon_const_tensor:
228        return False
229    full_shape = m.kwargs["shape"]
230    selector = m.kwargs["selector"]
231    dim = m.kwargs["dim"]
232    if dim < 0:
233        dim += len(full_shape)
234
235    selector_ft = selector.meta["val"]
236    assert selector_ft.dim() == len(full_shape)
237
238    for idx, select_sz, full_sz in zip(
239        itertools.count(), selector_ft.shape, full_shape
240    ):
241        if idx == dim:
242            continue
243
244        # TODO: the pattern can be updated to support the case that index tensor
245        # is shorter. But that will need a more complex condition expression
246        # especially for multi-dimensional tensors.
247        # Skip it for now.
248        if isinstance(full_sz, fx.Node):
249            full_sz = full_sz.meta["val"]
250        if select_sz < full_sz:
251            return False
252
253    # Actually we can support small size larger than 1. It would be a bit
254    # tedius. E.g., we load all the index values (not many) and compare
255    # them with the position in tensor to decide what value to return.
256    return selector_ft.size(dim) == 1
257
258
259@register_lowering_pattern(
260    CallFunction(
261        aten.scatter.value,
262        CallFunction(
263            aten.full,
264            KeywordArg("shape"),
265            KeywordArg("background_val"),
266            dtype=KeywordArg("dtype"),
267        ),
268        KeywordArg("dim"),
269        KeywordArg("selector"),
270        KeywordArg("val"),  # scalar value
271    ),
272    extra_check=scatter_upon_const_tensor_extra_check,
273)
274def scatter_upon_const_tensor(
275    match: Match, shape, background_val, dtype, dim, selector, val
276):
277    """
278    Match the pattern of full+scatter into a pointwise.
279
280    TODO: Right now the scatter value must be a scalar. But we could support it
281    when it is a tensor as well.
282    """
283    from torch._inductor import metrics
284
285    metrics.num_matches_for_scatter_upon_const_tensor += 1
286
287    selector_loader = selector.make_loader()
288
289    def inner_fn(idx):
290        selector_idx = list(idx)
291        selector_idx[dim] = 0
292
293        selector = selector_loader(selector_idx)
294        return ops.where(
295            selector == ops.index_expr(idx[dim], torch.int64),
296            ops.constant(val, dtype),
297            ops.constant(background_val, dtype),
298        )
299
300    return ir.Pointwise.create(
301        device=selector.get_device(),
302        dtype=dtype,
303        inner_fn=inner_fn,
304        ranges=shape,
305    )
306
307
308@register_lowering_pattern(
309    CallFunction(
310        aten.add,
311        CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")),
312        CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")),
313    ),
314    extra_check=is_valid_mm_plus_mm,
315)
316def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4):
317    return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4)
318
319
320def cuda_and_enabled_mixed_mm(match):
321    return (
322        (config.use_mixed_mm or config.mixed_mm_choice != "default")
323        and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
324        and (
325            match.kwargs["mat2_dtype"].itemsize
326            > match.kwargs["mat2"].meta.get("val").dtype.itemsize
327        )
328        and has_backend_feature("cuda", BackendFeature.TRITON_TEMPLATES)
329    )
330
331
332def cuda_and_enabled_mixed_mm_and_not_int8(match):
333    return (
334        cuda_and_enabled_mixed_mm(match)
335        and getattr(match.kwargs["mat1"].meta.get("val"), "is_cuda", False)
336        and getattr(match.kwargs["mat2"].meta.get("val"), "dtype", torch.int8)
337        != torch.int8
338    )  # bitshift numerics in triton and pytorch don't match for torch.int8
339
340
341"""
342    this is intended to be used to unpack a [K,N] int4 tensor from a [K/2, N] uint4x2 tensor
343    (where the int4 and uint4x2 are represented with int8 and uint8 respectively)
344    where every other row of the int4 is packed with the row above it as:
345    uint4x2[k,n] = (8+int4[2*k,n])+(8+int4[2*k+1,n])<<4
346
347    unpack formulas:
348    int4[2*k,n]=(uint4x2[k,n] & 0xF) - 8
349    int4[2*k+1,n]=(uint4x2[k,n] >> 4) - 8
350
351    thus matching on unpack formula:
352    torch.mm(mat1, torch.cat((mat2 & 0xF, mat2>>4),1).reshape(mat2_mm_shape).to(mat2_dtype).sub(8))
353
354    note: although the unpack formula in pytorch and the triton kernel is designed for a uint8 mat2, the behavior
355    of the kernel matches the pytorch formula for all dtypes except torch.int8
356    where the bitwise numerics in triton do not match those in pytorch.
357"""
358
359
360@register_lowering_pattern(
361    CallFunction(
362        aten.mm.default,
363        KeywordArg("mat1"),
364        CallFunction(
365            aten.sub.Tensor,
366            CallFunction(
367                prims.convert_element_type.default,
368                CallFunction(
369                    aten.reshape.default,
370                    CallFunction(
371                        aten.cat.default,
372                        ListOf(
373                            CallFunction(
374                                aten.bitwise_and.Scalar,
375                                KeywordArg("mat2"),
376                                0xF,
377                            ),
378                            # CallFunction(
379                            #    aten.__rshift__.Scalar,
380                            #    KeywordArg("mat2"),
381                            #    4,
382                            # ),
383                            True,
384                        ),
385                        1,
386                    ),
387                    KeywordArg("mat2_mm_shape"),
388                ),
389                KeywordArg("mat2_dtype"),
390            ),
391            8,
392        ),
393    ),
394    extra_check=cuda_and_enabled_mixed_mm_and_not_int8,
395)
396def uint4x2_mixed_mm(match: Match, mat1, mat2, mat2_mm_shape, mat2_dtype):
397    return inductor.kernel.unpack_mixed_mm.tuned_uint4x2_mixed_mm(
398        mat1, mat2, mat2_mm_shape, mat2_dtype
399    )
400
401
402"""
403    torch.mm(mat1, mat2.to(mat2_dtype))
404"""
405
406
407@register_lowering_pattern(
408    CallFunction(
409        aten.mm,
410        KeywordArg("mat1"),
411        CallFunction(
412            prims.convert_element_type.default,
413            KeywordArg("mat2"),
414            KeywordArg("mat2_dtype"),
415        ),
416    ),
417    extra_check=cuda_and_enabled_mixed_mm,
418)
419def mixed_mm(match: Match, mat1, mat2, mat2_dtype):
420    return inductor.kernel.mm.tuned_mixed_mm(mat1, mat2, mat2_dtype)
421
422
423@register_graph_pattern(
424    CallFunction(
425        aten.cumsum.default,
426        CallFunction(
427            torch.ops.aten.full.default,
428            KeywordArg("shape"),
429            KeywordArg("fill_value"),
430            dtype=KeywordArg("dtype"),
431            layout=Ignored(),
432            device=KeywordArg("device"),
433            pin_memory=False,
434            _users=MULTIPLE,
435        ),
436        KeywordArg("dim"),
437        _users=MULTIPLE,
438    ),
439    pass_dict=pass_patterns[1],
440)
441def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim):
442    """Based on a pattern in OPTForCausalLM"""
443
444    if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
445        # cumsum promotes all integral types to int64
446        dtype = torch.int64
447
448    def repl(*shape):
449        dim_size = shape[dim]
450        idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype)
451
452        inter_shape = [1] * len(shape)
453        inter_shape[dim] = dim_size
454        return (idx * fill_value).view(inter_shape).expand(shape)
455
456    # only replace the output node, not all nodes
457    match.nodes = [match.output_node()]
458    match.replace_by_example(repl, list(shape))
459
460
461def shape_of_mm(a, b):
462    m, _ = a.get_size()
463    _, n = b.get_size()
464    return [m, n]
465
466
467@register_lowering_pattern(
468    CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()),
469)
470def cat_mm(match, inputs, dim):
471    return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm)
472
473
474@register_lowering_pattern(
475    CallFunction(
476        aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg()
477    ),
478)
479def cat_addmm(match, inputs, dim):
480    def shape_of(bias, a, b):
481        m, _ = a.get_size()
482        _, n = b.get_size()
483        return [m, n]
484
485    return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of)
486
487
488def cat_tuned_op(match, inputs, dim, *, op, shape_of):
489    """
490    Memory planning to remove cat. We can't use the stock memory
491    planner since autotuning matmuls needs to know the output layout.
492    """
493    if len(inputs) == 1:
494        return op(*inputs[0])
495
496    # TODO(jansel): rewrite this as a bmm?
497    if dim < 0:
498        dim += len(shape_of(*inputs[0]))
499    assert dim in (0, 1)
500    notdim = 1 - dim
501
502    new_size: Optional[Union[List[Expr], List[int]]] = None
503    offsets_start = []
504    offsets_end = []
505
506    # compute output sizes
507    for i in range(len(inputs)):
508        shape = shape_of(*inputs[i])
509        if new_size is None:
510            new_size = shape
511        else:
512            new_size[notdim] = V.graph.sizevars.guard_equals(  # type: ignore[call-overload]
513                shape[notdim], new_size[notdim]
514            )
515            new_size[dim] += shape[dim]
516        offsets_start.append(new_size[dim] - shape[dim])
517        offsets_end.append(new_size[dim])
518
519    assert new_size is not None
520    dtype = functools.reduce(
521        torch.promote_types,
522        [x.get_dtype() for x in itertools.chain.from_iterable(inputs)],
523    )
524    device = inputs[0][0].get_device()
525    kernel = ir.ConcatKernel(
526        name=None,
527        layout=ir.FixedLayout(device, dtype, new_size),
528        inputs=[],
529    )
530    kernel_tensor = ir.TensorBox.create(kernel)
531
532    for i in range(len(inputs)):
533        dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i])
534        src = op(*inputs[i], layout=dst.get_layout()).data.data
535        assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer))
536        src.layout = ir.NonOwningLayout(dst)
537        kernel.inputs.append(src)
538
539    kernel.name = V.graph.register_buffer(kernel)
540    kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs)
541    V.graph.register_operation(kernel)
542    return kernel_tensor
543
544
545_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2)
546
547
548@register_lowering_pattern(
549    CallFunction(
550        aten.cat,
551        [
552            _cat_1,
553            CallFunction(
554                aten.slice,
555                _cat_1,
556                1,
557                0,
558                KeywordArg("size"),
559            ),
560        ],
561        1,
562    )
563)
564def cat_slice_cat(match, cat_input, size, dim=1):
565    """
566    This is an example of a more complex pattern where cat_1 is used
567    multiple times inside the pattern.  We fold 2 calls to cat into one.
568
569    Matches:
570        cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1)
571        slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
572        slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
573        cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1)
574
575
576    Rewrite to:
577        slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19)
578        cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1)
579    """
580    first, *rest = cat_input
581    # Optimization is optional, because we can just not fold the cat
582    # size should be within first.get_size()[dim] such that the optimization is valid.
583    # For negative `end`, we currently fallback to not optimizing.
584    if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]):
585        # fold 2 cats into 1 cat
586        return L[aten.cat](
587            [
588                first,
589                *rest,
590                L[aten.slice](first, dim, 0, size),
591            ],
592            dim,
593        )
594    else:
595        # don't expect to hit this case, just fall back
596        tmp = L[aten.cat](cat_input, dim)
597        return L[aten.cat](
598            [
599                tmp,
600                L[aten.slice](tmp, dim, 0, size),
601            ],
602            dim,
603        )
604
605
606def is_valid_splitwithsizes_cat(match):
607    split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
608    cat_nodes = filter_nodes(match.nodes, aten.cat)
609    get_item_nodes = filter_nodes(match.nodes, operator.getitem)
610    if len(split_nodes) != 1 or len(cat_nodes) != 1:
611        return False
612    split_node, cat_node = split_nodes[0], cat_nodes[0]
613    # The dim of split and cat should match for passthrough
614    if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"):
615        return False
616    get_item_args = {
617        get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes
618    }
619    assert None not in get_item_args
620    split_sizes = get_arg_value(split_node, 1, "split_sizes")
621    # All parts of split should be included in the cat
622    if get_item_args != set(range(len(split_sizes))):
623        return False
624    # The order of get_item_args should same with cat_node used.
625    # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1),
626    # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1).
627    cat_items_args_order = [
628        get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0)
629    ]
630    if cat_items_args_order != list(range(len(split_sizes))):
631        return False
632
633    return True
634
635
636def same_meta(node1: torch.fx.Node, node2: torch.fx.Node):
637    """True if two nodes have the same metadata"""
638    val1 = node1.meta.get("val")
639    val2 = node2.meta.get("val")
640    return (
641        val1 is not None
642        and val2 is not None
643        and statically_known_true(sym_eq(val1.size(), val2.size()))
644        and val1.layout == val2.layout
645        and val1.dtype == val2.dtype
646        and val1.device == val2.device
647        and (
648            val1.layout != torch.strided
649            or statically_known_true(sym_eq(val1.stride(), val2.stride()))
650        )
651    )
652
653
654noop_registry: Dict[Any, Any] = {}
655
656
657def register_noop_decomp(targets, nop_arg=0):
658    def register_fun(cond):
659        register_decomposition(targets, registry=noop_registry, unsafe=True)(
660            (cond, nop_arg)  # type: ignore[arg-type]
661        )
662        return cond
663
664    return register_fun
665
666
667@register_noop_decomp(aten.slice)
668def slice_noop(self, dim=0, start=None, end=None, step=1):
669    if start is None or end is None:
670        return False
671    if (
672        statically_known_true(sym_eq(start, 0))
673        and statically_known_true(end >= 2**63 - 1)
674        and statically_known_true(sym_eq(step, 1))
675    ):
676        return True
677    return False
678
679
680@register_noop_decomp(aten.slice_scatter, 1)
681def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1):
682    if start is None:
683        start = 0
684    if end is None:
685        end = 2**63 - 1
686    if start == 0 and end >= 2**63 - 1 and step == 1:
687        return True
688    return False
689
690
691@register_noop_decomp(aten.repeat)
692def repeat_noop(self, repeats):
693    return all(r == 1 for r in repeats)
694
695
696@register_noop_decomp(aten.constant_pad_nd)
697def constant_pad_nd(x, padding, fill_value=0):
698    return all(p == 0 for p in padding)
699
700
701@register_noop_decomp(torch.ops.prims.convert_element_type)
702def convert_element_type_noop(x, dtype: torch.dtype):
703    return x.dtype == dtype
704
705
706@register_noop_decomp(torch.ops.prims.device_put)
707def device_put_noop(x, device):
708    return x.device == decode_device(device)
709
710
711@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc])
712def int_noop(x):
713    return is_integer_dtype(x.dtype)
714
715
716@register_noop_decomp([aten.pow])
717def pow_noop(a, b):
718    return isinstance(b, int) and b == 1
719
720
721@register_noop_decomp([aten.cat], lambda args: args[0][0])
722def cat_noop(inputs, dim=0):
723    return len(inputs) == 1
724
725
726@register_noop_decomp(aten.view)
727def view_noop(arg, size):
728    return arg.shape == size
729
730
731# Note, we also always have a check for identical metadata, which is why these
732# are safe
733@register_noop_decomp([aten.copy], nop_arg=1)
734@register_noop_decomp([aten.alias, aten.clone])
735def true_noop(*args, **kwargs):
736    return True
737
738
739def remove_noop_ops(graph: torch.fx.Graph):
740    """
741    Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph.
742    """
743    inputs = set()
744    input_storages = set()
745    output_storages = set()
746
747    for node in graph.find_nodes(op="placeholder"):
748        inputs.add(node)
749        input_storages.add(get_node_storage(node))
750
751    output_node = next(iter(reversed(graph.nodes)))
752    assert output_node.op == "output"
753    outputs = output_node.args[0]
754    if not isinstance(outputs, (list, tuple)):
755        # nested subgraphs can have singleton outputs
756        outputs = (outputs,)
757    for out in outputs:
758        if isinstance(out, torch.fx.Node):
759            output_storages.add(get_node_storage(out))
760
761    for node in graph.nodes:
762        if node.target in noop_registry:
763            cond, src_index = noop_registry[node.target]
764            if isinstance(src_index, int):
765                src = node.args[src_index]
766            else:
767                src = src_index(node.args)
768            if not isinstance(src, torch.fx.Node):
769                continue
770            # Don't introduce new aliasing between inputs and outputs.
771            # See fx_passes/README.md for a discussion of why this is
772            # necessary.
773            node_storage = get_node_storage(node)
774            src_storage = get_node_storage(src)
775            node_is_view = node_storage == src_storage
776            if (
777                not node_is_view
778                and node_storage in output_storages
779                and (src_storage in input_storages or src_storage in output_storages)
780            ):
781                continue
782
783            # Even if input and outputs are expected to alias,
784            # don't make "node is src" True
785            if (
786                node_is_view
787                and node in output_node.args
788                and (src in inputs or src in output_node.args)
789            ):
790                continue
791
792            is_valid, args, kwargs = get_fake_args_kwargs(node)
793            if not is_valid:
794                continue
795            if same_meta(node, src) and cond(*args, **kwargs):
796                node.replace_all_uses_with(src)
797                graph.erase_node(node)
798
799
800def decompose_auto_functionalized(graph):
801    """Decomposes auto_functionalized and triton_kernel_wrapper_functional
802    nodes into clones and the underlying mutation node.
803
804    We assume that the reinplacing pass runs before this; the reinplacing pass
805    tells us (via rewriting the arguments or .meta to those nodes) which
806    Tensors we should clone and which Tensors are safe to reinplace.
807    """
808    graph_pass = PatternMatcherPass()
809
810    @register_graph_pattern(
811        CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized),
812        pass_dict=graph_pass,
813    )
814    def _(match: Match, *args, **kwargs):
815        from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense
816
817        only_clone_these_tensors = tuple(
818            match.nodes[0].meta.get("only_clone_these_tensors", [])
819        )
820
821        flat_args, spec = pytree.tree_flatten((args, kwargs))
822
823        # NB: we combine (args, kwargs) into flat args for replacing.
824        # This is replace_by_example uses make_fx which does not support
825        # tracing a function with kwargs.
826        def decomp(*flat_args):
827            args, kwargs = pytree.tree_unflatten(flat_args, spec)
828            return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs)
829
830        match.replace_by_example(decomp, flat_args, run_functional_passes=False)
831
832    @register_graph_pattern(
833        CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional),
834        pass_dict=graph_pass,
835    )
836    def _(match: Match, *args, **kwargs):
837        from torch._higher_order_ops.triton_kernel_wrap import (
838            triton_kernel_wrapper_functional_dense,
839        )
840
841        flat_args, spec = pytree.tree_flatten((args, kwargs))
842
843        # NB: we combine (args, kwargs) into flat args for replacing.
844        # This is replace_by_example uses make_fx which does not support
845        # tracing a function with kwargs.
846        def decomp(*flat_args):
847            args, kwargs = pytree.tree_unflatten(flat_args, spec)
848            return (triton_kernel_wrapper_functional_dense(*args, **kwargs),)
849
850        match.replace_by_example(decomp, flat_args, run_functional_passes=False)
851
852    @register_graph_pattern(
853        CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2),
854        pass_dict=graph_pass,
855    )
856    def _(match: Match, *args, **kwargs):
857        from torch._higher_order_ops.auto_functionalize import (
858            auto_functionalized_v2_dense,
859        )
860
861        only_clone_these_bases = tuple(
862            match.nodes[0].meta.get("only_clone_these_tensors", [])
863        )
864
865        flat_args, spec = pytree.tree_flatten((args, kwargs))
866
867        # NB: we combine (args, kwargs) into flat args for replacing.
868        # This is replace_by_example uses make_fx which does not support
869        # tracing a function with kwargs.
870        def decomp(*flat_args):
871            args, kwargs = pytree.tree_unflatten(flat_args, spec)
872            return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs)
873
874        match.replace_by_example(decomp, flat_args, run_functional_passes=False)
875
876    graph_pass.apply(graph)
877
878    for node in graph.find_nodes(
879        op="call_function", target=torch.ops.higher_order.auto_functionalized
880    ):
881        raise AssertionError("auto_functionalized was not removed")
882
883    for node in graph.find_nodes(
884        op="call_function", target=torch.ops.higher_order.auto_functionalized_v2
885    ):
886        raise AssertionError("auto_functionalized_v2 was not removed")
887
888    for node in graph.find_nodes(
889        op="call_function",
890        target=torch.ops.higher_order.triton_kernel_wrapper_functional,
891    ):
892        raise AssertionError("triton_kernel_wrapper_functional was not removed")
893
894
895@register_lowering_pattern(
896    CallFunction(
897        aten.cat,
898        ListOf(
899            CallFunction(
900                operator.getitem,
901                CallFunction(
902                    aten.split_with_sizes,
903                    KeywordArg("input_"),
904                    Ignored(),
905                    Ignored(),
906                    _users=MULTIPLE,
907                ),
908                Ignored(),
909            ),
910        ),
911        Ignored(),
912    ),
913    pass_number=2,
914    extra_check=is_valid_splitwithsizes_cat,
915)
916def splitwithsizes_cat_replace(match, input_):
917    return input_
918
919
920def is_valid_cat_splitwithsizes(match):
921    cat_nodes = filter_nodes(match.nodes, aten.cat)
922    split_nodes = filter_nodes(match.nodes, aten.split_with_sizes)
923    if len(split_nodes) != 1 or len(cat_nodes) != 1:
924        return False
925    split_node, cat_node = split_nodes[0], cat_nodes[0]
926
927    # the cat node has other users: can't eliminate
928    if len(cat_node.users) > 1:
929        return False
930
931    # the dim of the cat and split should match
932    dim = get_arg_value(split_node, 2, "dim")
933    if dim != get_arg_value(cat_node, 1, "dim"):
934        return False
935
936    cat_inputs = list(get_arg_value(cat_node, 0))
937    split_sizes = get_arg_value(split_node, 1, "split_sizes")
938    # the number of input tensors in cat and the
939    # length of the split sizes should match
940    if len(cat_inputs) != len(split_sizes):
941        return False
942
943    for cat_input, split_size in zip(cat_inputs, split_sizes):
944        # each cat input tensor's size along dim
945        # should match the corresponding split size
946        if "val" not in cat_input.meta:
947            return False
948        cat_input_size = cat_input.meta["val"].size(dim)
949        if cat_input_size != split_size:
950            return False
951
952    return True
953
954
955@register_lowering_pattern(
956    CallFunction(
957        aten.split_with_sizes,
958        CallFunction(
959            aten.cat,
960            KeywordArg("input_"),
961            Ignored(),
962            _users=MULTIPLE,
963        ),
964        Ignored(),
965        Ignored(),
966    ),
967    pass_number=2,
968    extra_check=is_valid_cat_splitwithsizes,
969)
970def cat_splitwithsizes_replace(match, input_):
971    return input_
972
973
974def view_to_reshape(gm):
975    """
976    Replace view ops in the GraphModule to reshape ops.
977    """
978    for nd in gm.graph.find_nodes(
979        op="call_function", target=torch.ops.aten.view.default
980    ):
981        nd.target = torch.ops.aten.reshape.default
982
983
984def should_prefer_unfused_addmm(match):
985    inp = match.kwargs["inp"]
986    if not inp.meta["val"].is_cuda:
987        return False
988
989    output = match.output_node()
990    return all(is_pointwise_use(use) for use in output.users)
991
992
993@register_graph_pattern(
994    CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
995    pass_dict=pass_patterns[2],
996    extra_check=should_prefer_unfused_addmm,
997)
998def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
999    def repl(inp, x1, x2):
1000        return x1 @ x2 + inp
1001
1002    match.replace_by_example(repl, [inp, mat1, mat2])
1003
1004
1005def is_valid_addmm_fusion(match):
1006    mat1, mat2 = match.args
1007    inp = match.kwargs["inp"]
1008
1009    if not (
1010        isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor)
1011    ):
1012        return False  # Input is a number
1013
1014    in_shape = inp.meta["val"].shape
1015    mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1]
1016    matched = is_expandable_to(in_shape, mm_shape)
1017    if not matched:
1018        return False  # Shape mismatch
1019
1020    return not should_prefer_unfused_addmm(match)
1021
1022
1023@register_graph_pattern(
1024    CallFunction(
1025        aten.add,
1026        CallFunction(aten.mm, Arg(), Arg()),
1027        KeywordArg("inp"),
1028    ),
1029    pass_dict=pass_patterns[2],
1030    extra_check=is_valid_addmm_fusion,
1031)
1032@register_graph_pattern(
1033    CallFunction(
1034        aten.add,
1035        KeywordArg("inp"),
1036        CallFunction(aten.mm, Arg(), Arg()),
1037    ),
1038    pass_dict=pass_patterns[2],
1039    extra_check=is_valid_addmm_fusion,
1040)
1041def addmm(match, mat1, mat2, *, inp):
1042    def repl(inp, mat1, mat2):
1043        return aten.addmm(inp, mat1, mat2)
1044
1045    match.replace_by_example(repl, [inp, mat1, mat2])
1046
1047
1048def check_shape_cuda_and_fused_int_mm_mul_enabled(match):
1049    return (
1050        config.force_fuse_int_mm_with_mul
1051        and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2
1052        and getattr(match.args[2].meta.get("val"), "is_cuda", False)
1053    )
1054
1055
1056@register_lowering_pattern(
1057    CallFunction(
1058        prims.convert_element_type.default,
1059        CallFunction(
1060            aten.mul,
1061            CallFunction(
1062                aten._int_mm,
1063                Arg(),
1064                Arg(),
1065            ),
1066            Arg(),
1067        ),
1068        Arg(),
1069    ),
1070    check_shape_cuda_and_fused_int_mm_mul_enabled,
1071)
1072@register_lowering_pattern(
1073    CallFunction(
1074        aten.mul,
1075        CallFunction(
1076            aten._int_mm,
1077            Arg(),
1078            Arg(),
1079        ),
1080        Arg(),
1081    ),
1082    check_shape_cuda_and_fused_int_mm_mul_enabled,
1083)
1084def fused_int_mm_mul(match: Match, mat1, mat2, mat3, out_dtype=None):
1085    return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype)
1086
1087
1088def is_index_put_and_requires_h2d_sync_for_gpu_value(node):
1089    from torch.fx.operator_schemas import normalize_function
1090
1091    if node.target not in [
1092        torch.ops.aten.index_put.default,
1093        torch.ops.aten.index_put_.default,
1094    ]:
1095        return False
1096    # Inductor falls back to aten.index_put_.
1097    # index_put_ will will call nonzero() and perform a H2D sync if
1098    # any of its indices are bool/byte tensors
1099    # However, it will short-circuit this H2D sync and run mask_fill_
1100    # if the value we are putting is a cpu scalar.
1101    # Therefore, when inductor sees an index_put_ with byte tensor indices,
1102    # it should *not* convert the cpu scalar value into a gpu tensor.
1103    args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs)  # type: ignore[misc]
1104    any_byte_bool_indices = False
1105    indices = args_[1]
1106    for i in indices:
1107        if i is not None and i.meta["val"].dtype in [torch.bool, torch.int8]:
1108            any_byte_bool_indices = True
1109
1110    val = args_[2].meta["val"]
1111    val_is_cpu_scalar = val.device.type == "cpu" and val.numel() == 1
1112    # If both these conditions hold, then converting the val
1113    # to a gpu tensor will incur a H2D sync when inductor calls aten.index_put_
1114    return any_byte_bool_indices and val_is_cpu_scalar
1115
1116
1117class ConstructorMoverPass:
1118    def __init__(self, target: str, allow_outputs: bool = False) -> None:
1119        """
1120        Move constructors from cpu to the target_device.
1121
1122        Sweeps through the module, looking for constructor nodes that can be moved
1123        to the target_device.
1124
1125        A constructor node can be moved to the target_device iff all of its users
1126        can also be moved (tested by cannot_be_moved). Otherwise, all dependent
1127        constructor nodes won't be moved.
1128
1129        - target: target device type
1130        - allow_outputs: allow outputs to be moved
1131        """
1132
1133        self.target = target
1134        self.allow_outputs = allow_outputs
1135
1136        assert isinstance(target, str), (
1137            "target should be a string representing the device type. "
1138            f"Got: {type(target).__name__}"
1139        )
1140
1141    def allow_cpu_device(self, node: fx.Node) -> bool:
1142        """
1143        Returns whether a node that returns a tensor on the target device may have
1144        cpu tensors as input.
1145        """
1146        return node.target in (
1147            torch.ops.aten.index.Tensor,
1148            torch.ops.aten.index_put.default,
1149            torch.ops.aten.index_put_.default,
1150            torch.ops.aten.copy.default,
1151            torch.ops.aten.copy_.default,
1152            torch.ops.aten.slice_scatter.default,
1153        )
1154
1155    def cannot_be_moved(self, node: fx.Node) -> bool:
1156        """
1157        Returns whether a node can be moved to the target device.
1158
1159        If this function returns False, it means that this node and all of its users
1160        won't be moved into the target device.
1161        """
1162        if node.target == "output":
1163            return not self.allow_outputs
1164
1165        if not (
1166            isinstance(node.target, torch._ops.OpOverload)
1167            and node.target.namespace in ("prims", "aten")
1168        ):
1169            return True
1170        if is_index_put_and_requires_h2d_sync_for_gpu_value(node):
1171            return True
1172
1173        return False
1174
1175    def get_node_device(self, node: fx.Node) -> Optional[torch.device]:
1176        """
1177        Get the device of a node.
1178        """
1179        ten = node.meta.get("val")
1180        return None if not isinstance(ten, torch.Tensor) else ten.device
1181
1182    def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]:
1183        """
1184        Get the number of cpu inputs to a node
1185        """
1186        cpu_indeg: Dict[fx.Node, int] = Counter()
1187
1188        for node in graph.nodes:
1189            cpu_count = 0
1190
1191            def add_cpu_inp(node):
1192                nonlocal cpu_count
1193                device = self.get_node_device(node)
1194                cpu_count += device is not None and device.type == "cpu"
1195
1196            pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs))
1197
1198            if cpu_count:
1199                cpu_indeg[node] = cpu_count
1200
1201        return cpu_indeg
1202
1203    def __call__(self, graph: fx.Graph) -> None:
1204        target_devices = set()
1205        constructors = []
1206
1207        for node in graph.nodes:
1208            device = self.get_node_device(node)
1209            if device and device.type == self.target:
1210                target_devices.add(device)
1211
1212            if not (
1213                isinstance(node.target, torch._ops.OpOverload)
1214                and node.target.namespace in ("prims", "aten")
1215            ):
1216                continue
1217
1218            if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
1219                continue
1220
1221            if not node.kwargs.get("device") == torch.device("cpu"):
1222                continue
1223
1224            constructors.append(node)
1225
1226        # not handling multiple target devices initially
1227        if not constructors or len(target_devices) != 1:
1228            return
1229
1230        movable_constructors = self.find_movable_constructors(graph, constructors)
1231
1232        for node in movable_constructors:
1233            kwargs = node.kwargs.copy()
1234            kwargs["device"] = next(iter(target_devices))
1235            node.kwargs = kwargs
1236
1237    def find_movable_constructors(
1238        self, graph: fx.Graph, constructors: List[fx.Node]
1239    ) -> Set[fx.Node]:
1240        """
1241        Starting from the cpu constructors, iterate through the graph and test that all of their
1242        downstream uses can safely be moved to cpu.
1243        """
1244        cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph)
1245
1246        # which constructors cannot be moved to gpu
1247        cannot_move_to_gpu: Set[fx.Node] = set()
1248
1249        # For any node in the graph, which constructors does it have a dependency on
1250        constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set)
1251
1252        # if a cpu node has a dependency on two different cpu constructors,
1253        # then if either constructor cannot be moved to gpu, the other cannot as well.
1254        # In this case any node with a dependency on one will have a dependency on the other
1255        equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = {
1256            c: {c} for c in constructors
1257        }
1258
1259        def make_dependencies_equivalent(
1260            set1: Set[fx.Node], set2: Set[fx.Node]
1261        ) -> Set[fx.Node]:
1262            # could use union find but not worth complexity here
1263            set1.update(set2)
1264            for obj in set1:
1265                equal_constructor_sets[obj] = set1
1266            return set1
1267
1268        queue: List[fx.Node] = list(constructors)
1269
1270        for c in queue:
1271            constructor_dependencies[c].add(c)
1272
1273        while queue:
1274            node = queue.pop()
1275            dependencies = constructor_dependencies[node]
1276
1277            for user in node.users:
1278                if self.cannot_be_moved(user):
1279                    cannot_move_to_gpu.update(dependencies)
1280                    break
1281
1282                # this node was used on a op which takes in multiple devices and output a gpu
1283                # tensor. we can convert its cpu input to gpu without making further changes
1284                node_device = self.get_node_device(user)
1285                if (
1286                    self.allow_cpu_device(user)
1287                    and node_device
1288                    and node_device.type == self.target
1289                ):
1290                    del cpu_indeg[user]
1291                else:
1292                    # otherwise, we should continue look at its downstream uses
1293                    cpu_indeg[user] -= 1
1294                    if cpu_indeg[user] == 0:
1295                        del cpu_indeg[user]
1296                        queue.append(user)
1297
1298                unioned_set = make_dependencies_equivalent(
1299                    dependencies, constructor_dependencies[user]
1300                )
1301                constructor_dependencies[user] = unioned_set
1302
1303        for node in cpu_indeg:
1304            if constructor_dependencies[node]:
1305                cannot_move_to_gpu.update(constructor_dependencies[node])
1306
1307        all_cannot_move_to_gpu = cannot_move_to_gpu.copy()
1308        for constructor in cannot_move_to_gpu:
1309            all_cannot_move_to_gpu.update(equal_constructor_sets[constructor])
1310
1311        return set(constructors) - all_cannot_move_to_gpu
1312
1313
1314def move_constructors_to_gpu(graph: fx.Graph) -> None:
1315    """
1316    Moves intermediary tensors which are constructed on the cpu to gpu when safe
1317    """
1318    ConstructorMoverPass(get_gpu_type())(graph)
1319