xref: /aosp_15_r20/external/pytorch/tools/dynamo/verify_dynamo.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import os
2import re
3import subprocess
4import sys
5import traceback
6import warnings
7
8
9MIN_CUDA_VERSION = "11.6"
10MIN_ROCM_VERSION = "5.4"
11MIN_PYTHON_VERSION = (3, 8)
12
13
14class VerifyDynamoError(BaseException):
15    pass
16
17
18def check_python():
19    if sys.version_info < MIN_PYTHON_VERSION:
20        raise VerifyDynamoError(
21            f"Python version not supported: {sys.version_info} "
22            f"- minimum requirement: {MIN_PYTHON_VERSION}"
23        )
24    return sys.version_info
25
26
27def check_torch():
28    import torch
29
30    return torch.__version__
31
32
33# based on torch/utils/cpp_extension.py
34def get_cuda_version():
35    from torch.torch_version import TorchVersion
36    from torch.utils import cpp_extension
37
38    CUDA_HOME = cpp_extension._find_cuda_home()
39    if not CUDA_HOME:
40        raise VerifyDynamoError(cpp_extension.CUDA_NOT_FOUND_MESSAGE)
41
42    nvcc = os.path.join(CUDA_HOME, "bin", "nvcc")
43    cuda_version_str = (
44        subprocess.check_output([nvcc, "--version"])
45        .strip()
46        .decode(*cpp_extension.SUBPROCESS_DECODE_ARGS)
47    )
48    cuda_version = re.search(r"release (\d+[.]\d+)", cuda_version_str)
49    if cuda_version is None:
50        raise VerifyDynamoError("CUDA version not found in `nvcc --version` output")
51
52    cuda_str_version = cuda_version.group(1)
53    return TorchVersion(cuda_str_version)
54
55
56def get_rocm_version():
57    from torch.torch_version import TorchVersion
58    from torch.utils import cpp_extension
59
60    ROCM_HOME = cpp_extension._find_rocm_home()
61    if not ROCM_HOME:
62        raise VerifyDynamoError(
63            "ROCM was not found on the system, please set ROCM_HOME environment variable"
64        )
65
66    hipcc = os.path.join(ROCM_HOME, "bin", "hipcc")
67    hip_version_str = (
68        subprocess.check_output([hipcc, "--version"])
69        .strip()
70        .decode(*cpp_extension.SUBPROCESS_DECODE_ARGS)
71    )
72    hip_version = re.search(r"HIP version: (\d+[.]\d+)", hip_version_str)
73
74    if hip_version is None:
75        raise VerifyDynamoError("HIP version not found in `hipcc --version` output")
76
77    hip_str_version = hip_version.group(1)
78
79    return TorchVersion(hip_str_version)
80
81
82def check_cuda():
83    import torch
84    from torch.torch_version import TorchVersion
85
86    if not torch.cuda.is_available() or torch.version.hip is not None:
87        return None
88
89    torch_cuda_ver = TorchVersion(torch.version.cuda)
90
91    # check if torch cuda version matches system cuda version
92    cuda_ver = get_cuda_version()
93    if cuda_ver != torch_cuda_ver:
94        # raise VerifyDynamoError(
95        warnings.warn(
96            f"CUDA version mismatch, `torch` version: {torch_cuda_ver}, env version: {cuda_ver}"
97        )
98
99    if torch_cuda_ver < MIN_CUDA_VERSION:
100        # raise VerifyDynamoError(
101        warnings.warn(
102            f"(`torch`) CUDA version not supported: {torch_cuda_ver} "
103            f"- minimum requirement: {MIN_CUDA_VERSION}"
104        )
105    if cuda_ver < MIN_CUDA_VERSION:
106        # raise VerifyDynamoError(
107        warnings.warn(
108            f"(env) CUDA version not supported: {cuda_ver} "
109            f"- minimum requirement: {MIN_CUDA_VERSION}"
110        )
111
112    return cuda_ver if torch.version.hip is None else "None"
113
114
115def check_rocm():
116    import torch
117    from torch.torch_version import TorchVersion
118
119    if not torch.cuda.is_available() or torch.version.hip is None:
120        return None
121
122    # Extracts main ROCm version from full string
123    torch_rocm_ver = TorchVersion(".".join(list(torch.version.hip.split(".")[0:2])))
124
125    # check if torch rocm version matches system rocm version
126    rocm_ver = get_rocm_version()
127    if rocm_ver != torch_rocm_ver:
128        warnings.warn(
129            f"ROCm version mismatch, `torch` version: {torch_rocm_ver}, env version: {rocm_ver}"
130        )
131    if torch_rocm_ver < MIN_ROCM_VERSION:
132        warnings.warn(
133            f"(`torch`) ROCm version not supported: {torch_rocm_ver} "
134            f"- minimum requirement: {MIN_ROCM_VERSION}"
135        )
136    if rocm_ver < MIN_ROCM_VERSION:
137        warnings.warn(
138            f"(env) ROCm version not supported: {rocm_ver} "
139            f"- minimum requirement: {MIN_ROCM_VERSION}"
140        )
141
142    return rocm_ver if torch.version.hip else "None"
143
144
145def check_dynamo(backend, device, err_msg) -> None:
146    import torch
147
148    if device == "cuda" and not torch.cuda.is_available():
149        print(f"CUDA not available -- skipping CUDA check on {backend} backend\n")
150        return
151
152    try:
153        import torch._dynamo as dynamo
154
155        if device == "cuda":
156            from torch.utils._triton import has_triton
157
158            if not has_triton():
159                print(
160                    f"WARNING: CUDA available but triton cannot be used. "
161                    f"Your GPU may not be supported. "
162                    f"Skipping CUDA check on {backend} backend\n"
163                )
164                return
165
166        dynamo.reset()
167
168        @dynamo.optimize(backend, nopython=True)
169        def fn(x):
170            return x + x
171
172        class Module(torch.nn.Module):
173            def forward(self, x):
174                return x + x
175
176        mod = Module()
177        opt_mod = dynamo.optimize(backend, nopython=True)(mod)
178
179        for f in (fn, opt_mod):
180            x = torch.randn(10, 10).to(device)
181            x.requires_grad = True
182            y = f(x)
183            torch.testing.assert_close(y, x + x)
184            z = y.sum()
185            z.backward()
186            torch.testing.assert_close(x.grad, 2 * torch.ones_like(x))
187    except Exception:
188        sys.stderr.write(traceback.format_exc() + "\n" + err_msg + "\n\n")
189        sys.exit(1)
190
191
192_SANITY_CHECK_ARGS = (
193    ("eager", "cpu", "CPU eager sanity check failed"),
194    ("eager", "cuda", "CUDA eager sanity check failed"),
195    ("aot_eager", "cpu", "CPU aot_eager sanity check failed"),
196    ("aot_eager", "cuda", "CUDA aot_eager sanity check failed"),
197    ("inductor", "cpu", "CPU inductor sanity check failed"),
198    (
199        "inductor",
200        "cuda",
201        "CUDA inductor sanity check failed\n"
202        + "NOTE: Please check that you installed the correct hash/version of `triton`",
203    ),
204)
205
206
207def main() -> None:
208    python_ver = check_python()
209    torch_ver = check_torch()
210    cuda_ver = check_cuda()
211    rocm_ver = check_rocm()
212    print(
213        f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n"
214        f"`torch` version: {torch_ver}\n"
215        f"CUDA version: {cuda_ver}\n"
216        f"ROCM version: {rocm_ver}\n"
217    )
218    for args in _SANITY_CHECK_ARGS:
219        if sys.version_info >= (3, 13):
220            warnings.warn("Dynamo not yet supported in Python 3.13. Skipping check.")
221            continue
222        check_dynamo(*args)
223    print("All required checks passed")
224
225
226if __name__ == "__main__":
227    main()
228