xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/wrapper.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import collections
5import contextlib
6import dataclasses
7import dis
8import functools
9import inspect
10import logging
11import operator
12import re
13import tempfile
14from itertools import count
15from typing import (
16    Any,
17    Callable,
18    Dict,
19    Iterator,
20    List,
21    Optional,
22    Set,
23    Tuple,
24    TYPE_CHECKING,
25    Union,
26)
27
28import sympy
29from sympy import Expr
30
31import torch
32import torch._ops
33from torch import dtype as torch_dtype
34from torch._dynamo.utils import counters, dynamo_timed
35from torch._inductor.codegen.debug_utils import DebugPrinterManager
36from torch._inductor.codegen.multi_kernel import MultiKernelState
37from torch._inductor.runtime.runtime_utils import cache_dir
38from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
39from torch.fx.node import _get_qualified_name
40from torch.utils._sympy.singleton_int import SingletonInt
41from torch.utils._sympy.symbol import symbol_is_type, SymT
42
43from .. import async_compile, config, ir
44from ..codecache import output_code_log
45from ..ir import ReinterpretView
46from ..runtime import triton_heuristics
47from ..runtime.hints import DeviceProperties
48from ..utils import (
49    cache_on_self,
50    get_benchmark_name,
51    LineContext,
52    sympy_product,
53    sympy_str,
54)
55from ..virtualized import V
56from .aoti_hipify_utils import maybe_hipify_code_wrapper
57from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter
58from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
59
60
61if TYPE_CHECKING:
62    import triton
63
64    from ..graph import GraphLowering
65
66
67pexpr = PythonPrinter().doprint
68
69
70ReuseKey = Tuple[torch.device, torch.dtype, str]
71
72
73def buffer_reuse_key(node: ir.Buffer) -> ReuseKey:
74    return (
75        node.get_device(),
76        node.get_dtype(),
77        # NB: this is symbolic so that we don't try to reuse a buffer
78        # for s0 for s1, just because they happen to share the same
79        # size hint
80        sympy_str(V.graph.sizevars.simplify(node.layout.storage_size())),
81    )
82
83
84def convert_arg_type(arg: torch.Argument) -> str:
85    from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP
86
87    # use x.real_type instead of x.type so that we get ScalarType instead of int
88    python_type = repr(arg.real_type)  # type: ignore[attr-defined]
89
90    if python_type == "Tensor":
91        # Conversions rules follow https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func
92        if arg.alias_info is not None and arg.alias_info.is_write:
93            return f"at::{python_type}&"
94        else:
95            return f"at::{python_type} const&"
96
97    if python_type in PYTHON_TO_CPP:
98        cpp_type = PYTHON_TO_CPP[python_type]
99        return cpp_type
100
101    # Convert args of container types e.g. Optional[*]
102    for py_container, cpp_container in CONTAINER_PYTHON_TO_CPP.items():
103        container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type)
104        if len(container_match) == 1:
105            contained_type = container_match[0]
106            assert (
107                contained_type in PYTHON_TO_CPP
108            ), f"unsupported {py_container} type in convert_arg_type: {contained_type}"
109            cpp_contained_type = PYTHON_TO_CPP[contained_type]
110            return f"{cpp_container}<{cpp_contained_type}>"
111
112    raise AssertionError(f"unsupport python_type: {python_type}")
113
114
115def convert_return_type(ret: torch.Argument) -> str:
116    # use x.real_type instead of x.type so that we get ScalarType instead of int
117    python_type = repr(ret.real_type)  # type: ignore[attr-defined]
118    python_to_cpp = {
119        "Tensor": "at::Tensor",
120        "List[Tensor]": "std::vector<at::Tensor>",
121    }
122
123    cpp_type = python_to_cpp.get(python_type, None)
124    assert cpp_type is not None, f"NYI return type: {python_type}"
125    # An output aliasing an input is returned by reference only when it's a
126    # Tensor, not when it's a Tensor[]. For example, aten.split.Tensor's output
127    # aliases the input tensor, but the op returns a vector by value.
128    if python_type == "Tensor" and ret.alias_info is not None:
129        cpp_type += "&"
130    return cpp_type
131
132
133def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str:
134    args = kernel._schema.arguments
135    returns = kernel._schema.returns
136
137    num_returns = len(returns)
138    assert num_returns > 0, "must have at least one return value"
139
140    if num_returns == 1:
141        cpp_return_value = convert_return_type(returns[0])
142    elif num_returns > 1:
143        tuple_returns = ", ".join([convert_return_type(r) for r in returns])
144        cpp_return_value = f"std::tuple<{tuple_returns}>"
145
146    cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args]
147    return f"{cpp_return_value}({', '.join(cpp_arg_type)})"  # type: ignore[possibly-undefined]
148
149
150# TODO: Move to a well known place
151TritonMetaParams = Dict[str, int]
152TritonGrid = Union[
153    Tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], Tuple[int, ...]]
154]
155
156
157def user_defined_kernel_grid_fn_code(
158    name: str,
159    configs: List[triton.Config],  # type: ignore[name-defined]
160    grids: List[TritonGrid],
161    wrapper: Optional[WrapperCodeGen] = None,
162) -> Tuple[str, str]:
163    output = IndentedBuffer()
164
165    def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr:
166        return item if isinstance(item, sympy.Expr) else sympy.Integer(item)
167
168    def determine_grid(
169        grid: TritonGrid,
170    ):
171        """
172        This function return a tuple of two values: the first one is for the real grid
173        which is used in the generated code; the second one is an example grid with
174        concreate values which is used in the autotune block to run the generated
175        kernels at compile time.
176        """
177        if wrapper is None or callable(grid):
178            # return as-is when used in eager mode or when grid is callable
179            return grid, grid
180        # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen
181        sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid)
182        return (
183            wrapper.codegen_shape_tuple(sympy_grid),
184            wrapper.codegen_shape_tuple(
185                tuple(
186                    wrapper.generate_example_arg_value(g, type(g)) for g in sympy_grid
187                )
188            )
189            if config.triton.autotune_at_compile_time
190            else None,
191        )
192
193    def writeline(line: str, example_grid: Optional[str] = None):
194        output.writeline(line)
195        if (
196            wrapper
197            and config.triton.autotune_at_compile_time
198            and name not in wrapper.kernel_autotune_names
199        ):
200            wrapper.kernel_autotune_calls.writeline(example_grid or line)
201
202    fn_name = f"grid_wrapper_for_{name}"
203    writeline(f"def {fn_name}(meta):")
204    kernel_autotune_calls_indent = (
205        wrapper.kernel_autotune_calls.indent()
206        if wrapper and config.triton.autotune_at_compile_time
207        else contextlib.nullcontext()
208    )
209    with output.indent(), kernel_autotune_calls_indent:
210        if len(grids) == 1:
211            grid, example_grid = determine_grid(grids[0])
212            writeline(f"return {grid}", f"return {example_grid}")
213        else:
214            assert len(grids) > 1
215            assert len(grids) == len(configs)
216            seen = set()
217            for grid, c in zip(grids, configs):
218                guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()]
219                guards = " and ".join(guards)
220                grid, example_grid = determine_grid(grid)
221                statement = f"if {guards}: return {grid}"
222                if statement in seen:
223                    continue
224                seen.add(statement)
225                writeline(statement, f"if {guards}: return {example_grid}")
226
227    return fn_name, output.getvalue()
228
229
230@dataclasses.dataclass
231class SymbolicCallArg:
232    inner: str
233    # the original symbolic expression represented by inner
234    inner_expr: sympy.Expr
235
236    def __str__(self):
237        return str(self.inner)
238
239
240# Default thread stack sizes vary by platform:
241# - Linux: 8 MB
242# - macOS: 512 KB
243# - Windows: 1 MB
244# Just pick something comfortably smaller than the smallest for now.
245MAX_STACK_ALLOCATION_SIZE = 1024 * 100
246
247
248class MemoryPlanningState:
249    def __init__(self):
250        super().__init__()
251        self.reuse_pool: Dict[
252            ReuseKey, List[FreeIfNotReusedLine]
253        ] = collections.defaultdict(list)
254        self.total_allocated_buffer_size: int = 0
255
256    def __contains__(self, key: ReuseKey) -> bool:
257        return bool(self.reuse_pool.get(key, None))
258
259    def pop(self, key: ReuseKey) -> FreeIfNotReusedLine:
260        item = self.reuse_pool[key].pop()
261        assert not item.is_reused
262        return item
263
264    def push(self, key: ReuseKey, item: FreeIfNotReusedLine) -> None:
265        assert not item.is_reused
266        self.reuse_pool[key].append(item)
267
268
269class WrapperLine:
270    pass
271
272
273@dataclasses.dataclass
274class EnterSubgraphLine(WrapperLine):
275    wrapper: WrapperCodeGen
276    graph: GraphLowering
277
278    def __post_init__(self) -> None:
279        self.wrapper.push_computed_sizes(self.wrapper.computed_sizes)
280
281    def codegen(self, code: IndentedBuffer) -> None:
282        self.wrapper.push_codegened_graph(self.graph)
283        code.do_indent()
284
285
286@dataclasses.dataclass
287class ExitSubgraphLine(WrapperLine):
288    wrapper: WrapperCodeGen
289
290    def __post_init__(self) -> None:
291        self.wrapper.computed_sizes = self.wrapper.pop_computed_sizes()
292
293    def codegen(self, code: IndentedBuffer) -> None:
294        self.wrapper.pop_codegened_graph()
295        code.do_unindent()
296
297
298@dataclasses.dataclass
299class EnterDeviceContextManagerLine(WrapperLine):
300    device_idx: int
301    last_seen_device_guard_index: Optional[int]
302
303    def codegen(self, code: IndentedBuffer) -> None:
304        if V.graph.cpp_wrapper:
305            code.writeline("\n")
306            if V.graph.aot_mode:
307                # In AOT mode, we have a stream provided as a param. A stream is
308                # associated with a device, so we never expect the device to change.
309                # CUDAStreamGuard sets the stream and the device.
310                if self.last_seen_device_guard_index is None:
311                    if config.abi_compatible:
312                        code.writeline(
313                            "AOTICudaStreamGuard stream_guard(stream, this->device_idx_);"
314                        )
315                    else:
316                        code.writeline(
317                            maybe_hipify_code_wrapper(
318                                "at::cuda::CUDAStreamGuard stream_guard("
319                                + "at::cuda::getStreamFromExternal(stream, this->device_idx_));"
320                            )
321                        )
322                else:
323                    assert (
324                        self.last_seen_device_guard_index == self.device_idx
325                    ), "AOTInductor only supports running on one CUDA device"
326            else:
327                if self.last_seen_device_guard_index is None:
328                    code.writeline(
329                        f"AOTICudaGuard device_guard({self.device_idx});"
330                        if config.abi_compatible
331                        else maybe_hipify_code_wrapper(
332                            f"at::cuda::CUDAGuard device_guard({self.device_idx});"
333                        )
334                    )
335                else:
336                    code.writeline(f"device_guard.set_index({self.device_idx});")
337        else:
338            # Note _DeviceGuard has less overhead than device, but only accepts
339            # integers
340            code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:")
341            code.do_indent()
342            code.writeline(V.graph.device_ops.set_device(self.device_idx))
343
344
345class ExitDeviceContextManagerLine(WrapperLine):
346    def codegen(self, code: IndentedBuffer) -> None:
347        if not V.graph.cpp_wrapper:
348            code.do_unindent()
349
350
351@dataclasses.dataclass
352class MemoryPlanningLine(WrapperLine):
353    wrapper: WrapperCodeGen
354
355    def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
356        """First pass to find reuse"""
357        return self
358
359    def codegen(self, code: IndentedBuffer) -> None:
360        """Second pass to output code"""
361
362    def __str__(self) -> str:
363        """
364        Emits a string representation that fits on one line.
365        """
366        args: List[str] = []
367        for field in dataclasses.fields(self):
368            if field.name == "wrapper":
369                continue
370            val = getattr(self, field.name)
371            args.append(
372                f"{field.name}={val.get_name() if field.type is ir.Buffer else val}"
373            )
374        return f"{type(self).__name__}({', '.join(args)})"
375
376
377@dataclasses.dataclass
378class AllocateLine(MemoryPlanningLine):
379    node: ir.Buffer
380
381    def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
382        if self.node.get_name() in V.graph.removed_buffers:
383            return NullLine(self.wrapper)
384
385        # try to reuse a recently freed buffer
386        key = buffer_reuse_key(self.node)
387        if config.allow_buffer_reuse and key in state:
388            free_line = state.pop(key)
389            free_line.is_reused = True
390            return ReuseLine(self.wrapper, free_line.node, self.node)
391
392        if self.node.get_device().type == "cpu":
393            static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
394            if static_shape is not None:
395                state.total_allocated_buffer_size += int(
396                    functools.reduce(operator.mul, static_shape, 1)
397                )
398
399        return self
400
401    def codegen(self, code: IndentedBuffer) -> None:
402        assert self.node.get_name() not in V.graph.removed_buffers
403        line = self.wrapper.make_buffer_allocation(self.node)
404        code.writeline(line)
405
406
407@dataclasses.dataclass
408class FreeIfNotReusedLine(MemoryPlanningLine):
409    node: ir.Buffer
410    is_reused: bool = False
411
412    def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
413        if len(self.node.get_inputs_that_alias_output()) > 0:
414            return self
415        if isinstance(self.node.layout, ir.MultiOutputLayout):
416            return self
417        assert not self.is_reused
418        if self.node.get_name() in V.graph.removed_buffers:
419            return NullLine(self.wrapper)
420        if config.allow_buffer_reuse:
421            state.push(buffer_reuse_key(self.node), self)
422        return self
423
424    def codegen(self, code: IndentedBuffer) -> None:
425        assert self.node.get_name() not in V.graph.removed_buffers
426        if not self.is_reused:
427            code.writeline(self.wrapper.make_buffer_free(self.node))
428
429
430@dataclasses.dataclass
431class ReuseLine(MemoryPlanningLine):
432    node: ir.Buffer
433    reused_as: ir.Buffer
434    delete_old: bool = True
435
436    def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
437        if self.node.get_name() in V.graph.removed_buffers:
438            assert self.reused_as.get_name() in V.graph.removed_buffers
439            return NullLine(self.wrapper)
440        assert self.reused_as.get_name() not in V.graph.removed_buffers
441        return self
442
443    def codegen(self, code: IndentedBuffer) -> None:
444        assert self.node.get_name() not in V.graph.removed_buffers
445        assert self.reused_as.get_name() not in V.graph.removed_buffers
446        code.writeline(
447            self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old)
448        )
449
450
451class NullLine(MemoryPlanningLine):
452    pass
453
454
455BufferName = str
456
457
458class WrapperCodeGen(CodeGen):
459    """
460    Generate outer wrapper in Python that calls the kernels.
461    """
462
463    def __init__(self):
464        super().__init__()
465        self._names_iter: Iterator[int] = count()
466        self.imports = IndentedBuffer()
467        self.header = IndentedBuffer()
468        self.prefix = IndentedBuffer()
469        self.suffix = IndentedBuffer()
470        self.wrapper_call = IndentedBuffer()
471        self.kernel_autotune_defs = IndentedBuffer()
472        self.kernel_autotune_calls = IndentedBuffer()
473        self.kernel_autotune_names: Set[str] = set()
474        # If the generated source code is exactly the same, reuse the
475        # pre-existing kernel for it
476        self.src_to_kernel: Dict[str, str] = {}
477        self.kernel_numel_expr: Set[Tuple[str, GraphLowering]] = set()
478        self.lines: List[Union[MemoryPlanningLine, LineContext]] = []
479        self.declare = ""
480        self.declare_maybe_reference = ""
481        self.ending = ""
482        self.open_bracket = "["
483        self.closed_bracket = "]"
484        self.comment = "#"
485        self.namespace = ""
486        self.none_str = "None"
487        self.size = "size()"
488        self.stride = "stride()"
489        self.last_seen_device_guard_index: Optional[int] = None
490        self.supports_intermediate_hooks = True
491        self.expr_printer: Callable[[Any], str] = pexpr
492        self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {}
493        self.unbacked_symbol_decls: Set[str] = set()  # str of sympy.Symbol
494        self.allow_stack_allocation: Optional[bool] = None
495        self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {}
496        self.computed_sizes: Set[sympy.Symbol] = set()
497
498        # this is used for tracking which GraphLowering instance---parent graph
499        # or (nested) subgraph---is currently codegened; the primary use case is
500        # including the graph instance into a cache key to avoid cross-graph
501        # caching during lowering of nested subgraphs
502        self.codegened_graph_stack = []
503        self.computed_sizes_stack = []
504
505        self.write_header()
506        self.write_prefix()
507        self.write_kernel_autotune_defs_header()
508
509        if not V.graph.aot_mode:
510            for name, hashed in V.graph.constant_reprs.items():
511                # include a hash so our code cache puts different constants into different files
512                self.write_constant(name, hashed)
513
514        self.allocated: Set[BufferName] = set()
515        self.freed: Set[BufferName] = set()
516
517        # maps from reusing buffer to reused buffer
518        self.reuses: Dict[BufferName, BufferName] = {}
519
520        self.write_get_raw_stream = functools.lru_cache(None)(  # type: ignore[assignment]
521            self.write_get_raw_stream
522        )
523
524        @functools.lru_cache(None)
525        def add_import_once(line: str) -> None:
526            self.imports.writeline(line)
527            if config.triton.autotune_at_compile_time:
528                self.kernel_autotune_calls.writeline(line)
529
530        self.add_import_once = add_import_once
531        self._metas: Dict[str, str] = {}
532        self._meta_vars: Set[str] = set()
533        self.multi_kernel_state = MultiKernelState()
534
535        # intermediate tensor value printing utility
536        self.debug_printer = DebugPrinterManager(
537            debug_printer_level=config.aot_inductor.debug_intermediate_value_printer
538        )
539
540    def write_constant(self, name: str, hashed: str) -> None:
541        self.header.writeline(f"{name} = None  # {hashed}")
542
543    def write_header(self) -> None:
544        context = torch._guards.TracingContext.try_get()
545        aot_config_comment = ""
546        if context is not None and context.aot_graph_name is not None:
547            aot_config_comment = f"# AOT ID: {context.aot_graph_name}"
548        self.imports.splice(
549            f"""
550                {aot_config_comment}
551                from ctypes import c_void_p, c_long, c_int
552                import torch
553                import math
554                import random
555                import os
556                import tempfile
557                from math import inf, nan
558                from torch._inductor.hooks import run_intermediate_hooks
559                from torch._inductor.utils import maybe_profile
560                from torch._inductor.codegen.memory_planning import _align as align
561                from torch import device, empty_strided
562                from {async_compile.__name__} import AsyncCompile
563                from torch._inductor.select_algorithm import extern_kernels
564                from torch._inductor.codegen.multi_kernel import MultiKernelCall
565            """,
566            strip=True,
567        )
568        self.header.splice(
569            """
570                aten = torch.ops.aten
571                inductor_ops = torch.ops.inductor
572                _quantized = torch.ops._quantized
573                assert_size_stride = torch._C._dynamo.guards.assert_size_stride
574                empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
575                empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
576                empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
577                reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
578                alloc_from_pool = torch.ops.inductor._alloc_from_pool
579                async_compile = AsyncCompile()
580            """,
581            strip=True,
582        )
583
584    def write_kernel_autotune_defs_header(self) -> None:
585        self.kernel_autotune_defs.splice(
586            f"""
587                import torch
588                from torch._dynamo.testing import rand_strided
589                from torch._dynamo.utils import preserve_rng_state
590                from torch._inductor.select_algorithm import AlgorithmSelectorCache
591                from {async_compile.__name__} import AsyncCompile
592
593                async_compile = AsyncCompile()
594                generate_example_value = AlgorithmSelectorCache.generate_example_value
595            """
596        )
597
598    @cache_on_self
599    def write_triton_header_once(self) -> None:
600        import_str = f"""
601            import triton
602            import triton.language as tl
603            from {triton_heuristics.__name__} import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
604            """
605        self.imports.splice(import_str, strip=True)
606        if config.triton.autotune_at_compile_time:
607            self.kernel_autotune_calls.splice(import_str)
608        self.write_get_raw_stream_header_once()
609
610    @cache_on_self
611    def write_get_raw_stream_header_once(self) -> None:
612        self.imports.writeline(
613            V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
614        )
615        if config.triton.autotune_at_compile_time:
616            self.kernel_autotune_calls.writeline(
617                V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
618            )
619
620    def add_meta_once(self, meta: TritonMetaParams) -> str:
621        meta = repr(meta)
622        if meta not in self._metas:
623            var = f"meta{len(self._metas)}"
624            self._metas[meta] = var
625            self.header.writeline(f"{var} = {meta}")
626            if config.triton.autotune_at_compile_time:
627                self.kernel_autotune_calls.writeline(f"{var} = {meta}")
628                self._meta_vars.add(var)
629        return self._metas[meta]
630
631    @cache_on_self
632    def get_output_refs(self) -> List[str]:
633        return [x.codegen_reference(self.wrapper_call) for x in V.graph.graph_outputs]
634
635    def mark_output_type(self) -> None:
636        return
637
638    def codegen_input_size_asserts(self) -> None:
639        for name, buf in V.graph.graph_inputs.items():
640            if isinstance(buf, sympy.Expr):
641                continue
642
643            # comparing strides for 0 size tensor is tricky. Ignore them for now.
644            if sympy_product(buf.get_size()) == 0:
645                continue
646            size = self.codegen_shape_tuple(buf.get_size())
647            stride = self.codegen_shape_tuple(buf.get_stride())
648            self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})")
649
650    def codegen_input_nan_asserts(self) -> None:
651        self.prefix.writeline("# make sure graph inputs are not nan/inf")
652        for name, buf in V.graph.graph_inputs.items():
653            if isinstance(buf, sympy.Expr):
654                continue
655
656            line = f"assert not {name}.isnan().any().item()"
657            self.prefix.writeline(line)
658            line = f"assert not {name}.isinf().any().item()"
659            self.prefix.writeline(line)
660
661    def write_prefix(self) -> None:
662        self.prefix.splice(
663            """
664
665            async_compile.wait(globals())
666            del async_compile
667
668            def call(args):
669            """
670        )
671        with self.prefix.indent():
672            if config.triton.debug_sync_graph:
673                self.prefix.writeline(V.graph.device_ops.synchronize())
674            if V.graph.graph_inputs:
675                lhs = ", ".join(V.graph.graph_input_names)
676                if len(V.graph.graph_input_names) == 1:
677                    lhs += ","
678                self.prefix.writeline(f"{lhs} = args")
679                self.prefix.writeline("args.clear()")
680
681            self.codegen_inputs(self.prefix, V.graph.graph_inputs)
682            if config.size_asserts:
683                self.codegen_input_size_asserts()
684            if config.nan_asserts:
685                self.codegen_input_nan_asserts()
686
687    # this function (and below) takes a graph as input so
688    # that stream caching happens per graph instance. this
689    # is important for nested subgraph codegening.
690    def write_get_raw_stream(self, device_idx: int, graph=None) -> str:
691        self.write_get_raw_stream_header_once()
692        name = f"stream{device_idx}"
693        self.writeline(f"{name} = get_raw_stream({device_idx})")
694        return name
695
696    def get_codegened_graph(self):
697        return self.codegened_graph_stack[-1]
698
699    def push_codegened_graph(self, graph):
700        self.codegened_graph_stack.append(graph)
701
702    def pop_codegened_graph(self):
703        return self.codegened_graph_stack.pop()
704
705    def push_computed_sizes(self, computed_sizes):
706        from copy import deepcopy
707
708        return self.computed_sizes_stack.append(deepcopy(computed_sizes))
709
710    def pop_computed_sizes(self):
711        return self.computed_sizes_stack.pop()
712
713    def next_kernel_suffix(self) -> str:
714        return f"{next(self._names_iter)}"
715
716    def codegen_device_guard_enter(self, device_idx: int) -> None:
717        self.writeline(
718            EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index)
719        )
720        if config.triton.autotune_at_compile_time:
721            # mimic logic of EnterDeviceContextManagerLine.codegen for the autotune code block
722            self.write_triton_header_once()
723            self.kernel_autotune_calls.writeline(
724                f"with {V.graph.device_ops.device_guard(device_idx)}:"
725            )
726            self.kernel_autotune_calls.do_indent()
727            self.kernel_autotune_calls.writeline(
728                V.graph.device_ops.set_device(device_idx)
729            )
730            self.kernel_autotune_calls.writeline(
731                f"stream{device_idx} = get_raw_stream({device_idx})"
732            )
733        self.last_seen_device_guard_index = device_idx
734
735    def codegen_device_guard_exit(self) -> None:
736        self.writeline(ExitDeviceContextManagerLine())
737        if config.triton.autotune_at_compile_time:
738            self.kernel_autotune_calls.do_unindent()
739
740    def generate_return(self, output_refs: List[str]) -> None:
741        if output_refs:
742            self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
743        else:
744            self.wrapper_call.writeline("return ()")
745
746    def generate_before_suffix(self, result: IndentedBuffer) -> None:
747        return
748
749    def generate_end(self, result: IndentedBuffer) -> None:
750        return
751
752    def generate_fallback_kernel(self, fallback_kernel, args):
753        self.generate_extern_kernel_alloc(fallback_kernel, args)
754
755    def generate_extern_kernel_alloc(self, extern_kernel, args):
756        # If it's a NoneLayout then the extern_kernel should essentially be
757        # treated as if it doesn't return anything
758        no_return = isinstance(extern_kernel.layout, ir.NoneLayout)
759        output_name = extern_kernel.get_name()
760        origin_node = extern_kernel.get_origin_node()
761        kernel_name = extern_kernel.get_kernel_name()
762        ending = self.ending
763        if config.memory_planning and "view_as_complex" in kernel_name:
764            # view operation fallbacks cause issues since inductor
765            # doesn't know the memory is still needed and might reuse it.
766            ending = f".clone(){ending}"
767
768        if no_return:
769            self.writeline(f"{self.declare}{kernel_name}({', '.join(args)}){ending}")
770        else:
771            self.writeline(
772                f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}"
773            )
774            if (
775                self.supports_intermediate_hooks
776                and config.generate_intermediate_hooks
777                and origin_node is not None
778            ):
779                counters["inductor"]["intermediate_hooks"] += 1
780                self.writeline(
781                    f"run_intermediate_hooks({origin_node.name!r}, {output_name})"
782                )
783
784    def generate_extern_kernel_out(
785        self, kernel: str, out: str, out_view: Optional[str], args: List[str]
786    ):
787        args.append(f"out={out_view if out_view else out}")
788        self.writeline(f"{kernel}({', '.join(args)})")
789
790    def generate_user_defined_triton_kernel(
791        self,
792        kernel_name: str,
793        raw_args: List[Any],
794        grid: List[Any],
795        configs,
796        triton_meta,
797        constexprs,
798    ):
799        grid_fn, code = user_defined_kernel_grid_fn_code(
800            kernel_name, configs, grid, wrapper=self
801        )
802        # Must happen after free symbols are already codegened
803        # Emit the grid wrapper function right before the call
804        for line in code.split("\n"):
805            self.writeline(line)
806
807        args = [self.val_to_arg_str(v) for v in raw_args]
808        arg_types = [
809            arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg)
810            for arg in raw_args
811        ]
812        self.generate_kernel_call(
813            kernel_name, args, grid_fn=grid_fn, arg_types=arg_types, raw_args=raw_args
814        )
815
816    def generate_scatter_fallback(
817        self,
818        output,
819        inputs,
820        cpp_kernel_name,
821        python_kernel_name,
822        src_is_tensor,
823        reduce,
824        kwargs,
825    ):
826        line = f"{python_kernel_name}({','.join(map(str, inputs))}"
827        if python_kernel_name.startswith("aten.scatter_reduce"):
828            line += ", ".join([""] + kwargs)
829        else:
830            if reduce:
831                line += f", reduce={repr(reduce)}"
832        line += ")"
833        self.writeline(line)
834
835    def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
836        indices_str = f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}"
837        args = [x, indices_str, values, accumulate]
838        self.writeline(self.wrap_kernel_call(kernel, args))
839
840    def generate_extern_kernel_alloc_and_find_schema_if_needed(
841        self,
842        buf_name: str,
843        python_kernel_name: str,
844        cpp_kernel_name: str,
845        codegen_args: List[str],
846        cpp_op_schema: str,
847        cpp_kernel_key: str,
848        cpp_kernel_overload_name: str = "",
849        op_overload: Optional[torch._ops.OpOverload] = None,
850        raw_args=None,
851        outputs=None,
852    ):
853        self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(codegen_args)})")
854
855    def generate(self, is_inference):
856        with dynamo_timed("WrapperCodeGen.generate"):
857            return self._generate(is_inference)
858
859    def _generate(self, is_inference):
860        if config.profile_bandwidth:
861            self.write_triton_header_once()
862        result = IndentedBuffer()
863        result.splice(self.imports)
864        result.writeline("")
865        result.splice(self.header)
866        # We do not want the cpp header for intermediate const graph. Headers would be
867        # rendered by the main module instead.
868        if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph:
869            result = IndentedBuffer()
870
871        with contextlib.ExitStack() as stack:
872            stack.enter_context(self.wrapper_call.indent())
873            if config.profiler_mark_wrapper_call:
874                self.generate_profiler_mark_wrapper_call(stack)
875            if config.profile_bandwidth:
876                self.generate_start_graph()
877
878            # We disable planning during training because it presently increases peak memory consumption.
879            if is_inference and config.memory_planning:
880                self.memory_plan()
881                # TODO: integrate memory planning & stack allocation?
882                self.allow_stack_allocation = False
883            else:
884                self.memory_plan_reuse()
885
886            if config.triton.store_cubin:
887                self.generate_reset_kernel_saved_flags()
888
889            for line in self.lines:
890                if isinstance(line, WrapperLine):
891                    line.codegen(self.wrapper_call)
892                else:
893                    self.wrapper_call.writeline(line)
894
895            output_refs = self.get_output_refs()
896            self.mark_output_type()
897            if config.triton.debug_sync_graph:
898                self.wrapper_call.writeline(V.graph.device_ops.synchronize())
899
900            if config.profile_bandwidth:
901                self.generate_end_graph()
902
903            if config.triton.store_cubin:
904                self.generate_save_uncompiled_kernels()
905
906            if config.triton.autotune_at_compile_time:
907                self.generate_and_run_autotune_block()
908
909            self.generate_return(output_refs)
910
911        self.finalize_prefix()
912        result.splice(self.prefix)
913
914        with result.indent():
915            result.splice(self.wrapper_call)
916
917        self.generate_before_suffix(result)
918        result.splice(self.suffix)
919
920        self.generate_end(result)
921
922        self.add_benchmark_harness(result)
923
924        return result.getvaluewithlinemap()
925
926    def generate_and_run_autotune_block(self):
927        """
928        Compose self.kernel_autotune_defs and self.kernel_autotune_calls into a single block of
929        code and execute it to trigger Triton kernel compilation and auto-tuning
930        """
931        self.kernel_autotune_defs.splice(
932            """
933            async_compile.wait(globals())
934            del async_compile
935        """
936        )
937        scope = {}  # type: ignore[var-annotated]
938        tuning_code = (
939            self.kernel_autotune_defs.getvalue() + self.kernel_autotune_calls.getvalue()
940        )
941        if output_code_log.level == logging.DEBUG:
942            # Save the autotuning code block into a file
943            # Create a temporary file
944            with tempfile.NamedTemporaryFile(
945                dir=cache_dir(), suffix=".py", delete=False
946            ) as f:
947                f.write(tuning_code.encode("utf-8"))
948                file_path = f.name
949            output_code_log.debug(
950                "\nCompile-time auto-tuning code: \n%s\nAuto-tuning code written to %s",
951                tuning_code,
952                file_path,
953            )
954        # Execute the code to autotune kernels
955        exec(tuning_code, scope)
956
957    def memory_plan(self):
958        from .memory_planning import MemoryPlanner
959
960        self.lines = MemoryPlanner(self).plan(self.lines)
961
962    def memory_plan_reuse(self):
963        out_names = V.graph.get_output_names()
964
965        while (
966            self.lines
967            and isinstance(self.lines[-1], MemoryPlanningLine)
968            # TODO: this seems legit, NullLine has no node
969            and self.lines[-1].node.name not in out_names  # type: ignore[attr-defined]
970        ):
971            # these lines will be pointless
972            self.lines.pop()
973
974        # codegen allocations in two passes
975        planning_states = [MemoryPlanningState()]
976        past_planning_states = []
977        for i in range(len(self.lines)):
978            line = self.lines[i]
979            if isinstance(line, MemoryPlanningLine):
980                self.lines[i] = line.plan(planning_states[-1])
981            elif isinstance(line, EnterSubgraphLine):
982                planning_states.append(MemoryPlanningState())
983            elif isinstance(line, ExitSubgraphLine):
984                past_planning_states.append(planning_states.pop())
985        past_planning_states.append(planning_states.pop())
986        assert len(planning_states) == 0
987
988        # conservatively use the sum of all allocated buffer sizes
989        # in potentially nested scopes as the total allocated size
990        total_allocated_buffer_size = sum(
991            s.total_allocated_buffer_size for s in past_planning_states
992        )
993
994        self.allow_stack_allocation = (
995            self.allow_stack_allocation is not False
996            and config.allow_stack_allocation
997            and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE
998        )
999
1000    def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
1001        code.writeline(f"{self.declare}{name}_size = {name}.{self.size}{self.ending}")
1002
1003    def codegen_input_stride_var_decl(self, code: IndentedBuffer, name):
1004        code.writeline(
1005            f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}"
1006        )
1007
1008    def codegen_inputs(
1009        self, code: IndentedBuffer, graph_inputs: Dict[str, ir.TensorBox]
1010    ):
1011        """Assign all symbolic shapes to locals"""
1012
1013        @functools.lru_cache(None)
1014        def sizeof(name):
1015            self.codegen_input_size_var_decl(code, name)
1016            return f"{name}_size"
1017
1018        @functools.lru_cache(None)
1019        def strideof(name):
1020            self.codegen_input_stride_var_decl(code, name)
1021            return f"{name}_stride"
1022
1023        # Assign all symbolic shapes needed to local variables
1024        bound_vars: Set[sympy.Symbol] = set()
1025
1026        def is_expr(x):
1027            return isinstance(x[1], sympy.Expr)
1028
1029        graph_inputs_expr = list(filter(is_expr, graph_inputs.items()))
1030        graph_inputs_tensors = list(
1031            filter(lambda x: not is_expr(x), graph_inputs.items())
1032        )
1033
1034        for name, shape in graph_inputs_expr:
1035            if isinstance(shape, sympy.Symbol) and shape not in bound_vars:
1036                code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
1037                bound_vars.add(shape)
1038
1039        for name, value in graph_inputs_tensors:
1040            shapes = value.get_size()
1041            for dim, shape in enumerate(shapes):
1042                if isinstance(shape, sympy.Symbol) and shape not in bound_vars:
1043                    code.writeline(
1044                        f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
1045                    )
1046                    bound_vars.add(shape)
1047
1048        for name, value in graph_inputs_tensors:
1049            shapes = value.get_stride()
1050            for dim, shape in enumerate(shapes):
1051                if isinstance(shape, sympy.Symbol) and shape not in bound_vars:
1052                    code.writeline(
1053                        f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
1054                    )
1055                    bound_vars.add(shape)
1056
1057    def ensure_size_computed(self, sym: sympy.Symbol):
1058        if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE):
1059            if sym in self.computed_sizes:
1060                return
1061            self.computed_sizes.add(sym)
1062            expr = V.graph.sizevars.inv_precomputed_replacements[sym]
1063            self.writeline(
1064                f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}"
1065            )
1066
1067    def finalize_prefix(self):
1068        pass
1069
1070    def codegen_python_sizevar(self, x: Expr, *, simplify: bool = True) -> str:
1071        return pexpr(x, simplify=simplify)
1072
1073    def codegen_sizevar(self, x: Expr) -> str:
1074        return self.codegen_python_sizevar(x)
1075
1076    def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
1077        return f"{basename}[{index}]"
1078
1079    def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
1080        parts = list(map(self.codegen_python_sizevar, shape))
1081        if len(parts) == 0:
1082            return "()"
1083        if len(parts) == 1:
1084            return f"({parts[0]}, )"
1085        return f"({', '.join(parts)})"
1086
1087    def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
1088        return self.codegen_python_shape_tuple(shape)
1089
1090    def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str:
1091        return "alloc_from_pool({})".format(
1092            ", ".join(
1093                [
1094                    name,
1095                    pexpr(offset),  # bytes not numel
1096                    str(dtype),
1097                    self.codegen_shape_tuple(shape),
1098                    self.codegen_shape_tuple(stride),
1099                ]
1100            )
1101        )
1102
1103    def codegen_reinterpret_view(
1104        self, data, size, stride, offset, writer, dtype=None
1105    ) -> str:
1106        if (
1107            size == data.layout.size
1108            and stride == data.layout.stride
1109            and offset == data.layout.offset
1110        ):
1111            if dtype is not None and dtype != data.dtype:
1112                return f"aten.view.dtype({data.get_name()}, {dtype})"
1113            else:
1114                return f"{data.get_name()}"
1115        else:
1116            size = self.codegen_shape_tuple(size)
1117            stride = self.codegen_shape_tuple(stride)
1118            offset = self.codegen_sizevar(offset)
1119            if dtype is not None and dtype != data.dtype:
1120                return f"aten.view.dtype(reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset}), {dtype})"
1121            else:
1122                return (
1123                    f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})"
1124                )
1125
1126    def codegen_device_copy(self, src, dst):
1127        self.writeline(f"{dst}.copy_({src})")
1128
1129    def codegen_multi_output(self, name, value):
1130        self.writeline(f"{self.declare}{name} = {value}{self.ending}")
1131
1132    def codegen_dynamic_scalar(self, node):
1133        (data,) = (t.codegen_reference() for t in node.inputs)
1134        if len(node.keypath) == 0:
1135            self.writeline(f"{node.sym} = {data}.item()")
1136        elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey):
1137            self.writeline(f"{node.sym} = 1 if {data}.item() else 0")
1138        elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey):
1139            self.writeline(f"{node.sym}_undivided = {data}.item()")
1140            self.writeline(
1141                f"assert {node.sym}_undivided % {node.keypath[0].divisor} == 0, "
1142                f"f'{{{node.sym}_undivided}} not divisible by {node.keypath[0].divisor}'"
1143            )
1144            self.writeline(
1145                f"{node.sym} = {node.sym}_undivided // {node.keypath[0].divisor}"
1146            )
1147        else:
1148            raise AssertionError(f"unrecognized keypath {node.keypath}")
1149        # No one should ever use this buffer, but for uniformity
1150        # define the variable and assign it None
1151        self.writeline(f"{node.get_name()} = None")
1152
1153    def benchmark_compiled_module(self, output):
1154        def add_fake_input(name, shape, stride, device, dtype):
1155            output.writeline(
1156                f"{name} = rand_strided("
1157                f"{self.codegen_python_shape_tuple(shape)}, "
1158                f"{self.codegen_python_shape_tuple(stride)}, "
1159                f"device='{device}', dtype={dtype})"
1160            )
1161
1162        def add_expr_input(name, val):
1163            output.writeline(f"{name} = {val}")
1164
1165        def add_torchbind_input(name, value):
1166            import pickle
1167
1168            output.writeline(f"{name} = pickle.loads({pickle.dumps(value)!r})")
1169
1170        output.writelines(
1171            ["", "", "def benchmark_compiled_module(times=10, repeat=10):"]
1172        )
1173        with output.indent():
1174            output.splice(
1175                """
1176                from torch._dynamo.testing import rand_strided
1177                from torch._inductor.utils import print_performance
1178                """,
1179                strip=True,
1180            )
1181
1182            for name, value in V.graph.constants.items():
1183                # all the constants are global variables, that's why we need
1184                # these 'global var_name' lines
1185                output.writeline(f"global {name}")
1186                add_fake_input(
1187                    name, value.size(), value.stride(), value.device, value.dtype
1188                )
1189
1190            if len(V.graph.torchbind_constants) > 0:
1191                output.writeline("import pickle")
1192                for name, torchbind_obj in V.graph.torchbind_constants.items():
1193                    # all the constants are global variables, that's why we need
1194                    # these 'global var_name' lines
1195                    output.writeline(f"global {name}")
1196                    add_torchbind_input(name, torchbind_obj)
1197
1198            for name, value in V.graph.graph_inputs.items():
1199                if isinstance(value, sympy.Symbol) and isinstance(
1200                    V.graph.sizevars.var_to_val.get(value, None), SingletonInt
1201                ):
1202                    # Inductor should only work with dense -> dense graph, and
1203                    # SingletonInts belong to metadata that should only live on
1204                    # the subclass.
1205                    continue
1206                if isinstance(value, sympy.Expr):  # Don't need to add symbolic
1207                    # TODO: this fallback and those below actually will generate possibly
1208                    # invalid benchmark code, because it's not guaranteed 42
1209                    # is actually a valid value for the kernel in question.
1210                    # See https://github.com/pytorch/pytorch/issues/124686
1211                    add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42))
1212                else:
1213                    shape = [
1214                        V.graph.sizevars.size_hint(x, fallback=42)
1215                        for x in value.get_size()
1216                    ]
1217                    stride = [
1218                        V.graph.sizevars.size_hint(x, fallback=42)
1219                        for x in value.get_stride()
1220                    ]
1221                    add_fake_input(
1222                        name,
1223                        shape,
1224                        stride,
1225                        value.get_device(),
1226                        value.get_dtype(),
1227                    )
1228
1229            call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])"
1230            output.writeline(f"fn = lambda: {call_str}")
1231            output.writeline("return print_performance(fn, times=times, repeat=repeat)")
1232
1233    def add_benchmark_harness(self, output):
1234        """
1235        Append a benchmark harness to generated code for debugging
1236        """
1237        if not config.benchmark_harness:
1238            return
1239
1240        self.benchmark_compiled_module(output)
1241
1242        output.writelines(["", "", 'if __name__ == "__main__":'])
1243        with output.indent():
1244            output.writelines(
1245                [
1246                    "from torch._inductor.wrapper_benchmark import compiled_module_main",
1247                    f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)",
1248                ]
1249            )
1250
1251    def define_kernel(
1252        self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True
1253    ):
1254        metadata_comment = f"{metadata}\n" if metadata else ""
1255        body = f"\n\n{metadata_comment}{name} = {kernel}"
1256        self.header.splice(body)
1257        if config.triton.autotune_at_compile_time:
1258            self.kernel_autotune_defs.splice(body)
1259
1260    def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
1261        from torch.utils._triton import patch_triton_dtype_repr
1262
1263        patch_triton_dtype_repr()
1264
1265        original_name = kernel.__name__
1266
1267        from .common import KernelArgType, SizeArg, TensorArg
1268
1269        signature: List[KernelArgType] = []
1270        constants: Dict[int, Any] = {}
1271        non_constant_indices = []
1272        equal_to_1_arg_idx: List[int] = []
1273        for idx, key in enumerate(kernel.arg_names):
1274            if key not in kwargs:
1275                continue
1276            arg = kwargs[key]
1277            if idx in kernel.constexprs:
1278                constants[idx] = arg
1279            else:
1280                non_constant_indices.append(idx)
1281                if isinstance(arg, ir.Buffer):
1282                    signature.append(
1283                        TensorArg(
1284                            name=key,
1285                            buffer=arg.get_name(),
1286                            dtype=arg.get_dtype(),
1287                        )
1288                    )
1289                elif isinstance(arg, ir.ReinterpretView):
1290                    # for ReinterpretView we use the underlying
1291                    # buffer name and note the (possibly non-zero)
1292                    # offset relative to the underlying buffer
1293                    signature.append(
1294                        TensorArg(
1295                            name=key,
1296                            buffer=arg.data.get_name(),
1297                            dtype=arg.get_dtype(),
1298                            offset=arg.layout.offset,
1299                        )
1300                    )
1301                else:
1302                    signature.append(SizeArg(key, arg))
1303                    if isinstance(
1304                        arg, (int, sympy.Integer)
1305                    ) and V.graph.sizevars.statically_known_equals(
1306                        arg, 1  # type: ignore[arg-type]
1307                    ):
1308                        equal_to_1_arg_idx.append(idx)
1309        index_dtype = "tl.int32"
1310        triton_meta = {
1311            "signature": signature_to_meta(
1312                signature,
1313                size_dtype=index_dtype,
1314                indices=non_constant_indices,
1315            ),
1316            "device": DeviceProperties.create(
1317                V.graph.scheduler.get_current_device_or_throw()
1318            ),
1319            # Triton compiler includes equal_to_1 args into constants even
1320            # when they are not constexpr. otherwise there may be a segfault
1321            # during launching the Inductor-compiled Triton kernel.
1322            # TODO(aakhundov): add None args to constants, too. currently, this
1323            # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input.
1324            # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
1325            # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
1326            "constants": {
1327                **constants,
1328                **dict.fromkeys(equal_to_1_arg_idx, 1),
1329            },
1330            "configs": [
1331                config_of(
1332                    signature,
1333                    indices=non_constant_indices,
1334                )
1335            ],
1336        }
1337
1338        # Distinguish between different functions using function id
1339        cache_key: List[Any] = [id(kernel.fn)]
1340        if len(configs) > 0:
1341            for arg in kwargs.values():
1342                # We need to key on non tensor arg only in autotune mode
1343                if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
1344                    cache_key.append(arg)
1345        cache_key.append(str(triton_meta))
1346        cache_key = tuple(cache_key)
1347
1348        if cache_key in self.user_defined_kernel_cache:
1349            return self.user_defined_kernel_cache[cache_key]
1350
1351        name = f"{original_name}_{len(self.user_defined_kernel_cache)}"
1352        # Add to the cache for the next use
1353        self.user_defined_kernel_cache[cache_key] = (name, triton_meta)
1354
1355        compile_wrapper = IndentedBuffer()
1356        compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''")
1357
1358        from .triton import gen_common_triton_imports, TritonKernel
1359
1360        compile_wrapper.splice(gen_common_triton_imports())
1361
1362        inductor_meta = {
1363            "kernel_name": name,
1364            **TritonKernel.inductor_meta_common(),
1365        }
1366
1367        configs = [
1368            {
1369                "kwargs": config.kwargs,
1370                "num_warps": config.num_warps,
1371                "num_stages": config.num_stages,
1372            }
1373            for config in configs
1374        ]
1375
1376        compile_wrapper.splice(
1377            f"""
1378            @triton_heuristics.user_autotune(
1379                configs={configs!r},
1380                inductor_meta={inductor_meta!r},
1381                triton_meta={triton_meta!r},
1382                filename=__file__,
1383                custom_kernel=True,
1384            )
1385            @triton.jit
1386            """
1387        )
1388        compile_wrapper.splice(kernel.src, strip=True)
1389
1390        # Also include any possible kernel being called indirectly
1391        from triton import JITFunction  # type: ignore[name-defined, attr-defined]
1392        from triton.language import constexpr  # type: ignore[name-defined]
1393
1394        # global constexpr vars handled above
1395        symbols_included = {original_name}
1396
1397        def traverse(cur_kernel):
1398            # here we extract the unqualified names (i.e., not attributes and
1399            # without prepended module name) loaded in the kernel code, which
1400            # are matched with the co_names and __globals__ below to codegen
1401            # the respective imports necessary for the kernel compilation
1402            unqualified_loads = {
1403                inst.argval
1404                for inst in dis.Bytecode(cur_kernel.fn)
1405                if inst.opname == "LOAD_GLOBAL"
1406            }
1407            global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {})
1408            for symbol_name in cur_kernel.fn.__code__.co_names:
1409                if symbol_name in symbols_included:
1410                    continue
1411                if symbol_name in cur_kernel.fn.__globals__:
1412                    symbol = cur_kernel.fn.__globals__[symbol_name]
1413                    if isinstance(symbol, JITFunction):
1414                        compile_wrapper.newline()
1415                        compile_wrapper.writeline("@triton.jit")
1416                        compile_wrapper.splice(symbol.src, strip=True)
1417                        symbols_included.add(symbol_name)
1418                        traverse(symbol)
1419                    elif isinstance(symbol, (int, str, bool, constexpr)):
1420                        compile_wrapper.newline()
1421                        if isinstance(symbol, constexpr):
1422                            symbol_str = f"tl.constexpr({symbol.value!r})"
1423                        else:
1424                            symbol_str = f"{symbol!r}"
1425                        if annotation := global_annotations.get(symbol_name):
1426                            annotion_code = ""
1427                            if isinstance(annotation, type):
1428                                annotation_code = (
1429                                    f": {annotation.__module__}.{annotation.__name__}"
1430                                )
1431                            else:
1432                                annotation_code = f": {annotation!r}"
1433                            compile_wrapper.writeline(
1434                                f"{symbol_name}{annotation_code} = {symbol_str}"
1435                            )
1436                        else:
1437                            compile_wrapper.writeline(f"{symbol_name} = {symbol!r}")
1438                        symbols_included.add(symbol_name)
1439                    elif (
1440                        symbol_name in unqualified_loads
1441                        and symbol_name != "tl"  # already imported
1442                        and hasattr(symbol, "__module__")
1443                        # only codegen imports from triton; JITFunctions
1444                        # imported from other modules will be codegened
1445                        # in the separate branch above
1446                        and symbol.__module__.startswith("triton")
1447                    ):
1448                        # a global symbol imported from triton is referenced
1449                        # without module qualification (i.e., `store` instead
1450                        # of `tl.store`): need to codegen an import
1451                        compile_wrapper.writeline(
1452                            f"from {symbol.__module__} import {symbol.__name__} as {symbol_name}"
1453                        )
1454                        symbols_included.add(symbol_name)
1455
1456        traverse(kernel)
1457
1458        current_device = V.graph.scheduler.get_current_device_or_throw()
1459        compile_wrapper.writeline(f"''', device_str='{current_device.type}')")
1460        _, lineno = inspect.getsourcelines(kernel.fn)
1461        srcfile = inspect.getsourcefile(kernel.fn)
1462        metadata = f"# Original path: {srcfile}:{lineno}"
1463        self.define_kernel(
1464            name,
1465            compile_wrapper.getvalue(),
1466            metadata,
1467        )
1468        return name, triton_meta
1469
1470    def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None):
1471        expr = f"{kernel_name}_{tree.prefix}numel"
1472        if suffix is not None:
1473            expr += f"_{suffix}"
1474        if (expr, V.graph) not in self.kernel_numel_expr:
1475            # declare expr once in each graph (scope)
1476            self.kernel_numel_expr.add((expr, V.graph))
1477            self.writeline(
1478                f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}"
1479            )
1480        else:
1481            self.writeline(f"{expr} = {self.expr_printer(tree.numel)}{self.ending}")
1482        # We can get symbolic expressions here, like s0*64
1483        # It is fine to have them here, but we need to handle them correctly as their own type
1484        # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*
1485        # scalars as well.
1486        # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for
1487        # constant now, need type info. I agree, this needs type info, and while this is not true type info
1488        # it suffices as a type hint for the purposes of producing the correct code for this type.
1489        return SymbolicCallArg(expr, tree.numel)
1490
1491    def generate_workspace_allocation(self, nbytes, device, zero_fill):
1492        line = self.make_allocation(
1493            "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,)
1494        )
1495        self.writeline(line)
1496        if zero_fill:
1497            self.writeline(f"workspace.zero_(){self.ending}")
1498
1499    def wrap_kernel_call(self, name, call_args):
1500        return f"{name}({', '.join(call_args)}){self.ending}"
1501
1502    def generate_profiler_mark_wrapper_call(self, stack):
1503        self.wrapper_call.writeline("from torch.profiler import record_function")
1504        self.wrapper_call.writeline(
1505            f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):"
1506        )
1507        stack.enter_context(self.wrapper_call.indent())
1508
1509    def generate_start_graph(self):
1510        self.wrapper_call.writeline("start_graph()")
1511
1512    def generate_end_graph(self):
1513        self.wrapper_call.writeline(f"end_graph({config.profile_bandwidth_output!r})")
1514
1515    def generate_reset_kernel_saved_flags(self):
1516        self.wrapper_call.splice(
1517            f"""
1518            for kernel in globals().values():
1519                if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner):
1520                    kernel.cuda_kernel_saved = False
1521            """
1522        )
1523
1524    def generate_save_uncompiled_kernels(self):
1525        """
1526        Precompile and save the CUBINs of the Triton kernels that haven't
1527        been precompiled and saved as a side effect of running the generated
1528        JIT model (Python wrapper). This can happen when the model contains
1529        control flow: only one pass through the control flow operators covers
1530        the kernels that are saved, the remaining kernels are not launched,
1531        hence not saved. The main purpose of this codegen is to compile and
1532        save the Triton kernels outside the active control flow path for
1533        subsequent AOTInductor code generation and compilation.
1534        """
1535        self.wrapper_call.splice(
1536            f"""
1537            for kernel in globals().values():
1538                if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner):
1539                    if not kernel.cuda_kernel_saved:
1540                        if len(kernel.launchers) == 0:
1541                            kernel.precompile()
1542                        kernel.save_gpu_kernel(
1543                            grid=(0, 0, 0),   # use dummy grid
1544                            stream="stream",  # use dummy stream
1545                            launcher=kernel.launchers[0],
1546                        )
1547            """
1548        )
1549
1550    def generate_default_grid(
1551        self,
1552        kernel_name: str,
1553        grid: List[Any],
1554        cuda: bool = True,
1555        grid_callable: Optional[Callable[..., Any]] = None,
1556        **grid_extra_kwags,
1557    ):
1558        return grid
1559
1560    def prepare_triton_kernel_call(self, device_index, call_args):
1561        def wrap_arg(arg):
1562            if isinstance(arg, str):
1563                # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
1564                return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg
1565            elif isinstance(arg, (int, float, bool, SymbolicCallArg)):
1566                return str(arg)
1567            else:
1568                return self.expr_printer(V.graph.sizevars.simplify(arg))
1569
1570        call_args = [wrap_arg(arg) for arg in call_args]
1571
1572        if device_index is None:
1573            current_device = V.graph.scheduler.get_current_device_or_throw()
1574            device_index = current_device.index
1575
1576        return device_index, call_args
1577
1578    def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None):
1579        if isinstance(arg_type, torch_dtype):
1580            if V.graph.try_get_buffer(arg) is not None:
1581                buf_name = arg
1582                buf = V.graph.get_buffer(arg)
1583            else:
1584                assert (
1585                    raw_arg is not None
1586                ), "V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
1587                buf_name = f"tmp_arg_{index}"
1588                buf = raw_arg
1589
1590            size = V.graph.sizevars.size_hints(
1591                buf.get_size(),
1592                fallback=config.unbacked_symint_fallback,
1593            )
1594            stride = V.graph.sizevars.size_hints(
1595                buf.get_stride(),
1596                fallback=config.unbacked_symint_fallback,
1597            )
1598            device = buf.get_device()
1599            dtype = buf.get_dtype()
1600            offset = V.graph.sizevars.size_hint(
1601                buf.layout.offset,
1602                fallback=config.unbacked_symint_fallback,
1603            )
1604            value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset})"
1605            self.kernel_autotune_calls.writeline(f"{buf_name} = {value}")
1606            return buf_name
1607        elif issubclass(arg_type, sympy.Basic) or isinstance(arg, SymbolicCallArg):
1608            # arg is a symbol or symbolic expression
1609            if isinstance(arg, str):
1610                if arg in self._meta_vars:
1611                    return arg
1612                if raw_arg is None:
1613                    return "None"
1614                arg = raw_arg
1615            if isinstance(arg, SymbolicCallArg):
1616                arg = arg.inner_expr
1617            if arg in V.graph.sizevars.inv_precomputed_replacements:
1618                arg = V.graph.sizevars.inv_precomputed_replacements[arg]
1619            return str(
1620                V.graph.sizevars.size_hint(
1621                    arg,
1622                    fallback=config.unbacked_symint_fallback,
1623                )
1624            )
1625        elif isinstance(arg, (str, int, float, bool)):
1626            return str(arg)
1627        elif isinstance(arg, list):
1628            return f"[{', '.join(self.generate_example_arg_value(a, type(a)) for a in arg)}]"
1629        else:
1630            raise NotImplementedError(f"Unsupported type {type(arg)}")
1631
1632    def _grid_dim_str(self, grid_per_dim):
1633        if isinstance(grid_per_dim, list):
1634            return (
1635                "[" + ", ".join(self._grid_dim_str(item) for item in grid_per_dim) + "]"
1636            )
1637        else:
1638            return pexpr(grid_per_dim)
1639
1640    def generate_kernel_call(
1641        self,
1642        kernel_name,
1643        call_args,
1644        grid=None,
1645        device_index=None,
1646        cuda=True,
1647        triton=True,
1648        arg_types=None,
1649        raw_args=None,
1650        grid_fn: str = "grid",
1651        triton_meta=None,
1652        autotune_configs=None,
1653        grid_extra_kwargs="",
1654    ):
1655        """
1656        Generates kernel call code.
1657
1658        cuda: Defines whether the backend is GPU. Otherwise the backend is CPU.
1659
1660        triton: Defines whether the GPU backend uses Triton for codegen.
1661                Otherwise it uses the CUDA language for codegen.
1662                Only valid when cuda == True.
1663        """
1664        if cuda:
1665            device_index, call_args_str = self.prepare_triton_kernel_call(
1666                device_index, call_args
1667            )
1668            call_args_str = ", ".join(call_args_str)
1669            stream_name = self.write_get_raw_stream(device_index, V.graph)
1670            if triton:
1671                self.write_triton_header_once()
1672                if grid is None:
1673                    grid_str = grid_fn
1674                else:
1675                    grid_str = ", ".join(self._grid_dim_str(item) for item in grid)
1676                    if grid_extra_kwargs:
1677                        grid_str = f"{grid_str}, {grid_extra_kwargs}"
1678                    grid_str = f"{grid_fn}({grid_str})"
1679                self.writeline(
1680                    f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
1681                )
1682                if (
1683                    config.triton.autotune_at_compile_time
1684                    and kernel_name not in self.kernel_autotune_names
1685                ):
1686                    # Create example args for autotune in a separate epilogue
1687                    assert arg_types is not None and len(call_args) == len(
1688                        arg_types
1689                    ), "call_args and arg_types do not match"
1690
1691                    tensor_args = {}
1692                    all_args = []
1693                    if raw_args is None:
1694                        # create a dummy raw_args for uniform behavior in the following loop
1695                        raw_args = [None] * len(call_args)
1696                    else:
1697                        assert len(raw_args) == len(
1698                            call_args
1699                        ), "call_args and raw_args do not match"
1700
1701                    for i, (arg, arg_type, raw_arg) in enumerate(
1702                        zip(call_args, arg_types, raw_args)
1703                    ):
1704                        key = None
1705                        if isinstance(arg, str) and "=" in str(arg):
1706                            # arg may be passed in a kwarg style, and then we need to extract its value
1707                            key, arg = arg.split("=")
1708
1709                        if isinstance(arg_type, torch_dtype):
1710                            if arg not in tensor_args:
1711                                arg_str = self.generate_example_arg_value(
1712                                    arg, arg_type, raw_arg, i
1713                                )
1714                                tensor_args[arg] = arg_str
1715                            else:
1716                                arg_str = tensor_args[arg]
1717                        else:
1718                            arg_str = self.generate_example_arg_value(
1719                                arg, arg_type, raw_arg, i
1720                            )
1721                        all_args.append(arg_str if key is None else f"{key}={arg_str}")
1722
1723                    if grid is None:
1724                        grid_str = grid_fn
1725                    else:
1726                        grid_str = ", ".join(
1727                            self.generate_example_arg_value(g, type(g)) for g in grid
1728                        )
1729                        if grid_extra_kwargs:
1730                            grid_str = f"{grid_str}, {grid_extra_kwargs}"
1731                        grid_str = f"{grid_fn}({grid_str})"
1732
1733                    self.kernel_autotune_calls.writeline(
1734                        f"{kernel_name}.run({', '.join(all_args)}, grid={grid_str}, stream={stream_name})"
1735                    )
1736                    self.kernel_autotune_calls.writeline(
1737                        f"del {', '.join(arg for arg in tensor_args.values())}\n",
1738                    )
1739                    self.kernel_autotune_names.add(kernel_name)
1740            else:
1741                stream_ptr = f"c_void_p({stream_name})"
1742                self.writeline(
1743                    f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})"
1744                )
1745        else:
1746            self.writeline(self.wrap_kernel_call(kernel_name, call_args))
1747
1748    def writeline(self, line):
1749        self.lines.append(line)
1750
1751    def writelines(self, lines):
1752        for line in lines:
1753            self.writeline(line)
1754
1755    def enter_context(self, ctx):
1756        self.lines.append(LineContext(ctx))
1757
1758    def val_to_arg_str(self, s, type_=None):
1759        from torch.utils._triton import dtype_to_string, has_triton_package
1760
1761        if has_triton_package():
1762            import triton
1763
1764        if isinstance(s, SymTypes):
1765            return pexpr(s.node.expr)
1766        elif isinstance(s, sympy.Expr):
1767            return pexpr(s)
1768        elif isinstance(s, (tuple, list)):
1769
1770            @dataclasses.dataclass
1771            class Shim:
1772                ref: Any
1773
1774                def __repr__(self):
1775                    return self.ref
1776
1777            return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s))
1778        elif isinstance(s, torch._ops.OpOverload):
1779            return _get_qualified_name(s)
1780        elif isinstance(s, (ir.Buffer, ReinterpretView)):
1781            return s.codegen_reference()
1782        elif has_triton_package() and isinstance(s, triton.language.dtype):  # type: ignore[possibly-undefined]
1783            return dtype_to_string(s)
1784        else:
1785            return repr(s)
1786
1787    # The following methods are for memory management
1788    def make_buffer_allocation(self, buffer):
1789        device = buffer.get_device()
1790        dtype = buffer.get_dtype()
1791        shape = tuple(buffer.get_size())
1792        stride = tuple(buffer.get_stride())
1793        return self.make_allocation(buffer.get_name(), device, dtype, shape, stride)
1794
1795    def make_allocation(self, name, device, dtype, shape, stride):
1796        if device.type in ("cpu", "cuda", "xpu"):
1797            # optimized path for faster allocations, saving ~2us versus the stuff below
1798            return (
1799                f"{name} = empty_strided_{device.type}("
1800                f"{self.codegen_shape_tuple(shape)}, "
1801                f"{self.codegen_shape_tuple(stride)}, "
1802                f"{dtype})"
1803            )
1804        # all other devices:
1805        return (
1806            f"{name} = empty_strided("
1807            f"{self.codegen_shape_tuple(shape)}, "
1808            f"{self.codegen_shape_tuple(stride)}, "
1809            f"device='{device.type}', dtype={dtype})"
1810        )
1811
1812    def make_tensor_alias(self, new_name, old_name, comment=""):
1813        return f"{self.declare}{new_name} = {old_name}{self.ending}  {self.comment} {comment}"
1814
1815    def make_buffer_free(self, buffer):
1816        return f"del {buffer.get_name()}"
1817
1818    def make_free_by_names(self, names_to_del: List[str]):
1819        return f"del {', '.join(name for name in names_to_del)}"
1820
1821    def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
1822        return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending}  {self.comment} reuse"
1823
1824    def make_buffer_reuse(self, old: ir.Buffer, new: ir.Buffer, delete_old: bool):
1825        assert old.get_dtype() == new.get_dtype()
1826        old_name = old.get_name()
1827        new_name = new.get_name()
1828        del_line = ";"
1829        if old_name not in V.graph.get_output_names() and delete_old:
1830            del_line = f"; {self.make_buffer_free(old)}"
1831
1832        if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
1833            if old_name in self.stack_allocated_buffers:
1834                self.stack_allocated_buffers[new_name] = new
1835            return self.codegen_exact_buffer_reuse(old_name, new_name, del_line)
1836
1837        reinterpret_view = self.codegen_reinterpret_view(
1838            old, new.get_size(), new.get_stride(), 0, self.wrapper_call
1839        )
1840        if reinterpret_view in self.stack_allocated_buffers:
1841            self.stack_allocated_buffers[new_name] = new
1842        return f"{self.declare_maybe_reference}{new_name} = {reinterpret_view}{del_line}  {self.comment} reuse"
1843
1844    def codegen_deferred_allocation(self, name, layout):
1845        self.writeline(
1846            DeferredLine(
1847                name,
1848                f"{self.declare_maybe_reference}{name} = {layout.view.codegen_reference()}{self.ending}  "
1849                f"{self.comment} alias",
1850            )
1851        )
1852
1853    def codegen_allocation(self, buffer: ir.Buffer):
1854        name = buffer.get_name()
1855
1856        if name in V.graph.removed_buffers or name in self.allocated:
1857            return
1858        self.allocated.add(name)
1859        if isinstance(
1860            buffer.get_defining_op(),
1861            (ir.ExternKernelAlloc, ir.MultiOutput),
1862        ):
1863            return
1864
1865        layout = buffer.get_layout()
1866        if isinstance(layout, ir.MutationLayoutSHOULDREMOVE):
1867            return
1868        if isinstance(layout, ir.NoneLayout):
1869            return
1870        if isinstance(layout, ir.NonOwningLayout):
1871            assert isinstance(
1872                layout.view, ir.ReinterpretView
1873            ), f"unexpected {type(layout.view)}: {layout.view}"
1874            assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data)
1875            assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data)
1876            self.codegen_allocation(layout.view.data.data)
1877            self.codegen_deferred_allocation(name, layout)
1878            return
1879
1880        self.writeline(AllocateLine(self, buffer))
1881
1882    def codegen_free(self, buffer):
1883        name = buffer.get_name()
1884
1885        # can be freed but not reused
1886        if isinstance(buffer, ir.InputBuffer):
1887            self.writeline(self.make_buffer_free(buffer))
1888            return
1889
1890        if not self.can_reuse(buffer):
1891            return
1892        self.freed.add(name)
1893
1894        self.writeline(FreeIfNotReusedLine(self, buffer))
1895
1896    def can_reuse(self, input_buffer, output_buffer=None):
1897        name = input_buffer.get_name()
1898        return not (
1899            name in V.graph.removed_buffers
1900            or name in V.graph.graph_inputs
1901            or name in V.graph.constants
1902            or name in V.graph.torchbind_constants
1903            or name in V.graph.never_reuse_buffers
1904            or name in self.freed
1905        )
1906
1907    def did_reuse(self, buffer, reused_buffer):
1908        # Check whether a given buffer was reused by a possible reuser in the wrapper codegen
1909        # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
1910        return (
1911            buffer.get_name() in self.reuses
1912            and self.reuses[buffer.get_name()] == reused_buffer.get_name()
1913        )
1914
1915    def codegen_inplace_reuse(self, input_buffer: ir.Buffer, output_buffer: ir.Buffer):
1916        assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
1917        self.codegen_allocation(input_buffer)
1918        self.freed.add(input_buffer.get_name())
1919        self.allocated.add(output_buffer.get_name())
1920        self.reuses[output_buffer.get_name()] = input_buffer.get_name()
1921        self.writeline(ReuseLine(self, input_buffer, output_buffer))
1922
1923    def codegen_unbacked_symbol_decl(self, symbol):
1924        name = str(symbol)
1925        if name in self.unbacked_symbol_decls:
1926            return name
1927        else:
1928            # When in CppWrapperCpu, we should only generate the declaration once
1929            self.unbacked_symbol_decls.add(name)
1930            return self.declare + name
1931
1932    def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
1933        for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs):
1934            self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}")
1935
1936    def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
1937        for inner_output, outer_output in zip(
1938            subgraph.graph.graph_outputs, outer_outputs
1939        ):
1940            self.writeline(
1941                f"{outer_output} = {inner_output.codegen_reference()}{self.ending}"
1942            )
1943
1944    def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
1945        try:
1946            self.push_codegened_graph(subgraph.graph)
1947            self.writeline(f"{self.comment} subgraph: {subgraph.name}")
1948            self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs)
1949            parent_graph = V.graph
1950            with V.set_graph_handler(subgraph.graph):
1951                subgraph.graph.codegen_subgraph(
1952                    parent_graph=parent_graph,
1953                )
1954            self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs)
1955        finally:
1956            self.pop_codegened_graph()
1957
1958    def codegen_conditional(self, conditional):
1959        name = conditional.get_name()
1960
1961        self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
1962
1963        outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
1964        outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
1965
1966        predicate = conditional.predicate.codegen_reference()
1967        if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
1968            # move the Tensor predicate to host
1969            predicate = f"{predicate}.item()"
1970
1971        self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
1972        self.writeline(f"if {predicate}:")
1973        self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
1974        self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
1975        self.writeline(ExitSubgraphLine(self))
1976        self.writeline("else:")
1977        self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
1978        self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
1979        self.writeline(ExitSubgraphLine(self))
1980
1981    def codegen_while_loop(self, while_loop):
1982        name = while_loop.get_name()
1983        outer_carried_inputs = [
1984            buf.codegen_reference() for buf in while_loop.carried_inputs
1985        ]
1986        outer_additional_inputs = [
1987            buf.codegen_reference() for buf in while_loop.additional_inputs
1988        ]
1989
1990        self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}")
1991        for i, inp in enumerate(outer_carried_inputs):
1992            # set the initial state before the loop
1993            self.writeline(f"{name}[{i}] = {inp}")
1994
1995        cond_outer_inputs = [
1996            *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))],
1997            *outer_additional_inputs,
1998        ]
1999        cond_outer_outputs = [f"{name}_cond_result"]
2000        body_outer_inputs = list(
2001            cond_outer_inputs
2002        )  # same inputs for cond_fn and body_fn
2003        # Carry over the state from body_fn. Note: We only carry over
2004        # the carried_inputs part of the inputs, the additional ones
2005        # are passed in as they're before.
2006        body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)]
2007
2008        self.writeline("while True:")
2009        self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
2010        self.codegen_subgraph(
2011            while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
2012        )
2013        self.writeline(
2014            f"if not {cond_outer_outputs[0]}.item(): break"
2015        )  # condition doesn't hold
2016        self.writeline(ExitSubgraphLine(self))
2017        self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
2018        self.codegen_subgraph(
2019            while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
2020        )
2021        self.writeline(ExitSubgraphLine(self))
2022
2023    @staticmethod
2024    def statically_known_int_or_none(x):
2025        try:
2026            if getattr(x, "free_symbols", None):
2027                # _maybe_evaluate_static will return (s0 // (2 // s0)) as 2, but
2028                # the actual codegen will still generate the full expression here.
2029                return None
2030            if isinstance(x, int):
2031                return x
2032            val = V.graph._shape_env._maybe_evaluate_static(x)
2033            return int(val)
2034        except Exception:
2035            return None
2036
2037    @staticmethod
2038    def statically_known_list_of_ints_or_none(lst):
2039        result = []
2040        for x in lst:
2041            num = WrapperCodeGen.statically_known_int_or_none(x)
2042            if num is None:
2043                return None
2044            result.append(num)
2045        return result
2046
2047    @staticmethod
2048    def is_statically_known_list_of_ints(lst):
2049        return WrapperCodeGen.statically_known_list_of_ints_or_none(lst) is not None
2050
2051    @staticmethod
2052    def static_shape_for_buffer_or_none(buffer):
2053        return WrapperCodeGen.statically_known_list_of_ints_or_none(buffer.get_size())
2054
2055    @staticmethod
2056    def can_prove_buffer_has_static_shape(buffer):
2057        return WrapperCodeGen.static_shape_for_buffer_or_none(buffer) is not None
2058