xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/triton_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Dict, List, Optional
3
4import sympy
5
6import torch
7
8from .. import config
9from ..runtime.hints import instance_descriptor
10from ..utils import _type_of
11from ..virtualized import V
12from .common import KernelArgType, SizeArg, TensorArg, WorkspaceArg
13
14
15def should_unwrap_unspec_arg(name: str):
16    if V.graph.is_unspec_arg(name):
17        # Unwrap on all devices except CPU
18        if V.graph.scheduler.get_current_device_or_throw().type != "cpu":
19            return True
20        # Only unwrap on CPU if the input is not used as an output
21        if name not in V.graph.mutated_buffers:
22            return True
23    return False
24
25
26def signature_of(arg: KernelArgType, *, size_dtype: str) -> str:
27    if isinstance(arg, TensorArg):
28        # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes.
29        # Related PR: https://github.com/openai/triton/pull/2279/
30        if arg.dtype == torch.float8_e4m3fn:
31            tye = "*fp8e4nv"
32        elif arg.dtype == torch.float8_e5m2:
33            tye = "*fp8e5"
34        elif arg.dtype == torch.float8_e4m3fnuz:
35            tye = "*fp8e4b8"
36        elif arg.dtype == torch.float8_e5m2fnuz:
37            tye = "*fp8e5b16"
38        else:
39            tye = _type_of(arg.dtype)
40        if should_unwrap_unspec_arg(arg.buffer):
41            # had unwrapped 0d tensor as scalar
42            new_tye = tye.lstrip("*")
43            if new_tye in ["fp16", "bf16"]:
44                return "fp32"
45            else:
46                return new_tye
47        else:
48            return tye
49    if isinstance(arg, SizeArg):
50        if arg.expr is None:
51            # From triton/runtime/jit.py
52            # `None` is nullptr.  Implicitly convert to *i8.
53            return "*i8"
54        elif isinstance(arg.expr, (float, sympy.Float)):
55            return "fp32"
56        if size_dtype == "tl.int32":
57            return "i32"
58        elif size_dtype == "tl.int64":
59            return "i64"
60        else:
61            raise NotImplementedError(f"unhandled size_dtype {size_dtype}")
62    if isinstance(arg, WorkspaceArg):
63        return "*i8"
64    raise NotImplementedError(f"unhandled {type(arg)}: {arg}")
65
66
67def signature_to_meta(
68    signature: List[KernelArgType],
69    *,
70    size_dtype: str,
71    indices: Optional[List[int]] = None,
72) -> Dict[int, str]:
73    if indices is None:
74        indices = list(range(len(signature)))
75    return {
76        i: signature_of(arg, size_dtype=size_dtype)
77        for i, arg in zip(indices, signature)
78    }
79
80
81def is_unaligned_buffer(arg: TensorArg):
82    buf_name = arg.buffer
83    if buf_name in V.graph.graph_inputs:
84        # See Note: [Input Alignment handling in Inductor]
85        return buf_name not in V.graph.aligned_inputs
86
87    if buf_name in V.graph.constants:
88        # all constants are assumed to be aligned
89        return False
90
91    if V.graph.scheduler:
92        layout = V.graph.scheduler.get_buffer_layout(buf_name)
93    else:
94        buffer = V.graph.try_get_buffer(buf_name)
95        # output arg
96        if not buffer:
97            assert buf_name == V.kernel.output_node.name
98            layout = V.kernel.output_node.layout
99        else:
100            layout = buffer.get_layout()
101
102    if isinstance(layout, torch._inductor.ir.NonOwningLayout):
103        return not layout.maybe_guard_aligned()
104    else:
105        return False
106
107
108def config_of(
109    args: List[KernelArgType],
110    *,
111    indices: Optional[List[int]] = None,
112) -> Any:
113    if indices is None:
114        indices = list(range(len(args)))
115
116    def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool:
117        """
118        Roughly follow triton code here:
119        https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222
120        """
121        if isinstance(x, TensorArg):
122            if include_tensor:
123                offset_aligned = V.graph.sizevars.statically_known_multiple_of(
124                    x.offset * x.dtype.itemsize, alignment  # type: ignore[arg-type]
125                )
126                return offset_aligned and not is_unaligned_buffer(x)
127            else:
128                return False
129        if isinstance(x, SizeArg):
130            # TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with
131            # _maybe_evaluate_static...
132            if x.name.startswith("load_seed_offset"):
133                return False
134            if x.expr is None:
135                return False
136            if isinstance(x.expr, float):
137                return False
138            return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment)  # type: ignore[arg-type]
139        if isinstance(x, WorkspaceArg):
140            return V.graph.sizevars.statically_known_multiple_of(x.nbytes, alignment)  # type: ignore[arg-type]
141        raise NotImplementedError(f"unhandled {type(x)}: {x}")
142
143    if config.triton.divisible_by_16:
144        divisible_by_16 = tuple(
145            i
146            for i, arg in zip(indices, args)
147            if is_aligned(arg, alignment=16, include_tensor=True)
148        )
149    else:
150        divisible_by_16 = ()
151    divisible_by_8 = tuple(
152        i
153        for i, arg in zip(indices, args)
154        if is_aligned(arg, alignment=8, include_tensor=False)
155    )
156
157    equal_to_1 = tuple(
158        i
159        for i, arg in zip(indices, args)
160        if isinstance(arg, SizeArg)
161        and isinstance(arg.expr, (int, sympy.Integer))
162        and V.graph.sizevars.statically_known_equals(arg.expr, 1)  # type: ignore[arg-type]
163    )
164    # ids_of_folded_args is set from equal_to_1
165    # and None args by the Triton compiler
166    ids_of_folded_args = tuple(equal_to_1)
167
168    return instance_descriptor(
169        divisible_by_16, equal_to_1, ids_of_folded_args, divisible_by_8
170    )
171