1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, final 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport torch 10*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.example.example_backend import ExampleBackend 11*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.example.example_operators.ops import module_to_annotator 12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 13*523fa7a6SAndroid Build Coastguard Worker generate_partitions_from_list_of_nodes, 14*523fa7a6SAndroid Build Coastguard Worker) 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.partitioner import ( 16*523fa7a6SAndroid Build Coastguard Worker DelegationSpec, 17*523fa7a6SAndroid Build Coastguard Worker Partitioner, 18*523fa7a6SAndroid Build Coastguard Worker PartitionResult, 19*523fa7a6SAndroid Build Coastguard Worker) 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import get_control_flow_submodules 22*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions 23*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import ExportedProgram 24*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.operator_support import OperatorSupportBase 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Worker@final 28*523fa7a6SAndroid Build Coastguard Workerclass ExamplePartitioner(Partitioner): 29*523fa7a6SAndroid Build Coastguard Worker """ 30*523fa7a6SAndroid Build Coastguard Worker Partitions all add/mul nodes regardless of order 31*523fa7a6SAndroid Build Coastguard Worker """ 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 34*523fa7a6SAndroid Build Coastguard Worker self.patterns = module_to_annotator.keys() 35*523fa7a6SAndroid Build Coastguard Worker self.delegation_spec = DelegationSpec(ExampleBackend.__name__, []) 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker class DequantQuantOperatorSupport(OperatorSupportBase): 38*523fa7a6SAndroid Build Coastguard Worker def is_node_supported(self, _submodules, node: torch.fx.Node) -> bool: 39*523fa7a6SAndroid Build Coastguard Worker return node.op == "call_function" and node.target in [ 40*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 41*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 42*523fa7a6SAndroid Build Coastguard Worker ] 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker self.dequant_quant_support = DequantQuantOperatorSupport() 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker def _partition_graph_module( 47*523fa7a6SAndroid Build Coastguard Worker self, edge_graph_module: torch.fx.GraphModule 48*523fa7a6SAndroid Build Coastguard Worker ) -> Dict[str, DelegationSpec]: 49*523fa7a6SAndroid Build Coastguard Worker partition_tags: Dict[str, DelegationSpec] = {} 50*523fa7a6SAndroid Build Coastguard Worker partition_nodes = [] 51*523fa7a6SAndroid Build Coastguard Worker for pattern in self.patterns: 52*523fa7a6SAndroid Build Coastguard Worker fused_partitions = find_sequential_partitions( 53*523fa7a6SAndroid Build Coastguard Worker edge_graph_module, 54*523fa7a6SAndroid Build Coastguard Worker pattern, 55*523fa7a6SAndroid Build Coastguard Worker ) 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker for fused_partition in fused_partitions: 58*523fa7a6SAndroid Build Coastguard Worker for partition in fused_partition: 59*523fa7a6SAndroid Build Coastguard Worker partition_nodes.append(partition.nodes) 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Worker partitions = generate_partitions_from_list_of_nodes( 62*523fa7a6SAndroid Build Coastguard Worker edge_graph_module, partition_nodes, self.dequant_quant_support 63*523fa7a6SAndroid Build Coastguard Worker ) 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker for partition in partitions: 66*523fa7a6SAndroid Build Coastguard Worker for node in partition.nodes: 67*523fa7a6SAndroid Build Coastguard Worker delegation_tag = f"tag{partition.id}" 68*523fa7a6SAndroid Build Coastguard Worker node.meta["delegation_tag"] = delegation_tag 69*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function": 70*523fa7a6SAndroid Build Coastguard Worker for arg_node in node.args: 71*523fa7a6SAndroid Build Coastguard Worker if ( 72*523fa7a6SAndroid Build Coastguard Worker isinstance(arg_node, torch.fx.Node) 73*523fa7a6SAndroid Build Coastguard Worker and arg_node.op == "get_attr" 74*523fa7a6SAndroid Build Coastguard Worker ): 75*523fa7a6SAndroid Build Coastguard Worker arg_node.meta["delegation_tag"] = delegation_tag 76*523fa7a6SAndroid Build Coastguard Worker partition_tags[delegation_tag] = self.delegation_spec 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Worker for _, submodule, _ in get_control_flow_submodules(edge_graph_module): 79*523fa7a6SAndroid Build Coastguard Worker submodule_partition_tags = self._partition_graph_module(submodule) 80*523fa7a6SAndroid Build Coastguard Worker partition_tags.update(submodule_partition_tags) 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker return partition_tags 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker def partition(self, exported_program: ExportedProgram) -> PartitionResult: 85*523fa7a6SAndroid Build Coastguard Worker partition_tag = self._partition_graph_module(exported_program.graph_module) 86*523fa7a6SAndroid Build Coastguard Worker return PartitionResult( 87*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program=exported_program, partition_tags=partition_tag 88*523fa7a6SAndroid Build Coastguard Worker ) 89