xref: /aosp_15_r20/external/executorch/backends/mediatek/partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) 2024 MediaTek Inc.
2#
3# Licensed under the BSD License (the "License"); you may not use this file
4# except in compliance with the License. See the license file in the root
5# directory of this source tree for more details.
6
7from typing import Callable, final, List, Optional, Tuple
8
9import torch
10from executorch.backends.mediatek.preprocess import NeuropilotBackend
11from executorch.exir.backend.backend_details import CompileSpec
12from executorch.exir.backend.partitioner import (
13    DelegationSpec,
14    Partitioner,
15    PartitionResult,
16)
17from executorch.exir.backend.utils import tag_constant_data
18
19from mtk_converter.python.converters.pytorch import importer_v2
20from torch.export.exported_program import ExportedProgram
21from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
22from torch.fx.passes.operator_support import OperatorSupportBase
23
24
25class NeuropilotOperatorsSupport(OperatorSupportBase):
26
27    def __init__(
28        self,
29        op_types_to_skip: Optional[set] = None,
30        op_names_to_skip: Optional[set] = None,
31    ) -> None:
32        if op_types_to_skip is None:
33            op_types_to_skip = set()
34        if op_names_to_skip is None:
35            op_names_to_skip = set()
36
37        self._op_types_to_skip = op_types_to_skip
38        self._op_names_to_skip = op_names_to_skip
39
40    def is_node_supported(self, _, node: torch.fx.Node) -> bool:
41        # Handle 'call_function' only cause 'placeholder' and 'output' cannot be tagged.
42        # Ref: https://github.com/pytorch/executorch/pull/1398
43        if node.op != "call_function":
44            return False
45
46        op_type = node.target.__name__
47        if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip:
48            print(
49                f"[Neuropilot Backend] The {op_type} operator with name '{node.name}' is skipped."
50            )
51            return False
52
53        return importer_v2.is_fx_node_supported(node)
54
55
56@final
57class NeuropilotPartitioner(Partitioner):
58
59    def __init__(
60        self,
61        compile_spec: List[CompileSpec],
62        op_types_to_skip: Optional[set] = None,
63        op_names_to_skip: Optional[set] = None,
64    ) -> None:
65        self.delegation_spec = DelegationSpec(NeuropilotBackend.__name__, compile_spec)
66        self._op_types_to_skip = op_types_to_skip
67        self._op_names_to_skip = op_names_to_skip
68
69    def ops_to_not_decompose(
70        self,
71        ep: ExportedProgram,
72    ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
73        ops_not_decompose = [
74            torch.ops.aten.pixel_shuffle.default,
75            torch.ops.aten.upsample_bilinear2d.default,
76            torch.ops.aten.upsample_bilinear2d.vec,
77            torch.ops.aten.upsample_nearest2d.default,
78            torch.ops.aten.upsample_nearest2d.vec,
79        ]
80        return (ops_not_decompose, None)
81
82    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
83        capability_partitioner = CapabilityBasedPartitioner(
84            exported_program.graph_module,
85            NeuropilotOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip),
86            allows_single_node_partition=True,
87        )
88        partition_list = capability_partitioner.propose_partitions()
89
90        partition_tags = {}
91        for partition in partition_list:
92            for node in partition.nodes:
93                tag = f"tag{partition.id}"
94                node.meta["delegation_tag"] = tag
95                partition_tags[tag] = self.delegation_spec
96
97        tag_constant_data(exported_program)
98
99        return PartitionResult(
100            tagged_exported_program=exported_program, partition_tags=partition_tags
101        )
102