# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Dict, final import torch from executorch.backends.example.example_backend import ExampleBackend from executorch.backends.example.example_operators.ops import module_to_annotator from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( generate_partitions_from_list_of_nodes, ) from executorch.exir.backend.partitioner import ( DelegationSpec, Partitioner, PartitionResult, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.graph_module import get_control_flow_submodules from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.export import ExportedProgram from torch.fx.passes.operator_support import OperatorSupportBase @final class ExamplePartitioner(Partitioner): """ Partitions all add/mul nodes regardless of order """ def __init__(self) -> None: self.patterns = module_to_annotator.keys() self.delegation_spec = DelegationSpec(ExampleBackend.__name__, []) class DequantQuantOperatorSupport(OperatorSupportBase): def is_node_supported(self, _submodules, node: torch.fx.Node) -> bool: return node.op == "call_function" and node.target in [ exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, ] self.dequant_quant_support = DequantQuantOperatorSupport() def _partition_graph_module( self, edge_graph_module: torch.fx.GraphModule ) -> Dict[str, DelegationSpec]: partition_tags: Dict[str, DelegationSpec] = {} partition_nodes = [] for pattern in self.patterns: fused_partitions = find_sequential_partitions( edge_graph_module, pattern, ) for fused_partition in fused_partitions: for partition in fused_partition: partition_nodes.append(partition.nodes) partitions = generate_partitions_from_list_of_nodes( edge_graph_module, partition_nodes, self.dequant_quant_support ) for partition in partitions: for node in partition.nodes: delegation_tag = f"tag{partition.id}" node.meta["delegation_tag"] = delegation_tag if node.op == "call_function": for arg_node in node.args: if ( isinstance(arg_node, torch.fx.Node) and arg_node.op == "get_attr" ): arg_node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec for _, submodule, _ in get_control_flow_submodules(edge_graph_module): submodule_partition_tags = self._partition_graph_module(submodule) partition_tags.update(submodule_partition_tags) return partition_tags def partition(self, exported_program: ExportedProgram) -> PartitionResult: partition_tag = self._partition_graph_module(exported_program.graph_module) return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tag )