xref: /aosp_15_r20/external/pytorch/torch/_inductor/graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import itertools
4import logging
5import operator
6import os
7import re
8import sys
9import time
10from collections import defaultdict
11from contextlib import contextmanager
12from typing import (
13    Any,
14    Callable,
15    DefaultDict,
16    Dict,
17    List,
18    Optional,
19    Set,
20    Tuple,
21    TYPE_CHECKING,
22    Union,
23)
24
25import sympy
26
27import torch
28import torch._logging
29import torch.fx
30from torch._decomp import get_decompositions
31from torch._dynamo.utils import defake, dynamo_timed
32from torch._logging import LazyString, trace_structured
33from torch._prims_common import make_channels_last_strides_for
34from torch._subclasses.fake_tensor import FakeTensor
35from torch.fx.experimental._backward_state import BackwardState
36from torch.fx.experimental.sym_node import magic_methods, method_to_operator
37from torch.fx.experimental.symbolic_shapes import (
38    free_unbacked_symbols,
39    has_free_symbols,
40    resolve_unbacked_bindings,
41    RuntimeAssert,
42    ShapeEnv,
43    SymTypes,
44)
45from torch.utils._mode_utils import no_dispatch
46from torch.utils._sympy.numbers import int_oo
47
48from . import config, ir
49from .codegen.common import (
50    DeviceOpOverrides,
51    get_device_op_overrides,
52    get_scheduling_for_device,
53    get_wrapper_codegen_for_device,
54    register_backend_for_device,
55)
56from .codegen.cpp_wrapper_cpu import CppWrapperCpu
57from .codegen.cpp_wrapper_cuda import CppWrapperCuda
58from .codegen.wrapper import WrapperCodeGen
59from .exc import (
60    CppWrapperCodeGenError,
61    LoweringException,
62    MissingOperatorWithDecomp,
63    MissingOperatorWithoutDecomp,
64)
65from .ir import (
66    Constant,
67    FixedLayout,
68    InputBuffer,
69    Pointwise,
70    Reduction,
71    StorageBox,
72    TensorBox,
73    TorchBindObject,
74)
75from .lowering import (
76    constrain_to_fx_strides,
77    FALLBACK_ALLOW_LIST,
78    fallback_handler,
79    fallback_node_due_to_unsupported_type,
80    layout_constraints,
81    lowerings,
82    make_fallback,
83    needs_realized_inputs,
84    unsupported_output_tensor,
85)
86from .sizevars import SizeVarAllocator
87from .utils import (
88    convert_shape_to_inductor,
89    gather_origins,
90    get_cloned_parameter_buffer_name,
91    get_sympy_Expr_dtype,
92    maybe_get_suppress_shape_guards_ctx,
93    should_assume_input_aligned,
94)
95from .virtualized import NullHandler, V
96
97if TYPE_CHECKING:
98    from torch._higher_order_ops.effects import _EffectType
99
100log = logging.getLogger(__name__)
101perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
102output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
103aten = torch.ops.aten
104
105_post_grad_graph_counter = itertools.count()
106
107if config.is_fbcode():
108    from torch._inductor.fb.utils import log_module_code
109else:
110
111    def log_module_code(*args, **kwargs):
112        pass
113
114
115def supported_dtype_of_cpp_wrapper(dtype, cuda):
116    supported_dtype = {
117        torch.float32,
118        torch.float64,
119        torch.int64,
120        torch.int32,
121        torch.int16,
122        torch.int8,
123        torch.uint8,
124        torch.bool,
125        torch.bfloat16,
126        torch.complex32,
127        torch.complex64,
128        torch.complex128,
129        torch.float16,
130    }
131    if cuda:
132        supported_dtype.add(torch.float8_e4m3fn)
133        supported_dtype.add(torch.float8_e5m2)
134        supported_dtype.add(torch.float8_e4m3fnuz)
135        supported_dtype.add(torch.float8_e5m2fnuz)
136
137    return dtype in supported_dtype
138
139
140def may_get_constant_buffer_dtype(constant_buffer):
141    assert isinstance(
142        constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
143    ), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
144    if isinstance(constant_buffer, sympy.core.numbers.Integer):
145        return torch.int64
146
147    if isinstance(constant_buffer, sympy.Expr):
148        return get_sympy_Expr_dtype(constant_buffer)
149
150    if constant_buffer.is_integer:
151        return torch.int64
152    elif constant_buffer.is_float:
153        return torch.float32
154    else:
155        return None
156
157
158def is_magic_method(op):
159    magic_ops = {method_to_operator(m) for m in magic_methods}
160    return op in magic_ops
161
162
163def getattr_recursive(obj, target):
164    target_atoms = target.split(".")
165    attr_itr = obj
166    for i, atom in enumerate(target_atoms):
167        if not hasattr(attr_itr, atom):
168            raise RuntimeError(
169                f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
170            )
171        attr_itr = getattr(attr_itr, atom)
172    return attr_itr
173
174
175def mark_nodes_dislike_padding(g):
176    """
177    Nodes like convolution/convolution_backward want its input to be dense.
178    If we pad their inputs, we result in extra calls to copy kernels!  On the other hand, padding usually helps reduction.
179
180    The pass finds nodes that dislike padding. These are nodes that can be reached
181    from a convolution/convolution_backward in the backward direction without
182    going thru a reduction.
183    """
184    if not config.comprehensive_padding:
185        return
186    ops_dislike_padding = {
187        aten.convolution,
188        aten.convolution_backward,
189    }
190    # what's a better way to collect the reduction ops?
191    ops_like_padding = {
192        aten.var_mean,
193        aten.sum,
194        aten.mean,
195        aten.prod,
196        aten.any,
197        aten.amin,
198        aten.amax,
199        aten.min,
200        aten.max,
201        aten.argmin,
202        aten.argmax,
203        aten.scatter_reduce,
204    }
205
206    def _get_overload_packet(node):
207        return (
208            node.target._overloadpacket
209            if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
210            else None
211        )
212
213    for cur in reversed(g.nodes):
214        op = _get_overload_packet(cur)
215        if not op:
216            continue
217        if op in ops_dislike_padding:
218            cur.meta["dislike_padding"] = True
219
220        if cur.meta.get("dislike_padding", False):
221            # propagate
222            for prior in cur.all_input_nodes:
223                prior_op = _get_overload_packet(prior)
224                if not prior_op:
225                    continue
226                if prior_op not in ops_like_padding:
227                    prior.meta["dislike_padding"] = True
228
229
230class GraphLowering(torch.fx.Interpreter):
231    graph_outputs: List[ir.IRNode]
232
233    def symbolic_sizes_strides(self, ex: torch.Tensor):
234        """
235        Support dynamic shapes and dynamic strides by assigning variables
236        to each dimension.  We duck-shape tensors, so if two tensors
237        have the same size they get assigned the same symbolic variable.
238        """
239        if self.reuse_shape_env:
240            return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
241                ex.stride()
242            )
243        else:
244            from torch._dynamo.source import ConstantSource
245
246            # TODO: this should not be needed once #93059 lands
247            # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
248            # TODO: make a dedicated UnknownSource for this?
249            # NB: This is using the legacy default behavior from
250            # create_symbolic_sizes_strides_storage_offset but we hope we can
251            # just delete this entirely
252            source = ConstantSource(
253                f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}"
254            )
255            (
256                size,
257                stride,
258                _,
259            ) = self._shape_env.create_symbolic_sizes_strides_storage_offset(
260                ex,
261                source,
262            )
263
264        size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
265        stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
266        return size, stride
267
268    def static_sizes_strides(self, ex: torch.Tensor):
269        """
270        Primarily used to weights
271        """
272        size = [sympy.Integer(i) for i in ex.size()]
273        stride = [sympy.Integer(i) for i in ex.stride()]
274        return size, stride
275
276    def init_backend_registration(self):
277        if get_scheduling_for_device("cpu") is None:
278            from .codegen.cpp import CppScheduling
279
280            register_backend_for_device(
281                "cpu", CppScheduling, WrapperCodeGen, CppWrapperCpu
282            )
283
284        if get_scheduling_for_device("cuda") is None:
285            from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
286
287            # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
288            register_backend_for_device(
289                "cuda", CUDACombinedScheduling, WrapperCodeGen, CppWrapperCuda
290            )
291
292        if get_scheduling_for_device("xpu") is None:
293            from .codegen.triton import TritonScheduling
294
295            register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen)
296
297    def __init__(
298        self,
299        gm: torch.fx.GraphModule,
300        example_inputs: Optional[List[torch.Tensor]] = None,
301        shape_env=None,
302        graph_id=None,
303        cpp_wrapper=False,
304        aot_mode=False,
305        user_visible_outputs=None,
306        layout_opt=None,
307        extern_node_serializer=None,
308        is_inference=False,
309        is_const_graph=False,
310        const_output_index=None,
311        const_code=None,
312        const_module=None,
313        name=None,
314    ):
315        super().__init__(gm)
316        self.example_inputs = example_inputs
317        self.layout_opt = (
318            layout_opt
319            if layout_opt is not None
320            else self.decide_layout_opt(gm, is_inference=is_inference)
321        )
322        self.num_channels_last_conv = 0
323        self.is_inference = is_inference
324        self.is_const_graph = is_const_graph
325        self.const_code = const_code
326        self.const_module = const_module
327
328        self.extra_traceback = False  # we do our own error wrapping
329        if shape_env is None:
330            shape_env = ShapeEnv()
331            self.reuse_shape_env = False
332        else:
333            self._shape_env = shape_env
334            self.reuse_shape_env = True
335        self._shape_env = shape_env
336        # We are going to start code generating runtime asserts, so make sure
337        # you don't start adding new ones in the lowering process
338        shape_env.freeze_runtime_asserts()
339        # We're going to mutate ras_by_symbol as we finish generating them
340        self.ras_by_symbol: Dict[
341            sympy.Symbol, List[RuntimeAssert]
342        ] = shape_env.deferred_runtime_asserts.copy()
343        self.bound_unbacked_symbols: Set[sympy.Symbol] = set()
344        self.sizevars = SizeVarAllocator(shape_env)
345        self.graph_input_names: List[str] = []
346        self.graph_inputs: Dict[str, TensorBox] = {}
347        self.graph_inputs_original: Dict[str, InputBuffer] = {}
348        self.device_types: Set[str] = (
349            const_module.device_types if const_module else set()
350        )
351        self.device_idxs: Set[int] = const_module.device_idxs if const_module else set()
352        self.cuda = False
353        self.buffers: List[ir.Buffer] = []
354        self.const_output_index: Dict[str, int] = (
355            const_output_index if const_output_index else {}
356        )
357        self.folded_constants: Set[str] = (
358            set(const_output_index.keys()) if const_output_index else set()
359        )
360        self.constants: Dict[str, torch.Tensor] = (
361            const_module.constants if const_module else {}
362        )
363        self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {}
364        self.constant_reprs: Dict[str, str] = {}
365        self.removed_buffers: Set[str] = set()
366        self.removed_inplace_buffers: Set[str] = set()
367        self.mutated_buffers: Set[str] = set()
368        self.never_reuse_buffers: Set[str] = set()
369        self.inplaced_to_remove: Set[str] = set()
370        self.device_ops: DeviceOpOverrides = None  # type: ignore[assignment]
371        self.wrapper_code: WrapperCodeGen = None  # type: ignore[assignment]
372        # See `ProxyExecutor Design Note` in ir.py for more details
373        self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
374        self.extern_node_serializer: Optional[
375            Callable[[List[ir.ExternKernelNode]], Any]
376        ] = extern_node_serializer
377        self.current_node: torch.fx.Node = None  # type: ignore[assignment]
378        self.lists: Dict[str, List[str]] = {}
379        self.mutated_inputs: Set[str] = set()
380        self.mutated_input_idxs: List[int] = []
381        self.name_to_buffer: Dict[str, ir.Buffer] = {}
382        self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
383        self.creation_time = time.time()
384        self.name = name
385        self.cpp_wrapper = cpp_wrapper
386
387        # record multi_kernel choice for cpp_wrapper so the second pass knows
388        # which sub-kernel is picked. Copy cpp_wrapper to another variable
389        # since cpp_wrapper flag is set to false for the first pass of codegen.
390        self.record_multi_kernel_choice = cpp_wrapper
391        self.multi_kernel_to_choice: Dict[str, int] = {}
392
393        self.aot_mode = aot_mode
394        self.graph_id = graph_id
395        self.post_grad_graph_id = next(_post_grad_graph_counter)
396        self.scheduler: torch._inductor.scheduler.Scheduler = None  # type: ignore[assignment]
397        self.nodes_prefer_channels_last = (
398            self.find_nodes_prefer_channels_last() if self.layout_opt else set()
399        )
400        mark_nodes_dislike_padding(gm.graph)
401        self._warned_fallback = {"aten.convolution_backward"}
402        self.user_visible_outputs = (
403            user_visible_outputs if user_visible_outputs is not None else {}
404        )
405        self.cache_key: str = ""  # This is the cache key for the compiled artifact
406        self.cache_path: str = ""  # This is the path in the filesystem where the compiled artifact is stored
407        self.cache_linemap: List[
408            Tuple[int, str]
409        ] = (
410            []
411        )  # This is the linemap used by the profiler to mark custom compiled kernels getting run
412        # Used if lowering encounters cases where cudagraphs are not supported
413        self.disable_cudagraphs_reason: Optional[str] = None
414
415        # only keeping one node per device for stack trace purposes
416        self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
417        self.orig_gm: torch.fx.GraphModule = gm.__copy__()
418        self.dynamo_flat_name_to_original_fqn = self.module.meta.get(
419            "dynamo_flat_name_to_original_fqn", {}
420        )
421        self.allocated_constant_name = (
422            const_module.allocated_constant_name if const_module is not None else {}
423        )
424        self.init_backend_registration()
425
426        self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
427
428        self.aligned_inputs: Set[str] = set()
429
430    @staticmethod
431    def decide_layout_opt(gm, *, is_inference) -> bool:
432        """
433        Decide if we should enable layout optimization for this graph based on
434        heuristics.
435        """
436        if not config.layout_optimization:
437            return False
438
439        if config.force_layout_optimization:
440            return True
441
442        conv_nodes = [
443            n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default
444        ]
445        nconv = len(conv_nodes)
446
447        if nconv == 0:
448            return False
449
450        # For cpu backend and mkldnn enabled, we always use channels_last for better performance.
451        if (
452            torch.backends.mkldnn.enabled
453            and torch.backends.mkldnn.is_available()
454            and all(
455                n.args[idx].meta["val"].device == torch.device("cpu")
456                for n in conv_nodes
457                for idx in [0, 1]
458            )
459        ):
460            return True
461
462        # Following models are skipped due to this:
463        # jx_nest_base
464        # volo_d1_224
465        if len(list(gm.graph.nodes)) >= 300 * nconv:
466            log.debug("Skipped layout opt because only a few conv")
467            return False
468
469        if any(
470            has_free_symbols(n.args[idx].meta["val"])
471            for n in conv_nodes
472            for idx in [0, 1]
473        ):
474            log.debug(
475                "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670"
476            )
477            return False
478
479        def is_grouped(n):
480            return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1
481
482        def is_in_out_channel(n):
483            return (
484                n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1)
485                and n.args[1].meta["val"].size(2) > 1
486            )
487
488        def is_small_channel(n):
489            return (
490                n.args[1].meta["val"].size(0) <= 64
491                and n.args[1].meta["val"].size(1) <= 64
492            )
493
494        # only grouped convolutions benchmarked as slower in conv samples for inference only
495        if is_inference:
496            from torch.utils.flop_counter import FlopCounterMode
497
498            flop_counts: Dict[str, float] = defaultdict(float)
499            for node in conv_nodes:
500                success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
501                    node
502                )
503
504                if success:
505                    with FlopCounterMode(display=False) as flop_counter_mode:
506                        with V.fake_mode:
507                            node.target(*args, **kwargs)
508
509                    counted_flops = flop_counter_mode.get_total_flops()
510                    if is_grouped(node):
511                        node_type = "grouped"
512                    elif is_small_channel(node):
513                        node_type = "small"
514                    elif is_in_out_channel(node):
515                        node_type = "in_out"
516                    else:
517                        node_type = "default"
518
519                    flop_counts[node_type] += counted_flops
520                else:
521                    log.debug("Conv inputs meta not found")
522
523            # average benchmarked channels last speedup / slowdown, < 1 is speedup.
524            # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/
525            # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb
526            GROUPED_MULTIPLIER = 1.358
527            DEFAULT_MULTIPLIER = 0.823
528            IN_OUT_MULTIPLIER = 0.725
529            SMALL_MULTIPLIER = 0.783
530
531            total_flops = sum(flop_counts.values())
532            # TODO - get different values per hardware
533            weighted_flops = (
534                flop_counts["grouped"] * GROUPED_MULTIPLIER
535                + flop_counts["small"] * SMALL_MULTIPLIER
536                + flop_counts["in_out"] * IN_OUT_MULTIPLIER
537                + flop_counts["default"] * DEFAULT_MULTIPLIER
538            )
539            do_layout_opt = weighted_flops <= total_flops
540            if not do_layout_opt:
541                log.debug(
542                    "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d",
543                    total_flops,
544                    weighted_flops,
545                )
546            return do_layout_opt
547
548        # Channels last layout can dramatically hurt grouped conv perf. E.g.
549        # Conv with arguments like
550        #   {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3],
551        #    "stride": [2, 2], "padding": [1, 1], "groups": 2}
552        # slows down 31x using channels last..
553
554        # But a lot of timm models use depthwise separable convolution which will
555        # result in grouped convolution with in-channel size == 1.
556        # For those grouped convolution, channels last still helps a lot.
557        # E.g.
558        # Conv with arguments
559        #   {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3],
560        #    "stride": [2, 2], "padding": [1, 1], "groups": 58}
561        # get 1.86x speedup with channels last layout.
562        #
563        # The following heuristics skip using channels-last if the model contains
564        # grouped convolution with in-channels > 1.
565        if any(map(is_grouped, conv_nodes)):
566            log.debug(
567                "Skip layout opt because found grouped convolution with >1 in_channels!"
568            )
569            return False
570
571        # For some models that contain convolution with larger in-channel than out-channel, applying
572        # channels last hurts performance.
573        # Following models are skipped due to this:
574        # - pytorch_unet
575        # - phlippe_densenet (slightly worse)
576        # - Background_Matting (1.22x -> 0.821x)
577        # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x)
578        if any(map(is_in_out_channel, conv_nodes)):
579            log.debug(
580                "Skip layout opt because some convolutions have smaller out_channel"
581            )
582            return False
583
584        # Following models are skipped due to this:
585        # - functorch_maml_omniglot
586        if all(map(is_small_channel, conv_nodes)):
587            log.debug("Skip layout opt because all convolution channels are too small")
588            return False
589
590        return True
591
592    def qualify_name(self, name: str) -> str:
593        """Prepend the given name with the graph name if any."""
594        if self.name is not None:
595            return f"{self.name}_{name}"
596        return name
597
598    def make_subgraph(
599        self,
600        gm: torch.fx.GraphModule,
601        example_inputs: List[torch.Tensor],
602        subgraph_name: str,
603    ) -> "GraphLowering":
604        """
605        Make a subgraph of the current graph with all inherited
606        parts, except the graph module (`gm`) and `example_inputs`.
607        The subgraphs are lowered separately, but intended to be
608        inlined in the parent graph's codegening. Hence the need
609        for maintaining the same `shape_env` and other properties.
610        The subgraph name is qualified by the parent graph's name.
611        """
612        return GraphLowering(
613            gm=gm,
614            example_inputs=example_inputs,
615            shape_env=self._shape_env,
616            cpp_wrapper=self.cpp_wrapper,
617            aot_mode=self.aot_mode,
618            extern_node_serializer=self.extern_node_serializer,
619            is_inference=self.is_inference,
620            name=self.qualify_name(subgraph_name),
621        )
622
623    def find_nodes_prefer_channels_last(self):
624        """
625        The rule to decide if an node prefer channels last is simple.
626        1. if it's input/output of a convolution
627        2. if one of its user prefers channels last
628
629        We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs;
630        Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers
631        channels last.
632
633        Consider the scenario: conv -> batch-norm -> relu -> conv
634        Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies:
635        1. the output of batch-norm should be channels last initially since its input is a conv's output.
636           Forcing the batch-norm's output to be contiguous results in the first copy
637        2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output.
638           We need convert it to channels last layout which results in the second copy.
639        With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
640        can be saved.
641        """
642        output_set = set()
643        for n in reversed(self.module.graph.nodes):
644            if n.target == torch.ops.aten.convolution.default:
645                output_set.add(n)
646                continue
647
648            for user in n.users:
649                if user in output_set:
650                    output_set.add(n)
651                    break
652
653        # need a second pass to add downstream nodes of those channel last nodes to the sets.
654        # This pass is especially needed to avoid mix-layout kernel inputs in backward pass.
655        #
656        # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned
657        # from the fwd graph. Without this second pass, we will force relu's output to be contiguous.
658        # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last
659        # tensors and passed to a kernel.
660        #
661        # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x.
662        # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x .
663        # This also helps the following models:
664        # - res2net101_26w_4s
665        # - res2net50_14w_8s
666        # - sebotnet33ts_256
667        for n in self.module.graph.nodes:
668            if n in output_set:
669                output_set.update(n.users)
670
671        return output_set
672
673    def warn_fallback(self, name):
674        if name not in self._warned_fallback:
675            self._warned_fallback.add(name)
676            perf_hint_log.info("Using FallbackKernel: %s", name)
677
678    def add_device_info(self, device: torch.device):
679        self.device_types.add(device.type)
680        if device.index is not None:
681            self.device_idxs.add(device.index)
682        if V.graph.current_node and device not in self.device_node_mapping:
683            self.device_node_mapping[device] = V.graph.current_node
684
685    @property
686    def fake_mode(self):
687        return V.fake_mode
688
689    def get_buffer(self, buffer_name: str):
690        if buffer_name in self.name_to_buffer:
691            return self.name_to_buffer[buffer_name]
692        if buffer_name in self.graph_inputs:
693            return self.graph_inputs[buffer_name]
694        if buffer_name in self.constants:
695            data = V.graph.constants[buffer_name]
696            return ir.ConstantBuffer(
697                buffer_name,
698                ir.FixedLayout(
699                    data.device, data.dtype, *V.graph.static_sizes_strides(data)
700                ),
701            )
702        return None
703
704    def get_dtype(self, buffer_name: str):
705        if buffer_name in self.constants:
706            return self.constants[buffer_name].dtype
707        if buffer_name in self.name_to_buffer:
708            return self.name_to_buffer[buffer_name].get_dtype()
709        if buffer_name in self.graph_inputs:
710            return self.graph_inputs[buffer_name].get_dtype()
711        m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
712        if m:
713            return self.get_dtype(m.group(1))
714        raise KeyError(f"could not find {buffer_name}")
715
716    def get_numel(self, buffer_name: str):
717        from .ir import MultiOutputLayout
718
719        if buffer_name in self.constants:
720            return self.constants[buffer_name].numel()
721        if buffer_name in self.name_to_buffer:
722            buf = self.name_to_buffer[buffer_name]
723            if isinstance(getattr(buf, "layout", None), MultiOutputLayout):
724                return 1
725            return buf.get_numel()
726        if buffer_name in self.graph_inputs:
727            return self.graph_inputs[buffer_name].get_numel()
728        raise KeyError(f"could not find {buffer_name}")
729
730    @dynamo_timed
731    def run(self, *args):
732        return super().run(*args)
733
734    def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False):
735        name = self.qualify_name(f"buf{len(self.buffers)}")
736        self.buffers.append(buffer)
737        self.name_to_buffer[name] = buffer
738        # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
739        if (
740            not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements())
741            and buffer.get_device() is not None
742        ):
743            self.add_device_info(buffer.get_device())
744
745        if set_name:
746            buffer.name = name
747        return name
748
749    def register_list(self, buffer_names: List[str]):
750        name = self.qualify_name("list_" + "_".join(buffer_names))
751        self.lists[name] = buffer_names
752        return name
753
754    def register_users_of(self, node_output):
755        def register(value):
756            if isinstance(value, (list, tuple)):
757                for x in value:
758                    register(x)
759            if isinstance(value, ir.IRNode):
760                if (
761                    not hasattr(value, "data")
762                    or not isinstance(value.data, ir.IRNode)
763                    or not (
764                        hasattr(value.data, "data")
765                        and isinstance(value.data.data, ir.IRNode)
766                    )
767                ):
768                    return
769
770                for read_name in value.get_read_names():
771                    self.name_to_users[read_name].append(value)
772
773        register(node_output)
774
775    def mark_buffer_mutated(self, name: str):
776        """
777        When a buffer is mutated we need to make sure all the reads to
778        the old version are realized before the mutation happens.
779        """
780        assert isinstance(name, str)
781        self.mutated_buffers.add(name)
782
783        if name not in self.name_to_users:
784            return
785
786        for user in self.name_to_users[name]:
787            user.realize()
788
789    def get_original_value_of_constant(self, name: str):
790        """
791        In AOTI, module buffers may have been mutated during the tracing and compilation.
792        Thus we need to read from previously stored original buffers, to make sure the
793        generated model.so uses correct initial values.
794        """
795        assert name in self.allocated_constant_name and name in self.constants, (
796            "Can not find the original value for " + name
797        )
798        orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name])
799        return (
800            self.module.meta[orig_name]
801            if orig_name in self.module.meta
802            else self.constants[name]
803        )
804
805    def allocate_non_dup_const_name(self, name, data):
806        orig_name = name
807        if not config.aot_inductor.use_runtime_constant_folding:
808            for constant_name, value in self.constants.items():
809                if (
810                    not data.is_mkldnn
811                    and data.size() == value.size()
812                    and data.stride() == value.stride()
813                    and data.dtype == value.dtype
814                    and data.device == value.device
815                    and data.untyped_storage().data_ptr()
816                    == value.untyped_storage().data_ptr()
817                    and data.storage_offset() == value.storage_offset()
818                ):
819                    return constant_name
820
821        if name is None:
822            name = f"constant{len(self.constants)}"
823        if name[0].isdigit():
824            name = f"constant_{name}"
825        name = self.qualify_name(name)
826        # We may generate a var name for each constant in the codegen.
827        # Let's only keep sane characters.
828        prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name)
829        name = prefix
830        cnt = 0
831        while name in self.constants:
832            name = f"{prefix}_{cnt}"
833            cnt += 1
834        self.constants[name] = data
835        self.constant_reprs[name] = (
836            f"{data.device!r} {data.dtype!r} "
837            f"{tuple(data.size())!r} {tuple(data.stride())!r} "
838            f"{hash(data):x}"
839        )
840        self.allocated_constant_name[name] = orig_name
841        return name
842
843    def add_tensor_constant(self, data, name=None):
844        new_name = self.allocate_non_dup_const_name(name, data)
845        return TensorBox.create(
846            ir.ConstantBuffer(
847                new_name,
848                FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)),
849            )
850        )
851
852    def constant_name(self, name: str, device_override: Optional[torch.device]):
853        """
854        We AOT copy constants to the devices they are needed on.
855        If device_override doesn't match the constant's device, then
856        copy it and return a different name.
857        """
858        if self.constants[name].device == device_override or device_override is None:
859            return name
860        with torch.utils._python_dispatch._disable_current_modes():
861            # caller might have set fake tensor mode which will create a fake tensor
862            # when calling .to, so unset modes here
863            return self.allocate_non_dup_const_name(
864                f"{name}_{device_override.type}{device_override.index or 0}",
865                self.constants[name].to(device_override),
866            )
867
868    def placeholder(self, target: str, args, kwargs):
869        example = super().placeholder(target, args, kwargs)
870        self.graph_input_names.append(target)
871        if isinstance(example, SymTypes):
872            expr = example.node.expr
873            self.graph_inputs[target] = expr
874            return expr
875        elif isinstance(example, (int, bool, float)):
876            expr = sympy.sympify(example)
877            self.graph_inputs[target] = expr
878            return expr
879        if isinstance(example, BackwardState):
880            # Ignored arg, must be unused
881            # Alternately we could filter this out in AotAutograd
882            return None
883        assert isinstance(example, torch.Tensor), example
884        # todo(chilli): We can remove the last check once we turn buffers into
885        # static shape tensors. That's a hack to workaround Inductor believing
886        # the buffer should be static but us passing in a fake tensor with
887        # symbolic shapes.
888        if not example._has_symbolic_sizes_strides:
889            # the first N inputs are weights
890            sizes, strides = self.static_sizes_strides(example)
891        else:
892            sizes, strides = self.symbolic_sizes_strides(example)
893        # TODO(jansel): handle input aliasing
894        target = self.qualify_name(target)
895        tensor = TensorBox.create(
896            InputBuffer(
897                target,
898                FixedLayout(example.device, example.dtype, sizes, strides),
899            )
900        )
901        self.graph_inputs[target] = tensor
902        self.graph_inputs_original[target] = tensor.data.data
903        self.add_device_info(example.device)
904
905        # Note: [Input Alignment handling in Inductor]
906        # Alignment matters for generating efficient code. Some operations,
907        # e.g. vectorized loads, can only be performed on aligned inputs.
908        #
909        # But if we codegen assuming aligned inputs and then get unaligned
910        # inputs at runtime, then we are forced to clone - which is bad for
911        # both perf and memory usage.
912        #
913        # One option would be to guard on storage_offset%ALIGNMENT, and then
914        # codegen based on this. But storage_offset guards turned out to be
915        # expensive and cause recompiles; Instead, we're generating code
916        # based on the alignment of the example input without guarding.
917        with maybe_get_suppress_shape_guards_ctx():
918            if should_assume_input_aligned(example):
919                self.aligned_inputs.add(target)
920        return tensor
921
922    def call_function(self, target, args, kwargs):
923        if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
924            return super().call_function(target, args, kwargs)
925
926        if hasattr(target, "_inductor_lowering_function"):
927            # passthrough lowerings from .pattern_matcher
928            return target(*args, **kwargs)
929
930        def get_custom_op_layout_constraints(target, args, kwargs):
931            # Custom operations that require preserving stride order
932            # which run through implicit fallback must constrain their
933            # arguments' fx strides
934            layout_constraint = None
935            if torch._C.Tag.needs_fixed_stride_order in target.tags:
936                # We have to set the current args because call_function will immediately
937                # evaluate this lowering after creating the fallback, without evaluating
938                # the layout constraint
939                constrain_fn = functools.partial(
940                    constrain_to_fx_strides, ignore_mutated_args_FIXME=True
941                )
942                args, kwargs = constrain_fn(self.current_node, *args, **kwargs)
943                # Also register the layout constraint so when the fallback
944                # is used again, we can constrain the args to the same layout
945                layout_constraint = constrain_fn
946            return layout_constraint, args, kwargs
947
948        if target not in lowerings:
949            assert isinstance(
950                target, torch._ops.OpOverload
951            ), f"{target} is not an OpOverload"
952            base_name = target.name().split(".")[0]
953            if base_name in FALLBACK_ALLOW_LIST:
954                make_fallback(target)
955            elif config.implicit_fallbacks:
956                layout_constraint, args, kwargs = get_custom_op_layout_constraints(
957                    target, args, kwargs
958                )
959                error = (
960                    MissingOperatorWithDecomp
961                    if get_decompositions([target])
962                    else MissingOperatorWithoutDecomp
963                )
964                log.info(
965                    "Creating implicit fallback for:\n%s",
966                    error.operator_str(target, args, kwargs),
967                )
968                make_fallback(target, layout_constraint)
969
970            elif get_decompositions([target]):
971                # There isn't a good way to dynamically patch this in
972                # since AOT Autograd already ran.  The error message tells
973                # the user how to fix it.
974                raise MissingOperatorWithDecomp(target, args, kwargs)
975            else:
976                raise MissingOperatorWithoutDecomp(target, args, kwargs)
977
978        try:
979            log.debug("  via %s", lowerings[target])
980            out = lowerings[target](*args, **kwargs)
981            return out
982        except Exception as e:
983            raise LoweringException(e, target, args, kwargs).with_traceback(
984                e.__traceback__
985            ) from None
986
987    @staticmethod
988    def can_inline_constant(t: torch.Tensor) -> bool:
989        """
990        True if this is a small constant attr that will be inlined.
991        """
992        return len(t.shape) == 1 and t.shape[0] <= 8
993
994    def get_attr(self, target, args, kwargs):
995        # this is a constant
996        value = getattr_recursive(self.module, target)
997
998        if isinstance(value, torch.fx.GraphModule):
999            return ir.Subgraph(name=target, graph_module=value)
1000
1001        if isinstance(value, torch._C.ScriptObject):
1002            self.torchbind_constants[target] = value
1003            self.constant_reprs[target] = ""
1004            return TorchBindObject(target, value)
1005
1006        if (
1007            config.aot_inductor.use_runtime_constant_folding
1008            or config.always_keep_tensor_constants
1009            or unsupported_output_tensor(value)
1010        ):
1011            return self.add_tensor_constant(value, target)
1012
1013        with no_dispatch():
1014            if value.shape == ():
1015                return Constant(value.item(), value.dtype, value.device)
1016            if self.can_inline_constant(value):
1017                # tensor lowering has constant inlining logic
1018                from .lowering import tensor
1019
1020                return tensor(value.tolist(), dtype=value.dtype, device=value.device)
1021
1022        return self.add_tensor_constant(value, target)
1023
1024    def call_module(self, target, args, kwargs):
1025        raise AssertionError
1026
1027    def call_method(self, target, args, kwargs):
1028        raise AssertionError
1029
1030    def output(self, target, args, kwargs):
1031        result = super().output(target, args, kwargs)
1032        if not isinstance(result, (tuple, list)):
1033            # nested subgraphs can have singleton outputs
1034            result = (result,)
1035        assert isinstance(result, (tuple, list)), type(result)
1036        assert all(
1037            isinstance(
1038                x,
1039                (
1040                    TensorBox,
1041                    ir.Constant,
1042                    type(None),
1043                    ir.ConstantBuffer,
1044                    sympy.Expr,
1045                    sympy.logic.boolalg.Boolean,
1046                    int,
1047                    ir.EffectfulKernel,
1048                ),
1049            )
1050            for x in result
1051        ), result
1052
1053        fx_node_args = V.graph.current_node.args[0]  # type: ignore[arg-type]
1054        if not isinstance(fx_node_args, (tuple, list)):
1055            # nested subgraphs can have singleton outputs
1056            fx_node_args = (fx_node_args,)
1057        result = [ir.ExternKernel.realize_input(x) for x in result]
1058        result_correct_strides = []
1059
1060        assert len(fx_node_args) == len(result)
1061        for r, fx_node in zip(result, fx_node_args):
1062            if not isinstance(r, (ir.TensorBox, ir.BaseView)):
1063                result_correct_strides.append(r)
1064            else:
1065                # AOT Autograd tries to detect stride divergence of inductor from output metadata.
1066                # Here, we try to avoid spurious divergence by matching insignificant strides such as
1067                result_correct_strides.append(
1068                    self.try_match_insignificant_strides(
1069                        r, fx_node.meta["val"].stride()
1070                    )
1071                )
1072
1073        self.graph_outputs = result_correct_strides
1074        value: ir.IRNode
1075        for name, value in self.graph_inputs.items():
1076            assert isinstance(
1077                value, (TensorBox, sympy.Expr)
1078            ), f"Unsupported inductor graph input type: {type(value)}"
1079            if not isinstance(value, TensorBox):
1080                continue
1081            value.realize()
1082            assert isinstance(value, TensorBox)
1083            value = value.data
1084            assert isinstance(value, ir.StorageBox)
1085            value_storage_box = value
1086            value = value.data
1087            if not isinstance(value, InputBuffer) or value.get_name() != name:
1088                # one of our inputs was mutated, need to turn that into a copy
1089                ir.MutationLayoutSHOULDREMOVE.realize_into(
1090                    value, self.graph_inputs_original[name]
1091                )
1092                # replace output with mutated input
1093                try:
1094                    ind = self.graph_outputs.index(value_storage_box)
1095                    self.graph_outputs[ind] = self.graph_inputs_original[name]
1096                except ValueError:
1097                    pass
1098
1099        self.finalize()
1100        log.debug(
1101            "Force channels last inputs for %d conv for the current graph with id %d",
1102            self.num_channels_last_conv,
1103            self.graph_id if self.graph_id is not None else -1,
1104        )
1105
1106    def finalize(self):
1107        for buf in self.buffers:
1108            buf.decide_layout()
1109
1110    @contextmanager
1111    def set_current_node(self, node: torch.fx.Node):
1112        old = self.current_node
1113        try:
1114            self.current_node = node
1115            yield
1116        finally:
1117            self.current_node = old
1118
1119    def try_match_insignificant_strides(
1120        self,
1121        tensor,
1122        meta_strides_inp: Tuple[Union[int, torch.SymInt], ...],
1123    ) -> ir.TensorBox:
1124        """
1125        Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant
1126        dimensions - size 0 or 1 - will be updated.
1127
1128        If there are real stride differences (NHWC vs NCHW) then the input will be returned.
1129        """
1130
1131        # should have already been realized
1132        assert torch._inductor.ir.is_storage_and_layout(tensor)
1133
1134        meta_strides = [
1135            s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_strides_inp
1136        ]
1137
1138        if all(
1139            self.sizevars.statically_known_equals(s1, s2)
1140            for s1, s2 in zip(meta_strides, tensor.get_stride())
1141        ):
1142            return tensor
1143
1144        def significant_strides_equal(shape, meta_strides, tensor_strides):
1145            for dim, s1, s2 in zip(shape, meta_strides, tensor_strides):
1146                if self.sizevars.statically_known_leq(dim, 1):  # type: ignore[arg-type]
1147                    continue
1148
1149                if not self.sizevars.statically_known_equals(s1, s2):
1150                    return False
1151
1152            return True
1153
1154        if not significant_strides_equal(
1155            tensor.get_size(), meta_strides, tensor.get_stride()
1156        ):
1157            return tensor
1158
1159        storage, old_layout = torch._inductor.ir.as_storage_and_layout(tensor)
1160        new_stride = list(old_layout.stride)
1161        for i, s in enumerate(tensor.get_size()):
1162            if self.sizevars.statically_known_leq(s, 1):  # type: ignore[arg-type]
1163                new_stride[i] = meta_strides[i]
1164
1165        new_layout = torch._inductor.ir.FixedLayout(
1166            old_layout.device,
1167            old_layout.dtype,
1168            old_layout.size,
1169            new_stride,
1170            old_layout.offset,
1171        )
1172        return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout))
1173
1174    def run_node(self, n: torch.fx.Node):
1175        def debug(msg):
1176            log.debug("lowering %s %s", LazyString(n.format_node), msg)
1177
1178        buffer_watermark = len(self.buffers)
1179
1180        origins = {n}
1181        if n.op == "call_function":
1182            args, kwargs = self.fetch_args_kwargs_from_env(n)
1183            origins |= gather_origins(args, kwargs)
1184        with ir.IRNode.current_origins(origins), self.set_current_node(
1185            n
1186        ), V.set_current_node(n):
1187            if (
1188                n.op == "call_function"
1189                and n.target is not operator.getitem
1190                and fallback_node_due_to_unsupported_type(n)
1191            ):
1192                debug("fallback_handler")
1193                result = fallback_handler(n.target, add_to_fallback_set=False)(
1194                    *args, **kwargs  # type: ignore[possibly-undefined]
1195                )
1196            elif n.op == "call_function" and n.target in layout_constraints:
1197                debug("layout_constraints")
1198                args, kwargs = layout_constraints[n.target](n, *args, **kwargs)  # type: ignore[index]
1199                result = self.call_function(n.target, args, kwargs)
1200            elif is_magic_method(n.target):
1201                # TODO: this is sus, it probably should be handled in the
1202                # lowerings themselves similarly to sym_size/sym-stride
1203                # https://github.com/pytorch/pytorch/issues/127789
1204                debug("is_magic_method")
1205                if isinstance(
1206                    n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool)
1207                ):
1208                    result = n.meta["val"].node.expr
1209                else:
1210                    result = super().run_node(n)
1211            else:
1212                debug("")
1213                result = super().run_node(n)
1214
1215            # require the same stride order for dense outputs,
1216            # 1. user-land view() will not throw because inductor
1217            # output different strides than eager
1218            # long term the solution is to make view() always succeed
1219            # with infallible strides.
1220            # 2: as_strided ops, we need make sure its input has same size/stride with
1221            # eager model to align with eager behavior.
1222            as_strided_ops = [
1223                torch.ops.aten.as_strided.default,
1224                torch.ops.aten.as_strided_.default,
1225                torch.ops.aten.as_strided_scatter.default,
1226                torch.ops.aten.resize.default,
1227                torch.ops.aten.resize_as.default,
1228            ]
1229            is_output = any(user.op == "output" for user in n.users)
1230            is_input_for_as_strided = any(
1231                user.target in as_strided_ops for user in n.users
1232            )
1233
1234            if n.meta.get("inductor_realize_to_strides", False) and isinstance(
1235                result, TensorBox
1236            ):
1237                result.realize()
1238                strides = n.meta["val"].stride()
1239                sym_strides = torch._inductor.utils.any_is_symbolic(*strides)
1240                if (
1241                    not hasattr(result, "get_stride")
1242                    or result.get_stride() != strides
1243                    and not sym_strides
1244                ):
1245                    stride_order = ir.get_stride_order(strides)
1246                    result = ir.ExternKernel.require_stride_order(result, stride_order)
1247            if (
1248                is_output
1249                and isinstance(result, TensorBox)
1250                and isinstance(result.data, ir.BaseView)
1251            ):
1252                # Realize so that outputs are correctly aliased
1253                result.realize()
1254
1255            if (is_output or is_input_for_as_strided) and isinstance(
1256                n.meta["val"], torch.Tensor
1257            ):
1258                strides = n.meta["val"].stride()
1259                dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"])
1260                unbacked_symbols_in_strides = len(free_unbacked_symbols(strides)) > 0
1261                # requiring a stride order for a non-dense output wouldn't
1262                # recreate the same strides, and would fail with view, defer for now.
1263                if not unbacked_symbols_in_strides and dense and len(strides):
1264                    stride_order = ir.get_stride_order(strides)
1265                    if (
1266                        len(result.get_size()) == 4
1267                        and n in self.nodes_prefer_channels_last
1268                        and n.name not in self.user_visible_outputs
1269                        and not is_input_for_as_strided
1270                    ):
1271                        stride_order = ir.NHWC_STRIDE_ORDER
1272
1273                    allow_padding = (
1274                        n.name not in self.user_visible_outputs
1275                        and not is_input_for_as_strided
1276                    )
1277                    result = ir.ExternKernel.require_stride_order(
1278                        result, stride_order, allow_padding=allow_padding
1279                    )
1280
1281            # Realize if (1) any user need inputs realized, or (2) there is
1282            # already too many reads and rematerializing can be bad.
1283            num_users = len(set(n.users))
1284            if num_users > 1 and isinstance(result, TensorBox):
1285                for user in n.users:
1286                    if user.target in needs_realized_inputs:
1287                        result.realize_hint()
1288                        # This inclusion is somewhat controversial (from
1289                        # discussion between Horace, Natalia, and Elias).
1290                        # Currently, it's not very clear why this is helpful.
1291                        # The general idea here is that even though a node may
1292                        # have FlexibleLayout, we still often *treat* it as if
1293                        # it was contiguous. This appears to sometimes result in
1294                        # suboptimal behavior.
1295                        #
1296                        # When we do a better job selecting layout, we should
1297                        # revisit this.
1298                        need_fixed_layout = [
1299                            torch.ops.aten.convolution_backward.default,
1300                            torch.ops.aten.mm.default,
1301                            torch.ops.aten._int_mm.default,
1302                        ]
1303                        need_fixed_channels_last_layout = []
1304                        if not self.layout_opt:
1305                            need_fixed_layout.append(torch.ops.aten.convolution.default)
1306                        if torch._C._has_mkldnn:
1307                            need_fixed_layout += [
1308                                torch.ops.mkldnn._linear_pointwise.default,
1309                                torch.ops.mkldnn._linear_pointwise.binary,
1310                                torch.ops.aten.mkldnn_rnn_layer.default,
1311                                torch.ops.onednn.qlinear_pointwise.default,
1312                                torch.ops.onednn.qlinear_pointwise.tensor,
1313                                torch.ops.onednn.qlinear_pointwise.binary,
1314                                torch.ops.onednn.qlinear_pointwise.binary_tensor,
1315                            ]
1316                            need_fixed_channels_last_layout += [
1317                                torch.ops.mkldnn._convolution_pointwise.default,
1318                                torch.ops.mkldnn._convolution_pointwise.binary,
1319                                torch.ops.mkldnn._convolution_pointwise_.binary,
1320                                torch.ops.mkldnn._convolution_transpose_pointwise.default,
1321                                torch.ops.onednn.qconv2d_pointwise.default,
1322                                torch.ops.onednn.qconv2d_pointwise.binary,
1323                            ]
1324                            if torch._C.has_mkl:
1325                                need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
1326                        if user.target in need_fixed_layout:
1327                            result = ir.ExternKernel.require_stride_order(
1328                                result,
1329                                ir.get_stride_order(n.meta["val"].stride()),
1330                                allow_padding=True,
1331                            )
1332                        if (
1333                            user.target in need_fixed_channels_last_layout
1334                            and n is user.args[0]
1335                        ):
1336                            result = ir.ExternKernel.require_stride_order(
1337                                result,
1338                                ir.get_stride_order(
1339                                    make_channels_last_strides_for(n.meta["val"].shape)
1340                                ),
1341                            )
1342                    if user.op == "output":
1343                        if isinstance(result.data.data, (Pointwise, Reduction)):
1344                            result.realize()
1345
1346                # TODO(jansel): introduce a store vs inline choice
1347                result.mark_reuse(len(n.users))
1348
1349            # Realize if the IRNode already has accumulated lots of reads
1350            if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
1351                # Prevent excessive accumulation in a computed buffer, when
1352                # there are multiple branches each with small number of memory
1353                # reads, but they converge to a user.
1354                result.realize_hint()
1355
1356            # Realize if a Pointwise has too much stuff to be inlined.
1357            # As this may cause RecursionError during Inductor's evaluation.
1358            if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
1359                curr = result.data.data
1360                if isinstance(curr, Pointwise):
1361                    # Use inner fn as a rough proxy. Good enough.
1362                    if curr.has_large_inner_fn():
1363                        result.realize()
1364
1365        # This is not complete, but it doesn't have to be: origin_node
1366        # tracking is best effort.  The logic here critically relies on direct
1367        # TensorBox -> StorageBox denoting a non-view; we don't bother trying
1368        # to get views to work.  Feel free to add any extra cases as needed.
1369        #
1370        # Note: we can't YOLO tree_map over this result, because if there are
1371        # buffers or a view involved, we might not be able to validly assign
1372        # the origin_node here.
1373        if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
1374            if isinstance(result.data.data, ir.Loops):
1375                result.data.data.origin_node = n
1376            elif isinstance(result.data.data, ir.Buffer):
1377                result.data.data.origin_node = n
1378                if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
1379                    result.data.data.data, ir.Loops
1380                ):
1381                    result.data.data.data.origin_node = n
1382                # Not really multi-output, can straightforwardly recurse in
1383                elif (
1384                    isinstance(result.data.data, ir.MultiOutput)
1385                    and not result.data.data.indices
1386                ):
1387                    if isinstance(result.data.data.inputs[0], ir.Buffer):
1388                        result.data.data.inputs[0].origin_node = n
1389
1390        self.register_users_of(result)
1391
1392        new_unbacked_defs = set()
1393        for i in range(buffer_watermark, len(self.buffers)):
1394            new_unbacked_defs |= self.buffers[i].get_unbacked_symbol_defs()
1395
1396        def format_buffers():
1397            r = []
1398            for b in self.buffers[buffer_watermark:]:
1399                r.append(
1400                    f"unbacked_symbol_defs={b.get_unbacked_symbol_defs()} in:\n{b}\n"
1401                )
1402            return "***\n".join(r)
1403
1404        if n.op != "placeholder":
1405            # Note [Backwards runtime asserts]
1406            # Backwards poses an interesting problem for deferred runtime
1407            # asserts.  In the easy case, we may solely close over data
1408            # dependent sized tensors, and there are no binding sites for
1409            # unbacked SymInts.  In this case, we can just drop all the
1410            # runtime asserts on the floor: no non-placeholder bindings, no
1411            # problem.
1412            #
1413            # However, it is *possible* for a fresh runtime assert to show up
1414            # between forwards and backwards.  Right now, the freezing process
1415            # that happens when we lower forwards means that we will freeze
1416            # runtime asserts, and then the moment the backwards lowering
1417            # process attempts to add a new deferred runtime assert, we will
1418            # fail.  Let's say you remove that assert.  Now when we get here,
1419            # we need to make sure we actually emit these asserts (because we
1420            # can't emit them in forwards, we already compiled it).  So we
1421            # have to do something here.  But we don't want to reemit ALL
1422            # deferred runtime asserts, we only want to emit the NEW ones.
1423            # Therefore needing some sort of stratification in the ShapeEnv.
1424            # This is all doable, it just hasn't been done yet.
1425            shape_env = V.graph.sizevars.shape_env
1426
1427            for i0 in new_unbacked_defs:
1428                ras = self.ras_by_symbol.pop(i0, [])
1429                # NB: size-like not needed, we won't retrace
1430                vr = shape_env.var_to_range[i0]
1431                if not shape_env._default_unspecified_value_range().issubset(vr):
1432
1433                    def is_convertible(s):
1434                        if s in (int_oo, -int_oo):
1435                            return False
1436                        try:
1437                            int(s)
1438                            return True
1439                        except TypeError:
1440                            return False
1441
1442                    if is_convertible(vr.lower):
1443                        self.register_buffer(
1444                            ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"),
1445                            set_name=True,
1446                        )
1447                    if is_convertible(vr.upper):
1448                        self.register_buffer(
1449                            ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"),
1450                            set_name=True,
1451                        )
1452
1453                for ra in ras:
1454                    fvs = free_unbacked_symbols(ra.expr)
1455                    missing = fvs - self.bound_unbacked_symbols
1456                    if missing:
1457                        i1 = sorted(missing, key=lambda x: str(x))[0]
1458                        self.ras_by_symbol.setdefault(i1, []).append(ra)
1459                    else:
1460                        self.register_buffer(
1461                            ir.AssertScalar(ra.expr, f"{ra.expr}"), set_name=True
1462                        )
1463
1464            self.bound_unbacked_symbols |= new_unbacked_defs
1465
1466            unbacked_bindings = resolve_unbacked_bindings(
1467                V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
1468            )
1469            # When we do lowering, it is possible we reallocate unbacked SymInts.
1470            # So we need to line up the unbacked SymInts when performing the test
1471            # here
1472            #
1473            # In principle, we could permit lowering to introduce MORE unbacked
1474            # SymInts: as long as all the old unbacked ones are accounted for,
1475            # it's fine for inductor to introduce extra calls to item()/unbacked()
1476            # whatever.  This actually happens in practice when an unbacked SymInt
1477            # gets memoized away; naively, when Inductor reprocesses a kernel, it
1478            # doesn't know that the memo still applies, and ends up allocating a
1479            # new symbol.  However, this is generally a bad thing: we may still
1480            # end up needing to test equalities on the symbols, and a fresh
1481            # symbol is likely to hit lots of GuardOnDataDependent errors that
1482            # we already know facts for.
1483            renamed_unbacked_bindings = {
1484                V.fake_mode.shape_env.unbacked_renamings.get(s, s)
1485                for s in unbacked_bindings.keys()
1486            }
1487            assert new_unbacked_defs >= renamed_unbacked_bindings, (
1488                f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
1489                f"fx node is: {n.format_node()}\n"
1490                f"new buffers are:\n\n{format_buffers()}"
1491            )
1492
1493        return result
1494
1495    def validate_can_generate_cpp_wrapper(self):
1496        if config.disable_cpp_codegen:
1497            raise CppWrapperCodeGenError("C++ codegen is disabled")
1498
1499        if sys.platform not in ["linux", "darwin"]:
1500            raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}")
1501
1502        for value in self.graph_inputs.values():
1503            dtype = None
1504            if isinstance(value, TensorBox):
1505                dtype = value.get_dtype()
1506            elif isinstance(
1507                value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
1508            ):
1509                dtype = may_get_constant_buffer_dtype(value)
1510
1511            if not supported_dtype_of_cpp_wrapper(dtype, self.cuda):
1512                raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}")
1513
1514    def init_wrapper_code(self):
1515        self.cuda = "cuda" in self.device_types
1516        if self.cpp_wrapper:
1517            self.validate_can_generate_cpp_wrapper()
1518
1519        device_types = self.device_types.copy()
1520        device_types.discard("cpu")
1521        device_types.discard("meta")
1522        # TODO(Eikan): Only support mixing cpu and other device now.
1523        assert len(device_types) <= 1, "Does not support mixing {}".format(
1524            "+".join(device_types)
1525        )
1526        only_cpu = len(device_types) == 0
1527        device_type = "cpu" if only_cpu else device_types.pop()
1528
1529        self.device_ops = get_device_op_overrides(device_type)
1530        wrapper_code_gen_cls = get_wrapper_codegen_for_device(
1531            device_type, self.cpp_wrapper
1532        )
1533        assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
1534        self.wrapper_code = wrapper_code_gen_cls()
1535
1536        if self.const_module:
1537            # If we have const module, we could reuse the kernels
1538            # This could avoid duplication and save time on doing recompilation (if Triton.)
1539            self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
1540            self.wrapper_code.src_to_kernel = (
1541                self.const_module.wrapper_code.src_to_kernel
1542            )
1543
1544    def codegen_with_cpp_wrapper(self):
1545        """
1546        For CPU, the cpp wrapper codegen is done in one pass.
1547        For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
1548        wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
1549        generate cpp wrapper code and compile it to a dynamic library in the second pass.
1550        """
1551        if "cuda" in self.device_types:
1552            # first pass
1553            self.cpp_wrapper = False
1554            # Although triton.store_cubin was set in compile_fx, the backward pass didn't pick
1555            # that up. In theory it should work by only setting triton.store_cubin to True here,
1556            # but that will cause a problem when use_runtime_constant_folding is set.
1557            with config.patch({"triton.store_cubin": True}):
1558                compiled = self.compile_to_module().call
1559
1560            def materialize(x):
1561                if isinstance(x, (torch.SymInt, torch.SymFloat)):
1562                    # Need concrete value to run dynamic shapes and tune the result
1563                    return x.node.hint
1564                elif isinstance(x, FakeTensor):
1565                    return defake(x)
1566                else:
1567                    assert isinstance(
1568                        x, torch.Tensor
1569                    ), "Unknown type when creating real inputs" + str(type(x))
1570                    return x
1571
1572            tracing_context = torch._guards.TracingContext.try_get()
1573            if tracing_context is not None and not isinstance(
1574                V.real_inputs, NullHandler
1575            ):
1576                if tracing_context.output_strides:
1577                    tracing_context.output_strides.clear()
1578
1579                params_flat = [
1580                    param
1581                    for param in tracing_context.params_flat  # type: ignore[union-attr]
1582                    if param is not None
1583                ]
1584                real_inputs = [
1585                    materialize(x) for x in itertools.chain(params_flat, V.real_inputs)
1586                ]
1587            else:
1588                # In the backward pass, V.real_inputs is not set.
1589                # Generating random inputs based on self.example_inputs sometimes can be problematic,
1590                # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
1591                real_inputs = [
1592                    materialize(x)
1593                    for x in (
1594                        self.example_inputs
1595                        if isinstance(V.real_inputs, NullHandler)
1596                        else V.real_inputs
1597                    )
1598                ]
1599
1600            if self.mutated_inputs:
1601                from .compile_fx import clone_preserve_strides
1602
1603                mutated_input_idxs = [
1604                    idx
1605                    for idx, name in enumerate(self.graph_inputs)
1606                    if name in self.mutated_inputs
1607                    and isinstance(real_inputs[idx], torch.Tensor)
1608                ]
1609                for idx in mutated_input_idxs:
1610                    # clone mutated Tensor inputs to avoid mutating them in
1611                    # the first pass of the CPP wrapper-based compilation, as
1612                    # this will lead to a side effect on the example inputs:
1613                    # e.g. if torch.compile(f)(x) if called on input-mutating
1614                    # f, the inputs x will be mutated twice in the process:
1615                    # once here, and again when running the compiled model;
1616                    # this will also lead to a numerically incorrect output
1617                    real_inputs[idx] = clone_preserve_strides(real_inputs[idx])
1618
1619            with torch.utils._python_dispatch._disable_current_modes():
1620                compiled(real_inputs)
1621            del real_inputs
1622
1623            # second pass
1624            # TODO: reuse self.scheduler from the first pass to speed up the second pass
1625            self.cpp_wrapper = True
1626            self.removed_buffers.clear()
1627            self.inplaced_to_remove.clear()
1628            V.graph.sizevars.precomputed_replacements.clear()
1629            V.graph.sizevars.inv_precomputed_replacements.clear()
1630            return self.codegen()
1631        else:
1632            # cpu
1633            return self.codegen()
1634
1635    def codegen(self):
1636        from .scheduler import Scheduler
1637
1638        self.init_wrapper_code()
1639
1640        self.scheduler = Scheduler(self.buffers)
1641        V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
1642
1643        self.wrapper_code.push_codegened_graph(self)
1644        self.scheduler.codegen()
1645        result = self.wrapper_code.generate(self.is_inference)
1646        self.wrapper_code.pop_codegened_graph()
1647        return result
1648
1649    def codegen_subgraph(self, parent_graph):
1650        """
1651        This is a more compact version of the `codegen()` above
1652        where we codegen this graph as a subgraph of some parent
1653        graph. The parent graph is passed as an argument: the
1654        intention is to inline codegening of the subgraph in
1655        the parent graph's wrapper code (including the generated
1656        kerenls). The wrapper code is not finalized (via `.generate()`
1657        call), as this will be done in the parent graph's `codegen()`.
1658        """
1659        from .scheduler import Scheduler
1660
1661        self.wrapper_code = parent_graph.wrapper_code
1662        self.device_ops = parent_graph.device_ops
1663        self.cpp_wrapper = parent_graph.cpp_wrapper
1664
1665        self.scheduler = Scheduler(self.buffers)
1666        self.scheduler.codegen()
1667
1668    def count_bytes(self):
1669        total_bytes = 0
1670        node_counts = []
1671        node_runtimes = []
1672        for node in self.scheduler.nodes:
1673            num_bytes = node.get_read_write_buffers_sizes()
1674            total_bytes += num_bytes
1675            node_counts.append((node, num_bytes // 4))
1676            node_runtimes.append((node, node.get_estimated_runtime()))
1677        return total_bytes, node_counts, node_runtimes
1678
1679    @dynamo_timed(phase_name="code_gen", fwd_only=False)
1680    def compile_to_module(self):
1681        from .codecache import PyCodeCache
1682
1683        code, linemap = (
1684            self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
1685        )
1686
1687        output_code_log.debug("Output code: \n%s", code)
1688        try:
1689            linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
1690            key, path = PyCodeCache.write(code)
1691        except Exception:
1692            trace_structured(
1693                "inductor_output_code",
1694                # Just omit the filename, I still want the code though!
1695                payload_fn=lambda: code,
1696            )
1697            raise
1698        else:
1699            trace_structured(
1700                "inductor_output_code",
1701                lambda: {"filename": path},
1702                payload_fn=lambda: code,
1703            )
1704
1705        mod = PyCodeCache.load_by_key_path(
1706            key,
1707            path,
1708            linemap=linemap,
1709            attrs={**self.constants, **self.torchbind_constants},
1710        )
1711        self.cache_key = key
1712        self.cache_path = path
1713        self.cache_linemap = linemap
1714
1715        # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
1716        # TODO. Revisit this once the logging API is more mature
1717        assert mod.__file__ is not None
1718
1719        log_module_code(mod.__file__)
1720        log.debug("Output code written to: %s", mod.__file__)
1721        output_code_log.info("Output code written to: %s", mod.__file__)
1722        if config.benchmark_kernel:
1723            print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
1724        V.debug.output_code(mod.__file__)
1725        V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
1726        return mod
1727
1728    def compile_to_fn(self):
1729        if self.aot_mode:
1730            from .codecache import AotCodeCompiler
1731
1732            assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
1733            code, linemap = self.codegen_with_cpp_wrapper()
1734            output_code_log.debug("Output code: \n%s", code)
1735
1736            serialized_extern_kernel_nodes = None
1737            if (
1738                config.is_fbcode()
1739                and self.extern_kernel_nodes
1740                and self.extern_node_serializer
1741            ):
1742                serialized_extern_kernel_nodes = self.extern_node_serializer(
1743                    self.extern_kernel_nodes
1744                )
1745                output_code_log.debug(
1746                    "Serialized Extern Kernel Nodes: \n%s",
1747                    serialized_extern_kernel_nodes,
1748                )
1749
1750            # Directly return the file path with the compiled code
1751            return AotCodeCompiler.compile(
1752                self, code, serialized_extern_kernel_nodes, cuda=self.cuda
1753            )
1754        else:
1755            return self.compile_to_module().call
1756
1757    def get_output_names(self):
1758        return [
1759            node.get_name()
1760            for node in self.graph_outputs
1761            if not isinstance(node, ir.NoneAsConstantBuffer)
1762            and not isinstance(node, ir.ShapeAsConstantBuffer)
1763        ]
1764
1765    def is_unspec_arg(self, name: str):
1766        # dynamo wraps unspec variable as 0d CPU tensor,
1767        # need to convert to scalar during codegen (triton only)
1768        return (
1769            name in self.graph_inputs.keys()
1770            and self.graph_inputs[name].get_numel() == 1
1771            and self.graph_inputs[name].get_device().type == "cpu"
1772        )
1773