xref: /aosp_15_r20/external/pytorch/torch/fx/passes/infra/pass_manager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3import logging
4from queue import Queue
5from functools import wraps
6from typing import Callable, Dict, List
7
8import torch.nn as nn
9from torch.fx.graph_module import GraphModule
10from torch.fx._compatibility import compatibility
11from torch.fx.passes.infra.pass_base import PassResult
12
13logger = logging.getLogger(__name__)
14logger.setLevel(logging.WARNING)
15
16__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
17
18@compatibility(is_backward_compatible=False)
19def pass_result_wrapper(fn: Callable) -> Callable:
20    """
21    Wrapper for passes which currently do not return a PassResult.
22    This wrapper makes them return a PassResult containing the modified object
23    and True for the "modified" flag.
24
25    Args:
26        fn (Callable[Module, Any])
27
28    Returns:
29        wrapped_fn (Callable[Module, PassResult])
30    """
31    if fn is None:
32        return None
33
34    @wraps(fn)
35    def wrapped_fn(gm):
36        res = fn(gm)
37        if res is None:
38            return PassResult(gm, True)
39        if isinstance(res, PassResult):
40            return res
41        elif isinstance(res, nn.Module):
42            return PassResult(res, True)
43
44    if not inspect.isfunction(fn):
45        wrapped_fn.__name__ = type(fn).__name__
46
47    return wrapped_fn
48
49def _validate_pass_schedule_constraint(
50    constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
51) -> None:
52    for i, a in enumerate(passes):
53        for j, b in enumerate(passes[i + 1 :]):
54            if constraint(a, b):
55                continue
56            raise RuntimeError(
57                f"pass schedule constraint violated. Expected {a} before {b}"
58                f" but found {a} at index {i} and {b} at index{j} in pass"
59                f" list."
60            )
61
62def _topological_sort_passes(
63    passes: List[Callable], constraints: List[Callable]
64) -> List[Callable]:
65    """
66    Args
67        passes: Passes that we are ordering
68        constraints: Constraints applied on these passes
69
70    Returns
71        A sorted list of callables and a boolean of if a circular dependency
72        existed
73    """
74    if len(constraints) == 0:
75        return passes
76
77    # Contruct a graph mapping nodes to a list of their users
78    graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
79    indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
80    candidates: Queue = Queue()
81    for a in passes:
82        for b in passes:
83            if a == b:
84                continue
85
86            for constraint in constraints:
87                if not constraint(a, b):
88                    graph[b].append(a)
89                    indegree_map[a] += 1
90
91        if indegree_map[a] == 0:
92            candidates.put(a)
93
94    visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
95    sorted_passes: List[Callable] = []
96
97    while not candidates.empty():
98        p = candidates.get()
99        sorted_passes.append(p)
100        visited[p] = True
101
102        for n in graph[p]:
103            if not visited[n]:
104                indegree_map[n] -= 1
105                if indegree_map[n] == 0:
106                    candidates.put(n)
107
108    # Check if there are unvisited nodes (aka cycles in the graph)
109    cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
110    if len(cycle_passes) != 0:
111        error = f"Circular dependency detected within the following passes: {cycle_passes}"
112        raise RuntimeError(error)
113
114    return sorted_passes
115
116@compatibility(is_backward_compatible=False)
117def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
118    """
119    Defines a partial order ('depends on' function) where `this` must occur
120    before `that`.
121
122    For example, the following pass list and constraint list would be invalid.
123    ```
124    passes = [pass_b, pass_a]
125
126    constraints = [
127        this_before_that_pass_constraint(pass_a, pass_b)
128    ]
129    ```
130
131    Args:
132        this (Callable): pass which should occur first
133        that (Callable): pass which should occur later
134
135    Returns:
136        depends_on (Callable[[Object, Object], bool]
137    """
138
139    def depends_on(a: Callable, b: Callable):
140        return a != that or b != this
141
142    return depends_on
143
144
145@compatibility(is_backward_compatible=False)
146class PassManager:
147    """
148    Construct a PassManager.
149
150    Collects passes and constraints. This defines the pass schedule, manages
151    pass constraints and pass execution.
152
153    Args:
154        passes (Optional[List[Callable]]): List of passes. A pass is a
155            callable which modifies an object and returns a PassResult
156        constraint (Optional[List[Callable]]): List of constraints. A
157            constraint is a callable which takes two passes (A, B) and returns
158            True if A depends on B and False otherwise. See implementation of
159            `this_before_that_pass_constraint` for example.
160        steps (int): Max number of times we run the passes (default = 1).
161        run_checks_after_each_pass (bool): Whether to run checks and linting
162            after each pass
163        suppress_check_failures (bool): Whether to raise errors when running
164            checks
165    """
166
167    passes: List[Callable[[nn.Module], PassResult]]
168    constraints: List[Callable[[Callable, Callable], bool]]
169    _validated: bool = False
170    steps: int = 1
171
172    def __init__(
173        self,
174        passes=None,
175        constraints=None,
176        steps=None,
177        run_checks_after_each_pass: bool = False,
178        suppress_check_failures: bool = False,
179    ):
180        self.passes = passes or []
181        self.constraints = constraints or []
182        if steps:
183            self.steps = steps
184
185        self.run_checks_after_each_pass = run_checks_after_each_pass
186        self.suppress_check_failures = suppress_check_failures
187
188    def add_pass(self, _pass: Callable):
189        """
190        Adds a pass into the current list of passes.
191        """
192        self.passes.append(_pass)
193        self._validated = False
194
195    def add_constraint(self, constraint: Callable):
196        """
197        Adds a constraint into the current list of constraints.
198        """
199        self.constraints.append(constraint)
200        self._validated = False
201
202    def validate_constraints(self):
203        """
204        Validates that current pass schedule defined by `self.passes` is valid
205        according to all constraints in `self.constraints`
206        """
207        if self._validated:
208            return
209        for constraint in self.constraints:
210            _validate_pass_schedule_constraint(constraint, self.passes)
211        self._validated = True
212
213    def solve_constraints(self):
214        """
215        Finds a valid traversal order based on the given constraints and orders
216        the passes based on this order.
217
218        If a circular dependency exists between the constraints and steps = 1,
219        then we will raise an error because if steps != 1 this means that we
220        will re-run the passes, allowing for circular dependencies.
221        """
222        self.passes = _topological_sort_passes(self.passes, self.constraints)
223        self._validated = True
224
225    def add_checks(self, check: Callable) -> None:
226        """
227        Adds a function which takes runs various checks on a given graph module.
228        This function is run before and after each pass if the
229        `run_checks_after_each_pass` flag is enabled.
230        """
231        sig = inspect.signature(check)
232
233        if len(list(sig.parameters.values())) != 1:
234            raise TypeError("PassManager check function should only take in one variable, a module")
235
236        setattr(self, "check", check)  # noqa: B010
237
238    def check(self, module: nn.Module) -> None:
239        pass
240
241    def __call__(self, module: nn.Module) -> PassResult:
242        """
243        Runs a list of passes in the order based on `self.passes` on the given
244        graph module. Each time a pass is run, checks and linting will be run on
245        the graph module if `run_checks_after_each_pass` is set.
246
247        If the module is a graph module, we will run the list of passes until
248        the graph stops changing, or until `steps` number of times.
249        """
250        # Order the passes based on the constraints
251        if not self._validated:
252            self.solve_constraints()
253
254        # Check graph invariants
255        self.check(module)
256
257        # Run the set of passes `steps` number of times or until the graph stops
258        # changing
259        overall_modified = False
260        for _ in range(self.steps):
261            modified = False
262
263            # Run the set of passes on the graph module
264            for i, fn in enumerate(self.passes):
265                fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
266                logger.debug("Running pass '%s'", fn_name)
267
268                try:
269                    res = fn(module)
270
271                    if not isinstance(res, PassResult) and not hasattr(
272                        res, "graph_module"
273                    ):
274                        raise TypeError(
275                            f"The result of the pass {fn_name} should be type PassResult."
276                            + "Please wrap it with pass_result_wrapper()"
277                        )
278                    module = res.graph_module
279                    modified = modified or res.modified
280
281                    if isinstance(module, GraphModule):
282                        logger.debug("Graph after pass '%s': %s", fn_name, module.graph)
283                        module.recompile()
284
285                    # Check graph invariants
286                    if self.run_checks_after_each_pass:
287                        self.check(module)
288
289                except Exception as e:
290                    prev_pass_names = [
291                        p.__name__ if inspect.isfunction(p) else type(p).__name__
292                        for p in self.passes[:i]
293                    ]
294                    msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
295                    raise Exception(msg) from e  # noqa: TRY002
296
297            # If the graph no longer changes, then we can stop running these passes
298            overall_modified = overall_modified or modified
299            if not modified:
300                break
301
302        return PassResult(module, overall_modified)
303