xref: /aosp_15_r20/external/pytorch/tools/nvcc_fix_deps.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker"""Tool to fix the nvcc's dependecy file output
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard WorkerUsage: python nvcc_fix_deps.py nvcc [nvcc args]...
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard WorkerThis wraps nvcc to ensure that the dependency file created by nvcc with the
6*da0073e9SAndroid Build Coastguard Worker-MD flag always uses absolute paths. nvcc sometimes outputs relative paths,
7*da0073e9SAndroid Build Coastguard Workerwhich ninja interprets as an unresolved dependency, so it triggers a rebuild
8*da0073e9SAndroid Build Coastguard Workerof that file every time.
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard WorkerThe easiest way to use this is to define:
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard WorkerCMAKE_CUDA_COMPILER_LAUNCHER="python;tools/nvcc_fix_deps.py;ccache"
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker"""
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerimport subprocess
19*da0073e9SAndroid Build Coastguard Workerimport sys
20*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
21*da0073e9SAndroid Build Coastguard Workerfrom typing import TextIO
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerdef resolve_include(path: Path, include_dirs: list[Path]) -> Path:
25*da0073e9SAndroid Build Coastguard Worker    for include_path in include_dirs:
26*da0073e9SAndroid Build Coastguard Worker        abs_path = include_path / path
27*da0073e9SAndroid Build Coastguard Worker        if abs_path.exists():
28*da0073e9SAndroid Build Coastguard Worker            return abs_path
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    paths = "\n    ".join(str(d / path) for d in include_dirs)
31*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
32*da0073e9SAndroid Build Coastguard Worker        f"""
33*da0073e9SAndroid Build Coastguard WorkerERROR: Failed to resolve dependency:
34*da0073e9SAndroid Build Coastguard Worker    {path}
35*da0073e9SAndroid Build Coastguard WorkerTried the following paths, but none existed:
36*da0073e9SAndroid Build Coastguard Worker    {paths}
37*da0073e9SAndroid Build Coastguard Worker"""
38*da0073e9SAndroid Build Coastguard Worker    )
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerdef repair_depfile(depfile: TextIO, include_dirs: list[Path]) -> None:
42*da0073e9SAndroid Build Coastguard Worker    changes_made = False
43*da0073e9SAndroid Build Coastguard Worker    out = ""
44*da0073e9SAndroid Build Coastguard Worker    for line in depfile:
45*da0073e9SAndroid Build Coastguard Worker        if ":" in line:
46*da0073e9SAndroid Build Coastguard Worker            colon_pos = line.rfind(":")
47*da0073e9SAndroid Build Coastguard Worker            out += line[: colon_pos + 1]
48*da0073e9SAndroid Build Coastguard Worker            line = line[colon_pos + 1 :]
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        line = line.strip()
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        if line.endswith("\\"):
53*da0073e9SAndroid Build Coastguard Worker            end = " \\"
54*da0073e9SAndroid Build Coastguard Worker            line = line[:-1].strip()
55*da0073e9SAndroid Build Coastguard Worker        else:
56*da0073e9SAndroid Build Coastguard Worker            end = ""
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker        path = Path(line)
59*da0073e9SAndroid Build Coastguard Worker        if not path.is_absolute():
60*da0073e9SAndroid Build Coastguard Worker            changes_made = True
61*da0073e9SAndroid Build Coastguard Worker            path = resolve_include(path, include_dirs)
62*da0073e9SAndroid Build Coastguard Worker        out += f"    {path}{end}\n"
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    # If any paths were changed, rewrite the entire file
65*da0073e9SAndroid Build Coastguard Worker    if changes_made:
66*da0073e9SAndroid Build Coastguard Worker        depfile.seek(0)
67*da0073e9SAndroid Build Coastguard Worker        depfile.write(out)
68*da0073e9SAndroid Build Coastguard Worker        depfile.truncate()
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard WorkerPRE_INCLUDE_ARGS = ["-include", "--pre-include"]
72*da0073e9SAndroid Build Coastguard WorkerPOST_INCLUDE_ARGS = ["-I", "--include-path", "-isystem", "--system-include"]
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Workerdef extract_include_arg(include_dirs: list[Path], i: int, args: list[str]) -> None:
76*da0073e9SAndroid Build Coastguard Worker    def extract_one(name: str, i: int, args: list[str]) -> str | None:
77*da0073e9SAndroid Build Coastguard Worker        arg = args[i]
78*da0073e9SAndroid Build Coastguard Worker        if arg == name:
79*da0073e9SAndroid Build Coastguard Worker            return args[i + 1]
80*da0073e9SAndroid Build Coastguard Worker        if arg.startswith(name):
81*da0073e9SAndroid Build Coastguard Worker            arg = arg[len(name) :]
82*da0073e9SAndroid Build Coastguard Worker            return arg[1:] if arg[0] == "=" else arg
83*da0073e9SAndroid Build Coastguard Worker        return None
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    for name in PRE_INCLUDE_ARGS:
86*da0073e9SAndroid Build Coastguard Worker        path = extract_one(name, i, args)
87*da0073e9SAndroid Build Coastguard Worker        if path is not None:
88*da0073e9SAndroid Build Coastguard Worker            include_dirs.insert(0, Path(path).resolve())
89*da0073e9SAndroid Build Coastguard Worker            return
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    for name in POST_INCLUDE_ARGS:
92*da0073e9SAndroid Build Coastguard Worker        path = extract_one(name, i, args)
93*da0073e9SAndroid Build Coastguard Worker        if path is not None:
94*da0073e9SAndroid Build Coastguard Worker            include_dirs.append(Path(path).resolve())
95*da0073e9SAndroid Build Coastguard Worker            return
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
99*da0073e9SAndroid Build Coastguard Worker    ret = subprocess.run(
100*da0073e9SAndroid Build Coastguard Worker        sys.argv[1:], stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr
101*da0073e9SAndroid Build Coastguard Worker    )
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    depfile_path = None
104*da0073e9SAndroid Build Coastguard Worker    include_dirs = []
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    # Parse only the nvcc arguments we care about
107*da0073e9SAndroid Build Coastguard Worker    args = sys.argv[2:]
108*da0073e9SAndroid Build Coastguard Worker    for i, arg in enumerate(args):
109*da0073e9SAndroid Build Coastguard Worker        if arg == "-MF":
110*da0073e9SAndroid Build Coastguard Worker            depfile_path = Path(args[i + 1])
111*da0073e9SAndroid Build Coastguard Worker        elif arg == "-c":
112*da0073e9SAndroid Build Coastguard Worker            # Include the base path of the cuda file
113*da0073e9SAndroid Build Coastguard Worker            include_dirs.append(Path(args[i + 1]).resolve().parent)
114*da0073e9SAndroid Build Coastguard Worker        else:
115*da0073e9SAndroid Build Coastguard Worker            extract_include_arg(include_dirs, i, args)
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    if depfile_path is not None and depfile_path.exists():
118*da0073e9SAndroid Build Coastguard Worker        with depfile_path.open("r+") as f:
119*da0073e9SAndroid Build Coastguard Worker            repair_depfile(f, include_dirs)
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    sys.exit(ret.returncode)
122