xref: /aosp_15_r20/external/executorch/backends/example/example_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, final
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport torch
10*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.example.example_backend import ExampleBackend
11*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.example.example_operators.ops import module_to_annotator
12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
13*523fa7a6SAndroid Build Coastguard Worker    generate_partitions_from_list_of_nodes,
14*523fa7a6SAndroid Build Coastguard Worker)
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.partitioner import (
16*523fa7a6SAndroid Build Coastguard Worker    DelegationSpec,
17*523fa7a6SAndroid Build Coastguard Worker    Partitioner,
18*523fa7a6SAndroid Build Coastguard Worker    PartitionResult,
19*523fa7a6SAndroid Build Coastguard Worker)
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import get_control_flow_submodules
22*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
23*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import ExportedProgram
24*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.operator_support import OperatorSupportBase
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker@final
28*523fa7a6SAndroid Build Coastguard Workerclass ExamplePartitioner(Partitioner):
29*523fa7a6SAndroid Build Coastguard Worker    """
30*523fa7a6SAndroid Build Coastguard Worker    Partitions all add/mul nodes regardless of order
31*523fa7a6SAndroid Build Coastguard Worker    """
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker    def __init__(self) -> None:
34*523fa7a6SAndroid Build Coastguard Worker        self.patterns = module_to_annotator.keys()
35*523fa7a6SAndroid Build Coastguard Worker        self.delegation_spec = DelegationSpec(ExampleBackend.__name__, [])
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker        class DequantQuantOperatorSupport(OperatorSupportBase):
38*523fa7a6SAndroid Build Coastguard Worker            def is_node_supported(self, _submodules, node: torch.fx.Node) -> bool:
39*523fa7a6SAndroid Build Coastguard Worker                return node.op == "call_function" and node.target in [
40*523fa7a6SAndroid Build Coastguard Worker                    exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
41*523fa7a6SAndroid Build Coastguard Worker                    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
42*523fa7a6SAndroid Build Coastguard Worker                ]
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker        self.dequant_quant_support = DequantQuantOperatorSupport()
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker    def _partition_graph_module(
47*523fa7a6SAndroid Build Coastguard Worker        self, edge_graph_module: torch.fx.GraphModule
48*523fa7a6SAndroid Build Coastguard Worker    ) -> Dict[str, DelegationSpec]:
49*523fa7a6SAndroid Build Coastguard Worker        partition_tags: Dict[str, DelegationSpec] = {}
50*523fa7a6SAndroid Build Coastguard Worker        partition_nodes = []
51*523fa7a6SAndroid Build Coastguard Worker        for pattern in self.patterns:
52*523fa7a6SAndroid Build Coastguard Worker            fused_partitions = find_sequential_partitions(
53*523fa7a6SAndroid Build Coastguard Worker                edge_graph_module,
54*523fa7a6SAndroid Build Coastguard Worker                pattern,
55*523fa7a6SAndroid Build Coastguard Worker            )
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker            for fused_partition in fused_partitions:
58*523fa7a6SAndroid Build Coastguard Worker                for partition in fused_partition:
59*523fa7a6SAndroid Build Coastguard Worker                    partition_nodes.append(partition.nodes)
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker        partitions = generate_partitions_from_list_of_nodes(
62*523fa7a6SAndroid Build Coastguard Worker            edge_graph_module, partition_nodes, self.dequant_quant_support
63*523fa7a6SAndroid Build Coastguard Worker        )
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker        for partition in partitions:
66*523fa7a6SAndroid Build Coastguard Worker            for node in partition.nodes:
67*523fa7a6SAndroid Build Coastguard Worker                delegation_tag = f"tag{partition.id}"
68*523fa7a6SAndroid Build Coastguard Worker                node.meta["delegation_tag"] = delegation_tag
69*523fa7a6SAndroid Build Coastguard Worker                if node.op == "call_function":
70*523fa7a6SAndroid Build Coastguard Worker                    for arg_node in node.args:
71*523fa7a6SAndroid Build Coastguard Worker                        if (
72*523fa7a6SAndroid Build Coastguard Worker                            isinstance(arg_node, torch.fx.Node)
73*523fa7a6SAndroid Build Coastguard Worker                            and arg_node.op == "get_attr"
74*523fa7a6SAndroid Build Coastguard Worker                        ):
75*523fa7a6SAndroid Build Coastguard Worker                            arg_node.meta["delegation_tag"] = delegation_tag
76*523fa7a6SAndroid Build Coastguard Worker                partition_tags[delegation_tag] = self.delegation_spec
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Worker        for _, submodule, _ in get_control_flow_submodules(edge_graph_module):
79*523fa7a6SAndroid Build Coastguard Worker            submodule_partition_tags = self._partition_graph_module(submodule)
80*523fa7a6SAndroid Build Coastguard Worker            partition_tags.update(submodule_partition_tags)
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker        return partition_tags
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
85*523fa7a6SAndroid Build Coastguard Worker        partition_tag = self._partition_graph_module(exported_program.graph_module)
86*523fa7a6SAndroid Build Coastguard Worker        return PartitionResult(
87*523fa7a6SAndroid Build Coastguard Worker            tagged_exported_program=exported_program, partition_tags=partition_tag
88*523fa7a6SAndroid Build Coastguard Worker        )
89