xref: /aosp_15_r20/external/executorch/backends/xnnpack/operators/op_conv2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import cast, Dict, List
8
9import torch
10from executorch.backends.transforms import get_shape
11from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
12from executorch.backends.xnnpack.operators.node_visitor import (
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.xnnpack.operators.quant_params import QuantParams
17from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
18    XNNConv2d,
19    XNNDepthwiseConv2d,
20    XNNGraph,
21    XNode,
22)
23from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node
24
25from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
26
27
28@register_node_visitor
29class Conv2d(NodeVisitor):
30    target = "aten.convolution.default"
31
32    def __init__(self, *args) -> None:
33        super().__init__(*args)
34
35    def define_node(
36        self,
37        node: torch.fx.Node,
38        xnn_graph: XNNGraph,
39        vals_to_ids: Dict[torch.fx.Node, int],
40        debug_handle: int,
41    ) -> None:
42        kwargs = {}
43        # input
44        input_node = get_input_node(node, 0)
45        input_quant_params = QuantParams.from_inputs(input_node, self._exported_program)
46        self.define_tensor(
47            input_node,
48            xnn_graph,
49            vals_to_ids,
50            convert_to_nhwc=True,
51            quant_params=input_quant_params,
52        )  # NHWC input
53        kwargs["input1_id"] = vals_to_ids[get_input_node(node, 0)]
54
55        # filter shape for pytorch convolution is (oc, inc/groups, height, width)
56        # shape for xnnpack convolution is (oc, height, width, inc/groups), to convert
57        # to the proper shape, this is essentially a NCHW to NHWC conversion
58        kernel_node = get_input_node(node, 1)
59        kernel_shape = get_shape(kernel_node)
60        groups = cast(int, node.args[8])
61        group_input_channels = kernel_shape[1]
62        group_output_channels = int(kernel_shape[0] / groups)
63
64        # XNNPACK expects the kernel's N and C dimensions to be swapped for
65        # Depthwise Convolution, which occurs under the following conditions:
66        # 1) groups = input_channels (i.e. group_input_channels = 1)
67        # 2) output_channels is a positive integer multiple of input channels
68        is_depthwise_conv = (group_input_channels == 1) and (
69            group_output_channels % group_input_channels == 0
70        )
71        weight_quant_params = QuantParams.from_weights(
72            kernel_node, self._exported_program
73        )
74        fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16
75
76        self.define_tensor(
77            kernel_node,
78            xnn_graph,
79            vals_to_ids,
80            convert_to_nhwc=True,
81            swap_nc_for_depthwise_weights=is_depthwise_conv,
82            quant_params=weight_quant_params,
83            fp32_static_weights=fp32_static_weights,
84        )
85        kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)]
86
87        # output
88        output_min_max = FuseActivationPass.get_fused_activation(node)
89        output_quant_params = QuantParams.from_outputs(node)
90        self.define_tensor(
91            node,
92            xnn_graph,
93            vals_to_ids,
94            convert_to_nhwc=True,
95            quant_params=output_quant_params,
96        )  # NHWC output
97        kwargs["output_id"] = vals_to_ids[node]
98
99        # bias
100        kwargs["bias_id"] = XNN_INVALID_VALUE_ID
101        if node.args[2] is not None:
102            # If there is a bias
103            bias_node = get_input_node(node, 2)
104            bias_quant_params = QuantParams.from_bias(
105                bias_node, weight_quant_params, input_quant_params
106            )
107            self.define_tensor(
108                get_input_node(node, 2),
109                xnn_graph,
110                vals_to_ids,
111                convert_to_nhwc=False,
112                quant_params=bias_quant_params,
113                fp32_static_weights=fp32_static_weights,
114            )
115            kwargs["bias_id"] = vals_to_ids[get_input_node(node, 2)]
116
117        stride = cast(List[int], node.args[3])
118        padding = cast(List[int], node.args[4])
119        dilation = cast(List[int], node.args[5])
120        if len(padding) == 1:
121            padding = padding + padding
122
123        # args[6] = transposed
124        check_or_raise(
125            not cast(bool, node.args[6]), "No support for transposed convolution"
126        )
127        # args[7] = output padding
128        check_or_raise(
129            all(out_pad == 0 for out_pad in cast(List[int], node.args[7])),
130            "XNNPACK does not support output padding",
131        )
132
133        check_or_raise(
134            len(stride) == 2, "XNNPACK currently only supports 2D convolution"
135        )
136        kwargs["padding_top"] = padding[0]
137        kwargs["padding_right"] = padding[1]
138        kwargs["padding_bottom"] = padding[0]
139        kwargs["padding_left"] = padding[1]
140        kwargs["kernel_height"] = kernel_shape[2]
141        kwargs["kernel_width"] = kernel_shape[3]
142        kwargs["subsampling_height"] = stride[0]
143        kwargs["subsampling_width"] = stride[1]
144        kwargs["dilation_height"] = dilation[0]
145        kwargs["dilation_width"] = dilation[1]
146        kwargs["group_input_channels"] = group_input_channels
147        kwargs["group_output_channels"] = group_output_channels
148        kwargs["groups"] = groups
149        kwargs["adjustment_height"] = 0
150        kwargs["adjustment_width"] = 0
151        kwargs["flags"] = 0
152
153        if is_depthwise_conv:
154            conv_node_type = XNNDepthwiseConv2d
155        else:
156            conv_node_type = XNNConv2d
157
158        ser_node = XNode(
159            xnode_union=conv_node_type(
160                **kwargs,
161            ),
162            debug_handle=debug_handle,
163            output_min_max=output_min_max,
164        )
165        xnn_graph.xnodes.append(ser_node)
166