1import os 2import re 3import subprocess 4import sys 5import traceback 6import warnings 7 8 9MIN_CUDA_VERSION = "11.6" 10MIN_ROCM_VERSION = "5.4" 11MIN_PYTHON_VERSION = (3, 8) 12 13 14class VerifyDynamoError(BaseException): 15 pass 16 17 18def check_python(): 19 if sys.version_info < MIN_PYTHON_VERSION: 20 raise VerifyDynamoError( 21 f"Python version not supported: {sys.version_info} " 22 f"- minimum requirement: {MIN_PYTHON_VERSION}" 23 ) 24 return sys.version_info 25 26 27def check_torch(): 28 import torch 29 30 return torch.__version__ 31 32 33# based on torch/utils/cpp_extension.py 34def get_cuda_version(): 35 from torch.torch_version import TorchVersion 36 from torch.utils import cpp_extension 37 38 CUDA_HOME = cpp_extension._find_cuda_home() 39 if not CUDA_HOME: 40 raise VerifyDynamoError(cpp_extension.CUDA_NOT_FOUND_MESSAGE) 41 42 nvcc = os.path.join(CUDA_HOME, "bin", "nvcc") 43 cuda_version_str = ( 44 subprocess.check_output([nvcc, "--version"]) 45 .strip() 46 .decode(*cpp_extension.SUBPROCESS_DECODE_ARGS) 47 ) 48 cuda_version = re.search(r"release (\d+[.]\d+)", cuda_version_str) 49 if cuda_version is None: 50 raise VerifyDynamoError("CUDA version not found in `nvcc --version` output") 51 52 cuda_str_version = cuda_version.group(1) 53 return TorchVersion(cuda_str_version) 54 55 56def get_rocm_version(): 57 from torch.torch_version import TorchVersion 58 from torch.utils import cpp_extension 59 60 ROCM_HOME = cpp_extension._find_rocm_home() 61 if not ROCM_HOME: 62 raise VerifyDynamoError( 63 "ROCM was not found on the system, please set ROCM_HOME environment variable" 64 ) 65 66 hipcc = os.path.join(ROCM_HOME, "bin", "hipcc") 67 hip_version_str = ( 68 subprocess.check_output([hipcc, "--version"]) 69 .strip() 70 .decode(*cpp_extension.SUBPROCESS_DECODE_ARGS) 71 ) 72 hip_version = re.search(r"HIP version: (\d+[.]\d+)", hip_version_str) 73 74 if hip_version is None: 75 raise VerifyDynamoError("HIP version not found in `hipcc --version` output") 76 77 hip_str_version = hip_version.group(1) 78 79 return TorchVersion(hip_str_version) 80 81 82def check_cuda(): 83 import torch 84 from torch.torch_version import TorchVersion 85 86 if not torch.cuda.is_available() or torch.version.hip is not None: 87 return None 88 89 torch_cuda_ver = TorchVersion(torch.version.cuda) 90 91 # check if torch cuda version matches system cuda version 92 cuda_ver = get_cuda_version() 93 if cuda_ver != torch_cuda_ver: 94 # raise VerifyDynamoError( 95 warnings.warn( 96 f"CUDA version mismatch, `torch` version: {torch_cuda_ver}, env version: {cuda_ver}" 97 ) 98 99 if torch_cuda_ver < MIN_CUDA_VERSION: 100 # raise VerifyDynamoError( 101 warnings.warn( 102 f"(`torch`) CUDA version not supported: {torch_cuda_ver} " 103 f"- minimum requirement: {MIN_CUDA_VERSION}" 104 ) 105 if cuda_ver < MIN_CUDA_VERSION: 106 # raise VerifyDynamoError( 107 warnings.warn( 108 f"(env) CUDA version not supported: {cuda_ver} " 109 f"- minimum requirement: {MIN_CUDA_VERSION}" 110 ) 111 112 return cuda_ver if torch.version.hip is None else "None" 113 114 115def check_rocm(): 116 import torch 117 from torch.torch_version import TorchVersion 118 119 if not torch.cuda.is_available() or torch.version.hip is None: 120 return None 121 122 # Extracts main ROCm version from full string 123 torch_rocm_ver = TorchVersion(".".join(list(torch.version.hip.split(".")[0:2]))) 124 125 # check if torch rocm version matches system rocm version 126 rocm_ver = get_rocm_version() 127 if rocm_ver != torch_rocm_ver: 128 warnings.warn( 129 f"ROCm version mismatch, `torch` version: {torch_rocm_ver}, env version: {rocm_ver}" 130 ) 131 if torch_rocm_ver < MIN_ROCM_VERSION: 132 warnings.warn( 133 f"(`torch`) ROCm version not supported: {torch_rocm_ver} " 134 f"- minimum requirement: {MIN_ROCM_VERSION}" 135 ) 136 if rocm_ver < MIN_ROCM_VERSION: 137 warnings.warn( 138 f"(env) ROCm version not supported: {rocm_ver} " 139 f"- minimum requirement: {MIN_ROCM_VERSION}" 140 ) 141 142 return rocm_ver if torch.version.hip else "None" 143 144 145def check_dynamo(backend, device, err_msg) -> None: 146 import torch 147 148 if device == "cuda" and not torch.cuda.is_available(): 149 print(f"CUDA not available -- skipping CUDA check on {backend} backend\n") 150 return 151 152 try: 153 import torch._dynamo as dynamo 154 155 if device == "cuda": 156 from torch.utils._triton import has_triton 157 158 if not has_triton(): 159 print( 160 f"WARNING: CUDA available but triton cannot be used. " 161 f"Your GPU may not be supported. " 162 f"Skipping CUDA check on {backend} backend\n" 163 ) 164 return 165 166 dynamo.reset() 167 168 @dynamo.optimize(backend, nopython=True) 169 def fn(x): 170 return x + x 171 172 class Module(torch.nn.Module): 173 def forward(self, x): 174 return x + x 175 176 mod = Module() 177 opt_mod = dynamo.optimize(backend, nopython=True)(mod) 178 179 for f in (fn, opt_mod): 180 x = torch.randn(10, 10).to(device) 181 x.requires_grad = True 182 y = f(x) 183 torch.testing.assert_close(y, x + x) 184 z = y.sum() 185 z.backward() 186 torch.testing.assert_close(x.grad, 2 * torch.ones_like(x)) 187 except Exception: 188 sys.stderr.write(traceback.format_exc() + "\n" + err_msg + "\n\n") 189 sys.exit(1) 190 191 192_SANITY_CHECK_ARGS = ( 193 ("eager", "cpu", "CPU eager sanity check failed"), 194 ("eager", "cuda", "CUDA eager sanity check failed"), 195 ("aot_eager", "cpu", "CPU aot_eager sanity check failed"), 196 ("aot_eager", "cuda", "CUDA aot_eager sanity check failed"), 197 ("inductor", "cpu", "CPU inductor sanity check failed"), 198 ( 199 "inductor", 200 "cuda", 201 "CUDA inductor sanity check failed\n" 202 + "NOTE: Please check that you installed the correct hash/version of `triton`", 203 ), 204) 205 206 207def main() -> None: 208 python_ver = check_python() 209 torch_ver = check_torch() 210 cuda_ver = check_cuda() 211 rocm_ver = check_rocm() 212 print( 213 f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n" 214 f"`torch` version: {torch_ver}\n" 215 f"CUDA version: {cuda_ver}\n" 216 f"ROCM version: {rocm_ver}\n" 217 ) 218 for args in _SANITY_CHECK_ARGS: 219 if sys.version_info >= (3, 13): 220 warnings.warn("Dynamo not yet supported in Python 3.13. Skipping check.") 221 continue 222 check_dynamo(*args) 223 print("All required checks passed") 224 225 226if __name__ == "__main__": 227 main() 228