xref: /aosp_15_r20/external/crosvm/tools/contrib/refactor_use_references.py (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1# Copyright 2022 The ChromiumOS Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5# Tools for refactoring references in rust code.
6#
7# Contains the last run refactoring for reference. Don't run this script, it'll
8# fail, but use it as a foundation for other refactorings.
9
10from contextlib import contextmanager
11import os
12import re
13import subprocess
14from pathlib import Path
15from typing import Callable, NamedTuple, Union
16
17SearchPattern = Union[str, re.Pattern[str]]
18
19
20class Token(NamedTuple):
21    token: str
22    start: int
23    end: int
24
25
26def tokenize(source: str):
27    "Split source by whitespace with start/end indices annotated."
28    start = 0
29    for i in range(len(source)):
30        if source[i] in (" ", "\n", "\t") and i - start > 0:
31            token = source[start:i].strip()
32            if token:
33                yield Token(token, start, i)
34            start = i
35
36
37def parse_module_chunks(source: str):
38    """Terrible parser to split code by `mod foo { ... }` statements. Please don't judge me.
39
40    Returns the original source split with module names anntated as ('module name', 'source')
41    """
42    tokens = list(tokenize(source))
43    prev = 0
44    for i in range(len(tokens) - 2):
45        if tokens[i].token == "mod" and tokens[i + 2].token == "{":
46            brackets = 1
47            for j in range(i + 3, len(tokens)):
48                if "{" not in tokens[j].token or "}" not in tokens[j].token:
49                    if "{" in tokens[j].token:
50                        brackets += 1
51                    elif "}" in tokens[j].token:
52                        brackets -= 1
53                if brackets == 0:
54                    start = tokens[i + 2].end
55                    end = tokens[j].start
56                    yield ("", source[prev:start])
57                    yield (tokens[i + 1].token, source[start:end])
58                    prev = end
59                    break
60    if prev != len(source):
61        yield ("", source[prev:])
62
63
64def replace_use_references(file_path: Path, callback: Callable[[list[str], str], str]):
65    """Calls 'callback' for each foo::bar reference in `file_path`.
66
67    The callback is called with the reference as an argument and is expected to return the rewritten
68    reference.
69    Additionally, the absolute path in the module tree is provided, taking into account the file
70    path as well as modules defined in the source itself.
71
72    eg.
73    src/foo.rs:
74    ```
75    mod tests {
76        use crate::baz;
77    }
78    ```
79    will call `callback(['foo', 'tests'], 'crate::baz')`
80    """
81    module_parts = list(file_path.parts[:-1])
82    if file_path.stem not in ("mod", "lib"):
83        module_parts.append(file_path.stem)
84
85    with open(file_path, "r") as file:
86        contents = file.read()
87    chunks: list[str] = []
88    for module, source in parse_module_chunks(contents):
89        if module:
90            full_module_parts = module_parts + [module]
91        else:
92            full_module_parts = module_parts
93        chunks.append(
94            re.sub(
95                r"([\w\*\_\$]+\:\:)+[\w\*\_]+",
96                lambda m: callback(full_module_parts, m.group(0)),
97                source,
98            )
99        )
100    with open(file_path, "w") as file:
101        file.write("".join(chunks))
102
103
104@contextmanager
105def chdir(path: Union[Path, str]):
106    origin = Path().absolute()
107    try:
108        os.chdir(path)
109        yield
110    finally:
111        os.chdir(origin)
112
113
114def use_super_instead_of_crate(root: Path):
115    """Expects to be run directly on the src directory and assumes
116    that directory to be the module crate:: refers to."""
117
118    def replace(module: list[str], use: str):
119        # Patch up weird module structure...
120        if len(module) > 1 and module[0] == "win":
121            # Only the listed modules are actually in win::.
122            # The rest is in the top level.
123            if module[1] not in (
124                "file_traits",
125                "syslog",
126                "platform_timer_utils",
127                "file_util",
128                "shm",
129                "wait",
130                "mmap",
131                "stream_channel",
132                "timer",
133            ):
134                del module[0]
135        if len(module) > 0 and module[0] in ("punch_hole", "write_zeros"):
136            module = ["write_zeroes", module[0]]
137
138        if use.startswith("crate::"):
139            new_use = use.replace("crate::", "super::" * len(module))
140            print("::".join(module), use, "->", new_use)
141            return new_use
142        return use
143
144    with chdir(root):
145        for file in Path().glob("**/*.rs"):
146            replace_use_references(file, replace)
147
148
149def main():
150    path = Path("common") / "win_sys_util/src"
151    subprocess.check_call(["git", "checkout", "-f", str(path)])
152
153    # Use rustfmt to re-format use statements to be one per line.
154    subprocess.check_call(
155        ["rustfmt", "+nightly", "--config=imports_granularity=item", f"{path}/lib.rs"]
156    )
157    use_super_instead_of_crate(path)
158    subprocess.check_call(
159        ["rustfmt", "+nightly", "--config=imports_granularity=crate", f"{path}/lib.rs"]
160    )
161
162
163main()
164