1# Copyright 2022 The ChromiumOS Authors 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5# Tools for refactoring references in rust code. 6# 7# Contains the last run refactoring for reference. Don't run this script, it'll 8# fail, but use it as a foundation for other refactorings. 9 10from contextlib import contextmanager 11import os 12import re 13import subprocess 14from pathlib import Path 15from typing import Callable, NamedTuple, Union 16 17SearchPattern = Union[str, re.Pattern[str]] 18 19 20class Token(NamedTuple): 21 token: str 22 start: int 23 end: int 24 25 26def tokenize(source: str): 27 "Split source by whitespace with start/end indices annotated." 28 start = 0 29 for i in range(len(source)): 30 if source[i] in (" ", "\n", "\t") and i - start > 0: 31 token = source[start:i].strip() 32 if token: 33 yield Token(token, start, i) 34 start = i 35 36 37def parse_module_chunks(source: str): 38 """Terrible parser to split code by `mod foo { ... }` statements. Please don't judge me. 39 40 Returns the original source split with module names anntated as ('module name', 'source') 41 """ 42 tokens = list(tokenize(source)) 43 prev = 0 44 for i in range(len(tokens) - 2): 45 if tokens[i].token == "mod" and tokens[i + 2].token == "{": 46 brackets = 1 47 for j in range(i + 3, len(tokens)): 48 if "{" not in tokens[j].token or "}" not in tokens[j].token: 49 if "{" in tokens[j].token: 50 brackets += 1 51 elif "}" in tokens[j].token: 52 brackets -= 1 53 if brackets == 0: 54 start = tokens[i + 2].end 55 end = tokens[j].start 56 yield ("", source[prev:start]) 57 yield (tokens[i + 1].token, source[start:end]) 58 prev = end 59 break 60 if prev != len(source): 61 yield ("", source[prev:]) 62 63 64def replace_use_references(file_path: Path, callback: Callable[[list[str], str], str]): 65 """Calls 'callback' for each foo::bar reference in `file_path`. 66 67 The callback is called with the reference as an argument and is expected to return the rewritten 68 reference. 69 Additionally, the absolute path in the module tree is provided, taking into account the file 70 path as well as modules defined in the source itself. 71 72 eg. 73 src/foo.rs: 74 ``` 75 mod tests { 76 use crate::baz; 77 } 78 ``` 79 will call `callback(['foo', 'tests'], 'crate::baz')` 80 """ 81 module_parts = list(file_path.parts[:-1]) 82 if file_path.stem not in ("mod", "lib"): 83 module_parts.append(file_path.stem) 84 85 with open(file_path, "r") as file: 86 contents = file.read() 87 chunks: list[str] = [] 88 for module, source in parse_module_chunks(contents): 89 if module: 90 full_module_parts = module_parts + [module] 91 else: 92 full_module_parts = module_parts 93 chunks.append( 94 re.sub( 95 r"([\w\*\_\$]+\:\:)+[\w\*\_]+", 96 lambda m: callback(full_module_parts, m.group(0)), 97 source, 98 ) 99 ) 100 with open(file_path, "w") as file: 101 file.write("".join(chunks)) 102 103 104@contextmanager 105def chdir(path: Union[Path, str]): 106 origin = Path().absolute() 107 try: 108 os.chdir(path) 109 yield 110 finally: 111 os.chdir(origin) 112 113 114def use_super_instead_of_crate(root: Path): 115 """Expects to be run directly on the src directory and assumes 116 that directory to be the module crate:: refers to.""" 117 118 def replace(module: list[str], use: str): 119 # Patch up weird module structure... 120 if len(module) > 1 and module[0] == "win": 121 # Only the listed modules are actually in win::. 122 # The rest is in the top level. 123 if module[1] not in ( 124 "file_traits", 125 "syslog", 126 "platform_timer_utils", 127 "file_util", 128 "shm", 129 "wait", 130 "mmap", 131 "stream_channel", 132 "timer", 133 ): 134 del module[0] 135 if len(module) > 0 and module[0] in ("punch_hole", "write_zeros"): 136 module = ["write_zeroes", module[0]] 137 138 if use.startswith("crate::"): 139 new_use = use.replace("crate::", "super::" * len(module)) 140 print("::".join(module), use, "->", new_use) 141 return new_use 142 return use 143 144 with chdir(root): 145 for file in Path().glob("**/*.rs"): 146 replace_use_references(file, replace) 147 148 149def main(): 150 path = Path("common") / "win_sys_util/src" 151 subprocess.check_call(["git", "checkout", "-f", str(path)]) 152 153 # Use rustfmt to re-format use statements to be one per line. 154 subprocess.check_call( 155 ["rustfmt", "+nightly", "--config=imports_granularity=item", f"{path}/lib.rs"] 156 ) 157 use_super_instead_of_crate(path) 158 subprocess.check_call( 159 ["rustfmt", "+nightly", "--config=imports_granularity=crate", f"{path}/lib.rs"] 160 ) 161 162 163main() 164