xref: /aosp_15_r20/external/pytorch/torch/_inductor/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import collections
5import contextlib
6import dataclasses
7import enum
8import functools
9import inspect
10import io
11import itertools
12import logging
13import math
14import operator
15import os
16import platform
17import shutil
18import sys
19import tempfile
20import textwrap
21import time
22import unittest
23from datetime import datetime
24from io import StringIO
25from typing import (
26    Any,
27    Callable,
28    Dict,
29    Generic,
30    Iterable,
31    List,
32    NamedTuple,
33    Optional,
34    Protocol,
35    Sequence,
36    Set,
37    TypeVar,
38    Union,
39    ValuesView,
40)
41from typing_extensions import Concatenate, ParamSpec
42from unittest import mock
43
44import sympy
45
46import torch
47
48
49GPU_TYPES = ["cuda", "xpu"]
50
51
52# defines here before import torch._dynamo is for avoiding circular import
53# when get_gpu_type is imported from dynamo
54@functools.lru_cache(None)
55def get_gpu_type():
56    avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()]
57    assert len(avail_gpus) <= 1
58    gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop()
59    return gpu_type
60
61
62from torch._dynamo.device_interface import get_interface_for_device
63from torch._dynamo.utils import detect_fake_mode
64from torch.autograd import DeviceType
65from torch.autograd.profiler_util import EventList
66from torch.fx.passes.graph_transform_observer import GraphTransformObserver
67from torch.fx.passes.shape_prop import ShapeProp
68from torch.utils._sympy.functions import (
69    CeilDiv,
70    CleanDiv,
71    FloorDiv,
72    Identity,
73    ModularIndexing,
74)
75from torch.utils._sympy.symbol import make_symbol, SymT
76from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
77
78from . import config
79from .runtime.runtime_utils import ceildiv as runtime_ceildiv
80
81
82_IS_WINDOWS = sys.platform == "win32"
83
84log = logging.getLogger(__name__)
85
86_T = TypeVar("_T")
87VarRanges = Dict[sympy.Expr, sympy.Expr]
88InputType = Union[torch.Tensor, int]
89
90
91GPU_ALIGN_BYTES = 16
92ALIGNMENT = 16
93
94ALIGN_BYTES = 64
95assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"
96
97
98def _align(nbytes):
99    """Round up to the nearest multiple of ALIGN_BYTES"""
100    return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES
101
102
103def _is_aligned(v: sympy.Expr):
104    """v can be statically proven to be a multiple of ALIGN_BYTES"""
105    if isinstance(v, (sympy.Add, sympy.Max)):
106        return all(map(_is_aligned, v.args))
107    return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES
108
109
110class align(sympy.Function):
111    """Symbolically round up to the nearest multiple of ALIGN_BYTES"""
112
113    nargs = (1,)
114    is_integer = True
115
116    @classmethod
117    def eval(cls, value):
118        if isinstance(value, (int, sympy.Integer)):
119            return _align(int(value))
120        if _is_aligned(value):
121            return value
122
123
124def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float:
125    """
126    Returns benchmark results by examining torch profiler events.
127    This could be more accurate as it doesn't count CPU side overhead.
128    However, this also requires manually excluding irrelevant event, e.g.
129    vectorized_elementwise_kernel which is used to fill L2 cache,
130    various CUDA events, etc, so could also be fragile.
131    """
132
133    fn()
134    torch.cuda.synchronize()
135    cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
136
137    # Estimate the runtime of the function
138    start_event = torch.cuda.Event(enable_timing=True)
139    end_event = torch.cuda.Event(enable_timing=True)
140    start_event.record()
141    for _ in range(5):
142        cache.zero_()
143        fn()
144    end_event.record()
145    torch.cuda.synchronize()
146    estimate_ms = start_event.elapsed_time(end_event) / 5
147
148    # compute number of warmup and repeat
149    n_warmup = max(1, int(warmup / estimate_ms))
150    n_repeat = max(1, int(rep / estimate_ms))
151
152    # Warm-up
153    for _ in range(n_warmup):
154        fn()
155
156    with torch.profiler.profile(
157        activities=[
158            torch.profiler.ProfilerActivity.CUDA,
159        ]
160    ) as p:
161        # Benchmark
162        for i in range(n_repeat):
163            # we clear the L2 cache before each run
164            cache.zero_()
165            # record time of `fn`
166            fn()
167        # Record clocks
168        torch.cuda.synchronize()
169
170    log.debug("raw events")
171    log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1))
172
173    filtered_events = EventList(
174        [
175            event
176            for event in p.events()
177            if event.device_type == DeviceType.CUDA and event.name != "Context Sync"
178        ]
179    )
180    if len(filtered_events) % n_repeat != 0:
181        raise RuntimeError(
182            "Failed to divide all profiling events into #repeat groups. "
183            "#CUDA events: %d, #repeats: %s",
184            len(filtered_events),
185            n_repeat,
186        )
187    num_event_per_group = len(filtered_events) / n_repeat
188    actual_events = EventList(
189        [
190            event
191            for i, event in enumerate(filtered_events)
192            if i % num_event_per_group != 0
193        ]
194    )
195    actual_events._build_tree()
196    actual_events = actual_events.key_averages()
197
198    log.debug("profiling time breakdown")
199    log.debug(actual_events.table(row_limit=-1))
200
201    res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat
202    log.debug("profiling results: %s ms", res)
203    return res
204
205
206@functools.lru_cache(None)
207def has_torchvision_roi_align() -> bool:
208    try:
209        from torchvision.ops import roi_align  # noqa: F401
210
211        torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta")
212        return roi_align is not None and hasattr(
213            getattr(torch.ops, "torchvision", None), "roi_align"
214        )
215    except ImportError:
216        return False
217    except RuntimeError as e:
218        assert "torchvision::nms does not exist" in str(e)
219        return False
220
221
222def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
223    if device is None:
224        return torch.tensor(0.0).device  # default device
225    if isinstance(device, str):
226        device = torch.device(device)
227    if device.type not in ("cpu", "meta") and device.index is None:
228        device_interface = get_interface_for_device(device.type)
229        return torch.device(device.type, index=device_interface.Worker.current_device())
230    return device
231
232
233def sympy_product(it):
234    return functools.reduce(operator.mul, it, sympy.Integer(1))
235
236
237def sympy_dot(seq1, seq2):
238    assert len(seq1) == len(seq2)
239    return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
240
241
242def unique(it: Iterable[_T]) -> ValuesView[_T]:
243    return {id(x): x for x in it}.values()
244
245
246def ceildiv(
247    numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
248) -> Union[int, sympy.Expr]:
249    if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr):
250        return CeilDiv(sympy.sympify(numer), sympy.sympify(denom))
251    # TODO: There is a bug in a call to this function, to repro:
252    # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
253    # --amp --only YituTechConvBert --dynamic-shapes
254    assert isinstance(numer, int) and isinstance(
255        denom, int
256    ), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
257    return runtime_ceildiv(numer, denom)
258
259
260def _type_of(key):
261    # Use the function here to get rid of dependencies on the Triton during the codegen.
262    # Refer to Triton implementation here:
263    # https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
264    # `None` is nullptr.  Implicitly convert to *i8.
265    if key is None:
266        return "*i8"
267    dtype_str = str(key).split(".")[-1]
268    tys = {
269        "bool": "i1",
270        "float8e4nv": "fp8e4nv",
271        "float8e5": "fp8e5",
272        "float8e4b15": "fp8e4b15",
273        "float8e4b15x4": "fp8e4b15x4",
274        "float8_e4m3fn": "fp8e4nv",
275        "float8_e5m2": "fp8e5",
276        "float16": "fp16",
277        "bfloat16": "bf16",
278        "float32": "fp32",
279        "float64": "fp64",
280        "int8": "i8",
281        "int16": "i16",
282        "int32": "i32",
283        "int64": "i64",
284        "uint8": "u8",
285        "uint16": "u16",
286        "uint32": "u32",
287        "uint64": "u64",
288    }
289    # reinterpret can create triton type
290    for v in list(tys.values()):
291        tys[v] = v
292    return key if isinstance(key, str) else f"*{tys[dtype_str]}"
293
294
295def convert_shape_to_inductor(
296    lst: Iterable[Union[int, torch.SymInt]]
297) -> List[sympy.Expr]:
298    """
299    Gets the shape and stride of a tensor. For non-symbolic tensors, this is
300    trivial. But for symbolic tensors, we need to map from SymIntNode into
301    sympy.Expr.
302    """
303    return [sympy.sympify(i) for i in lst]
304
305
306def convert_shape_to_symint(
307    lst: Iterable[Union[int, sympy.Expr]]
308) -> List[Union[int, torch.SymInt]]:
309    """
310    Takes a list of shapes from Inductor and converts them into symints (or just
311    ints if all shapes are static).
312    """
313    from .virtualized import V
314
315    return [
316        i
317        if isinstance(i, int)
318        else int(i)
319        if isinstance(i, sympy.Integer)
320        else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
321        for i in lst
322    ]
323
324
325def is_view(op: torch._ops.OpOverload):
326    """
327    Does this op overload have aliasing
328    """
329    assert isinstance(op, torch._ops.OpOverload)
330    return any(a.alias_info is not None for a in op._schema.arguments)
331
332
333def is_pointwise_use(
334    use, is_pointwise_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None
335):
336    """
337    Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
338
339    Uses in views ops will follow the views uses
340    """
341
342    if not use.op == "call_function":
343        return False
344
345    if not (
346        isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
347    ):
348        return False
349
350    if use.target is operator.getitem or is_view(use.target):
351        return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
352
353    return torch.Tag.pointwise in use.target.tags or (
354        is_pointwise_fn is not None and is_pointwise_fn(use.target)
355    )
356
357
358def gen_gm_and_inputs(target, args, kwargs):
359    g = torch.fx.Graph()
360    g_args = []
361    a_args = []
362    for n, arg in enumerate(args):
363        if isinstance(arg, torch.Tensor):
364            g_args.append(g.placeholder(f"arg{n}"))
365            a_args.append(arg)
366        else:
367            g_args.append(arg)
368    assert all(not isinstance(x, torch.Tensor) for x in kwargs.values())
369    node = g.call_function(target, tuple(g_args), kwargs)
370    if (
371        len(target._schema.returns) == 1
372        and str(target._schema.returns[0].type) == "Tensor"
373    ):
374        node = (node,)  # type: ignore[assignment]
375    g.output(node)
376
377    gm = torch.fx.GraphModule({}, g)
378    return gm, a_args
379
380
381def synchronize(device: str = "cuda"):
382    if device == "cpu":
383        return
384    device_interface = get_interface_for_device(device)
385    if device_interface.is_available():
386        device_interface.synchronize()
387
388
389def timed(
390    model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda"
391) -> float:
392    synchronize(device)
393    torch.manual_seed(1337)
394    t0 = time.perf_counter()
395    for _ in range(times):
396        result = model(*example_inputs)
397        synchronize(device)
398    t1 = time.perf_counter()
399    # GC the result after timing
400    assert result is not None  # type: ignore[possibly-undefined]
401    return t1 - t0
402
403
404def print_performance(
405    fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda"
406):
407    timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)])
408    took = torch.median(timings) / times
409    print(f"{took / baseline:.6f}")
410    return took
411
412
413def precompute_method(obj: Any, method: str):
414    """Replace obj.method() with a new method that returns a precomputed constant."""
415    result = getattr(obj, method)()
416    setattr(obj, method, lambda: result)
417
418
419def precompute_methods(obj: Any, methods: List[str]):
420    """Replace methods with new methods that returns a precomputed constants."""
421    for method in methods:
422        precompute_method(obj, method)
423
424
425def cmp(a, b) -> int:
426    return int(a > b) - int(a < b)
427
428
429def pad_listlike(x, size):
430    if len(x) == 1:
431        return type(x)([x[0]]) * size
432    else:
433        return x
434
435
436# Used to ensure that iterating over a set is deterministic
437def tuple_sorted(x):
438    if len(x) == 0:
439        return []
440
441    def sort_func(elem):
442        if isinstance(elem, str):
443            return elem
444        else:
445            # We expect `elem` to be `scheduler.BaseSchedulerNode` type here,
446            # but we are not able to do isinstance assert because of circular dependency
447            return elem.get_name()
448
449    return sorted(x, key=sort_func)
450
451
452P = ParamSpec("P")
453RV = TypeVar("RV", covariant=True)
454
455
456class CachedMethod(Protocol, Generic[P, RV]):
457    @staticmethod
458    def clear_cache(self) -> None:
459        ...
460
461    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
462        ...
463
464
465# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
466def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
467    key = f"__{fn.__name__}_cache"
468
469    @functools.wraps(fn)
470    def wrapper(self):
471        if not hasattr(self, key):
472            setattr(self, key, fn(self))
473        return getattr(self, key)
474
475    def clear_cache(self):
476        if hasattr(self, key):
477            delattr(self, key)
478
479    wrapper.clear_cache = clear_cache  # type: ignore[attr-defined]
480    return wrapper  # type: ignore[return-value]
481
482
483def aggregate_origins(node_schedule):
484    from . import ir
485
486    if isinstance(node_schedule, list):
487        return functools.reduce(
488            operator.or_,
489            [
490                node.node.origins
491                for node in node_schedule
492                if hasattr(node, "node") and node.node
493            ],
494            set(),
495        )
496    elif isinstance(node_schedule, ir.ExternKernel):
497        return node_schedule.origins
498    else:
499        return set()
500
501
502def get_fused_kernel_name(node_schedule, descriptive_names):
503    all_origins = aggregate_origins(node_schedule)
504    if descriptive_names == "original_aten":
505        # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
506        sources = [
507            origin.meta["original_aten"]._overloadpacket.__name__
508            for origin in all_origins
509            if origin.op == "call_function"
510            and "original_aten" in origin.meta
511            and origin.meta["original_aten"] is not None
512        ]
513        sources = sorted(set(sources))
514    elif descriptive_names == "torch":
515        # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
516        sources = []
517        for origin in all_origins:
518            if origin.op == "call_function" and "source_fn_stack" in origin.meta:
519                source_fn = origin.meta["source_fn_stack"][-1]
520                if isinstance(source_fn[1], str):
521                    sources.append(source_fn[1])
522                else:
523                    sources.append(source_fn[1].__name__)
524        sources = sorted(set(sources))
525    elif descriptive_names == "inductor_node":
526        sources = [
527            origin.name for origin in all_origins if origin.op == "call_function"
528        ]
529    else:
530        raise NotImplementedError
531    sources = sources
532    return "_".join(["fused"] + sources)
533
534
535def get_kernel_metadata(node_schedule, wrapper):
536    all_origins = aggregate_origins(node_schedule)
537    inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
538
539    from_node_dict = collections.defaultdict(list)
540    original_aten_dict = collections.defaultdict(list)
541
542    # Attempt to sort `inductor_nodes` topologically. Note that the case
543    # where `inductor_nodes` contains nodes from multiple graph instances
544    # is not supported. An example of this is conditional statements.
545    single_graph = None
546    if len(inductor_nodes):
547        unique_graphs = {n.graph for n in inductor_nodes}
548        if len(unique_graphs) == 1:
549            single_graph = inductor_nodes[0].graph
550            # create a map of idx -> node and cache it
551            if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"):
552                node_to_idx_map = {}
553                for idx, n in enumerate(single_graph.nodes):
554                    node_to_idx_map[n] = idx
555                single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map
556            inductor_nodes.sort(
557                key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n]
558            )
559
560    for node in inductor_nodes:
561        if "original_aten" in node.meta and node.meta["original_aten"] is not None:
562            key = str(node.meta["original_aten"]._overloadpacket)
563            original_aten_dict[key].append(node.name)
564        if "from_node" in node.meta:
565            key = node.meta["from_node"][0][0]
566            from_node_dict[key].append(node.name)
567    sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted"
568    metadata = (
569        f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], "
570        f"Original ATen: [{', '.join(original_aten_dict.keys())}]"
571    )
572
573    # trace back to original node here
574    detailed_metadata = [f"{wrapper.comment} Source node to ATen node mapping:"]
575    for original_node, nodes in sorted(from_node_dict.items()):
576        detailed_metadata.append(
577            f"{wrapper.comment}   {original_node} => {', '.join(sorted(nodes))}"
578        )
579
580    # print the aot_autograd graph fragment
581    if single_graph is not None:
582        detailed_metadata.append(f"{wrapper.comment} Graph fragment:")
583        for n in inductor_nodes:
584            # TODO(future): maybe refactor torch/fx/graph.py to make it easy to
585            # generate python code for graph fragments
586            detailed_metadata.append(f"{wrapper.comment}   {n.format_node()}")
587
588    return metadata, "\n".join(detailed_metadata)
589
590
591def dominated_nodes(
592    initial_queue: Iterable[torch.fx.Node], skip_filter=None
593) -> Set[torch.fx.Node]:
594    """Returns the set of nodes whose values depend on those within initial_queue"""
595    initial_queue = list(initial_queue)
596    dominated_set = set(initial_queue)
597
598    while initial_queue:
599        node = initial_queue.pop()
600        for user in node.users:
601            if skip_filter and skip_filter(user):
602                continue
603            if user not in dominated_set:
604                dominated_set.add(user)
605                initial_queue.append(user)
606
607    return dominated_set
608
609
610def gather_origins(args, kwargs):
611    import itertools
612
613    from . import ir
614
615    def is_unrealized_node(n):
616        if isinstance(n, ir.TensorBox):
617            return is_unrealized_node(n.data)
618        if isinstance(n, ir.StorageBox):
619            return is_unrealized_node(n.data)
620        return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise)
621
622    kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)]
623    arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)]
624    return set(itertools.chain(*arg_origins, *kwarg_origins))
625
626
627def sympy_str(expr: sympy.Expr) -> str:
628    """
629    Normal sympy str is very slow, this is a lot faster.  The result are
630    somewhat worse, as it doesn't do as much simplification.  So don't
631    use this for final codegen.
632    """
633    if isinstance(expr, sympy.Symbol):
634        return expr.name
635    if isinstance(expr, sympy.Add):
636        return " + ".join(map(sympy_str, expr.args))
637    if isinstance(expr, sympy.Mul):
638        return " * ".join(map(sympy_str, expr.args))
639
640    if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)):
641        return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
642    return str(expr)
643
644
645def get_bounds_index_expr(index):
646    from .virtualized import V
647
648    # If this expression does not come from an FX node, we compute its bounds
649    if (
650        config.compute_all_bounds
651        and (fx_node := getattr(V.interpreter, "current_node", None))
652        and fx_node.target != "index_expr"
653    ):
654        return bound_sympy(index)
655    else:
656        return ValueRanges.unknown()
657
658
659def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
660    """
661    Used to generate an integer-nonnegative symbol.
662    """
663    # This should never be used for creating shape/stride symbols, as those
664    # should all be allocated before Inductor.
665    assert prefix != SymT.SIZE
666    # NOTE: shape symbols are positive (> 0), but index variables are only
667    # non-negative (>= 0).
668    return make_symbol(prefix, idx, integer=True, nonnegative=True)
669
670
671def generate_assert(check):
672    return (check or config.debug_index_asserts) and config.assert_indirect_indexing
673
674
675def sympy_index_symbol(name: str) -> sympy.Symbol:
676    """
677    Used to generate an integer-nonnegative symbol.
678    """
679    # This should never be used for creating shape/stride symbols, as those
680    # should all be allocated before Inductor.
681    assert name[0] != "s"
682    # NOTE: shape symbols are positive (> 0), but index variables are only
683    # non-negative (>= 0).
684    return sympy.Symbol(name, integer=True, nonnegative=True)
685
686
687def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr:
688    """
689    When the passed replacement symbol v is a string, it is converted to a symbol with name v that
690    have the same replaced expression integer and nonnegative properties.
691    """
692
693    def to_symbol(replaced, replacement):
694        assert isinstance(replaced, sympy.Expr)
695        if isinstance(replacement, str):
696            return sympy.Symbol(
697                replacement,
698                integer=replaced.is_integer,  # type: ignore[attr-defined]
699                nonnegative=replaced.is_nonnegative,  # type: ignore[attr-defined]
700            )
701        else:
702            return replacement
703
704    # xreplace is faster than subs, but is way more picky
705    return sympy.sympify(expr).xreplace(
706        {k: to_symbol(k, v) for k, v in replacements.items()}
707    )
708
709
710def is_symbolic(a: Any) -> bool:
711    return isinstance(a, torch.SymInt) or (
712        isinstance(a, torch.Tensor)
713        and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride()))
714    )
715
716
717def any_is_symbolic(*args: Any) -> bool:
718    return any(is_symbolic(a) for a in args)
719
720
721def get_first_incompatible_cudagraph_node(gm):
722    from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
723
724    forbidden_set = {
725        "aten._fused_moving_avg_obs_fq_helper.default",
726        "aten._fused_moving_avg_obs_fq_helper_functional.default",
727        "aten.multinomial.default",
728        "fbgemm.dense_to_jagged.default",
729        "fbgemm.jagged_to_padded_dense.default",
730        "run_and_save_rng_state",
731        "run_with_rng_state",
732        "aten._local_scalar_dense",
733        # Technically, it's not necessary to ban this, because an
734        # assert_scalar with constant arguments can be validly run
735        # with CUDA graphs, but the operator is also pointless with
736        # constant arguments, so might as well ban
737        "aten._assert_scalar",
738    }
739    if torch.are_deterministic_algorithms_enabled():
740        forbidden_set.update(
741            {
742                "aten._unsafe_index_put.default",
743                "aten._unsafe_masked_index_put_accumulate.default",
744                "aten.index_put.default",
745                "aten.index_put_.default",
746                "aten.scatter.src",
747                "aten.scatter.reduce",
748                "aten.scatter.value_reduce",
749                "aten.scatter_add_",
750                "aten.scatter_add.default",
751                "aten.scatter_reduce.two",
752                "aten.scatter_reduce_.two",
753                "aten.scatter_reduce.two_out",
754            }
755        )
756    for node in gm.graph.nodes:
757        if str(node.target) in forbidden_set:
758            return node
759        if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
760            return node
761    return None
762
763
764def has_incompatible_cudagraph_ops(gm):
765    return get_first_incompatible_cudagraph_node(gm) is not None
766
767
768def output_node(gm: torch.fx.GraphModule):
769    """Get the output node from an FX graph"""
770    last_node = next(iter(reversed(gm.graph.nodes)))
771    assert last_node.op == "output"
772    return last_node
773
774
775_registered_caches: List[Any] = []
776
777
778def clear_on_fresh_inductor_cache(obj: Any):
779    """
780    Use this decorator to register any caches that should be cache_clear'd
781    with fresh_inductor_cache().
782    """
783    if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
784        raise AttributeError(f"{obj} does not have a cache_clear method")
785
786    _registered_caches.append(obj)
787    return obj
788
789
790def clear_inductor_caches():
791    """
792    Clear all registered caches.
793    """
794    for obj in _registered_caches:
795        obj.cache_clear()
796
797
798@contextlib.contextmanager
799def fresh_inductor_cache(cache_entries=None, dir=None, delete=True):
800    """
801    Contextmanager that provides a clean tmp cachedir for inductor.
802
803    Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
804    generated with this cache instance.
805    """
806    clear_inductor_caches()
807
808    inductor_cache_dir = tempfile.mkdtemp(dir=dir)
809    try:
810        with mock.patch.dict(
811            os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
812        ):
813            log.debug("Using inductor cache dir %s", inductor_cache_dir)
814            triton_cache_dir = os.path.join(inductor_cache_dir, "triton")
815            with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
816                yield
817                if isinstance(cache_entries, dict):
818                    assert len(cache_entries) == 0, "expected empty cache_entries dict"
819                    if os.path.exists(triton_cache_dir):
820                        files = os.listdir(triton_cache_dir)
821                        cache_entries.update(
822                            {
823                                f: os.path.getsize(os.path.join(triton_cache_dir, f))
824                                for f in files
825                                if ".lock" not in f
826                            }
827                        )
828        if delete:
829            shutil.rmtree(inductor_cache_dir)
830    except Exception:
831        if not _IS_WINDOWS:
832            """
833            Windows can't delete the loaded modules, because the modules binaries are opened.
834            TODO: discuss if have better solution to handle this issue.
835            """
836            log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
837            raise
838    finally:
839        clear_inductor_caches()
840
841
842def argsort(seq) -> List[int]:
843    # preserve original order for equal strides
844    getter = seq.__getitem__
845    a_r = range(len(seq))
846    return list(reversed(sorted(a_r, key=getter, reverse=True)))  # noqa: C413
847
848
849@functools.lru_cache(8)
850def get_dtype_size(dtype):
851    return torch.empty((), dtype=dtype).element_size()
852
853
854class LineContext(NamedTuple):
855    context: Any
856
857
858class IndentedBuffer:
859    tabwidth = 4
860
861    def __init__(self, initial_indent=0):
862        self._lines = []
863        self._indent = initial_indent
864
865    def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]:
866        buf = StringIO()
867        p = 1
868        linemap = []
869        for line in self._lines:
870            if isinstance(line, DeferredLineBase):
871                line = line()
872                if line is None:
873                    continue
874            elif isinstance(line, LineContext):
875                linemap.append((p, line.context))
876                continue
877            assert isinstance(line, str)
878            buf.write(line)
879            buf.write("\n")
880            p += 1 + line.count("\n")
881        return buf.getvalue(), linemap
882
883    def getvalue(self) -> str:
884        v, _ = self.getvaluewithlinemap()
885        return v
886
887    def getrawvalue(self) -> str:
888        buf = StringIO()
889        for line in self._lines:
890            if isinstance(line, DeferredLineBase):
891                line = line()
892                if line is None:
893                    continue
894            elif isinstance(line, LineContext):
895                continue
896            assert isinstance(line, str)
897            # backslash implies line continuation
898            if line.endswith("\\"):
899                buf.write(line[:-1])
900            else:
901                buf.write(line)
902                buf.write("\n")
903        return buf.getvalue()
904
905    def clear(self):
906        self._lines.clear()
907
908    def __bool__(self):
909        return bool(self._lines)
910
911    def prefix(self):
912        return " " * (self._indent * self.tabwidth)
913
914    def newline(self):
915        self.writeline("\n")
916
917    def writeline(self, line):
918        if isinstance(line, LineContext):
919            self._lines.append(line)
920        elif isinstance(line, DeferredLineBase):
921            self._lines.append(line.with_prefix(self.prefix()))
922        elif line.strip():
923            self._lines.append(f"{self.prefix()}{line}")
924        else:
925            self._lines.append("")
926
927    def writelines(self, lines):
928        for line in lines:
929            self.writeline(line)
930
931    def indent(self, offset=1):
932        @contextlib.contextmanager
933        def ctx():
934            self._indent += offset
935            try:
936                yield
937            finally:
938                self._indent -= offset
939
940        return ctx()
941
942    def do_indent(self, offset=1):
943        self._indent += offset
944
945    def do_unindent(self, offset=1):
946        self._indent -= offset
947
948    def splice(self, other_code, strip=False):
949        if isinstance(other_code, IndentedBuffer):
950            dedent = float("inf")
951            for line in other_code._lines:
952                if not isinstance(line, LineContext) and line:
953                    dedent = min(dedent, len(line) - len(line.lstrip()))
954            if math.isinf(dedent):
955                dedent = 0
956            for line in other_code._lines:
957                if isinstance(line, LineContext):
958                    self._lines.append(line)
959                else:
960                    IndentedBuffer.writeline(self, line[int(dedent) :])
961        else:
962            other_code = textwrap.dedent(other_code)
963            if strip:
964                other_code = other_code.lstrip()
965            if not other_code:
966                return
967            other_code = other_code.rstrip()
968            for line in other_code.split("\n"):
969                self.writeline(line)
970
971    def map(self, func: Callable[[Any], Any]) -> IndentedBuffer:
972        res = IndentedBuffer(initial_indent=self._indent)
973        res._lines = [func(line) for line in self._lines]
974        return res
975
976    def __repr__(self):
977        return f"{type(self)}({self.getvalue()})"
978
979    def __add__(self, other):
980        assert self._indent == other._indent
981        res = IndentedBuffer(initial_indent=self._indent)
982        res.writelines(self._lines)
983        res.writelines(other._lines)
984        return res
985
986
987class FakeIndentedBuffer(IndentedBuffer):
988    def __init__(self) -> None:
989        super().__init__()
990
991    def __getattribute__(self, name):
992        if name == "__class__":  # Allow access to the class attribute
993            return object.__getattribute__(self, name)
994        raise RuntimeError(
995            f"Tried to call self.{name} on FakeIndentedBuffer. This buffer"
996            "is currently used on TritonTemplateKernel to prevent actual"
997            "writes to the body without explicitly specifying the body with"
998            "`TritonTemplateKernel.set_subgraph_body(name)`"
999        )
1000
1001
1002@contextlib.contextmanager
1003def restore_stdout_stderr(initial_stdout, initial_stderr):
1004    try:
1005        yield
1006    finally:
1007        sys.stdout = initial_stdout
1008        sys.stderr = initial_stderr
1009
1010
1011class DeferredLineBase:
1012    """A line that can be 'unwritten' at a later time"""
1013
1014    def __init__(self, line):
1015        if not line.strip():
1016            line = ""
1017        self.line = line
1018
1019    def __call__(self) -> Optional[str]:
1020        """Returns either self.line or None to indicate the line has been 'unwritten'"""
1021        raise NotImplementedError
1022
1023    def _new_line(self, line: str) -> DeferredLineBase:
1024        """Returns a new deferred line with the same condition"""
1025        raise NotImplementedError
1026
1027    def with_prefix(self, prefix):
1028        return self._new_line(f"{prefix}{self.line}")
1029
1030    def lstrip(self):
1031        return self._new_line(self.line.lstrip())
1032
1033    def __getitem__(self, index):
1034        return self._new_line(self.line[index])
1035
1036    def __bool__(self):
1037        return bool(self.line)
1038
1039    def __len__(self):
1040        return len(self.line)
1041
1042
1043@functools.lru_cache(None)
1044def is_big_gpu(index) -> bool:
1045    min_sms = 68  # 3080
1046    avail_sms = torch.cuda.get_device_properties(index).multi_processor_count
1047    if avail_sms < min_sms:
1048        log.warning(
1049            "Not enough SMs to use max_autotune_gemm mode",
1050            extra={"min_sms": min_sms, "avail_sms": avail_sms},
1051        )
1052        return False
1053    return True
1054
1055
1056def use_max_autotune() -> bool:
1057    return config.max_autotune or config.max_autotune_gemm
1058
1059
1060def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
1061    return (
1062        use_max_autotune()
1063        and layout.device.type == "cuda"
1064        and layout.dtype in allowed_layout_dtypes
1065        and is_big_gpu(layout.device.index or 0)
1066    )
1067
1068
1069def _use_autotune_backend(backend: str) -> bool:
1070    return backend.upper() in [
1071        x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
1072    ]
1073
1074
1075def _use_conv_autotune_backend(backend: str) -> bool:
1076    return backend.upper() in [
1077        x.strip() for x in config.max_autotune_conv_backends.upper().split(",")
1078    ]
1079
1080
1081def use_triton_template(layout, *, enable_int32=False, enable_float8=False):
1082    from .codegen.common import BackendFeature, has_backend_feature
1083
1084    layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
1085    if enable_int32:
1086        layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
1087    if enable_float8:
1088        layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2])
1089    return (
1090        _use_template_for_cuda(layout, layout_dtypes)
1091        and _use_autotune_backend("TRITON")
1092        and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
1093    )
1094
1095
1096def use_cutlass_template(layout, m, n, k):
1097    from .virtualized import V
1098
1099    gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
1100    if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size:
1101        return False
1102    from .codegen.cuda.cutlass_utils import try_import_cutlass
1103
1104    # Do not use cutlass template on ROCm
1105    if torch.version.hip:
1106        return False
1107
1108    layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
1109    res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
1110        "CUTLASS"
1111    )
1112
1113    if res:
1114        if not try_import_cutlass():
1115            log.warning(
1116                "Failed to import CUTLASS lib. Please check whether "
1117                "_inductor.config.cuda.cutlass_dir is set correctly. "
1118                "Skipping CUTLASS backend for now."
1119            )
1120            return False
1121    return res
1122
1123
1124@functools.lru_cache(None)
1125def _rocm_native_device_arch_name(device):
1126    return torch.cuda.get_device_properties(device).gcnArchName
1127
1128
1129@functools.lru_cache(None)
1130def try_import_ck_lib():
1131    try:
1132        import ck4inductor  # type: ignore[import]
1133        from ck4inductor.universal_gemm.gen_instances import (  # type: ignore[import]
1134            gen_ops_library,
1135            gen_ops_preselected,
1136        )
1137        from ck4inductor.universal_gemm.op import (  # type: ignore[import]
1138            CKGemmOperation,
1139        )
1140
1141        package_dirname = os.path.dirname(ck4inductor.__file__)
1142    except ImportError:
1143
1144        def gen_ops_library():
1145            return []
1146
1147        def gen_ops_preselected():
1148            return []
1149
1150        class CKGemmOperation:  # type: ignore[no-redef]
1151            pass
1152
1153        package_dirname = None
1154    return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation
1155
1156
1157def use_ck_template(layout, m, n, k):
1158    # config knobs check 1
1159    if not use_max_autotune():
1160        return False
1161    # config knobs check 2
1162    if not _use_autotune_backend("CK"):
1163        return False
1164    # platform check
1165    if not torch.version.hip:
1166        return False
1167    # tensors must be on GPU
1168    if not layout.device.type == "cuda":
1169        return False
1170    # hardware check
1171    # if config arch list is not specified, get the native arch from the device properties
1172    native_arch = _rocm_native_device_arch_name(layout.device)
1173    requested_archs = {k.split(":")[0]: k for k in config.rocm.arch} or {
1174        native_arch.split(":")[0]: native_arch
1175    }
1176    requested_supported_archs = [
1177        requested_archs[k]
1178        for k in requested_archs.keys() & config.rocm.ck_supported_arch
1179    ]
1180    if not requested_supported_archs:
1181        return False
1182    # supported input dtypes
1183    if layout.dtype not in [torch.float16, torch.bfloat16]:
1184        return False
1185    # TBD: investigate if we need to disable backend based on number of available CUs similar to `is_big_gpu`
1186    # check if shape is static and gemm size is not 0
1187    from .virtualized import V
1188
1189    gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
1190    if gemm_size <= 0:
1191        return False
1192    # TBD: investigate if backend needs to be disabled for small gemms similar to CUTLASS
1193
1194    ck_package_dirname, _, _, _ = try_import_ck_lib()
1195
1196    if not ck_package_dirname:
1197        log.warning("Please pip install Composable Kernel package")
1198        return False
1199
1200    if not config.rocm.ck_dir:
1201        log.warning("Please set TORCHINDUCTOR_CK_DIR env variable")
1202        return False
1203
1204    if ck_package_dirname != config.rocm.ck_dir:
1205        log.warning("Invalid path to CK library")
1206        return False
1207
1208    return True
1209
1210
1211def _use_template_for_cpu(layout):
1212    return use_max_autotune() and layout.device.type == "cpu"
1213
1214
1215def use_cpp_packed_gemm_template(layout, mat1, mat2, mat2_transposed=False):
1216    from . import ir
1217    from .codegen.cpp_micro_gemm import create_micro_gemm
1218    from .codegen.cpp_utils import get_gemm_template_output_and_compute_dtype
1219    from .kernel.mm_common import mm_args
1220
1221    if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"):
1222        return False
1223
1224    if not config.cpp.weight_prepack:
1225        return False
1226
1227    int8_gemm = mat1.get_dtype() == torch.uint8
1228    layout_dtypes = [torch.float32, torch.bfloat16, torch.half, torch.uint8]
1229    m, n, k, layout, mat1, mat2 = mm_args(
1230        mat1,
1231        mat2,
1232        out_dtype=layout.dtype if int8_gemm else None,
1233        mat2_transposed=mat2_transposed,
1234    )
1235
1236    # TODO(jgong5): support dynamic shapes for n or k
1237    if has_free_symbols((n, k)):
1238        return False
1239    if isinstance(mat2, ir.BaseView):
1240        mat2 = mat2.unwrap_view()
1241
1242    output_dtype, _ = get_gemm_template_output_and_compute_dtype(mat1.get_dtype())
1243    micro_gemm = create_micro_gemm(
1244        "micro_gemm",
1245        m,
1246        n,
1247        k,
1248        input_dtype=mat1.get_dtype(),
1249        input2_dtype=mat2.get_dtype(),
1250        output_dtype=output_dtype,
1251        num_threads=parallel_num_threads(),
1252    )
1253
1254    def is_last_dim_stride1(x):
1255        x.freeze_layout()
1256        return x.get_stride()[-1] == 1
1257
1258    return (
1259        layout.dtype in layout_dtypes
1260        and micro_gemm is not None
1261        and is_last_dim_stride1(mat1)  # TODO(jgong5): support transposed input
1262        and isinstance(mat2, ir.StorageBox)
1263        and mat2.is_module_buffer()
1264    )
1265
1266
1267def use_aten_gemm_kernels():
1268    return not use_max_autotune() or _use_autotune_backend("ATEN")
1269
1270
1271class DebugDirManager:
1272    counter = itertools.count(0)
1273    prev_debug_name: str
1274
1275    def __init__(self) -> None:
1276        self.id = next(DebugDirManager.counter)
1277
1278    def __enter__(self):
1279        self.prev_debug_name = torch._dynamo.config.debug_dir_root
1280        self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
1281        torch._dynamo.config.debug_dir_root = self.new_name
1282
1283    def __exit__(self, *args):
1284        shutil.rmtree(self.new_name)
1285        torch._dynamo.config.debug_dir_root = self.prev_debug_name
1286
1287
1288def run_and_get_code(fn, *args, **kwargs):
1289    from .graph import GraphLowering
1290
1291    source_codes: List[str] = []
1292
1293    def save_output_code(code: str):
1294        source_codes.append(code)
1295
1296    with mock.patch.object(GraphLowering, "save_output_code", save_output_code):
1297        torch._dynamo.reset()
1298        result = fn(*args, **kwargs)
1299    return result, source_codes
1300
1301
1302def run_fw_bw_and_get_code(fn):
1303    def run_with_backward():
1304        result = fn()
1305        result.sum().backward()
1306        return result
1307
1308    return run_and_get_code(run_with_backward)
1309
1310
1311def get_code(fn, *args, **kwargs):
1312    """Get the inductor-generated code, but skip any actual compilation or running."""
1313    from .graph import GraphLowering
1314
1315    source_codes: List[str] = []
1316
1317    def save_output_code(code: str):
1318        source_codes.append(code)
1319
1320    def patched_compile_to_module(self: GraphLowering):
1321        class DummyModule:
1322            """This is empty to replace the generated triton module"""
1323
1324            def __init__(self) -> None:
1325                pass
1326
1327            def call(self, *args, **kwargs):
1328                # Don't do anything when called
1329                pass
1330
1331        code, _ = (
1332            self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
1333        )
1334        # Skip all the actual compiling.
1335        nonlocal save_output_code
1336        save_output_code(code)
1337
1338        return DummyModule()
1339
1340    with mock.patch.object(
1341        GraphLowering, "compile_to_module", patched_compile_to_module
1342    ), mock.patch.object(GraphLowering, "save_output_code", save_output_code):
1343        torch._dynamo.reset()
1344        # Note the return here is None
1345        _ = fn(*args, **kwargs)
1346
1347    return source_codes
1348
1349
1350def get_triton_code(fn, *args, **kwargs):
1351    source_codes = get_code(fn, *args, **kwargs)
1352    # Can have two outputs if backwards was eagerly compiled
1353    assert (
1354        1 <= len(source_codes) <= 2
1355    ), f"expected one or two code outputs got {len(source_codes)}"
1356    return source_codes[0]
1357
1358
1359def run_and_get_triton_code(fn, *args, **kwargs):
1360    _, source_codes = run_and_get_code(fn, *args, **kwargs)
1361    # Can have two outputs if backwards was eagerly compiled
1362    assert (
1363        1 <= len(source_codes) <= 2
1364    ), f"expected one or two code outputs got {len(source_codes)}"
1365    return source_codes[0]
1366
1367
1368def run_and_get_graph_lowering(fn, *args, **kwargs):
1369    from torch._inductor.codecache import CompiledFxGraph
1370    from torch._inductor.graph import GraphLowering
1371
1372    real_init = CompiledFxGraph.__init__
1373    graph_lowerings = []
1374
1375    def fake_init(*args, **kwargs):
1376        real_init(*args, **kwargs)
1377        graph = args[2]
1378        assert isinstance(graph, GraphLowering)
1379        graph_lowerings.append(graph)
1380
1381    with mock.patch.object(CompiledFxGraph, "__init__", fake_init):
1382        result = fn(*args, **kwargs)
1383
1384    return result, graph_lowerings
1385
1386
1387@contextlib.contextmanager
1388def override_lowering(aten_op, override_fn):
1389    """
1390    Override the lowering of aten_op with override_fn.
1391    The first argument of override_fn is the original lowering fn.
1392    """
1393    from torch._inductor import lowering
1394
1395    orig_fn = lowering.lowerings[aten_op]
1396    try:
1397        lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn)
1398        yield
1399    finally:
1400        lowering.lowerings[aten_op] = orig_fn
1401
1402
1403def add_scheduler_init_hook(pre_fn, post_fn=None):
1404    """
1405    Add hook functions to be called at the beginning and end of Scheduler.__init__.
1406    Used for unit tests.
1407    """
1408    from torch._inductor.scheduler import Scheduler
1409
1410    orig_fn = Scheduler.__init__
1411
1412    def wrapper(scheduler, nodes):
1413        pre_fn(scheduler, nodes)
1414        out = orig_fn(scheduler, nodes)
1415        if post_fn:
1416            post_fn(scheduler, nodes)
1417        return out
1418
1419    return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
1420
1421
1422def developer_warning(msg):
1423    """
1424    Warnings that will be actionable for PyTorch developers, but not
1425    end users.  Allows us to easily disable them in stable releases but
1426    keep them on for nightly builds.
1427    """
1428    if config.developer_warnings:
1429        log.warning(msg)
1430    else:
1431        log.info(msg)
1432
1433
1434def get_benchmark_name():
1435    """
1436    An experimental API used only when config.benchmark_kernel is true.
1437
1438    The benchmark name is only available at codegen time. So we can not
1439    directly call it in benchmark_all_kernels which is run after codegen.
1440
1441    The function assumes the argument after --only is the benchmark name.
1442    It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
1443    scripts, this function may return None.
1444
1445    There are 2 flavors of --only argument we need handle:
1446    1. --only model_name
1447    2. --only=model_name
1448    """
1449    try:
1450        idx = sys.argv.index("--only")
1451        if (
1452            idx + 1 < len(sys.argv)
1453            and len(sys.argv[idx + 1]) > 0
1454            and sys.argv[idx + 1][0] != "-"
1455        ):
1456            return sys.argv[idx + 1]
1457    except ValueError:
1458        pass
1459
1460    for arg in sys.argv:
1461        if arg.startswith("--only="):
1462            return arg[len("--only=") :]
1463
1464
1465def is_ones(items):
1466    return all(x == 1 for x in items)
1467
1468
1469def is_zeros(items):
1470    return all(x == 0 for x in items)
1471
1472
1473def is_cpu_device(inputs):
1474    return all(
1475        item.device == torch.device("cpu")
1476        for item in inputs
1477        if isinstance(item, torch.Tensor)
1478    )
1479
1480
1481def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
1482    assert isinstance(
1483        val, sympy.Expr
1484    ), "only support sympy.Expr as input to get_sympy_Expr_dtype"
1485    if val.is_integer:  # type: ignore[attr-defined]
1486        return torch.int64
1487    else:
1488        return torch.float64
1489
1490
1491@contextlib.contextmanager
1492def maybe_profile(should_profile, *args, **kwargs):
1493    if should_profile:
1494        with torch.profiler.profile(*args, **kwargs) as p:
1495            yield p
1496    else:
1497        yield
1498
1499
1500def parallel_num_threads():
1501    threads = config.cpp.threads
1502    if threads < 1:
1503        threads = torch.get_num_threads()
1504    return threads
1505
1506
1507@functools.lru_cache(None)
1508def get_device_tflops(dtype):
1509    from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
1510
1511    assert dtype in (torch.float16, torch.bfloat16, torch.float32)
1512
1513    if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
1514        # Triton API change in https://github.com/openai/triton/pull/2293
1515        from torch._utils_internal import max_clock_rate
1516
1517        sm_clock = max_clock_rate()
1518        if dtype in (torch.float16, torch.bfloat16):
1519            return get_max_tensorcore_tflops(dtype, sm_clock)
1520
1521        if torch.backends.cuda.matmul.allow_tf32:
1522            return get_max_tensorcore_tflops(torch.float32, sm_clock)
1523        else:
1524            return get_max_simd_tflops(torch.float32, sm_clock)
1525    else:
1526        if dtype in (torch.float16, torch.bfloat16):
1527            return get_max_tensorcore_tflops(dtype)
1528
1529        if torch.backends.cuda.matmul.allow_tf32:
1530            return get_max_tensorcore_tflops(torch.float32)
1531        else:
1532            return get_max_simd_tflops(torch.float32)
1533
1534
1535@functools.lru_cache(None)
1536def get_gpu_dram_gbps():
1537    from triton.testing import get_dram_gbps
1538
1539    return get_dram_gbps()
1540
1541
1542def get_gpu_shared_memory():
1543    from triton.runtime import driver
1544
1545    return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)
1546
1547
1548def is_welford_reduction(reduction_type):
1549    return reduction_type.startswith("welford")
1550
1551
1552def reduction_num_outputs(reduction_type):
1553    return 3 if is_welford_reduction(reduction_type) else 1
1554
1555
1556def is_linux() -> bool:
1557    return platform.system() == "Linux"
1558
1559
1560def is_windows():
1561    return sys.platform == "win32"
1562
1563
1564def has_free_symbols(itr: Iterable[Any]):
1565    return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr)
1566
1567
1568def is_dynamic(*args):
1569    from . import ir
1570
1571    for t in args:
1572        if isinstance(t, ir.TensorBox):
1573            if has_free_symbols(t.data.get_size()) or (
1574                hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride())
1575            ):
1576                return True
1577        elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)):
1578            assert hasattr(t, "get_size") and hasattr(t, "get_stride")
1579            if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()):
1580                return True
1581        elif not isinstance(t, ir.IRNode):
1582            continue
1583        else:
1584            raise TypeError(f"unexpected type for is_dynamic {type(t)}")
1585
1586    return False
1587
1588
1589# Placeholder strings used in triton codegen.
1590class Placeholder(enum.Enum):
1591    # The placeholder for the actual name of a triton kernel.
1592    # e.g. for "def triton_" it would be "triton_"
1593    KERNEL_NAME = "KERNEL_NAME"
1594
1595    # The descriptive name of the triton kernel; when unique_kernel_names = False, this
1596    # placeholder will be replaced with a string with more information.
1597    DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
1598
1599
1600def pass_execution_and_save(func, gm, inp, msg):
1601    from .pattern_matcher import stable_topological_sort
1602
1603    with tempfile.NamedTemporaryFile(
1604        mode="w",
1605        encoding="utf-8",
1606        delete=False,
1607    ) as f:
1608        before_io = io.StringIO()
1609        after_io = io.StringIO()
1610        ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp)
1611        print(f"Before:\n{gm.graph}", file=f)
1612        print(gm.graph, file=before_io)
1613        start_time = datetime.now()
1614        with GraphTransformObserver(gm, msg, config.trace.log_url_for_graph_xform):
1615            func(gm.graph)
1616        time_elapsed = datetime.now() - start_time
1617        # recompile graph
1618        stable_topological_sort(gm.graph)
1619        gm.graph.lint()
1620        gm.recompile()
1621
1622        print(f"After:\n{gm.graph}", file=f)
1623        print(gm.graph, file=after_io)
1624        t = before_io.getvalue() == after_io.getvalue()
1625        log.info(
1626            "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
1627            msg,
1628            f.name,
1629            t,
1630            time_elapsed,
1631        )
1632
1633
1634def is_collective(node, op=None):
1635    from . import ir
1636
1637    return type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op)
1638
1639
1640def is_wait(node):
1641    from . import ir
1642
1643    return type(node) == ir._WaitKernel
1644
1645
1646def contains_collective(snode):
1647    from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode
1648
1649    assert isinstance(snode, BaseSchedulerNode)
1650    if isinstance(snode, GroupedSchedulerNode):
1651        return any(contains_collective(x) for x in snode.snodes)
1652    else:
1653        return is_collective(snode.node)
1654
1655
1656def contains_wait(snode):
1657    from torch._inductor.scheduler import BaseSchedulerNode, GroupedSchedulerNode
1658
1659    assert isinstance(snode, BaseSchedulerNode)
1660    if isinstance(snode, GroupedSchedulerNode):
1661        return any(contains_wait(x) for x in snode.snodes)
1662    else:
1663        return is_wait(snode.node)
1664
1665
1666def is_fallback_op(node, op):
1667    from . import ir
1668
1669    if isinstance(op, torch._ops.OpOverload):
1670        op = {op}
1671    return isinstance(node, ir.FallbackKernel) and node.op_overload in op
1672
1673
1674def buf_name_to_fused_snode(buf_name, name_to_buf, name_to_fused_node):
1675    return name_to_fused_node[name_to_buf[buf_name].defining_op.get_name()]
1676
1677
1678def find_recursive_deps_of_node(
1679    snode, collected_node_set, name_to_buf, name_to_fused_node, criteria_cb=None
1680):
1681    if criteria_cb and criteria_cb(snode):
1682        return
1683    collected_node_set.add(snode)
1684    for dep in snode.unmet_dependencies:
1685        defining_op_for_dep = buf_name_to_fused_snode(
1686            dep.name, name_to_buf, name_to_fused_node
1687        )
1688        if defining_op_for_dep in collected_node_set:
1689            continue
1690        find_recursive_deps_of_node(
1691            defining_op_for_dep,
1692            collected_node_set,
1693            name_to_buf,
1694            name_to_fused_node,
1695            criteria_cb=criteria_cb,
1696        )
1697
1698
1699def find_recursive_users_of_node(
1700    snode, collected_node_set, name_to_buf, name_to_fused_node, criteria_cb=None
1701):
1702    if criteria_cb and criteria_cb(snode):
1703        return
1704    collected_node_set.add(snode)
1705    for o in snode.get_outputs():
1706        for user in o.users:
1707            assert user.node is not None
1708            if user.node.get_name() == "OUTPUT":
1709                continue
1710            if user.node.get_name() not in name_to_fused_node:
1711                continue
1712            user_op = name_to_fused_node[user.node.get_name()]
1713            if user_op in collected_node_set:
1714                continue
1715            find_recursive_users_of_node(
1716                user_op,
1717                collected_node_set,
1718                name_to_buf,
1719                name_to_fused_node,
1720                criteria_cb=criteria_cb,
1721            )
1722
1723
1724def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int):
1725    "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)"
1726    num_rng_seed_offset_inputs = (
1727        2 if torch._functorch.config.functionalize_rng_ops else 0
1728    )
1729    # AOT won't lift any parameters if we're inlining NN Modules
1730    # however desugaring subclasses will still add arguments
1731    # resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502
1732    if (
1733        torch._dynamo.config.inline_inbuilt_nn_modules
1734        and not torch._dynamo.utils.is_parameter_freezing()
1735    ):
1736        return 0
1737
1738    return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
1739
1740
1741def count_tangents(fx_g: torch.fx.GraphModule):
1742    """
1743    Infers which inputs are static for a backwards graph
1744    """
1745
1746    def is_saved_tensor(x):
1747        return (
1748            "tangents" not in x.name
1749            and "bwd_seed" not in x.name
1750            and "bwd_base_offset" not in x.name
1751        )
1752
1753    arg_count = 0
1754    static_arg_idxs = []
1755    for n in fx_g.graph.nodes:
1756        if n.op == "placeholder":
1757            if is_saved_tensor(n):
1758                static_arg_idxs.append(arg_count)
1759            arg_count += 1
1760
1761    assert static_arg_idxs == list(range(len(static_arg_idxs)))
1762    return len(static_arg_idxs)
1763
1764
1765@dataclasses.dataclass
1766class BoxedBool:
1767    value: bool
1768
1769    def __bool__(self):
1770        return self.value
1771
1772    @staticmethod
1773    def disable(obj):
1774        if isinstance(obj, BoxedBool):
1775            obj.value = False
1776            return obj
1777        return False
1778
1779
1780@contextlib.contextmanager
1781def collect_defined_kernels(kernel_list):
1782    from .codegen.wrapper import WrapperCodeGen
1783
1784    orig_define_kernel = WrapperCodeGen.define_kernel
1785
1786    def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs):
1787        nonlocal kernel_list
1788        kernel_list.append(kernel_code)
1789        return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs)
1790
1791    with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel):
1792        yield
1793
1794
1795def get_cloned_parameter_buffer_name(name: str):
1796    return name + "__original__"
1797
1798
1799def is_gpu(device: str):
1800    assert isinstance(device, str) or device is None, device
1801    return device in ["cuda", "xpu"]
1802
1803
1804def device_need_guard(device: str):
1805    assert isinstance(device, str)
1806    return is_gpu(device)
1807
1808
1809def needs_fallback_due_to_atomic_add_limitations(dtype):
1810    # tl.atomic_add does NOT support the following types
1811    return dtype in {torch.int64, torch.bool, torch.bfloat16}
1812
1813
1814def use_scatter_fallback(
1815    op_overload: torch._ops.OpOverload,
1816    reduction_type,
1817    self_dtype,
1818    src_dtype,
1819    src_device_type,
1820    src_is_tensor,
1821):
1822    if (
1823        op_overload.overloadpacket
1824        in (torch.ops.aten.scatter_reduce_, torch.ops.aten.scatter_reduce)
1825        and reduction_type is None
1826    ):
1827        return False
1828
1829    reduce_ty = (
1830        "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum"
1831    )
1832
1833    return (
1834        reduction_type not in {None, reduce_ty}
1835        or (
1836            src_is_tensor
1837            and is_gpu(src_device_type)
1838            and needs_fallback_due_to_atomic_add_limitations(src_dtype)
1839        )
1840        or (
1841            op_overload.overloadpacket == torch.ops.aten.scatter_reduce_
1842            and reduction_type == "sum"
1843            and src_is_tensor
1844            and src_device_type == "cpu"
1845            and config.cpp.fallback_scatter_reduce_sum
1846            and (config.cpp.dynamic_threads or parallel_num_threads() != 1)
1847        )
1848        or (reduction_type == reduce_ty and self_dtype in {torch.bool, torch.int64})
1849        or torch.are_deterministic_algorithms_enabled()
1850    )
1851
1852
1853def dump_node_schedule(node_schedule):
1854    """
1855    An API that can be used in pdb to dump a node_schedule.
1856    Right mainly dump the read/write dependencies but can add more as needed.
1857    """
1858    from torch._inductor.codegen.simd import DisableReduction, EnableReduction
1859    from torch._inductor.scheduler import SchedulerNode
1860
1861    print(f"Node schedule with {len(node_schedule)} nodes")
1862    for idx, node in enumerate(node_schedule):
1863        print(f" {idx:3}:")
1864        if node is EnableReduction:
1865            print("enable reduction")
1866        elif node is DisableReduction:
1867            print("disable reduction")
1868        elif isinstance(node, SchedulerNode):
1869            is_red = node.is_reduction()
1870            print(f"{'red' if is_red else 'pw'} scheduler node")
1871            if is_red:
1872                assert node.node is not None
1873                print(f"original reduction hint {node.node.data.reduction_hint}")  # type: ignore[attr-defined]
1874            print("ReadDep:")
1875            for dep in node.read_writes.reads:
1876                print(dep)
1877            print("WriteDep:")
1878            for dep in node.read_writes.writes:
1879                print(dep)
1880        else:
1881            raise RuntimeError(f"Unrecognized node type: {type(node)}")
1882
1883
1884def tensor_is_aligned(tensor: torch.Tensor):
1885    # See Note: [Input Alignment handling in Inductor]
1886    # Right now, we don't try to guard on the alignment of the storage offset.
1887    # When this comment was written, non-symbolic storage_offsets are not guarded on
1888    # but symbolic storage_offsets are. For consistency, we suppress guard creation
1889    # upon performing this check: that ensures that we don't add recompiles when we
1890    # add this logic.
1891    from torch.fx.experimental.symbolic_shapes import statically_known_true
1892
1893    return statically_known_true(
1894        (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % GPU_ALIGN_BYTES == 0
1895    )
1896
1897
1898def should_assume_input_aligned(example_input: torch.Tensor):
1899    # See Note: [Input Alignment handling in Inductor]
1900
1901    # right now, we only care about alignment for cuda tensors.
1902    if not is_gpu(example_input.device.type):
1903        return False
1904    return config.assume_aligned_inputs or tensor_is_aligned(example_input)
1905
1906
1907def maybe_get_suppress_shape_guards_ctx():
1908    # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards()
1909    # If it's not available, return a nullcontext.
1910
1911    # If we're dealing with cudagraphs, we might not have a tracing_context
1912    tracing_context = torch._guards.TracingContext.try_get()
1913    if not tracing_context:
1914        return contextlib.nullcontext()
1915
1916    # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
1917    shape_env = tracing_context.fake_mode.shape_env
1918    if not shape_env:
1919        return contextlib.nullcontext()
1920
1921    return shape_env.suppress_guards()
1922
1923
1924def run_and_get_cpp_code(fn, *args, **kwargs):
1925    # We use the patch context manager instead of using it as a decorator.
1926    # In this way, we can ensure that the attribute is patched and unpatched correctly
1927    # even if this run_and_get_cpp_code function is called multiple times.
1928    with unittest.mock.patch.object(config, "debug", True):
1929        torch._dynamo.reset()
1930        import io
1931        import logging
1932
1933        log_capture_string = io.StringIO()
1934        ch = logging.StreamHandler(log_capture_string)
1935        from torch._inductor.codecache import output_code_log
1936
1937        output_code_log.addHandler(ch)
1938        prev_level = output_code_log.level
1939        output_code_log.setLevel(logging.DEBUG)
1940        result = fn(*args, **kwargs)
1941        s = log_capture_string.getvalue()
1942        output_code_log.setLevel(prev_level)
1943        output_code_log.removeHandler(ch)
1944    return result, s
1945
1946
1947def shape_env_from_inputs(inputs: List[torch.Tensor]):
1948    shape_env = None
1949    fake_mode = detect_fake_mode(inputs)
1950
1951    # TODO(voz): It would be nice to enable this assert, but there are lots of tests that
1952    # pass in real inputs for now.
1953    # if len(inputs) > 0:
1954    # assert fake_mode is not None, breakpoint()
1955
1956    if fake_mode is not None:
1957        return fake_mode.shape_env
1958
1959    # When there are no tensor inputs, get shape_env from the first SymInt.
1960    for input in inputs:
1961        if isinstance(input, torch.SymInt):
1962            return input.node.shape_env
1963
1964    # TODO(voz): Should we always have one anyway?
1965    return None
1966
1967
1968def align_inputs_from_check_idxs(
1969    model: Callable[[List[InputType]], Any],
1970    inputs_to_check: Sequence[int],
1971) -> Callable[[List[InputType]], Any]:
1972    if len(inputs_to_check) == 0:
1973        return model
1974
1975    def run(new_inputs: List[InputType]):
1976        copy_misaligned_inputs(new_inputs, inputs_to_check)
1977        return model(new_inputs)
1978
1979    return run
1980
1981
1982def clone_preserve_strides(x: torch.Tensor):
1983    needed_size = (
1984        sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
1985    )
1986    buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
1987    return torch.as_strided(buffer, x.size(), x.stride())
1988
1989
1990def copy_misaligned_inputs(
1991    new_inputs: List[InputType], check_inputs_idxs: Sequence[int]
1992) -> None:
1993    for i in check_inputs_idxs:
1994        _inp = new_inputs[i]
1995        assert isinstance(_inp, torch.Tensor)
1996        if _inp.data_ptr() % ALIGNMENT:
1997            new_inputs[i] = clone_preserve_strides(_inp)
1998
1999
2000def remove_unaligned_input_idxs(
2001    inputs: List[InputType],
2002    static_input_idxs: Sequence[int],
2003):
2004    """
2005    We require all inputs to be aligned, so introduce a copy for any
2006    that aren't.
2007    """
2008    aligned_static_input_idxs = []
2009    for idx in static_input_idxs:
2010        input = inputs[idx]
2011        if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
2012            aligned_static_input_idxs.append(idx)
2013    if len(aligned_static_input_idxs) != len(static_input_idxs):
2014        return aligned_static_input_idxs
2015    return static_input_idxs
2016
2017
2018def set_tracing_context_output_strides(example_inputs, compiled_graph):
2019    # Return the output strides to the caller via TracingContext
2020    context = torch._guards.TracingContext.try_get()
2021    if context is not None and context.output_strides is not None:
2022        assert len(context.output_strides) == 0
2023        shape_env = shape_env_from_inputs(example_inputs)
2024        for exprs in compiled_graph.output_strides:
2025            if exprs is None:
2026                context.output_strides.append(None)
2027            else:
2028                context.output_strides.append(
2029                    tuple(
2030                        (
2031                            shape_env.evaluate_symexpr(e)
2032                            if shape_env is not None
2033                            else int(e)
2034                        )
2035                        for e in exprs
2036                    )
2037                )
2038