1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import functools 5import logging 6import multiprocessing 7import os 8import sys 9from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor 10from concurrent.futures.process import BrokenProcessPool 11from functools import partial 12from time import time 13from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING 14 15import torch 16from torch._dynamo.device_interface import get_registered_device_interfaces 17from torch._inductor import config 18from torch._inductor.codecache import ( 19 CodeCacheFuture, 20 CppCodeCache, 21 CppPythonBindingsCodeCache, 22 CUDACodeCache, 23 HalideCodeCache, 24 LambdaFuture, 25 ROCmCodeCache, 26 TritonCodeCache, 27 TritonFuture, 28) 29from torch._inductor.compile_worker.subproc_pool import ( 30 _warm_process_pool, 31 AnyPool, 32 SubprocPool, 33) 34from torch._inductor.compile_worker.watchdog import _async_compile_initializer 35from torch._inductor.runtime.compile_tasks import ( 36 _set_triton_ptxas_path, 37 _worker_compile_triton, 38) 39from torch.hub import _Faketqdm, tqdm 40from torch.utils._triton import has_triton_package 41 42 43if TYPE_CHECKING: 44 from torch._inductor.runtime.hints import HalideMeta 45 46# timing metrics for time spent in the compilation 47_cumulative_compile_time = 0.0 48_t0: Optional[float] = None 49 50kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") 51 52 53def pre_fork_setup(): 54 """ 55 Setup that must be done prior to forking with a process pool. 56 """ 57 # ensure properties have been calculated before processes 58 # are forked 59 caching_device_properties() 60 61 # Computing the triton key can be slow. If we call it before fork, 62 # it will be cached for the forked subprocesses. 63 try: 64 from triton.compiler.compiler import triton_key 65 66 triton_key() 67 except ImportError: 68 # Triton might not be installed or might be an old version. 69 pass 70 71 72def caching_device_properties(): 73 for _, device_interface in get_registered_device_interfaces(): 74 if device_interface.is_available(): 75 device_interface.Worker.get_device_properties() 76 77 78def _compile_start() -> None: 79 global _t0 80 if _t0 is None: 81 _t0 = time() 82 83 84def _compile_end() -> None: 85 global _cumulative_compile_time, _t0 86 if _t0 is not None: 87 t1 = time() 88 _cumulative_compile_time += t1 - _t0 89 _t0 = None 90 # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) 91 92 93_IS_WINDOWS = sys.platform == "win32" 94 95log = logging.getLogger(__name__) 96 97 98# Used to keep track of all process pools invoked so far. 99_pool_set: Set[AnyPool] = set() 100 101 102def shutdown_compile_workers() -> None: 103 """Shut down all outstanding compile-worker pools.""" 104 for pool in _pool_set: 105 pool.shutdown() 106 after_fork() 107 108 109def after_fork(): 110 """Reset pools to initial state without shutting them down""" 111 _pool_set.clear() 112 AsyncCompile.process_pool.cache_clear() 113 114 115try: 116 os.register_at_fork(after_in_child=after_fork) 117except AttributeError: 118 pass # register_at_fork does not exists on windows 119 120 121class AsyncCompile: 122 def __init__(self) -> None: 123 pass 124 125 @staticmethod 126 @functools.lru_cache(1) 127 def pool() -> ThreadPoolExecutor: 128 assert config.compile_threads > 1 129 return ThreadPoolExecutor(config.compile_threads) 130 131 @staticmethod 132 def _get_ready(): 133 """No-op function to help mark when the subprocess pool is ready.""" 134 return "ready" 135 136 @staticmethod 137 @functools.lru_cache(1) 138 def process_pool() -> AnyPool: 139 assert config.compile_threads > 1 140 pool: AnyPool 141 if config.worker_start_method == "subprocess": 142 # Wrapper around ProcessPoolExecutor forks in a new process we control 143 pool = SubprocPool(config.compile_threads) 144 else: 145 pre_fork_setup() 146 ctx = multiprocessing.get_context(config.worker_start_method) 147 pool = ProcessPoolExecutor( 148 config.compile_threads, 149 mp_context=ctx, 150 initializer=partial(_async_compile_initializer, os.getpid()), 151 ) 152 # when this pool is created in a subprocess object, the normal exit handler 153 # doesn't run, and we need to register our own handler. 154 # exitpriority has to be high, because another one of the finalizers will 155 # kill the worker thread that sends the shutdown message to the workers... 156 multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) 157 158 # Set an attribute we can check to see if the pool is ready. 159 pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr] 160 _pool_set.add(pool) 161 return pool 162 163 @classmethod 164 def warm_pool(cls) -> None: 165 if config.compile_threads <= 1: 166 return 167 _compile_start() 168 _warm_process_pool(cls.process_pool(), config.compile_threads) 169 _compile_end() 170 171 @classmethod 172 def submit(cls, task: Callable[..., Any]) -> Any: 173 if config.compile_threads <= 1: 174 return task() 175 return cls.pool().submit(task) 176 177 def _use_process_pool(self): 178 return ( 179 config.compile_threads > 1 180 and self.process_pool().ready_future.done() # type: ignore[union-attr] 181 ) 182 183 def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): 184 kernel_code_log.info("Triton Kernel:\n%s", source_code) 185 _compile_start() 186 _set_triton_ptxas_path() 187 188 kernel = TritonCodeCache.load(kernel_name, source_code) 189 if self._use_process_pool(): 190 # We want to support changing these env vars after (and while) the 191 # process pool is running, so pass them to the subprocess to reset. 192 env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] 193 extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} 194 return TritonFuture( 195 kernel, 196 self.process_pool().submit( 197 _worker_compile_triton, 198 kernel._reload_in_subproc, 199 extra_env, 200 ), 201 ) 202 else: 203 kernel.precompile() 204 return kernel 205 206 def multi_kernel(self, *args, **kwargs) -> Any: 207 from torch._inductor.codegen.multi_kernel import MultiKernelCall 208 209 # no need to call this in parallel since the sub-kernels are already parallel tasks 210 return MultiKernelCall(*args, **kwargs) 211 212 def cpp(self, source_code: str): 213 kernel_code_log.info("CPP Kernel:\n%s", source_code) 214 if config.compile_threads <= 1: 215 return CppCodeCache.load(source_code).kernel 216 else: 217 get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) 218 return LambdaFuture(lambda: get_result().kernel) 219 220 def cpp_pybinding(self, argtypes: List[str], source_code: str): 221 kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) 222 if config.compile_threads <= 1: 223 return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) 224 else: 225 get_result = CppPythonBindingsCodeCache.load_pybinding_async( 226 argtypes, source_code, submit_fn=self.submit 227 ) 228 return LambdaFuture(get_result) 229 230 def cuda(self, source_code, dst_file_ext): 231 kernel_code_log.info("CUDA Kernel:\n%s", source_code) 232 233 def task(): 234 return CUDACodeCache.load(source_code, dst_file_ext)[0] 235 236 return self.submit(task) 237 238 def rocm(self, source_code, dst_file_ext): 239 kernel_code_log.info("ROCm Kernel:\n%s", source_code) 240 241 def task(): 242 return ROCmCodeCache.load(source_code, dst_file_ext)[0] 243 244 return self.submit(task) 245 246 def halide(self, meta: HalideMeta, source_code: str): 247 kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) 248 if config.compile_threads <= 1: 249 return HalideCodeCache.generate_halide(meta, source_code) 250 else: 251 get_result = HalideCodeCache.generate_halide_async( 252 meta, source_code, submit_fn=self.submit 253 ) 254 return LambdaFuture(get_result) 255 256 def wait(self, scope: Dict[str, Any]) -> None: 257 num_kernels = len( 258 [ 259 value 260 for key, value in scope.items() 261 if isinstance(value, (Future, CodeCacheFuture)) 262 ] 263 ) 264 pbar = tqdm( 265 total=num_kernels, 266 desc="Inductor Compilation", 267 disable=config.disable_progress, 268 delay=0, 269 ) 270 if config.compile_threads > 1: 271 for key, result in scope.items(): 272 if config.verbose_progress and not isinstance(pbar, _Faketqdm): 273 pbar.set_postfix_str(key) 274 if isinstance(result, (Future, CodeCacheFuture)): 275 try: 276 scope[key] = result.result() 277 except BrokenProcessPool as e: 278 raise RuntimeError( 279 "A compilation subprocess exited unexpectedly. This " 280 "is likely due to a crash. To facilitate debugging, " 281 "you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 " 282 "to cause compilation to occur in the main process." 283 ) from e 284 pbar.update(1) 285 286 _compile_end() 287 288 289if ( 290 os.environ.get("TORCH_TNT_IN_USE", "0") == "1" 291 or os.environ.get("TORCH_WARM_POOL", "1") != "1" 292 # The subprocess pool is only used for the Triton backend 293 or not has_triton_package() 294): 295 pass 296else: 297 AsyncCompile.warm_pool() 298