xref: /aosp_15_r20/external/executorch/backends/arm/arm_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8import logging
9import os
10from typing import Callable, final, List, Optional, Tuple
11
12import torch
13from executorch.backends.arm.arm_backend import ArmBackend  # usort: skip
14from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
15from executorch.backends.arm.operator_support.tosa_supported_operators import (
16    TOSASupportedOperators,
17)
18from executorch.backends.arm.tosa_specification import TosaSpecification
19from executorch.exir.backend.compile_spec_schema import CompileSpec
20from executorch.exir.backend.partitioner import (
21    DelegationSpec,
22    Partitioner,
23    PartitionResult,
24)
25from executorch.exir.backend.utils import tag_constant_data
26from executorch.exir.passes import PassManager
27from torch.export.exported_program import ExportedProgram
28from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
29
30logger = logging.getLogger(__name__)
31logger.setLevel(logging.WARNING)
32TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
33if TOSA_DBG_VERBOSE:
34    logging.basicConfig(level=logging.INFO)
35    logger.setLevel(logging.INFO)
36
37
38@final
39class ArmPartitioner(Partitioner):
40    def __init__(self, compile_spec: List[CompileSpec]) -> None:
41        self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec)
42
43    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
44        # Run the CapabilityBasedPartitioner to return the largest possible
45        # subgraphs containing the nodes with the tags
46        logger.info("ArmPartitioner::partition")
47        partition_tags = {}
48
49        tosa_spec = TosaSpecification.create_from_compilespecs(
50            self.delegation_spec.compile_specs
51        )
52
53        logger.info(f"Partitioning for {tosa_spec}")
54
55        for spec in self.delegation_spec.compile_specs:
56            if spec.key == "quantize_io" and spec.value.decode() == "True":
57                # Exclude IO quantization from the partition
58                passes = PassManager(
59                    passes=[
60                        TagIOQuantPass(),
61                    ]
62                )
63                passes(exported_program.graph_module)
64
65        capability_partitioner = CapabilityBasedPartitioner(
66            exported_program.graph_module,
67            TOSASupportedOperators(tosa_spec),
68            allows_single_node_partition=True,
69        )
70        partition_list = capability_partitioner.propose_partitions()
71        for partition in partition_list:
72            for node in partition.nodes:
73                tag = f"tag{partition.id}"
74                node.meta["delegation_tag"] = tag
75                partition_tags[tag] = self.delegation_spec
76
77        tag_constant_data(exported_program)
78
79        return PartitionResult(
80            tagged_exported_program=exported_program, partition_tags=partition_tags
81        )
82
83    def ops_to_not_decompose(
84        self,
85        ep: ExportedProgram,
86    ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
87        ops_to_not_decompose = [
88            torch.ops.aten.linear.default,
89            torch.ops.aten.upsample_nearest2d.vec,
90        ]
91        return (ops_to_not_decompose, None)
92