xref: /aosp_15_r20/external/pytorch/torch/_inductor/codecache.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import base64
5import copyreg
6import dataclasses
7import functools
8import hashlib
9import importlib
10import io
11import json
12import logging
13import os
14import pickle
15import pkgutil
16import platform
17import re
18import shlex
19import shutil
20import struct
21import subprocess
22import sys
23import sysconfig
24import tempfile
25import textwrap
26import threading
27import warnings
28from bisect import bisect_right
29from copy import copy
30from ctypes import c_void_p, cdll, CDLL
31from functools import partial
32from pathlib import Path
33from time import time, time_ns
34from types import ModuleType
35from typing import (
36    Any,
37    Callable,
38    cast,
39    Dict,
40    Generator,
41    List,
42    Optional,
43    Sequence,
44    Set,
45    Tuple,
46    TYPE_CHECKING,
47    Union,
48)
49
50import torch
51from torch._dynamo.utils import counters, dynamo_timed
52from torch._inductor import config, exc, metrics
53from torch._inductor.codegen.cuda import cuda_env
54from torch._inductor.runtime.compile_tasks import (
55    _module_to_triton_kernel,
56    _reload_python_module,
57    _reload_python_module_in_subproc,
58)
59from torch._inductor.runtime.runtime_utils import cache_dir
60from torch._inductor.utils import ALIGN_BYTES, clear_on_fresh_inductor_cache, is_linux
61
62from torch._logging import trace_structured
63from torch._subclasses.fake_tensor import (
64    extract_tensor_metadata,
65    FakeTensor,
66    TensorMetadata,
67)
68from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
69
70if TYPE_CHECKING:
71    from concurrent.futures import Future
72
73    from torch._inductor.graph import GraphLowering
74    from torch._inductor.ir import ChoiceCaller
75    from torch._inductor.runtime.hints import HalideMeta
76
77
78_HERE = os.path.abspath(__file__)
79_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
80_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld")
81
82_IS_WINDOWS = sys.platform == "win32"
83
84if config.is_fbcode():
85    from triton.fb import build_paths
86    from triton.fb.build import _run_build_command
87
88    from torch._inductor.fb.utils import (
89        log_global_cache_errors,
90        log_global_cache_stats,
91        log_global_cache_vals,
92        use_global_cache,
93    )
94else:
95
96    def log_global_cache_errors(*args, **kwargs):
97        pass
98
99    def log_global_cache_stats(*args, **kwargs):
100        pass
101
102    def log_global_cache_vals(*args, **kwargs):
103        pass
104
105    def use_global_cache() -> bool:
106        return False
107
108
109output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
110
111LOCK_TIMEOUT = 600
112
113_IS_WINDOWS = sys.platform == "win32"
114
115
116log = logging.getLogger(__name__)
117
118
119def cpp_wrapper_cache_dir(name: str) -> str:
120    cu_str = (
121        "cpu"
122        if torch.version.cuda is None
123        else f'cu{torch.version.cuda.replace(".", "")}'
124    )
125    python_version = f"py{sys.version_info.major}{sys.version_info.minor}"
126    build_folder = f"{python_version}_{cu_str}"
127
128    cpp_wrapper_dir = os.path.join(cache_dir(), build_folder)
129    cpp_wrapper_build_directory = os.path.join(cpp_wrapper_dir, name)
130    os.makedirs(cpp_wrapper_build_directory, exist_ok=True)
131    return cpp_wrapper_build_directory
132
133
134def get_cpp_wrapper_cubin_path_name():
135    return "cubin_path" if torch.version.hip is None else "hsaco_path"
136
137
138class CacheBase:
139    @staticmethod
140    @functools.lru_cache(None)
141    def get_system() -> Dict[str, Any]:
142        try:
143            from triton.compiler.compiler import triton_key
144
145            # Use triton_key instead of triton.__version__ as the version
146            # is not updated with each code change
147            triton_version = triton_key()
148        except ModuleNotFoundError:
149            triton_version = None
150
151        try:
152            system: Dict[str, Any] = {
153                "device": {
154                    "name": torch.cuda.get_device_properties(
155                        torch.cuda.current_device()
156                    ).name,
157                },
158                "version": {
159                    "cuda": torch.version.cuda,
160                    "triton": triton_version,
161                },
162            }
163        except (AssertionError, RuntimeError):
164            # If cuda is not installed, none of the above config is relevant.
165            system = {}
166
167        system["hash"] = hashlib.sha256(
168            json.dumps(system, sort_keys=True).encode("utf-8")
169        ).hexdigest()
170
171        return system
172
173    @staticmethod
174    @clear_on_fresh_inductor_cache
175    @functools.lru_cache(None)
176    def get_local_cache_path() -> Path:
177        return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"]))
178
179    @staticmethod
180    @functools.lru_cache(None)
181    def get_global_cache_path() -> Optional[Path]:
182        return (
183            Path(os.path.join(config.global_cache_dir, CacheBase.get_system()["hash"]))
184            if config.global_cache_dir is not None
185            else None
186        )
187
188    def __init__(self) -> None:
189        self.system = CacheBase.get_system()
190
191    def get_local_cache(self) -> Dict[str, Any]:
192        local_cache_path = self.get_local_cache_path()
193        if not local_cache_path.is_file():
194            return {}
195        with open(local_cache_path) as local_cache_fp:
196            local_cache = json.load(local_cache_fp)
197        return local_cache["cache"]
198
199    def update_local_cache(self, local_cache: Dict[str, Any]) -> None:
200        local_cache_path = self.get_local_cache_path()
201        write_atomic(
202            str(local_cache_path),
203            json.dumps({"system": self.system, "cache": local_cache}, indent=4),
204            make_dirs=True,
205        )
206
207
208class LocalCache(CacheBase):
209    def lookup(self, *keys: str) -> Optional[Dict[str, Any]]:
210        cache = self.get_local_cache()
211
212        sub_cache = cache
213        for key in keys:
214            if key in cache:
215                sub_cache = cache[key]
216            else:
217                return None
218
219        return sub_cache
220
221    def set_value(self, *keys: str, value: Any) -> None:
222        cache = self.get_local_cache()
223
224        sub_cache = cache
225        for key in keys[0:-1]:
226            sub_cache.setdefault(key, {})
227            sub_cache = sub_cache[key]
228        sub_cache[keys[-1]] = value
229
230        self.update_local_cache(cache)
231
232
233class PersistentCache(CacheBase):
234    @functools.lru_cache(None)  # noqa: B019
235    def get_global_cache(self):
236        global_cache_path = self.get_global_cache_path()
237        if global_cache_path is None or not global_cache_path.is_file():
238            return {}
239        with open(global_cache_path) as global_cache_fp:
240            global_cache = json.load(global_cache_fp)
241        return global_cache["cache"]
242
243    def lookup(
244        self,
245        choices: List[ChoiceCaller],
246        op: str,
247        inputs: str,
248        benchmark: Optional[Callable[[Any], Dict[ChoiceCaller, float]]],
249    ) -> Dict[ChoiceCaller, float]:
250        """
251        Check to see if we have benchmarked the given choice callers. For each
252        choice caller:
253
254            1. Check global_cache[op][inputs][choice][precision], return benchmark if cached.
255            2. Check local_cache[op][inputs][choice][precision], return benchmark if cached.
256            3. If benchmark is not None:
257                a. `max_autotune_gemm=True`: benchmark the choice, update
258                    local_cache[op][inputs][choice], and return the benchmark.
259                b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing.
260        """
261        precision = torch.get_float32_matmul_precision()
262
263        log_stats = partial(log_global_cache_stats, self.system, op, inputs, precision)
264        log_vals = partial(log_global_cache_vals, self.system, op, inputs, precision)
265        log_errors = partial(
266            log_global_cache_errors, self.system, op, inputs, precision
267        )
268        timings = {}
269
270        def check_cache(cache, callback=None) -> bool:
271            """Check if `cache` contains data for all the choices"""
272            hit = True
273            for choice in choices:
274                choice_hash = choice.hash_key()
275                if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}):
276                    # cache hit
277                    timings[choice] = cache[op][inputs][precision][choice_hash]
278                else:
279                    # cache miss
280                    hit = False
281                    break
282            if callback:
283                callback(cached=hit)
284            return hit
285
286        if config.max_autotune or config.max_autotune_gemm:
287            local_cache = self.get_local_cache() if config.autotune_local_cache else {}
288            # check local cache first since it is data specific to the current machine
289            if (
290                not check_cache(local_cache)
291                and not (
292                    use_global_cache()
293                    and check_cache(self.get_global_cache(), callback=log_stats)
294                )
295                and benchmark is not None
296            ):
297                try:
298                    # re-benchmark everything to try to get consistent numbers from the same machine
299                    timings = benchmark(choices)
300                    assert all(choice in timings for choice in choices)
301                    local_cache.setdefault(op, {})
302                    local_cache[op].setdefault(inputs, {}).setdefault(precision, {})
303                    for choice, timing in timings.items():
304                        local_cache[op][inputs][precision][choice.hash_key()] = timing
305                except RuntimeError as e:
306                    # catch and log autotuning failures
307                    log_errors(e)
308                    raise e
309
310                self.update_local_cache(local_cache)
311
312                timings_to_log = {
313                    choice.hash_key(): timings[choice] for choice in choices
314                }
315                log_vals(timings_to_log)
316        elif use_global_cache():
317            # only check global cache, not local one
318            check_cache(self.get_global_cache(), callback=log_stats)
319            # may have a partial cache hit, where not everything is benchmarked
320
321        return timings
322
323
324def get_lock_dir() -> str:
325    lock_dir = os.path.join(cache_dir(), "locks")
326    if not os.path.exists(lock_dir):
327        os.makedirs(lock_dir, exist_ok=True)
328    return lock_dir
329
330
331def sha256_hash(data: bytes) -> str:
332    # [:51] to strip off the "Q====" suffix common to every hash value.
333    return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
334
335
336def code_hash(code: Union[str, bytes], extra: str = ""):
337    hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
338    if extra != "":
339        hashing_str = hashing_str + b"||" + extra.encode("utf-8")
340    return "c" + sha256_hash(hashing_str)
341
342
343def get_path(
344    basename: str, extension: str, specified_dir: str = ""
345) -> Tuple[str, str, str]:
346    if specified_dir:
347        if os.path.isabs(specified_dir):
348            subdir = specified_dir
349        else:
350            subdir = os.path.join(cache_dir(), specified_dir)
351    else:
352        subdir = os.path.join(cache_dir(), basename[1:3])
353    path = os.path.join(subdir, f"{basename}.{extension}")
354    return basename, subdir, path
355
356
357def get_hash(content: Union[str, bytes], extra: str = "", hash_type: str = "code"):
358    if hash_type == "code":
359        return code_hash(content, extra)
360    if hash_type in ["cubin", "hsaco"]:
361        return code_hash(repr(content))
362    raise AssertionError(f"Unknown hash type {hash_type}")
363
364
365def write(
366    content: Union[str, bytes],
367    extension: str,
368    extra: str = "",
369    hash_type: str = "code",
370    specified_dir: str = "",
371) -> Tuple[str, str]:
372    # use striped content to compute hash so we don't end up with different
373    # hashes just because the content begins/ends with different number of
374    # spaces.
375    key: str = get_hash(content.strip(), extra, hash_type)
376    basename, subdir, path = get_path(key, extension, specified_dir)
377    if not os.path.exists(path):
378        write_atomic(path, content, make_dirs=True)
379    return basename, path
380
381
382def write_text(text: str) -> str:
383    """
384    Write the `text` to a file and return the path computed based on the hash.
385    """
386    return write(text, "txt")[1]
387
388
389def write_atomic(
390    path: str, content: Union[str, bytes], make_dirs: bool = False
391) -> None:
392    # Write into temporary file first to avoid conflicts between threads
393    # Avoid using a named temporary file, as those have restricted permissions
394    assert isinstance(
395        content, (str, bytes)
396    ), "Only strings and byte arrays can be saved in the cache"
397    path = Path(path)
398    if make_dirs:
399        path.parent.mkdir(parents=True, exist_ok=True)
400    tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp"
401    write_mode = "w" if isinstance(content, str) else "wb"
402    with tmp_path.open(write_mode) as f:
403        f.write(content)
404    tmp_path.rename(path)
405
406
407@dataclasses.dataclass
408class TensorMetadataAndValues:
409    """
410    TensorMetadata plus the elements as a list of raw values.
411    Used for hashing inlined constants.
412    """
413
414    tensor_metadata: TensorMetadata
415    values: List[Any]
416
417
418def _ident(x: Any) -> Any:
419    return x
420
421
422def extract_tensor_metadata_for_cache_key(t):
423    """
424    Extracts the tensor metadata and removes fields of the TensorMetadata
425    that are not needed for caching
426    """
427    meta = extract_tensor_metadata(t)
428    if not hasattr(t, "_is_inductor_static"):
429        meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None)
430    return meta
431
432
433def _reduce_fake_tensor(t):
434    """
435    See FxGraphCachePickler. Custom reducer to pickle FakeTensors.
436    """
437    metadata = extract_tensor_metadata_for_cache_key(t)
438    return (_ident, (metadata,))
439
440
441def _reduce_tensor(t):
442    """
443    See FxGraphCachePickler. Custom reducer to pickle Tensors.
444    If we see tensors, we know they're constants stored as attributes on
445    the GraphModule. Include the values in the key calculation. Small
446    tensors will be inlined, so we can't serve the same cache entry for
447    different values anyway. Large constants are treated as parameters,
448    so we could conceivably reuse a cache entry. To do that, however,
449    PyCodeCache would need more complexity to create a new module from its
450    cache, but with the right constants attached as attributes.
451    """
452    if t.is_mkldnn:
453        # TODO: These tensors don't currently pickle, so we can't cache a
454        # compiled graph containing them. Just fail now. If mkldnn tensors
455        # get pickling support, we can remove this.
456        raise BypassFxGraphCache
457
458    # Very large tensors could be expensive to copy to cpu and hash. Let's
459    # at least report if we find slowness.
460    start = time()
461    values = t.tolist()
462    elapsed = time() - start
463    if elapsed > 1.0:
464        warnings.warn(
465            f"FX graph cache handling of a large constant took {elapsed:.1}s. Please file an issue."
466        )
467
468    metadata = extract_tensor_metadata_for_cache_key(t)
469    return (_ident, (TensorMetadataAndValues(metadata, values),))
470
471
472def _reduce_symint(s):
473    """
474    See FxGraphCachePickler. Custom reducer to pickle SymInts.
475    """
476    # For hashing purposes, we only care about the name of the symbol and
477    # not the backed value. We evaluate guards stored with a cached graph
478    # to ensure a cached entity with SymInt args is safe to reuse.
479    return (_ident, (str(s),))
480
481
482def _reduce_unsupported(s):
483    """
484    See FxGraphCachePickler. Custom reducer to handle any objects that we don't
485    support and therefore raise to bypass caching.
486    """
487    raise BypassFxGraphCache
488
489
490class FxGraphCachePickler(pickle.Pickler):
491    """
492    Custom pickler to customize the pickling of some objects (Tensors), only for the
493    purpose of computing a hash for keying into the FxGraphCache. Tensors contain
494    objects that don't pickle and/or vary between runs, and we want to capture the
495    data that allow us to compute a stable, but safe hash.
496    """
497
498    dispatch_table = copyreg.dispatch_table.copy()
499    dispatch_table[FakeTensor] = _reduce_fake_tensor
500    dispatch_table[torch.Tensor] = _reduce_tensor
501    dispatch_table[torch.SymInt] = _reduce_symint
502    dispatch_table[
503        torch.fx.experimental._backward_state.BackwardState
504    ] = _reduce_unsupported
505
506    @classmethod
507    def dumps(cls, obj) -> bytes:
508        """
509        Pickle an object using the FxGraphCachePickler.
510        """
511        with io.BytesIO() as stream:
512            pickler = cls(stream)
513            try:
514                pickler.dump(obj)
515            except (TypeError, AttributeError) as e:
516                # Some configs options are callables, e.g., post_grad_custom_pre_pass,
517                # and may not pickle.
518                log.warning("Can't pickle", exc_info=True)
519                raise BypassFxGraphCache from e
520            return stream.getvalue()
521
522    @classmethod
523    def get_hash(cls, obj: Any) -> str:
524        """
525        Serialize an object using the FxGraphCachePickler and return a hash
526        of the pickled object.
527        """
528        serialized_data = cls.dumps(obj)
529        return sha256_hash(serialized_data)
530
531    @classmethod
532    def debug_str(cls, inp: Any) -> str:
533        """
534        Get a printable string describing in more detail all the attributes
535        comprising an object. Useful for debugging when one graph hashes
536        to a different value than another.
537        """
538
539        def get_str(obj) -> str:
540            if isinstance(obj, torch.Tensor):
541                return str(extract_tensor_metadata_for_cache_key(obj))
542            elif isinstance(obj, bytes):
543                return "<bytes>"
544            else:
545                return str(obj)
546
547        lines = []
548        for attr, obj in vars(inp).items():
549            if isinstance(obj, list):
550                for ii in range(len(obj)):
551                    h = cls.get_hash(obj[ii])
552                    lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}")
553            elif isinstance(obj, dict):
554                for k, v in obj.items():
555                    h = cls.get_hash(v)
556                    lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}")
557            else:
558                h = cls.get_hash(obj)
559                lines.append(f"[{h}] {attr}: {get_str(obj)}")
560        return "\n".join(lines)
561
562
563def build_code_hash(roots, prefix, hasher):
564    for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name):
565        spec = lib.module_finder.find_spec(lib.name, None)
566        assert spec is not None
567        module = spec.origin
568        assert module is not None
569        with open(module, "rb") as f:
570            hasher.update(spec.name.encode("utf-8"))
571            hasher.update(f.read())
572        if lib.ispkg:
573            # need to also hash submodules
574            build_code_hash(spec.submodule_search_locations, f"{spec.name}.", hasher)
575
576
577def get_code_hash(roots, extra_files=()):
578    hasher = hashlib.sha256()
579    hasher.update(torch.__version__.encode("utf-8"))
580    build_code_hash(roots, "", hasher)
581    for path in extra_files:
582        if os.path.exists(path):
583            with open(path, "rb") as f:
584                hasher.update(f.read())
585    return hasher.digest()
586
587
588@functools.lru_cache(None)
589def torch_key():
590    """
591    Compute a key that contains relevant information about torch source files
592    """
593    if not config.is_fbcode():
594        inductor_root = os.path.dirname(__file__)
595        extra_files = (
596            "codegen/aoti_runtime/interface.cpp",
597            "codegen/aoti_runtime/implementation.cpp",
598            "codegen/cpp_prefix.h",
599            "script.ld",
600        )
601        return get_code_hash(
602            [inductor_root], [os.path.join(inductor_root, x) for x in extra_files]
603        )
604
605    from libfb.py import parutil
606
607    return parutil.get_file_contents("torch/src_hash.txt").rstrip()
608
609
610def get_inductor_root():
611    return os.path.dirname(__file__)
612
613
614@dataclasses.dataclass
615class OrderedSetHolder:
616    """
617    See FxGraphHashDetails. Holds a sorted list to support stable hashing
618    of set kwargs.
619    """
620
621    items: List[Any]
622
623
624class BypassFxGraphCache(Exception):
625    """
626    Exception to indicate that the FxGraphCache should be bypassed.
627    """
628
629    pass
630
631
632class FxGraphHashDetails:
633    """
634    Object to capture all the details for a compiled FX graph relevant to computing
635    a safe and stable cache key.
636    """
637
638    # Excluded kwargs param that are not stable between runs
639    EXCLUDED_KWARGS = ["graph_id"]
640
641    def __init__(
642        self,
643        gm: torch.fx.GraphModule,
644        example_inputs: List[torch.Tensor],
645        fx_kwargs: Dict[str, Any],
646        inputs_to_check: Sequence[int],
647    ):
648        self.gm = gm
649        self.example_inputs = example_inputs
650
651        # Order kwargs so hashing is stable to changes in kwarg order.
652        self.fx_kwargs = {}
653        for k in sorted(fx_kwargs):
654            if k not in self.EXCLUDED_KWARGS:
655                if type(fx_kwargs[k]) is set:
656                    # Special case to handle set params. Python sets can't be
657                    # ordered, so sort the elements and store them in a proxy.
658                    self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k]))
659                else:
660                    self.fx_kwargs[k] = fx_kwargs[k]
661
662        # Alignment checks
663        self.inputs_to_check = inputs_to_check
664
665        # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels.
666        self.deterministic_algorithms_settings = (
667            torch.are_deterministic_algorithms_enabled(),
668            torch.is_deterministic_algorithms_warn_only_enabled(),
669            torch.utils.deterministic.fill_uninitialized_memory,  # type: ignore[attr-defined]
670        )
671
672        # Global settings affecting matmul codegen.
673        self.cuda_matmul_settings = (
674            torch.backends.cuda.matmul.allow_tf32,
675            torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction,
676            torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction,
677        )
678
679        # Also hash on various system info (including the triton compiler version).
680        self.torch_version = torch_key()
681        self.system_info = CacheBase.get_system()
682        self.inductor_config = config.save_config_portable()
683
684    def debug_str(self) -> str:
685        """
686        Get a printable string describing in more detail all the attributes
687        comprising this object. Useful for debugging when one graph hashes
688        to a different value than another.
689        """
690        return FxGraphCachePickler.debug_str(self)
691
692
693def compiled_fx_graph_hash(
694    gm: torch.fx.GraphModule,
695    example_inputs: List[torch.Tensor],
696    fx_kwargs: Dict[str, Any],
697    inputs_to_check: Sequence[int],
698) -> str:
699    """
700    Generate a unique hash of the FX graph for caching.
701    """
702    details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
703    # The prefix distinguishes among the other kinds of objects we
704    # cache in this module.
705    key = "f" + FxGraphCachePickler.get_hash(details)
706    debug_str = details.debug_str()
707    log.debug(f"FX graph cache hash details for key {key}:\n{debug_str}")  # noqa: G004
708    torch._logging.trace_structured(
709        "artifact",
710        metadata_fn=lambda: {
711            "name": "fx_graph_cache_hash",
712            "encoding": "json",
713        },
714        payload_fn=lambda: json.dumps(
715            {"key": key, "components": debug_str.split("\n")}
716        ),
717    )
718
719    return key
720
721
722class FxGraphCache:
723    """
724    Supports caching and reusing compiled Fx graphs.
725
726    The overall strategy is as follows:
727    - This cache stores entries on disk. When saving an entry, we can't
728      serialize callables (that could be C++, Triton, etc.), so we serialize
729      their own disk cache location. We then recreate the compiled artifact
730      after fetching from disk.
731    - For indexing the cache, we gather the fields relevant to identifying an
732      FxGraph (the graph module, graph inputs, system settings etc.) into an
733      FxGraphCacheDetails object, pickle it, and compute a hash for the key.
734      See FxGraphCachePickler.
735    - Among the metadata we store, we also include a guards expression that's
736      appropriate for validating any symbols for Tensor arguments that have
737      symbolic bounds. On cache lookup then, we evaluate those guards in the
738      current context to validate that a cached entry can be served.
739    - A given graph could have multiple compiled versions, corresponding to
740      different sets of guards. Therefore, we store cache entries in the form:
741          <temp dir>/<fx graph hash>/<serialized metatdata>
742    - On lookup, we compute the key from the graph details, iterate over all
743      leaf files in the corresponding subdirectory, deserialize the entry, and
744      evaluate its guards expression. If the evaluation succeeds, we have a
745      cache hit. If it fails, we compile the graph and store a new entry.
746    - Finally, on a cache hit, we need to make sure any guards that would
747      have been created during compilation are added to the current context.
748    """
749
750    # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs
751    # in an in-memory cache after loading from disk.
752    @staticmethod
753    def _get_tmp_dir() -> str:
754        """
755        Get the toplevel temporary directory for storing compiled graphs.
756        """
757        return os.path.join(cache_dir(), "fxgraph")
758
759    @staticmethod
760    def _get_tmp_dir_for_key(key: str) -> str:
761        """
762        Return the disk location for a given cache key.
763        """
764        return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
765
766    @staticmethod
767    def _filter_backed_symints(inputs: List[Any]) -> List[torch.SymInt]:
768        """
769        Get the backed SymInt objects from the input list. Note that we can never
770        have guards that depend on unbacked symint.
771        """
772        return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)]
773
774    @staticmethod
775    def _get_shape_env() -> Optional[ShapeEnv]:
776        """
777        Helper to get the shape env from the tracing context.
778        """
779        ctx = torch._guards.TracingContext.try_get()
780        if not ctx:
781            return None
782        return ctx.fake_mode.shape_env
783
784    @staticmethod
785    def _lookup_graph(
786        key: str,
787        example_inputs: List[torch.Tensor],
788        local,
789        remote_cache,
790    ) -> Optional[CompiledFxGraph]:
791        """
792        Lookup a compiled graph in the cache by key. On a hit, return the
793        deserialized CompiledFxGraph object. On a miss, return None.
794        """
795        shape_env = FxGraphCache._get_shape_env()
796        assert shape_env is not None
797
798        symints = FxGraphCache._filter_backed_symints(example_inputs)
799        hints = [hint_int(s) for s in symints]
800
801        def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]:
802            if local:
803                subdir = FxGraphCache._get_tmp_dir_for_key(key)
804                if os.path.exists(subdir):
805                    for path in sorted(os.listdir(subdir)):
806                        try:
807                            with open(os.path.join(subdir, path), "rb") as f:
808                                yield pickle.load(f)
809                        except Exception:
810                            log.warning(
811                                "fx graph cache unable to load compiled graph",
812                                exc_info=True,
813                            )
814
815            if remote_cache:
816                try:
817                    if (data := remote_cache.get(key)) is not None:
818                        yield pickle.loads(data)
819                except Exception:
820                    log.warning(
821                        "fx graph cache unable to load compiled graph", exc_info=True
822                    )
823
824        # Iterate over any entries in the subdir for this key and evaluate
825        # their guards to determine whether there's a hit.
826        graph = None
827
828        for candidate in iterate_over_candidates():
829            if not candidate.guards_expr:
830                # No guards to evaluate, so this is a hit.
831                graph = candidate
832                break
833
834            # Evaluate the guard expression in the current context.
835            # If there's not a cache hit, we don't want the evaluation to
836            # affect the current env, e.g., cause the creation of new guards,
837            # so we evaluate with the hints instead of the symbols.
838            hit = bool(
839                shape_env.evaluate_guards_expression(candidate.guards_expr, hints)
840            )
841            log.debug(
842                "fx graph cache key %s evaluating guards [%s] with values %s => hit=%s",
843                key,
844                candidate.guards_expr,
845                hints,
846                hit,
847            )
848            if hit:
849                graph = candidate
850                break
851
852        if graph is None:
853            return None
854
855        # See _save_graph(); we don't store the callable in the cache entry so
856        # recreate it here from the PyCodeCache disk cache.
857        artifact_path = get_path(graph.cache_key, "py")[2]
858        if not os.path.exists(artifact_path):
859            counters["inductor"]["fxgraph_lookup_write_file"] += 1
860            Path(os.path.dirname(artifact_path)).mkdir(parents=True, exist_ok=True)
861            code = graph.source_code
862            cpp_pp = cpp_prefix_path()
863            if os.path.basename(cpp_pp) in code:
864                if cpp_pp in code:
865                    # Great the name is correct
866                    pass
867                else:
868                    # Old dir name is included, replace it
869                    pattern = rf'#include\s*"[^"]+{os.path.basename(cpp_pp)}"'
870                    code = re.sub(pattern, f'#include "{cpp_pp}"', code)
871
872            write_atomic(artifact_path, code, make_dirs=True)
873
874        try:
875            graph.current_callable = PyCodeCache.load_by_key_path(
876                graph.cache_key,
877                artifact_path,
878                graph.cache_linemap,
879                graph.constants,
880            ).call
881        except OSError:
882            # Not expected, but in case the PyCodeCache entry is removed from
883            # underneath us, treat it as a cache miss and recompile.
884            log.error("Failed to load cached artifact: %s", artifact_path)
885            return None
886
887        # Now re-evaluate with the symints to add any guards to the current env.
888        if graph.guards_expr:
889            check = bool(
890                shape_env.evaluate_guards_expression(graph.guards_expr, symints)
891            )
892            assert check is True
893            log.debug(
894                "fx graph cache key %s post-load guards: %s", key, shape_env.guards
895            )
896
897        # Increment the cached metrics by the amounts recorded when the FX
898        # graph was compiled for this cache entry. Pretending these counters
899        # were incremented normally is useful for testing with the cache enabled.
900        metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas)
901
902        return graph
903
904    @staticmethod
905    def _save_graph(
906        key: str,
907        compiled_graph: CompiledFxGraph,
908        example_inputs: List[torch.Tensor],
909        time_taken_ns,
910        local,
911        remote_cache,
912    ):
913        """
914        Store a serialized CompiledFxGraph on disk.
915        """
916        disk_compiled_graph = copy(compiled_graph)
917        # We can't really serialize callables that may be C++/Triton/etc.,
918        # so we serialize their PyCodeCache disk cache location instead.
919        # TODO: This could be better if we're ever able to serialize compiled
920        # models to disk.
921        disk_compiled_graph.current_callable = None
922
923        # Before serializing, compute the guard expression that will be used to
924        # ensure that a CompiledFxGraph is valid when loaded from the cache. It's
925        # sufficient to consider only the SymInt args to the fx graph since the
926        # Tensor shapes are already captured in the hash for the cache key. Any
927        # Tensor arg with a symbolic shape will have a SymInt arg for the graph.
928        shape_env = FxGraphCache._get_shape_env()
929        assert shape_env is not None
930        symints = FxGraphCache._filter_backed_symints(example_inputs)
931        guards = shape_env.get_pruned_guards(symints)
932        disk_compiled_graph.guards_expr = shape_env.produce_guards_expression(
933            placeholders=symints, guards=guards
934        )
935
936        try:
937            content = pickle.dumps(disk_compiled_graph)
938        except Exception:
939            log.warning(
940                "fx graph cache unable to serialize compiled graph", exc_info=True
941            )
942            counters["inductor"]["fxgraph_cache_pickle_error"] += 1
943            return
944
945        try:
946            if local:
947                subdir = FxGraphCache._get_tmp_dir_for_key(key)
948                if not os.path.exists(subdir):
949                    os.makedirs(subdir, exist_ok=True)
950
951                # Use a hash of the serialized CompiledFxGraph to get a unique file
952                # name. The specific name doesn't matter since a lookup involves
953                # iterating over all entries in the parent subdir.
954                path = os.path.join(subdir, sha256_hash(content))
955                write_atomic(path, content, make_dirs=True)
956
957            if remote_cache:
958                cache_data = (
959                    {
960                        "data": content,
961                        "time_taken_ms": time_taken_ns
962                        // 1000000,  # Convert from NS to MS
963                    }
964                    if config.is_fbcode()
965                    else content
966                )
967                remote_cache.put(key, cache_data)
968        except Exception:
969            log.warning("fx graph unable to write to cache", exc_info=True)
970            counters["inductor"]["fxgraph_cache_write_error"] += 1
971
972    @staticmethod
973    def _check_can_cache(gm: torch.fx.GraphModule):
974        """
975        Check some conditions that would preclude caching and raise BypassFxGraphCache
976        to bypass in case caching is not possible.
977        """
978        # Freezing can embed constants that wouldn't be static across runs.
979        if config.freezing or config.aot_inductor.use_runtime_constant_folding:
980            raise BypassFxGraphCache
981
982        # The treatment of guards in the caching implementation requires that
983        # we have a shape env.
984        if FxGraphCache._get_shape_env() is None:
985            log.debug("fx graph cache no shape env")
986            raise BypassFxGraphCache
987
988        # HigherOrderOperators should be handled on a case-by-case basis.
989        # Currently, we just skip caching if we have any.
990        # We also skip if there are any torchbind objects.
991        for node in gm.graph.nodes:
992            if isinstance(node.target, torch._ops.HigherOrderOperator):
993                raise BypassFxGraphCache
994            if node.op == "getattr" and isinstance(
995                getattr(gm, node.target), torch._C.ScriptObject
996            ):
997                raise BypassFxGraphCache
998
999    @staticmethod
1000    def load(
1001        compile_fx_fn: Callable[..., Any],
1002        gm: torch.fx.GraphModule,
1003        example_inputs: List[torch.Tensor],
1004        fx_kwargs: Dict[str, Any],
1005        inputs_to_check: Sequence[int],
1006        local: bool,
1007        remote: bool,
1008    ):
1009        """
1010        Load a compiled graph from the cache. If a cached entry does not exist,
1011        compile the graph and save it to the cache.
1012        """
1013        assert local or remote, "at least one of them needs to be enabled"
1014        compiled_graph = None
1015        try:
1016            FxGraphCache._check_can_cache(gm)
1017            key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs, inputs_to_check)
1018
1019            remote_cache = None
1020            if remote:
1021                cache_id = "fx-graph-v1"
1022                try:
1023                    if config.is_fbcode():
1024                        from triton.fb.fb_memcache import (
1025                            FbMemcacheRemoteFxGraphCacheBackend,
1026                        )
1027
1028                        remote_cache = FbMemcacheRemoteFxGraphCacheBackend(cache_id)
1029                    else:
1030                        from torch._inductor.remote_cache import RedisRemoteCacheBackend
1031
1032                        remote_cache = RedisRemoteCacheBackend(cache_id)
1033                except Exception:
1034                    remote_cache = None
1035                    log.warning("Unable to create a remote cache", exc_info=True)
1036
1037            compiled_graph = FxGraphCache._lookup_graph(
1038                key, example_inputs, local, remote_cache
1039            )
1040            if compiled_graph is None:
1041                log.debug("fx graph cache miss for key %s", key)
1042                counters["inductor"]["fxgraph_cache_miss"] += 1
1043                start_time = time_ns()
1044                compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs)
1045                time_taken_ns = time_ns() - start_time
1046                FxGraphCache._save_graph(
1047                    key,
1048                    compiled_graph,
1049                    example_inputs,
1050                    time_taken_ns,
1051                    local,
1052                    remote_cache,
1053                )
1054            else:
1055                log.debug("fx graph cache hit for key %s", key)
1056                counters["inductor"]["fxgraph_cache_hit"] += 1
1057        except BypassFxGraphCache:
1058            counters["inductor"]["fxgraph_cache_bypass"] += 1
1059            if not compiled_graph:
1060                compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs)
1061
1062        return compiled_graph
1063
1064    @staticmethod
1065    def clear():
1066        """
1067        Clear out the on-disk cache.
1068        """
1069        try:
1070            shutil.rmtree(FxGraphCache._get_tmp_dir())
1071        except FileNotFoundError:
1072            pass
1073
1074
1075@dataclasses.dataclass
1076class CompiledFxGraph:
1077    """
1078    Class holding a compiled FX graph. This is the object serialized on disk
1079    to support FxGraph caching.
1080    """
1081
1082    current_callable: Optional[Callable[..., Any]]
1083    cache_key: str
1084    source_code: str = dataclasses.field(repr=False)  # Do not display source_code
1085    cache_linemap: Optional[List[Tuple[int, str]]]
1086    device_types: Set[str]
1087    device_idxs: Set[int]
1088    mutated_inputs: Set[str]
1089    mutated_input_idxs: Set[int]
1090    constants: Dict[str, torch.Tensor]
1091    torchbind_constants: Dict[str, torch._C.ScriptObject]
1092    output_strides: Optional[List[Optional[Tuple[int, ...]]]]
1093    disabled_cudagraphs_reason: Optional[str]
1094    metrics_deltas: metrics.CachedMetricsDeltas
1095    # This is a string representation of an expression we serialize
1096    # with the object so the guards can be evaluated in a different
1097    # context in order to verify the validity of serving a cached
1098    # fx graph. The expression must be generated by:
1099    # ShapeEnv.produce_guards_expression()
1100    guards_expr: Optional[str]
1101
1102    _boxed_call: Optional[bool] = None
1103
1104    def __init__(
1105        self,
1106        current_callable: Optional[Callable[..., Any]],
1107        graph: GraphLowering,
1108        output_strides: List[Optional[Tuple[int, ...]]],
1109        disabled_cudagraphs_reason: Optional[str],
1110        metrics_deltas: metrics.CachedMetricsDeltas,
1111    ):
1112        self.current_callable = current_callable
1113        self.cache_key = graph.cache_key
1114        if graph.cache_path:
1115            with open(graph.cache_path) as f:
1116                self.source_code = f.read()
1117        self.cache_linemap = graph.cache_linemap
1118        self.device_types = graph.device_types
1119        self.device_idxs = graph.device_idxs
1120        self.mutated_inputs = graph.mutated_inputs
1121        self.mutated_input_idxs = set(graph.mutated_input_idxs)
1122        self.constants = graph.constants
1123        self.torchbind_constants = graph.torchbind_constants
1124        self.output_strides = output_strides
1125        self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
1126        self.metrics_deltas = metrics_deltas
1127        self.guards_expr = None
1128
1129    def __call__(self, inputs: List[Any]) -> Any:
1130        assert self.current_callable is not None
1131        return self.current_callable(inputs)
1132
1133
1134def cpp_compiler() -> str:
1135    if config.is_fbcode():
1136        return build_paths.cc() if torch.version.hip is None else build_paths.clang()
1137    if isinstance(config.cpp.cxx, (list, tuple)):
1138        search = tuple(config.cpp.cxx)
1139    else:
1140        search = (config.cpp.cxx,)
1141    return cpp_compiler_search(search)
1142
1143
1144@functools.lru_cache(1)
1145def cpp_compiler_search(search: str) -> str:
1146    for cxx in search:
1147        try:
1148            if cxx is None:
1149                # gxx package is only available for Linux
1150                # according to https://anaconda.org/conda-forge/gxx/
1151                if sys.platform != "linux":
1152                    continue
1153                # Do not install GXX by default
1154                if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"):
1155                    continue
1156                from filelock import FileLock
1157
1158                lock_dir = get_lock_dir()
1159                lock = FileLock(
1160                    os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT
1161                )
1162                with lock:
1163                    cxx = install_gcc_via_conda()
1164            subprocess.check_output([cxx, "--version"])
1165            return cxx
1166        except (subprocess.SubprocessError, FileNotFoundError, ImportError):
1167            continue
1168    raise exc.InvalidCxxCompiler
1169
1170
1171def install_gcc_via_conda() -> str:
1172    """On older systems, this is a quick way to get a modern compiler"""
1173    prefix = os.path.join(cache_dir(), "gcc")
1174    cxx_path = os.path.join(prefix, "bin", "g++")
1175    if not os.path.exists(cxx_path):
1176        log.info("Downloading GCC via conda")
1177        conda = os.environ.get("CONDA_EXE", "conda")
1178        if conda is None:
1179            conda = shutil.which("conda")
1180        if conda is not None:
1181            subprocess.check_call(
1182                [
1183                    conda,
1184                    "create",
1185                    f"--prefix={prefix}",
1186                    "--channel=conda-forge",
1187                    "--quiet",
1188                    "-y",
1189                    "python=3.8",
1190                    "gxx",
1191                ],
1192                stdout=subprocess.PIPE,
1193            )
1194    return cxx_path
1195
1196
1197def is_gcc() -> bool:
1198    if sys.platform == "darwin" and is_apple_clang():
1199        return False
1200    return bool(re.search(r"(gcc|g\+\+)", cpp_compiler()))
1201
1202
1203@functools.lru_cache(None)
1204def is_apple_clang() -> bool:
1205    cxx = cpp_compiler()
1206    version_string = subprocess.check_output([cxx, "--version"]).decode("utf8")
1207    return "Apple" in version_string.splitlines()[0]
1208
1209
1210def is_clang() -> bool:
1211    # Mac OS apple clang maybe named as gcc, need check compiler info.
1212    if sys.platform == "darwin":
1213        return is_apple_clang()
1214    return bool(re.search(r"(clang|clang\+\+)", cpp_compiler()))
1215
1216
1217def get_compiler_version_info(compiler):
1218    SUBPROCESS_DECODE_ARGS = ("oem",) if _IS_WINDOWS else ()
1219    env = os.environ.copy()
1220    env["LC_ALL"] = "C"  # Don't localize output
1221    try:
1222        version_string = subprocess.check_output(
1223            [compiler, "-v"], stderr=subprocess.STDOUT, env=env
1224        ).decode(*SUBPROCESS_DECODE_ARGS)
1225    except Exception as e:
1226        try:
1227            version_string = subprocess.check_output(
1228                [compiler, "--version"], stderr=subprocess.STDOUT, env=env
1229            ).decode(*SUBPROCESS_DECODE_ARGS)
1230        except Exception as e:
1231            return ""
1232    # Mutiple lines to one line string.
1233    version_string = version_string.replace("\r", "_")
1234    version_string = version_string.replace("\n", "_")
1235    return version_string
1236
1237
1238def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
1239    # ISA dry compile will cost about 1 sec time each startup time.
1240    # Please check the issue: https://github.com/pytorch/pytorch/issues/100378
1241    # Actually, dry compile is checking compile capability for ISA.
1242    # We just record the compiler version, isa options and pytorch version info,
1243    # and generated them to output binary hash path.
1244    # It would optimize and skip compile existing binary.
1245    compiler_info = get_compiler_version_info(cpp_compiler())
1246    torch_version = torch.__version__
1247    fingerprint = f"{compiler_info}={isa_flags}={torch_version}"
1248    return fingerprint
1249
1250
1251class VecISA:
1252    _bit_width: int
1253    _macro: List[str]
1254    _arch_flags: str
1255    _dtype_nelements: Dict[torch.dtype, int]
1256
1257    # Note [Checking for Vectorized Support in Inductor]
1258    # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
1259    # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
1260    # like exp, pow, sin, cos and etc.
1261    # But PyTorch and TorchInductor might use different compilers to build code. If
1262    # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
1263    # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
1264    # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
1265    # gcc/g++ compiler by default while it could support the AVX512 compilation.
1266    # Therefore, there would be a conflict sleef version between PyTorch and
1267    # TorchInductor. Hence, we dry-compile the following code to check whether current
1268    # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
1269    # also needs the logic
1270    # In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
1271    # making the runtime check unnecessary.
1272    _avx_code = """
1273#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
1274#include <ATen/cpu/vec/functional.h>
1275#include <ATen/cpu/vec/vec.h>
1276#endif
1277
1278alignas(64) float in_out_ptr0[16] = {0.0};
1279
1280extern "C" void __avx_chk_kernel() {
1281    auto tmp0 = at::vec::Vectorized<float>(1);
1282    auto tmp1 = tmp0.exp();
1283    tmp1.store(in_out_ptr0);
1284}
1285"""  # noqa: B950
1286
1287    _avx_py_load = """
1288import torch
1289from ctypes import cdll
1290cdll.LoadLibrary("__lib_path__")
1291"""
1292
1293    def bit_width(self) -> int:
1294        return self._bit_width
1295
1296    def nelements(self, dtype: torch.dtype = torch.float) -> int:
1297        return self._dtype_nelements[dtype]
1298
1299    def build_macro(self) -> List[str]:
1300        return self._macro
1301
1302    def build_arch_flags(self) -> str:
1303        return self._arch_flags
1304
1305    def __hash__(self) -> int:
1306        return hash(str(self))
1307
1308    @functools.lru_cache(None)  # noqa: B019
1309    def __bool__(self) -> bool:
1310        from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions
1311
1312        if config.cpp.vec_isa_ok is not None:
1313            return config.cpp.vec_isa_ok
1314
1315        if config.is_fbcode():
1316            return True
1317
1318        key, input_path = write(
1319            VecISA._avx_code,
1320            "cpp",
1321            extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
1322        )
1323        from filelock import FileLock
1324
1325        lock_dir = get_lock_dir()
1326        lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
1327        with lock:
1328            output_dir = os.path.dirname(input_path)
1329            buid_options = CppTorchOptions(vec_isa=self, warning_all=False)
1330            x86_isa_help_builder = CppBuilder(
1331                key,
1332                [input_path],
1333                buid_options,
1334                output_dir,
1335            )
1336            try:
1337                # Check if the output file exist, and compile when not.
1338                output_path = x86_isa_help_builder.get_target_file_path()
1339                if not os.path.isfile(output_path):
1340                    status, target_file = x86_isa_help_builder.build()
1341                    if status:
1342                        return False
1343
1344                # Check build result
1345                subprocess.check_call(
1346                    [
1347                        sys.executable,
1348                        "-c",
1349                        VecISA._avx_py_load.replace("__lib_path__", output_path),
1350                    ],
1351                    stderr=subprocess.DEVNULL,
1352                    env={**os.environ, "PYTHONPATH": ":".join(sys.path)},
1353                )
1354            except Exception as e:
1355                return False
1356
1357            return True
1358
1359
1360@dataclasses.dataclass
1361class VecNEON(VecISA):
1362    _bit_width = 256  # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
1363    _macro = ["CPU_CAPABILITY_NEON"]
1364    if sys.platform == "darwin" and platform.processor() == "arm":
1365        _macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF")
1366    _arch_flags = ""  # Unused
1367    _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
1368
1369    def __str__(self) -> str:
1370        return "asimd"  # detects the presence of advanced SIMD on armv8-a kernels
1371
1372    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
1373
1374
1375@dataclasses.dataclass
1376class VecAVX512(VecISA):
1377    _bit_width = 512
1378    _macro = ["CPU_CAPABILITY_AVX512"]
1379    _arch_flags = (
1380        "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
1381        if not _IS_WINDOWS
1382        else "/arch:AVX512"
1383    )  # TODO: use cflags
1384    _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
1385
1386    def __str__(self) -> str:
1387        return "avx512"
1388
1389    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
1390
1391
1392@dataclasses.dataclass
1393class VecAVX2(VecISA):
1394    _bit_width = 256
1395    _macro = ["CPU_CAPABILITY_AVX2"]
1396    _arch_flags = (
1397        "-mavx2 -mfma" if not _IS_WINDOWS else "/arch:AVX2"
1398    )  # TODO: use cflags
1399    _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
1400
1401    def __str__(self) -> str:
1402        return "avx2"
1403
1404    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
1405
1406
1407@dataclasses.dataclass
1408class VecZVECTOR(VecISA):
1409    _bit_width = 256
1410    _macro = [
1411        "CPU_CAPABILITY_ZVECTOR",
1412        "CPU_CAPABILITY=ZVECTOR",
1413        "HAVE_ZVECTOR_CPU_DEFINITION",
1414    ]
1415    _arch_flags = "-mvx -mzvector"
1416    _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
1417
1418    def __str__(self) -> str:
1419        return "zvector"
1420
1421    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
1422
1423
1424class InvalidVecISA(VecISA):
1425    _bit_width = 0
1426    _macro = [""]
1427    _arch_flags = ""
1428    _dtype_nelements = {}
1429
1430    def __str__(self) -> str:
1431        return "INVALID_VEC_ISA"
1432
1433    def __bool__(self) -> bool:  # type: ignore[override]
1434        return False
1435
1436    __hash__: Callable[[VecISA], Any] = VecISA.__hash__
1437
1438
1439def x86_isa_checker() -> List[str]:
1440    supported_isa: List[str] = []
1441
1442    def _check_and_append_supported_isa(
1443        dest: List[str], isa_supported: bool, isa_name: str
1444    ):
1445        if isa_supported:
1446            dest.append(isa_name)
1447
1448    Arch = platform.machine()
1449    """
1450    Arch value is x86_64 on Linux, and the value is AMD64 on Windows.
1451    """
1452    if Arch != "x86_64" and Arch != "AMD64":
1453        return supported_isa
1454
1455    avx2 = torch.cpu._is_cpu_support_avx2()
1456    avx512 = torch.cpu._is_cpu_support_avx512()
1457
1458    _check_and_append_supported_isa(supported_isa, avx2, "avx2")
1459    _check_and_append_supported_isa(supported_isa, avx512, "avx512")
1460
1461    return supported_isa
1462
1463
1464invalid_vec_isa = InvalidVecISA()
1465supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()]
1466
1467
1468# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
1469# might have too much redundant content that is useless for ISA check. Hence,
1470# we only cache some key isa information.
1471@functools.lru_cache(None)
1472def valid_vec_isa_list() -> List[VecISA]:
1473    isa_list: List[VecISA] = []
1474    if sys.platform == "darwin" and platform.processor() == "arm":
1475        isa_list.append(VecNEON())
1476
1477    if sys.platform not in ["linux", "win32"]:
1478        return isa_list
1479
1480    if platform.machine() == "s390x":
1481        with open("/proc/cpuinfo") as _cpu_info:
1482            while True:
1483                line = _cpu_info.readline()
1484                if not line:
1485                    break
1486                # process line
1487                featuresmatch = re.match(r"^features\s*:\s*(.*)$", line)
1488                if featuresmatch:
1489                    for group in featuresmatch.groups():
1490                        if re.search(r"[\^ ]+vxe[\$ ]+", group):
1491                            isa_list.append(VecZVECTOR())
1492                            break
1493    elif platform.machine() == "aarch64":
1494        isa_list.append(VecNEON())
1495    elif platform.machine() in ["x86_64", "AMD64"]:
1496        """
1497        platform.machine() value is x86_64 on Linux, and the value is AMD64 on Windows.
1498        """
1499        _cpu_supported_x86_isa = x86_isa_checker()
1500        for isa in supported_vec_isa_list:
1501            if str(isa) in _cpu_supported_x86_isa and isa:
1502                isa_list.append(isa)
1503
1504    return isa_list
1505
1506
1507def pick_vec_isa() -> VecISA:
1508    if config.is_fbcode():
1509        return VecAVX2()
1510
1511    _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
1512    if not _valid_vec_isa_list:
1513        return invalid_vec_isa
1514
1515    # If the simdlen is None, it indicates determine the vectorization length automatically
1516    if config.cpp.simdlen is None:
1517        assert _valid_vec_isa_list
1518        return _valid_vec_isa_list[0]
1519
1520    for isa in _valid_vec_isa_list:
1521        if config.cpp.simdlen == isa.bit_width():
1522            return isa
1523
1524    return invalid_vec_isa
1525
1526
1527def get_compile_only(compile_only: bool = True) -> str:
1528    return "-c" if compile_only else ""
1529
1530
1531def get_shared(shared: bool = True, compile_only: bool = False) -> str:
1532    if not shared:
1533        return ""
1534    if compile_only:
1535        return "-fPIC"
1536    if platform.system() == "Darwin" and "clang" in cpp_compiler():
1537        # This causes undefined symbols to behave the same as linux
1538        return "-shared -fPIC -undefined dynamic_lookup"
1539    else:
1540        return "-shared -fPIC"
1541
1542
1543def get_warning_all_flag(warning_all: bool = True) -> str:
1544    return "-Wall" if warning_all else ""
1545
1546
1547def get_glibcxx_abi_build_flags() -> str:
1548    return "-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
1549
1550
1551def cpp_flags() -> str:
1552    flags = ["-std=c++17", "-Wno-unused-variable", "-Wno-unknown-pragmas"]
1553    if is_clang():
1554        flags.append("-Werror=ignored-optimization-argument")
1555    return " ".join(flags)
1556
1557
1558def cpp_wrapper_flags() -> str:
1559    return "-D TORCH_INDUCTOR_CPP_WRAPPER"
1560
1561
1562def optimization_flags() -> str:
1563    base_flags = "-O0 -g" if config.aot_inductor.debug_compile else "-O3 -DNDEBUG"
1564    base_flags += " -ffast-math -fno-finite-math-only"
1565    if not config.cpp.enable_unsafe_math_opt_flag:
1566        base_flags += " -fno-unsafe-math-optimizations"
1567    if not config.cpp.enable_floating_point_contract_flag:
1568        base_flags += " -ffp-contract=off"
1569
1570    if config.is_fbcode():
1571        # FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies.
1572        # This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths.
1573        # We will fix it later by exposing the lib path.
1574        return base_flags
1575
1576    if sys.platform == "darwin":
1577        # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
1578        # Also, `-march=native` is unrecognized option on M1
1579        base_flags += " -Xclang"
1580    else:
1581        if platform.machine() == "ppc64le":
1582            base_flags += " -mcpu=native"
1583        else:
1584            base_flags += " -march=native"
1585
1586    # Internal cannot find libgomp.so
1587    if not config.is_fbcode():
1588        base_flags += " -fopenmp"
1589    return base_flags
1590
1591
1592def use_custom_generated_macros() -> str:
1593    return "-D C10_USING_CUSTOM_GENERATED_MACROS"
1594
1595
1596def use_fb_internal_macros() -> str:
1597    if config.is_fbcode():
1598        # TODO: this is to avoid FC breakage for fbcode. When using newly
1599        # generated model.so on an older verion of PyTorch, need to use
1600        # the v1 version for aoti_torch_create_tensor_from_blob
1601        create_tensor_from_blob_v1 = "-D AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1"
1602        openmp_lib = build_paths.openmp_lib()
1603        preprocessor_flags = " ".join(
1604            (
1605                "-D C10_USE_GLOG",
1606                "-D C10_USE_MINIMAL_GLOG",
1607                "-D C10_DISABLE_TENSORIMPL_EXTENSIBILITY",
1608            )
1609        )
1610        return f"-Wp,-fopenmp {openmp_lib} {preprocessor_flags} {create_tensor_from_blob_v1}"
1611    else:
1612        return ""
1613
1614
1615def use_standard_sys_dir_headers() -> str:
1616    if config.is_fbcode():
1617        return "-nostdinc"
1618    else:
1619        return ""
1620
1621
1622@functools.lru_cache(None)
1623def is_conda_llvm_openmp_installed() -> bool:
1624    try:
1625        command = "conda list llvm-openmp --json"
1626        output = subprocess.check_output(command.split()).decode("utf8")
1627        return len(json.loads(output)) > 0
1628    except subprocess.SubprocessError:
1629        return False
1630
1631
1632@functools.lru_cache(None)
1633def homebrew_libomp() -> Tuple[bool, str]:
1634    try:
1635        # check if `brew` is installed
1636        subprocess.check_output(["which", "brew"])
1637        # get the location of `libomp` if it is installed
1638        # this is the location that `libomp` **would** be installed
1639        # see https://github.com/Homebrew/brew/issues/10261#issuecomment-756563567 for details
1640        libomp_path = (
1641            subprocess.check_output(["brew", "--prefix", "libomp"])
1642            .decode("utf8")
1643            .strip()
1644        )
1645        # check if `libomp` is installed
1646        omp_available = os.path.exists(libomp_path)
1647        return omp_available, libomp_path
1648    except subprocess.SubprocessError:
1649        return False, ""
1650
1651
1652def _set_gpu_runtime_env() -> None:
1653    if (
1654        config.is_fbcode()
1655        and torch.version.hip is None
1656        and "CUDA_HOME" not in os.environ
1657        and "CUDA_PATH" not in os.environ
1658    ):
1659        os.environ["CUDA_HOME"] = build_paths.cuda()
1660
1661
1662def _get_python_include_dirs():
1663    include_dir = Path(sysconfig.get_path("include"))
1664    # On Darwin Python executable from a framework can return
1665    # non-existing /Library/Python/... include path, in which case
1666    # one should use Headers folder from the framework
1667    if not include_dir.exists() and platform.system() == "Darwin":
1668        std_lib = Path(sysconfig.get_path("stdlib"))
1669        include_dir = (std_lib.parent.parent / "Headers").absolute()
1670    if not (include_dir / "Python.h").exists():
1671        warnings.warn(f"Can't find Python.h in {str(include_dir)}")
1672    return [str(include_dir)]
1673
1674
1675def _transform_cuda_paths(lpaths):
1676    # This handles two cases:
1677    # 1. Meta internal cuda-12 where libs are in lib/cuda-12 and lib/cuda-12/stubs
1678    # 2. Linux machines may have CUDA installed under either lib64/ or lib/
1679    for i, path in enumerate(lpaths):
1680        if (
1681            "CUDA_HOME" in os.environ
1682            and path.startswith(os.environ["CUDA_HOME"])
1683            and not os.path.exists(f"{path}/libcudart_static.a")
1684        ):
1685            for root, dirs, files in os.walk(path):
1686                if "libcudart_static.a" in files:
1687                    lpaths[i] = os.path.join(path, root)
1688                    lpaths.append(os.path.join(lpaths[i], "stubs"))
1689                    break
1690
1691
1692def get_include_and_linking_paths(
1693    include_pytorch: bool = False,
1694    vec_isa: VecISA = invalid_vec_isa,
1695    cuda: bool = False,
1696    aot_mode: bool = False,
1697) -> Tuple[List[str], str, str, str, str]:
1698    _set_gpu_runtime_env()
1699    from torch.utils import cpp_extension
1700
1701    # Remove below in the further
1702    # macros = "-D {}".format(vec_isa.build_macro()) if vec_isa != invalid_vec_isa else ""
1703    macros = ""
1704    if vec_isa != invalid_vec_isa:
1705        for x in vec_isa.build_macro():
1706            macros_def = f"-D {x} "
1707            macros += macros_def
1708
1709    build_arch_flags = ""
1710    if sys.platform == "linux" and (
1711        include_pytorch
1712        or vec_isa != invalid_vec_isa
1713        or cuda
1714        or config.cpp.enable_kernel_profile
1715    ):
1716        # Note - We include pytorch only on linux right now. There is more work
1717        # to do to enable OMP build on darwin where PyTorch is built with IOMP
1718        # and we need a way to link to what PyTorch links.
1719        ipaths = cpp_extension.include_paths(cuda) + _get_python_include_dirs()
1720        lpaths = cpp_extension.library_paths(cuda) + [
1721            sysconfig.get_config_var("LIBDIR")
1722        ]
1723
1724        libs = []
1725
1726        # No need to manually specify libraries in fbcode.
1727        if not config.is_fbcode():
1728            libs += ["torch", "torch_cpu"]
1729            libs += ["gomp"]
1730            if not aot_mode:
1731                libs += ["torch_python"]
1732        else:
1733            # internal remote execution is able to find omp, but not gomp
1734            libs += ["omp"]
1735            if aot_mode:
1736                ipaths += [os.path.dirname(cpp_prefix_path())]
1737                if cuda and torch.version.hip is None:
1738                    _transform_cuda_paths(lpaths)
1739        if macros:
1740            if config.is_fbcode() and vec_isa != invalid_vec_isa:
1741                cap = str(vec_isa).upper()
1742                macros = " ".join(
1743                    [
1744                        vec_isa.build_arch_flags(),
1745                        f"-D CPU_CAPABILITY={cap}",
1746                        f"-D CPU_CAPABILITY_{cap}",
1747                        f"-D HAVE_{cap}_CPU_DEFINITION",
1748                    ]
1749                )
1750
1751        if cuda:
1752            if macros is None:
1753                macros = ""
1754            macros += " -D USE_ROCM" if torch.version.hip else " -D USE_CUDA"
1755
1756        if cuda:
1757            if torch.version.hip is not None:
1758                if config.is_fbcode():
1759                    libs += ["amdhip64"]
1760                else:
1761                    libs += ["c10_hip", "torch_hip"]
1762                macros += " -D __HIP_PLATFORM_AMD__"
1763            else:
1764                if config.is_fbcode():
1765                    libs += ["cuda"]
1766                else:
1767                    libs += ["c10_cuda", "cuda", "torch_cuda"]
1768        build_arch_flags = vec_isa.build_arch_flags()
1769    else:
1770        # Note - this is effectively a header only inclusion. Usage of some header files may result in
1771        # symbol not found, if those header files require a library.
1772        # For those cases, include the lpath and libs command as we do for pytorch above.
1773        # This approach allows us to only pay for what we use.
1774        ipaths = cpp_extension.include_paths(cuda) + _get_python_include_dirs()
1775        if aot_mode:
1776            ipaths += [os.path.dirname(cpp_prefix_path())]
1777        lpaths = []
1778        if sys.platform == "darwin":
1779            # only Apple builtin compilers (Apple Clang++) require openmp
1780            omp_available = not is_apple_clang()
1781
1782            # check the `OMP_PREFIX` environment first
1783            if os.getenv("OMP_PREFIX") is not None:
1784                header_path = os.path.join(os.getenv("OMP_PREFIX"), "include", "omp.h")  # type: ignore[arg-type]
1785                valid_env = os.path.exists(header_path)
1786                if valid_env:
1787                    ipaths.append(os.path.join(os.getenv("OMP_PREFIX"), "include"))  # type: ignore[arg-type]
1788                    lpaths.append(os.path.join(os.getenv("OMP_PREFIX"), "lib"))  # type: ignore[arg-type]
1789                else:
1790                    warnings.warn("environment variable `OMP_PREFIX` is invalid.")
1791                omp_available = omp_available or valid_env
1792
1793            libs = [] if omp_available else ["omp"]
1794
1795            # prefer to use openmp from `conda install llvm-openmp`
1796            if not omp_available and os.getenv("CONDA_PREFIX") is not None:
1797                omp_available = is_conda_llvm_openmp_installed()
1798                if omp_available:
1799                    conda_lib_path = os.path.join(os.getenv("CONDA_PREFIX"), "lib")  # type: ignore[arg-type]
1800                    ipaths.append(os.path.join(os.getenv("CONDA_PREFIX"), "include"))  # type: ignore[arg-type]
1801                    lpaths.append(conda_lib_path)
1802                    # Prefer Intel OpenMP on x86 machine
1803                    if os.uname().machine == "x86_64" and os.path.exists(
1804                        os.path.join(conda_lib_path, "libiomp5.dylib")
1805                    ):
1806                        libs = ["iomp5"]
1807
1808            # next, try to use openmp from `brew install libomp`
1809            if not omp_available:
1810                omp_available, libomp_path = homebrew_libomp()
1811                if omp_available:
1812                    ipaths.append(os.path.join(libomp_path, "include"))
1813                    lpaths.append(os.path.join(libomp_path, "lib"))
1814
1815            # if openmp is still not available, we let the compiler to have a try,
1816            # and raise error together with instructions at compilation error later
1817        else:
1818            libs = ["omp"] if config.is_fbcode() else ["gomp"]
1819
1820        # For AOT mode, the produced library relies on torch cpu to set grad mode
1821        # like aoti_torch_grad_mode_set_enabled
1822        if aot_mode and sys.platform == "linux" and not config.is_fbcode():
1823            libs += ["torch", "torch_cpu"]
1824
1825    # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690
1826    if not config.abi_compatible:
1827        libs += ["c10"]
1828        lpaths += [cpp_extension.TORCH_LIB_PATH]
1829
1830    # third party libs
1831    if config.is_fbcode():
1832        # Note that the order of include paths do matter, as a result
1833        # we need to have several branches interleaved here
1834        if torch.version.hip is None:
1835            ipaths.append(build_paths.sleef())
1836        ipaths.append(build_paths.openmp())
1837        ipaths.append(build_paths.python())
1838        if torch.version.hip is not None:
1839            ipaths.append(build_paths.clang_include())
1840            ipaths.append(build_paths.gcc_include())
1841            ipaths.append(build_paths.gcc_install_tools_include())
1842        else:
1843            ipaths.append(build_paths.cc_include())
1844            ipaths.append(build_paths.libgcc())
1845            ipaths.append(build_paths.libgcc_arch())
1846        ipaths.append(build_paths.libgcc_backward())
1847        ipaths.append(build_paths.glibc())
1848        ipaths.append(build_paths.linux_kernel())
1849        if torch.version.hip is not None:
1850            ipaths.append(build_paths.rocm())
1851        else:
1852            ipaths.append(os.path.join(build_paths.cuda(), "include"))
1853        # We also need to bundle includes with absolute paths into a remote directory
1854        # (later on, we copy the include paths from cpp_extensions into our remote dir)
1855        ipaths.append("include")
1856
1857    static_link_libs = []
1858    if aot_mode and cuda and config.is_fbcode():
1859        # For Meta internal cuda-12, it is recommended to static link cudart
1860        if torch.version.hip is None:
1861            static_link_libs = ["-Wl,-Bstatic", "-lcudart_static", "-Wl,-Bdynamic"]
1862
1863    lpaths_str = " ".join(["-L" + p for p in lpaths])
1864    libs_str = " ".join(static_link_libs + ["-l" + p for p in libs])
1865    return ipaths, lpaths_str, libs_str, macros, build_arch_flags
1866
1867
1868def cpp_compile_command(
1869    input: Union[str, List[str]],
1870    output: str,
1871    warning_all: bool = True,
1872    shared: bool = True,
1873    include_pytorch: bool = False,
1874    vec_isa: VecISA = invalid_vec_isa,
1875    cuda: bool = False,
1876    aot_mode: bool = False,
1877    compile_only: bool = False,
1878    use_absolute_path: bool = False,
1879    use_mmap_weights: bool = False,
1880    extra_flags: Sequence[str] = (),
1881) -> str:
1882    ipaths, lpaths, libs, macros, build_arch_flags = get_include_and_linking_paths(
1883        include_pytorch, vec_isa, cuda, aot_mode
1884    )
1885    if isinstance(input, str):
1886        input = [input]
1887    ipaths_str = " ".join(["-I" + p for p in ipaths])
1888    clang_flags = ""
1889    if config.is_fbcode():
1890        if aot_mode and not use_absolute_path:
1891            inp_name = input
1892            out_name = output
1893            linker_script = _LINKER_SCRIPT
1894        else:
1895            # We need to copy any absolute-path torch includes
1896            inp_name = [os.path.basename(i) for i in input]
1897            out_name = os.path.basename(output)
1898            linker_script = os.path.basename(_LINKER_SCRIPT)
1899        assert is_clang()
1900        # Use clang runtime instead of libgcc
1901        clang_flags += " --rtlib=compiler-rt"
1902        clang_flags += " -fuse-ld=lld"
1903        clang_flags += f" -Wl,--script={linker_script}"
1904        linker_paths = "-B" + build_paths.glibc_lib()
1905        linker_paths += " -L" + build_paths.glibc_lib()
1906    else:
1907        inp_name = input
1908        out_name = output
1909        linker_paths = ""  # let the compiler pick
1910    if compile_only:
1911        libs, lpaths = "", ""
1912    inp_name_str = " ".join(inp_name)
1913    if use_mmap_weights:
1914        macros += " -D USE_MMAP_SELF"
1915
1916    return re.sub(
1917        r"[ \n]+",
1918        " ",
1919        f"""
1920            {cpp_compiler()} {inp_name_str} {get_shared(shared, compile_only)}
1921            {get_warning_all_flag(warning_all)} {cpp_flags()}
1922            {get_glibcxx_abi_build_flags()}
1923            {ipaths_str} {lpaths} {libs} {build_arch_flags}
1924            {macros} {linker_paths} {clang_flags}
1925            {optimization_flags()} {cpp_wrapper_flags()}
1926            {use_custom_generated_macros()}
1927            {use_fb_internal_macros()}
1928            {use_standard_sys_dir_headers()}
1929            {get_compile_only(compile_only)}
1930            {' '.join(extra_flags)}
1931            -o {out_name}
1932        """,
1933    ).strip()
1934
1935
1936def run_command_and_check(cmd: str):
1937    cmd = shlex.split(cmd)
1938    try:
1939        subprocess.check_call(cmd)
1940    except subprocess.CalledProcessError as e:
1941        raise exc.CppCompileError(cmd, e.output) from e
1942
1943
1944@functools.lru_cache(None)
1945def split_aot_inductor_output_path(path: str) -> Tuple[str, str]:
1946    """Returns the path where the AOT Inductor compiled kernels are stored."""
1947    if path.endswith(".so"):
1948        return os.path.split(path)
1949    else:
1950        return path, ""
1951
1952
1953@clear_on_fresh_inductor_cache
1954class CudaKernelParamCache:
1955    cache: Dict[str, Dict[str, str]] = dict()
1956    cache_clear = staticmethod(cache.clear)
1957
1958    @classmethod
1959    def set(cls, key: str, params: Dict[str, str], cubin: str) -> None:
1960        bin_type = "cubin" if torch.version.hip is None else "hsaco"
1961        _, path = write(
1962            cubin,
1963            bin_type,
1964            hash_type=bin_type,
1965            specified_dir=split_aot_inductor_output_path(
1966                config.aot_inductor.output_path
1967            )[0],
1968        )
1969
1970        params[get_cpp_wrapper_cubin_path_name()] = path
1971
1972        cls.cache[key] = params
1973
1974    @classmethod
1975    def get(cls, key: str) -> Optional[Dict[str, str]]:
1976        return cls.cache.get(key, None)
1977
1978    @classmethod
1979    def get_keys(cls):
1980        return cls.cache.keys()
1981
1982
1983class AotCodeCompiler:
1984    @classmethod
1985    def compile(
1986        cls,
1987        graph: GraphLowering,
1988        source_code: str,
1989        serialized_extern_kernel_nodes: Optional[str],
1990        cuda: bool,
1991    ) -> str:
1992        picked_vec_isa = pick_vec_isa()
1993        cpp_command = repr(
1994            cpp_compile_command(
1995                "i",
1996                "o",
1997                vec_isa=picked_vec_isa,
1998                cuda=cuda,
1999                aot_mode=graph.aot_mode,
2000            )
2001        )
2002        fbcode_aot_cpu_re = False
2003        use_absolute_path = False
2004        if config.is_fbcode():
2005            ld_command = build_paths.ld()
2006            if not cuda and graph.aot_mode:  # Meta internal AOTInductor CPU
2007                objcopy_command = build_paths.objcopy_fallback()
2008                fbcode_aot_cpu_re = True
2009                use_absolute_path = True
2010            else:
2011                objcopy_command = build_paths.objcopy()
2012        else:
2013            ld_command = "ld"
2014            objcopy_command = "objcopy"
2015
2016        (
2017            specified_output_path,
2018            specified_so_name,
2019        ) = split_aot_inductor_output_path(config.aot_inductor.output_path)
2020        key, input_path = write(
2021            source_code,
2022            "cpp",
2023            extra=cpp_command,
2024            specified_dir=specified_output_path,
2025        )
2026        output_code_log.info("Output code written to: %s", input_path)
2027        trace_structured(
2028            "graph_dump",
2029            lambda: {
2030                "name": "inductor_aot_code",
2031                "type": "cpp",
2032                "filename": input_path,
2033            },
2034            payload_fn=lambda: source_code,
2035        )
2036
2037        def _compile_consts_linux(consts: bytes) -> str:
2038            _, consts_path = write(
2039                consts,
2040                "bin",
2041                specified_dir=specified_output_path,
2042            )
2043
2044            consts_o = os.path.splitext(consts_path)[0] + ".o"
2045            if fbcode_aot_cpu_re:
2046                cmd = f"{ld_command} -r -b binary -o {os.path.basename(consts_o)} {os.path.basename(consts_path)}"
2047                compile_file(consts_path, consts_o, cmd.split())
2048                os.chmod(consts_o, 0o644)
2049            else:
2050                cmd = f"{ld_command} -r -b binary -o {consts_o} {consts_path}"
2051                run_command_and_check(cmd)
2052            log.debug("aot constant binary command: %s", cmd)
2053
2054            if graph.mutated_buffers & set(graph.constants.keys()):
2055                # .data section is between .text and .bss. When the size of .data is large,
2056                # during the linking, the relocation of .text against .bss may overflow.
2057                # Rename it to .ldata so that it won't be in between the .text and .bss section
2058                if len(consts) > 2_000_000_000:
2059                    raise ValueError(
2060                        "Models with buffer mutation included doesn't support constants greater than 2GB!"
2061                    )
2062                rename_data = " .data=.ldata"
2063            else:
2064                # if no buffer mutation is needed, we could instead set the data region
2065                # as read-only (i.e. .lrodata) which could accomodate larger size of data
2066                # to be linked.
2067                rename_data = " .data=.lrodata,alloc,load,readonly,data,contents"
2068
2069            assert (
2070                ALIGN_BYTES & (ALIGN_BYTES - 1)
2071            ) == 0 and ALIGN_BYTES >= 64, "must be power of 2 and >= 64"
2072            cmd = (
2073                f"{objcopy_command} --rename-section"
2074                f"{rename_data}"
2075                f" --set-section-alignment .data={ALIGN_BYTES}"  # following the gAlignment of CPU in c10/core/alignment.h
2076                f" {consts_o} {consts_o}"
2077            )
2078            log.debug("aot constant rename section command: %s", cmd)
2079            run_command_and_check(cmd)
2080
2081            cmd = f"rm {consts_path}"
2082            log.debug("aot constant bin removal command: %s", cmd)
2083            run_command_and_check(cmd)
2084
2085            if fbcode_aot_cpu_re:
2086                body = re.sub(r"[\W]", "_", os.path.basename(consts_path))
2087            else:
2088                body = re.sub(r"[\W]", "_", consts_path)
2089
2090            symbol_list = []
2091            symbol_list.append(
2092                f"{objcopy_command} --redefine-sym _binary_{body}_start=_binary_constants_bin_start {consts_o}"
2093            )
2094            symbol_list.append(
2095                f"{objcopy_command} --redefine-sym _binary_{body}_size=_binary_constants_bin_size {consts_o}"
2096            )
2097            symbol_list.append(
2098                f"{objcopy_command} --redefine-sym _binary_{body}_end=_binary_constants_bin_end {consts_o}"
2099            )
2100            log.debug("aot constant binary redefine symbol: %s", " ".join(symbol_list))
2101            for cmd in symbol_list:
2102                run_command_and_check(cmd)
2103            return consts_o
2104
2105        def _compile_consts_darwin(consts: bytes) -> str:
2106            if config.aot_inductor.debug_dump_consts_bin:
2107                _, _binary_constants_path = write(
2108                    consts,
2109                    "bin",
2110                    specified_dir=specified_output_path,
2111                )
2112                log.debug("binary constants path: %s", _binary_constants_path)
2113
2114            is_large_consts = len(consts) > 1024
2115            consts_asm = "\t.section\t__DATA,__data\n"
2116            consts_asm += "\t.globl\t__binary_constants_bin_start\n"
2117            consts_asm += "__binary_constants_bin_start:\n"
2118            if not is_large_consts:
2119                for c in consts:
2120                    consts_asm += f"\t.byte {c}\n"
2121                # Add one element even if constants are empty
2122                # Otherwise assembler will not put them in data section
2123                if not consts:
2124                    consts_asm += "\t.space 1\n"
2125            else:
2126                consts_asm += "\t.quad 0x1234567899abcdef\n"
2127                consts_asm += f"\t.space {len(consts) - 8}\n"
2128            consts_asm += ".globl\t__binary_constants_bin_end\n"
2129            consts_asm += "__binary_constants_bin_end:\n"
2130            _, consts_path = write(
2131                consts_asm,
2132                "S",
2133                specified_dir=specified_output_path,
2134            )
2135            consts_o = os.path.splitext(consts_path)[0] + ".o"
2136            cmd = f"{cpp_compiler()} -c -o {consts_o} {consts_path}"
2137            run_command_and_check(cmd)
2138            if is_large_consts:
2139                with open(consts_o, "r+b") as f:
2140                    f.seek(0)
2141                    hdr = f.read(1024)
2142                    # Search for magic number and write the actual data over it
2143                    start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12")
2144                    assert start_idx != -1
2145                    f.seek(start_idx)
2146                    pos = 0
2147                    while pos < len(consts):
2148                        rc = f.write(consts[pos:])
2149                        pos += rc
2150            return consts_o
2151
2152        from filelock import FileLock
2153
2154        lock_dir = get_lock_dir()
2155        lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
2156        with lock:
2157            # Currently, this only support serializing extern nodes in fbcode
2158            # Eventually, we should also have a serializer for OSS.
2159            if config.is_fbcode() and serialized_extern_kernel_nodes:
2160                output_json = os.path.splitext(input_path)[0] + ".json"
2161                with open(output_json, "w") as f:
2162                    f.write(serialized_extern_kernel_nodes)
2163
2164            output_so = (
2165                config.aot_inductor.output_path
2166                if specified_so_name
2167                else os.path.splitext(input_path)[0] + ".so"
2168            )
2169
2170            output_o = os.path.splitext(input_path)[0] + ".o"
2171            consts_size = sum(
2172                torch.ops.mkldnn._nbytes(tensor)
2173                if tensor.is_mkldnn
2174                else tensor.untyped_storage().nbytes()
2175                for (name, tensor) in graph.constants.items()
2176                if name not in graph.folded_constants
2177            )
2178            # TODO: Fix mmap weights with cuda
2179            use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000
2180            if config.aot_inductor.force_mmap_weights:
2181                use_mmap_weights = True
2182            compile_cmd = cpp_compile_command(
2183                input=input_path,
2184                output=output_o,
2185                vec_isa=picked_vec_isa,
2186                cuda=cuda,
2187                aot_mode=graph.aot_mode,
2188                compile_only=True,
2189                use_absolute_path=use_absolute_path,
2190                use_mmap_weights=use_mmap_weights,
2191            )
2192            log.debug("aot compilation command: %s", compile_cmd)
2193            if fbcode_aot_cpu_re:
2194                compile_file(input_path, output_o, compile_cmd.split())
2195                os.chmod(output_o, 0o644)
2196            else:
2197                run_command_and_check(compile_cmd)
2198
2199            def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes:
2200                def _pad_to_alignment(raw_bytes):
2201                    padded_bytes = raw_bytes.ljust(
2202                        (len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES,
2203                        b"\x00",
2204                    )
2205                    return padded_bytes
2206
2207                # This serializes the tensor's untyped_storage to bytes by accessing
2208                # the raw data of the underlying structure.
2209                import ctypes
2210
2211                if t.numel() == 0:
2212                    return b""
2213
2214                if t.is_mkldnn:
2215                    data_ptr = torch.ops.mkldnn.data_ptr(t)
2216                    nbytes = torch.ops.mkldnn._nbytes(t)
2217                else:
2218                    t_cpu = t.untyped_storage().cpu()
2219                    data_ptr = t_cpu.data_ptr()
2220                    nbytes = t_cpu.nbytes()
2221
2222                raw_array = ctypes.cast(
2223                    data_ptr,
2224                    ctypes.POINTER(ctypes.c_ubyte * nbytes),
2225                )
2226                raw_bytes = bytes(raw_array.contents)
2227                return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes)
2228
2229            all_cuda = all(
2230                graph.get_original_value_of_constant(name).is_cuda
2231                for name in graph.constants.keys()
2232                if name not in graph.folded_constants
2233            )
2234            serialized_weights = b"".join(
2235                _to_bytes(graph.get_original_value_of_constant(name), all_cuda)
2236                for name in graph.constants.keys()
2237                if name not in graph.folded_constants
2238            )
2239            if not use_mmap_weights:
2240                aot_constants = serialized_weights
2241                magic_number = 0
2242            else:
2243                magic_number = cast(
2244                    int, torch.randint(0, torch.iinfo(torch.int64).max, (1,)).item()
2245                )
2246                aot_constants = struct.pack("qq", consts_size + 8, magic_number)
2247            consts_o = {
2248                "linux": _compile_consts_linux,
2249                "darwin": _compile_consts_darwin,
2250            }[sys.platform](aot_constants)
2251
2252            link_cmd = cpp_compile_command(
2253                input=[output_o, consts_o],
2254                output=output_so,
2255                vec_isa=picked_vec_isa,
2256                cuda=cuda,
2257                aot_mode=graph.aot_mode,
2258                use_absolute_path=use_absolute_path,
2259            )
2260            log.debug("aot linkage command: %s", link_cmd)
2261            if fbcode_aot_cpu_re:
2262                compile_file([output_o, consts_o], output_so, link_cmd.split())
2263                os.chmod(output_so, 0o755)
2264            else:
2265                run_command_and_check(link_cmd)
2266
2267            if use_mmap_weights:
2268                with open(output_so, "a+b") as f_so:
2269                    so_size = f_so.tell()
2270                    # Page align the weights
2271                    f_so.write(b" " * (16384 - so_size % 16384))
2272                    f_so.write(serialized_weights)
2273                    f_so.write(struct.pack("q", magic_number))
2274
2275            # Append cmds to the end of codegen-ed wrapper file
2276            with open(input_path, "a") as f:
2277                f.write("\n")
2278                f.write(f"// Compile cmd\n// {compile_cmd}\n")
2279                f.write(f"// Link cmd\n// {link_cmd}\n")
2280
2281        return output_so
2282
2283
2284# Putting this fn in cpp.py (unfortunately) causes a deadlock, which is why it's in codecache.py.
2285# Why? importing from cpp.py invokes codecache.pick_vec_isa(), which takes out a lock.
2286# Cycle goes:
2287# - CppCodeCache.load()
2288# - pick_vec_isa()
2289# - valid_vec_isa_list()
2290# - VecISA.__bool__() <-- takes out a lock
2291# - compile_file() <-- imports cpp_prefix_path from cpp, which causes us to try to take out the same lock.
2292@clear_on_fresh_inductor_cache
2293@functools.lru_cache
2294def cpp_prefix_path() -> str:
2295    path = Path(__file__).parent / "codegen/cpp_prefix.h"
2296    with path.open() as f:
2297        content = f.read()
2298        _, filename = write(
2299            content,
2300            "h",
2301        )
2302    return filename
2303
2304
2305def cpp_prefix() -> str:
2306    filename = cpp_prefix_path()
2307    if config.is_fbcode():
2308        # We need relative paths, since we bundle up
2309        # everything that we compile into a folder for remote compilation.
2310        return f'#include "{os.path.basename(filename)}"'
2311    else:
2312        return f'#include "{filename}"'
2313
2314
2315# Given a path to an input cpp file and an output path,
2316# Attempts to compile the file, storing the output in "output_path"
2317@dynamo_timed
2318def compile_file(
2319    input_path: Union[str, List[str]], output_path: str, cmd: List[str]
2320) -> None:
2321    input_paths = [input_path] if isinstance(input_path, str) else input_path
2322    input_files = [
2323        os.path.basename(ip) if config.is_fbcode() else ip for ip in input_paths
2324    ]
2325    try:
2326        if config.is_fbcode():
2327            # Need to copy our header into the same folder as the sourcecode.
2328            header_path = cpp_prefix_path()
2329            header_name = os.path.basename(header_path)
2330            output_name = os.path.basename(output_path)
2331            # When we build remotely, we need to make sure to carefully copy any files
2332            # that are required during the compilation process into our build directly.
2333            # This is where all of the ATen/c10/Torch includes come from.
2334            torch_includes_path = os.path.join(_TORCH_PATH, "include")
2335            with tempfile.TemporaryDirectory() as tmp_dir:
2336                # Copy everything to tmp compilation folder
2337                shutil.copy(header_path, os.path.join(tmp_dir, header_name))
2338                shutil.copy(_LINKER_SCRIPT, os.path.join(tmp_dir, "script.ld"))
2339                for p, f in zip(input_paths, input_files):
2340                    shutil.copy(p, os.path.join(tmp_dir, f))
2341                dest_include_path = os.path.join(tmp_dir, "include")
2342                shutil.copytree(torch_includes_path, dest_include_path)
2343                # Run the build
2344                output_file_path = _run_build_command(cmd, tmp_dir, output_name)
2345                # Copy output from the build
2346                if os.path.exists(output_path):
2347                    os.remove(output_path)
2348                shutil.copy(output_file_path, output_path)
2349        else:
2350            subprocess.check_output(cmd, stderr=subprocess.STDOUT)
2351    except subprocess.CalledProcessError as e:
2352        output = e.output.decode("utf-8")
2353        openmp_problem = "'omp.h' file not found" in output or "libomp" in output
2354        if openmp_problem and sys.platform == "darwin":
2355            instruction = (
2356                "\n\nOpenMP support not found. Please try one of the following solutions:\n"
2357                "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ "
2358                "that has builtin OpenMP support;\n"
2359                "(2) install OpenMP via conda: `conda install llvm-openmp`;\n"
2360                "(3) install libomp via brew: `brew install libomp`;\n"
2361                "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path"
2362                " with `include/omp.h` under it."
2363            )
2364            output += instruction
2365        raise exc.CppCompileError(cmd, output) from e
2366
2367
2368_libgomp: Optional[CDLL] = None
2369
2370
2371def custom_op_wrapper(op: str, *args):
2372    # This function will be called from generated cpp wrapper code in the JIT mode.
2373    # Because tensors will be passed in as AtenTensorHandle, we need to explicitly convert them.
2374    def convert_arg(arg):
2375        if str(type(arg)) == "<class 'PyCapsule'>":
2376            # No easy way to do isinstance check on PyCapsule
2377            return torch._C._aoti.alloc_tensor_by_stealing_from_void_ptr(arg)
2378        elif isinstance(arg, (list, tuple)):
2379            return type(arg)(convert_arg(a) for a in arg)
2380        else:
2381            return arg
2382
2383    converted_args = [convert_arg(arg) for arg in args]
2384
2385    assert op.startswith("torch.ops."), (
2386        op + " can not be called through custom_op_wrapper"
2387    )
2388    func = None
2389    for i, s in enumerate(op.split(".")):
2390        if i == 0:
2391            func = importlib.import_module(s)
2392        func = getattr(func, s)
2393
2394    assert callable(func), op + " can not be loaded through custom_op_wrapper"
2395    result = func(*converted_args)
2396    if isinstance(result, (list, tuple)):
2397        for r in result:
2398            assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors"
2399        return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result)  # type: ignore[arg-type]
2400    else:
2401        assert isinstance(result, torch.Tensor), op + " returns a non-tensor"
2402        return torch._C._aoti.unsafe_alloc_void_ptr_from_tensor(result)
2403
2404
2405@clear_on_fresh_inductor_cache
2406class CppCodeCache:
2407    cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
2408    cache_clear = staticmethod(cache.clear)
2409    cpp_compile_command_flags: Dict[str, Any] = {}
2410
2411    @staticmethod
2412    def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
2413        return cdll.LoadLibrary(path)
2414
2415    @classmethod
2416    def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
2417        try:
2418            result = cls._load_library_inner(path, key)
2419            result.key = key  # type: ignore[union-attr]
2420            return result
2421        except (ImportError, OSError) as e:
2422            if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"):
2423                # hacky workaround for fbcode/buck
2424                global _libgomp
2425                _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1")
2426                result = cls._load_library_inner(path, key)
2427                result.key = key  # type: ignore[union-attr]
2428                return result
2429            if "failed to map segment from shared object" in str(e):
2430                raise OSError(
2431                    f"{e}.  The most common reason this may occur is if the {tempfile.gettempdir()} folder "
2432                    "is mounted with noexec (e.g., by default Docker mounts tmp file systems "
2433                    f"as noexec).  Please remount {tempfile.gettempdir()} with exec enabled, or set another "
2434                    "temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable."
2435                ) from e
2436            raise
2437
2438    @classmethod
2439    def load_async(cls, source_code: str, cuda=False, submit_fn=None, extra_flags=()):
2440        compile_command = {
2441            **cls.cpp_compile_command_flags,
2442            "cuda": cuda,
2443            "vec_isa": pick_vec_isa(),
2444            "extra_flags": extra_flags,
2445        }
2446
2447        _set_gpu_runtime_env()  # cpp_extension consults the env
2448
2449        from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions
2450
2451        dummy_builder = CppBuilder(
2452            name="o", sources="i", BuildOption=CppTorchCudaOptions(**compile_command)
2453        )
2454        # write function will calc source_code hash, the same source code with different
2455        # ISA level should be generate different hash.
2456        # So we need get a command_line which contains isa related parameter as a part of hash key.
2457        # And then pass the command_line to below write function as extra parameter to
2458        # guarantee the source code hash contains ISA difference.
2459        dummy_cmd = repr(dummy_builder.get_command_line())
2460        key, input_path = write(source_code, "cpp", extra=dummy_cmd)
2461
2462        if key not in cls.cache:
2463            from filelock import FileLock
2464
2465            lock_path = os.path.join(get_lock_dir(), key + ".lock")
2466            output_path = input_path[:-3] + "so"
2467            future: Optional[Future[Any]] = None
2468            lib = None
2469            worker_fn = functools.partial(
2470                _worker_compile_cpp,
2471                lock_path,
2472                input_path,
2473                output_path,
2474                cpp_compile_command(
2475                    input=input_path, output=output_path, **compile_command
2476                ),
2477            )
2478
2479            def load_fn():
2480                nonlocal lib
2481                if lib is None:
2482                    if future is not None:
2483                        future.result()
2484                    result = worker_fn()
2485                    assert result is None
2486                    lib = cls._load_library(output_path, key)
2487                    assert lib is not None
2488                return lib
2489
2490            if submit_fn is not None:
2491                with FileLock(lock_path, timeout=LOCK_TIMEOUT):
2492                    if not os.path.exists(output_path):
2493                        future = submit_fn(worker_fn)
2494
2495            cls.cache[key] = load_fn
2496
2497        return cls.cache[key]
2498
2499    @classmethod
2500    def load(cls, source_code: str, cuda: bool = False):
2501        return cls.load_async(source_code, cuda)()
2502
2503
2504def _worker_compile_cpp(lock_path, input_path, output_path, cmd):
2505    from filelock import FileLock
2506
2507    with FileLock(lock_path, timeout=LOCK_TIMEOUT):
2508        if not os.path.exists(output_path):
2509            compile_file(input_path, output_path, shlex.split(cmd))
2510
2511
2512# Customized Python binding for cpp kernels
2513@clear_on_fresh_inductor_cache
2514class CppPythonBindingsCodeCache(CppCodeCache):
2515    cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
2516    cache_clear = staticmethod(cache.clear)
2517    cpp_compile_command_flags = {
2518        # kernels have no dependency on libtorch
2519        "include_pytorch": False,
2520        "shared": True,
2521    }
2522    entry_function = "kernel"
2523    call_entry_function = "kernel(%s);Py_RETURN_NONE;"
2524    extra_parse_arg = ""
2525    suffix_template = textwrap.dedent(
2526        """
2527        // Python bindings to call %s():
2528        #define PY_SSIZE_T_CLEAN
2529        #include <Python.h>
2530        #include <sstream>
2531        #include <cstdlib>
2532
2533        #ifndef _MSC_VER
2534        #if __cplusplus < 202002L
2535        // C++20 earlier code
2536        // https://en.cppreference.com/w/cpp/language/attributes/likely
2537        #define likely(x)       __builtin_expect(!!(x), 1)
2538        #define unlikely(x)     __builtin_expect(!!(x), 0)
2539        #endif
2540        #endif
2541
2542        // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow.
2543        // We manually link it below to workaround issues with fbcode build.
2544        static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj);
2545
2546        template <typename T> static inline T parse_arg(PyObject* args, size_t n) {
2547            static_assert(std::is_pointer<T>::value, "arg type must be pointer or long");
2548            return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n)));
2549        }
2550        template <> inline long parse_arg<long>(PyObject* args, size_t n) {
2551            auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n));
2552            if(unlikely(result == -1 && PyErr_Occurred()))
2553                throw std::runtime_error("expected int arg");
2554            return result;
2555        }
2556        template <> inline uintptr_t parse_arg<uintptr_t>(PyObject* args, size_t n) {
2557            auto result = PyLong_AsVoidPtr(PyTuple_GET_ITEM(args, n));
2558            if(unlikely(result == reinterpret_cast<void*>(-1) && PyErr_Occurred()))
2559                throw std::runtime_error("expected int arg");
2560            return reinterpret_cast<uintptr_t>(result);
2561        }
2562
2563        %s
2564
2565        static PyObject* %s_py(PyObject* self, PyObject* args) {
2566            try {
2567                if(unlikely(!PyTuple_CheckExact(args)))
2568                    throw std::runtime_error("tuple args required");
2569                if(unlikely(PyTuple_GET_SIZE(args) != %s))
2570                    throw std::runtime_error("requires %s args");
2571                %s
2572            } catch(std::exception const& e) {
2573                PyErr_SetString(PyExc_RuntimeError, e.what());
2574                return nullptr;
2575            } catch(...) {
2576                PyErr_SetString(PyExc_RuntimeError, "unhandled error");
2577                return nullptr;
2578            }
2579        }
2580
2581        static PyMethodDef py_methods[] = {
2582            {"%s", %s_py, METH_VARARGS, ""},
2583            {NULL, NULL, 0, NULL}};
2584
2585        static struct PyModuleDef py_module =
2586            {PyModuleDef_HEAD_INIT, "%s", NULL, -1, py_methods};
2587
2588        PyMODINIT_FUNC PyInit_%s(void) {
2589            const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR");
2590            if(!str_addr) {
2591                PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set");
2592                return nullptr;
2593            }
2594            std::istringstream iss(str_addr);
2595            uintptr_t addr = 0;
2596            iss >> addr;
2597            _torchinductor_pyobject_tensor_data_ptr =
2598                reinterpret_cast<decltype(_torchinductor_pyobject_tensor_data_ptr)>(addr);
2599            return PyModule_Create(&py_module);
2600        }
2601        """
2602    )
2603
2604    @classmethod
2605    def _load_library_inner(cls, path: str, key: str) -> ModuleType:
2606        os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str(
2607            torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr  # type: ignore[attr-defined]
2608        )
2609        module_name = f"{key}.{cls.entry_function}"
2610        try:
2611            return sys.modules[module_name]
2612        except KeyError:
2613            pass
2614        spec = importlib.util.spec_from_file_location(module_name, path)
2615        assert spec is not None
2616        module = importlib.util.module_from_spec(spec)
2617        sys.modules[module_name] = module
2618        spec.loader.exec_module(module)  # type: ignore[union-attr]
2619        return module
2620
2621    @classmethod
2622    def load_pybinding_async(
2623        cls,
2624        argtypes: List[str],
2625        source_code: str,
2626        cuda: bool = False,
2627        num_outputs: int = -1,
2628        submit_fn=None,
2629        extra_flags=(),
2630    ) -> Any:
2631        """
2632        Wrap a C++ function in fast Python bindings.
2633
2634        Args:
2635            argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"]
2636            source_code: C++ source code containing a ENTRY_FUNCTION() function
2637
2638        Returns:
2639            A python version of ENTRY_FUNCTION()
2640        """
2641        parseargs = ", ".join(
2642            f"parse_arg<{argtype.replace('const ', '')}>(args, {n})"
2643            for n, argtype in enumerate(argtypes)
2644        )
2645        suffix = cls.suffix_template % (
2646            cls.entry_function,
2647            cls.extra_parse_arg % num_outputs if cls.extra_parse_arg else "",
2648            cls.entry_function,
2649            len(argtypes),
2650            len(argtypes),
2651            cls.call_entry_function % parseargs,
2652            cls.entry_function,
2653            cls.entry_function,
2654            cls.entry_function,
2655            cls.entry_function,
2656        )
2657        get_result = cls.load_async(
2658            source_code + suffix, cuda, submit_fn=submit_fn, extra_flags=extra_flags
2659        )
2660        result = None
2661
2662        def future():
2663            nonlocal result
2664            if result is None:
2665                result = get_result()
2666                assert isinstance(result, ModuleType)
2667            return getattr(result, cls.entry_function)
2668
2669        return future
2670
2671    @classmethod
2672    def load_pybinding(cls, *args, **kwargs) -> Any:
2673        return cls.load_pybinding_async(*args, **kwargs)()
2674
2675
2676@clear_on_fresh_inductor_cache
2677class CppWrapperCodeCache(CppPythonBindingsCodeCache):
2678    cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
2679    cache_clear = staticmethod(cache.clear)
2680    cpp_compile_command_flags = {
2681        "include_pytorch": True,
2682        "shared": True,
2683    }
2684    entry_function = "inductor_entry_cpp"
2685    call_entry_function = "return inductor_entry_cpp(%s);"
2686    extra_parse_arg = textwrap.dedent(
2687        """
2688        #include <torch/csrc/inductor/aoti_torch/c/shim.h>
2689
2690        static inline std::vector<AtenTensorHandle> unpack_tensor_handle_list(PyObject* pyvec) {
2691            std::vector<AtenTensorHandle> result;
2692            size_t result_len = PyList_GET_SIZE(pyvec);
2693            result.reserve(result_len);
2694            for (size_t i = 0; i < result_len; i++) {
2695                // AtenTensorHandle is essentially a pointer
2696                void* elem = PyCapsule_GetPointer(PyList_GET_ITEM(pyvec, i), NULL);
2697                result.push_back(reinterpret_cast<AtenTensorHandle>(elem));
2698            }
2699            return result;
2700        }
2701
2702        static inline PyObject* pack_tensor_handle_list(const std::vector<AtenTensorHandle>& cppvec) {
2703            size_t result_len = cppvec.size();
2704            PyObject* result = PyList_New(static_cast<Py_ssize_t>(result_len));
2705            for (size_t i = 0; i < result_len; i++) {
2706                PyObject *elem =
2707                    cppvec[i] == nullptr
2708                        ? Py_None
2709                        // Store AtenTensorHandle as PyCapsulate
2710                        : PyCapsule_New(reinterpret_cast<void*>(cppvec[i]), NULL, NULL);
2711                PyList_SET_ITEM(result, i, elem);
2712            }
2713            return result;
2714        }
2715
2716        template <> inline std::vector<AtenTensorHandle> parse_arg<std::vector<AtenTensorHandle>>(PyObject* args, size_t n) {
2717            return unpack_tensor_handle_list(PyTuple_GET_ITEM(args, n));
2718        }
2719
2720        PyObject* inductor_entry_cpp(std::vector<AtenTensorHandle>&& input_handles) {
2721            // For outputs, we only allocate a vector to hold returned tensor handles,
2722            // not allocating the actual output tensor storage here
2723            std::vector<AtenTensorHandle> output_handles(%s);
2724            try {
2725                inductor_entry_impl(input_handles.data(), output_handles.data());
2726                return pack_tensor_handle_list(output_handles);
2727            } catch(std::exception const& e) {
2728                PyErr_SetString(PyExc_RuntimeError, e.what());
2729                return {};
2730            } catch(...) {
2731                PyErr_SetString(PyExc_RuntimeError, "unhandled error");
2732                return {};
2733            }
2734        }
2735        """
2736    )
2737
2738
2739# TODO: Will remove the temp code after switch to new cpp_builder
2740def _temp_validate_new_and_old_command(new_cmd: List[str], old_cmd: List[str]):
2741    new_diff: List[str] = [x for x in new_cmd if x not in old_cmd]
2742    old_diff: List[str] = [y for y in old_cmd if y not in new_cmd]
2743
2744    if new_diff or old_diff:
2745        print("!!! new_cmd: ", new_cmd)
2746        print("!!! old_cmd: ", old_cmd)
2747        print("!!! new_diff: ", new_diff)
2748        print("!!! old_diff: ", old_diff)
2749        raise RuntimeError("Error in new and old command different.")
2750
2751
2752def _do_validate_cpp_commands(
2753    include_pytorch: bool,
2754    cuda: bool,
2755    compile_only: bool,
2756    mmap_weights: bool,
2757    use_absolute_path: bool,
2758):
2759    # PreCI will failed if test machine can't run cuda.
2760    temp_dir = tempfile.TemporaryDirectory()
2761    test_dir_path = temp_dir.name
2762    test_cuda = torch.cuda.is_available() and cuda
2763    input_path = os.path.join(test_dir_path, "dummy_input.cpp")
2764    output_path = os.path.join(test_dir_path, "dummy_output.so")
2765    extra_flags = ["-D TEST_EXTRA_FLAGS"]
2766    if compile_only:
2767        output_path = os.path.join(test_dir_path, "dummy_output.o")
2768    picked_isa = pick_vec_isa()
2769
2770    old_cmd = cpp_compile_command(
2771        input=input_path,
2772        output=output_path,
2773        include_pytorch=include_pytorch,
2774        vec_isa=picked_isa,
2775        cuda=test_cuda,
2776        aot_mode=False,
2777        compile_only=compile_only,
2778        use_absolute_path=use_absolute_path,
2779        use_mmap_weights=mmap_weights,
2780        extra_flags=extra_flags,
2781    ).split(" ")
2782
2783    from torch._inductor.cpp_builder import CppBuilder, CppTorchCudaOptions
2784
2785    dummy_build_option = CppTorchCudaOptions(
2786        vec_isa=picked_isa,
2787        include_pytorch=include_pytorch,
2788        cuda=test_cuda,
2789        compile_only=compile_only,
2790        use_absolute_path=use_absolute_path,
2791        use_mmap_weights=mmap_weights,
2792        extra_flags=extra_flags,
2793    )
2794
2795    dummy_builder = CppBuilder(
2796        name="dummy_output",
2797        sources=input_path,
2798        BuildOption=dummy_build_option,
2799        output_dir=test_dir_path,
2800    )
2801    new_cmd = dummy_builder.get_command_line().split(" ")
2802
2803    _temp_validate_new_and_old_command(new_cmd, old_cmd)
2804
2805    temp_dir.cleanup()
2806
2807
2808# TODO: Will remove the temp code after switch to new cpp_builder
2809# It could help on sync new cpp_builder generate same command line as the old one.
2810def validate_new_cpp_commands():
2811    cuda = [True, False]
2812    use_mmap_weights = [True, False]
2813    compile_only = [True, False]
2814    include_pytorch = [True, False]
2815    use_absolute_path = [True, False]
2816
2817    for x in cuda:
2818        for y in use_mmap_weights:
2819            for z in compile_only:
2820                for m in include_pytorch:
2821                    for n in use_absolute_path:
2822                        print(
2823                            f"!!! cuda:{x}, use_mmap_weights:{y}, compile_only:{z}, include_pytorch:{m}, use_absolute_path:{n}"
2824                        )
2825                        _do_validate_cpp_commands(
2826                            include_pytorch=m,
2827                            cuda=x,
2828                            mmap_weights=y,
2829                            compile_only=z,
2830                            use_absolute_path=n,
2831                        )
2832
2833
2834@clear_on_fresh_inductor_cache
2835class HalideCodeCache(CppPythonBindingsCodeCache):
2836    cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
2837    cache_clear = staticmethod(cache.clear)
2838    glue_template = textwrap.dedent(
2839        """
2840        #include "{halidebuffer_h}"
2841        #include "{headerfile}"
2842        #include <stdexcept>
2843        #include <cmath>
2844        void kernel({argdefs}) {{
2845            {buffers}
2846            int err = halide_kernel({buffer_names});
2847            if(err != 0) {{
2848                throw std::runtime_error("halide_kernel failed");
2849            }}
2850        }}
2851        """
2852    )
2853
2854    @classmethod
2855    def _codegen_glue(cls, argtypes, headerfile):
2856        buffers = []
2857        buffer_names = []
2858        for i, arg in enumerate(argtypes):
2859            if arg.numel:
2860                buffer_names.append(f"hl_buf_{i}")
2861                buffers.append(
2862                    f"    Halide::Runtime::Buffer {buffer_names[-1]}({arg.halide_type()}, {arg.name}, {arg.numel});"
2863                )
2864            else:
2865                assert "*" not in arg.ctype
2866                buffer_names.append(arg.name)
2867        glue_code = cls.glue_template.format(
2868            halidebuffer_h=cls.find_header("HalideBuffer.h"),
2869            headerfile=headerfile,
2870            argdefs=", ".join(f"{a.bindings_type()} {a.name}" for a in argtypes),
2871            buffers="\n".join(buffers).lstrip(),
2872            buffer_names=", ".join(buffer_names),
2873        )
2874        return glue_code
2875
2876    @classmethod
2877    @functools.lru_cache(None)
2878    def config_hash(cls):
2879        return sha256_hash(
2880            "\n".join(
2881                [
2882                    cls.glue_template,
2883                    f"{cls.cpu_cache_size()}",
2884                    cpp_compile_command("I", "O"),
2885                ]
2886            ).encode("utf-8")
2887        )
2888
2889    @staticmethod
2890    @functools.lru_cache(None)
2891    def cpu_cache_size():
2892        try:
2893            cpuinfo = open("/proc/cpuinfo").read()
2894        except OSError:
2895            return 16777216
2896        m = re.search(r"cache size\s*: (\d+) KB", cpuinfo)
2897        if m:
2898            return int(m.group(1)) * 1024
2899        m = re.search(r"cache size\s*: (\d+) MB", cpuinfo)
2900        if m:
2901            return int(m.group(1)) * 1024 * 1024
2902        raise RuntimeError("failed to find 'cache size: ... KB' in /proc/cpuinfo")
2903
2904    @staticmethod
2905    def _search_for_file(suffix, errmsg):
2906        try:
2907            search, *_ = importlib.machinery.PathFinder.find_spec(  # type: ignore[union-attr,misc]
2908                "halide"
2909            ).submodule_search_locations
2910            for file in os.listdir(search):
2911                if file.endswith(".so"):
2912                    try:
2913                        out = subprocess.check_output(
2914                            ["ldd", os.path.join(search, file)]
2915                        )
2916                    except subprocess.SubprocessError:
2917                        continue
2918                    m = re.search(r"(/.*)/libHalide.so", out.decode("utf-8"))
2919                    if m:
2920                        path = os.path.join(os.path.abspath(m.group(1)), suffix)
2921                        if os.path.exists(path):
2922                            return os.path.abspath(path)
2923        except Exception as e:
2924            raise RuntimeError(errmsg) from e
2925        raise RuntimeError(errmsg)
2926
2927    @staticmethod
2928    @functools.lru_cache(None)
2929    def find_libautoschedule(name):
2930        sofile = f"libautoschedule_{name.lower()}.so"
2931        if "HALIDE_LIB" in os.environ:
2932            path = os.path.join(os.environ["HALIDE_LIB"], sofile)
2933            if os.path.exists(path):
2934                return path
2935        errmsg = (
2936            f"Can't find {sofile}, set env HALIDE_LIB to the directory containing it"
2937        )
2938        return HalideCodeCache._search_for_file(sofile, errmsg)
2939
2940    @staticmethod
2941    @functools.lru_cache(None)
2942    def find_header(name):
2943        if "HALIDE_INCLUDE" in os.environ:
2944            path = os.path.join(os.environ["HALIDE_INCLUDE"], name)
2945            if os.path.exists(path):
2946                return path
2947        if "HALIDE_LIB" in os.environ:
2948            path = os.path.abspath(
2949                os.path.join(os.environ["HALIDE_LIB"], f"../include/{name}")
2950            )
2951            if os.path.exists(path):
2952                return path
2953        errmsg = (
2954            f"Can't find {name}, set env HALIDE_INCLUDE to the directory containing it"
2955        )
2956        return HalideCodeCache._search_for_file(f"../include/{name}", errmsg)
2957
2958    @classmethod
2959    def generate_halide_async(cls, meta: HalideMeta, source_code: str, submit_fn=None):
2960        dirpath = Path(
2961            get_path(
2962                code_hash(
2963                    source_code,
2964                    extra=repr((cls.config_hash(), meta)),
2965                ),
2966                "halide",
2967            )[2]
2968        )
2969        os.makedirs(dirpath, exist_ok=True)
2970        wait_for_compile = None
2971        genfile = str(dirpath / "generate_kernel.py")
2972        libfile = str(dirpath / "halide_kernel.a")
2973        headerfile = str(dirpath / "halide_kernel.h")
2974        donefile = str(dirpath / "done")
2975        lockfile = str(dirpath / "lock")
2976        need_compile = not os.path.exists(donefile)
2977        jobs = []
2978
2979        if need_compile:
2980            write_atomic(genfile, source_code)
2981            jobs.append(
2982                functools.partial(
2983                    subprocess.check_call,
2984                    [
2985                        sys.executable,
2986                        genfile,
2987                        "-g",
2988                        "kernel",
2989                        "-o",
2990                        f"{dirpath}",
2991                        "-f",
2992                        "halide_kernel",
2993                        "-e",
2994                        "static_library,h,schedule,pytorch_wrapper",
2995                        "-p",
2996                        cls.find_libautoschedule(meta.scheduler),
2997                        *meta.args(),
2998                    ],
2999                )
3000            )
3001
3002        bindings_future = cls.load_pybinding_async(
3003            [arg.bindings_type() for arg in meta.argtypes],
3004            cls._codegen_glue(meta.argtypes, headerfile),
3005            extra_flags=(libfile,),
3006            submit_fn=jobs.append if need_compile else None,
3007        )
3008
3009        if need_compile:
3010            jobs.append(functools.partial(touch, donefile))
3011            task = functools.partial(_worker_task_halide, lockfile, jobs)
3012            if submit_fn:
3013                wait_for_compile = submit_fn(task).result
3014            else:
3015                task()
3016
3017        def load():
3018            if wait_for_compile:
3019                wait_for_compile()
3020            return bindings_future()
3021
3022        return load
3023
3024    @classmethod
3025    def generate_halide(cls, *args, **kwargs):
3026        return cls.generate_halide_async(*args, **kwargs)()
3027
3028
3029def _worker_task_halide(lockfile, jobs):
3030    from filelock import FileLock
3031
3032    with FileLock(lockfile, LOCK_TIMEOUT):
3033        for job in jobs:
3034            job()
3035
3036
3037def touch(filename):
3038    open(filename, "a").close()
3039
3040
3041@clear_on_fresh_inductor_cache
3042class PyCodeCache:
3043    cache: Dict[str, ModuleType] = dict()
3044    linemaps: Dict[str, List[Tuple[Any, ...]]] = dict()
3045    cache_clear = staticmethod(cache.clear)
3046
3047    @classmethod
3048    def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]:
3049        return write(source_code, "py", extra=extra)
3050
3051    @classmethod
3052    def load(
3053        cls,
3054        source_code: str,
3055        extra: str = "",
3056        linemap: Optional[List[Tuple[int, str]]] = None,
3057        attrs: Optional[Dict[str, Any]] = None,
3058    ) -> ModuleType:
3059        key, path = write(source_code, "py", extra=extra)
3060        return cls.load_by_key_path(key, path, linemap, attrs)
3061
3062    @classmethod
3063    def load_by_key_path(
3064        cls,
3065        key: str,
3066        path: str,
3067        linemap: Optional[List[Tuple[int, str]]] = None,
3068        attrs: Optional[Dict[str, Any]] = None,
3069    ) -> ModuleType:
3070        if linemap is None:
3071            linemap = []
3072        if key not in cls.cache:
3073            mod = _reload_python_module(key, path)
3074
3075            # another thread might set this first
3076            cls.cache.setdefault(key, mod)
3077            # unzip into separate lines/nodes lists
3078            cls.linemaps[path] = list(zip(*linemap))
3079
3080            if attrs is not None:
3081                for k, v in attrs.items():
3082                    setattr(mod, k, v)
3083
3084            if not (linemap or attrs):
3085                mod._reload_in_subproc = functools.partial(  # type: ignore[attr-defined]
3086                    _reload_python_module_in_subproc, key, path
3087                )
3088
3089        return cls.cache[key]
3090
3091    @classmethod
3092    @functools.lru_cache(None)
3093    def stack_frames_for_code(
3094        cls, path: str, lineno: int
3095    ) -> Optional[List[Dict[str, Any]]]:
3096        if path not in cls.linemaps:
3097            return None
3098        # [(starting_line, <fx node>), ...]
3099        lines, nodes = cls.linemaps[path]
3100        p = bisect_right(lines, lineno)
3101        if p == 0:
3102            return None
3103        entry = nodes[p - 1]
3104        if not entry:
3105            return None
3106
3107        def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]:
3108            # ideally fx stores stack traces as data rather than a string
3109            # but this is not along a performance critical path
3110            regex = r'File "(.+)", line (\d+), in (.+)\n'
3111            matches = re.findall(regex, stack_trace)
3112            return [
3113                {"filename": f, "line": int(l), "name": n}
3114                for f, l, n in reversed(matches)
3115            ]
3116
3117        return parse_stack_trace(entry)
3118
3119
3120class TritonCodeCache:
3121    @classmethod
3122    def load(cls, kernel_name: str, source_code: str) -> ModuleType:
3123        return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
3124
3125
3126def _cuda_compiler() -> Optional[str]:
3127    if cuda_env.nvcc_exist(config.cuda.cuda_cxx):
3128        return config.cuda.cuda_cxx
3129    if config.is_fbcode():
3130        return os.path.join(build_paths.cuda(), "bin", "nvcc")
3131    if cuda_env.nvcc_exist(os.getenv("CUDACXX")):
3132        return os.getenv("CUDACXX", "")
3133    if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")):
3134        return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc"))
3135    return "nvcc"
3136
3137
3138def _cutlass_include_paths() -> List[str]:
3139    if config.is_fbcode():
3140        from libfb.py import parutil
3141
3142        cutlass_path = parutil.get_dir_path("cutlass-3-headers")
3143    else:
3144        cutlass_path = config.cuda.cutlass_dir
3145    return [
3146        # Use realpath to get canonical absolute paths, in order not to mess up cache keys
3147        os.path.realpath(os.path.join(cutlass_path, "include")),
3148        os.path.realpath(os.path.join(cutlass_path, "tools/library/include")),
3149        os.path.realpath(os.path.join(cutlass_path, "tools/library/src")),
3150        os.path.realpath(os.path.join(cutlass_path, "tools/util/include")),
3151    ]
3152
3153
3154def _cuda_lib_options() -> List[str]:
3155    _set_gpu_runtime_env()  # cpp_extension consults the env
3156    from torch.utils import cpp_extension
3157
3158    lpaths = cpp_extension.library_paths(cuda=True) + [
3159        sysconfig.get_config_var("LIBDIR")
3160    ]
3161    extra_ldflags: List[str] = []
3162    if is_linux():
3163        _transform_cuda_paths(lpaths)
3164        for path in lpaths:
3165            # -rpath ensures the DLL can find its dependencies when loaded, even
3166            # if the library path is non-standard.
3167            extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"])
3168        extra_ldflags.append("-lcuda")
3169        extra_ldflags.append("-lcudart")
3170    else:
3171        raise NotImplementedError(
3172            "Unsupported env, failed to find cuda libs! Currently only Linux is supported."
3173        )
3174    return extra_ldflags
3175
3176
3177def _nvcc_host_compiler_options() -> List[str]:
3178    return [
3179        "-fPIC",
3180        "-fno-strict-aliasing",
3181        "-fvisibility=hidden",
3182        "-Wconversion",
3183    ]
3184
3185
3186def _nvcc_compiler_options() -> List[str]:
3187    arch = cuda_env.get_cuda_arch()
3188    if arch == "90":
3189        # Required by cutlass compilation.
3190        arch = "90a"
3191    code = [f"sm_{arch}", f"compute_{arch}"]
3192    if config.cuda.enable_cuda_lto:
3193        code += [f"lto_{arch}"]
3194    options = [
3195        "-t=0",
3196        "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
3197        "-w",
3198        f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
3199        config.cuda.compile_opt_level,
3200        "-std=c++17",
3201        "--expt-relaxed-constexpr",
3202        "-DNDEBUG",
3203    ]
3204    if config.is_fbcode():
3205        options.extend(["-ccbin", os.path.dirname(build_paths.gcc())])
3206    if config.cuda.enable_debug_info:
3207        options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
3208    if config.cuda.enable_ptxas_info:
3209        options.extend(
3210            [
3211                "--keep",  # Keep the intermediate files for debugging (including ptx, sass, cubin etc.)
3212                "--ptxas-options=--warn-on-local-memory-usage",  # warn us if local memory is used in CUDA Kernels
3213                "--ptxas-options=--warn-on-spills",  # warn us if register spilling happens in CUDA Kernels
3214                "--resource-usage",  # Report on CUDA resource usage (shared mem, registers etc.)
3215                "--source-in-ptx",
3216            ]
3217        )  # Annotate the ptx file with source information
3218    if config.cuda.use_fast_math:
3219        options.extend(
3220            [
3221                "--use_fast_math",
3222                "-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
3223            ]
3224        )
3225    return options
3226
3227
3228def cuda_compile_command(
3229    src_files: List[str],
3230    dst_file: str,
3231    dst_file_ext: str,
3232    extra_args: Optional[List[str]] = None,
3233) -> str:
3234    if extra_args is None:
3235        extra_args = []
3236    include_paths = _cutlass_include_paths()
3237    cuda_lib_options = _cuda_lib_options()
3238    nvcc_host_compiler_options = _nvcc_host_compiler_options()
3239    nvcc_compiler_options = _nvcc_compiler_options()
3240    options = (
3241        nvcc_compiler_options
3242        + extra_args
3243        + [
3244            f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}"
3245            for opt in nvcc_host_compiler_options
3246        ]
3247        + ["-I" + path for path in include_paths]
3248        + cuda_lib_options
3249    )
3250    src_file = " ".join(src_files)
3251    res = ""
3252    if dst_file_ext == "o":
3253        res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
3254    elif dst_file_ext == "so":
3255        options.append("-shared")
3256        res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
3257    elif dst_file_ext == "exe":
3258        res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
3259    else:
3260        raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
3261    log.debug("CUDA command: %s", res)
3262    return res
3263
3264
3265class DLLWrapper:
3266    """A wrapper for a dynamic library."""
3267
3268    def __init__(
3269        self,
3270        lib_path: str,
3271    ):
3272        self.lib_path = lib_path
3273        self.is_open = False
3274        self.DLL = cdll.LoadLibrary(lib_path)
3275        self.is_open = True
3276
3277    def close(self):
3278        if self.is_open:
3279            self._dlclose()
3280            self.is_open = False
3281
3282    def _dlclose(self):
3283        f_dlclose = None
3284
3285        if is_linux():
3286            syms = CDLL(None)
3287            if not hasattr(syms, "dlclose"):
3288                # Apline Linux
3289                syms = CDLL("libc.so")
3290
3291            if hasattr(syms, "dlclose"):
3292                f_dlclose = syms.dlclose
3293        else:
3294            raise NotImplementedError("Unsupported env, failed to do dlclose!")
3295
3296        if f_dlclose is not None:
3297            f_dlclose.argtypes = [c_void_p]
3298            f_dlclose(self.DLL._handle)
3299        else:
3300            log.warning(
3301                "dll unloading function was not found, library may not be unloaded properly!"
3302            )
3303
3304    def __getattr__(self, name):
3305        if not self.is_open:
3306            raise RuntimeError(f"Cannot use closed DLL library: {self.lib_path}")
3307
3308        method = getattr(self.DLL, name)
3309
3310        def _wrapped_func(*args):
3311            err = method(*args)
3312            if err:
3313                raise RuntimeError(f"Error in function: {method.__name__}")
3314
3315        return _wrapped_func
3316
3317    def __enter__(self):
3318        return self
3319
3320    def __exit__(self, *args):
3321        self.close()
3322
3323    def __del__(self):
3324        self.close()
3325
3326
3327@clear_on_fresh_inductor_cache
3328class CUDACodeCache:
3329    @dataclasses.dataclass
3330    class CacheEntry:
3331        input_path: str
3332        output_path: str
3333
3334    cache: Dict[str, CacheEntry] = dict()
3335    cache_clear = staticmethod(cache.clear)
3336    _SOURCE_CODE_SUFFIX = "cu"
3337
3338    @classmethod
3339    def write(cls, source_code, dst_file_ext) -> Tuple[str, str]:
3340        """
3341        Writes source code into a file with dst_file_ext as the file extension.
3342        Returns the hash key of source code, and the path to the file.
3343        """
3344
3345        cuda_command = repr(
3346            cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext)
3347        )
3348        key, input_path = write(
3349            source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command
3350        )
3351        return key, input_path
3352
3353    @classmethod
3354    def compile(
3355        cls, source_code, dst_file_ext, extra_args: Optional[List[str]] = None
3356    ) -> Tuple[str, str, str]:
3357        """
3358        Compiles CUDA source_code into a file with dst_file_ext extension.
3359        Returns a tuple of dst_file_path, hash_key, source_code_path
3360        """
3361        key, input_path = cls.write(source_code, dst_file_ext)
3362        if key not in cls.cache:
3363            from filelock import FileLock
3364
3365            lock_dir = get_lock_dir()
3366            lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
3367            with lock:
3368                output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext
3369                if not os.path.exists(output_path):
3370                    cmd = cuda_compile_command(
3371                        [input_path], output_path, dst_file_ext, extra_args
3372                    )
3373                    start_time = time()
3374                    log.debug("CUDA Compilation: %s", cmd)
3375                    cmd_parts = cmd.split(" ")
3376                    try:
3377                        subprocess.check_output(
3378                            cmd_parts, stderr=subprocess.STDOUT, env=os.environ
3379                        )
3380                    except subprocess.CalledProcessError as error:
3381                        raise exc.CUDACompileError(cmd_parts, error.output) from error
3382                    end_time = time()
3383                    log_duration_msg = f"CUDA Compilation took {end_time-start_time} seconds. Compile command: {cmd}"
3384                    log.info(log_duration_msg)
3385                else:
3386                    log.debug(
3387                        "CUDA Compilation skipped: %s since output already exists",
3388                        input_path,
3389                    )
3390                cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path)
3391
3392        return (cls.cache[key].output_path, key, input_path)
3393
3394    @classmethod
3395    def load(cls, source_code, dst_file_ext) -> Tuple[DLLWrapper, str, str]:
3396        """
3397        Compiles source code and loads the generated .so file.
3398        Returns a tuple of DLLWrapper, hash_key, source_code_path
3399        """
3400
3401        if dst_file_ext != "so":
3402            raise RuntimeError(
3403                f"Only support loading a .so file for now. "
3404                f"Requested file extension: {dst_file_ext}. Source code: {source_code}"
3405            )
3406        dst_file_path, hash_key, source_code_path = cls.compile(
3407            source_code, dst_file_ext
3408        )
3409        return (DLLWrapper(dst_file_path), hash_key, source_code_path)
3410
3411
3412class CodeCacheFuture:
3413    def result(self):
3414        raise NotImplementedError
3415
3416
3417class TritonFuture(CodeCacheFuture):
3418    kernel: ModuleType
3419
3420    def __init__(
3421        self,
3422        kernel: Any,
3423        future: Optional[Future[Any]],
3424    ) -> None:
3425        self.kernel = kernel
3426        self.future = future
3427
3428    # @dynamo_utils.dynamo_timed
3429    def result(self) -> ModuleType:
3430        if self.future is not None:
3431            # If the worker failed this will throw an exception.
3432            result = self.future.result()
3433            assert result is None
3434            self.future = None
3435            self.kernel.precompile()
3436        return self.kernel
3437
3438
3439class LambdaFuture(CodeCacheFuture):
3440    def __init__(self, result_fn):
3441        self.result_fn = result_fn
3442
3443    def result(self):
3444        return self.result_fn()
3445