1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import functools 5import os 6import sys 7import warnings 8from types import ModuleType 9from typing import Any, Callable, Dict 10 11 12def _reload_triton_kernel_in_subproc(reload_module, kernel_name): 13 return _module_to_triton_kernel(reload_module(), kernel_name) 14 15 16def _module_to_triton_kernel(mod, kernel_name): 17 kernel = getattr(mod, kernel_name) 18 kernel._reload_in_subproc = functools.partial( 19 _reload_triton_kernel_in_subproc, 20 mod._reload_in_subproc, 21 kernel_name, 22 ) 23 return kernel 24 25 26def _reload_python_module_in_subproc(key, path): 27 codecache = sys.modules.get("torch._inductor.codecache") 28 if codecache: 29 return codecache.PyCodeCache.load_by_key_path(key, path) 30 else: 31 return _reload_python_module(key, path) 32 33 34def _reload_python_module(key, path): 35 with open(path) as f: 36 try: 37 code = compile(f.read(), path, "exec", dont_inherit=True) 38 except Exception as e: 39 raise RuntimeError( 40 f"Failed to import {path}\n{type(e).__name__}: {e}" 41 ) from None 42 mod = ModuleType(f"{__name__}.{key}") 43 mod.__file__ = path 44 mod.key = key # type: ignore[attr-defined] 45 exec(code, mod.__dict__, mod.__dict__) 46 sys.modules[mod.__name__] = mod 47 return mod 48 49 50@functools.lru_cache(None) 51def _set_triton_ptxas_path() -> None: 52 if os.environ.get("TRITON_PTXAS_PATH") is not None: 53 return 54 ptxas_path = os.path.abspath( 55 os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas") 56 ) 57 if not os.path.exists(ptxas_path): 58 return 59 if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK): 60 os.environ["TRITON_PTXAS_PATH"] = ptxas_path 61 else: 62 warnings.warn(f"{ptxas_path} exists but is not an executable") 63 64 65def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]): 66 _set_triton_ptxas_path() 67 os.environ.update(extra_env) 68 load_kernel().precompile(warm_cache_only=True) 69