xref: /aosp_15_r20/external/executorch/backends/xnnpack/partition/config/node_configs.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import operator
9from typing import List, Optional
10
11import torch
12from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
13    FuseBatchNormWithConvPass,
14)
15from executorch.backends.xnnpack.partition.config.xnnpack_config import (
16    ConfigPrecisionType,
17    XNNPartitionerConfig,
18)
19from executorch.backends.xnnpack.utils.utils import is_param_node
20from executorch.exir.backend.canonical_partitioners.config_partitioner import (
21    format_target_name,
22)
23from executorch.exir.backend.utils import WhyNoPartition
24from torch.export import ExportedProgram
25
26logger = logging.getLogger(__name__)
27why = WhyNoPartition(logger=logger)
28
29
30class BatchNormConfig(XNNPartitionerConfig):
31    target_name = "_native_batch_norm_legit_no_training.default"
32
33    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
34        if not self.check_common_constraints(node, ep):
35            return False
36
37        bn = node
38        conv = node.all_input_nodes[0]
39
40        if conv.op != "call_function":
41            return False
42
43        conv_name = format_target_name(conv.target.__name__)  # pyre-ignore
44
45        if conv_name not in ["convolution.default"]:
46            why(node, f"Invalid conv target {conv_name}")
47            return False
48
49        can_fuse = FuseBatchNormWithConvPass.can_fuse(conv, bn, ep)
50        if not can_fuse:
51            why(node, "BatchNorm cannot be fused with Convolution")
52            return False
53
54        return True
55
56    def get_node_and_deps(
57        self, node: torch.fx.Node, ep: ExportedProgram
58    ) -> List[torch.fx.Node]:
59        deps = [node]
60
61        # weight, bias, running_mean, running_var
62        deps.extend(node.all_input_nodes[1:5])
63
64        # All the users of batchnorm node must be getitem ops. batchnorm
65        # returns a 3-element tuple. Each user must only access the first
66        # element of the tuple.
67        if [
68            (user.target == operator.getitem and user.args[1] == 0)
69            for user in node.users
70        ].count(False):
71            return []
72
73        deps.extend(list(node.users.keys()))
74        return deps
75
76    def supported_precision_types(self) -> List[ConfigPrecisionType]:
77        return [ConfigPrecisionType.FP32]
78
79
80class MaxDimConfig(XNNPartitionerConfig):
81    target_name = "max.dim"
82
83    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
84        # We support max_dim as long as we don't return indices
85        supported_dtypes = {torch.float32, torch.float16, torch.int8, torch.qint8}
86        node_val = node.meta.get("val")
87        output_0 = node_val[0]
88
89        input_node = node.all_input_nodes[0]
90        if len(input_node.meta.get("val").shape) != 4:
91            why(node, f"Unsupported input rank {input_node.meta.get('val').shape}")
92            return False
93        # Don't check indicies dtype
94        if output_0.dtype not in supported_dtypes:
95            why(node, f"Unsupported output dtype {output_0.dtype}")
96            return False
97
98        max_input = node.all_input_nodes[0]
99        if max_input.meta.get("val").dtype not in supported_dtypes:
100            why(node, f"Unsupported input dtype {max_input.meta.get('val').dtype}")
101            return False
102
103        # Make sure that all users are getitems of the first output
104        for user in node.users:
105            if not (user.target == operator.getitem and user.args[1] == 0):
106                why(node, "Unsupported user of max.dim")
107                return False
108
109        return True
110
111    def get_node_and_deps(
112        self, node: torch.fx.Node, ep: ExportedProgram
113    ) -> List[torch.fx.Node]:
114        getitems = list(node.users)
115
116        return [node] + getitems
117
118    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
119        return None
120
121    def supported_precision_types(self) -> List[ConfigPrecisionType]:
122        return [ConfigPrecisionType.FP32]
123
124
125class PreluConfig(XNNPartitionerConfig):
126    target_name = "prelu.default"
127
128    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
129        if not self.check_common_constraints(node, ep):
130            return False
131
132        weight = node.all_input_nodes[1]
133        is_param = is_param_node(ep, weight)
134        if not is_param:
135            why(node, "Prelu weight must be a parameter")
136            return False
137        return True
138
139    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
140        return torch.ops.aten.prelu.default
141
142    def get_node_and_deps(
143        self, node: torch.fx.Node, ep: ExportedProgram
144    ) -> List[torch.fx.Node]:
145        weight = node.all_input_nodes[1]
146
147        return [node, weight]
148
149    def supported_precision_types(self) -> List[ConfigPrecisionType]:
150        return [ConfigPrecisionType.FP32]
151