xref: /aosp_15_r20/external/pytorch/tools/iwyu/fixup.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport re
2*da0073e9SAndroid Build Coastguard Workerimport sys
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard WorkerQUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"')
6*da0073e9SAndroid Build Coastguard WorkerANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>")
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker# By default iwyu will pick the C include, but we prefer the C++ headers
9*da0073e9SAndroid Build Coastguard WorkerSTD_C_HEADER_MAP = {
10*da0073e9SAndroid Build Coastguard Worker    "<assert.h>": "<cassert>",
11*da0073e9SAndroid Build Coastguard Worker    "<complex.h>": "<ccomplex>",
12*da0073e9SAndroid Build Coastguard Worker    "<ctype.h>": "<cctype>",
13*da0073e9SAndroid Build Coastguard Worker    "<errno.h>": "<cerrno>",
14*da0073e9SAndroid Build Coastguard Worker    "<fenv.h>": "<cfenv>",
15*da0073e9SAndroid Build Coastguard Worker    "<float.h>": "<cfloat>",
16*da0073e9SAndroid Build Coastguard Worker    "<inttypes.h>": "<cinttypes>",
17*da0073e9SAndroid Build Coastguard Worker    "<iso646.h>": "<ciso646>",
18*da0073e9SAndroid Build Coastguard Worker    "<limits.h>": "<climits>",
19*da0073e9SAndroid Build Coastguard Worker    "<locale.h>": "<clocale>",
20*da0073e9SAndroid Build Coastguard Worker    "<math.h>": "<cmath>",
21*da0073e9SAndroid Build Coastguard Worker    "<setjmp.h>": "<csetjmp>",
22*da0073e9SAndroid Build Coastguard Worker    "<signal.h>": "<csignal>",
23*da0073e9SAndroid Build Coastguard Worker    "<stdalign.h>": "<cstdalign>",
24*da0073e9SAndroid Build Coastguard Worker    "<stdarg.h>": "<cstdarg>",
25*da0073e9SAndroid Build Coastguard Worker    "<stdbool.h>": "<cstdbool>",
26*da0073e9SAndroid Build Coastguard Worker    "<stddef.h>": "<cstddef>",
27*da0073e9SAndroid Build Coastguard Worker    "<stdint.h>": "<cstdint>",
28*da0073e9SAndroid Build Coastguard Worker    "<stdio.h>": "<cstdio>",
29*da0073e9SAndroid Build Coastguard Worker    "<stdlib.h>": "<cstdlib>",
30*da0073e9SAndroid Build Coastguard Worker    "<string.h>": "<cstring>",
31*da0073e9SAndroid Build Coastguard Worker    "<tgmath.h>": "<ctgmath>",
32*da0073e9SAndroid Build Coastguard Worker    "<time.h>": "<ctime>",
33*da0073e9SAndroid Build Coastguard Worker    "<uchar.h>": "<cuchar>",
34*da0073e9SAndroid Build Coastguard Worker    "<wchar.h>": "<cwchar>",
35*da0073e9SAndroid Build Coastguard Worker    "<wctype.h>": "<cwctype>",
36*da0073e9SAndroid Build Coastguard Worker}
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Workerdef main() -> None:
40*da0073e9SAndroid Build Coastguard Worker    for line in sys.stdin:
41*da0073e9SAndroid Build Coastguard Worker        # Convert all quoted includes to angle brackets
42*da0073e9SAndroid Build Coastguard Worker        match = QUOTE_INCLUDE_RE.match(line)
43*da0073e9SAndroid Build Coastguard Worker        if match is not None:
44*da0073e9SAndroid Build Coastguard Worker            print(f"#include <{match.group(1)}>{line[match.end(0):]}", end="")
45*da0073e9SAndroid Build Coastguard Worker            continue
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker        match = ANGLE_INCLUDE_RE.match(line)
48*da0073e9SAndroid Build Coastguard Worker        if match is not None:
49*da0073e9SAndroid Build Coastguard Worker            path = f"<{match.group(1)}>"
50*da0073e9SAndroid Build Coastguard Worker            new_path = STD_C_HEADER_MAP.get(path, path)
51*da0073e9SAndroid Build Coastguard Worker            tail = line[match.end(0) :]
52*da0073e9SAndroid Build Coastguard Worker            if len(tail) > 1:
53*da0073e9SAndroid Build Coastguard Worker                tail = " " + tail
54*da0073e9SAndroid Build Coastguard Worker            print(f"#include {new_path}{tail}", end="")
55*da0073e9SAndroid Build Coastguard Worker            continue
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        print(line, end="")
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
61*da0073e9SAndroid Build Coastguard Worker    main()
62