1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional, Union 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport torch 12*523fa7a6SAndroid Build Coastguard Workerimport torch.fx.passes.infra.pass_manager as fx 13*523fa7a6SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType 15*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.pass_base import PassResult 16*523fa7a6SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard WorkerPassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]] 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Workerclass PassManager(fx.PassManager): 22*523fa7a6SAndroid Build Coastguard Worker """ 23*523fa7a6SAndroid Build Coastguard Worker Class to run multiple passes on a given graph module. The PassManager is 24*523fa7a6SAndroid Build Coastguard Worker callable so to run it, we can just call the PassManager instance. 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker Private Attributes: 27*523fa7a6SAndroid Build Coastguard Worker * **passes**: A list of callable passes 28*523fa7a6SAndroid Build Coastguard Worker * **params**: An instance of PassManagerParams containing the result of the 29*523fa7a6SAndroid Build Coastguard Worker flags set in the constructor. 30*523fa7a6SAndroid Build Coastguard Worker """ 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Worker def __init__( 33*523fa7a6SAndroid Build Coastguard Worker self, 34*523fa7a6SAndroid Build Coastguard Worker passes: Optional[Union[List[PassType], List[List[PassType]]]] = None, 35*523fa7a6SAndroid Build Coastguard Worker run_checks_after_each_pass: bool = False, 36*523fa7a6SAndroid Build Coastguard Worker suppress_check_failures: bool = False, 37*523fa7a6SAndroid Build Coastguard Worker ) -> None: 38*523fa7a6SAndroid Build Coastguard Worker r""" 39*523fa7a6SAndroid Build Coastguard Worker Args: 40*523fa7a6SAndroid Build Coastguard Worker passes: A list of passes 41*523fa7a6SAndroid Build Coastguard Worker enable_debug_pass: set to true to enable the debug passes 42*523fa7a6SAndroid Build Coastguard Worker run_checks_after_each_pass: whether to run checks and linting after each pass 43*523fa7a6SAndroid Build Coastguard Worker """ 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker # Flatten the passes to a list of callables 46*523fa7a6SAndroid Build Coastguard Worker passes = passes if passes else [] 47*523fa7a6SAndroid Build Coastguard Worker flattened_passes = [ 48*523fa7a6SAndroid Build Coastguard Worker fx.pass_result_wrapper(fn) for fn in pytree.tree_flatten(passes)[0] 49*523fa7a6SAndroid Build Coastguard Worker ] 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker super().__init__( 52*523fa7a6SAndroid Build Coastguard Worker flattened_passes, 53*523fa7a6SAndroid Build Coastguard Worker run_checks_after_each_pass=run_checks_after_each_pass, 54*523fa7a6SAndroid Build Coastguard Worker suppress_check_failures=suppress_check_failures, 55*523fa7a6SAndroid Build Coastguard Worker ) 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker def check(self, module: torch.nn.Module) -> None: 58*523fa7a6SAndroid Build Coastguard Worker """ 59*523fa7a6SAndroid Build Coastguard Worker Runs various checks on the given graph module to make sure it contains 60*523fa7a6SAndroid Build Coastguard Worker the needed data for passes. 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker Some checks that need to be run: 63*523fa7a6SAndroid Build Coastguard Worker - Ensure that types of operator node match the types specified in 64*523fa7a6SAndroid Build Coastguard Worker the node's spec field (ex. if the op returns a tuple then the 65*523fa7a6SAndroid Build Coastguard Worker node's spec field is a tuple) 66*523fa7a6SAndroid Build Coastguard Worker - Ensure that the graph module has type torch.fx.GraphModule 67*523fa7a6SAndroid Build Coastguard Worker """ 68*523fa7a6SAndroid Build Coastguard Worker assert isinstance(module, fx.GraphModule) 69*523fa7a6SAndroid Build Coastguard Worker module.recompile() 70*523fa7a6SAndroid Build Coastguard Worker module.graph.lint() 71*523fa7a6SAndroid Build Coastguard Worker # TODO(qihan): use verifier.check_is_exir 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker for node in module.graph.nodes: 74*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_method": 75*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 76*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 77*523fa7a6SAndroid Build Coastguard Worker f"call_method `{node}` is not supported except for backend delegate.", 78*523fa7a6SAndroid Build Coastguard Worker ) 79