1# Copyright (c) 2024 MediaTek Inc. 2# 3# Licensed under the BSD License (the "License"); you may not use this file 4# except in compliance with the License. See the license file in the root 5# directory of this source tree for more details. 6 7from typing import Callable, final, List, Optional, Tuple 8 9import torch 10from executorch.backends.mediatek.preprocess import NeuropilotBackend 11from executorch.exir.backend.backend_details import CompileSpec 12from executorch.exir.backend.partitioner import ( 13 DelegationSpec, 14 Partitioner, 15 PartitionResult, 16) 17from executorch.exir.backend.utils import tag_constant_data 18 19from mtk_converter.python.converters.pytorch import importer_v2 20from torch.export.exported_program import ExportedProgram 21from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 22from torch.fx.passes.operator_support import OperatorSupportBase 23 24 25class NeuropilotOperatorsSupport(OperatorSupportBase): 26 27 def __init__( 28 self, 29 op_types_to_skip: Optional[set] = None, 30 op_names_to_skip: Optional[set] = None, 31 ) -> None: 32 if op_types_to_skip is None: 33 op_types_to_skip = set() 34 if op_names_to_skip is None: 35 op_names_to_skip = set() 36 37 self._op_types_to_skip = op_types_to_skip 38 self._op_names_to_skip = op_names_to_skip 39 40 def is_node_supported(self, _, node: torch.fx.Node) -> bool: 41 # Handle 'call_function' only cause 'placeholder' and 'output' cannot be tagged. 42 # Ref: https://github.com/pytorch/executorch/pull/1398 43 if node.op != "call_function": 44 return False 45 46 op_type = node.target.__name__ 47 if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip: 48 print( 49 f"[Neuropilot Backend] The {op_type} operator with name '{node.name}' is skipped." 50 ) 51 return False 52 53 return importer_v2.is_fx_node_supported(node) 54 55 56@final 57class NeuropilotPartitioner(Partitioner): 58 59 def __init__( 60 self, 61 compile_spec: List[CompileSpec], 62 op_types_to_skip: Optional[set] = None, 63 op_names_to_skip: Optional[set] = None, 64 ) -> None: 65 self.delegation_spec = DelegationSpec(NeuropilotBackend.__name__, compile_spec) 66 self._op_types_to_skip = op_types_to_skip 67 self._op_names_to_skip = op_names_to_skip 68 69 def ops_to_not_decompose( 70 self, 71 ep: ExportedProgram, 72 ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: 73 ops_not_decompose = [ 74 torch.ops.aten.pixel_shuffle.default, 75 torch.ops.aten.upsample_bilinear2d.default, 76 torch.ops.aten.upsample_bilinear2d.vec, 77 torch.ops.aten.upsample_nearest2d.default, 78 torch.ops.aten.upsample_nearest2d.vec, 79 ] 80 return (ops_not_decompose, None) 81 82 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 83 capability_partitioner = CapabilityBasedPartitioner( 84 exported_program.graph_module, 85 NeuropilotOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip), 86 allows_single_node_partition=True, 87 ) 88 partition_list = capability_partitioner.propose_partitions() 89 90 partition_tags = {} 91 for partition in partition_list: 92 for node in partition.nodes: 93 tag = f"tag{partition.id}" 94 node.meta["delegation_tag"] = tag 95 partition_tags[tag] = self.delegation_spec 96 97 tag_constant_data(exported_program) 98 99 return PartitionResult( 100 tagged_exported_program=exported_program, partition_tags=partition_tags 101 ) 102