xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cuda/cuda_env.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import functools
2import logging
3from typing import Optional
4
5import torch
6
7from ... import config
8
9
10log = logging.getLogger(__name__)
11
12
13def get_cuda_arch() -> Optional[str]:
14    try:
15        cuda_arch = config.cuda.arch
16        if cuda_arch is None:
17            # Get Compute Capability of the first Visible device
18            major, minor = torch.cuda.get_device_capability(0)
19            return str(major * 10 + minor)
20        return str(cuda_arch)
21    except Exception as e:
22        log.error("Error getting cuda arch: %s", e)
23        return None
24
25
26def get_cuda_version() -> Optional[str]:
27    try:
28        cuda_version = config.cuda.version
29        if cuda_version is None:
30            cuda_version = torch.version.cuda
31        return cuda_version
32    except Exception as e:
33        log.error("Error getting cuda version: %s", e)
34        return None
35
36
37@functools.lru_cache(None)
38def nvcc_exist(nvcc_path: str = "nvcc") -> bool:
39    if nvcc_path is None:
40        return False
41    import subprocess
42
43    res = subprocess.call(
44        ["which", nvcc_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
45    )
46    return res == 0
47