xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_conv2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7from typing import List
8
9import serializer.tosa_serializer as ts
10import torch
11from executorch.backends.arm.operators.node_visitor import (
12    NodeVisitor,
13    register_node_visitor,
14)
15from executorch.backends.arm.tosa_mapping import TosaArg
16from executorch.backends.arm.tosa_quant_utils import (
17    build_rescale_conv_output,
18    get_quant_arg_downstream,
19    get_quant_arg_upstream,
20)
21from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape
22
23from serializer.tosa_serializer import TosaOp
24
25
26@register_node_visitor
27class Conv2dVisitor(NodeVisitor):
28    target = "aten.convolution.default"
29
30    def __init__(self, *args):
31        super().__init__(*args)
32
33    # torch.nn.Conv2d does not require the result of
34    # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
35    # must be an integer, but tosa currently strictly require this property.
36    # This function adjusts the pad value to meet the requirement.
37    def adjust_pad_if_needed(self, input, weight, stride, pad, dilation):
38        mod_remainder = (input + 2 * pad - dilation * (weight - 1) - 1) % stride
39
40        # No need to adjust
41        if mod_remainder == 0:
42            return pad
43
44        if mod_remainder > pad:
45            raise RuntimeError(
46                "This case should be handled by the SizeAdjustConv2d pass, is it enabled?"
47            )
48        return pad - mod_remainder
49
50    def define_node(
51        self,
52        node: torch.fx.Node,
53        tosa_graph: ts.TosaSerializer,
54        inputs: List[TosaArg],
55        output: TosaArg,
56        is_quant_node: bool,
57    ) -> None:
58        input, weight, bias, stride, pad, dilation, _, _, group = inputs
59
60        # Currently only int8 is supported in quantized types.
61        actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype
62
63        # Get the attributes of convolution.
64        attr = ts.TosaSerializerAttribute()
65        pad_attr = [val for val in pad.special for _ in (0, 1)]
66        stride_attr = stride.special
67        dilation_attr = dilation.special
68
69        # Adjust the pad value if needed to meet the strict convolution output shape calculation.
70        pad_attr[1] = self.adjust_pad_if_needed(
71            input.shape[2],
72            weight.shape[2],
73            stride_attr[0],
74            pad_attr[1],
75            dilation_attr[0],
76        )
77        pad_attr[3] = self.adjust_pad_if_needed(
78            input.shape[3],
79            weight.shape[3],
80            stride_attr[1],
81            pad_attr[3],
82            dilation_attr[1],
83        )
84
85        input_zp = (
86            get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0
87        )
88
89        attr.ConvAttribute(
90            pad=pad_attr,
91            stride=stride_attr,
92            dilation=dilation_attr,
93            input_zp=input_zp,
94            weight_zp=0,
95            local_bound=False,
96        )
97
98        # Non-bias case.
99        if len(node.all_input_nodes) == 2:
100            # Create a zero bias tensor if not presented
101            out_channels = weight.shape[0]
102            bias_name = "bias" + node.name.split("default", 1)[1]
103            bias = tosa_graph.addConst(
104                [out_channels],
105                ts.DType.INT32 if is_quant_node else output.dtype,
106                [0] * out_channels,
107                name=bias_name,
108            )
109
110        # The output type is int32 when input type is int8.
111        conv2d_output_name = output.name
112        if is_quant_node:
113            conv2d_res = tosa_graph.addIntermediate(
114                tosa_shape(output.shape, output.dim_order), ts.DType.INT32
115            )
116            conv2d_output_name = conv2d_res.name
117
118        # Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)
119        in_channels = input.shape[1]
120        out_channels = weight.shape[0]
121        if (in_channels == group.number) and (out_channels % in_channels) == 0:
122            """Depthwise convolution case"""
123            # Reshape torch shape format of weight tensor to tosa required format.
124            # https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
125            m_length = int(out_channels / in_channels)
126            weight_post_shape = (
127                weight.shape[2],
128                weight.shape[3],
129                in_channels,
130                m_length,
131            )
132
133            weight_reshaped = tosa_graph.addIntermediate(
134                weight_post_shape,
135                ts.DType.INT8 if is_quant_node else weight.dtype,
136            )
137            build_reshape(
138                tosa_graph, weight.name, weight_post_shape, weight_reshaped.name
139            )
140            tosa_op = TosaOp.Op().DEPTHWISE_CONV2D
141            weight_name = weight_reshaped.name
142        else:
143            """Regular convolution case"""
144            tosa_op = TosaOp.Op().CONV2D
145            weight_name = weight.name
146
147        tosa_graph.addOperator(
148            tosa_op,
149            [
150                input.name,
151                weight_name,
152                bias.name,
153            ],
154            [conv2d_output_name],
155            attr,
156        )
157
158        # For quantized convolution, rescale the output value back to the same
159        # integer value domain of the next op. Otherwise return float32 output.
160        if is_quant_node:
161            # Get scale_factor from input, weight, and output.
162            input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale
163            weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale
164            output_qargs = get_quant_arg_downstream(list(node.users)[0])
165
166            build_rescale_conv_output(
167                tosa_graph,
168                # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
169                conv2d_res,
170                output.name,
171                actual_out_type,
172                input_scale,
173                weight_scale,
174                output_qargs.scale,
175                output_qargs.zp,
176            )
177