xref: /aosp_15_r20/external/pytorch/tools/nvcc_fix_deps.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Tool to fix the nvcc's dependecy file output
2
3Usage: python nvcc_fix_deps.py nvcc [nvcc args]...
4
5This wraps nvcc to ensure that the dependency file created by nvcc with the
6-MD flag always uses absolute paths. nvcc sometimes outputs relative paths,
7which ninja interprets as an unresolved dependency, so it triggers a rebuild
8of that file every time.
9
10The easiest way to use this is to define:
11
12CMAKE_CUDA_COMPILER_LAUNCHER="python;tools/nvcc_fix_deps.py;ccache"
13
14"""
15
16from __future__ import annotations
17
18import subprocess
19import sys
20from pathlib import Path
21from typing import TextIO
22
23
24def resolve_include(path: Path, include_dirs: list[Path]) -> Path:
25    for include_path in include_dirs:
26        abs_path = include_path / path
27        if abs_path.exists():
28            return abs_path
29
30    paths = "\n    ".join(str(d / path) for d in include_dirs)
31    raise RuntimeError(
32        f"""
33ERROR: Failed to resolve dependency:
34    {path}
35Tried the following paths, but none existed:
36    {paths}
37"""
38    )
39
40
41def repair_depfile(depfile: TextIO, include_dirs: list[Path]) -> None:
42    changes_made = False
43    out = ""
44    for line in depfile:
45        if ":" in line:
46            colon_pos = line.rfind(":")
47            out += line[: colon_pos + 1]
48            line = line[colon_pos + 1 :]
49
50        line = line.strip()
51
52        if line.endswith("\\"):
53            end = " \\"
54            line = line[:-1].strip()
55        else:
56            end = ""
57
58        path = Path(line)
59        if not path.is_absolute():
60            changes_made = True
61            path = resolve_include(path, include_dirs)
62        out += f"    {path}{end}\n"
63
64    # If any paths were changed, rewrite the entire file
65    if changes_made:
66        depfile.seek(0)
67        depfile.write(out)
68        depfile.truncate()
69
70
71PRE_INCLUDE_ARGS = ["-include", "--pre-include"]
72POST_INCLUDE_ARGS = ["-I", "--include-path", "-isystem", "--system-include"]
73
74
75def extract_include_arg(include_dirs: list[Path], i: int, args: list[str]) -> None:
76    def extract_one(name: str, i: int, args: list[str]) -> str | None:
77        arg = args[i]
78        if arg == name:
79            return args[i + 1]
80        if arg.startswith(name):
81            arg = arg[len(name) :]
82            return arg[1:] if arg[0] == "=" else arg
83        return None
84
85    for name in PRE_INCLUDE_ARGS:
86        path = extract_one(name, i, args)
87        if path is not None:
88            include_dirs.insert(0, Path(path).resolve())
89            return
90
91    for name in POST_INCLUDE_ARGS:
92        path = extract_one(name, i, args)
93        if path is not None:
94            include_dirs.append(Path(path).resolve())
95            return
96
97
98if __name__ == "__main__":
99    ret = subprocess.run(
100        sys.argv[1:], stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr
101    )
102
103    depfile_path = None
104    include_dirs = []
105
106    # Parse only the nvcc arguments we care about
107    args = sys.argv[2:]
108    for i, arg in enumerate(args):
109        if arg == "-MF":
110            depfile_path = Path(args[i + 1])
111        elif arg == "-c":
112            # Include the base path of the cuda file
113            include_dirs.append(Path(args[i + 1]).resolve().parent)
114        else:
115            extract_include_arg(include_dirs, i, args)
116
117    if depfile_path is not None and depfile_path.exists():
118        with depfile_path.open("r+") as f:
119            repair_depfile(f, include_dirs)
120
121    sys.exit(ret.returncode)
122