1import contextlib 2import threading 3from typing import Callable, Generator, Iterable, Optional, Union 4 5from .custom_ops import custom_op 6from .infer_schema import infer_schema 7 8 9def triton_op( 10 name: str, 11 fn: Optional[Callable] = None, 12 /, 13 *, 14 mutates_args: Union[str, Iterable[str]], 15 schema: Optional[str] = None, 16) -> Callable: 17 """Create a custom operator whose implementation is backed by 1+ triton kernels. 18 19 Use this instead of :func:`torch.library.custom_op` when the implementation 20 consists of 1+ triton kernels. :func:`torch.library.custom_op` treats 21 custom operators as opaque (:func:`torch.compile` and 22 :func:`torch.export.export` will never trace into them), but ``triton_op`` 23 makes the implementation visible to these subsystems, allowing them 24 to optimize the triton kernel(s). 25 26 Note that ``fn`` must only consist of calls to PyTorch-understood 27 operators and triton kernels. Any triton kernels called inside ``fn`` 28 must be wrapped in a call to :func:`torch._library.capture_triton``. 29 30 Args: 31 name (str): A name for the custom op that looks like "{namespace}::{name}", 32 e.g. "mylib::my_linear". The name is used as the op's stable identifier 33 in PyTorch subsystems (e.g. torch.export, FX graphs). 34 To avoid name collisions, please use your project name as the namespace; 35 e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. 36 mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. 37 This MUST be accurate, otherwise, the behavior is undefined. If "unknown", 38 it pessimistically assumes that all inputs to the operator are being mutated. 39 schema (None | str): A schema string for the operator. If None 40 (recommended) we'll infer a schema for the operator from its type 41 annotations. We recommend letting us infer a schema unless you 42 have a specific reason not to. 43 Example: "(Tensor x, int y) -> (Tensor, Tensor)". 44 45 Example:: 46 47 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 48 >>> import torch 49 >>> from torch._library import triton_op, capture_triton 50 >>> 51 >>> import triton 52 >>> from triton import language as tl 53 >>> 54 >>> @triton.jit 55 >>> def add_kernel( 56 >>> in_ptr0, 57 >>> in_ptr1, 58 >>> out_ptr, 59 >>> n_elements, 60 >>> BLOCK_SIZE: "tl.constexpr", 61 >>> ): 62 >>> pid = tl.program_id(axis=0) 63 >>> block_start = pid * BLOCK_SIZE 64 >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) 65 >>> mask = offsets < n_elements 66 >>> x = tl.load(in_ptr0 + offsets, mask=mask) 67 >>> y = tl.load(in_ptr1 + offsets, mask=mask) 68 >>> output = x + y 69 >>> tl.store(out_ptr + offsets, output, mask=mask) 70 >>> 71 >>> @triton_op("mylib::add", mutates_args={}) 72 >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 73 >>> output = torch.empty_like(x) 74 >>> n_elements = output.numel() 75 >>> 76 >>> def grid(meta): 77 >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 78 >>> 79 >>> # NB: we need to wrap the triton kernel in a call to capture_triton 80 >>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) 81 >>> return output 82 >>> 83 >>> @torch.compile 84 >>> def f(x, y): 85 >>> return add(x, y) 86 >>> 87 >>> x = torch.randn(3, device="cuda") 88 >>> y = torch.randn(3, device="cuda") 89 >>> 90 >>> z = f(x, y) 91 >>> assert torch.allclose(z, x + y) 92 93 """ 94 95 def dec(fn: Callable) -> Callable: 96 def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def] 97 # Optimization: we're passing regular Tensors into the triton kernel, so 98 # no need to go through HOP dispatch 99 with set_capture_triton_enabled(False): 100 return fn(*args, **kwargs) 101 102 result = custom_op( 103 name, 104 backend_fn, 105 mutates_args=mutates_args, 106 schema=infer_schema(fn, mutates_args=mutates_args), 107 ) 108 from .._subclasses.functional_tensor import FunctionalTensorMode 109 110 # We require that the user pass us a function that is make_fx traceable, 111 # so we can just register it as the Fake/meta kernel. 112 result.register_fake(fn) 113 114 # We decompose the operator when FunctionalTensorMode is active. 115 # The goal is to decompose the operator in AOTDispatcher. 116 # - With torch.compile, this means that the backend (usually Inductor) 117 # can see a call to the triton kernel(s) and so it can directly optimize 118 # them by inlining them into the lowering process. 119 # - With post-dispatch torch.export, this means that there will 120 # be a call(s) to the triton_kernel_wrapper_functional HOP in the 121 # graph (that we have yet to figure out how to serialize). 122 def functional_decomp( # type: ignore[no-untyped-def] 123 mode, _, types, args, kwargs 124 ): 125 with mode: 126 return fn(*args, **kwargs) 127 128 result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) 129 return result 130 131 if fn is None: 132 return dec 133 else: 134 return dec(fn) 135 136 137capture_triton_enabled = threading.local() 138capture_triton_enabled_default = True 139 140 141@contextlib.contextmanager 142def set_capture_triton_enabled(enabled: bool) -> Generator[None, None, None]: 143 """If triton kernels annotated with @capture_triton should dispatch via HOP 144 or go straight to the triton kernel execution. 145 146 We have this switch because eager-mode performance of HOP dispatch is slow 147 enough to matter (~1ms) and we know that capture_triton isn't necessary in 148 some situations (eager-mode with regular Tensors) 149 """ 150 try: 151 prev = is_capture_triton_enabled() 152 capture_triton_enabled.value = enabled 153 yield 154 finally: 155 capture_triton_enabled.value = prev 156 157 158def is_capture_triton_enabled() -> bool: 159 return getattr(capture_triton_enabled, "value", capture_triton_enabled_default) 160 161 162def capture_triton(triton_kernel: Callable, /) -> Callable: 163 """Allows capture of a triton kernel into a graph via make_fx or 164 non-strict export (coming soon). 165 166 These technologies perform Dispatcher-based tracing (via 167 ``__torch_dispatch__``) and cannot see calls to raw triton kernels. 168 The ``capture_triton`` API returns a new callable that can actually 169 be traced into a graph. 170 171 Examples: 172 173 >>> # xdoctest: +SKIP 174 >>> import torch 175 >>> import triton 176 >>> from triton import language as tl 177 >>> from torch.fx.experimental.proxy_tensor import make_fx 178 >>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton 179 >>> 180 >>> @triton.jit 181 >>> def add_kernel( 182 >>> in_ptr0, 183 >>> in_ptr1, 184 >>> out_ptr, 185 >>> n_elements, 186 >>> BLOCK_SIZE: "tl.constexpr", 187 >>> ): 188 >>> pid = tl.program_id(axis=0) 189 >>> block_start = pid * BLOCK_SIZE 190 >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) 191 >>> mask = offsets < n_elements 192 >>> x = tl.load(in_ptr0 + offsets, mask=mask) 193 >>> y = tl.load(in_ptr1 + offsets, mask=mask) 194 >>> output = x + y 195 >>> tl.store(out_ptr + offsets, output, mask=mask) 196 >>> 197 >>> def add(x, y): 198 >>> output = torch.empty_like(x) 199 >>> n_elements = output.numel() 200 >>> 201 >>> def grid_fn(meta): 202 >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 203 >>> 204 >>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) 205 >>> return output 206 >>> 207 >>> x = torch.randn(3, device="cuda") 208 >>> y = torch.randn(3, device="cuda") 209 >>> gm = make_fx(add)(x, y) 210 >>> print(gm.code) 211 >>> # def forward(self, x_1, y_1): 212 >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) 213 >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( 214 >>> # kernel_idx = 0, constant_args_idx = 0, 215 >>> # grid = [(1, 1, 1)], kwargs = { 216 >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, 217 >>> # 'n_elements': 3, 'BLOCK_SIZE': 16 218 >>> # }) 219 >>> # return empty_like 220 221 """ 222 from triton.runtime.autotuner import Autotuner 223 from triton.runtime.jit import JITFunction 224 225 from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper 226 227 if not isinstance(triton_kernel, (JITFunction, Autotuner)): 228 raise RuntimeError( 229 "capture_triton only works on functions annotated with triton.jit or triton.autotune" 230 ) 231 if not is_capture_triton_enabled(): 232 return triton_kernel 233 return TraceableTritonKernelWrapper(triton_kernel, None, None) 234