xref: /aosp_15_r20/external/executorch/backends/arm/operator_support/tosa_supported_operators.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8import operator
9from typing import Type
10
11import torch.fx as fx
12from executorch.backends.arm.tosa_specification import TosaSpecification
13from executorch.exir.dialects._ops import ops as exir_ops
14from torch.fx.passes.operator_support import OperatorSupportBase
15
16
17class SupportedTOSAOperatorCheck:
18    """
19    Supported OP for TOSA lowering
20    """
21
22    # Should be populated by subclass implementation
23    tosa_specs: list[TosaSpecification] = []
24    targets: list[str] = []
25
26    def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
27        """
28        Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec.
29        To be implemented by subclasses targeting
30        """
31        raise NotImplementedError("NodeVisitor must be extended.")
32
33
34# container for all SupportedTosaOperatorCheck classes
35_tosa_spec_dicts: dict[
36    TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]]
37] = {
38    TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
39    TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
40}
41
42
43def register_tosa_support_check(checker):
44    """
45    Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck
46    to be registered for checking if a torch.fx.Node is lowerable given
47    a TOSA specification.
48    """
49    for tosa_spec in checker.tosa_specs:
50        for target in checker.targets:
51            _tosa_spec_dicts[tosa_spec][target] = checker
52    return checker
53
54
55def get_registered_tosa_support_checks(
56    tosa_spec: TosaSpecification,
57) -> dict[str, SupportedTOSAOperatorCheck]:
58
59    if tosa_spec not in _tosa_spec_dicts:
60        raise RuntimeError
61
62    tosa_support_checks = {}
63    for target, tosa_check in _tosa_spec_dicts[tosa_spec].items():
64        tosa_support_checks[target] = tosa_check()
65
66    return tosa_support_checks
67
68
69class TOSASupportedOperators(OperatorSupportBase):
70    def __init__(self, tosa_spec: TosaSpecification):
71        super().__init__()
72        self.tosa_spec = tosa_spec
73
74    def is_node_supported(self, submodules, node: fx.Node) -> bool:
75        supported = node.op == "call_function" and node.target in [
76            exir_ops.edge.aten.add.Tensor,
77            exir_ops.edge.aten.expand_copy.default,
78            exir_ops.edge.aten.cat.default,
79            exir_ops.edge.aten.bmm.default,
80            exir_ops.edge.aten.permute_copy.default,
81            exir_ops.edge.aten.hardtanh.default,
82            exir_ops.edge.aten.convolution.default,
83            exir_ops.edge.aten.div.Tensor,
84            exir_ops.edge.aten.exp.default,
85            exir_ops.edge.aten.log.default,
86            exir_ops.edge.aten.linear.default,
87            exir_ops.edge.aten.split_with_sizes_copy.default,
88            exir_ops.edge.aten.full.default,
89            exir_ops.edge.aten.mul.Tensor,
90            exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
91            exir_ops.edge.aten.native_layer_norm.default,
92            exir_ops.edge.aten.avg_pool2d.default,
93            exir_ops.edge.aten.max_pool2d_with_indices.default,
94            exir_ops.edge.aten.sigmoid.default,
95            exir_ops.edge.aten.mm.default,
96            exir_ops.edge.aten.repeat.default,
97            exir_ops.edge.aten.reciprocal.default,
98            exir_ops.edge.aten.relu.default,
99            exir_ops.edge.aten.rsqrt.default,
100            exir_ops.edge.aten._softmax.default,
101            exir_ops.edge.aten.select_copy.int,
102            exir_ops.edge.aten._log_softmax.default,
103            exir_ops.edge.aten.slice_copy.Tensor,
104            exir_ops.edge.aten.sub.Tensor,
105            exir_ops.edge.aten.sum.dim_IntList,
106            exir_ops.edge.aten.tanh.default,
107            exir_ops.edge.aten.upsample_nearest2d.vec,
108            exir_ops.edge.aten.view_copy.default,
109            exir_ops.edge.aten.clone.default,
110            exir_ops.edge.aten.unsqueeze_copy.default,
111            exir_ops.edge.aten.squeeze_copy.dims,
112            operator.getitem,
113            exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
114            exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
115        ]
116
117        if not supported:
118            supported = self.is_node_supported_custom(node)
119
120        # Override partitioning based on pre partition passes
121        if "arm_override_partition" in node.meta:
122            supported = supported & node.meta["arm_override_partition"]
123            node.meta.pop("arm_override_partition")
124
125        return supported
126
127    def is_node_supported_custom(self, node: fx.Node) -> bool:
128        tosa_checks = get_registered_tosa_support_checks(self.tosa_spec)
129        if node.target in tosa_checks.keys():
130            return tosa_checks[node.target].is_node_supported(node, self.tosa_spec)
131        return False
132