xref: /aosp_15_r20/external/pytorch/torch/_inductor/runtime/compile_tasks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import functools
5import os
6import sys
7import warnings
8from types import ModuleType
9from typing import Any, Callable, Dict
10
11
12def _reload_triton_kernel_in_subproc(reload_module, kernel_name):
13    return _module_to_triton_kernel(reload_module(), kernel_name)
14
15
16def _module_to_triton_kernel(mod, kernel_name):
17    kernel = getattr(mod, kernel_name)
18    kernel._reload_in_subproc = functools.partial(
19        _reload_triton_kernel_in_subproc,
20        mod._reload_in_subproc,
21        kernel_name,
22    )
23    return kernel
24
25
26def _reload_python_module_in_subproc(key, path):
27    codecache = sys.modules.get("torch._inductor.codecache")
28    if codecache:
29        return codecache.PyCodeCache.load_by_key_path(key, path)
30    else:
31        return _reload_python_module(key, path)
32
33
34def _reload_python_module(key, path):
35    with open(path) as f:
36        try:
37            code = compile(f.read(), path, "exec", dont_inherit=True)
38        except Exception as e:
39            raise RuntimeError(
40                f"Failed to import {path}\n{type(e).__name__}: {e}"
41            ) from None
42        mod = ModuleType(f"{__name__}.{key}")
43        mod.__file__ = path
44        mod.key = key  # type: ignore[attr-defined]
45        exec(code, mod.__dict__, mod.__dict__)
46        sys.modules[mod.__name__] = mod
47        return mod
48
49
50@functools.lru_cache(None)
51def _set_triton_ptxas_path() -> None:
52    if os.environ.get("TRITON_PTXAS_PATH") is not None:
53        return
54    ptxas_path = os.path.abspath(
55        os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas")
56    )
57    if not os.path.exists(ptxas_path):
58        return
59    if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK):
60        os.environ["TRITON_PTXAS_PATH"] = ptxas_path
61    else:
62        warnings.warn(f"{ptxas_path} exists but is not an executable")
63
64
65def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]):
66    _set_triton_ptxas_path()
67    os.environ.update(extra_env)
68    load_kernel().precompile(warm_cache_only=True)
69