xref: /aosp_15_r20/external/pytorch/tools/dynamo/verify_dynamo.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport os
2*da0073e9SAndroid Build Coastguard Workerimport re
3*da0073e9SAndroid Build Coastguard Workerimport subprocess
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport traceback
6*da0073e9SAndroid Build Coastguard Workerimport warnings
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard WorkerMIN_CUDA_VERSION = "11.6"
10*da0073e9SAndroid Build Coastguard WorkerMIN_ROCM_VERSION = "5.4"
11*da0073e9SAndroid Build Coastguard WorkerMIN_PYTHON_VERSION = (3, 8)
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerclass VerifyDynamoError(BaseException):
15*da0073e9SAndroid Build Coastguard Worker    pass
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerdef check_python():
19*da0073e9SAndroid Build Coastguard Worker    if sys.version_info < MIN_PYTHON_VERSION:
20*da0073e9SAndroid Build Coastguard Worker        raise VerifyDynamoError(
21*da0073e9SAndroid Build Coastguard Worker            f"Python version not supported: {sys.version_info} "
22*da0073e9SAndroid Build Coastguard Worker            f"- minimum requirement: {MIN_PYTHON_VERSION}"
23*da0073e9SAndroid Build Coastguard Worker        )
24*da0073e9SAndroid Build Coastguard Worker    return sys.version_info
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerdef check_torch():
28*da0073e9SAndroid Build Coastguard Worker    import torch
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    return torch.__version__
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker# based on torch/utils/cpp_extension.py
34*da0073e9SAndroid Build Coastguard Workerdef get_cuda_version():
35*da0073e9SAndroid Build Coastguard Worker    from torch.torch_version import TorchVersion
36*da0073e9SAndroid Build Coastguard Worker    from torch.utils import cpp_extension
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker    CUDA_HOME = cpp_extension._find_cuda_home()
39*da0073e9SAndroid Build Coastguard Worker    if not CUDA_HOME:
40*da0073e9SAndroid Build Coastguard Worker        raise VerifyDynamoError(cpp_extension.CUDA_NOT_FOUND_MESSAGE)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    nvcc = os.path.join(CUDA_HOME, "bin", "nvcc")
43*da0073e9SAndroid Build Coastguard Worker    cuda_version_str = (
44*da0073e9SAndroid Build Coastguard Worker        subprocess.check_output([nvcc, "--version"])
45*da0073e9SAndroid Build Coastguard Worker        .strip()
46*da0073e9SAndroid Build Coastguard Worker        .decode(*cpp_extension.SUBPROCESS_DECODE_ARGS)
47*da0073e9SAndroid Build Coastguard Worker    )
48*da0073e9SAndroid Build Coastguard Worker    cuda_version = re.search(r"release (\d+[.]\d+)", cuda_version_str)
49*da0073e9SAndroid Build Coastguard Worker    if cuda_version is None:
50*da0073e9SAndroid Build Coastguard Worker        raise VerifyDynamoError("CUDA version not found in `nvcc --version` output")
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    cuda_str_version = cuda_version.group(1)
53*da0073e9SAndroid Build Coastguard Worker    return TorchVersion(cuda_str_version)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerdef get_rocm_version():
57*da0073e9SAndroid Build Coastguard Worker    from torch.torch_version import TorchVersion
58*da0073e9SAndroid Build Coastguard Worker    from torch.utils import cpp_extension
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    ROCM_HOME = cpp_extension._find_rocm_home()
61*da0073e9SAndroid Build Coastguard Worker    if not ROCM_HOME:
62*da0073e9SAndroid Build Coastguard Worker        raise VerifyDynamoError(
63*da0073e9SAndroid Build Coastguard Worker            "ROCM was not found on the system, please set ROCM_HOME environment variable"
64*da0073e9SAndroid Build Coastguard Worker        )
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    hipcc = os.path.join(ROCM_HOME, "bin", "hipcc")
67*da0073e9SAndroid Build Coastguard Worker    hip_version_str = (
68*da0073e9SAndroid Build Coastguard Worker        subprocess.check_output([hipcc, "--version"])
69*da0073e9SAndroid Build Coastguard Worker        .strip()
70*da0073e9SAndroid Build Coastguard Worker        .decode(*cpp_extension.SUBPROCESS_DECODE_ARGS)
71*da0073e9SAndroid Build Coastguard Worker    )
72*da0073e9SAndroid Build Coastguard Worker    hip_version = re.search(r"HIP version: (\d+[.]\d+)", hip_version_str)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    if hip_version is None:
75*da0073e9SAndroid Build Coastguard Worker        raise VerifyDynamoError("HIP version not found in `hipcc --version` output")
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    hip_str_version = hip_version.group(1)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    return TorchVersion(hip_str_version)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Workerdef check_cuda():
83*da0073e9SAndroid Build Coastguard Worker    import torch
84*da0073e9SAndroid Build Coastguard Worker    from torch.torch_version import TorchVersion
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    if not torch.cuda.is_available() or torch.version.hip is not None:
87*da0073e9SAndroid Build Coastguard Worker        return None
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    torch_cuda_ver = TorchVersion(torch.version.cuda)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    # check if torch cuda version matches system cuda version
92*da0073e9SAndroid Build Coastguard Worker    cuda_ver = get_cuda_version()
93*da0073e9SAndroid Build Coastguard Worker    if cuda_ver != torch_cuda_ver:
94*da0073e9SAndroid Build Coastguard Worker        # raise VerifyDynamoError(
95*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
96*da0073e9SAndroid Build Coastguard Worker            f"CUDA version mismatch, `torch` version: {torch_cuda_ver}, env version: {cuda_ver}"
97*da0073e9SAndroid Build Coastguard Worker        )
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker    if torch_cuda_ver < MIN_CUDA_VERSION:
100*da0073e9SAndroid Build Coastguard Worker        # raise VerifyDynamoError(
101*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
102*da0073e9SAndroid Build Coastguard Worker            f"(`torch`) CUDA version not supported: {torch_cuda_ver} "
103*da0073e9SAndroid Build Coastguard Worker            f"- minimum requirement: {MIN_CUDA_VERSION}"
104*da0073e9SAndroid Build Coastguard Worker        )
105*da0073e9SAndroid Build Coastguard Worker    if cuda_ver < MIN_CUDA_VERSION:
106*da0073e9SAndroid Build Coastguard Worker        # raise VerifyDynamoError(
107*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
108*da0073e9SAndroid Build Coastguard Worker            f"(env) CUDA version not supported: {cuda_ver} "
109*da0073e9SAndroid Build Coastguard Worker            f"- minimum requirement: {MIN_CUDA_VERSION}"
110*da0073e9SAndroid Build Coastguard Worker        )
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    return cuda_ver if torch.version.hip is None else "None"
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Workerdef check_rocm():
116*da0073e9SAndroid Build Coastguard Worker    import torch
117*da0073e9SAndroid Build Coastguard Worker    from torch.torch_version import TorchVersion
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    if not torch.cuda.is_available() or torch.version.hip is None:
120*da0073e9SAndroid Build Coastguard Worker        return None
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    # Extracts main ROCm version from full string
123*da0073e9SAndroid Build Coastguard Worker    torch_rocm_ver = TorchVersion(".".join(list(torch.version.hip.split(".")[0:2])))
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    # check if torch rocm version matches system rocm version
126*da0073e9SAndroid Build Coastguard Worker    rocm_ver = get_rocm_version()
127*da0073e9SAndroid Build Coastguard Worker    if rocm_ver != torch_rocm_ver:
128*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
129*da0073e9SAndroid Build Coastguard Worker            f"ROCm version mismatch, `torch` version: {torch_rocm_ver}, env version: {rocm_ver}"
130*da0073e9SAndroid Build Coastguard Worker        )
131*da0073e9SAndroid Build Coastguard Worker    if torch_rocm_ver < MIN_ROCM_VERSION:
132*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
133*da0073e9SAndroid Build Coastguard Worker            f"(`torch`) ROCm version not supported: {torch_rocm_ver} "
134*da0073e9SAndroid Build Coastguard Worker            f"- minimum requirement: {MIN_ROCM_VERSION}"
135*da0073e9SAndroid Build Coastguard Worker        )
136*da0073e9SAndroid Build Coastguard Worker    if rocm_ver < MIN_ROCM_VERSION:
137*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
138*da0073e9SAndroid Build Coastguard Worker            f"(env) ROCm version not supported: {rocm_ver} "
139*da0073e9SAndroid Build Coastguard Worker            f"- minimum requirement: {MIN_ROCM_VERSION}"
140*da0073e9SAndroid Build Coastguard Worker        )
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    return rocm_ver if torch.version.hip else "None"
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Workerdef check_dynamo(backend, device, err_msg) -> None:
146*da0073e9SAndroid Build Coastguard Worker    import torch
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    if device == "cuda" and not torch.cuda.is_available():
149*da0073e9SAndroid Build Coastguard Worker        print(f"CUDA not available -- skipping CUDA check on {backend} backend\n")
150*da0073e9SAndroid Build Coastguard Worker        return
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    try:
153*da0073e9SAndroid Build Coastguard Worker        import torch._dynamo as dynamo
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        if device == "cuda":
156*da0073e9SAndroid Build Coastguard Worker            from torch.utils._triton import has_triton
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker            if not has_triton():
159*da0073e9SAndroid Build Coastguard Worker                print(
160*da0073e9SAndroid Build Coastguard Worker                    f"WARNING: CUDA available but triton cannot be used. "
161*da0073e9SAndroid Build Coastguard Worker                    f"Your GPU may not be supported. "
162*da0073e9SAndroid Build Coastguard Worker                    f"Skipping CUDA check on {backend} backend\n"
163*da0073e9SAndroid Build Coastguard Worker                )
164*da0073e9SAndroid Build Coastguard Worker                return
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        dynamo.reset()
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        @dynamo.optimize(backend, nopython=True)
169*da0073e9SAndroid Build Coastguard Worker        def fn(x):
170*da0073e9SAndroid Build Coastguard Worker            return x + x
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
173*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
174*da0073e9SAndroid Build Coastguard Worker                return x + x
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        mod = Module()
177*da0073e9SAndroid Build Coastguard Worker        opt_mod = dynamo.optimize(backend, nopython=True)(mod)
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        for f in (fn, opt_mod):
180*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(10, 10).to(device)
181*da0073e9SAndroid Build Coastguard Worker            x.requires_grad = True
182*da0073e9SAndroid Build Coastguard Worker            y = f(x)
183*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(y, x + x)
184*da0073e9SAndroid Build Coastguard Worker            z = y.sum()
185*da0073e9SAndroid Build Coastguard Worker            z.backward()
186*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(x.grad, 2 * torch.ones_like(x))
187*da0073e9SAndroid Build Coastguard Worker    except Exception:
188*da0073e9SAndroid Build Coastguard Worker        sys.stderr.write(traceback.format_exc() + "\n" + err_msg + "\n\n")
189*da0073e9SAndroid Build Coastguard Worker        sys.exit(1)
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker_SANITY_CHECK_ARGS = (
193*da0073e9SAndroid Build Coastguard Worker    ("eager", "cpu", "CPU eager sanity check failed"),
194*da0073e9SAndroid Build Coastguard Worker    ("eager", "cuda", "CUDA eager sanity check failed"),
195*da0073e9SAndroid Build Coastguard Worker    ("aot_eager", "cpu", "CPU aot_eager sanity check failed"),
196*da0073e9SAndroid Build Coastguard Worker    ("aot_eager", "cuda", "CUDA aot_eager sanity check failed"),
197*da0073e9SAndroid Build Coastguard Worker    ("inductor", "cpu", "CPU inductor sanity check failed"),
198*da0073e9SAndroid Build Coastguard Worker    (
199*da0073e9SAndroid Build Coastguard Worker        "inductor",
200*da0073e9SAndroid Build Coastguard Worker        "cuda",
201*da0073e9SAndroid Build Coastguard Worker        "CUDA inductor sanity check failed\n"
202*da0073e9SAndroid Build Coastguard Worker        + "NOTE: Please check that you installed the correct hash/version of `triton`",
203*da0073e9SAndroid Build Coastguard Worker    ),
204*da0073e9SAndroid Build Coastguard Worker)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Workerdef main() -> None:
208*da0073e9SAndroid Build Coastguard Worker    python_ver = check_python()
209*da0073e9SAndroid Build Coastguard Worker    torch_ver = check_torch()
210*da0073e9SAndroid Build Coastguard Worker    cuda_ver = check_cuda()
211*da0073e9SAndroid Build Coastguard Worker    rocm_ver = check_rocm()
212*da0073e9SAndroid Build Coastguard Worker    print(
213*da0073e9SAndroid Build Coastguard Worker        f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n"
214*da0073e9SAndroid Build Coastguard Worker        f"`torch` version: {torch_ver}\n"
215*da0073e9SAndroid Build Coastguard Worker        f"CUDA version: {cuda_ver}\n"
216*da0073e9SAndroid Build Coastguard Worker        f"ROCM version: {rocm_ver}\n"
217*da0073e9SAndroid Build Coastguard Worker    )
218*da0073e9SAndroid Build Coastguard Worker    for args in _SANITY_CHECK_ARGS:
219*da0073e9SAndroid Build Coastguard Worker        if sys.version_info >= (3, 13):
220*da0073e9SAndroid Build Coastguard Worker            warnings.warn("Dynamo not yet supported in Python 3.13. Skipping check.")
221*da0073e9SAndroid Build Coastguard Worker            continue
222*da0073e9SAndroid Build Coastguard Worker        check_dynamo(*args)
223*da0073e9SAndroid Build Coastguard Worker    print("All required checks passed")
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
227*da0073e9SAndroid Build Coastguard Worker    main()
228