xref: /aosp_15_r20/external/executorch/exir/passes/__init__.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import _operator
10import copy
11import json
12import logging
13import os
14import re
15from collections import defaultdict
16from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
17
18import torch
19from executorch.exir import control_flow, memory, memory_planning
20from executorch.exir.common import override_logger
21from executorch.exir.delegate import executorch_call_delegate
22from executorch.exir.dialects.backend._ops import BackendOpOverload
23from executorch.exir.dialects.edge._ops import EdgeOpOverload
24from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
25from executorch.exir.error import InternalError
26from executorch.exir.operator.convert import (
27    get_out_args_from_opoverload,
28    is_out_variant,
29    to_out_variant,
30    to_scratch_op,
31)
32
33from executorch.exir.pass_base import ExportPass
34from executorch.exir.pass_manager import PassManager, PassType
35from executorch.exir.passes.const_prop_pass import ConstPropPass
36from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
37
38from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
39from executorch.exir.passes.insert_write_back_for_buffers_pass import (
40    insert_write_back_for_buffers_pass,
41)
42from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
43from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
44from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
45from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
46from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
47from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
48from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
49    ReplaceBrokenOpsWithFunctionalOpsPass,
50)
51from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
52from executorch.exir.passes.replace_sym_size_op_pass import ReplaceSymSizeOpPass
53from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
54from executorch.exir.passes.spec_prop_pass import SpecPropPass
55from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
56from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
57from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass
58from torch import fx
59from torch._subclasses import FakeTensor
60from torch.fx.passes.infra.pass_base import PassBase, PassResult
61from torch.fx.passes.shape_prop import TensorMetadata
62
63__all__ = [
64    "ExportPass",
65    "ConstPropPass",
66    "QuantFusionPass",
67    "OpReplacePass",
68    "EdgeToBackendOpsPass",
69    "MemoryFormatOpsPass",
70    "MemoryPlanningPass",
71    "HintBasedSymShapeEvalPass",
72    "insert_write_back_for_buffers_pass",
73    "weights_to_outputs_pass",
74]
75
76Argument = Optional[
77    Union[
78        Tuple["Argument", ...],
79        List["Argument"],
80        Dict[str, "Argument"],
81        slice,
82        torch.fx.Node,
83        str,
84        int,
85        float,
86        bool,
87        complex,
88        torch.dtype,
89        torch.Tensor,
90        torch.device,
91        torch.memory_format,
92        torch.layout,
93    ]
94]
95
96
97def update_args(
98    args: Tuple[Argument, ...], key: int, val: torch.fx.Node
99) -> Tuple[Argument, ...]:
100    """
101    A helper function to update an argument container without changing it.
102    This can be used with both args and kwargs.
103    """
104    if isinstance(args, dict):
105        new_dict = copy.copy(args)
106        new_dict[key] = val
107        return new_dict
108
109    assert isinstance(args, tuple)
110    new_tuple = list(args)
111    new_tuple[key] = val
112    return tuple(new_tuple)
113
114
115class DebugPass(PassBase):
116    def __init__(
117        self,
118        msg: str = "",
119        enable_debug_pass: bool = True,
120        show_src: bool = False,
121        show_full_path: bool = False,
122        show_all_frames: bool = False,
123        path_filter: Optional[str] = None,
124        show_spec: bool = False,
125        log_filename: Optional[str] = None,
126    ) -> None:
127        """
128        show_src: whether to show source code that generated each fx Node
129        show_full_path: whether to show the full path of source code or just the filename
130        show_all_frames: control for each node whether show only the last frame or all the frames.
131        path_filter: a regular expression to filter the path of the stackframes
132        log_filename: if provided, the output will also be written to this path.
133            Existing content in this file will be discarded.
134        """
135        self.msg = msg
136        self.enable_debug_pass = enable_debug_pass
137        self.show_src = show_src
138        self.show_full_path = show_full_path
139        self.show_all_frames = show_all_frames
140        self.show_spec = show_spec
141        self.log_filename = log_filename
142        if path_filter:
143            self.path_filter_re = re.compile(path_filter)  # pyre-ignore
144        else:
145            self.path_filter_re = None
146
147    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
148        """
149        Counts the number of operations and call_funciton operations.
150        """
151        if not self.enable_debug_pass:
152            return PassResult(graph_module, False)
153        # it doesn't make sense to mute the DebugPass if user already
154        # specify self.enable_debug_pass to be true
155        with override_logger(filename=self.log_filename):
156            self.callWithLoggerEnabled(graph_module)
157        return PassResult(graph_module, True)
158
159    def printFrames(self, node: fx.Node) -> None:
160        """
161        The DebugPass maybe used for graph generated by both the old exir dispatch
162        tracer or the new pt2 tracer.
163        The former store 'stack_trace' field as a json string;
164        the latter store 'stack_trace' field as a free form string like:
165          ```
166            File "/data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/20c706e99f51cf3a/executorch/test/end2end/__end2end__/end2end#link-tree/executorch/test/end2end/test_end2end.py", line 150, in forward
167                o = o * a
168          ```
169        Make this method handle both format. In future, maybe we can drop the
170        support for old exir dispatch tracer.
171        """
172        if (
173            self.show_src
174            and "stack_trace" in node.meta
175            and len(node.meta["stack_trace"]) > 0
176        ):
177            try:
178                stack_trace = json.loads(node.meta["stack_trace"])
179                is_json = True
180            except json.decoder.JSONDecodeError:
181                is_json = False
182
183            if not is_json:
184                logging.debug(node.meta["stack_trace"])
185                return
186
187            frame_list = []  # tuple of filename, frame name, line number and line
188            for frame in stack_trace:
189                filename = frame["filename"]
190                name = frame["name"]
191                lineno = frame["lineno"]
192                line = frame["line"]
193                if not self.show_full_path:
194                    filename = os.path.basename(filename)
195                mark = "#link-tree/"
196                if mark in filename:
197                    filename = filename.split(mark)[-1]
198
199                if not self.path_filter_re or self.path_filter_re.search(filename):
200                    frame_list.append((filename, name, lineno, line))
201
202            if not self.show_all_frames:
203                frame_list = frame_list[-1:]
204            for filename, name, lineno, line in frame_list:
205                logging.debug(f"      > {filename}:{lineno} in {name}: {line}")
206
207    def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
208        if self.msg:
209            logging.debug(self.msg)
210        logging.debug("Enter debug_pass")
211        graph_module.recompile()
212        logging.debug(f"Code is:\n{graph_module.code}")
213        op_to_cnt = defaultdict(int)  # stats for op type
214        func_to_cnt = defaultdict(int)  # stats for targets in call_function type
215        logging.debug("Nodes:")
216        idx = 0
217        for node in graph_module.graph.nodes:
218            # TODO: better to print python code along with TensorSpecs
219            logging.debug(f"{idx:4}: {node.format_node()}")
220            if self.show_spec:
221                specs = memory_planning.get_node_tensor_specs(node)
222                for spec in specs:
223                    logging.debug(f"      {spec.debug()}")
224                logging.debug(f"      val: {node.meta.get('val', None)}")
225            self.printFrames(node)
226            idx += 1
227            op_to_cnt[node.op] += 1
228
229            if node.op == "call_function":
230                target = str(node.target)
231                func_to_cnt[target] += 1
232
233        logging.debug("-- node op type stat --")
234        for op, cnt in op_to_cnt.items():
235            logging.debug(f" op {op}, cnt {cnt}")
236
237        logging.debug("-- call_function stat --")
238        for fn, cnt in func_to_cnt.items():
239            logging.debug(f" fn {fn}, cnt {cnt}")
240
241
242# Skip these ops when converting to out variants. They will be handled and
243# removed by the emitter.
244# pyre-ignore
245to_out_var_skiplist: Set[Callable[[Any], Any]] = {
246    _operator.getitem,
247    torch.ops.higher_order.cond,
248    control_flow.while_loop,
249    # memory.alloc will be added after the to_out_variant pass so usually
250    # we won't see it in the input graph to the to_out_variant pass, unless
251    # it's retraced after running to_out_variant with the first trace.
252    memory.alloc,
253    memory.view,
254    executorch_call_delegate,
255    torch.ops.aten.copy_.default,
256}
257to_out_var_skiplist.update(_EXECUTORCH_SYM_OPS)
258
259
260def make_alloc_node(
261    graph_module: torch.fx.GraphModule,
262    val: Union[
263        Optional[FakeTensor], List[Optional[FakeTensor]], Tuple[Optional[FakeTensor]]
264    ],
265    tensor_meta: Union[
266        Optional[TensorMetadata],
267        List[Optional[TensorMetadata]],
268        Tuple[Optional[TensorMetadata]],
269    ],
270) -> torch.fx.Node:
271    """
272    Note: tensor_metadata is only used in the case of a Tensor subclass, since
273    fakifying a tensor subclass is not supported right now
274    """
275    if val is None:
276        if tensor_meta is not None:
277            assert isinstance(tensor_meta, TensorMetadata)
278            alloc_spec = (tensor_meta.shape, tensor_meta.dtype)
279        else:
280            raise InternalError(
281                "Memory allocator node needs FakeTensor val or TensorMetadata to proceed"
282            )
283    elif isinstance(val, FakeTensor):
284        alloc_spec = (val.shape, val.dtype)
285    else:
286        assert isinstance(val, list) or isinstance(val, tuple)
287        assert isinstance(tensor_meta, list) or isinstance(tensor_meta, tuple)
288        alloc_spec: List[memory.AllocSpec] = []
289        for v, t in zip(val, tensor_meta):
290            if v is not None:
291                # pyre-fixme[6]: For 1st argument expected
292                #  `Union[List[Tuple[List[int], dtype]], Tuple[List[int], dtype]]` but
293                #  got `Tuple[Size, dtype]`.
294                alloc_spec.append((v.shape, v.dtype))
295            elif t is not None:
296                # pyre-fixme[6]: For 1st argument expected
297                #  `Union[List[Tuple[List[int], dtype]], Tuple[List[int], dtype]]` but
298                #  got `Tuple[Size, dtype]`.
299                alloc_spec.append((t.shape, t.dtype))
300            else:
301                raise InternalError(
302                    "Memory allocator node needs FakeTensor val or TensorMetadata to proceed"
303                )
304
305    # pyre-fixme[6]
306    alloc = graph_module.graph.call_function(memory.alloc, (alloc_spec,))
307    alloc.meta["val"] = val
308    alloc.meta["tensor_meta"] = tensor_meta
309    return alloc
310
311
312class ToOutVarPass(PassBase):
313    def __init__(self, ignore_to_out_var_failure: bool = False) -> None:
314        self.ignore_to_out_var_failure = ignore_to_out_var_failure
315
316    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:  # noqa: C901
317        """
318        Converts all of the functions to contain an out variant if it does not exist
319        """
320        missing_out_vars: Set[str] = set()
321
322        def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
323            assert node.op == "get_attr"
324            return getattr(graph_module, node.target)
325
326        for node in graph_module.graph.nodes:
327            if node.op != "call_function":
328                continue
329
330            target = node.target
331            if target == torch.ops.higher_order.cond:
332                self.call(get_submodule(node.args[1]))
333                self.call(get_submodule(node.args[2]))
334                continue
335            if target == torch.ops.higher_order.map_impl:
336                self.call(get_submodule(node.args[0]))
337                continue
338            elif target == control_flow.while_loop:
339                self.call(get_submodule(node.args[0]))
340                self.call(get_submodule(node.args[1]))
341                continue
342            elif getattr(target, "__module__", None) in ("builtins", "_operator"):
343                continue
344            elif target in to_out_var_skiplist:
345                continue
346            if not isinstance(
347                target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload)
348            ):
349                raise RuntimeError(f"Require an op overload for target: {target}")
350
351            op_name = target._schema.name
352            overload_name = target._schema.overload_name
353            if is_out_variant(op_name, overload_name):
354                # TODO (zhxchen17) Remove this after functionalization is always on.
355                if "out" in node.kwargs and isinstance(node.kwargs["out"], fx.Node):
356                    out = node.kwargs["out"]
357                    if out.target is not memory.alloc and len(out.users) == 1:
358                        with graph_module.graph.inserting_before(node):
359                            alloc = make_alloc_node(
360                                graph_module,
361                                node.meta["val"],
362                                node.meta["tensor_meta"],
363                            )
364                        out.replace_all_uses_with(alloc)
365                        graph_module.graph.erase_node(out)
366                continue
367
368            try:
369                if isinstance(target, (EdgeOpOverload, BackendOpOverload)):
370                    out_var_target = target.to_out_variant()
371                    out_args_names = get_out_args_from_opoverload(out_var_target)
372                else:
373                    out_var_target, out_args_names = to_out_variant(target)
374            except RuntimeError as e:
375                # pyre-fixme[16]: `GraphModule` has no attribute
376                #  `encounter_to_out_var_failure`.
377                graph_module.encounter_to_out_var_failure = True
378                logging.info(
379                    f"Failed converting '{target}' to its out variant with error: '{e}'"
380                )
381                missing_out_vars.add(op_name)
382                continue
383
384            assert out_var_target
385            out_var_kwargs = {}
386
387            # Pool functional target's kwargs into out-variant's kwargs
388            for arg in out_var_target._schema.arguments:
389                if arg.name in out_args_names:
390                    continue
391                if arg.name in node.kwargs:
392                    out_var_kwargs[arg.name] = node.kwargs[arg.name]
393
394            with graph_module.graph.inserting_before(node):
395                if len(out_args_names) == 1:
396                    alloc_node = make_alloc_node(
397                        graph_module, node.meta["val"], node.meta["tensor_meta"]
398                    )
399                    out_var_kwargs[out_args_names[0]] = alloc_node
400                    if len(out_var_target._schema.returns) == 0:
401                        node.replace_all_uses_with(alloc_node)
402                else:
403                    # If the op has multiple out args, we assume the node's
404                    # metadata contains a fake tensor with the same size and type
405                    fake_tensor_list = node.meta["val"]
406                    tensor_metadatas = node.meta["tensor_meta"]
407                    assert isinstance(
408                        fake_tensor_list, (list, tuple)
409                    ), "Expected a list/tuple of tensors when the op has multiple out arguments"
410                    assert len(out_args_names) == len(
411                        fake_tensor_list
412                    ), f"Expected {len(out_args_names)} tensor specs, but got {len(node.meta['val'])}"
413                    for out_arg_name, val, tensor_meta in zip(
414                        out_args_names, fake_tensor_list, tensor_metadatas
415                    ):
416                        if val is None:
417                            out_var_kwargs[out_arg_name] = None
418                            continue
419                        assert isinstance(val, FakeTensor)
420                        out_var_kwargs[out_arg_name] = make_alloc_node(
421                            graph_module, val, tensor_meta
422                        )
423
424            node.target = out_var_target
425            node.kwargs = out_var_kwargs
426
427        if (not self.ignore_to_out_var_failure) and len(missing_out_vars) > 0:
428            raise RuntimeError(f"Missing out variants: {missing_out_vars}")
429        return PassResult(graph_module, True)
430
431
432def to_scratch_op_pass(graph_module: torch.fx.GraphModule) -> PassResult:
433    for node in graph_module.graph.nodes:
434        if node.op != "call_function":
435            continue
436        target = node.target
437        if not isinstance(target, torch._ops.OpOverload):
438            # ignore ops that are not OpOverload. Examples are operator.getitem,
439            # memory.alloc etc.
440            continue
441
442        scratch_op = to_scratch_op(target)
443        if not scratch_op:
444            continue
445
446        args_vals = [nd.meta.get("val") for nd in node.args]
447        kwargs_vals = {name: nd.meta.get("val") for name, nd in node.kwargs.items()}
448        get_scratch_metas = getattr(target, "get_scratch_metas", None)
449        if not get_scratch_metas:
450            raise RuntimeError(
451                "The get_scratch_metas attribute is not found on the out variant op when converting it to a scratch op. Make sure you have imported the module that attaches the get_scratch_metas attribute to the out variant op."
452            )
453        scratch_metas = get_scratch_metas(*args_vals, **kwargs_vals)
454        scratch_kwargs = {}
455        with graph_module.graph.inserting_before(node):
456            for name, val in scratch_metas.items():
457                scratch = make_alloc_node(graph_module, val, None)
458                scratch_kwargs[name] = scratch
459        node.target = scratch_op
460        kwargs = dict(node.kwargs)
461        kwargs.update(scratch_kwargs)
462        node.kwargs = kwargs
463        logging.debug(f"Out variant {target} is converted to scratch op {scratch_op}")
464    return PassResult(graph_module, True)
465
466
467def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult:
468    for subgm in graph_module.modules():
469        if not isinstance(subgm, torch.fx.GraphModule):
470            continue
471        subgm.graph.eliminate_dead_code()
472        subgm.recompile()
473    return PassResult(graph_module, True)
474
475
476# Passes to convert a graph module from ATen to Edge IR
477
478base_pre_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = PassManager(
479    passes=[
480        # ReplaceSymSizeOpPass need to be run before other passes which inherits
481        # from ExportPass. ExportPass can not handle OpOverloadPacket in its
482        # call_function method. The ReplaceSymSizeOpPass pass converts sym size
483        # ops from OpOverloadPacket to OpOverload.
484        ReplaceSymSizeOpPass(),
485        NormalizeTransposePass(),
486        ReplaceBrokenOpsWithFunctionalOpsPass(),
487        ScalarToTensorPass(),
488        SymToTensorPass(),
489        RemoveNoopPass(),
490        RemoveToCopyPass(),
491    ]
492).passes
493
494base_post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = (
495    PassManager(
496        passes=[
497            dead_code_elimination_pass,
498            DebugHandleGeneratorPass(),
499        ]
500    ).passes
501)
502
503
504def propagate_dynamic_shape(
505    dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND,
506) -> List[PassType]:
507    """
508    Run a few passes on the GraphModule to propagate the dynamic shape information.
509
510    Mainly used to provide dynamic shape information for delegation.
511    """
512    return [
513        SpecPropPass(),
514        HintBasedSymShapeEvalPass(),
515    ]
516