1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8import logging 9import os 10from typing import Callable, final, List, Optional, Tuple 11 12import torch 13from executorch.backends.arm.arm_backend import ArmBackend # usort: skip 14from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass 15from executorch.backends.arm.operator_support.tosa_supported_operators import ( 16 TOSASupportedOperators, 17) 18from executorch.backends.arm.tosa_specification import TosaSpecification 19from executorch.exir.backend.compile_spec_schema import CompileSpec 20from executorch.exir.backend.partitioner import ( 21 DelegationSpec, 22 Partitioner, 23 PartitionResult, 24) 25from executorch.exir.backend.utils import tag_constant_data 26from executorch.exir.passes import PassManager 27from torch.export.exported_program import ExportedProgram 28from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 29 30logger = logging.getLogger(__name__) 31logger.setLevel(logging.WARNING) 32TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" 33if TOSA_DBG_VERBOSE: 34 logging.basicConfig(level=logging.INFO) 35 logger.setLevel(logging.INFO) 36 37 38@final 39class ArmPartitioner(Partitioner): 40 def __init__(self, compile_spec: List[CompileSpec]) -> None: 41 self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec) 42 43 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 44 # Run the CapabilityBasedPartitioner to return the largest possible 45 # subgraphs containing the nodes with the tags 46 logger.info("ArmPartitioner::partition") 47 partition_tags = {} 48 49 tosa_spec = TosaSpecification.create_from_compilespecs( 50 self.delegation_spec.compile_specs 51 ) 52 53 logger.info(f"Partitioning for {tosa_spec}") 54 55 for spec in self.delegation_spec.compile_specs: 56 if spec.key == "quantize_io" and spec.value.decode() == "True": 57 # Exclude IO quantization from the partition 58 passes = PassManager( 59 passes=[ 60 TagIOQuantPass(), 61 ] 62 ) 63 passes(exported_program.graph_module) 64 65 capability_partitioner = CapabilityBasedPartitioner( 66 exported_program.graph_module, 67 TOSASupportedOperators(tosa_spec), 68 allows_single_node_partition=True, 69 ) 70 partition_list = capability_partitioner.propose_partitions() 71 for partition in partition_list: 72 for node in partition.nodes: 73 tag = f"tag{partition.id}" 74 node.meta["delegation_tag"] = tag 75 partition_tags[tag] = self.delegation_spec 76 77 tag_constant_data(exported_program) 78 79 return PartitionResult( 80 tagged_exported_program=exported_program, partition_tags=partition_tags 81 ) 82 83 def ops_to_not_decompose( 84 self, 85 ep: ExportedProgram, 86 ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: 87 ops_to_not_decompose = [ 88 torch.ops.aten.linear.default, 89 torch.ops.aten.upsample_nearest2d.vec, 90 ] 91 return (ops_to_not_decompose, None) 92