xref: /aosp_15_r20/external/pytorch/torch/_inductor/async_compile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import functools
5import logging
6import multiprocessing
7import os
8import sys
9from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
10from concurrent.futures.process import BrokenProcessPool
11from functools import partial
12from time import time
13from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
14
15import torch
16from torch._dynamo.device_interface import get_registered_device_interfaces
17from torch._inductor import config
18from torch._inductor.codecache import (
19    CodeCacheFuture,
20    CppCodeCache,
21    CppPythonBindingsCodeCache,
22    CUDACodeCache,
23    HalideCodeCache,
24    LambdaFuture,
25    ROCmCodeCache,
26    TritonCodeCache,
27    TritonFuture,
28)
29from torch._inductor.compile_worker.subproc_pool import (
30    _warm_process_pool,
31    AnyPool,
32    SubprocPool,
33)
34from torch._inductor.compile_worker.watchdog import _async_compile_initializer
35from torch._inductor.runtime.compile_tasks import (
36    _set_triton_ptxas_path,
37    _worker_compile_triton,
38)
39from torch.hub import _Faketqdm, tqdm
40from torch.utils._triton import has_triton_package
41
42
43if TYPE_CHECKING:
44    from torch._inductor.runtime.hints import HalideMeta
45
46# timing metrics for time spent in the compilation
47_cumulative_compile_time = 0.0
48_t0: Optional[float] = None
49
50kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
51
52
53def pre_fork_setup():
54    """
55    Setup that must be done prior to forking with a process pool.
56    """
57    # ensure properties have been calculated before processes
58    # are forked
59    caching_device_properties()
60
61    # Computing the triton key can be slow. If we call it before fork,
62    # it will be cached for the forked subprocesses.
63    try:
64        from triton.compiler.compiler import triton_key
65
66        triton_key()
67    except ImportError:
68        # Triton might not be installed or might be an old version.
69        pass
70
71
72def caching_device_properties():
73    for _, device_interface in get_registered_device_interfaces():
74        if device_interface.is_available():
75            device_interface.Worker.get_device_properties()
76
77
78def _compile_start() -> None:
79    global _t0
80    if _t0 is None:
81        _t0 = time()
82
83
84def _compile_end() -> None:
85    global _cumulative_compile_time, _t0
86    if _t0 is not None:
87        t1 = time()
88        _cumulative_compile_time += t1 - _t0
89        _t0 = None
90        # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
91
92
93_IS_WINDOWS = sys.platform == "win32"
94
95log = logging.getLogger(__name__)
96
97
98# Used to keep track of all process pools invoked so far.
99_pool_set: Set[AnyPool] = set()
100
101
102def shutdown_compile_workers() -> None:
103    """Shut down all outstanding compile-worker pools."""
104    for pool in _pool_set:
105        pool.shutdown()
106    after_fork()
107
108
109def after_fork():
110    """Reset pools to initial state without shutting them down"""
111    _pool_set.clear()
112    AsyncCompile.process_pool.cache_clear()
113
114
115try:
116    os.register_at_fork(after_in_child=after_fork)
117except AttributeError:
118    pass  # register_at_fork does not exists on windows
119
120
121class AsyncCompile:
122    def __init__(self) -> None:
123        pass
124
125    @staticmethod
126    @functools.lru_cache(1)
127    def pool() -> ThreadPoolExecutor:
128        assert config.compile_threads > 1
129        return ThreadPoolExecutor(config.compile_threads)
130
131    @staticmethod
132    def _get_ready():
133        """No-op function to help mark when the subprocess pool is ready."""
134        return "ready"
135
136    @staticmethod
137    @functools.lru_cache(1)
138    def process_pool() -> AnyPool:
139        assert config.compile_threads > 1
140        pool: AnyPool
141        if config.worker_start_method == "subprocess":
142            # Wrapper around ProcessPoolExecutor forks in a new process we control
143            pool = SubprocPool(config.compile_threads)
144        else:
145            pre_fork_setup()
146            ctx = multiprocessing.get_context(config.worker_start_method)
147            pool = ProcessPoolExecutor(
148                config.compile_threads,
149                mp_context=ctx,
150                initializer=partial(_async_compile_initializer, os.getpid()),
151            )
152            # when this pool is created in a subprocess object, the normal exit handler
153            # doesn't run, and we need to register our own handler.
154            # exitpriority has to be high, because another one of the finalizers will
155            # kill the worker thread that sends the shutdown message to the workers...
156            multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
157
158        # Set an attribute we can check to see if the pool is ready.
159        pool.ready_future = pool.submit(AsyncCompile._get_ready)  # type: ignore[union-attr]
160        _pool_set.add(pool)
161        return pool
162
163    @classmethod
164    def warm_pool(cls) -> None:
165        if config.compile_threads <= 1:
166            return
167        _compile_start()
168        _warm_process_pool(cls.process_pool(), config.compile_threads)
169        _compile_end()
170
171    @classmethod
172    def submit(cls, task: Callable[..., Any]) -> Any:
173        if config.compile_threads <= 1:
174            return task()
175        return cls.pool().submit(task)
176
177    def _use_process_pool(self):
178        return (
179            config.compile_threads > 1
180            and self.process_pool().ready_future.done()  # type: ignore[union-attr]
181        )
182
183    def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
184        kernel_code_log.info("Triton Kernel:\n%s", source_code)
185        _compile_start()
186        _set_triton_ptxas_path()
187
188        kernel = TritonCodeCache.load(kernel_name, source_code)
189        if self._use_process_pool():
190            # We want to support changing these env vars after (and while) the
191            # process pool is running, so pass them to the subprocess to reset.
192            env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
193            extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
194            return TritonFuture(
195                kernel,
196                self.process_pool().submit(
197                    _worker_compile_triton,
198                    kernel._reload_in_subproc,
199                    extra_env,
200                ),
201            )
202        else:
203            kernel.precompile()
204            return kernel
205
206    def multi_kernel(self, *args, **kwargs) -> Any:
207        from torch._inductor.codegen.multi_kernel import MultiKernelCall
208
209        # no need to call this in parallel since the sub-kernels are already parallel tasks
210        return MultiKernelCall(*args, **kwargs)
211
212    def cpp(self, source_code: str):
213        kernel_code_log.info("CPP Kernel:\n%s", source_code)
214        if config.compile_threads <= 1:
215            return CppCodeCache.load(source_code).kernel
216        else:
217            get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
218            return LambdaFuture(lambda: get_result().kernel)
219
220    def cpp_pybinding(self, argtypes: List[str], source_code: str):
221        kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
222        if config.compile_threads <= 1:
223            return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
224        else:
225            get_result = CppPythonBindingsCodeCache.load_pybinding_async(
226                argtypes, source_code, submit_fn=self.submit
227            )
228            return LambdaFuture(get_result)
229
230    def cuda(self, source_code, dst_file_ext):
231        kernel_code_log.info("CUDA Kernel:\n%s", source_code)
232
233        def task():
234            return CUDACodeCache.load(source_code, dst_file_ext)[0]
235
236        return self.submit(task)
237
238    def rocm(self, source_code, dst_file_ext):
239        kernel_code_log.info("ROCm Kernel:\n%s", source_code)
240
241        def task():
242            return ROCmCodeCache.load(source_code, dst_file_ext)[0]
243
244        return self.submit(task)
245
246    def halide(self, meta: HalideMeta, source_code: str):
247        kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code)
248        if config.compile_threads <= 1:
249            return HalideCodeCache.generate_halide(meta, source_code)
250        else:
251            get_result = HalideCodeCache.generate_halide_async(
252                meta, source_code, submit_fn=self.submit
253            )
254            return LambdaFuture(get_result)
255
256    def wait(self, scope: Dict[str, Any]) -> None:
257        num_kernels = len(
258            [
259                value
260                for key, value in scope.items()
261                if isinstance(value, (Future, CodeCacheFuture))
262            ]
263        )
264        pbar = tqdm(
265            total=num_kernels,
266            desc="Inductor Compilation",
267            disable=config.disable_progress,
268            delay=0,
269        )
270        if config.compile_threads > 1:
271            for key, result in scope.items():
272                if config.verbose_progress and not isinstance(pbar, _Faketqdm):
273                    pbar.set_postfix_str(key)
274                if isinstance(result, (Future, CodeCacheFuture)):
275                    try:
276                        scope[key] = result.result()
277                    except BrokenProcessPool as e:
278                        raise RuntimeError(
279                            "A compilation subprocess exited unexpectedly. This "
280                            "is likely due to a crash. To facilitate debugging, "
281                            "you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 "
282                            "to cause compilation to occur in the main process."
283                        ) from e
284                    pbar.update(1)
285
286        _compile_end()
287
288
289if (
290    os.environ.get("TORCH_TNT_IN_USE", "0") == "1"
291    or os.environ.get("TORCH_WARM_POOL", "1") != "1"
292    # The subprocess pool is only used for the Triton backend
293    or not has_triton_package()
294):
295    pass
296else:
297    AsyncCompile.warm_pool()
298