1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import Any, List, Optional, Type 10 11import torch 12import torch.fx 13import torch.utils._pytree as pytree 14from executorch.backends.cadence.aot.fuse_ops import ( 15 CadenceFuseOpsInGraph, 16 FuseFullThenReshapePass, 17 FuseTransposeOpPairsPass, 18) 19from executorch.backends.cadence.aot.pass_utils import ( 20 CadencePassAttribute, 21 create_cadence_pass_filter, 22 register_cadence_pass, 23) 24 25from executorch.backends.cadence.aot.remove_ops import ( 26 CadenceRemoveNops, 27 RemoveNopSliceOrViewOpPass, 28 RemoveRedundantOps, 29) 30from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph 31from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph 32from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph 33from executorch.exir.pass_base import ExportPass, PassResult 34from executorch.exir.pass_manager import PassManager, PassType 35from executorch.exir.passes import dead_code_elimination_pass 36from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass 37from executorch.exir.passes.spec_prop_pass import SpecPropPass 38 39 40@register_cadence_pass(CadencePassAttribute(opt_level=0)) 41class InitializePipeline(ExportPass): 42 """ 43 Initialize the pass pipeline. This should invariably be the first pass to 44 run. 45 """ 46 47 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 48 dead_code_elimination_pass(graph_module) 49 result = SpecPropPass()(graph_module) 50 assert result is not None 51 return result 52 53 54@register_cadence_pass(CadencePassAttribute(opt_level=0)) 55class FinalizePipeline(ExportPass): 56 """ 57 The final cleanup pass after running the pass pipeline. 58 """ 59 60 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 61 finalize_passes: List[PassType] = [ 62 ScalarToTensorPass(), 63 SpecPropPass(), 64 ] 65 result = PassManager(passes=finalize_passes)(graph_module) 66 dead_code_elimination_pass(result.graph_module) 67 return result 68 69 70# Similar to what's done in executorch/exir/pass_base.py 71Argument = Any # pyre-ignore 72 73 74def get_passes_in_default_order() -> List[Type[PassType]]: 75 passes = [ 76 InitializePipeline, 77 RemoveRedundantOps.passes, 78 CadenceReorderOpsInGraph.passes, 79 # Phase ordering: remove -> fusion -> replacement passes. 80 CadenceRemoveNops.passes, 81 CadenceFuseOpsInGraph.passes, 82 CadenceReplaceOpsInGraph.passes, 83 CadenceSimplifyOpsInGraph.passes, 84 FinalizePipeline, 85 FuseFullThenReshapePass, 86 FuseTransposeOpPairsPass, 87 RemoveNopSliceOrViewOpPass, 88 ] 89 return pytree.tree_flatten(passes)[0] 90 91 92def get_cadence_passes( 93 opt_level: int, 94) -> List[Optional[PassResult]]: 95 passes = get_passes_in_default_order() 96 pass_filter = create_cadence_pass_filter(opt_level) 97 filtered_passes = [ 98 # pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`. 99 filtered_pass() 100 # pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`. 101 for filtered_pass in list(filter(pass_filter, passes)) 102 ] 103 return filtered_passes 104