1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import collections 5import contextlib 6import dataclasses 7import dis 8import functools 9import inspect 10import logging 11import operator 12import re 13import tempfile 14from itertools import count 15from typing import ( 16 Any, 17 Callable, 18 Dict, 19 Iterator, 20 List, 21 Optional, 22 Set, 23 Tuple, 24 TYPE_CHECKING, 25 Union, 26) 27 28import sympy 29from sympy import Expr 30 31import torch 32import torch._ops 33from torch import dtype as torch_dtype 34from torch._dynamo.utils import counters, dynamo_timed 35from torch._inductor.codegen.debug_utils import DebugPrinterManager 36from torch._inductor.codegen.multi_kernel import MultiKernelState 37from torch._inductor.runtime.runtime_utils import cache_dir 38from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes 39from torch.fx.node import _get_qualified_name 40from torch.utils._sympy.singleton_int import SingletonInt 41from torch.utils._sympy.symbol import symbol_is_type, SymT 42 43from .. import async_compile, config, ir 44from ..codecache import output_code_log 45from ..ir import ReinterpretView 46from ..runtime import triton_heuristics 47from ..runtime.hints import DeviceProperties 48from ..utils import ( 49 cache_on_self, 50 get_benchmark_name, 51 LineContext, 52 sympy_product, 53 sympy_str, 54) 55from ..virtualized import V 56from .aoti_hipify_utils import maybe_hipify_code_wrapper 57from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter 58from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta 59 60 61if TYPE_CHECKING: 62 import triton 63 64 from ..graph import GraphLowering 65 66 67pexpr = PythonPrinter().doprint 68 69 70ReuseKey = Tuple[torch.device, torch.dtype, str] 71 72 73def buffer_reuse_key(node: ir.Buffer) -> ReuseKey: 74 return ( 75 node.get_device(), 76 node.get_dtype(), 77 # NB: this is symbolic so that we don't try to reuse a buffer 78 # for s0 for s1, just because they happen to share the same 79 # size hint 80 sympy_str(V.graph.sizevars.simplify(node.layout.storage_size())), 81 ) 82 83 84def convert_arg_type(arg: torch.Argument) -> str: 85 from .cpp import CONTAINER_PYTHON_TO_CPP, PYTHON_TO_CPP 86 87 # use x.real_type instead of x.type so that we get ScalarType instead of int 88 python_type = repr(arg.real_type) # type: ignore[attr-defined] 89 90 if python_type == "Tensor": 91 # Conversions rules follow https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func 92 if arg.alias_info is not None and arg.alias_info.is_write: 93 return f"at::{python_type}&" 94 else: 95 return f"at::{python_type} const&" 96 97 if python_type in PYTHON_TO_CPP: 98 cpp_type = PYTHON_TO_CPP[python_type] 99 return cpp_type 100 101 # Convert args of container types e.g. Optional[*] 102 for py_container, cpp_container in CONTAINER_PYTHON_TO_CPP.items(): 103 container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type) 104 if len(container_match) == 1: 105 contained_type = container_match[0] 106 assert ( 107 contained_type in PYTHON_TO_CPP 108 ), f"unsupported {py_container} type in convert_arg_type: {contained_type}" 109 cpp_contained_type = PYTHON_TO_CPP[contained_type] 110 return f"{cpp_container}<{cpp_contained_type}>" 111 112 raise AssertionError(f"unsupport python_type: {python_type}") 113 114 115def convert_return_type(ret: torch.Argument) -> str: 116 # use x.real_type instead of x.type so that we get ScalarType instead of int 117 python_type = repr(ret.real_type) # type: ignore[attr-defined] 118 python_to_cpp = { 119 "Tensor": "at::Tensor", 120 "List[Tensor]": "std::vector<at::Tensor>", 121 } 122 123 cpp_type = python_to_cpp.get(python_type, None) 124 assert cpp_type is not None, f"NYI return type: {python_type}" 125 # An output aliasing an input is returned by reference only when it's a 126 # Tensor, not when it's a Tensor[]. For example, aten.split.Tensor's output 127 # aliases the input tensor, but the op returns a vector by value. 128 if python_type == "Tensor" and ret.alias_info is not None: 129 cpp_type += "&" 130 return cpp_type 131 132 133def get_cpp_op_schema(kernel: torch._ops.OpOverload) -> str: 134 args = kernel._schema.arguments 135 returns = kernel._schema.returns 136 137 num_returns = len(returns) 138 assert num_returns > 0, "must have at least one return value" 139 140 if num_returns == 1: 141 cpp_return_value = convert_return_type(returns[0]) 142 elif num_returns > 1: 143 tuple_returns = ", ".join([convert_return_type(r) for r in returns]) 144 cpp_return_value = f"std::tuple<{tuple_returns}>" 145 146 cpp_arg_type = [f"{convert_arg_type(arg)} {arg.name}" for arg in args] 147 return f"{cpp_return_value}({', '.join(cpp_arg_type)})" # type: ignore[possibly-undefined] 148 149 150# TODO: Move to a well known place 151TritonMetaParams = Dict[str, int] 152TritonGrid = Union[ 153 Tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], Tuple[int, ...]] 154] 155 156 157def user_defined_kernel_grid_fn_code( 158 name: str, 159 configs: List[triton.Config], # type: ignore[name-defined] 160 grids: List[TritonGrid], 161 wrapper: Optional[WrapperCodeGen] = None, 162) -> Tuple[str, str]: 163 output = IndentedBuffer() 164 165 def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr: 166 return item if isinstance(item, sympy.Expr) else sympy.Integer(item) 167 168 def determine_grid( 169 grid: TritonGrid, 170 ): 171 """ 172 This function return a tuple of two values: the first one is for the real grid 173 which is used in the generated code; the second one is an example grid with 174 concreate values which is used in the autotune block to run the generated 175 kernels at compile time. 176 """ 177 if wrapper is None or callable(grid): 178 # return as-is when used in eager mode or when grid is callable 179 return grid, grid 180 # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen 181 sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid) 182 return ( 183 wrapper.codegen_shape_tuple(sympy_grid), 184 wrapper.codegen_shape_tuple( 185 tuple( 186 wrapper.generate_example_arg_value(g, type(g)) for g in sympy_grid 187 ) 188 ) 189 if config.triton.autotune_at_compile_time 190 else None, 191 ) 192 193 def writeline(line: str, example_grid: Optional[str] = None): 194 output.writeline(line) 195 if ( 196 wrapper 197 and config.triton.autotune_at_compile_time 198 and name not in wrapper.kernel_autotune_names 199 ): 200 wrapper.kernel_autotune_calls.writeline(example_grid or line) 201 202 fn_name = f"grid_wrapper_for_{name}" 203 writeline(f"def {fn_name}(meta):") 204 kernel_autotune_calls_indent = ( 205 wrapper.kernel_autotune_calls.indent() 206 if wrapper and config.triton.autotune_at_compile_time 207 else contextlib.nullcontext() 208 ) 209 with output.indent(), kernel_autotune_calls_indent: 210 if len(grids) == 1: 211 grid, example_grid = determine_grid(grids[0]) 212 writeline(f"return {grid}", f"return {example_grid}") 213 else: 214 assert len(grids) > 1 215 assert len(grids) == len(configs) 216 seen = set() 217 for grid, c in zip(grids, configs): 218 guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()] 219 guards = " and ".join(guards) 220 grid, example_grid = determine_grid(grid) 221 statement = f"if {guards}: return {grid}" 222 if statement in seen: 223 continue 224 seen.add(statement) 225 writeline(statement, f"if {guards}: return {example_grid}") 226 227 return fn_name, output.getvalue() 228 229 230@dataclasses.dataclass 231class SymbolicCallArg: 232 inner: str 233 # the original symbolic expression represented by inner 234 inner_expr: sympy.Expr 235 236 def __str__(self): 237 return str(self.inner) 238 239 240# Default thread stack sizes vary by platform: 241# - Linux: 8 MB 242# - macOS: 512 KB 243# - Windows: 1 MB 244# Just pick something comfortably smaller than the smallest for now. 245MAX_STACK_ALLOCATION_SIZE = 1024 * 100 246 247 248class MemoryPlanningState: 249 def __init__(self): 250 super().__init__() 251 self.reuse_pool: Dict[ 252 ReuseKey, List[FreeIfNotReusedLine] 253 ] = collections.defaultdict(list) 254 self.total_allocated_buffer_size: int = 0 255 256 def __contains__(self, key: ReuseKey) -> bool: 257 return bool(self.reuse_pool.get(key, None)) 258 259 def pop(self, key: ReuseKey) -> FreeIfNotReusedLine: 260 item = self.reuse_pool[key].pop() 261 assert not item.is_reused 262 return item 263 264 def push(self, key: ReuseKey, item: FreeIfNotReusedLine) -> None: 265 assert not item.is_reused 266 self.reuse_pool[key].append(item) 267 268 269class WrapperLine: 270 pass 271 272 273@dataclasses.dataclass 274class EnterSubgraphLine(WrapperLine): 275 wrapper: WrapperCodeGen 276 graph: GraphLowering 277 278 def __post_init__(self) -> None: 279 self.wrapper.push_computed_sizes(self.wrapper.computed_sizes) 280 281 def codegen(self, code: IndentedBuffer) -> None: 282 self.wrapper.push_codegened_graph(self.graph) 283 code.do_indent() 284 285 286@dataclasses.dataclass 287class ExitSubgraphLine(WrapperLine): 288 wrapper: WrapperCodeGen 289 290 def __post_init__(self) -> None: 291 self.wrapper.computed_sizes = self.wrapper.pop_computed_sizes() 292 293 def codegen(self, code: IndentedBuffer) -> None: 294 self.wrapper.pop_codegened_graph() 295 code.do_unindent() 296 297 298@dataclasses.dataclass 299class EnterDeviceContextManagerLine(WrapperLine): 300 device_idx: int 301 last_seen_device_guard_index: Optional[int] 302 303 def codegen(self, code: IndentedBuffer) -> None: 304 if V.graph.cpp_wrapper: 305 code.writeline("\n") 306 if V.graph.aot_mode: 307 # In AOT mode, we have a stream provided as a param. A stream is 308 # associated with a device, so we never expect the device to change. 309 # CUDAStreamGuard sets the stream and the device. 310 if self.last_seen_device_guard_index is None: 311 if config.abi_compatible: 312 code.writeline( 313 "AOTICudaStreamGuard stream_guard(stream, this->device_idx_);" 314 ) 315 else: 316 code.writeline( 317 maybe_hipify_code_wrapper( 318 "at::cuda::CUDAStreamGuard stream_guard(" 319 + "at::cuda::getStreamFromExternal(stream, this->device_idx_));" 320 ) 321 ) 322 else: 323 assert ( 324 self.last_seen_device_guard_index == self.device_idx 325 ), "AOTInductor only supports running on one CUDA device" 326 else: 327 if self.last_seen_device_guard_index is None: 328 code.writeline( 329 f"AOTICudaGuard device_guard({self.device_idx});" 330 if config.abi_compatible 331 else maybe_hipify_code_wrapper( 332 f"at::cuda::CUDAGuard device_guard({self.device_idx});" 333 ) 334 ) 335 else: 336 code.writeline(f"device_guard.set_index({self.device_idx});") 337 else: 338 # Note _DeviceGuard has less overhead than device, but only accepts 339 # integers 340 code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:") 341 code.do_indent() 342 code.writeline(V.graph.device_ops.set_device(self.device_idx)) 343 344 345class ExitDeviceContextManagerLine(WrapperLine): 346 def codegen(self, code: IndentedBuffer) -> None: 347 if not V.graph.cpp_wrapper: 348 code.do_unindent() 349 350 351@dataclasses.dataclass 352class MemoryPlanningLine(WrapperLine): 353 wrapper: WrapperCodeGen 354 355 def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: 356 """First pass to find reuse""" 357 return self 358 359 def codegen(self, code: IndentedBuffer) -> None: 360 """Second pass to output code""" 361 362 def __str__(self) -> str: 363 """ 364 Emits a string representation that fits on one line. 365 """ 366 args: List[str] = [] 367 for field in dataclasses.fields(self): 368 if field.name == "wrapper": 369 continue 370 val = getattr(self, field.name) 371 args.append( 372 f"{field.name}={val.get_name() if field.type is ir.Buffer else val}" 373 ) 374 return f"{type(self).__name__}({', '.join(args)})" 375 376 377@dataclasses.dataclass 378class AllocateLine(MemoryPlanningLine): 379 node: ir.Buffer 380 381 def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: 382 if self.node.get_name() in V.graph.removed_buffers: 383 return NullLine(self.wrapper) 384 385 # try to reuse a recently freed buffer 386 key = buffer_reuse_key(self.node) 387 if config.allow_buffer_reuse and key in state: 388 free_line = state.pop(key) 389 free_line.is_reused = True 390 return ReuseLine(self.wrapper, free_line.node, self.node) 391 392 if self.node.get_device().type == "cpu": 393 static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node) 394 if static_shape is not None: 395 state.total_allocated_buffer_size += int( 396 functools.reduce(operator.mul, static_shape, 1) 397 ) 398 399 return self 400 401 def codegen(self, code: IndentedBuffer) -> None: 402 assert self.node.get_name() not in V.graph.removed_buffers 403 line = self.wrapper.make_buffer_allocation(self.node) 404 code.writeline(line) 405 406 407@dataclasses.dataclass 408class FreeIfNotReusedLine(MemoryPlanningLine): 409 node: ir.Buffer 410 is_reused: bool = False 411 412 def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: 413 if len(self.node.get_inputs_that_alias_output()) > 0: 414 return self 415 if isinstance(self.node.layout, ir.MultiOutputLayout): 416 return self 417 assert not self.is_reused 418 if self.node.get_name() in V.graph.removed_buffers: 419 return NullLine(self.wrapper) 420 if config.allow_buffer_reuse: 421 state.push(buffer_reuse_key(self.node), self) 422 return self 423 424 def codegen(self, code: IndentedBuffer) -> None: 425 assert self.node.get_name() not in V.graph.removed_buffers 426 if not self.is_reused: 427 code.writeline(self.wrapper.make_buffer_free(self.node)) 428 429 430@dataclasses.dataclass 431class ReuseLine(MemoryPlanningLine): 432 node: ir.Buffer 433 reused_as: ir.Buffer 434 delete_old: bool = True 435 436 def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: 437 if self.node.get_name() in V.graph.removed_buffers: 438 assert self.reused_as.get_name() in V.graph.removed_buffers 439 return NullLine(self.wrapper) 440 assert self.reused_as.get_name() not in V.graph.removed_buffers 441 return self 442 443 def codegen(self, code: IndentedBuffer) -> None: 444 assert self.node.get_name() not in V.graph.removed_buffers 445 assert self.reused_as.get_name() not in V.graph.removed_buffers 446 code.writeline( 447 self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old) 448 ) 449 450 451class NullLine(MemoryPlanningLine): 452 pass 453 454 455BufferName = str 456 457 458class WrapperCodeGen(CodeGen): 459 """ 460 Generate outer wrapper in Python that calls the kernels. 461 """ 462 463 def __init__(self): 464 super().__init__() 465 self._names_iter: Iterator[int] = count() 466 self.imports = IndentedBuffer() 467 self.header = IndentedBuffer() 468 self.prefix = IndentedBuffer() 469 self.suffix = IndentedBuffer() 470 self.wrapper_call = IndentedBuffer() 471 self.kernel_autotune_defs = IndentedBuffer() 472 self.kernel_autotune_calls = IndentedBuffer() 473 self.kernel_autotune_names: Set[str] = set() 474 # If the generated source code is exactly the same, reuse the 475 # pre-existing kernel for it 476 self.src_to_kernel: Dict[str, str] = {} 477 self.kernel_numel_expr: Set[Tuple[str, GraphLowering]] = set() 478 self.lines: List[Union[MemoryPlanningLine, LineContext]] = [] 479 self.declare = "" 480 self.declare_maybe_reference = "" 481 self.ending = "" 482 self.open_bracket = "[" 483 self.closed_bracket = "]" 484 self.comment = "#" 485 self.namespace = "" 486 self.none_str = "None" 487 self.size = "size()" 488 self.stride = "stride()" 489 self.last_seen_device_guard_index: Optional[int] = None 490 self.supports_intermediate_hooks = True 491 self.expr_printer: Callable[[Any], str] = pexpr 492 self.user_defined_kernel_cache: Dict[Tuple[Any, ...], Tuple[str, Any]] = {} 493 self.unbacked_symbol_decls: Set[str] = set() # str of sympy.Symbol 494 self.allow_stack_allocation: Optional[bool] = None 495 self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {} 496 self.computed_sizes: Set[sympy.Symbol] = set() 497 498 # this is used for tracking which GraphLowering instance---parent graph 499 # or (nested) subgraph---is currently codegened; the primary use case is 500 # including the graph instance into a cache key to avoid cross-graph 501 # caching during lowering of nested subgraphs 502 self.codegened_graph_stack = [] 503 self.computed_sizes_stack = [] 504 505 self.write_header() 506 self.write_prefix() 507 self.write_kernel_autotune_defs_header() 508 509 if not V.graph.aot_mode: 510 for name, hashed in V.graph.constant_reprs.items(): 511 # include a hash so our code cache puts different constants into different files 512 self.write_constant(name, hashed) 513 514 self.allocated: Set[BufferName] = set() 515 self.freed: Set[BufferName] = set() 516 517 # maps from reusing buffer to reused buffer 518 self.reuses: Dict[BufferName, BufferName] = {} 519 520 self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment] 521 self.write_get_raw_stream 522 ) 523 524 @functools.lru_cache(None) 525 def add_import_once(line: str) -> None: 526 self.imports.writeline(line) 527 if config.triton.autotune_at_compile_time: 528 self.kernel_autotune_calls.writeline(line) 529 530 self.add_import_once = add_import_once 531 self._metas: Dict[str, str] = {} 532 self._meta_vars: Set[str] = set() 533 self.multi_kernel_state = MultiKernelState() 534 535 # intermediate tensor value printing utility 536 self.debug_printer = DebugPrinterManager( 537 debug_printer_level=config.aot_inductor.debug_intermediate_value_printer 538 ) 539 540 def write_constant(self, name: str, hashed: str) -> None: 541 self.header.writeline(f"{name} = None # {hashed}") 542 543 def write_header(self) -> None: 544 context = torch._guards.TracingContext.try_get() 545 aot_config_comment = "" 546 if context is not None and context.aot_graph_name is not None: 547 aot_config_comment = f"# AOT ID: {context.aot_graph_name}" 548 self.imports.splice( 549 f""" 550 {aot_config_comment} 551 from ctypes import c_void_p, c_long, c_int 552 import torch 553 import math 554 import random 555 import os 556 import tempfile 557 from math import inf, nan 558 from torch._inductor.hooks import run_intermediate_hooks 559 from torch._inductor.utils import maybe_profile 560 from torch._inductor.codegen.memory_planning import _align as align 561 from torch import device, empty_strided 562 from {async_compile.__name__} import AsyncCompile 563 from torch._inductor.select_algorithm import extern_kernels 564 from torch._inductor.codegen.multi_kernel import MultiKernelCall 565 """, 566 strip=True, 567 ) 568 self.header.splice( 569 """ 570 aten = torch.ops.aten 571 inductor_ops = torch.ops.inductor 572 _quantized = torch.ops._quantized 573 assert_size_stride = torch._C._dynamo.guards.assert_size_stride 574 empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu 575 empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda 576 empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu 577 reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor 578 alloc_from_pool = torch.ops.inductor._alloc_from_pool 579 async_compile = AsyncCompile() 580 """, 581 strip=True, 582 ) 583 584 def write_kernel_autotune_defs_header(self) -> None: 585 self.kernel_autotune_defs.splice( 586 f""" 587 import torch 588 from torch._dynamo.testing import rand_strided 589 from torch._dynamo.utils import preserve_rng_state 590 from torch._inductor.select_algorithm import AlgorithmSelectorCache 591 from {async_compile.__name__} import AsyncCompile 592 593 async_compile = AsyncCompile() 594 generate_example_value = AlgorithmSelectorCache.generate_example_value 595 """ 596 ) 597 598 @cache_on_self 599 def write_triton_header_once(self) -> None: 600 import_str = f""" 601 import triton 602 import triton.language as tl 603 from {triton_heuristics.__name__} import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph 604 """ 605 self.imports.splice(import_str, strip=True) 606 if config.triton.autotune_at_compile_time: 607 self.kernel_autotune_calls.splice(import_str) 608 self.write_get_raw_stream_header_once() 609 610 @cache_on_self 611 def write_get_raw_stream_header_once(self) -> None: 612 self.imports.writeline( 613 V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") 614 ) 615 if config.triton.autotune_at_compile_time: 616 self.kernel_autotune_calls.writeline( 617 V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") 618 ) 619 620 def add_meta_once(self, meta: TritonMetaParams) -> str: 621 meta = repr(meta) 622 if meta not in self._metas: 623 var = f"meta{len(self._metas)}" 624 self._metas[meta] = var 625 self.header.writeline(f"{var} = {meta}") 626 if config.triton.autotune_at_compile_time: 627 self.kernel_autotune_calls.writeline(f"{var} = {meta}") 628 self._meta_vars.add(var) 629 return self._metas[meta] 630 631 @cache_on_self 632 def get_output_refs(self) -> List[str]: 633 return [x.codegen_reference(self.wrapper_call) for x in V.graph.graph_outputs] 634 635 def mark_output_type(self) -> None: 636 return 637 638 def codegen_input_size_asserts(self) -> None: 639 for name, buf in V.graph.graph_inputs.items(): 640 if isinstance(buf, sympy.Expr): 641 continue 642 643 # comparing strides for 0 size tensor is tricky. Ignore them for now. 644 if sympy_product(buf.get_size()) == 0: 645 continue 646 size = self.codegen_shape_tuple(buf.get_size()) 647 stride = self.codegen_shape_tuple(buf.get_stride()) 648 self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})") 649 650 def codegen_input_nan_asserts(self) -> None: 651 self.prefix.writeline("# make sure graph inputs are not nan/inf") 652 for name, buf in V.graph.graph_inputs.items(): 653 if isinstance(buf, sympy.Expr): 654 continue 655 656 line = f"assert not {name}.isnan().any().item()" 657 self.prefix.writeline(line) 658 line = f"assert not {name}.isinf().any().item()" 659 self.prefix.writeline(line) 660 661 def write_prefix(self) -> None: 662 self.prefix.splice( 663 """ 664 665 async_compile.wait(globals()) 666 del async_compile 667 668 def call(args): 669 """ 670 ) 671 with self.prefix.indent(): 672 if config.triton.debug_sync_graph: 673 self.prefix.writeline(V.graph.device_ops.synchronize()) 674 if V.graph.graph_inputs: 675 lhs = ", ".join(V.graph.graph_input_names) 676 if len(V.graph.graph_input_names) == 1: 677 lhs += "," 678 self.prefix.writeline(f"{lhs} = args") 679 self.prefix.writeline("args.clear()") 680 681 self.codegen_inputs(self.prefix, V.graph.graph_inputs) 682 if config.size_asserts: 683 self.codegen_input_size_asserts() 684 if config.nan_asserts: 685 self.codegen_input_nan_asserts() 686 687 # this function (and below) takes a graph as input so 688 # that stream caching happens per graph instance. this 689 # is important for nested subgraph codegening. 690 def write_get_raw_stream(self, device_idx: int, graph=None) -> str: 691 self.write_get_raw_stream_header_once() 692 name = f"stream{device_idx}" 693 self.writeline(f"{name} = get_raw_stream({device_idx})") 694 return name 695 696 def get_codegened_graph(self): 697 return self.codegened_graph_stack[-1] 698 699 def push_codegened_graph(self, graph): 700 self.codegened_graph_stack.append(graph) 701 702 def pop_codegened_graph(self): 703 return self.codegened_graph_stack.pop() 704 705 def push_computed_sizes(self, computed_sizes): 706 from copy import deepcopy 707 708 return self.computed_sizes_stack.append(deepcopy(computed_sizes)) 709 710 def pop_computed_sizes(self): 711 return self.computed_sizes_stack.pop() 712 713 def next_kernel_suffix(self) -> str: 714 return f"{next(self._names_iter)}" 715 716 def codegen_device_guard_enter(self, device_idx: int) -> None: 717 self.writeline( 718 EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index) 719 ) 720 if config.triton.autotune_at_compile_time: 721 # mimic logic of EnterDeviceContextManagerLine.codegen for the autotune code block 722 self.write_triton_header_once() 723 self.kernel_autotune_calls.writeline( 724 f"with {V.graph.device_ops.device_guard(device_idx)}:" 725 ) 726 self.kernel_autotune_calls.do_indent() 727 self.kernel_autotune_calls.writeline( 728 V.graph.device_ops.set_device(device_idx) 729 ) 730 self.kernel_autotune_calls.writeline( 731 f"stream{device_idx} = get_raw_stream({device_idx})" 732 ) 733 self.last_seen_device_guard_index = device_idx 734 735 def codegen_device_guard_exit(self) -> None: 736 self.writeline(ExitDeviceContextManagerLine()) 737 if config.triton.autotune_at_compile_time: 738 self.kernel_autotune_calls.do_unindent() 739 740 def generate_return(self, output_refs: List[str]) -> None: 741 if output_refs: 742 self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") 743 else: 744 self.wrapper_call.writeline("return ()") 745 746 def generate_before_suffix(self, result: IndentedBuffer) -> None: 747 return 748 749 def generate_end(self, result: IndentedBuffer) -> None: 750 return 751 752 def generate_fallback_kernel(self, fallback_kernel, args): 753 self.generate_extern_kernel_alloc(fallback_kernel, args) 754 755 def generate_extern_kernel_alloc(self, extern_kernel, args): 756 # If it's a NoneLayout then the extern_kernel should essentially be 757 # treated as if it doesn't return anything 758 no_return = isinstance(extern_kernel.layout, ir.NoneLayout) 759 output_name = extern_kernel.get_name() 760 origin_node = extern_kernel.get_origin_node() 761 kernel_name = extern_kernel.get_kernel_name() 762 ending = self.ending 763 if config.memory_planning and "view_as_complex" in kernel_name: 764 # view operation fallbacks cause issues since inductor 765 # doesn't know the memory is still needed and might reuse it. 766 ending = f".clone(){ending}" 767 768 if no_return: 769 self.writeline(f"{self.declare}{kernel_name}({', '.join(args)}){ending}") 770 else: 771 self.writeline( 772 f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}" 773 ) 774 if ( 775 self.supports_intermediate_hooks 776 and config.generate_intermediate_hooks 777 and origin_node is not None 778 ): 779 counters["inductor"]["intermediate_hooks"] += 1 780 self.writeline( 781 f"run_intermediate_hooks({origin_node.name!r}, {output_name})" 782 ) 783 784 def generate_extern_kernel_out( 785 self, kernel: str, out: str, out_view: Optional[str], args: List[str] 786 ): 787 args.append(f"out={out_view if out_view else out}") 788 self.writeline(f"{kernel}({', '.join(args)})") 789 790 def generate_user_defined_triton_kernel( 791 self, 792 kernel_name: str, 793 raw_args: List[Any], 794 grid: List[Any], 795 configs, 796 triton_meta, 797 constexprs, 798 ): 799 grid_fn, code = user_defined_kernel_grid_fn_code( 800 kernel_name, configs, grid, wrapper=self 801 ) 802 # Must happen after free symbols are already codegened 803 # Emit the grid wrapper function right before the call 804 for line in code.split("\n"): 805 self.writeline(line) 806 807 args = [self.val_to_arg_str(v) for v in raw_args] 808 arg_types = [ 809 arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg) 810 for arg in raw_args 811 ] 812 self.generate_kernel_call( 813 kernel_name, args, grid_fn=grid_fn, arg_types=arg_types, raw_args=raw_args 814 ) 815 816 def generate_scatter_fallback( 817 self, 818 output, 819 inputs, 820 cpp_kernel_name, 821 python_kernel_name, 822 src_is_tensor, 823 reduce, 824 kwargs, 825 ): 826 line = f"{python_kernel_name}({','.join(map(str, inputs))}" 827 if python_kernel_name.startswith("aten.scatter_reduce"): 828 line += ", ".join([""] + kwargs) 829 else: 830 if reduce: 831 line += f", reduce={repr(reduce)}" 832 line += ")" 833 self.writeline(line) 834 835 def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): 836 indices_str = f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" 837 args = [x, indices_str, values, accumulate] 838 self.writeline(self.wrap_kernel_call(kernel, args)) 839 840 def generate_extern_kernel_alloc_and_find_schema_if_needed( 841 self, 842 buf_name: str, 843 python_kernel_name: str, 844 cpp_kernel_name: str, 845 codegen_args: List[str], 846 cpp_op_schema: str, 847 cpp_kernel_key: str, 848 cpp_kernel_overload_name: str = "", 849 op_overload: Optional[torch._ops.OpOverload] = None, 850 raw_args=None, 851 outputs=None, 852 ): 853 self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(codegen_args)})") 854 855 def generate(self, is_inference): 856 with dynamo_timed("WrapperCodeGen.generate"): 857 return self._generate(is_inference) 858 859 def _generate(self, is_inference): 860 if config.profile_bandwidth: 861 self.write_triton_header_once() 862 result = IndentedBuffer() 863 result.splice(self.imports) 864 result.writeline("") 865 result.splice(self.header) 866 # We do not want the cpp header for intermediate const graph. Headers would be 867 # rendered by the main module instead. 868 if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph: 869 result = IndentedBuffer() 870 871 with contextlib.ExitStack() as stack: 872 stack.enter_context(self.wrapper_call.indent()) 873 if config.profiler_mark_wrapper_call: 874 self.generate_profiler_mark_wrapper_call(stack) 875 if config.profile_bandwidth: 876 self.generate_start_graph() 877 878 # We disable planning during training because it presently increases peak memory consumption. 879 if is_inference and config.memory_planning: 880 self.memory_plan() 881 # TODO: integrate memory planning & stack allocation? 882 self.allow_stack_allocation = False 883 else: 884 self.memory_plan_reuse() 885 886 if config.triton.store_cubin: 887 self.generate_reset_kernel_saved_flags() 888 889 for line in self.lines: 890 if isinstance(line, WrapperLine): 891 line.codegen(self.wrapper_call) 892 else: 893 self.wrapper_call.writeline(line) 894 895 output_refs = self.get_output_refs() 896 self.mark_output_type() 897 if config.triton.debug_sync_graph: 898 self.wrapper_call.writeline(V.graph.device_ops.synchronize()) 899 900 if config.profile_bandwidth: 901 self.generate_end_graph() 902 903 if config.triton.store_cubin: 904 self.generate_save_uncompiled_kernels() 905 906 if config.triton.autotune_at_compile_time: 907 self.generate_and_run_autotune_block() 908 909 self.generate_return(output_refs) 910 911 self.finalize_prefix() 912 result.splice(self.prefix) 913 914 with result.indent(): 915 result.splice(self.wrapper_call) 916 917 self.generate_before_suffix(result) 918 result.splice(self.suffix) 919 920 self.generate_end(result) 921 922 self.add_benchmark_harness(result) 923 924 return result.getvaluewithlinemap() 925 926 def generate_and_run_autotune_block(self): 927 """ 928 Compose self.kernel_autotune_defs and self.kernel_autotune_calls into a single block of 929 code and execute it to trigger Triton kernel compilation and auto-tuning 930 """ 931 self.kernel_autotune_defs.splice( 932 """ 933 async_compile.wait(globals()) 934 del async_compile 935 """ 936 ) 937 scope = {} # type: ignore[var-annotated] 938 tuning_code = ( 939 self.kernel_autotune_defs.getvalue() + self.kernel_autotune_calls.getvalue() 940 ) 941 if output_code_log.level == logging.DEBUG: 942 # Save the autotuning code block into a file 943 # Create a temporary file 944 with tempfile.NamedTemporaryFile( 945 dir=cache_dir(), suffix=".py", delete=False 946 ) as f: 947 f.write(tuning_code.encode("utf-8")) 948 file_path = f.name 949 output_code_log.debug( 950 "\nCompile-time auto-tuning code: \n%s\nAuto-tuning code written to %s", 951 tuning_code, 952 file_path, 953 ) 954 # Execute the code to autotune kernels 955 exec(tuning_code, scope) 956 957 def memory_plan(self): 958 from .memory_planning import MemoryPlanner 959 960 self.lines = MemoryPlanner(self).plan(self.lines) 961 962 def memory_plan_reuse(self): 963 out_names = V.graph.get_output_names() 964 965 while ( 966 self.lines 967 and isinstance(self.lines[-1], MemoryPlanningLine) 968 # TODO: this seems legit, NullLine has no node 969 and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] 970 ): 971 # these lines will be pointless 972 self.lines.pop() 973 974 # codegen allocations in two passes 975 planning_states = [MemoryPlanningState()] 976 past_planning_states = [] 977 for i in range(len(self.lines)): 978 line = self.lines[i] 979 if isinstance(line, MemoryPlanningLine): 980 self.lines[i] = line.plan(planning_states[-1]) 981 elif isinstance(line, EnterSubgraphLine): 982 planning_states.append(MemoryPlanningState()) 983 elif isinstance(line, ExitSubgraphLine): 984 past_planning_states.append(planning_states.pop()) 985 past_planning_states.append(planning_states.pop()) 986 assert len(planning_states) == 0 987 988 # conservatively use the sum of all allocated buffer sizes 989 # in potentially nested scopes as the total allocated size 990 total_allocated_buffer_size = sum( 991 s.total_allocated_buffer_size for s in past_planning_states 992 ) 993 994 self.allow_stack_allocation = ( 995 self.allow_stack_allocation is not False 996 and config.allow_stack_allocation 997 and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE 998 ) 999 1000 def codegen_input_size_var_decl(self, code: IndentedBuffer, name): 1001 code.writeline(f"{self.declare}{name}_size = {name}.{self.size}{self.ending}") 1002 1003 def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): 1004 code.writeline( 1005 f"{self.declare}{name}_stride = {name}.{self.stride}{self.ending}" 1006 ) 1007 1008 def codegen_inputs( 1009 self, code: IndentedBuffer, graph_inputs: Dict[str, ir.TensorBox] 1010 ): 1011 """Assign all symbolic shapes to locals""" 1012 1013 @functools.lru_cache(None) 1014 def sizeof(name): 1015 self.codegen_input_size_var_decl(code, name) 1016 return f"{name}_size" 1017 1018 @functools.lru_cache(None) 1019 def strideof(name): 1020 self.codegen_input_stride_var_decl(code, name) 1021 return f"{name}_stride" 1022 1023 # Assign all symbolic shapes needed to local variables 1024 bound_vars: Set[sympy.Symbol] = set() 1025 1026 def is_expr(x): 1027 return isinstance(x[1], sympy.Expr) 1028 1029 graph_inputs_expr = list(filter(is_expr, graph_inputs.items())) 1030 graph_inputs_tensors = list( 1031 filter(lambda x: not is_expr(x), graph_inputs.items()) 1032 ) 1033 1034 for name, shape in graph_inputs_expr: 1035 if isinstance(shape, sympy.Symbol) and shape not in bound_vars: 1036 code.writeline(f"{self.declare}{shape} = {name}{self.ending}") 1037 bound_vars.add(shape) 1038 1039 for name, value in graph_inputs_tensors: 1040 shapes = value.get_size() 1041 for dim, shape in enumerate(shapes): 1042 if isinstance(shape, sympy.Symbol) and shape not in bound_vars: 1043 code.writeline( 1044 f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}" 1045 ) 1046 bound_vars.add(shape) 1047 1048 for name, value in graph_inputs_tensors: 1049 shapes = value.get_stride() 1050 for dim, shape in enumerate(shapes): 1051 if isinstance(shape, sympy.Symbol) and shape not in bound_vars: 1052 code.writeline( 1053 f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}" 1054 ) 1055 bound_vars.add(shape) 1056 1057 def ensure_size_computed(self, sym: sympy.Symbol): 1058 if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): 1059 if sym in self.computed_sizes: 1060 return 1061 self.computed_sizes.add(sym) 1062 expr = V.graph.sizevars.inv_precomputed_replacements[sym] 1063 self.writeline( 1064 f"{self.declare}{sym} = {self.expr_printer(expr)}{self.ending}" 1065 ) 1066 1067 def finalize_prefix(self): 1068 pass 1069 1070 def codegen_python_sizevar(self, x: Expr, *, simplify: bool = True) -> str: 1071 return pexpr(x, simplify=simplify) 1072 1073 def codegen_sizevar(self, x: Expr) -> str: 1074 return self.codegen_python_sizevar(x) 1075 1076 def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: 1077 return f"{basename}[{index}]" 1078 1079 def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: 1080 parts = list(map(self.codegen_python_sizevar, shape)) 1081 if len(parts) == 0: 1082 return "()" 1083 if len(parts) == 1: 1084 return f"({parts[0]}, )" 1085 return f"({', '.join(parts)})" 1086 1087 def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: 1088 return self.codegen_python_shape_tuple(shape) 1089 1090 def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: 1091 return "alloc_from_pool({})".format( 1092 ", ".join( 1093 [ 1094 name, 1095 pexpr(offset), # bytes not numel 1096 str(dtype), 1097 self.codegen_shape_tuple(shape), 1098 self.codegen_shape_tuple(stride), 1099 ] 1100 ) 1101 ) 1102 1103 def codegen_reinterpret_view( 1104 self, data, size, stride, offset, writer, dtype=None 1105 ) -> str: 1106 if ( 1107 size == data.layout.size 1108 and stride == data.layout.stride 1109 and offset == data.layout.offset 1110 ): 1111 if dtype is not None and dtype != data.dtype: 1112 return f"aten.view.dtype({data.get_name()}, {dtype})" 1113 else: 1114 return f"{data.get_name()}" 1115 else: 1116 size = self.codegen_shape_tuple(size) 1117 stride = self.codegen_shape_tuple(stride) 1118 offset = self.codegen_sizevar(offset) 1119 if dtype is not None and dtype != data.dtype: 1120 return f"aten.view.dtype(reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset}), {dtype})" 1121 else: 1122 return ( 1123 f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})" 1124 ) 1125 1126 def codegen_device_copy(self, src, dst): 1127 self.writeline(f"{dst}.copy_({src})") 1128 1129 def codegen_multi_output(self, name, value): 1130 self.writeline(f"{self.declare}{name} = {value}{self.ending}") 1131 1132 def codegen_dynamic_scalar(self, node): 1133 (data,) = (t.codegen_reference() for t in node.inputs) 1134 if len(node.keypath) == 0: 1135 self.writeline(f"{node.sym} = {data}.item()") 1136 elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey): 1137 self.writeline(f"{node.sym} = 1 if {data}.item() else 0") 1138 elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey): 1139 self.writeline(f"{node.sym}_undivided = {data}.item()") 1140 self.writeline( 1141 f"assert {node.sym}_undivided % {node.keypath[0].divisor} == 0, " 1142 f"f'{{{node.sym}_undivided}} not divisible by {node.keypath[0].divisor}'" 1143 ) 1144 self.writeline( 1145 f"{node.sym} = {node.sym}_undivided // {node.keypath[0].divisor}" 1146 ) 1147 else: 1148 raise AssertionError(f"unrecognized keypath {node.keypath}") 1149 # No one should ever use this buffer, but for uniformity 1150 # define the variable and assign it None 1151 self.writeline(f"{node.get_name()} = None") 1152 1153 def benchmark_compiled_module(self, output): 1154 def add_fake_input(name, shape, stride, device, dtype): 1155 output.writeline( 1156 f"{name} = rand_strided(" 1157 f"{self.codegen_python_shape_tuple(shape)}, " 1158 f"{self.codegen_python_shape_tuple(stride)}, " 1159 f"device='{device}', dtype={dtype})" 1160 ) 1161 1162 def add_expr_input(name, val): 1163 output.writeline(f"{name} = {val}") 1164 1165 def add_torchbind_input(name, value): 1166 import pickle 1167 1168 output.writeline(f"{name} = pickle.loads({pickle.dumps(value)!r})") 1169 1170 output.writelines( 1171 ["", "", "def benchmark_compiled_module(times=10, repeat=10):"] 1172 ) 1173 with output.indent(): 1174 output.splice( 1175 """ 1176 from torch._dynamo.testing import rand_strided 1177 from torch._inductor.utils import print_performance 1178 """, 1179 strip=True, 1180 ) 1181 1182 for name, value in V.graph.constants.items(): 1183 # all the constants are global variables, that's why we need 1184 # these 'global var_name' lines 1185 output.writeline(f"global {name}") 1186 add_fake_input( 1187 name, value.size(), value.stride(), value.device, value.dtype 1188 ) 1189 1190 if len(V.graph.torchbind_constants) > 0: 1191 output.writeline("import pickle") 1192 for name, torchbind_obj in V.graph.torchbind_constants.items(): 1193 # all the constants are global variables, that's why we need 1194 # these 'global var_name' lines 1195 output.writeline(f"global {name}") 1196 add_torchbind_input(name, torchbind_obj) 1197 1198 for name, value in V.graph.graph_inputs.items(): 1199 if isinstance(value, sympy.Symbol) and isinstance( 1200 V.graph.sizevars.var_to_val.get(value, None), SingletonInt 1201 ): 1202 # Inductor should only work with dense -> dense graph, and 1203 # SingletonInts belong to metadata that should only live on 1204 # the subclass. 1205 continue 1206 if isinstance(value, sympy.Expr): # Don't need to add symbolic 1207 # TODO: this fallback and those below actually will generate possibly 1208 # invalid benchmark code, because it's not guaranteed 42 1209 # is actually a valid value for the kernel in question. 1210 # See https://github.com/pytorch/pytorch/issues/124686 1211 add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42)) 1212 else: 1213 shape = [ 1214 V.graph.sizevars.size_hint(x, fallback=42) 1215 for x in value.get_size() 1216 ] 1217 stride = [ 1218 V.graph.sizevars.size_hint(x, fallback=42) 1219 for x in value.get_stride() 1220 ] 1221 add_fake_input( 1222 name, 1223 shape, 1224 stride, 1225 value.get_device(), 1226 value.get_dtype(), 1227 ) 1228 1229 call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])" 1230 output.writeline(f"fn = lambda: {call_str}") 1231 output.writeline("return print_performance(fn, times=times, repeat=repeat)") 1232 1233 def add_benchmark_harness(self, output): 1234 """ 1235 Append a benchmark harness to generated code for debugging 1236 """ 1237 if not config.benchmark_harness: 1238 return 1239 1240 self.benchmark_compiled_module(output) 1241 1242 output.writelines(["", "", 'if __name__ == "__main__":']) 1243 with output.indent(): 1244 output.writelines( 1245 [ 1246 "from torch._inductor.wrapper_benchmark import compiled_module_main", 1247 f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)", 1248 ] 1249 ) 1250 1251 def define_kernel( 1252 self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True 1253 ): 1254 metadata_comment = f"{metadata}\n" if metadata else "" 1255 body = f"\n\n{metadata_comment}{name} = {kernel}" 1256 self.header.splice(body) 1257 if config.triton.autotune_at_compile_time: 1258 self.kernel_autotune_defs.splice(body) 1259 1260 def define_user_defined_triton_kernel(self, kernel, configs, kwargs): 1261 from torch.utils._triton import patch_triton_dtype_repr 1262 1263 patch_triton_dtype_repr() 1264 1265 original_name = kernel.__name__ 1266 1267 from .common import KernelArgType, SizeArg, TensorArg 1268 1269 signature: List[KernelArgType] = [] 1270 constants: Dict[int, Any] = {} 1271 non_constant_indices = [] 1272 equal_to_1_arg_idx: List[int] = [] 1273 for idx, key in enumerate(kernel.arg_names): 1274 if key not in kwargs: 1275 continue 1276 arg = kwargs[key] 1277 if idx in kernel.constexprs: 1278 constants[idx] = arg 1279 else: 1280 non_constant_indices.append(idx) 1281 if isinstance(arg, ir.Buffer): 1282 signature.append( 1283 TensorArg( 1284 name=key, 1285 buffer=arg.get_name(), 1286 dtype=arg.get_dtype(), 1287 ) 1288 ) 1289 elif isinstance(arg, ir.ReinterpretView): 1290 # for ReinterpretView we use the underlying 1291 # buffer name and note the (possibly non-zero) 1292 # offset relative to the underlying buffer 1293 signature.append( 1294 TensorArg( 1295 name=key, 1296 buffer=arg.data.get_name(), 1297 dtype=arg.get_dtype(), 1298 offset=arg.layout.offset, 1299 ) 1300 ) 1301 else: 1302 signature.append(SizeArg(key, arg)) 1303 if isinstance( 1304 arg, (int, sympy.Integer) 1305 ) and V.graph.sizevars.statically_known_equals( 1306 arg, 1 # type: ignore[arg-type] 1307 ): 1308 equal_to_1_arg_idx.append(idx) 1309 index_dtype = "tl.int32" 1310 triton_meta = { 1311 "signature": signature_to_meta( 1312 signature, 1313 size_dtype=index_dtype, 1314 indices=non_constant_indices, 1315 ), 1316 "device": DeviceProperties.create( 1317 V.graph.scheduler.get_current_device_or_throw() 1318 ), 1319 # Triton compiler includes equal_to_1 args into constants even 1320 # when they are not constexpr. otherwise there may be a segfault 1321 # during launching the Inductor-compiled Triton kernel. 1322 # TODO(aakhundov): add None args to constants, too. currently, this 1323 # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input. 1324 # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 1325 # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 1326 "constants": { 1327 **constants, 1328 **dict.fromkeys(equal_to_1_arg_idx, 1), 1329 }, 1330 "configs": [ 1331 config_of( 1332 signature, 1333 indices=non_constant_indices, 1334 ) 1335 ], 1336 } 1337 1338 # Distinguish between different functions using function id 1339 cache_key: List[Any] = [id(kernel.fn)] 1340 if len(configs) > 0: 1341 for arg in kwargs.values(): 1342 # We need to key on non tensor arg only in autotune mode 1343 if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)): 1344 cache_key.append(arg) 1345 cache_key.append(str(triton_meta)) 1346 cache_key = tuple(cache_key) 1347 1348 if cache_key in self.user_defined_kernel_cache: 1349 return self.user_defined_kernel_cache[cache_key] 1350 1351 name = f"{original_name}_{len(self.user_defined_kernel_cache)}" 1352 # Add to the cache for the next use 1353 self.user_defined_kernel_cache[cache_key] = (name, triton_meta) 1354 1355 compile_wrapper = IndentedBuffer() 1356 compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") 1357 1358 from .triton import gen_common_triton_imports, TritonKernel 1359 1360 compile_wrapper.splice(gen_common_triton_imports()) 1361 1362 inductor_meta = { 1363 "kernel_name": name, 1364 **TritonKernel.inductor_meta_common(), 1365 } 1366 1367 configs = [ 1368 { 1369 "kwargs": config.kwargs, 1370 "num_warps": config.num_warps, 1371 "num_stages": config.num_stages, 1372 } 1373 for config in configs 1374 ] 1375 1376 compile_wrapper.splice( 1377 f""" 1378 @triton_heuristics.user_autotune( 1379 configs={configs!r}, 1380 inductor_meta={inductor_meta!r}, 1381 triton_meta={triton_meta!r}, 1382 filename=__file__, 1383 custom_kernel=True, 1384 ) 1385 @triton.jit 1386 """ 1387 ) 1388 compile_wrapper.splice(kernel.src, strip=True) 1389 1390 # Also include any possible kernel being called indirectly 1391 from triton import JITFunction # type: ignore[name-defined, attr-defined] 1392 from triton.language import constexpr # type: ignore[name-defined] 1393 1394 # global constexpr vars handled above 1395 symbols_included = {original_name} 1396 1397 def traverse(cur_kernel): 1398 # here we extract the unqualified names (i.e., not attributes and 1399 # without prepended module name) loaded in the kernel code, which 1400 # are matched with the co_names and __globals__ below to codegen 1401 # the respective imports necessary for the kernel compilation 1402 unqualified_loads = { 1403 inst.argval 1404 for inst in dis.Bytecode(cur_kernel.fn) 1405 if inst.opname == "LOAD_GLOBAL" 1406 } 1407 global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {}) 1408 for symbol_name in cur_kernel.fn.__code__.co_names: 1409 if symbol_name in symbols_included: 1410 continue 1411 if symbol_name in cur_kernel.fn.__globals__: 1412 symbol = cur_kernel.fn.__globals__[symbol_name] 1413 if isinstance(symbol, JITFunction): 1414 compile_wrapper.newline() 1415 compile_wrapper.writeline("@triton.jit") 1416 compile_wrapper.splice(symbol.src, strip=True) 1417 symbols_included.add(symbol_name) 1418 traverse(symbol) 1419 elif isinstance(symbol, (int, str, bool, constexpr)): 1420 compile_wrapper.newline() 1421 if isinstance(symbol, constexpr): 1422 symbol_str = f"tl.constexpr({symbol.value!r})" 1423 else: 1424 symbol_str = f"{symbol!r}" 1425 if annotation := global_annotations.get(symbol_name): 1426 annotion_code = "" 1427 if isinstance(annotation, type): 1428 annotation_code = ( 1429 f": {annotation.__module__}.{annotation.__name__}" 1430 ) 1431 else: 1432 annotation_code = f": {annotation!r}" 1433 compile_wrapper.writeline( 1434 f"{symbol_name}{annotation_code} = {symbol_str}" 1435 ) 1436 else: 1437 compile_wrapper.writeline(f"{symbol_name} = {symbol!r}") 1438 symbols_included.add(symbol_name) 1439 elif ( 1440 symbol_name in unqualified_loads 1441 and symbol_name != "tl" # already imported 1442 and hasattr(symbol, "__module__") 1443 # only codegen imports from triton; JITFunctions 1444 # imported from other modules will be codegened 1445 # in the separate branch above 1446 and symbol.__module__.startswith("triton") 1447 ): 1448 # a global symbol imported from triton is referenced 1449 # without module qualification (i.e., `store` instead 1450 # of `tl.store`): need to codegen an import 1451 compile_wrapper.writeline( 1452 f"from {symbol.__module__} import {symbol.__name__} as {symbol_name}" 1453 ) 1454 symbols_included.add(symbol_name) 1455 1456 traverse(kernel) 1457 1458 current_device = V.graph.scheduler.get_current_device_or_throw() 1459 compile_wrapper.writeline(f"''', device_str='{current_device.type}')") 1460 _, lineno = inspect.getsourcelines(kernel.fn) 1461 srcfile = inspect.getsourcefile(kernel.fn) 1462 metadata = f"# Original path: {srcfile}:{lineno}" 1463 self.define_kernel( 1464 name, 1465 compile_wrapper.getvalue(), 1466 metadata, 1467 ) 1468 return name, triton_meta 1469 1470 def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None): 1471 expr = f"{kernel_name}_{tree.prefix}numel" 1472 if suffix is not None: 1473 expr += f"_{suffix}" 1474 if (expr, V.graph) not in self.kernel_numel_expr: 1475 # declare expr once in each graph (scope) 1476 self.kernel_numel_expr.add((expr, V.graph)) 1477 self.writeline( 1478 f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}" 1479 ) 1480 else: 1481 self.writeline(f"{expr} = {self.expr_printer(tree.numel)}{self.ending}") 1482 # We can get symbolic expressions here, like s0*64 1483 # It is fine to have them here, but we need to handle them correctly as their own type 1484 # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* 1485 # scalars as well. 1486 # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for 1487 # constant now, need type info. I agree, this needs type info, and while this is not true type info 1488 # it suffices as a type hint for the purposes of producing the correct code for this type. 1489 return SymbolicCallArg(expr, tree.numel) 1490 1491 def generate_workspace_allocation(self, nbytes, device, zero_fill): 1492 line = self.make_allocation( 1493 "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,) 1494 ) 1495 self.writeline(line) 1496 if zero_fill: 1497 self.writeline(f"workspace.zero_(){self.ending}") 1498 1499 def wrap_kernel_call(self, name, call_args): 1500 return f"{name}({', '.join(call_args)}){self.ending}" 1501 1502 def generate_profiler_mark_wrapper_call(self, stack): 1503 self.wrapper_call.writeline("from torch.profiler import record_function") 1504 self.wrapper_call.writeline( 1505 f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):" 1506 ) 1507 stack.enter_context(self.wrapper_call.indent()) 1508 1509 def generate_start_graph(self): 1510 self.wrapper_call.writeline("start_graph()") 1511 1512 def generate_end_graph(self): 1513 self.wrapper_call.writeline(f"end_graph({config.profile_bandwidth_output!r})") 1514 1515 def generate_reset_kernel_saved_flags(self): 1516 self.wrapper_call.splice( 1517 f""" 1518 for kernel in globals().values(): 1519 if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner): 1520 kernel.cuda_kernel_saved = False 1521 """ 1522 ) 1523 1524 def generate_save_uncompiled_kernels(self): 1525 """ 1526 Precompile and save the CUBINs of the Triton kernels that haven't 1527 been precompiled and saved as a side effect of running the generated 1528 JIT model (Python wrapper). This can happen when the model contains 1529 control flow: only one pass through the control flow operators covers 1530 the kernels that are saved, the remaining kernels are not launched, 1531 hence not saved. The main purpose of this codegen is to compile and 1532 save the Triton kernels outside the active control flow path for 1533 subsequent AOTInductor code generation and compilation. 1534 """ 1535 self.wrapper_call.splice( 1536 f""" 1537 for kernel in globals().values(): 1538 if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner): 1539 if not kernel.cuda_kernel_saved: 1540 if len(kernel.launchers) == 0: 1541 kernel.precompile() 1542 kernel.save_gpu_kernel( 1543 grid=(0, 0, 0), # use dummy grid 1544 stream="stream", # use dummy stream 1545 launcher=kernel.launchers[0], 1546 ) 1547 """ 1548 ) 1549 1550 def generate_default_grid( 1551 self, 1552 kernel_name: str, 1553 grid: List[Any], 1554 cuda: bool = True, 1555 grid_callable: Optional[Callable[..., Any]] = None, 1556 **grid_extra_kwags, 1557 ): 1558 return grid 1559 1560 def prepare_triton_kernel_call(self, device_index, call_args): 1561 def wrap_arg(arg): 1562 if isinstance(arg, str): 1563 # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar 1564 return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg 1565 elif isinstance(arg, (int, float, bool, SymbolicCallArg)): 1566 return str(arg) 1567 else: 1568 return self.expr_printer(V.graph.sizevars.simplify(arg)) 1569 1570 call_args = [wrap_arg(arg) for arg in call_args] 1571 1572 if device_index is None: 1573 current_device = V.graph.scheduler.get_current_device_or_throw() 1574 device_index = current_device.index 1575 1576 return device_index, call_args 1577 1578 def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): 1579 if isinstance(arg_type, torch_dtype): 1580 if V.graph.try_get_buffer(arg) is not None: 1581 buf_name = arg 1582 buf = V.graph.get_buffer(arg) 1583 else: 1584 assert ( 1585 raw_arg is not None 1586 ), "V.graph.get_buffer(arg) and raw_arg can't be None at the same time" 1587 buf_name = f"tmp_arg_{index}" 1588 buf = raw_arg 1589 1590 size = V.graph.sizevars.size_hints( 1591 buf.get_size(), 1592 fallback=config.unbacked_symint_fallback, 1593 ) 1594 stride = V.graph.sizevars.size_hints( 1595 buf.get_stride(), 1596 fallback=config.unbacked_symint_fallback, 1597 ) 1598 device = buf.get_device() 1599 dtype = buf.get_dtype() 1600 offset = V.graph.sizevars.size_hint( 1601 buf.layout.offset, 1602 fallback=config.unbacked_symint_fallback, 1603 ) 1604 value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset})" 1605 self.kernel_autotune_calls.writeline(f"{buf_name} = {value}") 1606 return buf_name 1607 elif issubclass(arg_type, sympy.Basic) or isinstance(arg, SymbolicCallArg): 1608 # arg is a symbol or symbolic expression 1609 if isinstance(arg, str): 1610 if arg in self._meta_vars: 1611 return arg 1612 if raw_arg is None: 1613 return "None" 1614 arg = raw_arg 1615 if isinstance(arg, SymbolicCallArg): 1616 arg = arg.inner_expr 1617 if arg in V.graph.sizevars.inv_precomputed_replacements: 1618 arg = V.graph.sizevars.inv_precomputed_replacements[arg] 1619 return str( 1620 V.graph.sizevars.size_hint( 1621 arg, 1622 fallback=config.unbacked_symint_fallback, 1623 ) 1624 ) 1625 elif isinstance(arg, (str, int, float, bool)): 1626 return str(arg) 1627 elif isinstance(arg, list): 1628 return f"[{', '.join(self.generate_example_arg_value(a, type(a)) for a in arg)}]" 1629 else: 1630 raise NotImplementedError(f"Unsupported type {type(arg)}") 1631 1632 def _grid_dim_str(self, grid_per_dim): 1633 if isinstance(grid_per_dim, list): 1634 return ( 1635 "[" + ", ".join(self._grid_dim_str(item) for item in grid_per_dim) + "]" 1636 ) 1637 else: 1638 return pexpr(grid_per_dim) 1639 1640 def generate_kernel_call( 1641 self, 1642 kernel_name, 1643 call_args, 1644 grid=None, 1645 device_index=None, 1646 cuda=True, 1647 triton=True, 1648 arg_types=None, 1649 raw_args=None, 1650 grid_fn: str = "grid", 1651 triton_meta=None, 1652 autotune_configs=None, 1653 grid_extra_kwargs="", 1654 ): 1655 """ 1656 Generates kernel call code. 1657 1658 cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. 1659 1660 triton: Defines whether the GPU backend uses Triton for codegen. 1661 Otherwise it uses the CUDA language for codegen. 1662 Only valid when cuda == True. 1663 """ 1664 if cuda: 1665 device_index, call_args_str = self.prepare_triton_kernel_call( 1666 device_index, call_args 1667 ) 1668 call_args_str = ", ".join(call_args_str) 1669 stream_name = self.write_get_raw_stream(device_index, V.graph) 1670 if triton: 1671 self.write_triton_header_once() 1672 if grid is None: 1673 grid_str = grid_fn 1674 else: 1675 grid_str = ", ".join(self._grid_dim_str(item) for item in grid) 1676 if grid_extra_kwargs: 1677 grid_str = f"{grid_str}, {grid_extra_kwargs}" 1678 grid_str = f"{grid_fn}({grid_str})" 1679 self.writeline( 1680 f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" 1681 ) 1682 if ( 1683 config.triton.autotune_at_compile_time 1684 and kernel_name not in self.kernel_autotune_names 1685 ): 1686 # Create example args for autotune in a separate epilogue 1687 assert arg_types is not None and len(call_args) == len( 1688 arg_types 1689 ), "call_args and arg_types do not match" 1690 1691 tensor_args = {} 1692 all_args = [] 1693 if raw_args is None: 1694 # create a dummy raw_args for uniform behavior in the following loop 1695 raw_args = [None] * len(call_args) 1696 else: 1697 assert len(raw_args) == len( 1698 call_args 1699 ), "call_args and raw_args do not match" 1700 1701 for i, (arg, arg_type, raw_arg) in enumerate( 1702 zip(call_args, arg_types, raw_args) 1703 ): 1704 key = None 1705 if isinstance(arg, str) and "=" in str(arg): 1706 # arg may be passed in a kwarg style, and then we need to extract its value 1707 key, arg = arg.split("=") 1708 1709 if isinstance(arg_type, torch_dtype): 1710 if arg not in tensor_args: 1711 arg_str = self.generate_example_arg_value( 1712 arg, arg_type, raw_arg, i 1713 ) 1714 tensor_args[arg] = arg_str 1715 else: 1716 arg_str = tensor_args[arg] 1717 else: 1718 arg_str = self.generate_example_arg_value( 1719 arg, arg_type, raw_arg, i 1720 ) 1721 all_args.append(arg_str if key is None else f"{key}={arg_str}") 1722 1723 if grid is None: 1724 grid_str = grid_fn 1725 else: 1726 grid_str = ", ".join( 1727 self.generate_example_arg_value(g, type(g)) for g in grid 1728 ) 1729 if grid_extra_kwargs: 1730 grid_str = f"{grid_str}, {grid_extra_kwargs}" 1731 grid_str = f"{grid_fn}({grid_str})" 1732 1733 self.kernel_autotune_calls.writeline( 1734 f"{kernel_name}.run({', '.join(all_args)}, grid={grid_str}, stream={stream_name})" 1735 ) 1736 self.kernel_autotune_calls.writeline( 1737 f"del {', '.join(arg for arg in tensor_args.values())}\n", 1738 ) 1739 self.kernel_autotune_names.add(kernel_name) 1740 else: 1741 stream_ptr = f"c_void_p({stream_name})" 1742 self.writeline( 1743 f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})" 1744 ) 1745 else: 1746 self.writeline(self.wrap_kernel_call(kernel_name, call_args)) 1747 1748 def writeline(self, line): 1749 self.lines.append(line) 1750 1751 def writelines(self, lines): 1752 for line in lines: 1753 self.writeline(line) 1754 1755 def enter_context(self, ctx): 1756 self.lines.append(LineContext(ctx)) 1757 1758 def val_to_arg_str(self, s, type_=None): 1759 from torch.utils._triton import dtype_to_string, has_triton_package 1760 1761 if has_triton_package(): 1762 import triton 1763 1764 if isinstance(s, SymTypes): 1765 return pexpr(s.node.expr) 1766 elif isinstance(s, sympy.Expr): 1767 return pexpr(s) 1768 elif isinstance(s, (tuple, list)): 1769 1770 @dataclasses.dataclass 1771 class Shim: 1772 ref: Any 1773 1774 def __repr__(self): 1775 return self.ref 1776 1777 return repr(type(s)(Shim(self.val_to_arg_str(a)) for a in s)) 1778 elif isinstance(s, torch._ops.OpOverload): 1779 return _get_qualified_name(s) 1780 elif isinstance(s, (ir.Buffer, ReinterpretView)): 1781 return s.codegen_reference() 1782 elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined] 1783 return dtype_to_string(s) 1784 else: 1785 return repr(s) 1786 1787 # The following methods are for memory management 1788 def make_buffer_allocation(self, buffer): 1789 device = buffer.get_device() 1790 dtype = buffer.get_dtype() 1791 shape = tuple(buffer.get_size()) 1792 stride = tuple(buffer.get_stride()) 1793 return self.make_allocation(buffer.get_name(), device, dtype, shape, stride) 1794 1795 def make_allocation(self, name, device, dtype, shape, stride): 1796 if device.type in ("cpu", "cuda", "xpu"): 1797 # optimized path for faster allocations, saving ~2us versus the stuff below 1798 return ( 1799 f"{name} = empty_strided_{device.type}(" 1800 f"{self.codegen_shape_tuple(shape)}, " 1801 f"{self.codegen_shape_tuple(stride)}, " 1802 f"{dtype})" 1803 ) 1804 # all other devices: 1805 return ( 1806 f"{name} = empty_strided(" 1807 f"{self.codegen_shape_tuple(shape)}, " 1808 f"{self.codegen_shape_tuple(stride)}, " 1809 f"device='{device.type}', dtype={dtype})" 1810 ) 1811 1812 def make_tensor_alias(self, new_name, old_name, comment=""): 1813 return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}" 1814 1815 def make_buffer_free(self, buffer): 1816 return f"del {buffer.get_name()}" 1817 1818 def make_free_by_names(self, names_to_del: List[str]): 1819 return f"del {', '.join(name for name in names_to_del)}" 1820 1821 def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): 1822 return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse" 1823 1824 def make_buffer_reuse(self, old: ir.Buffer, new: ir.Buffer, delete_old: bool): 1825 assert old.get_dtype() == new.get_dtype() 1826 old_name = old.get_name() 1827 new_name = new.get_name() 1828 del_line = ";" 1829 if old_name not in V.graph.get_output_names() and delete_old: 1830 del_line = f"; {self.make_buffer_free(old)}" 1831 1832 if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): 1833 if old_name in self.stack_allocated_buffers: 1834 self.stack_allocated_buffers[new_name] = new 1835 return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) 1836 1837 reinterpret_view = self.codegen_reinterpret_view( 1838 old, new.get_size(), new.get_stride(), 0, self.wrapper_call 1839 ) 1840 if reinterpret_view in self.stack_allocated_buffers: 1841 self.stack_allocated_buffers[new_name] = new 1842 return f"{self.declare_maybe_reference}{new_name} = {reinterpret_view}{del_line} {self.comment} reuse" 1843 1844 def codegen_deferred_allocation(self, name, layout): 1845 self.writeline( 1846 DeferredLine( 1847 name, 1848 f"{self.declare_maybe_reference}{name} = {layout.view.codegen_reference()}{self.ending} " 1849 f"{self.comment} alias", 1850 ) 1851 ) 1852 1853 def codegen_allocation(self, buffer: ir.Buffer): 1854 name = buffer.get_name() 1855 1856 if name in V.graph.removed_buffers or name in self.allocated: 1857 return 1858 self.allocated.add(name) 1859 if isinstance( 1860 buffer.get_defining_op(), 1861 (ir.ExternKernelAlloc, ir.MultiOutput), 1862 ): 1863 return 1864 1865 layout = buffer.get_layout() 1866 if isinstance(layout, ir.MutationLayoutSHOULDREMOVE): 1867 return 1868 if isinstance(layout, ir.NoneLayout): 1869 return 1870 if isinstance(layout, ir.NonOwningLayout): 1871 assert isinstance( 1872 layout.view, ir.ReinterpretView 1873 ), f"unexpected {type(layout.view)}: {layout.view}" 1874 assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data) 1875 assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data) 1876 self.codegen_allocation(layout.view.data.data) 1877 self.codegen_deferred_allocation(name, layout) 1878 return 1879 1880 self.writeline(AllocateLine(self, buffer)) 1881 1882 def codegen_free(self, buffer): 1883 name = buffer.get_name() 1884 1885 # can be freed but not reused 1886 if isinstance(buffer, ir.InputBuffer): 1887 self.writeline(self.make_buffer_free(buffer)) 1888 return 1889 1890 if not self.can_reuse(buffer): 1891 return 1892 self.freed.add(name) 1893 1894 self.writeline(FreeIfNotReusedLine(self, buffer)) 1895 1896 def can_reuse(self, input_buffer, output_buffer=None): 1897 name = input_buffer.get_name() 1898 return not ( 1899 name in V.graph.removed_buffers 1900 or name in V.graph.graph_inputs 1901 or name in V.graph.constants 1902 or name in V.graph.torchbind_constants 1903 or name in V.graph.never_reuse_buffers 1904 or name in self.freed 1905 ) 1906 1907 def did_reuse(self, buffer, reused_buffer): 1908 # Check whether a given buffer was reused by a possible reuser in the wrapper codegen 1909 # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed 1910 return ( 1911 buffer.get_name() in self.reuses 1912 and self.reuses[buffer.get_name()] == reused_buffer.get_name() 1913 ) 1914 1915 def codegen_inplace_reuse(self, input_buffer: ir.Buffer, output_buffer: ir.Buffer): 1916 assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer) 1917 self.codegen_allocation(input_buffer) 1918 self.freed.add(input_buffer.get_name()) 1919 self.allocated.add(output_buffer.get_name()) 1920 self.reuses[output_buffer.get_name()] = input_buffer.get_name() 1921 self.writeline(ReuseLine(self, input_buffer, output_buffer)) 1922 1923 def codegen_unbacked_symbol_decl(self, symbol): 1924 name = str(symbol) 1925 if name in self.unbacked_symbol_decls: 1926 return name 1927 else: 1928 # When in CppWrapperCpu, we should only generate the declaration once 1929 self.unbacked_symbol_decls.add(name) 1930 return self.declare + name 1931 1932 def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): 1933 for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs): 1934 self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}") 1935 1936 def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): 1937 for inner_output, outer_output in zip( 1938 subgraph.graph.graph_outputs, outer_outputs 1939 ): 1940 self.writeline( 1941 f"{outer_output} = {inner_output.codegen_reference()}{self.ending}" 1942 ) 1943 1944 def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): 1945 try: 1946 self.push_codegened_graph(subgraph.graph) 1947 self.writeline(f"{self.comment} subgraph: {subgraph.name}") 1948 self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) 1949 parent_graph = V.graph 1950 with V.set_graph_handler(subgraph.graph): 1951 subgraph.graph.codegen_subgraph( 1952 parent_graph=parent_graph, 1953 ) 1954 self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs) 1955 finally: 1956 self.pop_codegened_graph() 1957 1958 def codegen_conditional(self, conditional): 1959 name = conditional.get_name() 1960 1961 self.writeline(f"{name} = [None] * {len(conditional.outputs)}") 1962 1963 outer_inputs = [buf.codegen_reference() for buf in conditional.operands] 1964 outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] 1965 1966 predicate = conditional.predicate.codegen_reference() 1967 if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): 1968 # move the Tensor predicate to host 1969 predicate = f"{predicate}.item()" 1970 1971 self.writeline(f"{name} = [None] * {len(conditional.outputs)}") 1972 self.writeline(f"if {predicate}:") 1973 self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) 1974 self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) 1975 self.writeline(ExitSubgraphLine(self)) 1976 self.writeline("else:") 1977 self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) 1978 self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) 1979 self.writeline(ExitSubgraphLine(self)) 1980 1981 def codegen_while_loop(self, while_loop): 1982 name = while_loop.get_name() 1983 outer_carried_inputs = [ 1984 buf.codegen_reference() for buf in while_loop.carried_inputs 1985 ] 1986 outer_additional_inputs = [ 1987 buf.codegen_reference() for buf in while_loop.additional_inputs 1988 ] 1989 1990 self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}") 1991 for i, inp in enumerate(outer_carried_inputs): 1992 # set the initial state before the loop 1993 self.writeline(f"{name}[{i}] = {inp}") 1994 1995 cond_outer_inputs = [ 1996 *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))], 1997 *outer_additional_inputs, 1998 ] 1999 cond_outer_outputs = [f"{name}_cond_result"] 2000 body_outer_inputs = list( 2001 cond_outer_inputs 2002 ) # same inputs for cond_fn and body_fn 2003 # Carry over the state from body_fn. Note: We only carry over 2004 # the carried_inputs part of the inputs, the additional ones 2005 # are passed in as they're before. 2006 body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] 2007 2008 self.writeline("while True:") 2009 self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) 2010 self.codegen_subgraph( 2011 while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs 2012 ) 2013 self.writeline( 2014 f"if not {cond_outer_outputs[0]}.item(): break" 2015 ) # condition doesn't hold 2016 self.writeline(ExitSubgraphLine(self)) 2017 self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) 2018 self.codegen_subgraph( 2019 while_loop.body_subgraph, body_outer_inputs, body_outer_outputs 2020 ) 2021 self.writeline(ExitSubgraphLine(self)) 2022 2023 @staticmethod 2024 def statically_known_int_or_none(x): 2025 try: 2026 if getattr(x, "free_symbols", None): 2027 # _maybe_evaluate_static will return (s0 // (2 // s0)) as 2, but 2028 # the actual codegen will still generate the full expression here. 2029 return None 2030 if isinstance(x, int): 2031 return x 2032 val = V.graph._shape_env._maybe_evaluate_static(x) 2033 return int(val) 2034 except Exception: 2035 return None 2036 2037 @staticmethod 2038 def statically_known_list_of_ints_or_none(lst): 2039 result = [] 2040 for x in lst: 2041 num = WrapperCodeGen.statically_known_int_or_none(x) 2042 if num is None: 2043 return None 2044 result.append(num) 2045 return result 2046 2047 @staticmethod 2048 def is_statically_known_list_of_ints(lst): 2049 return WrapperCodeGen.statically_known_list_of_ints_or_none(lst) is not None 2050 2051 @staticmethod 2052 def static_shape_for_buffer_or_none(buffer): 2053 return WrapperCodeGen.statically_known_list_of_ints_or_none(buffer.get_size()) 2054 2055 @staticmethod 2056 def can_prove_buffer_has_static_shape(buffer): 2057 return WrapperCodeGen.static_shape_for_buffer_or_none(buffer) is not None 2058