1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import base64 5import copyreg 6import dataclasses 7import functools 8import hashlib 9import importlib 10import io 11import json 12import logging 13import os 14import pickle 15import pkgutil 16import platform 17import re 18import shlex 19import shutil 20import struct 21import subprocess 22import sys 23import sysconfig 24import tempfile 25import textwrap 26import threading 27import warnings 28from bisect import bisect_right 29from copy import copy 30from ctypes import c_void_p, cdll, CDLL 31from functools import partial 32from pathlib import Path 33from time import time, time_ns 34from types import ModuleType 35from typing import ( 36 Any, 37 Callable, 38 cast, 39 Dict, 40 Generator, 41 List, 42 Optional, 43 Sequence, 44 Set, 45 Tuple, 46 TYPE_CHECKING, 47 Union, 48) 49 50import torch 51from torch._dynamo.utils import counters, dynamo_timed 52from torch._inductor import config, exc, metrics 53from torch._inductor.codegen.cuda import cuda_env 54from torch._inductor.runtime.compile_tasks import ( 55 _module_to_triton_kernel, 56 _reload_python_module, 57 _reload_python_module_in_subproc, 58) 59from torch._inductor.runtime.runtime_utils import cache_dir 60from torch._inductor.utils import ALIGN_BYTES, clear_on_fresh_inductor_cache, is_linux 61 62from torch._logging import trace_structured 63from torch._subclasses.fake_tensor import ( 64 extract_tensor_metadata, 65 FakeTensor, 66 TensorMetadata, 67) 68from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv 69 70if TYPE_CHECKING: 71 from concurrent.futures import Future 72 73 from torch._inductor.graph import GraphLowering 74 from torch._inductor.ir import ChoiceCaller 75 from torch._inductor.runtime.hints import HalideMeta 76 77 78_HERE = os.path.abspath(__file__) 79_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) 80_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld") 81 82_IS_WINDOWS = sys.platform == "win32" 83 84if config.is_fbcode(): 85 from triton.fb import build_paths 86 from triton.fb.build import _run_build_command 87 88 from torch._inductor.fb.utils import ( 89 log_global_cache_errors, 90 log_global_cache_stats, 91 log_global_cache_vals, 92 use_global_cache, 93 ) 94else: 95 96 def log_global_cache_errors(*args, **kwargs): 97 pass 98 99 def log_global_cache_stats(*args, **kwargs): 100 pass 101 102 def log_global_cache_vals(*args, **kwargs): 103 pass 104 105 def use_global_cache() -> bool: 106 return False 107 108 109output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") 110 111LOCK_TIMEOUT = 600 112 113_IS_WINDOWS = sys.platform == "win32" 114 115 116log = logging.getLogger(__name__) 117 118 119def cpp_wrapper_cache_dir(name: str) -> str: 120 cu_str = ( 121 "cpu" 122 if torch.version.cuda is None 123 else f'cu{torch.version.cuda.replace(".", "")}' 124 ) 125 python_version = f"py{sys.version_info.major}{sys.version_info.minor}" 126 build_folder = f"{python_version}_{cu_str}" 127 128 cpp_wrapper_dir = os.path.join(cache_dir(), build_folder) 129 cpp_wrapper_build_directory = os.path.join(cpp_wrapper_dir, name) 130 os.makedirs(cpp_wrapper_build_directory, exist_ok=True) 131 return cpp_wrapper_build_directory 132 133 134def get_cpp_wrapper_cubin_path_name(): 135 return "cubin_path" if torch.version.hip is None else "hsaco_path" 136 137 138class CacheBase: 139 @staticmethod 140 @functools.lru_cache(None) 141 def get_system() -> Dict[str, Any]: 142 try: 143 from triton.compiler.compiler import triton_key 144 145 # Use triton_key instead of triton.__version__ as the version 146 # is not updated with each code change 147 triton_version = triton_key() 148 except ModuleNotFoundError: 149 triton_version = None 150 151 try: 152 system: Dict[str, Any] = { 153 "device": { 154 "name": torch.cuda.get_device_properties( 155 torch.cuda.current_device() 156 ).name, 157 }, 158 "version": { 159 "cuda": torch.version.cuda, 160 "triton": triton_version, 161 }, 162 } 163 except (AssertionError, RuntimeError): 164 # If cuda is not installed, none of the above config is relevant. 165 system = {} 166 167 system["hash"] = hashlib.sha256( 168 json.dumps(system, sort_keys=True).encode("utf-8") 169 ).hexdigest() 170 171 return system 172 173 @staticmethod 174 @clear_on_fresh_inductor_cache 175 @functools.lru_cache(None) 176 def get_local_cache_path() -> Path: 177 return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) 178 179 @staticmethod 180 @functools.lru_cache(None) 181 def get_global_cache_path() -> Optional[Path]: 182 return ( 183 Path(os.path.join(config.global_cache_dir, CacheBase.get_system()["hash"])) 184 if config.global_cache_dir is not None 185 else None 186 ) 187 188 def __init__(self) -> None: 189 self.system = CacheBase.get_system() 190 191 def get_local_cache(self) -> Dict[str, Any]: 192 local_cache_path = self.get_local_cache_path() 193 if not local_cache_path.is_file(): 194 return {} 195 with open(local_cache_path) as local_cache_fp: 196 local_cache = json.load(local_cache_fp) 197 return local_cache["cache"] 198 199 def update_local_cache(self, local_cache: Dict[str, Any]) -> None: 200 local_cache_path = self.get_local_cache_path() 201 write_atomic( 202 str(local_cache_path), 203 json.dumps({"system": self.system, "cache": local_cache}, indent=4), 204 make_dirs=True, 205 ) 206 207 208class LocalCache(CacheBase): 209 def lookup(self, *keys: str) -> Optional[Dict[str, Any]]: 210 cache = self.get_local_cache() 211 212 sub_cache = cache 213 for key in keys: 214 if key in cache: 215 sub_cache = cache[key] 216 else: 217 return None 218 219 return sub_cache 220 221 def set_value(self, *keys: str, value: Any) -> None: 222 cache = self.get_local_cache() 223 224 sub_cache = cache 225 for key in keys[0:-1]: 226 sub_cache.setdefault(key, {}) 227 sub_cache = sub_cache[key] 228 sub_cache[keys[-1]] = value 229 230 self.update_local_cache(cache) 231 232 233class PersistentCache(CacheBase): 234 @functools.lru_cache(None) # noqa: B019 235 def get_global_cache(self): 236 global_cache_path = self.get_global_cache_path() 237 if global_cache_path is None or not global_cache_path.is_file(): 238 return {} 239 with open(global_cache_path) as global_cache_fp: 240 global_cache = json.load(global_cache_fp) 241 return global_cache["cache"] 242 243 def lookup( 244 self, 245 choices: List[ChoiceCaller], 246 op: str, 247 inputs: str, 248 benchmark: Optional[Callable[[Any], Dict[ChoiceCaller, float]]], 249 ) -> Dict[ChoiceCaller, float]: 250 """ 251 Check to see if we have benchmarked the given choice callers. For each 252 choice caller: 253 254 1. Check global_cache[op][inputs][choice][precision], return benchmark if cached. 255 2. Check local_cache[op][inputs][choice][precision], return benchmark if cached. 256 3. If benchmark is not None: 257 a. `max_autotune_gemm=True`: benchmark the choice, update 258 local_cache[op][inputs][choice], and return the benchmark. 259 b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing. 260 """ 261 precision = torch.get_float32_matmul_precision() 262 263 log_stats = partial(log_global_cache_stats, self.system, op, inputs, precision) 264 log_vals = partial(log_global_cache_vals, self.system, op, inputs, precision) 265 log_errors = partial( 266 log_global_cache_errors, self.system, op, inputs, precision 267 ) 268 timings = {} 269 270 def check_cache(cache, callback=None) -> bool: 271 """Check if `cache` contains data for all the choices""" 272 hit = True 273 for choice in choices: 274 choice_hash = choice.hash_key() 275 if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}): 276 # cache hit 277 timings[choice] = cache[op][inputs][precision][choice_hash] 278 else: 279 # cache miss 280 hit = False 281 break 282 if callback: 283 callback(cached=hit) 284 return hit 285 286 if config.max_autotune or config.max_autotune_gemm: 287 local_cache = self.get_local_cache() if config.autotune_local_cache else {} 288 # check local cache first since it is data specific to the current machine 289 if ( 290 not check_cache(local_cache) 291 and not ( 292 use_global_cache() 293 and check_cache(self.get_global_cache(), callback=log_stats) 294 ) 295 and benchmark is not None 296 ): 297 try: 298 # re-benchmark everything to try to get consistent numbers from the same machine 299 timings = benchmark(choices) 300 assert all(choice in timings for choice in choices) 301 local_cache.setdefault(op, {}) 302 local_cache[op].setdefault(inputs, {}).setdefault(precision, {}) 303 for choice, timing in timings.items(): 304 local_cache[op][inputs][precision][choice.hash_key()] = timing 305 except RuntimeError as e: 306 # catch and log autotuning failures 307 log_errors(e) 308 raise e 309 310 self.update_local_cache(local_cache) 311 312 timings_to_log = { 313 choice.hash_key(): timings[choice] for choice in choices 314 } 315 log_vals(timings_to_log) 316 elif use_global_cache(): 317 # only check global cache, not local one 318 check_cache(self.get_global_cache(), callback=log_stats) 319 # may have a partial cache hit, where not everything is benchmarked 320 321 return timings 322 323 324def get_lock_dir() -> str: 325 lock_dir = os.path.join(cache_dir(), "locks") 326 if not os.path.exists(lock_dir): 327 os.makedirs(lock_dir, exist_ok=True) 328 return lock_dir 329 330 331def sha256_hash(data: bytes) -> str: 332 # [:51] to strip off the "Q====" suffix common to every hash value. 333 return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() 334 335 336def code_hash(code: Union[str, bytes], extra: str = ""): 337 hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") 338 if extra != "": 339 hashing_str = hashing_str + b"||" + extra.encode("utf-8") 340 return "c" + sha256_hash(hashing_str) 341 342 343def get_path( 344 basename: str, extension: str, specified_dir: str = "" 345) -> Tuple[str, str, str]: 346 if specified_dir: 347 if os.path.isabs(specified_dir): 348 subdir = specified_dir 349 else: 350 subdir = os.path.join(cache_dir(), specified_dir) 351 else: 352 subdir = os.path.join(cache_dir(), basename[1:3]) 353 path = os.path.join(subdir, f"{basename}.{extension}") 354 return basename, subdir, path 355 356 357def get_hash(content: Union[str, bytes], extra: str = "", hash_type: str = "code"): 358 if hash_type == "code": 359 return code_hash(content, extra) 360 if hash_type in ["cubin", "hsaco"]: 361 return code_hash(repr(content)) 362 raise AssertionError(f"Unknown hash type {hash_type}") 363 364 365def write( 366 content: Union[str, bytes], 367 extension: str, 368 extra: str = "", 369 hash_type: str = "code", 370 specified_dir: str = "", 371) -> Tuple[str, str]: 372 # use striped content to compute hash so we don't end up with different 373 # hashes just because the content begins/ends with different number of 374 # spaces. 375 key: str = get_hash(content.strip(), extra, hash_type) 376 basename, subdir, path = get_path(key, extension, specified_dir) 377 if not os.path.exists(path): 378 write_atomic(path, content, make_dirs=True) 379 return basename, path 380 381 382def write_text(text: str) -> str: 383 """ 384 Write the `text` to a file and return the path computed based on the hash. 385 """ 386 return write(text, "txt")[1] 387 388 389def write_atomic( 390 path: str, content: Union[str, bytes], make_dirs: bool = False 391) -> None: 392 # Write into temporary file first to avoid conflicts between threads 393 # Avoid using a named temporary file, as those have restricted permissions 394 assert isinstance( 395 content, (str, bytes) 396 ), "Only strings and byte arrays can be saved in the cache" 397 path = Path(path) 398 if make_dirs: 399 path.parent.mkdir(parents=True, exist_ok=True) 400 tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" 401 write_mode = "w" if isinstance(content, str) else "wb" 402 with tmp_path.open(write_mode) as f: 403 f.write(content) 404 tmp_path.rename(path) 405 406 407@dataclasses.dataclass 408class TensorMetadataAndValues: 409 """ 410 TensorMetadata plus the elements as a list of raw values. 411 Used for hashing inlined constants. 412 """ 413 414 tensor_metadata: TensorMetadata 415 values: List[Any] 416 417 418def _ident(x: Any) -> Any: 419 return x 420 421 422def extract_tensor_metadata_for_cache_key(t): 423 """ 424 Extracts the tensor metadata and removes fields of the TensorMetadata 425 that are not needed for caching 426 """ 427 meta = extract_tensor_metadata(t) 428 if not hasattr(t, "_is_inductor_static"): 429 meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None) 430 return meta 431 432 433def _reduce_fake_tensor(t): 434 """ 435 See FxGraphCachePickler. Custom reducer to pickle FakeTensors. 436 """ 437 metadata = extract_tensor_metadata_for_cache_key(t) 438 return (_ident, (metadata,)) 439 440 441def _reduce_tensor(t): 442 """ 443 See FxGraphCachePickler. Custom reducer to pickle Tensors. 444 If we see tensors, we know they're constants stored as attributes on 445 the GraphModule. Include the values in the key calculation. Small 446 tensors will be inlined, so we can't serve the same cache entry for 447 different values anyway. Large constants are treated as parameters, 448 so we could conceivably reuse a cache entry. To do that, however, 449 PyCodeCache would need more complexity to create a new module from its 450 cache, but with the right constants attached as attributes. 451 """ 452 if t.is_mkldnn: 453 # TODO: These tensors don't currently pickle, so we can't cache a 454 # compiled graph containing them. Just fail now. If mkldnn tensors 455 # get pickling support, we can remove this. 456 raise BypassFxGraphCache 457 458 # Very large tensors could be expensive to copy to cpu and hash. Let's 459 # at least report if we find slowness. 460 start = time() 461 values = t.tolist() 462 elapsed = time() - start 463 if elapsed > 1.0: 464 warnings.warn( 465 f"FX graph cache handling of a large constant took {elapsed:.1}s. Please file an issue." 466 ) 467 468 metadata = extract_tensor_metadata_for_cache_key(t) 469 return (_ident, (TensorMetadataAndValues(metadata, values),)) 470 471 472def _reduce_symint(s): 473 """ 474 See FxGraphCachePickler. Custom reducer to pickle SymInts. 475 """ 476 # For hashing purposes, we only care about the name of the symbol and 477 # not the backed value. We evaluate guards stored with a cached graph 478 # to ensure a cached entity with SymInt args is safe to reuse. 479 return (_ident, (str(s),)) 480 481 482def _reduce_unsupported(s): 483 """ 484 See FxGraphCachePickler. Custom reducer to handle any objects that we don't 485 support and therefore raise to bypass caching. 486 """ 487 raise BypassFxGraphCache 488 489 490class FxGraphCachePickler(pickle.Pickler): 491 """ 492 Custom pickler to customize the pickling of some objects (Tensors), only for the 493 purpose of computing a hash for keying into the FxGraphCache. Tensors contain 494 objects that don't pickle and/or vary between runs, and we want to capture the 495 data that allow us to compute a stable, but safe hash. 496 """ 497 498 dispatch_table = copyreg.dispatch_table.copy() 499 dispatch_table[FakeTensor] = _reduce_fake_tensor 500 dispatch_table[torch.Tensor] = _reduce_tensor 501 dispatch_table[torch.SymInt] = _reduce_symint 502 dispatch_table[ 503 torch.fx.experimental._backward_state.BackwardState 504 ] = _reduce_unsupported 505 506 @classmethod 507 def dumps(cls, obj) -> bytes: 508 """ 509 Pickle an object using the FxGraphCachePickler. 510 """ 511 with io.BytesIO() as stream: 512 pickler = cls(stream) 513 try: 514 pickler.dump(obj) 515 except (TypeError, AttributeError) as e: 516 # Some configs options are callables, e.g., post_grad_custom_pre_pass, 517 # and may not pickle. 518 log.warning("Can't pickle", exc_info=True) 519 raise BypassFxGraphCache from e 520 return stream.getvalue() 521 522 @classmethod 523 def get_hash(cls, obj: Any) -> str: 524 """ 525 Serialize an object using the FxGraphCachePickler and return a hash 526 of the pickled object. 527 """ 528 serialized_data = cls.dumps(obj) 529 return sha256_hash(serialized_data) 530 531 @classmethod 532 def debug_str(cls, inp: Any) -> str: 533 """ 534 Get a printable string describing in more detail all the attributes 535 comprising an object. Useful for debugging when one graph hashes 536 to a different value than another. 537 """ 538 539 def get_str(obj) -> str: 540 if isinstance(obj, torch.Tensor): 541 return str(extract_tensor_metadata_for_cache_key(obj)) 542 elif isinstance(obj, bytes): 543 return "<bytes>" 544 else: 545 return str(obj) 546 547 lines = [] 548 for attr, obj in vars(inp).items(): 549 if isinstance(obj, list): 550 for ii in range(len(obj)): 551 h = cls.get_hash(obj[ii]) 552 lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}") 553 elif isinstance(obj, dict): 554 for k, v in obj.items(): 555 h = cls.get_hash(v) 556 lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}") 557 else: 558 h = cls.get_hash(obj) 559 lines.append(f"[{h}] {attr}: {get_str(obj)}") 560 return "\n".join(lines) 561 562 563def build_code_hash(roots, prefix, hasher): 564 for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): 565 spec = lib.module_finder.find_spec(lib.name, None) 566 assert spec is not None 567 module = spec.origin 568 assert module is not None 569 with open(module, "rb") as f: 570 hasher.update(spec.name.encode("utf-8")) 571 hasher.update(f.read()) 572 if lib.ispkg: 573 # need to also hash submodules 574 build_code_hash(spec.submodule_search_locations, f"{spec.name}.", hasher) 575 576 577def get_code_hash(roots, extra_files=()): 578 hasher = hashlib.sha256() 579 hasher.update(torch.__version__.encode("utf-8")) 580 build_code_hash(roots, "", hasher) 581 for path in extra_files: 582 if os.path.exists(path): 583 with open(path, "rb") as f: 584 hasher.update(f.read()) 585 return hasher.digest() 586 587 588@functools.lru_cache(None) 589def torch_key(): 590 """ 591 Compute a key that contains relevant information about torch source files 592 """ 593 if not config.is_fbcode(): 594 inductor_root = os.path.dirname(__file__) 595 extra_files = ( 596 "codegen/aoti_runtime/interface.cpp", 597 "codegen/aoti_runtime/implementation.cpp", 598 "codegen/cpp_prefix.h", 599 "script.ld", 600 ) 601 return get_code_hash( 602 [inductor_root], [os.path.join(inductor_root, x) for x in extra_files] 603 ) 604 605 from libfb.py import parutil 606 607 return parutil.get_file_contents("torch/src_hash.txt").rstrip() 608 609 610def get_inductor_root(): 611 return os.path.dirname(__file__) 612 613 614@dataclasses.dataclass 615class OrderedSetHolder: 616 """ 617 See FxGraphHashDetails. Holds a sorted list to support stable hashing 618 of set kwargs. 619 """ 620 621 items: List[Any] 622 623 624class BypassFxGraphCache(Exception): 625 """ 626 Exception to indicate that the FxGraphCache should be bypassed. 627 """ 628 629 pass 630 631 632class FxGraphHashDetails: 633 """ 634 Object to capture all the details for a compiled FX graph relevant to computing 635 a safe and stable cache key. 636 """ 637 638 # Excluded kwargs param that are not stable between runs 639 EXCLUDED_KWARGS = ["graph_id"] 640 641 def __init__( 642 self, 643 gm: torch.fx.GraphModule, 644 example_inputs: List[torch.Tensor], 645 fx_kwargs: Dict[str, Any], 646 inputs_to_check: Sequence[int], 647 ): 648 self.gm = gm 649 self.example_inputs = example_inputs 650 651 # Order kwargs so hashing is stable to changes in kwarg order. 652 self.fx_kwargs = {} 653 for k in sorted(fx_kwargs): 654 if k not in self.EXCLUDED_KWARGS: 655 if type(fx_kwargs[k]) is set: 656 # Special case to handle set params. Python sets can't be 657 # ordered, so sort the elements and store them in a proxy. 658 self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k])) 659 else: 660 self.fx_kwargs[k] = fx_kwargs[k] 661 662 # Alignment checks 663 self.inputs_to_check = inputs_to_check 664 665 # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels. 666 self.deterministic_algorithms_settings = ( 667 torch.are_deterministic_algorithms_enabled(), 668 torch.is_deterministic_algorithms_warn_only_enabled(), 669 torch.utils.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined] 670 ) 671 672 # Global settings affecting matmul codegen. 673 self.cuda_matmul_settings = ( 674 torch.backends.cuda.matmul.allow_tf32, 675 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction, 676 torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction, 677 ) 678 679 # Also hash on various system info (including the triton compiler version). 680 self.torch_version = torch_key() 681 self.system_info = CacheBase.get_system() 682 self.inductor_config = config.save_config_portable() 683 684 def debug_str(self) -> str: 685 """ 686 Get a printable string describing in more detail all the attributes 687 comprising this object. Useful for debugging when one graph hashes 688 to a different value than another. 689 """ 690 return FxGraphCachePickler.debug_str(self) 691 692 693def compiled_fx_graph_hash( 694 gm: torch.fx.GraphModule, 695 example_inputs: List[torch.Tensor], 696 fx_kwargs: Dict[str, Any], 697 inputs_to_check: Sequence[int], 698) -> str: 699 """ 700 Generate a unique hash of the FX graph for caching. 701 """ 702 details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check) 703 # The prefix distinguishes among the other kinds of objects we 704 # cache in this module. 705 key = "f" + FxGraphCachePickler.get_hash(details) 706 debug_str = details.debug_str() 707 log.debug(f"FX graph cache hash details for key {key}:\n{debug_str}") # noqa: G004 708 torch._logging.trace_structured( 709 "artifact", 710 metadata_fn=lambda: { 711 "name": "fx_graph_cache_hash", 712 "encoding": "json", 713 }, 714 payload_fn=lambda: json.dumps( 715 {"key": key, "components": debug_str.split("\n")} 716 ), 717 ) 718 719 return key 720 721 722class FxGraphCache: 723 """ 724 Supports caching and reusing compiled Fx graphs. 725 726 The overall strategy is as follows: 727 - This cache stores entries on disk. When saving an entry, we can't 728 serialize callables (that could be C++, Triton, etc.), so we serialize 729 their own disk cache location. We then recreate the compiled artifact 730 after fetching from disk. 731 - For indexing the cache, we gather the fields relevant to identifying an 732 FxGraph (the graph module, graph inputs, system settings etc.) into an 733 FxGraphCacheDetails object, pickle it, and compute a hash for the key. 734 See FxGraphCachePickler. 735 - Among the metadata we store, we also include a guards expression that's 736 appropriate for validating any symbols for Tensor arguments that have 737 symbolic bounds. On cache lookup then, we evaluate those guards in the 738 current context to validate that a cached entry can be served. 739 - A given graph could have multiple compiled versions, corresponding to 740 different sets of guards. Therefore, we store cache entries in the form: 741 <temp dir>/<fx graph hash>/<serialized metatdata> 742 - On lookup, we compute the key from the graph details, iterate over all 743 leaf files in the corresponding subdirectory, deserialize the entry, and 744 evaluate its guards expression. If the evaluation succeeds, we have a 745 cache hit. If it fails, we compile the graph and store a new entry. 746 - Finally, on a cache hit, we need to make sure any guards that would 747 have been created during compilation are added to the current context. 748 """ 749 750 # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs 751 # in an in-memory cache after loading from disk. 752 @staticmethod 753 def _get_tmp_dir() -> str: 754 """ 755 Get the toplevel temporary directory for storing compiled graphs. 756 """ 757 return os.path.join(cache_dir(), "fxgraph") 758 759 @staticmethod 760 def _get_tmp_dir_for_key(key: str) -> str: 761 """ 762 Return the disk location for a given cache key. 763 """ 764 return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) 765 766 @staticmethod 767 def _filter_backed_symints(inputs: List[Any]) -> List[torch.SymInt]: 768 """ 769 Get the backed SymInt objects from the input list. Note that we can never 770 have guards that depend on unbacked symint. 771 """ 772 return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)] 773 774 @staticmethod 775 def _get_shape_env() -> Optional[ShapeEnv]: 776 """ 777 Helper to get the shape env from the tracing context. 778 """ 779 ctx = torch._guards.TracingContext.try_get() 780 if not ctx: 781 return None 782 return ctx.fake_mode.shape_env 783 784 @staticmethod 785 def _lookup_graph( 786 key: str, 787 example_inputs: List[torch.Tensor], 788 local, 789 remote_cache, 790 ) -> Optional[CompiledFxGraph]: 791 """ 792 Lookup a compiled graph in the cache by key. On a hit, return the 793 deserialized CompiledFxGraph object. On a miss, return None. 794 """ 795 shape_env = FxGraphCache._get_shape_env() 796 assert shape_env is not None 797 798 symints = FxGraphCache._filter_backed_symints(example_inputs) 799 hints = [hint_int(s) for s in symints] 800 801 def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]: 802 if local: 803 subdir = FxGraphCache._get_tmp_dir_for_key(key) 804 if os.path.exists(subdir): 805 for path in sorted(os.listdir(subdir)): 806 try: 807 with open(os.path.join(subdir, path), "rb") as f: 808 yield pickle.load(f) 809 except Exception: 810 log.warning( 811 "fx graph cache unable to load compiled graph", 812 exc_info=True, 813 ) 814 815 if remote_cache: 816 try: 817 if (data := remote_cache.get(key)) is not None: 818 yield pickle.loads(data) 819 except Exception: 820 log.warning( 821 "fx graph cache unable to load compiled graph", exc_info=True 822 ) 823 824 # Iterate over any entries in the subdir for this key and evaluate 825 # their guards to determine whether there's a hit. 826 graph = None 827 828 for candidate in iterate_over_candidates(): 829 if not candidate.guards_expr: 830 # No guards to evaluate, so this is a hit. 831 graph = candidate 832 break 833 834 # Evaluate the guard expression in the current context. 835 # If there's not a cache hit, we don't want the evaluation to 836 # affect the current env, e.g., cause the creation of new guards, 837 # so we evaluate with the hints instead of the symbols. 838 hit = bool( 839 shape_env.evaluate_guards_expression(candidate.guards_expr, hints) 840 ) 841 log.debug( 842 "fx graph cache key %s evaluating guards [%s] with values %s => hit=%s", 843 key, 844 candidate.guards_expr, 845 hints, 846 hit, 847 ) 848 if hit: 849 graph = candidate 850 break 851 852 if graph is None: 853 return None 854 855 # See _save_graph(); we don't store the callable in the cache entry so 856 # recreate it here from the PyCodeCache disk cache. 857 artifact_path = get_path(graph.cache_key, "py")[2] 858 if not os.path.exists(artifact_path): 859 counters["inductor"]["fxgraph_lookup_write_file"] += 1 860 Path(os.path.dirname(artifact_path)).mkdir(parents=True, exist_ok=True) 861 code = graph.source_code 862 cpp_pp = cpp_prefix_path() 863 if os.path.basename(cpp_pp) in code: 864 if cpp_pp in code: 865 # Great the name is correct 866 pass 867 else: 868 # Old dir name is included, replace it 869 pattern = rf'#include\s*"[^"]+{os.path.basename(cpp_pp)}"' 870 code = re.sub(pattern, f'#include "{cpp_pp}"', code) 871 872 write_atomic(artifact_path, code, make_dirs=True) 873 874 try: 875 graph.current_callable = PyCodeCache.load_by_key_path( 876 graph.cache_key, 877 artifact_path, 878 graph.cache_linemap, 879 graph.constants, 880 ).call 881 except OSError: 882 # Not expected, but in case the PyCodeCache entry is removed from 883 # underneath us, treat it as a cache miss and recompile. 884 log.error("Failed to load cached artifact: %s", artifact_path) 885 return None 886 887 # Now re-evaluate with the symints to add any guards to the current env. 888 if graph.guards_expr: 889 check = bool( 890 shape_env.evaluate_guards_expression(graph.guards_expr, symints) 891 ) 892 assert check is True 893 log.debug( 894 "fx graph cache key %s post-load guards: %s", key, shape_env.guards 895 ) 896 897 # Increment the cached metrics by the amounts recorded when the FX 898 # graph was compiled for this cache entry. Pretending these counters 899 # were incremented normally is useful for testing with the cache enabled. 900 metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas) 901 902 return graph 903 904 @staticmethod 905 def _save_graph( 906 key: str, 907 compiled_graph: CompiledFxGraph, 908 example_inputs: List[torch.Tensor], 909 time_taken_ns, 910 local, 911 remote_cache, 912 ): 913 """ 914 Store a serialized CompiledFxGraph on disk. 915 """ 916 disk_compiled_graph = copy(compiled_graph) 917 # We can't really serialize callables that may be C++/Triton/etc., 918 # so we serialize their PyCodeCache disk cache location instead. 919 # TODO: This could be better if we're ever able to serialize compiled 920 # models to disk. 921 disk_compiled_graph.current_callable = None 922 923 # Before serializing, compute the guard expression that will be used to 924 # ensure that a CompiledFxGraph is valid when loaded from the cache. It's 925 # sufficient to consider only the SymInt args to the fx graph since the 926 # Tensor shapes are already captured in the hash for the cache key. Any 927 # Tensor arg with a symbolic shape will have a SymInt arg for the graph. 928 shape_env = FxGraphCache._get_shape_env() 929 assert shape_env is not None 930 symints = FxGraphCache._filter_backed_symints(example_inputs) 931 guards = shape_env.get_pruned_guards(symints) 932 disk_compiled_graph.guards_expr = shape_env.produce_guards_expression( 933 placeholders=symints, guards=guards 934 ) 935 936 try: 937 content = pickle.dumps(disk_compiled_graph) 938 except Exception: 939 log.warning( 940 "fx graph cache unable to serialize compiled graph", exc_info=True 941 ) 942 counters["inductor"]["fxgraph_cache_pickle_error"] += 1 943 return 944 945 try: 946 if local: 947 subdir = FxGraphCache._get_tmp_dir_for_key(key) 948 if not os.path.exists(subdir): 949 os.makedirs(subdir, exist_ok=True) 950 951 # Use a hash of the serialized CompiledFxGraph to get a unique file 952 # name. The specific name doesn't matter since a lookup involves 953 # iterating over all entries in the parent subdir. 954 path = os.path.join(subdir, sha256_hash(content)) 955 write_atomic(path, content, make_dirs=True) 956 957 if remote_cache: 958 cache_data = ( 959 { 960 "data": content, 961 "time_taken_ms": time_taken_ns 962 // 1000000, # Convert from NS to MS 963 } 964 if config.is_fbcode() 965 else content 966 ) 967 remote_cache.put(key, cache_data) 968 except Exception: 969 log.warning("fx graph unable to write to cache", exc_info=True) 970 counters["inductor"]["fxgraph_cache_write_error"] += 1 971 972 @staticmethod 973 def _check_can_cache(gm: torch.fx.GraphModule): 974 """ 975 Check some conditions that would preclude caching and raise BypassFxGraphCache 976 to bypass in case caching is not possible. 977 """ 978 # Freezing can embed constants that wouldn't be static across runs. 979 if config.freezing or config.aot_inductor.use_runtime_constant_folding: 980 raise BypassFxGraphCache 981 982 # The treatment of guards in the caching implementation requires that 983 # we have a shape env. 984 if FxGraphCache._get_shape_env() is None: 985 log.debug("fx graph cache no shape env") 986 raise BypassFxGraphCache 987 988 # HigherOrderOperators should be handled on a case-by-case basis. 989 # Currently, we just skip caching if we have any. 990 # We also skip if there are any torchbind objects. 991 for node in gm.graph.nodes: 992 if isinstance(node.target, torch._ops.HigherOrderOperator): 993 raise BypassFxGraphCache 994 if node.op == "getattr" and isinstance( 995 getattr(gm, node.target), torch._C.ScriptObject 996 ): 997 raise BypassFxGraphCache 998 999 @staticmethod 1000 def load( 1001 compile_fx_fn: Callable[..., Any], 1002 gm: torch.fx.GraphModule, 1003 example_inputs: List[torch.Tensor], 1004 fx_kwargs: Dict[str, Any], 1005 inputs_to_check: Sequence[int], 1006 local: bool, 1007 remote: bool, 1008 ): 1009 """ 1010 Load a compiled graph from the cache. If a cached entry does not exist, 1011 compile the graph and save it to the cache. 1012 """ 1013 assert local or remote, "at least one of them needs to be enabled" 1014 compiled_graph = None 1015 try: 1016 FxGraphCache._check_can_cache(gm) 1017 key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs, inputs_to_check) 1018 1019 remote_cache = None 1020 if remote: 1021 cache_id = "fx-graph-v1" 1022 try: 1023 if config.is_fbcode(): 1024 from triton.fb.fb_memcache import ( 1025 FbMemcacheRemoteFxGraphCacheBackend, 1026 ) 1027 1028 remote_cache = FbMemcacheRemoteFxGraphCacheBackend(cache_id) 1029 else: 1030 from torch._inductor.remote_cache import RedisRemoteCacheBackend 1031 1032 remote_cache = RedisRemoteCacheBackend(cache_id) 1033 except Exception: 1034 remote_cache = None 1035 log.warning("Unable to create a remote cache", exc_info=True) 1036 1037 compiled_graph = FxGraphCache._lookup_graph( 1038 key, example_inputs, local, remote_cache 1039 ) 1040 if compiled_graph is None: 1041 log.debug("fx graph cache miss for key %s", key) 1042 counters["inductor"]["fxgraph_cache_miss"] += 1 1043 start_time = time_ns() 1044 compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs) 1045 time_taken_ns = time_ns() - start_time 1046 FxGraphCache._save_graph( 1047 key, 1048 compiled_graph, 1049 example_inputs, 1050 time_taken_ns, 1051 local, 1052 remote_cache, 1053 ) 1054 else: 1055 log.debug("fx graph cache hit for key %s", key) 1056 counters["inductor"]["fxgraph_cache_hit"] += 1 1057 except BypassFxGraphCache: 1058 counters["inductor"]["fxgraph_cache_bypass"] += 1 1059 if not compiled_graph: 1060 compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs) 1061 1062 return compiled_graph 1063 1064 @staticmethod 1065 def clear(): 1066 """ 1067 Clear out the on-disk cache. 1068 """ 1069 try: 1070 shutil.rmtree(FxGraphCache._get_tmp_dir()) 1071 except FileNotFoundError: 1072 pass 1073 1074 1075@dataclasses.dataclass 1076class CompiledFxGraph: 1077 """ 1078 Class holding a compiled FX graph. This is the object serialized on disk 1079 to support FxGraph caching. 1080 """ 1081 1082 current_callable: Optional[Callable[..., Any]] 1083 cache_key: str 1084 source_code: str = dataclasses.field(repr=False) # Do not display source_code 1085 cache_linemap: Optional[List[Tuple[int, str]]] 1086 device_types: Set[str] 1087 device_idxs: Set[int] 1088 mutated_inputs: Set[str] 1089 mutated_input_idxs: Set[int] 1090 constants: Dict[str, torch.Tensor] 1091 torchbind_constants: Dict[str, torch._C.ScriptObject] 1092 output_strides: Optional[List[Optional[Tuple[int, ...]]]] 1093 disabled_cudagraphs_reason: Optional[str] 1094 metrics_deltas: metrics.CachedMetricsDeltas 1095 # This is a string representation of an expression we serialize 1096 # with the object so the guards can be evaluated in a different 1097 # context in order to verify the validity of serving a cached 1098 # fx graph. The expression must be generated by: 1099 # ShapeEnv.produce_guards_expression() 1100 guards_expr: Optional[str] 1101 1102 _boxed_call: Optional[bool] = None 1103 1104 def __init__( 1105 self, 1106 current_callable: Optional[Callable[..., Any]], 1107 graph: GraphLowering, 1108 output_strides: List[Optional[Tuple[int, ...]]], 1109 disabled_cudagraphs_reason: Optional[str], 1110 metrics_deltas: metrics.CachedMetricsDeltas, 1111 ): 1112 self.current_callable = current_callable 1113 self.cache_key = graph.cache_key 1114 if graph.cache_path: 1115 with open(graph.cache_path) as f: 1116 self.source_code = f.read() 1117 self.cache_linemap = graph.cache_linemap 1118 self.device_types = graph.device_types 1119 self.device_idxs = graph.device_idxs 1120 self.mutated_inputs = graph.mutated_inputs 1121 self.mutated_input_idxs = set(graph.mutated_input_idxs) 1122 self.constants = graph.constants 1123 self.torchbind_constants = graph.torchbind_constants 1124 self.output_strides = output_strides 1125 self.disabled_cudagraphs_reason = disabled_cudagraphs_reason 1126 self.metrics_deltas = metrics_deltas 1127 self.guards_expr = None 1128 1129 def __call__(self, inputs: List[Any]) -> Any: 1130 assert self.current_callable is not None 1131 return self.current_callable(inputs) 1132 1133 1134def cpp_compiler() -> str: 1135 if config.is_fbcode(): 1136 return build_paths.cc() if torch.version.hip is None else build_paths.clang() 1137 if isinstance(config.cpp.cxx, (list, tuple)): 1138 search = tuple(config.cpp.cxx) 1139 else: 1140 search = (config.cpp.cxx,) 1141 return cpp_compiler_search(search) 1142 1143 1144@functools.lru_cache(1) 1145def cpp_compiler_search(search: str) -> str: 1146 for cxx in search: 1147 try: 1148 if cxx is None: 1149 # gxx package is only available for Linux 1150 # according to https://anaconda.org/conda-forge/gxx/ 1151 if sys.platform != "linux": 1152 continue 1153 # Do not install GXX by default 1154 if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"): 1155 continue 1156 from filelock import FileLock 1157 1158 lock_dir = get_lock_dir() 1159 lock = FileLock( 1160 os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT 1161 ) 1162 with lock: 1163 cxx = install_gcc_via_conda() 1164 subprocess.check_output([cxx, "--version"]) 1165 return cxx 1166 except (subprocess.SubprocessError, FileNotFoundError, ImportError): 1167 continue 1168 raise exc.InvalidCxxCompiler 1169 1170 1171def install_gcc_via_conda() -> str: 1172 """On older systems, this is a quick way to get a modern compiler""" 1173 prefix = os.path.join(cache_dir(), "gcc") 1174 cxx_path = os.path.join(prefix, "bin", "g++") 1175 if not os.path.exists(cxx_path): 1176 log.info("Downloading GCC via conda") 1177 conda = os.environ.get("CONDA_EXE", "conda") 1178 if conda is None: 1179 conda = shutil.which("conda") 1180 if conda is not None: 1181 subprocess.check_call( 1182 [ 1183 conda, 1184 "create", 1185 f"--prefix={prefix}", 1186 "--channel=conda-forge", 1187 "--quiet", 1188 "-y", 1189 "python=3.8", 1190 "gxx", 1191 ], 1192 stdout=subprocess.PIPE, 1193 ) 1194 return cxx_path 1195 1196 1197def is_gcc() -> bool: 1198 if sys.platform == "darwin" and is_apple_clang(): 1199 return False 1200 return bool(re.search(r"(gcc|g\+\+)", cpp_compiler())) 1201 1202 1203@functools.lru_cache(None) 1204def is_apple_clang() -> bool: 1205 cxx = cpp_compiler() 1206 version_string = subprocess.check_output([cxx, "--version"]).decode("utf8") 1207 return "Apple" in version_string.splitlines()[0] 1208 1209 1210def is_clang() -> bool: 1211 # Mac OS apple clang maybe named as gcc, need check compiler info. 1212 if sys.platform == "darwin": 1213 return is_apple_clang() 1214 return bool(re.search(r"(clang|clang\+\+)", cpp_compiler())) 1215 1216 1217def get_compiler_version_info(compiler): 1218 SUBPROCESS_DECODE_ARGS = ("oem",) if _IS_WINDOWS else () 1219 env = os.environ.copy() 1220 env["LC_ALL"] = "C" # Don't localize output 1221 try: 1222 version_string = subprocess.check_output( 1223 [compiler, "-v"], stderr=subprocess.STDOUT, env=env 1224 ).decode(*SUBPROCESS_DECODE_ARGS) 1225 except Exception as e: 1226 try: 1227 version_string = subprocess.check_output( 1228 [compiler, "--version"], stderr=subprocess.STDOUT, env=env 1229 ).decode(*SUBPROCESS_DECODE_ARGS) 1230 except Exception as e: 1231 return "" 1232 # Mutiple lines to one line string. 1233 version_string = version_string.replace("\r", "_") 1234 version_string = version_string.replace("\n", "_") 1235 return version_string 1236 1237 1238def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str: 1239 # ISA dry compile will cost about 1 sec time each startup time. 1240 # Please check the issue: https://github.com/pytorch/pytorch/issues/100378 1241 # Actually, dry compile is checking compile capability for ISA. 1242 # We just record the compiler version, isa options and pytorch version info, 1243 # and generated them to output binary hash path. 1244 # It would optimize and skip compile existing binary. 1245 compiler_info = get_compiler_version_info(cpp_compiler()) 1246 torch_version = torch.__version__ 1247 fingerprint = f"{compiler_info}={isa_flags}={torch_version}" 1248 return fingerprint 1249 1250 1251class VecISA: 1252 _bit_width: int 1253 _macro: List[str] 1254 _arch_flags: str 1255 _dtype_nelements: Dict[torch.dtype, int] 1256 1257 # Note [Checking for Vectorized Support in Inductor] 1258 # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions 1259 # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions 1260 # like exp, pow, sin, cos and etc. 1261 # But PyTorch and TorchInductor might use different compilers to build code. If 1262 # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so 1263 # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass 1264 # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest 1265 # gcc/g++ compiler by default while it could support the AVX512 compilation. 1266 # Therefore, there would be a conflict sleef version between PyTorch and 1267 # TorchInductor. Hence, we dry-compile the following code to check whether current 1268 # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM 1269 # also needs the logic 1270 # In fbcode however, we are using the same compiler for pytorch and for inductor codegen, 1271 # making the runtime check unnecessary. 1272 _avx_code = """ 1273#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) 1274#include <ATen/cpu/vec/functional.h> 1275#include <ATen/cpu/vec/vec.h> 1276#endif 1277 1278alignas(64) float in_out_ptr0[16] = {0.0}; 1279 1280extern "C" void __avx_chk_kernel() { 1281 auto tmp0 = at::vec::Vectorized<float>(1); 1282 auto tmp1 = tmp0.exp(); 1283 tmp1.store(in_out_ptr0); 1284} 1285""" # noqa: B950 1286 1287 _avx_py_load = """ 1288import torch 1289from ctypes import cdll 1290cdll.LoadLibrary("__lib_path__") 1291""" 1292 1293 def bit_width(self) -> int: 1294 return self._bit_width 1295 1296 def nelements(self, dtype: torch.dtype = torch.float) -> int: 1297 return self._dtype_nelements[dtype] 1298 1299 def build_macro(self) -> List[str]: 1300 return self._macro 1301 1302 def build_arch_flags(self) -> str: 1303 return self._arch_flags 1304 1305 def __hash__(self) -> int: 1306 return hash(str(self)) 1307 1308 @functools.lru_cache(None) # noqa: B019 1309 def __bool__(self) -> bool: 1310 from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions 1311 1312 if config.cpp.vec_isa_ok is not None: 1313 return config.cpp.vec_isa_ok 1314 1315 if config.is_fbcode(): 1316 return True 1317 1318 key, input_path = write( 1319 VecISA._avx_code, 1320 "cpp", 1321 extra=_get_isa_dry_compile_fingerprint(self._arch_flags), 1322 ) 1323 from filelock import FileLock 1324 1325 lock_dir = get_lock_dir() 1326 lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) 1327 with lock: 1328 output_dir = os.path.dirname(input_path) 1329 buid_options = CppTorchOptions(vec_isa=self, warning_all=False) 1330 x86_isa_help_builder = CppBuilder( 1331 key, 1332 [input_path], 1333 buid_options, 1334 output_dir, 1335 ) 1336 try: 1337 # Check if the output file exist, and compile when not. 1338 output_path = x86_isa_help_builder.get_target_file_path() 1339 if not os.path.isfile(output_path): 1340 status, target_file = x86_isa_help_builder.build() 1341 if status: 1342 return False 1343 1344 # Check build result 1345 subprocess.check_call( 1346 [ 1347 sys.executable, 1348 "-c", 1349 VecISA._avx_py_load.replace("__lib_path__", output_path), 1350 ], 1351 stderr=subprocess.DEVNULL, 1352 env={**os.environ, "PYTHONPATH": ":".join(sys.path)}, 1353 ) 1354 except Exception as e: 1355 return False 1356 1357 return True 1358 1359 1360@dataclasses.dataclass 1361class VecNEON(VecISA): 1362 _bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h 1363 _macro = ["CPU_CAPABILITY_NEON"] 1364 if sys.platform == "darwin" and platform.processor() == "arm": 1365 _macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF") 1366 _arch_flags = "" # Unused 1367 _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} 1368 1369 def __str__(self) -> str: 1370 return "asimd" # detects the presence of advanced SIMD on armv8-a kernels 1371 1372 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 1373 1374 1375@dataclasses.dataclass 1376class VecAVX512(VecISA): 1377 _bit_width = 512 1378 _macro = ["CPU_CAPABILITY_AVX512"] 1379 _arch_flags = ( 1380 "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" 1381 if not _IS_WINDOWS 1382 else "/arch:AVX512" 1383 ) # TODO: use cflags 1384 _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} 1385 1386 def __str__(self) -> str: 1387 return "avx512" 1388 1389 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 1390 1391 1392@dataclasses.dataclass 1393class VecAVX2(VecISA): 1394 _bit_width = 256 1395 _macro = ["CPU_CAPABILITY_AVX2"] 1396 _arch_flags = ( 1397 "-mavx2 -mfma" if not _IS_WINDOWS else "/arch:AVX2" 1398 ) # TODO: use cflags 1399 _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} 1400 1401 def __str__(self) -> str: 1402 return "avx2" 1403 1404 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 1405 1406 1407@dataclasses.dataclass 1408class VecZVECTOR(VecISA): 1409 _bit_width = 256 1410 _macro = [ 1411 "CPU_CAPABILITY_ZVECTOR", 1412 "CPU_CAPABILITY=ZVECTOR", 1413 "HAVE_ZVECTOR_CPU_DEFINITION", 1414 ] 1415 _arch_flags = "-mvx -mzvector" 1416 _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} 1417 1418 def __str__(self) -> str: 1419 return "zvector" 1420 1421 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 1422 1423 1424class InvalidVecISA(VecISA): 1425 _bit_width = 0 1426 _macro = [""] 1427 _arch_flags = "" 1428 _dtype_nelements = {} 1429 1430 def __str__(self) -> str: 1431 return "INVALID_VEC_ISA" 1432 1433 def __bool__(self) -> bool: # type: ignore[override] 1434 return False 1435 1436 __hash__: Callable[[VecISA], Any] = VecISA.__hash__ 1437 1438 1439def x86_isa_checker() -> List[str]: 1440 supported_isa: List[str] = [] 1441 1442 def _check_and_append_supported_isa( 1443 dest: List[str], isa_supported: bool, isa_name: str 1444 ): 1445 if isa_supported: 1446 dest.append(isa_name) 1447 1448 Arch = platform.machine() 1449 """ 1450 Arch value is x86_64 on Linux, and the value is AMD64 on Windows. 1451 """ 1452 if Arch != "x86_64" and Arch != "AMD64": 1453 return supported_isa 1454 1455 avx2 = torch.cpu._is_cpu_support_avx2() 1456 avx512 = torch.cpu._is_cpu_support_avx512() 1457 1458 _check_and_append_supported_isa(supported_isa, avx2, "avx2") 1459 _check_and_append_supported_isa(supported_isa, avx512, "avx512") 1460 1461 return supported_isa 1462 1463 1464invalid_vec_isa = InvalidVecISA() 1465supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()] 1466 1467 1468# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content 1469# might have too much redundant content that is useless for ISA check. Hence, 1470# we only cache some key isa information. 1471@functools.lru_cache(None) 1472def valid_vec_isa_list() -> List[VecISA]: 1473 isa_list: List[VecISA] = [] 1474 if sys.platform == "darwin" and platform.processor() == "arm": 1475 isa_list.append(VecNEON()) 1476 1477 if sys.platform not in ["linux", "win32"]: 1478 return isa_list 1479 1480 if platform.machine() == "s390x": 1481 with open("/proc/cpuinfo") as _cpu_info: 1482 while True: 1483 line = _cpu_info.readline() 1484 if not line: 1485 break 1486 # process line 1487 featuresmatch = re.match(r"^features\s*:\s*(.*)$", line) 1488 if featuresmatch: 1489 for group in featuresmatch.groups(): 1490 if re.search(r"[\^ ]+vxe[\$ ]+", group): 1491 isa_list.append(VecZVECTOR()) 1492 break 1493 elif platform.machine() == "aarch64": 1494 isa_list.append(VecNEON()) 1495 elif platform.machine() in ["x86_64", "AMD64"]: 1496 """ 1497 platform.machine() value is x86_64 on Linux, and the value is AMD64 on Windows. 1498 """ 1499 _cpu_supported_x86_isa = x86_isa_checker() 1500 for isa in supported_vec_isa_list: 1501 if str(isa) in _cpu_supported_x86_isa and isa: 1502 isa_list.append(isa) 1503 1504 return isa_list 1505 1506 1507def pick_vec_isa() -> VecISA: 1508 if config.is_fbcode(): 1509 return VecAVX2() 1510 1511 _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list() 1512 if not _valid_vec_isa_list: 1513 return invalid_vec_isa 1514 1515 # If the simdlen is None, it indicates determine the vectorization length automatically 1516 if config.cpp.simdlen is None: 1517 assert _valid_vec_isa_list 1518 return _valid_vec_isa_list[0] 1519 1520 for isa in _valid_vec_isa_list: 1521 if config.cpp.simdlen == isa.bit_width(): 1522 return isa 1523 1524 return invalid_vec_isa 1525 1526 1527def get_compile_only(compile_only: bool = True) -> str: 1528 return "-c" if compile_only else "" 1529 1530 1531def get_shared(shared: bool = True, compile_only: bool = False) -> str: 1532 if not shared: 1533 return "" 1534 if compile_only: 1535 return "-fPIC" 1536 if platform.system() == "Darwin" and "clang" in cpp_compiler(): 1537 # This causes undefined symbols to behave the same as linux 1538 return "-shared -fPIC -undefined dynamic_lookup" 1539 else: 1540 return "-shared -fPIC" 1541 1542 1543def get_warning_all_flag(warning_all: bool = True) -> str: 1544 return "-Wall" if warning_all else "" 1545 1546 1547def get_glibcxx_abi_build_flags() -> str: 1548 return "-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI)) 1549 1550 1551def cpp_flags() -> str: 1552 flags = ["-std=c++17", "-Wno-unused-variable", "-Wno-unknown-pragmas"] 1553 if is_clang(): 1554 flags.append("-Werror=ignored-optimization-argument") 1555 return " ".join(flags) 1556 1557 1558def cpp_wrapper_flags() -> str: 1559 return "-D TORCH_INDUCTOR_CPP_WRAPPER" 1560 1561 1562def optimization_flags() -> str: 1563 base_flags = "-O0 -g" if config.aot_inductor.debug_compile else "-O3 -DNDEBUG" 1564 base_flags += " -ffast-math -fno-finite-math-only" 1565 if not config.cpp.enable_unsafe_math_opt_flag: 1566 base_flags += " -fno-unsafe-math-optimizations" 1567 if not config.cpp.enable_floating_point_contract_flag: 1568 base_flags += " -ffp-contract=off" 1569 1570 if config.is_fbcode(): 1571 # FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies. 1572 # This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths. 1573 # We will fix it later by exposing the lib path. 1574 return base_flags 1575 1576 if sys.platform == "darwin": 1577 # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang` 1578 # Also, `-march=native` is unrecognized option on M1 1579 base_flags += " -Xclang" 1580 else: 1581 if platform.machine() == "ppc64le": 1582 base_flags += " -mcpu=native" 1583 else: 1584 base_flags += " -march=native" 1585 1586 # Internal cannot find libgomp.so 1587 if not config.is_fbcode(): 1588 base_flags += " -fopenmp" 1589 return base_flags 1590 1591 1592def use_custom_generated_macros() -> str: 1593 return "-D C10_USING_CUSTOM_GENERATED_MACROS" 1594 1595 1596def use_fb_internal_macros() -> str: 1597 if config.is_fbcode(): 1598 # TODO: this is to avoid FC breakage for fbcode. When using newly 1599 # generated model.so on an older verion of PyTorch, need to use 1600 # the v1 version for aoti_torch_create_tensor_from_blob 1601 create_tensor_from_blob_v1 = "-D AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1" 1602 openmp_lib = build_paths.openmp_lib() 1603 preprocessor_flags = " ".join( 1604 ( 1605 "-D C10_USE_GLOG", 1606 "-D C10_USE_MINIMAL_GLOG", 1607 "-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY", 1608 ) 1609 ) 1610 return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags} {create_tensor_from_blob_v1}" 1611 else: 1612 return "" 1613 1614 1615def use_standard_sys_dir_headers() -> str: 1616 if config.is_fbcode(): 1617 return "-nostdinc" 1618 else: 1619 return "" 1620 1621 1622@functools.lru_cache(None) 1623def is_conda_llvm_openmp_installed() -> bool: 1624 try: 1625 command = "conda list llvm-openmp --json" 1626 output = subprocess.check_output(command.split()).decode("utf8") 1627 return len(json.loads(output)) > 0 1628 except subprocess.SubprocessError: 1629 return False 1630 1631 1632@functools.lru_cache(None) 1633def homebrew_libomp() -> Tuple[bool, str]: 1634 try: 1635 # check if `brew` is installed 1636 subprocess.check_output(["which", "brew"]) 1637 # get the location of `libomp` if it is installed 1638 # this is the location that `libomp` **would** be installed 1639 # see https://github.com/Homebrew/brew/issues/10261#issuecomment-756563567 for details 1640 libomp_path = ( 1641 subprocess.check_output(["brew", "--prefix", "libomp"]) 1642 .decode("utf8") 1643 .strip() 1644 ) 1645 # check if `libomp` is installed 1646 omp_available = os.path.exists(libomp_path) 1647 return omp_available, libomp_path 1648 except subprocess.SubprocessError: 1649 return False, "" 1650 1651 1652def _set_gpu_runtime_env() -> None: 1653 if ( 1654 config.is_fbcode() 1655 and torch.version.hip is None 1656 and "CUDA_HOME" not in os.environ 1657 and "CUDA_PATH" not in os.environ 1658 ): 1659 os.environ["CUDA_HOME"] = build_paths.cuda() 1660 1661 1662def _get_python_include_dirs(): 1663 include_dir = Path(sysconfig.get_path("include")) 1664 # On Darwin Python executable from a framework can return 1665 # non-existing /Library/Python/... include path, in which case 1666 # one should use Headers folder from the framework 1667 if not include_dir.exists() and platform.system() == "Darwin": 1668 std_lib = Path(sysconfig.get_path("stdlib")) 1669 include_dir = (std_lib.parent.parent / "Headers").absolute() 1670 if not (include_dir / "Python.h").exists(): 1671 warnings.warn(f"Can't find Python.h in {str(include_dir)}") 1672 return [str(include_dir)] 1673 1674 1675def _transform_cuda_paths(lpaths): 1676 # This handles two cases: 1677 # 1. Meta internal cuda-12 where libs are in lib/cuda-12 and lib/cuda-12/stubs 1678 # 2. Linux machines may have CUDA installed under either lib64/ or lib/ 1679 for i, path in enumerate(lpaths): 1680 if ( 1681 "CUDA_HOME" in os.environ 1682 and path.startswith(os.environ["CUDA_HOME"]) 1683 and not os.path.exists(f"{path}/libcudart_static.a") 1684 ): 1685 for root, dirs, files in os.walk(path): 1686 if "libcudart_static.a" in files: 1687 lpaths[i] = os.path.join(path, root) 1688 lpaths.append(os.path.join(lpaths[i], "stubs")) 1689 break 1690 1691 1692def get_include_and_linking_paths( 1693 include_pytorch: bool = False, 1694 vec_isa: VecISA = invalid_vec_isa, 1695 cuda: bool = False, 1696 aot_mode: bool = False, 1697) -> Tuple[List[str], str, str, str, str]: 1698 _set_gpu_runtime_env() 1699 from torch.utils import cpp_extension 1700 1701 # Remove below in the further 1702 # macros = "-D {}".format(vec_isa.build_macro()) if vec_isa != invalid_vec_isa else "" 1703 macros = "" 1704 if vec_isa != invalid_vec_isa: 1705 for x in vec_isa.build_macro(): 1706 macros_def = f"-D {x} " 1707 macros += macros_def 1708 1709 build_arch_flags = "" 1710 if sys.platform == "linux" and ( 1711 include_pytorch 1712 or vec_isa != invalid_vec_isa 1713 or cuda 1714 or config.cpp.enable_kernel_profile 1715 ): 1716 # Note - We include pytorch only on linux right now. There is more work 1717 # to do to enable OMP build on darwin where PyTorch is built with IOMP 1718 # and we need a way to link to what PyTorch links. 1719 ipaths = cpp_extension.include_paths(cuda) + _get_python_include_dirs() 1720 lpaths = cpp_extension.library_paths(cuda) + [ 1721 sysconfig.get_config_var("LIBDIR") 1722 ] 1723 1724 libs = [] 1725 1726 # No need to manually specify libraries in fbcode. 1727 if not config.is_fbcode(): 1728 libs += ["torch", "torch_cpu"] 1729 libs += ["gomp"] 1730 if not aot_mode: 1731 libs += ["torch_python"] 1732 else: 1733 # internal remote execution is able to find omp, but not gomp 1734 libs += ["omp"] 1735 if aot_mode: 1736 ipaths += [os.path.dirname(cpp_prefix_path())] 1737 if cuda and torch.version.hip is None: 1738 _transform_cuda_paths(lpaths) 1739 if macros: 1740 if config.is_fbcode() and vec_isa != invalid_vec_isa: 1741 cap = str(vec_isa).upper() 1742 macros = " ".join( 1743 [ 1744 vec_isa.build_arch_flags(), 1745 f"-D CPU_CAPABILITY={cap}", 1746 f"-D CPU_CAPABILITY_{cap}", 1747 f"-D HAVE_{cap}_CPU_DEFINITION", 1748 ] 1749 ) 1750 1751 if cuda: 1752 if macros is None: 1753 macros = "" 1754 macros += " -D USE_ROCM" if torch.version.hip else " -D USE_CUDA" 1755 1756 if cuda: 1757 if torch.version.hip is not None: 1758 if config.is_fbcode(): 1759 libs += ["amdhip64"] 1760 else: 1761 libs += ["c10_hip", "torch_hip"] 1762 macros += " -D __HIP_PLATFORM_AMD__" 1763 else: 1764 if config.is_fbcode(): 1765 libs += ["cuda"] 1766 else: 1767 libs += ["c10_cuda", "cuda", "torch_cuda"] 1768 build_arch_flags = vec_isa.build_arch_flags() 1769 else: 1770 # Note - this is effectively a header only inclusion. Usage of some header files may result in 1771 # symbol not found, if those header files require a library. 1772 # For those cases, include the lpath and libs command as we do for pytorch above. 1773 # This approach allows us to only pay for what we use. 1774 ipaths = cpp_extension.include_paths(cuda) + _get_python_include_dirs() 1775 if aot_mode: 1776 ipaths += [os.path.dirname(cpp_prefix_path())] 1777 lpaths = [] 1778 if sys.platform == "darwin": 1779 # only Apple builtin compilers (Apple Clang++) require openmp 1780 omp_available = not is_apple_clang() 1781 1782 # check the `OMP_PREFIX` environment first 1783 if os.getenv("OMP_PREFIX") is not None: 1784 header_path = os.path.join(os.getenv("OMP_PREFIX"), "include", "omp.h") # type: ignore[arg-type] 1785 valid_env = os.path.exists(header_path) 1786 if valid_env: 1787 ipaths.append(os.path.join(os.getenv("OMP_PREFIX"), "include")) # type: ignore[arg-type] 1788 lpaths.append(os.path.join(os.getenv("OMP_PREFIX"), "lib")) # type: ignore[arg-type] 1789 else: 1790 warnings.warn("environment variable `OMP_PREFIX` is invalid.") 1791 omp_available = omp_available or valid_env 1792 1793 libs = [] if omp_available else ["omp"] 1794 1795 # prefer to use openmp from `conda install llvm-openmp` 1796 if not omp_available and os.getenv("CONDA_PREFIX") is not None: 1797 omp_available = is_conda_llvm_openmp_installed() 1798 if omp_available: 1799 conda_lib_path = os.path.join(os.getenv("CONDA_PREFIX"), "lib") # type: ignore[arg-type] 1800 ipaths.append(os.path.join(os.getenv("CONDA_PREFIX"), "include")) # type: ignore[arg-type] 1801 lpaths.append(conda_lib_path) 1802 # Prefer Intel OpenMP on x86 machine 1803 if os.uname().machine == "x86_64" and os.path.exists( 1804 os.path.join(conda_lib_path, "libiomp5.dylib") 1805 ): 1806 libs = ["iomp5"] 1807 1808 # next, try to use openmp from `brew install libomp` 1809 if not omp_available: 1810 omp_available, libomp_path = homebrew_libomp() 1811 if omp_available: 1812 ipaths.append(os.path.join(libomp_path, "include")) 1813 lpaths.append(os.path.join(libomp_path, "lib")) 1814 1815 # if openmp is still not available, we let the compiler to have a try, 1816 # and raise error together with instructions at compilation error later 1817 else: 1818 libs = ["omp"] if config.is_fbcode() else ["gomp"] 1819 1820 # For AOT mode, the produced library relies on torch cpu to set grad mode 1821 # like aoti_torch_grad_mode_set_enabled 1822 if aot_mode and sys.platform == "linux" and not config.is_fbcode(): 1823 libs += ["torch", "torch_cpu"] 1824 1825 # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 1826 if not config.abi_compatible: 1827 libs += ["c10"] 1828 lpaths += [cpp_extension.TORCH_LIB_PATH] 1829 1830 # third party libs 1831 if config.is_fbcode(): 1832 # Note that the order of include paths do matter, as a result 1833 # we need to have several branches interleaved here 1834 if torch.version.hip is None: 1835 ipaths.append(build_paths.sleef()) 1836 ipaths.append(build_paths.openmp()) 1837 ipaths.append(build_paths.python()) 1838 if torch.version.hip is not None: 1839 ipaths.append(build_paths.clang_include()) 1840 ipaths.append(build_paths.gcc_include()) 1841 ipaths.append(build_paths.gcc_install_tools_include()) 1842 else: 1843 ipaths.append(build_paths.cc_include()) 1844 ipaths.append(build_paths.libgcc()) 1845 ipaths.append(build_paths.libgcc_arch()) 1846 ipaths.append(build_paths.libgcc_backward()) 1847 ipaths.append(build_paths.glibc()) 1848 ipaths.append(build_paths.linux_kernel()) 1849 if torch.version.hip is not None: 1850 ipaths.append(build_paths.rocm()) 1851 else: 1852 ipaths.append(os.path.join(build_paths.cuda(), "include")) 1853 # We also need to bundle includes with absolute paths into a remote directory 1854 # (later on, we copy the include paths from cpp_extensions into our remote dir) 1855 ipaths.append("include") 1856 1857 static_link_libs = [] 1858 if aot_mode and cuda and config.is_fbcode(): 1859 # For Meta internal cuda-12, it is recommended to static link cudart 1860 if torch.version.hip is None: 1861 static_link_libs = ["-Wl,-Bstatic", "-lcudart_static", "-Wl,-Bdynamic"] 1862 1863 lpaths_str = " ".join(["-L" + p for p in lpaths]) 1864 libs_str = " ".join(static_link_libs + ["-l" + p for p in libs]) 1865 return ipaths, lpaths_str, libs_str, macros, build_arch_flags 1866 1867 1868def cpp_compile_command( 1869 input: Union[str, List[str]], 1870 output: str, 1871 warning_all: bool = True, 1872 shared: bool = True, 1873 include_pytorch: bool = False, 1874 vec_isa: VecISA = invalid_vec_isa, 1875 cuda: bool = False, 1876 aot_mode: bool = False, 1877 compile_only: bool = False, 1878 use_absolute_path: bool = False, 1879 use_mmap_weights: bool = False, 1880 extra_flags: Sequence[str] = (), 1881) -> str: 1882 ipaths, lpaths, libs, macros, build_arch_flags = get_include_and_linking_paths( 1883 include_pytorch, vec_isa, cuda, aot_mode 1884 ) 1885 if isinstance(input, str): 1886 input = [input] 1887 ipaths_str = " ".join(["-I" + p for p in ipaths]) 1888 clang_flags = "" 1889 if config.is_fbcode(): 1890 if aot_mode and not use_absolute_path: 1891 inp_name = input 1892 out_name = output 1893 linker_script = _LINKER_SCRIPT 1894 else: 1895 # We need to copy any absolute-path torch includes 1896 inp_name = [os.path.basename(i) for i in input] 1897 out_name = os.path.basename(output) 1898 linker_script = os.path.basename(_LINKER_SCRIPT) 1899 assert is_clang() 1900 # Use clang runtime instead of libgcc 1901 clang_flags += " --rtlib=compiler-rt" 1902 clang_flags += " -fuse-ld=lld" 1903 clang_flags += f" -Wl,--script={linker_script}" 1904 linker_paths = "-B" + build_paths.glibc_lib() 1905 linker_paths += " -L" + build_paths.glibc_lib() 1906 else: 1907 inp_name = input 1908 out_name = output 1909 linker_paths = "" # let the compiler pick 1910 if compile_only: 1911 libs, lpaths = "", "" 1912 inp_name_str = " ".join(inp_name) 1913 if use_mmap_weights: 1914 macros += " -D USE_MMAP_SELF" 1915 1916 return re.sub( 1917 r"[ \n]+", 1918 " ", 1919 f""" 1920 {cpp_compiler()} {inp_name_str} {get_shared(shared, compile_only)} 1921 {get_warning_all_flag(warning_all)} {cpp_flags()} 1922 {get_glibcxx_abi_build_flags()} 1923 {ipaths_str} {lpaths} {libs} {build_arch_flags} 1924 {macros} {linker_paths} {clang_flags} 1925 {optimization_flags()} {cpp_wrapper_flags()} 1926 {use_custom_generated_macros()} 1927 {use_fb_internal_macros()} 1928 {use_standard_sys_dir_headers()} 1929 {get_compile_only(compile_only)} 1930 {' '.join(extra_flags)} 1931 -o {out_name} 1932 """, 1933 ).strip() 1934 1935 1936def run_command_and_check(cmd: str): 1937 cmd = shlex.split(cmd) 1938 try: 1939 subprocess.check_call(cmd) 1940 except subprocess.CalledProcessError as e: 1941 raise exc.CppCompileError(cmd, e.output) from e 1942 1943 1944@functools.lru_cache(None) 1945def split_aot_inductor_output_path(path: str) -> Tuple[str, str]: 1946 """Returns the path where the AOT Inductor compiled kernels are stored.""" 1947 if path.endswith(".so"): 1948 return os.path.split(path) 1949 else: 1950 return path, "" 1951 1952 1953@clear_on_fresh_inductor_cache 1954class CudaKernelParamCache: 1955 cache: Dict[str, Dict[str, str]] = dict() 1956 cache_clear = staticmethod(cache.clear) 1957 1958 @classmethod 1959 def set(cls, key: str, params: Dict[str, str], cubin: str) -> None: 1960 bin_type = "cubin" if torch.version.hip is None else "hsaco" 1961 _, path = write( 1962 cubin, 1963 bin_type, 1964 hash_type=bin_type, 1965 specified_dir=split_aot_inductor_output_path( 1966 config.aot_inductor.output_path 1967 )[0], 1968 ) 1969 1970 params[get_cpp_wrapper_cubin_path_name()] = path 1971 1972 cls.cache[key] = params 1973 1974 @classmethod 1975 def get(cls, key: str) -> Optional[Dict[str, str]]: 1976 return cls.cache.get(key, None) 1977 1978 @classmethod 1979 def get_keys(cls): 1980 return cls.cache.keys() 1981 1982 1983class AotCodeCompiler: 1984 @classmethod 1985 def compile( 1986 cls, 1987 graph: GraphLowering, 1988 source_code: str, 1989 serialized_extern_kernel_nodes: Optional[str], 1990 cuda: bool, 1991 ) -> str: 1992 picked_vec_isa = pick_vec_isa() 1993 cpp_command = repr( 1994 cpp_compile_command( 1995 "i", 1996 "o", 1997 vec_isa=picked_vec_isa, 1998 cuda=cuda, 1999 aot_mode=graph.aot_mode, 2000 ) 2001 ) 2002 fbcode_aot_cpu_re = False 2003 use_absolute_path = False 2004 if config.is_fbcode(): 2005 ld_command = build_paths.ld() 2006 if not cuda and graph.aot_mode: # Meta internal AOTInductor CPU 2007 objcopy_command = build_paths.objcopy_fallback() 2008 fbcode_aot_cpu_re = True 2009 use_absolute_path = True 2010 else: 2011 objcopy_command = build_paths.objcopy() 2012 else: 2013 ld_command = "ld" 2014 objcopy_command = "objcopy" 2015 2016 ( 2017 specified_output_path, 2018 specified_so_name, 2019 ) = split_aot_inductor_output_path(config.aot_inductor.output_path) 2020 key, input_path = write( 2021 source_code, 2022 "cpp", 2023 extra=cpp_command, 2024 specified_dir=specified_output_path, 2025 ) 2026 output_code_log.info("Output code written to: %s", input_path) 2027 trace_structured( 2028 "graph_dump", 2029 lambda: { 2030 "name": "inductor_aot_code", 2031 "type": "cpp", 2032 "filename": input_path, 2033 }, 2034 payload_fn=lambda: source_code, 2035 ) 2036 2037 def _compile_consts_linux(consts: bytes) -> str: 2038 _, consts_path = write( 2039 consts, 2040 "bin", 2041 specified_dir=specified_output_path, 2042 ) 2043 2044 consts_o = os.path.splitext(consts_path)[0] + ".o" 2045 if fbcode_aot_cpu_re: 2046 cmd = f"{ld_command} -r -b binary -o {os.path.basename(consts_o)} {os.path.basename(consts_path)}" 2047 compile_file(consts_path, consts_o, cmd.split()) 2048 os.chmod(consts_o, 0o644) 2049 else: 2050 cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}" 2051 run_command_and_check(cmd) 2052 log.debug("aot constant binary command: %s", cmd) 2053 2054 if graph.mutated_buffers & set(graph.constants.keys()): 2055 # .data section is between .text and .bss. When the size of .data is large, 2056 # during the linking, the relocation of .text against .bss may overflow. 2057 # Rename it to .ldata so that it won't be in between the .text and .bss section 2058 if len(consts) > 2_000_000_000: 2059 raise ValueError( 2060 "Models with buffer mutation included doesn't support constants greater than 2GB!" 2061 ) 2062 rename_data = " .data=.ldata" 2063 else: 2064 # if no buffer mutation is needed, we could instead set the data region 2065 # as read-only (i.e. .lrodata) which could accomodate larger size of data 2066 # to be linked. 2067 rename_data = " .data=.lrodata,alloc,load,readonly,data,contents" 2068 2069 assert ( 2070 ALIGN_BYTES & (ALIGN_BYTES - 1) 2071 ) == 0 and ALIGN_BYTES >= 64, "must be power of 2 and >= 64" 2072 cmd = ( 2073 f"{objcopy_command} --rename-section" 2074 f"{rename_data}" 2075 f" --set-section-alignment .data={ALIGN_BYTES}" # following the gAlignment of CPU in c10/core/alignment.h 2076 f" {consts_o} {consts_o}" 2077 ) 2078 log.debug("aot constant rename section command: %s", cmd) 2079 run_command_and_check(cmd) 2080 2081 cmd = f"rm {consts_path}" 2082 log.debug("aot constant bin removal command: %s", cmd) 2083 run_command_and_check(cmd) 2084 2085 if fbcode_aot_cpu_re: 2086 body = re.sub(r"[\W]", "_", os.path.basename(consts_path)) 2087 else: 2088 body = re.sub(r"[\W]", "_", consts_path) 2089 2090 symbol_list = [] 2091 symbol_list.append( 2092 f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}" 2093 ) 2094 symbol_list.append( 2095 f"{objcopy_command} --redefine-sym _binary_{body}_size=_binary_constants_bin_size {consts_o}" 2096 ) 2097 symbol_list.append( 2098 f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}" 2099 ) 2100 log.debug("aot constant binary redefine symbol: %s", " ".join(symbol_list)) 2101 for cmd in symbol_list: 2102 run_command_and_check(cmd) 2103 return consts_o 2104 2105 def _compile_consts_darwin(consts: bytes) -> str: 2106 if config.aot_inductor.debug_dump_consts_bin: 2107 _, _binary_constants_path = write( 2108 consts, 2109 "bin", 2110 specified_dir=specified_output_path, 2111 ) 2112 log.debug("binary constants path: %s", _binary_constants_path) 2113 2114 is_large_consts = len(consts) > 1024 2115 consts_asm = "\t.section\t__DATA,__data\n" 2116 consts_asm += "\t.globl\t__binary_constants_bin_start\n" 2117 consts_asm += "__binary_constants_bin_start:\n" 2118 if not is_large_consts: 2119 for c in consts: 2120 consts_asm += f"\t.byte {c}\n" 2121 # Add one element even if constants are empty 2122 # Otherwise assembler will not put them in data section 2123 if not consts: 2124 consts_asm += "\t.space 1\n" 2125 else: 2126 consts_asm += "\t.quad 0x1234567899abcdef\n" 2127 consts_asm += f"\t.space {len(consts) - 8}\n" 2128 consts_asm += ".globl\t__binary_constants_bin_end\n" 2129 consts_asm += "__binary_constants_bin_end:\n" 2130 _, consts_path = write( 2131 consts_asm, 2132 "S", 2133 specified_dir=specified_output_path, 2134 ) 2135 consts_o = os.path.splitext(consts_path)[0] + ".o" 2136 cmd = f"{cpp_compiler()} -c -o {consts_o} {consts_path}" 2137 run_command_and_check(cmd) 2138 if is_large_consts: 2139 with open(consts_o, "r+b") as f: 2140 f.seek(0) 2141 hdr = f.read(1024) 2142 # Search for magic number and write the actual data over it 2143 start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") 2144 assert start_idx != -1 2145 f.seek(start_idx) 2146 pos = 0 2147 while pos < len(consts): 2148 rc = f.write(consts[pos:]) 2149 pos += rc 2150 return consts_o 2151 2152 from filelock import FileLock 2153 2154 lock_dir = get_lock_dir() 2155 lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) 2156 with lock: 2157 # Currently, this only support serializing extern nodes in fbcode 2158 # Eventually, we should also have a serializer for OSS. 2159 if config.is_fbcode() and serialized_extern_kernel_nodes: 2160 output_json = os.path.splitext(input_path)[0] + ".json" 2161 with open(output_json, "w") as f: 2162 f.write(serialized_extern_kernel_nodes) 2163 2164 output_so = ( 2165 config.aot_inductor.output_path 2166 if specified_so_name 2167 else os.path.splitext(input_path)[0] + ".so" 2168 ) 2169 2170 output_o = os.path.splitext(input_path)[0] + ".o" 2171 consts_size = sum( 2172 torch.ops.mkldnn._nbytes(tensor) 2173 if tensor.is_mkldnn 2174 else tensor.untyped_storage().nbytes() 2175 for (name, tensor) in graph.constants.items() 2176 if name not in graph.folded_constants 2177 ) 2178 # TODO: Fix mmap weights with cuda 2179 use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000 2180 if config.aot_inductor.force_mmap_weights: 2181 use_mmap_weights = True 2182 compile_cmd = cpp_compile_command( 2183 input=input_path, 2184 output=output_o, 2185 vec_isa=picked_vec_isa, 2186 cuda=cuda, 2187 aot_mode=graph.aot_mode, 2188 compile_only=True, 2189 use_absolute_path=use_absolute_path, 2190 use_mmap_weights=use_mmap_weights, 2191 ) 2192 log.debug("aot compilation command: %s", compile_cmd) 2193 if fbcode_aot_cpu_re: 2194 compile_file(input_path, output_o, compile_cmd.split()) 2195 os.chmod(output_o, 0o644) 2196 else: 2197 run_command_and_check(compile_cmd) 2198 2199 def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: 2200 def _pad_to_alignment(raw_bytes): 2201 padded_bytes = raw_bytes.ljust( 2202 (len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES, 2203 b"\x00", 2204 ) 2205 return padded_bytes 2206 2207 # This serializes the tensor's untyped_storage to bytes by accessing 2208 # the raw data of the underlying structure. 2209 import ctypes 2210 2211 if t.numel() == 0: 2212 return b"" 2213 2214 if t.is_mkldnn: 2215 data_ptr = torch.ops.mkldnn.data_ptr(t) 2216 nbytes = torch.ops.mkldnn._nbytes(t) 2217 else: 2218 t_cpu = t.untyped_storage().cpu() 2219 data_ptr = t_cpu.data_ptr() 2220 nbytes = t_cpu.nbytes() 2221 2222 raw_array = ctypes.cast( 2223 data_ptr, 2224 ctypes.POINTER(ctypes.c_ubyte * nbytes), 2225 ) 2226 raw_bytes = bytes(raw_array.contents) 2227 return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) 2228 2229 all_cuda = all( 2230 graph.get_original_value_of_constant(name).is_cuda 2231 for name in graph.constants.keys() 2232 if name not in graph.folded_constants 2233 ) 2234 serialized_weights = b"".join( 2235 _to_bytes(graph.get_original_value_of_constant(name), all_cuda) 2236 for name in graph.constants.keys() 2237 if name not in graph.folded_constants 2238 ) 2239 if not use_mmap_weights: 2240 aot_constants = serialized_weights 2241 magic_number = 0 2242 else: 2243 magic_number = cast( 2244 int, torch.randint(0, torch.iinfo(torch.int64).max, (1,)).item() 2245 ) 2246 aot_constants = struct.pack("qq", consts_size + 8, magic_number) 2247 consts_o = { 2248 "linux": _compile_consts_linux, 2249 "darwin": _compile_consts_darwin, 2250 }[sys.platform](aot_constants) 2251 2252 link_cmd = cpp_compile_command( 2253 input=[output_o, consts_o], 2254 output=output_so, 2255 vec_isa=picked_vec_isa, 2256 cuda=cuda, 2257 aot_mode=graph.aot_mode, 2258 use_absolute_path=use_absolute_path, 2259 ) 2260 log.debug("aot linkage command: %s", link_cmd) 2261 if fbcode_aot_cpu_re: 2262 compile_file([output_o, consts_o], output_so, link_cmd.split()) 2263 os.chmod(output_so, 0o755) 2264 else: 2265 run_command_and_check(link_cmd) 2266 2267 if use_mmap_weights: 2268 with open(output_so, "a+b") as f_so: 2269 so_size = f_so.tell() 2270 # Page align the weights 2271 f_so.write(b" " * (16384 - so_size % 16384)) 2272 f_so.write(serialized_weights) 2273 f_so.write(struct.pack("q", magic_number)) 2274 2275 # Append cmds to the end of codegen-ed wrapper file 2276 with open(input_path, "a") as f: 2277 f.write("\n") 2278 f.write(f"// Compile cmd\n// {compile_cmd}\n") 2279 f.write(f"// Link cmd\n// {link_cmd}\n") 2280 2281 return output_so 2282 2283 2284# Putting this fn in cpp.py (unfortunately) causes a deadlock, which is why it's in codecache.py. 2285# Why? importing from cpp.py invokes codecache.pick_vec_isa(), which takes out a lock. 2286# Cycle goes: 2287# - CppCodeCache.load() 2288# - pick_vec_isa() 2289# - valid_vec_isa_list() 2290# - VecISA.__bool__() <-- takes out a lock 2291# - compile_file() <-- imports cpp_prefix_path from cpp, which causes us to try to take out the same lock. 2292@clear_on_fresh_inductor_cache 2293@functools.lru_cache 2294def cpp_prefix_path() -> str: 2295 path = Path(__file__).parent / "codegen/cpp_prefix.h" 2296 with path.open() as f: 2297 content = f.read() 2298 _, filename = write( 2299 content, 2300 "h", 2301 ) 2302 return filename 2303 2304 2305def cpp_prefix() -> str: 2306 filename = cpp_prefix_path() 2307 if config.is_fbcode(): 2308 # We need relative paths, since we bundle up 2309 # everything that we compile into a folder for remote compilation. 2310 return f'#include "{os.path.basename(filename)}"' 2311 else: 2312 return f'#include "{filename}"' 2313 2314 2315# Given a path to an input cpp file and an output path, 2316# Attempts to compile the file, storing the output in "output_path" 2317@dynamo_timed 2318def compile_file( 2319 input_path: Union[str, List[str]], output_path: str, cmd: List[str] 2320) -> None: 2321 input_paths = [input_path] if isinstance(input_path, str) else input_path 2322 input_files = [ 2323 os.path.basename(ip) if config.is_fbcode() else ip for ip in input_paths 2324 ] 2325 try: 2326 if config.is_fbcode(): 2327 # Need to copy our header into the same folder as the sourcecode. 2328 header_path = cpp_prefix_path() 2329 header_name = os.path.basename(header_path) 2330 output_name = os.path.basename(output_path) 2331 # When we build remotely, we need to make sure to carefully copy any files 2332 # that are required during the compilation process into our build directly. 2333 # This is where all of the ATen/c10/Torch includes come from. 2334 torch_includes_path = os.path.join(_TORCH_PATH, "include") 2335 with tempfile.TemporaryDirectory() as tmp_dir: 2336 # Copy everything to tmp compilation folder 2337 shutil.copy(header_path, os.path.join(tmp_dir, header_name)) 2338 shutil.copy(_LINKER_SCRIPT, os.path.join(tmp_dir, "script.ld")) 2339 for p, f in zip(input_paths, input_files): 2340 shutil.copy(p, os.path.join(tmp_dir, f)) 2341 dest_include_path = os.path.join(tmp_dir, "include") 2342 shutil.copytree(torch_includes_path, dest_include_path) 2343 # Run the build 2344 output_file_path = _run_build_command(cmd, tmp_dir, output_name) 2345 # Copy output from the build 2346 if os.path.exists(output_path): 2347 os.remove(output_path) 2348 shutil.copy(output_file_path, output_path) 2349 else: 2350 subprocess.check_output(cmd, stderr=subprocess.STDOUT) 2351 except subprocess.CalledProcessError as e: 2352 output = e.output.decode("utf-8") 2353 openmp_problem = "'omp.h' file not found" in output or "libomp" in output 2354 if openmp_problem and sys.platform == "darwin": 2355 instruction = ( 2356 "\n\nOpenMP support not found. Please try one of the following solutions:\n" 2357 "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " 2358 "that has builtin OpenMP support;\n" 2359 "(2) install OpenMP via conda: `conda install llvm-openmp`;\n" 2360 "(3) install libomp via brew: `brew install libomp`;\n" 2361 "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" 2362 " with `include/omp.h` under it." 2363 ) 2364 output += instruction 2365 raise exc.CppCompileError(cmd, output) from e 2366 2367 2368_libgomp: Optional[CDLL] = None 2369 2370 2371def custom_op_wrapper(op: str, *args): 2372 # This function will be called from generated cpp wrapper code in the JIT mode. 2373 # Because tensors will be passed in as AtenTensorHandle, we need to explicitly convert them. 2374 def convert_arg(arg): 2375 if str(type(arg)) == "<class 'PyCapsule'>": 2376 # No easy way to do isinstance check on PyCapsule 2377 return torch._C._aoti.alloc_tensor_by_stealing_from_void_ptr(arg) 2378 elif isinstance(arg, (list, tuple)): 2379 return type(arg)(convert_arg(a) for a in arg) 2380 else: 2381 return arg 2382 2383 converted_args = [convert_arg(arg) for arg in args] 2384 2385 assert op.startswith("torch.ops."), ( 2386 op + " can not be called through custom_op_wrapper" 2387 ) 2388 func = None 2389 for i, s in enumerate(op.split(".")): 2390 if i == 0: 2391 func = importlib.import_module(s) 2392 func = getattr(func, s) 2393 2394 assert callable(func), op + " can not be loaded through custom_op_wrapper" 2395 result = func(*converted_args) 2396 if isinstance(result, (list, tuple)): 2397 for r in result: 2398 assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" 2399 return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type] 2400 else: 2401 assert isinstance(result, torch.Tensor), op + " returns a non-tensor" 2402 return torch._C._aoti.unsafe_alloc_void_ptr_from_tensor(result) 2403 2404 2405@clear_on_fresh_inductor_cache 2406class CppCodeCache: 2407 cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} 2408 cache_clear = staticmethod(cache.clear) 2409 cpp_compile_command_flags: Dict[str, Any] = {} 2410 2411 @staticmethod 2412 def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]: 2413 return cdll.LoadLibrary(path) 2414 2415 @classmethod 2416 def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]: 2417 try: 2418 result = cls._load_library_inner(path, key) 2419 result.key = key # type: ignore[union-attr] 2420 return result 2421 except (ImportError, OSError) as e: 2422 if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"): 2423 # hacky workaround for fbcode/buck 2424 global _libgomp 2425 _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1") 2426 result = cls._load_library_inner(path, key) 2427 result.key = key # type: ignore[union-attr] 2428 return result 2429 if "failed to map segment from shared object" in str(e): 2430 raise OSError( 2431 f"{e}. The most common reason this may occur is if the {tempfile.gettempdir()} folder " 2432 "is mounted with noexec (e.g., by default Docker mounts tmp file systems " 2433 f"as noexec). Please remount {tempfile.gettempdir()} with exec enabled, or set another " 2434 "temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable." 2435 ) from e 2436 raise 2437 2438 @classmethod 2439 def load_async(cls, source_code: str, cuda=False, submit_fn=None, extra_flags=()): 2440 compile_command = { 2441 **cls.cpp_compile_command_flags, 2442 "cuda": cuda, 2443 "vec_isa": pick_vec_isa(), 2444 "extra_flags": extra_flags, 2445 } 2446 2447 _set_gpu_runtime_env() # cpp_extension consults the env 2448 2449 from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions 2450 2451 dummy_builder = CppBuilder( 2452 name="o", sources="i", BuildOption=CppTorchCudaOptions(**compile_command) 2453 ) 2454 # write function will calc source_code hash, the same source code with different 2455 # ISA level should be generate different hash. 2456 # So we need get a command_line which contains isa related parameter as a part of hash key. 2457 # And then pass the command_line to below write function as extra parameter to 2458 # guarantee the source code hash contains ISA difference. 2459 dummy_cmd = repr(dummy_builder.get_command_line()) 2460 key, input_path = write(source_code, "cpp", extra=dummy_cmd) 2461 2462 if key not in cls.cache: 2463 from filelock import FileLock 2464 2465 lock_path = os.path.join(get_lock_dir(), key + ".lock") 2466 output_path = input_path[:-3] + "so" 2467 future: Optional[Future[Any]] = None 2468 lib = None 2469 worker_fn = functools.partial( 2470 _worker_compile_cpp, 2471 lock_path, 2472 input_path, 2473 output_path, 2474 cpp_compile_command( 2475 input=input_path, output=output_path, **compile_command 2476 ), 2477 ) 2478 2479 def load_fn(): 2480 nonlocal lib 2481 if lib is None: 2482 if future is not None: 2483 future.result() 2484 result = worker_fn() 2485 assert result is None 2486 lib = cls._load_library(output_path, key) 2487 assert lib is not None 2488 return lib 2489 2490 if submit_fn is not None: 2491 with FileLock(lock_path, timeout=LOCK_TIMEOUT): 2492 if not os.path.exists(output_path): 2493 future = submit_fn(worker_fn) 2494 2495 cls.cache[key] = load_fn 2496 2497 return cls.cache[key] 2498 2499 @classmethod 2500 def load(cls, source_code: str, cuda: bool = False): 2501 return cls.load_async(source_code, cuda)() 2502 2503 2504def _worker_compile_cpp(lock_path, input_path, output_path, cmd): 2505 from filelock import FileLock 2506 2507 with FileLock(lock_path, timeout=LOCK_TIMEOUT): 2508 if not os.path.exists(output_path): 2509 compile_file(input_path, output_path, shlex.split(cmd)) 2510 2511 2512# Customized Python binding for cpp kernels 2513@clear_on_fresh_inductor_cache 2514class CppPythonBindingsCodeCache(CppCodeCache): 2515 cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} 2516 cache_clear = staticmethod(cache.clear) 2517 cpp_compile_command_flags = { 2518 # kernels have no dependency on libtorch 2519 "include_pytorch": False, 2520 "shared": True, 2521 } 2522 entry_function = "kernel" 2523 call_entry_function = "kernel(%s);Py_RETURN_NONE;" 2524 extra_parse_arg = "" 2525 suffix_template = textwrap.dedent( 2526 """ 2527 // Python bindings to call %s(): 2528 #define PY_SSIZE_T_CLEAN 2529 #include <Python.h> 2530 #include <sstream> 2531 #include <cstdlib> 2532 2533 #ifndef _MSC_VER 2534 #if __cplusplus < 202002L 2535 // C++20 earlier code 2536 // https://en.cppreference.com/w/cpp/language/attributes/likely 2537 #define likely(x) __builtin_expect(!!(x), 1) 2538 #define unlikely(x) __builtin_expect(!!(x), 0) 2539 #endif 2540 #endif 2541 2542 // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow. 2543 // We manually link it below to workaround issues with fbcode build. 2544 static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj); 2545 2546 template <typename T> static inline T parse_arg(PyObject* args, size_t n) { 2547 static_assert(std::is_pointer<T>::value, "arg type must be pointer or long"); 2548 return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n))); 2549 } 2550 template <> inline long parse_arg<long>(PyObject* args, size_t n) { 2551 auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n)); 2552 if(unlikely(result == -1 && PyErr_Occurred())) 2553 throw std::runtime_error("expected int arg"); 2554 return result; 2555 } 2556 template <> inline uintptr_t parse_arg<uintptr_t>(PyObject* args, size_t n) { 2557 auto result = PyLong_AsVoidPtr(PyTuple_GET_ITEM(args, n)); 2558 if(unlikely(result == reinterpret_cast<void*>(-1) && PyErr_Occurred())) 2559 throw std::runtime_error("expected int arg"); 2560 return reinterpret_cast<uintptr_t>(result); 2561 } 2562 2563 %s 2564 2565 static PyObject* %s_py(PyObject* self, PyObject* args) { 2566 try { 2567 if(unlikely(!PyTuple_CheckExact(args))) 2568 throw std::runtime_error("tuple args required"); 2569 if(unlikely(PyTuple_GET_SIZE(args) != %s)) 2570 throw std::runtime_error("requires %s args"); 2571 %s 2572 } catch(std::exception const& e) { 2573 PyErr_SetString(PyExc_RuntimeError, e.what()); 2574 return nullptr; 2575 } catch(...) { 2576 PyErr_SetString(PyExc_RuntimeError, "unhandled error"); 2577 return nullptr; 2578 } 2579 } 2580 2581 static PyMethodDef py_methods[] = { 2582 {"%s", %s_py, METH_VARARGS, ""}, 2583 {NULL, NULL, 0, NULL}}; 2584 2585 static struct PyModuleDef py_module = 2586 {PyModuleDef_HEAD_INIT, "%s", NULL, -1, py_methods}; 2587 2588 PyMODINIT_FUNC PyInit_%s(void) { 2589 const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"); 2590 if(!str_addr) { 2591 PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set"); 2592 return nullptr; 2593 } 2594 std::istringstream iss(str_addr); 2595 uintptr_t addr = 0; 2596 iss >> addr; 2597 _torchinductor_pyobject_tensor_data_ptr = 2598 reinterpret_cast<decltype(_torchinductor_pyobject_tensor_data_ptr)>(addr); 2599 return PyModule_Create(&py_module); 2600 } 2601 """ 2602 ) 2603 2604 @classmethod 2605 def _load_library_inner(cls, path: str, key: str) -> ModuleType: 2606 os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str( 2607 torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr # type: ignore[attr-defined] 2608 ) 2609 module_name = f"{key}.{cls.entry_function}" 2610 try: 2611 return sys.modules[module_name] 2612 except KeyError: 2613 pass 2614 spec = importlib.util.spec_from_file_location(module_name, path) 2615 assert spec is not None 2616 module = importlib.util.module_from_spec(spec) 2617 sys.modules[module_name] = module 2618 spec.loader.exec_module(module) # type: ignore[union-attr] 2619 return module 2620 2621 @classmethod 2622 def load_pybinding_async( 2623 cls, 2624 argtypes: List[str], 2625 source_code: str, 2626 cuda: bool = False, 2627 num_outputs: int = -1, 2628 submit_fn=None, 2629 extra_flags=(), 2630 ) -> Any: 2631 """ 2632 Wrap a C++ function in fast Python bindings. 2633 2634 Args: 2635 argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"] 2636 source_code: C++ source code containing a ENTRY_FUNCTION() function 2637 2638 Returns: 2639 A python version of ENTRY_FUNCTION() 2640 """ 2641 parseargs = ", ".join( 2642 f"parse_arg<{argtype.replace('const ', '')}>(args, {n})" 2643 for n, argtype in enumerate(argtypes) 2644 ) 2645 suffix = cls.suffix_template % ( 2646 cls.entry_function, 2647 cls.extra_parse_arg % num_outputs if cls.extra_parse_arg else "", 2648 cls.entry_function, 2649 len(argtypes), 2650 len(argtypes), 2651 cls.call_entry_function % parseargs, 2652 cls.entry_function, 2653 cls.entry_function, 2654 cls.entry_function, 2655 cls.entry_function, 2656 ) 2657 get_result = cls.load_async( 2658 source_code + suffix, cuda, submit_fn=submit_fn, extra_flags=extra_flags 2659 ) 2660 result = None 2661 2662 def future(): 2663 nonlocal result 2664 if result is None: 2665 result = get_result() 2666 assert isinstance(result, ModuleType) 2667 return getattr(result, cls.entry_function) 2668 2669 return future 2670 2671 @classmethod 2672 def load_pybinding(cls, *args, **kwargs) -> Any: 2673 return cls.load_pybinding_async(*args, **kwargs)() 2674 2675 2676@clear_on_fresh_inductor_cache 2677class CppWrapperCodeCache(CppPythonBindingsCodeCache): 2678 cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} 2679 cache_clear = staticmethod(cache.clear) 2680 cpp_compile_command_flags = { 2681 "include_pytorch": True, 2682 "shared": True, 2683 } 2684 entry_function = "inductor_entry_cpp" 2685 call_entry_function = "return inductor_entry_cpp(%s);" 2686 extra_parse_arg = textwrap.dedent( 2687 """ 2688 #include <torch/csrc/inductor/aoti_torch/c/shim.h> 2689 2690 static inline std::vector<AtenTensorHandle> unpack_tensor_handle_list(PyObject* pyvec) { 2691 std::vector<AtenTensorHandle> result; 2692 size_t result_len = PyList_GET_SIZE(pyvec); 2693 result.reserve(result_len); 2694 for (size_t i = 0; i < result_len; i++) { 2695 // AtenTensorHandle is essentially a pointer 2696 void* elem = PyCapsule_GetPointer(PyList_GET_ITEM(pyvec, i), NULL); 2697 result.push_back(reinterpret_cast<AtenTensorHandle>(elem)); 2698 } 2699 return result; 2700 } 2701 2702 static inline PyObject* pack_tensor_handle_list(const std::vector<AtenTensorHandle>& cppvec) { 2703 size_t result_len = cppvec.size(); 2704 PyObject* result = PyList_New(static_cast<Py_ssize_t>(result_len)); 2705 for (size_t i = 0; i < result_len; i++) { 2706 PyObject *elem = 2707 cppvec[i] == nullptr 2708 ? Py_None 2709 // Store AtenTensorHandle as PyCapsulate 2710 : PyCapsule_New(reinterpret_cast<void*>(cppvec[i]), NULL, NULL); 2711 PyList_SET_ITEM(result, i, elem); 2712 } 2713 return result; 2714 } 2715 2716 template <> inline std::vector<AtenTensorHandle> parse_arg<std::vector<AtenTensorHandle>>(PyObject* args, size_t n) { 2717 return unpack_tensor_handle_list(PyTuple_GET_ITEM(args, n)); 2718 } 2719 2720 PyObject* inductor_entry_cpp(std::vector<AtenTensorHandle>&& input_handles) { 2721 // For outputs, we only allocate a vector to hold returned tensor handles, 2722 // not allocating the actual output tensor storage here 2723 std::vector<AtenTensorHandle> output_handles(%s); 2724 try { 2725 inductor_entry_impl(input_handles.data(), output_handles.data()); 2726 return pack_tensor_handle_list(output_handles); 2727 } catch(std::exception const& e) { 2728 PyErr_SetString(PyExc_RuntimeError, e.what()); 2729 return {}; 2730 } catch(...) { 2731 PyErr_SetString(PyExc_RuntimeError, "unhandled error"); 2732 return {}; 2733 } 2734 } 2735 """ 2736 ) 2737 2738 2739# TODO: Will remove the temp code after switch to new cpp_builder 2740def _temp_validate_new_and_old_command(new_cmd: List[str], old_cmd: List[str]): 2741 new_diff: List[str] = [x for x in new_cmd if x not in old_cmd] 2742 old_diff: List[str] = [y for y in old_cmd if y not in new_cmd] 2743 2744 if new_diff or old_diff: 2745 print("!!! new_cmd: ", new_cmd) 2746 print("!!! old_cmd: ", old_cmd) 2747 print("!!! new_diff: ", new_diff) 2748 print("!!! old_diff: ", old_diff) 2749 raise RuntimeError("Error in new and old command different.") 2750 2751 2752def _do_validate_cpp_commands( 2753 include_pytorch: bool, 2754 cuda: bool, 2755 compile_only: bool, 2756 mmap_weights: bool, 2757 use_absolute_path: bool, 2758): 2759 # PreCI will failed if test machine can't run cuda. 2760 temp_dir = tempfile.TemporaryDirectory() 2761 test_dir_path = temp_dir.name 2762 test_cuda = torch.cuda.is_available() and cuda 2763 input_path = os.path.join(test_dir_path, "dummy_input.cpp") 2764 output_path = os.path.join(test_dir_path, "dummy_output.so") 2765 extra_flags = ["-D TEST_EXTRA_FLAGS"] 2766 if compile_only: 2767 output_path = os.path.join(test_dir_path, "dummy_output.o") 2768 picked_isa = pick_vec_isa() 2769 2770 old_cmd = cpp_compile_command( 2771 input=input_path, 2772 output=output_path, 2773 include_pytorch=include_pytorch, 2774 vec_isa=picked_isa, 2775 cuda=test_cuda, 2776 aot_mode=False, 2777 compile_only=compile_only, 2778 use_absolute_path=use_absolute_path, 2779 use_mmap_weights=mmap_weights, 2780 extra_flags=extra_flags, 2781 ).split(" ") 2782 2783 from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions 2784 2785 dummy_build_option = CppTorchCudaOptions( 2786 vec_isa=picked_isa, 2787 include_pytorch=include_pytorch, 2788 cuda=test_cuda, 2789 compile_only=compile_only, 2790 use_absolute_path=use_absolute_path, 2791 use_mmap_weights=mmap_weights, 2792 extra_flags=extra_flags, 2793 ) 2794 2795 dummy_builder = CppBuilder( 2796 name="dummy_output", 2797 sources=input_path, 2798 BuildOption=dummy_build_option, 2799 output_dir=test_dir_path, 2800 ) 2801 new_cmd = dummy_builder.get_command_line().split(" ") 2802 2803 _temp_validate_new_and_old_command(new_cmd, old_cmd) 2804 2805 temp_dir.cleanup() 2806 2807 2808# TODO: Will remove the temp code after switch to new cpp_builder 2809# It could help on sync new cpp_builder generate same command line as the old one. 2810def validate_new_cpp_commands(): 2811 cuda = [True, False] 2812 use_mmap_weights = [True, False] 2813 compile_only = [True, False] 2814 include_pytorch = [True, False] 2815 use_absolute_path = [True, False] 2816 2817 for x in cuda: 2818 for y in use_mmap_weights: 2819 for z in compile_only: 2820 for m in include_pytorch: 2821 for n in use_absolute_path: 2822 print( 2823 f"!!! cuda:{x}, use_mmap_weights:{y}, compile_only:{z}, include_pytorch:{m}, use_absolute_path:{n}" 2824 ) 2825 _do_validate_cpp_commands( 2826 include_pytorch=m, 2827 cuda=x, 2828 mmap_weights=y, 2829 compile_only=z, 2830 use_absolute_path=n, 2831 ) 2832 2833 2834@clear_on_fresh_inductor_cache 2835class HalideCodeCache(CppPythonBindingsCodeCache): 2836 cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {} 2837 cache_clear = staticmethod(cache.clear) 2838 glue_template = textwrap.dedent( 2839 """ 2840 #include "{halidebuffer_h}" 2841 #include "{headerfile}" 2842 #include <stdexcept> 2843 #include <cmath> 2844 void kernel({argdefs}) {{ 2845 {buffers} 2846 int err = halide_kernel({buffer_names}); 2847 if(err != 0) {{ 2848 throw std::runtime_error("halide_kernel failed"); 2849 }} 2850 }} 2851 """ 2852 ) 2853 2854 @classmethod 2855 def _codegen_glue(cls, argtypes, headerfile): 2856 buffers = [] 2857 buffer_names = [] 2858 for i, arg in enumerate(argtypes): 2859 if arg.numel: 2860 buffer_names.append(f"hl_buf_{i}") 2861 buffers.append( 2862 f" Halide::Runtime::Buffer {buffer_names[-1]}({arg.halide_type()}, {arg.name}, {arg.numel});" 2863 ) 2864 else: 2865 assert "*" not in arg.ctype 2866 buffer_names.append(arg.name) 2867 glue_code = cls.glue_template.format( 2868 halidebuffer_h=cls.find_header("HalideBuffer.h"), 2869 headerfile=headerfile, 2870 argdefs=", ".join(f"{a.bindings_type()} {a.name}" for a in argtypes), 2871 buffers="\n".join(buffers).lstrip(), 2872 buffer_names=", ".join(buffer_names), 2873 ) 2874 return glue_code 2875 2876 @classmethod 2877 @functools.lru_cache(None) 2878 def config_hash(cls): 2879 return sha256_hash( 2880 "\n".join( 2881 [ 2882 cls.glue_template, 2883 f"{cls.cpu_cache_size()}", 2884 cpp_compile_command("I", "O"), 2885 ] 2886 ).encode("utf-8") 2887 ) 2888 2889 @staticmethod 2890 @functools.lru_cache(None) 2891 def cpu_cache_size(): 2892 try: 2893 cpuinfo = open("/proc/cpuinfo").read() 2894 except OSError: 2895 return 16777216 2896 m = re.search(r"cache size\s*: (\d+) KB", cpuinfo) 2897 if m: 2898 return int(m.group(1)) * 1024 2899 m = re.search(r"cache size\s*: (\d+) MB", cpuinfo) 2900 if m: 2901 return int(m.group(1)) * 1024 * 1024 2902 raise RuntimeError("failed to find 'cache size: ... KB' in /proc/cpuinfo") 2903 2904 @staticmethod 2905 def _search_for_file(suffix, errmsg): 2906 try: 2907 search, *_ = importlib.machinery.PathFinder.find_spec( # type: ignore[union-attr,misc] 2908 "halide" 2909 ).submodule_search_locations 2910 for file in os.listdir(search): 2911 if file.endswith(".so"): 2912 try: 2913 out = subprocess.check_output( 2914 ["ldd", os.path.join(search, file)] 2915 ) 2916 except subprocess.SubprocessError: 2917 continue 2918 m = re.search(r"(/.*)/libHalide.so", out.decode("utf-8")) 2919 if m: 2920 path = os.path.join(os.path.abspath(m.group(1)), suffix) 2921 if os.path.exists(path): 2922 return os.path.abspath(path) 2923 except Exception as e: 2924 raise RuntimeError(errmsg) from e 2925 raise RuntimeError(errmsg) 2926 2927 @staticmethod 2928 @functools.lru_cache(None) 2929 def find_libautoschedule(name): 2930 sofile = f"libautoschedule_{name.lower()}.so" 2931 if "HALIDE_LIB" in os.environ: 2932 path = os.path.join(os.environ["HALIDE_LIB"], sofile) 2933 if os.path.exists(path): 2934 return path 2935 errmsg = ( 2936 f"Can't find {sofile}, set env HALIDE_LIB to the directory containing it" 2937 ) 2938 return HalideCodeCache._search_for_file(sofile, errmsg) 2939 2940 @staticmethod 2941 @functools.lru_cache(None) 2942 def find_header(name): 2943 if "HALIDE_INCLUDE" in os.environ: 2944 path = os.path.join(os.environ["HALIDE_INCLUDE"], name) 2945 if os.path.exists(path): 2946 return path 2947 if "HALIDE_LIB" in os.environ: 2948 path = os.path.abspath( 2949 os.path.join(os.environ["HALIDE_LIB"], f"../include/{name}") 2950 ) 2951 if os.path.exists(path): 2952 return path 2953 errmsg = ( 2954 f"Can't find {name}, set env HALIDE_INCLUDE to the directory containing it" 2955 ) 2956 return HalideCodeCache._search_for_file(f"../include/{name}", errmsg) 2957 2958 @classmethod 2959 def generate_halide_async(cls, meta: HalideMeta, source_code: str, submit_fn=None): 2960 dirpath = Path( 2961 get_path( 2962 code_hash( 2963 source_code, 2964 extra=repr((cls.config_hash(), meta)), 2965 ), 2966 "halide", 2967 )[2] 2968 ) 2969 os.makedirs(dirpath, exist_ok=True) 2970 wait_for_compile = None 2971 genfile = str(dirpath / "generate_kernel.py") 2972 libfile = str(dirpath / "halide_kernel.a") 2973 headerfile = str(dirpath / "halide_kernel.h") 2974 donefile = str(dirpath / "done") 2975 lockfile = str(dirpath / "lock") 2976 need_compile = not os.path.exists(donefile) 2977 jobs = [] 2978 2979 if need_compile: 2980 write_atomic(genfile, source_code) 2981 jobs.append( 2982 functools.partial( 2983 subprocess.check_call, 2984 [ 2985 sys.executable, 2986 genfile, 2987 "-g", 2988 "kernel", 2989 "-o", 2990 f"{dirpath}", 2991 "-f", 2992 "halide_kernel", 2993 "-e", 2994 "static_library,h,schedule,pytorch_wrapper", 2995 "-p", 2996 cls.find_libautoschedule(meta.scheduler), 2997 *meta.args(), 2998 ], 2999 ) 3000 ) 3001 3002 bindings_future = cls.load_pybinding_async( 3003 [arg.bindings_type() for arg in meta.argtypes], 3004 cls._codegen_glue(meta.argtypes, headerfile), 3005 extra_flags=(libfile,), 3006 submit_fn=jobs.append if need_compile else None, 3007 ) 3008 3009 if need_compile: 3010 jobs.append(functools.partial(touch, donefile)) 3011 task = functools.partial(_worker_task_halide, lockfile, jobs) 3012 if submit_fn: 3013 wait_for_compile = submit_fn(task).result 3014 else: 3015 task() 3016 3017 def load(): 3018 if wait_for_compile: 3019 wait_for_compile() 3020 return bindings_future() 3021 3022 return load 3023 3024 @classmethod 3025 def generate_halide(cls, *args, **kwargs): 3026 return cls.generate_halide_async(*args, **kwargs)() 3027 3028 3029def _worker_task_halide(lockfile, jobs): 3030 from filelock import FileLock 3031 3032 with FileLock(lockfile, LOCK_TIMEOUT): 3033 for job in jobs: 3034 job() 3035 3036 3037def touch(filename): 3038 open(filename, "a").close() 3039 3040 3041@clear_on_fresh_inductor_cache 3042class PyCodeCache: 3043 cache: Dict[str, ModuleType] = dict() 3044 linemaps: Dict[str, List[Tuple[Any, ...]]] = dict() 3045 cache_clear = staticmethod(cache.clear) 3046 3047 @classmethod 3048 def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]: 3049 return write(source_code, "py", extra=extra) 3050 3051 @classmethod 3052 def load( 3053 cls, 3054 source_code: str, 3055 extra: str = "", 3056 linemap: Optional[List[Tuple[int, str]]] = None, 3057 attrs: Optional[Dict[str, Any]] = None, 3058 ) -> ModuleType: 3059 key, path = write(source_code, "py", extra=extra) 3060 return cls.load_by_key_path(key, path, linemap, attrs) 3061 3062 @classmethod 3063 def load_by_key_path( 3064 cls, 3065 key: str, 3066 path: str, 3067 linemap: Optional[List[Tuple[int, str]]] = None, 3068 attrs: Optional[Dict[str, Any]] = None, 3069 ) -> ModuleType: 3070 if linemap is None: 3071 linemap = [] 3072 if key not in cls.cache: 3073 mod = _reload_python_module(key, path) 3074 3075 # another thread might set this first 3076 cls.cache.setdefault(key, mod) 3077 # unzip into separate lines/nodes lists 3078 cls.linemaps[path] = list(zip(*linemap)) 3079 3080 if attrs is not None: 3081 for k, v in attrs.items(): 3082 setattr(mod, k, v) 3083 3084 if not (linemap or attrs): 3085 mod._reload_in_subproc = functools.partial( # type: ignore[attr-defined] 3086 _reload_python_module_in_subproc, key, path 3087 ) 3088 3089 return cls.cache[key] 3090 3091 @classmethod 3092 @functools.lru_cache(None) 3093 def stack_frames_for_code( 3094 cls, path: str, lineno: int 3095 ) -> Optional[List[Dict[str, Any]]]: 3096 if path not in cls.linemaps: 3097 return None 3098 # [(starting_line, <fx node>), ...] 3099 lines, nodes = cls.linemaps[path] 3100 p = bisect_right(lines, lineno) 3101 if p == 0: 3102 return None 3103 entry = nodes[p - 1] 3104 if not entry: 3105 return None 3106 3107 def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]: 3108 # ideally fx stores stack traces as data rather than a string 3109 # but this is not along a performance critical path 3110 regex = r'File "(.+)", line (\d+), in (.+)\n' 3111 matches = re.findall(regex, stack_trace) 3112 return [ 3113 {"filename": f, "line": int(l), "name": n} 3114 for f, l, n in reversed(matches) 3115 ] 3116 3117 return parse_stack_trace(entry) 3118 3119 3120class TritonCodeCache: 3121 @classmethod 3122 def load(cls, kernel_name: str, source_code: str) -> ModuleType: 3123 return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name) 3124 3125 3126def _cuda_compiler() -> Optional[str]: 3127 if cuda_env.nvcc_exist(config.cuda.cuda_cxx): 3128 return config.cuda.cuda_cxx 3129 if config.is_fbcode(): 3130 return os.path.join(build_paths.cuda(), "bin", "nvcc") 3131 if cuda_env.nvcc_exist(os.getenv("CUDACXX")): 3132 return os.getenv("CUDACXX", "") 3133 if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")): 3134 return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc")) 3135 return "nvcc" 3136 3137 3138def _cutlass_include_paths() -> List[str]: 3139 if config.is_fbcode(): 3140 from libfb.py import parutil 3141 3142 cutlass_path = parutil.get_dir_path("cutlass-3-headers") 3143 else: 3144 cutlass_path = config.cuda.cutlass_dir 3145 return [ 3146 # Use realpath to get canonical absolute paths, in order not to mess up cache keys 3147 os.path.realpath(os.path.join(cutlass_path, "include")), 3148 os.path.realpath(os.path.join(cutlass_path, "tools/library/include")), 3149 os.path.realpath(os.path.join(cutlass_path, "tools/library/src")), 3150 os.path.realpath(os.path.join(cutlass_path, "tools/util/include")), 3151 ] 3152 3153 3154def _cuda_lib_options() -> List[str]: 3155 _set_gpu_runtime_env() # cpp_extension consults the env 3156 from torch.utils import cpp_extension 3157 3158 lpaths = cpp_extension.library_paths(cuda=True) + [ 3159 sysconfig.get_config_var("LIBDIR") 3160 ] 3161 extra_ldflags: List[str] = [] 3162 if is_linux(): 3163 _transform_cuda_paths(lpaths) 3164 for path in lpaths: 3165 # -rpath ensures the DLL can find its dependencies when loaded, even 3166 # if the library path is non-standard. 3167 extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"]) 3168 extra_ldflags.append("-lcuda") 3169 extra_ldflags.append("-lcudart") 3170 else: 3171 raise NotImplementedError( 3172 "Unsupported env, failed to find cuda libs! Currently only Linux is supported." 3173 ) 3174 return extra_ldflags 3175 3176 3177def _nvcc_host_compiler_options() -> List[str]: 3178 return [ 3179 "-fPIC", 3180 "-fno-strict-aliasing", 3181 "-fvisibility=hidden", 3182 "-Wconversion", 3183 ] 3184 3185 3186def _nvcc_compiler_options() -> List[str]: 3187 arch = cuda_env.get_cuda_arch() 3188 if arch == "90": 3189 # Required by cutlass compilation. 3190 arch = "90a" 3191 code = [f"sm_{arch}", f"compute_{arch}"] 3192 if config.cuda.enable_cuda_lto: 3193 code += [f"lto_{arch}"] 3194 options = [ 3195 "-t=0", 3196 "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", 3197 "-w", 3198 f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", 3199 config.cuda.compile_opt_level, 3200 "-std=c++17", 3201 "--expt-relaxed-constexpr", 3202 "-DNDEBUG", 3203 ] 3204 if config.is_fbcode(): 3205 options.extend(["-ccbin", os.path.dirname(build_paths.gcc())]) 3206 if config.cuda.enable_debug_info: 3207 options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) 3208 if config.cuda.enable_ptxas_info: 3209 options.extend( 3210 [ 3211 "--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.) 3212 "--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels 3213 "--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels 3214 "--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.) 3215 "--source-in-ptx", 3216 ] 3217 ) # Annotate the ptx file with source information 3218 if config.cuda.use_fast_math: 3219 options.extend( 3220 [ 3221 "--use_fast_math", 3222 "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", 3223 ] 3224 ) 3225 return options 3226 3227 3228def cuda_compile_command( 3229 src_files: List[str], 3230 dst_file: str, 3231 dst_file_ext: str, 3232 extra_args: Optional[List[str]] = None, 3233) -> str: 3234 if extra_args is None: 3235 extra_args = [] 3236 include_paths = _cutlass_include_paths() 3237 cuda_lib_options = _cuda_lib_options() 3238 nvcc_host_compiler_options = _nvcc_host_compiler_options() 3239 nvcc_compiler_options = _nvcc_compiler_options() 3240 options = ( 3241 nvcc_compiler_options 3242 + extra_args 3243 + [ 3244 f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}" 3245 for opt in nvcc_host_compiler_options 3246 ] 3247 + ["-I" + path for path in include_paths] 3248 + cuda_lib_options 3249 ) 3250 src_file = " ".join(src_files) 3251 res = "" 3252 if dst_file_ext == "o": 3253 res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}" 3254 elif dst_file_ext == "so": 3255 options.append("-shared") 3256 res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" 3257 elif dst_file_ext == "exe": 3258 res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" 3259 else: 3260 raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") 3261 log.debug("CUDA command: %s", res) 3262 return res 3263 3264 3265class DLLWrapper: 3266 """A wrapper for a dynamic library.""" 3267 3268 def __init__( 3269 self, 3270 lib_path: str, 3271 ): 3272 self.lib_path = lib_path 3273 self.is_open = False 3274 self.DLL = cdll.LoadLibrary(lib_path) 3275 self.is_open = True 3276 3277 def close(self): 3278 if self.is_open: 3279 self._dlclose() 3280 self.is_open = False 3281 3282 def _dlclose(self): 3283 f_dlclose = None 3284 3285 if is_linux(): 3286 syms = CDLL(None) 3287 if not hasattr(syms, "dlclose"): 3288 # Apline Linux 3289 syms = CDLL("libc.so") 3290 3291 if hasattr(syms, "dlclose"): 3292 f_dlclose = syms.dlclose 3293 else: 3294 raise NotImplementedError("Unsupported env, failed to do dlclose!") 3295 3296 if f_dlclose is not None: 3297 f_dlclose.argtypes = [c_void_p] 3298 f_dlclose(self.DLL._handle) 3299 else: 3300 log.warning( 3301 "dll unloading function was not found, library may not be unloaded properly!" 3302 ) 3303 3304 def __getattr__(self, name): 3305 if not self.is_open: 3306 raise RuntimeError(f"Cannot use closed DLL library: {self.lib_path}") 3307 3308 method = getattr(self.DLL, name) 3309 3310 def _wrapped_func(*args): 3311 err = method(*args) 3312 if err: 3313 raise RuntimeError(f"Error in function: {method.__name__}") 3314 3315 return _wrapped_func 3316 3317 def __enter__(self): 3318 return self 3319 3320 def __exit__(self, *args): 3321 self.close() 3322 3323 def __del__(self): 3324 self.close() 3325 3326 3327@clear_on_fresh_inductor_cache 3328class CUDACodeCache: 3329 @dataclasses.dataclass 3330 class CacheEntry: 3331 input_path: str 3332 output_path: str 3333 3334 cache: Dict[str, CacheEntry] = dict() 3335 cache_clear = staticmethod(cache.clear) 3336 _SOURCE_CODE_SUFFIX = "cu" 3337 3338 @classmethod 3339 def write(cls, source_code, dst_file_ext) -> Tuple[str, str]: 3340 """ 3341 Writes source code into a file with dst_file_ext as the file extension. 3342 Returns the hash key of source code, and the path to the file. 3343 """ 3344 3345 cuda_command = repr( 3346 cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext) 3347 ) 3348 key, input_path = write( 3349 source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command 3350 ) 3351 return key, input_path 3352 3353 @classmethod 3354 def compile( 3355 cls, source_code, dst_file_ext, extra_args: Optional[List[str]] = None 3356 ) -> Tuple[str, str, str]: 3357 """ 3358 Compiles CUDA source_code into a file with dst_file_ext extension. 3359 Returns a tuple of dst_file_path, hash_key, source_code_path 3360 """ 3361 key, input_path = cls.write(source_code, dst_file_ext) 3362 if key not in cls.cache: 3363 from filelock import FileLock 3364 3365 lock_dir = get_lock_dir() 3366 lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) 3367 with lock: 3368 output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext 3369 if not os.path.exists(output_path): 3370 cmd = cuda_compile_command( 3371 [input_path], output_path, dst_file_ext, extra_args 3372 ) 3373 start_time = time() 3374 log.debug("CUDA Compilation: %s", cmd) 3375 cmd_parts = cmd.split(" ") 3376 try: 3377 subprocess.check_output( 3378 cmd_parts, stderr=subprocess.STDOUT, env=os.environ 3379 ) 3380 except subprocess.CalledProcessError as error: 3381 raise exc.CUDACompileError(cmd_parts, error.output) from error 3382 end_time = time() 3383 log_duration_msg = f"CUDA Compilation took {end_time-start_time} seconds. Compile command: {cmd}" 3384 log.info(log_duration_msg) 3385 else: 3386 log.debug( 3387 "CUDA Compilation skipped: %s since output already exists", 3388 input_path, 3389 ) 3390 cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path) 3391 3392 return (cls.cache[key].output_path, key, input_path) 3393 3394 @classmethod 3395 def load(cls, source_code, dst_file_ext) -> Tuple[DLLWrapper, str, str]: 3396 """ 3397 Compiles source code and loads the generated .so file. 3398 Returns a tuple of DLLWrapper, hash_key, source_code_path 3399 """ 3400 3401 if dst_file_ext != "so": 3402 raise RuntimeError( 3403 f"Only support loading a .so file for now. " 3404 f"Requested file extension: {dst_file_ext}. Source code: {source_code}" 3405 ) 3406 dst_file_path, hash_key, source_code_path = cls.compile( 3407 source_code, dst_file_ext 3408 ) 3409 return (DLLWrapper(dst_file_path), hash_key, source_code_path) 3410 3411 3412class CodeCacheFuture: 3413 def result(self): 3414 raise NotImplementedError 3415 3416 3417class TritonFuture(CodeCacheFuture): 3418 kernel: ModuleType 3419 3420 def __init__( 3421 self, 3422 kernel: Any, 3423 future: Optional[Future[Any]], 3424 ) -> None: 3425 self.kernel = kernel 3426 self.future = future 3427 3428 # @dynamo_utils.dynamo_timed 3429 def result(self) -> ModuleType: 3430 if self.future is not None: 3431 # If the worker failed this will throw an exception. 3432 result = self.future.result() 3433 assert result is None 3434 self.future = None 3435 self.kernel.precompile() 3436 return self.kernel 3437 3438 3439class LambdaFuture(CodeCacheFuture): 3440 def __init__(self, result_fn): 3441 self.result_fn = result_fn 3442 3443 def result(self): 3444 return self.result_fn() 3445