1*da0073e9SAndroid Build Coastguard Workerimport os 2*da0073e9SAndroid Build Coastguard Workerimport re 3*da0073e9SAndroid Build Coastguard Workerimport subprocess 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport traceback 6*da0073e9SAndroid Build Coastguard Workerimport warnings 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard WorkerMIN_CUDA_VERSION = "11.6" 10*da0073e9SAndroid Build Coastguard WorkerMIN_ROCM_VERSION = "5.4" 11*da0073e9SAndroid Build Coastguard WorkerMIN_PYTHON_VERSION = (3, 8) 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerclass VerifyDynamoError(BaseException): 15*da0073e9SAndroid Build Coastguard Worker pass 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerdef check_python(): 19*da0073e9SAndroid Build Coastguard Worker if sys.version_info < MIN_PYTHON_VERSION: 20*da0073e9SAndroid Build Coastguard Worker raise VerifyDynamoError( 21*da0073e9SAndroid Build Coastguard Worker f"Python version not supported: {sys.version_info} " 22*da0073e9SAndroid Build Coastguard Worker f"- minimum requirement: {MIN_PYTHON_VERSION}" 23*da0073e9SAndroid Build Coastguard Worker ) 24*da0073e9SAndroid Build Coastguard Worker return sys.version_info 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerdef check_torch(): 28*da0073e9SAndroid Build Coastguard Worker import torch 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker return torch.__version__ 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker# based on torch/utils/cpp_extension.py 34*da0073e9SAndroid Build Coastguard Workerdef get_cuda_version(): 35*da0073e9SAndroid Build Coastguard Worker from torch.torch_version import TorchVersion 36*da0073e9SAndroid Build Coastguard Worker from torch.utils import cpp_extension 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker CUDA_HOME = cpp_extension._find_cuda_home() 39*da0073e9SAndroid Build Coastguard Worker if not CUDA_HOME: 40*da0073e9SAndroid Build Coastguard Worker raise VerifyDynamoError(cpp_extension.CUDA_NOT_FOUND_MESSAGE) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker nvcc = os.path.join(CUDA_HOME, "bin", "nvcc") 43*da0073e9SAndroid Build Coastguard Worker cuda_version_str = ( 44*da0073e9SAndroid Build Coastguard Worker subprocess.check_output([nvcc, "--version"]) 45*da0073e9SAndroid Build Coastguard Worker .strip() 46*da0073e9SAndroid Build Coastguard Worker .decode(*cpp_extension.SUBPROCESS_DECODE_ARGS) 47*da0073e9SAndroid Build Coastguard Worker ) 48*da0073e9SAndroid Build Coastguard Worker cuda_version = re.search(r"release (\d+[.]\d+)", cuda_version_str) 49*da0073e9SAndroid Build Coastguard Worker if cuda_version is None: 50*da0073e9SAndroid Build Coastguard Worker raise VerifyDynamoError("CUDA version not found in `nvcc --version` output") 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker cuda_str_version = cuda_version.group(1) 53*da0073e9SAndroid Build Coastguard Worker return TorchVersion(cuda_str_version) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerdef get_rocm_version(): 57*da0073e9SAndroid Build Coastguard Worker from torch.torch_version import TorchVersion 58*da0073e9SAndroid Build Coastguard Worker from torch.utils import cpp_extension 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker ROCM_HOME = cpp_extension._find_rocm_home() 61*da0073e9SAndroid Build Coastguard Worker if not ROCM_HOME: 62*da0073e9SAndroid Build Coastguard Worker raise VerifyDynamoError( 63*da0073e9SAndroid Build Coastguard Worker "ROCM was not found on the system, please set ROCM_HOME environment variable" 64*da0073e9SAndroid Build Coastguard Worker ) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker hipcc = os.path.join(ROCM_HOME, "bin", "hipcc") 67*da0073e9SAndroid Build Coastguard Worker hip_version_str = ( 68*da0073e9SAndroid Build Coastguard Worker subprocess.check_output([hipcc, "--version"]) 69*da0073e9SAndroid Build Coastguard Worker .strip() 70*da0073e9SAndroid Build Coastguard Worker .decode(*cpp_extension.SUBPROCESS_DECODE_ARGS) 71*da0073e9SAndroid Build Coastguard Worker ) 72*da0073e9SAndroid Build Coastguard Worker hip_version = re.search(r"HIP version: (\d+[.]\d+)", hip_version_str) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker if hip_version is None: 75*da0073e9SAndroid Build Coastguard Worker raise VerifyDynamoError("HIP version not found in `hipcc --version` output") 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker hip_str_version = hip_version.group(1) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker return TorchVersion(hip_str_version) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Workerdef check_cuda(): 83*da0073e9SAndroid Build Coastguard Worker import torch 84*da0073e9SAndroid Build Coastguard Worker from torch.torch_version import TorchVersion 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker if not torch.cuda.is_available() or torch.version.hip is not None: 87*da0073e9SAndroid Build Coastguard Worker return None 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker torch_cuda_ver = TorchVersion(torch.version.cuda) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker # check if torch cuda version matches system cuda version 92*da0073e9SAndroid Build Coastguard Worker cuda_ver = get_cuda_version() 93*da0073e9SAndroid Build Coastguard Worker if cuda_ver != torch_cuda_ver: 94*da0073e9SAndroid Build Coastguard Worker # raise VerifyDynamoError( 95*da0073e9SAndroid Build Coastguard Worker warnings.warn( 96*da0073e9SAndroid Build Coastguard Worker f"CUDA version mismatch, `torch` version: {torch_cuda_ver}, env version: {cuda_ver}" 97*da0073e9SAndroid Build Coastguard Worker ) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker if torch_cuda_ver < MIN_CUDA_VERSION: 100*da0073e9SAndroid Build Coastguard Worker # raise VerifyDynamoError( 101*da0073e9SAndroid Build Coastguard Worker warnings.warn( 102*da0073e9SAndroid Build Coastguard Worker f"(`torch`) CUDA version not supported: {torch_cuda_ver} " 103*da0073e9SAndroid Build Coastguard Worker f"- minimum requirement: {MIN_CUDA_VERSION}" 104*da0073e9SAndroid Build Coastguard Worker ) 105*da0073e9SAndroid Build Coastguard Worker if cuda_ver < MIN_CUDA_VERSION: 106*da0073e9SAndroid Build Coastguard Worker # raise VerifyDynamoError( 107*da0073e9SAndroid Build Coastguard Worker warnings.warn( 108*da0073e9SAndroid Build Coastguard Worker f"(env) CUDA version not supported: {cuda_ver} " 109*da0073e9SAndroid Build Coastguard Worker f"- minimum requirement: {MIN_CUDA_VERSION}" 110*da0073e9SAndroid Build Coastguard Worker ) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker return cuda_ver if torch.version.hip is None else "None" 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Workerdef check_rocm(): 116*da0073e9SAndroid Build Coastguard Worker import torch 117*da0073e9SAndroid Build Coastguard Worker from torch.torch_version import TorchVersion 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker if not torch.cuda.is_available() or torch.version.hip is None: 120*da0073e9SAndroid Build Coastguard Worker return None 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker # Extracts main ROCm version from full string 123*da0073e9SAndroid Build Coastguard Worker torch_rocm_ver = TorchVersion(".".join(list(torch.version.hip.split(".")[0:2]))) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker # check if torch rocm version matches system rocm version 126*da0073e9SAndroid Build Coastguard Worker rocm_ver = get_rocm_version() 127*da0073e9SAndroid Build Coastguard Worker if rocm_ver != torch_rocm_ver: 128*da0073e9SAndroid Build Coastguard Worker warnings.warn( 129*da0073e9SAndroid Build Coastguard Worker f"ROCm version mismatch, `torch` version: {torch_rocm_ver}, env version: {rocm_ver}" 130*da0073e9SAndroid Build Coastguard Worker ) 131*da0073e9SAndroid Build Coastguard Worker if torch_rocm_ver < MIN_ROCM_VERSION: 132*da0073e9SAndroid Build Coastguard Worker warnings.warn( 133*da0073e9SAndroid Build Coastguard Worker f"(`torch`) ROCm version not supported: {torch_rocm_ver} " 134*da0073e9SAndroid Build Coastguard Worker f"- minimum requirement: {MIN_ROCM_VERSION}" 135*da0073e9SAndroid Build Coastguard Worker ) 136*da0073e9SAndroid Build Coastguard Worker if rocm_ver < MIN_ROCM_VERSION: 137*da0073e9SAndroid Build Coastguard Worker warnings.warn( 138*da0073e9SAndroid Build Coastguard Worker f"(env) ROCm version not supported: {rocm_ver} " 139*da0073e9SAndroid Build Coastguard Worker f"- minimum requirement: {MIN_ROCM_VERSION}" 140*da0073e9SAndroid Build Coastguard Worker ) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker return rocm_ver if torch.version.hip else "None" 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Workerdef check_dynamo(backend, device, err_msg) -> None: 146*da0073e9SAndroid Build Coastguard Worker import torch 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker if device == "cuda" and not torch.cuda.is_available(): 149*da0073e9SAndroid Build Coastguard Worker print(f"CUDA not available -- skipping CUDA check on {backend} backend\n") 150*da0073e9SAndroid Build Coastguard Worker return 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker try: 153*da0073e9SAndroid Build Coastguard Worker import torch._dynamo as dynamo 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker if device == "cuda": 156*da0073e9SAndroid Build Coastguard Worker from torch.utils._triton import has_triton 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker if not has_triton(): 159*da0073e9SAndroid Build Coastguard Worker print( 160*da0073e9SAndroid Build Coastguard Worker f"WARNING: CUDA available but triton cannot be used. " 161*da0073e9SAndroid Build Coastguard Worker f"Your GPU may not be supported. " 162*da0073e9SAndroid Build Coastguard Worker f"Skipping CUDA check on {backend} backend\n" 163*da0073e9SAndroid Build Coastguard Worker ) 164*da0073e9SAndroid Build Coastguard Worker return 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker dynamo.reset() 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker @dynamo.optimize(backend, nopython=True) 169*da0073e9SAndroid Build Coastguard Worker def fn(x): 170*da0073e9SAndroid Build Coastguard Worker return x + x 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker class Module(torch.nn.Module): 173*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 174*da0073e9SAndroid Build Coastguard Worker return x + x 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker mod = Module() 177*da0073e9SAndroid Build Coastguard Worker opt_mod = dynamo.optimize(backend, nopython=True)(mod) 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker for f in (fn, opt_mod): 180*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10).to(device) 181*da0073e9SAndroid Build Coastguard Worker x.requires_grad = True 182*da0073e9SAndroid Build Coastguard Worker y = f(x) 183*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(y, x + x) 184*da0073e9SAndroid Build Coastguard Worker z = y.sum() 185*da0073e9SAndroid Build Coastguard Worker z.backward() 186*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(x.grad, 2 * torch.ones_like(x)) 187*da0073e9SAndroid Build Coastguard Worker except Exception: 188*da0073e9SAndroid Build Coastguard Worker sys.stderr.write(traceback.format_exc() + "\n" + err_msg + "\n\n") 189*da0073e9SAndroid Build Coastguard Worker sys.exit(1) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker_SANITY_CHECK_ARGS = ( 193*da0073e9SAndroid Build Coastguard Worker ("eager", "cpu", "CPU eager sanity check failed"), 194*da0073e9SAndroid Build Coastguard Worker ("eager", "cuda", "CUDA eager sanity check failed"), 195*da0073e9SAndroid Build Coastguard Worker ("aot_eager", "cpu", "CPU aot_eager sanity check failed"), 196*da0073e9SAndroid Build Coastguard Worker ("aot_eager", "cuda", "CUDA aot_eager sanity check failed"), 197*da0073e9SAndroid Build Coastguard Worker ("inductor", "cpu", "CPU inductor sanity check failed"), 198*da0073e9SAndroid Build Coastguard Worker ( 199*da0073e9SAndroid Build Coastguard Worker "inductor", 200*da0073e9SAndroid Build Coastguard Worker "cuda", 201*da0073e9SAndroid Build Coastguard Worker "CUDA inductor sanity check failed\n" 202*da0073e9SAndroid Build Coastguard Worker + "NOTE: Please check that you installed the correct hash/version of `triton`", 203*da0073e9SAndroid Build Coastguard Worker ), 204*da0073e9SAndroid Build Coastguard Worker) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Workerdef main() -> None: 208*da0073e9SAndroid Build Coastguard Worker python_ver = check_python() 209*da0073e9SAndroid Build Coastguard Worker torch_ver = check_torch() 210*da0073e9SAndroid Build Coastguard Worker cuda_ver = check_cuda() 211*da0073e9SAndroid Build Coastguard Worker rocm_ver = check_rocm() 212*da0073e9SAndroid Build Coastguard Worker print( 213*da0073e9SAndroid Build Coastguard Worker f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n" 214*da0073e9SAndroid Build Coastguard Worker f"`torch` version: {torch_ver}\n" 215*da0073e9SAndroid Build Coastguard Worker f"CUDA version: {cuda_ver}\n" 216*da0073e9SAndroid Build Coastguard Worker f"ROCM version: {rocm_ver}\n" 217*da0073e9SAndroid Build Coastguard Worker ) 218*da0073e9SAndroid Build Coastguard Worker for args in _SANITY_CHECK_ARGS: 219*da0073e9SAndroid Build Coastguard Worker if sys.version_info >= (3, 13): 220*da0073e9SAndroid Build Coastguard Worker warnings.warn("Dynamo not yet supported in Python 3.13. Skipping check.") 221*da0073e9SAndroid Build Coastguard Worker continue 222*da0073e9SAndroid Build Coastguard Worker check_dynamo(*args) 223*da0073e9SAndroid Build Coastguard Worker print("All required checks passed") 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 227*da0073e9SAndroid Build Coastguard Worker main() 228