1# mypy: allow-untyped-defs 2import inspect 3import logging 4from queue import Queue 5from functools import wraps 6from typing import Callable, Dict, List 7 8import torch.nn as nn 9from torch.fx.graph_module import GraphModule 10from torch.fx._compatibility import compatibility 11from torch.fx.passes.infra.pass_base import PassResult 12 13logger = logging.getLogger(__name__) 14logger.setLevel(logging.WARNING) 15 16__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager'] 17 18@compatibility(is_backward_compatible=False) 19def pass_result_wrapper(fn: Callable) -> Callable: 20 """ 21 Wrapper for passes which currently do not return a PassResult. 22 This wrapper makes them return a PassResult containing the modified object 23 and True for the "modified" flag. 24 25 Args: 26 fn (Callable[Module, Any]) 27 28 Returns: 29 wrapped_fn (Callable[Module, PassResult]) 30 """ 31 if fn is None: 32 return None 33 34 @wraps(fn) 35 def wrapped_fn(gm): 36 res = fn(gm) 37 if res is None: 38 return PassResult(gm, True) 39 if isinstance(res, PassResult): 40 return res 41 elif isinstance(res, nn.Module): 42 return PassResult(res, True) 43 44 if not inspect.isfunction(fn): 45 wrapped_fn.__name__ = type(fn).__name__ 46 47 return wrapped_fn 48 49def _validate_pass_schedule_constraint( 50 constraint: Callable[[Callable, Callable], bool], passes: List[Callable] 51) -> None: 52 for i, a in enumerate(passes): 53 for j, b in enumerate(passes[i + 1 :]): 54 if constraint(a, b): 55 continue 56 raise RuntimeError( 57 f"pass schedule constraint violated. Expected {a} before {b}" 58 f" but found {a} at index {i} and {b} at index{j} in pass" 59 f" list." 60 ) 61 62def _topological_sort_passes( 63 passes: List[Callable], constraints: List[Callable] 64) -> List[Callable]: 65 """ 66 Args 67 passes: Passes that we are ordering 68 constraints: Constraints applied on these passes 69 70 Returns 71 A sorted list of callables and a boolean of if a circular dependency 72 existed 73 """ 74 if len(constraints) == 0: 75 return passes 76 77 # Contruct a graph mapping nodes to a list of their users 78 graph: Dict[Callable, List[Callable]] = {p : [] for p in passes} 79 indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0) 80 candidates: Queue = Queue() 81 for a in passes: 82 for b in passes: 83 if a == b: 84 continue 85 86 for constraint in constraints: 87 if not constraint(a, b): 88 graph[b].append(a) 89 indegree_map[a] += 1 90 91 if indegree_map[a] == 0: 92 candidates.put(a) 93 94 visited: Dict[Callable, bool] = dict.fromkeys(passes, False) 95 sorted_passes: List[Callable] = [] 96 97 while not candidates.empty(): 98 p = candidates.get() 99 sorted_passes.append(p) 100 visited[p] = True 101 102 for n in graph[p]: 103 if not visited[n]: 104 indegree_map[n] -= 1 105 if indegree_map[n] == 0: 106 candidates.put(n) 107 108 # Check if there are unvisited nodes (aka cycles in the graph) 109 cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) 110 if len(cycle_passes) != 0: 111 error = f"Circular dependency detected within the following passes: {cycle_passes}" 112 raise RuntimeError(error) 113 114 return sorted_passes 115 116@compatibility(is_backward_compatible=False) 117def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: 118 """ 119 Defines a partial order ('depends on' function) where `this` must occur 120 before `that`. 121 122 For example, the following pass list and constraint list would be invalid. 123 ``` 124 passes = [pass_b, pass_a] 125 126 constraints = [ 127 this_before_that_pass_constraint(pass_a, pass_b) 128 ] 129 ``` 130 131 Args: 132 this (Callable): pass which should occur first 133 that (Callable): pass which should occur later 134 135 Returns: 136 depends_on (Callable[[Object, Object], bool] 137 """ 138 139 def depends_on(a: Callable, b: Callable): 140 return a != that or b != this 141 142 return depends_on 143 144 145@compatibility(is_backward_compatible=False) 146class PassManager: 147 """ 148 Construct a PassManager. 149 150 Collects passes and constraints. This defines the pass schedule, manages 151 pass constraints and pass execution. 152 153 Args: 154 passes (Optional[List[Callable]]): List of passes. A pass is a 155 callable which modifies an object and returns a PassResult 156 constraint (Optional[List[Callable]]): List of constraints. A 157 constraint is a callable which takes two passes (A, B) and returns 158 True if A depends on B and False otherwise. See implementation of 159 `this_before_that_pass_constraint` for example. 160 steps (int): Max number of times we run the passes (default = 1). 161 run_checks_after_each_pass (bool): Whether to run checks and linting 162 after each pass 163 suppress_check_failures (bool): Whether to raise errors when running 164 checks 165 """ 166 167 passes: List[Callable[[nn.Module], PassResult]] 168 constraints: List[Callable[[Callable, Callable], bool]] 169 _validated: bool = False 170 steps: int = 1 171 172 def __init__( 173 self, 174 passes=None, 175 constraints=None, 176 steps=None, 177 run_checks_after_each_pass: bool = False, 178 suppress_check_failures: bool = False, 179 ): 180 self.passes = passes or [] 181 self.constraints = constraints or [] 182 if steps: 183 self.steps = steps 184 185 self.run_checks_after_each_pass = run_checks_after_each_pass 186 self.suppress_check_failures = suppress_check_failures 187 188 def add_pass(self, _pass: Callable): 189 """ 190 Adds a pass into the current list of passes. 191 """ 192 self.passes.append(_pass) 193 self._validated = False 194 195 def add_constraint(self, constraint: Callable): 196 """ 197 Adds a constraint into the current list of constraints. 198 """ 199 self.constraints.append(constraint) 200 self._validated = False 201 202 def validate_constraints(self): 203 """ 204 Validates that current pass schedule defined by `self.passes` is valid 205 according to all constraints in `self.constraints` 206 """ 207 if self._validated: 208 return 209 for constraint in self.constraints: 210 _validate_pass_schedule_constraint(constraint, self.passes) 211 self._validated = True 212 213 def solve_constraints(self): 214 """ 215 Finds a valid traversal order based on the given constraints and orders 216 the passes based on this order. 217 218 If a circular dependency exists between the constraints and steps = 1, 219 then we will raise an error because if steps != 1 this means that we 220 will re-run the passes, allowing for circular dependencies. 221 """ 222 self.passes = _topological_sort_passes(self.passes, self.constraints) 223 self._validated = True 224 225 def add_checks(self, check: Callable) -> None: 226 """ 227 Adds a function which takes runs various checks on a given graph module. 228 This function is run before and after each pass if the 229 `run_checks_after_each_pass` flag is enabled. 230 """ 231 sig = inspect.signature(check) 232 233 if len(list(sig.parameters.values())) != 1: 234 raise TypeError("PassManager check function should only take in one variable, a module") 235 236 setattr(self, "check", check) # noqa: B010 237 238 def check(self, module: nn.Module) -> None: 239 pass 240 241 def __call__(self, module: nn.Module) -> PassResult: 242 """ 243 Runs a list of passes in the order based on `self.passes` on the given 244 graph module. Each time a pass is run, checks and linting will be run on 245 the graph module if `run_checks_after_each_pass` is set. 246 247 If the module is a graph module, we will run the list of passes until 248 the graph stops changing, or until `steps` number of times. 249 """ 250 # Order the passes based on the constraints 251 if not self._validated: 252 self.solve_constraints() 253 254 # Check graph invariants 255 self.check(module) 256 257 # Run the set of passes `steps` number of times or until the graph stops 258 # changing 259 overall_modified = False 260 for _ in range(self.steps): 261 modified = False 262 263 # Run the set of passes on the graph module 264 for i, fn in enumerate(self.passes): 265 fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__ 266 logger.debug("Running pass '%s'", fn_name) 267 268 try: 269 res = fn(module) 270 271 if not isinstance(res, PassResult) and not hasattr( 272 res, "graph_module" 273 ): 274 raise TypeError( 275 f"The result of the pass {fn_name} should be type PassResult." 276 + "Please wrap it with pass_result_wrapper()" 277 ) 278 module = res.graph_module 279 modified = modified or res.modified 280 281 if isinstance(module, GraphModule): 282 logger.debug("Graph after pass '%s': %s", fn_name, module.graph) 283 module.recompile() 284 285 # Check graph invariants 286 if self.run_checks_after_each_pass: 287 self.check(module) 288 289 except Exception as e: 290 prev_pass_names = [ 291 p.__name__ if inspect.isfunction(p) else type(p).__name__ 292 for p in self.passes[:i] 293 ] 294 msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}" 295 raise Exception(msg) from e # noqa: TRY002 296 297 # If the graph no longer changes, then we can stop running these passes 298 overall_modified = overall_modified or modified 299 if not modified: 300 break 301 302 return PassResult(module, overall_modified) 303