1import functools 2import logging 3from typing import Optional 4 5import torch 6 7from ... import config 8 9 10log = logging.getLogger(__name__) 11 12 13def get_cuda_arch() -> Optional[str]: 14 try: 15 cuda_arch = config.cuda.arch 16 if cuda_arch is None: 17 # Get Compute Capability of the first Visible device 18 major, minor = torch.cuda.get_device_capability(0) 19 return str(major * 10 + minor) 20 return str(cuda_arch) 21 except Exception as e: 22 log.error("Error getting cuda arch: %s", e) 23 return None 24 25 26def get_cuda_version() -> Optional[str]: 27 try: 28 cuda_version = config.cuda.version 29 if cuda_version is None: 30 cuda_version = torch.version.cuda 31 return cuda_version 32 except Exception as e: 33 log.error("Error getting cuda version: %s", e) 34 return None 35 36 37@functools.lru_cache(None) 38def nvcc_exist(nvcc_path: str = "nvcc") -> bool: 39 if nvcc_path is None: 40 return False 41 import subprocess 42 43 res = subprocess.call( 44 ["which", nvcc_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL 45 ) 46 return res == 0 47