1# mypy: allow-untyped-defs 2from typing import Any, Dict, List, Optional 3 4import sympy 5 6import torch 7 8from .. import config 9from ..runtime.hints import instance_descriptor 10from ..utils import _type_of 11from ..virtualized import V 12from .common import KernelArgType, SizeArg, TensorArg, WorkspaceArg 13 14 15def should_unwrap_unspec_arg(name: str): 16 if V.graph.is_unspec_arg(name): 17 # Unwrap on all devices except CPU 18 if V.graph.scheduler.get_current_device_or_throw().type != "cpu": 19 return True 20 # Only unwrap on CPU if the input is not used as an output 21 if name not in V.graph.mutated_buffers: 22 return True 23 return False 24 25 26def signature_of(arg: KernelArgType, *, size_dtype: str) -> str: 27 if isinstance(arg, TensorArg): 28 # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes. 29 # Related PR: https://github.com/openai/triton/pull/2279/ 30 if arg.dtype == torch.float8_e4m3fn: 31 tye = "*fp8e4nv" 32 elif arg.dtype == torch.float8_e5m2: 33 tye = "*fp8e5" 34 elif arg.dtype == torch.float8_e4m3fnuz: 35 tye = "*fp8e4b8" 36 elif arg.dtype == torch.float8_e5m2fnuz: 37 tye = "*fp8e5b16" 38 else: 39 tye = _type_of(arg.dtype) 40 if should_unwrap_unspec_arg(arg.buffer): 41 # had unwrapped 0d tensor as scalar 42 new_tye = tye.lstrip("*") 43 if new_tye in ["fp16", "bf16"]: 44 return "fp32" 45 else: 46 return new_tye 47 else: 48 return tye 49 if isinstance(arg, SizeArg): 50 if arg.expr is None: 51 # From triton/runtime/jit.py 52 # `None` is nullptr. Implicitly convert to *i8. 53 return "*i8" 54 elif isinstance(arg.expr, (float, sympy.Float)): 55 return "fp32" 56 if size_dtype == "tl.int32": 57 return "i32" 58 elif size_dtype == "tl.int64": 59 return "i64" 60 else: 61 raise NotImplementedError(f"unhandled size_dtype {size_dtype}") 62 if isinstance(arg, WorkspaceArg): 63 return "*i8" 64 raise NotImplementedError(f"unhandled {type(arg)}: {arg}") 65 66 67def signature_to_meta( 68 signature: List[KernelArgType], 69 *, 70 size_dtype: str, 71 indices: Optional[List[int]] = None, 72) -> Dict[int, str]: 73 if indices is None: 74 indices = list(range(len(signature))) 75 return { 76 i: signature_of(arg, size_dtype=size_dtype) 77 for i, arg in zip(indices, signature) 78 } 79 80 81def is_unaligned_buffer(arg: TensorArg): 82 buf_name = arg.buffer 83 if buf_name in V.graph.graph_inputs: 84 # See Note: [Input Alignment handling in Inductor] 85 return buf_name not in V.graph.aligned_inputs 86 87 if buf_name in V.graph.constants: 88 # all constants are assumed to be aligned 89 return False 90 91 if V.graph.scheduler: 92 layout = V.graph.scheduler.get_buffer_layout(buf_name) 93 else: 94 buffer = V.graph.try_get_buffer(buf_name) 95 # output arg 96 if not buffer: 97 assert buf_name == V.kernel.output_node.name 98 layout = V.kernel.output_node.layout 99 else: 100 layout = buffer.get_layout() 101 102 if isinstance(layout, torch._inductor.ir.NonOwningLayout): 103 return not layout.maybe_guard_aligned() 104 else: 105 return False 106 107 108def config_of( 109 args: List[KernelArgType], 110 *, 111 indices: Optional[List[int]] = None, 112) -> Any: 113 if indices is None: 114 indices = list(range(len(args))) 115 116 def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: 117 """ 118 Roughly follow triton code here: 119 https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222 120 """ 121 if isinstance(x, TensorArg): 122 if include_tensor: 123 offset_aligned = V.graph.sizevars.statically_known_multiple_of( 124 x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type] 125 ) 126 return offset_aligned and not is_unaligned_buffer(x) 127 else: 128 return False 129 if isinstance(x, SizeArg): 130 # TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with 131 # _maybe_evaluate_static... 132 if x.name.startswith("load_seed_offset"): 133 return False 134 if x.expr is None: 135 return False 136 if isinstance(x.expr, float): 137 return False 138 return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type] 139 if isinstance(x, WorkspaceArg): 140 return V.graph.sizevars.statically_known_multiple_of(x.nbytes, alignment) # type: ignore[arg-type] 141 raise NotImplementedError(f"unhandled {type(x)}: {x}") 142 143 if config.triton.divisible_by_16: 144 divisible_by_16 = tuple( 145 i 146 for i, arg in zip(indices, args) 147 if is_aligned(arg, alignment=16, include_tensor=True) 148 ) 149 else: 150 divisible_by_16 = () 151 divisible_by_8 = tuple( 152 i 153 for i, arg in zip(indices, args) 154 if is_aligned(arg, alignment=8, include_tensor=False) 155 ) 156 157 equal_to_1 = tuple( 158 i 159 for i, arg in zip(indices, args) 160 if isinstance(arg, SizeArg) 161 and isinstance(arg.expr, (int, sympy.Integer)) 162 and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type] 163 ) 164 # ids_of_folded_args is set from equal_to_1 165 # and None args by the Triton compiler 166 ids_of_folded_args = tuple(equal_to_1) 167 168 return instance_descriptor( 169 divisible_by_16, equal_to_1, ids_of_folded_args, divisible_by_8 170 ) 171