1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8import operator 9from typing import Type 10 11import torch.fx as fx 12from executorch.backends.arm.tosa_specification import TosaSpecification 13from executorch.exir.dialects._ops import ops as exir_ops 14from torch.fx.passes.operator_support import OperatorSupportBase 15 16 17class SupportedTOSAOperatorCheck: 18 """ 19 Supported OP for TOSA lowering 20 """ 21 22 # Should be populated by subclass implementation 23 tosa_specs: list[TosaSpecification] = [] 24 targets: list[str] = [] 25 26 def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: 27 """ 28 Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec. 29 To be implemented by subclasses targeting 30 """ 31 raise NotImplementedError("NodeVisitor must be extended.") 32 33 34# container for all SupportedTosaOperatorCheck classes 35_tosa_spec_dicts: dict[ 36 TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]] 37] = { 38 TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {}, 39 TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {}, 40} 41 42 43def register_tosa_support_check(checker): 44 """ 45 Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck 46 to be registered for checking if a torch.fx.Node is lowerable given 47 a TOSA specification. 48 """ 49 for tosa_spec in checker.tosa_specs: 50 for target in checker.targets: 51 _tosa_spec_dicts[tosa_spec][target] = checker 52 return checker 53 54 55def get_registered_tosa_support_checks( 56 tosa_spec: TosaSpecification, 57) -> dict[str, SupportedTOSAOperatorCheck]: 58 59 if tosa_spec not in _tosa_spec_dicts: 60 raise RuntimeError 61 62 tosa_support_checks = {} 63 for target, tosa_check in _tosa_spec_dicts[tosa_spec].items(): 64 tosa_support_checks[target] = tosa_check() 65 66 return tosa_support_checks 67 68 69class TOSASupportedOperators(OperatorSupportBase): 70 def __init__(self, tosa_spec: TosaSpecification): 71 super().__init__() 72 self.tosa_spec = tosa_spec 73 74 def is_node_supported(self, submodules, node: fx.Node) -> bool: 75 supported = node.op == "call_function" and node.target in [ 76 exir_ops.edge.aten.add.Tensor, 77 exir_ops.edge.aten.expand_copy.default, 78 exir_ops.edge.aten.cat.default, 79 exir_ops.edge.aten.bmm.default, 80 exir_ops.edge.aten.permute_copy.default, 81 exir_ops.edge.aten.hardtanh.default, 82 exir_ops.edge.aten.convolution.default, 83 exir_ops.edge.aten.div.Tensor, 84 exir_ops.edge.aten.exp.default, 85 exir_ops.edge.aten.log.default, 86 exir_ops.edge.aten.linear.default, 87 exir_ops.edge.aten.split_with_sizes_copy.default, 88 exir_ops.edge.aten.full.default, 89 exir_ops.edge.aten.mul.Tensor, 90 exir_ops.edge.aten._native_batch_norm_legit_no_training.default, 91 exir_ops.edge.aten.native_layer_norm.default, 92 exir_ops.edge.aten.avg_pool2d.default, 93 exir_ops.edge.aten.max_pool2d_with_indices.default, 94 exir_ops.edge.aten.sigmoid.default, 95 exir_ops.edge.aten.mm.default, 96 exir_ops.edge.aten.repeat.default, 97 exir_ops.edge.aten.reciprocal.default, 98 exir_ops.edge.aten.relu.default, 99 exir_ops.edge.aten.rsqrt.default, 100 exir_ops.edge.aten._softmax.default, 101 exir_ops.edge.aten.select_copy.int, 102 exir_ops.edge.aten._log_softmax.default, 103 exir_ops.edge.aten.slice_copy.Tensor, 104 exir_ops.edge.aten.sub.Tensor, 105 exir_ops.edge.aten.sum.dim_IntList, 106 exir_ops.edge.aten.tanh.default, 107 exir_ops.edge.aten.upsample_nearest2d.vec, 108 exir_ops.edge.aten.view_copy.default, 109 exir_ops.edge.aten.clone.default, 110 exir_ops.edge.aten.unsqueeze_copy.default, 111 exir_ops.edge.aten.squeeze_copy.dims, 112 operator.getitem, 113 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 114 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 115 ] 116 117 if not supported: 118 supported = self.is_node_supported_custom(node) 119 120 # Override partitioning based on pre partition passes 121 if "arm_override_partition" in node.meta: 122 supported = supported & node.meta["arm_override_partition"] 123 node.meta.pop("arm_override_partition") 124 125 return supported 126 127 def is_node_supported_custom(self, node: fx.Node) -> bool: 128 tosa_checks = get_registered_tosa_support_checks(self.tosa_spec) 129 if node.target in tosa_checks.keys(): 130 return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) 131 return False 132