1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7from typing import List 8 9import serializer.tosa_serializer as ts 10import torch 11from executorch.backends.arm.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.arm.tosa_mapping import TosaArg 16from executorch.backends.arm.tosa_quant_utils import ( 17 build_rescale_conv_output, 18 get_quant_arg_downstream, 19 get_quant_arg_upstream, 20) 21from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape 22 23from serializer.tosa_serializer import TosaOp 24 25 26@register_node_visitor 27class Conv2dVisitor(NodeVisitor): 28 target = "aten.convolution.default" 29 30 def __init__(self, *args): 31 super().__init__(*args) 32 33 # torch.nn.Conv2d does not require the result of 34 # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` 35 # must be an integer, but tosa currently strictly require this property. 36 # This function adjusts the pad value to meet the requirement. 37 def adjust_pad_if_needed(self, input, weight, stride, pad, dilation): 38 mod_remainder = (input + 2 * pad - dilation * (weight - 1) - 1) % stride 39 40 # No need to adjust 41 if mod_remainder == 0: 42 return pad 43 44 if mod_remainder > pad: 45 raise RuntimeError( 46 "This case should be handled by the SizeAdjustConv2d pass, is it enabled?" 47 ) 48 return pad - mod_remainder 49 50 def define_node( 51 self, 52 node: torch.fx.Node, 53 tosa_graph: ts.TosaSerializer, 54 inputs: List[TosaArg], 55 output: TosaArg, 56 is_quant_node: bool, 57 ) -> None: 58 input, weight, bias, stride, pad, dilation, _, _, group = inputs 59 60 # Currently only int8 is supported in quantized types. 61 actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype 62 63 # Get the attributes of convolution. 64 attr = ts.TosaSerializerAttribute() 65 pad_attr = [val for val in pad.special for _ in (0, 1)] 66 stride_attr = stride.special 67 dilation_attr = dilation.special 68 69 # Adjust the pad value if needed to meet the strict convolution output shape calculation. 70 pad_attr[1] = self.adjust_pad_if_needed( 71 input.shape[2], 72 weight.shape[2], 73 stride_attr[0], 74 pad_attr[1], 75 dilation_attr[0], 76 ) 77 pad_attr[3] = self.adjust_pad_if_needed( 78 input.shape[3], 79 weight.shape[3], 80 stride_attr[1], 81 pad_attr[3], 82 dilation_attr[1], 83 ) 84 85 input_zp = ( 86 get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0 87 ) 88 89 attr.ConvAttribute( 90 pad=pad_attr, 91 stride=stride_attr, 92 dilation=dilation_attr, 93 input_zp=input_zp, 94 weight_zp=0, 95 local_bound=False, 96 ) 97 98 # Non-bias case. 99 if len(node.all_input_nodes) == 2: 100 # Create a zero bias tensor if not presented 101 out_channels = weight.shape[0] 102 bias_name = "bias" + node.name.split("default", 1)[1] 103 bias = tosa_graph.addConst( 104 [out_channels], 105 ts.DType.INT32 if is_quant_node else output.dtype, 106 [0] * out_channels, 107 name=bias_name, 108 ) 109 110 # The output type is int32 when input type is int8. 111 conv2d_output_name = output.name 112 if is_quant_node: 113 conv2d_res = tosa_graph.addIntermediate( 114 tosa_shape(output.shape, output.dim_order), ts.DType.INT32 115 ) 116 conv2d_output_name = conv2d_res.name 117 118 # Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W) 119 in_channels = input.shape[1] 120 out_channels = weight.shape[0] 121 if (in_channels == group.number) and (out_channels % in_channels) == 0: 122 """Depthwise convolution case""" 123 # Reshape torch shape format of weight tensor to tosa required format. 124 # https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d 125 m_length = int(out_channels / in_channels) 126 weight_post_shape = ( 127 weight.shape[2], 128 weight.shape[3], 129 in_channels, 130 m_length, 131 ) 132 133 weight_reshaped = tosa_graph.addIntermediate( 134 weight_post_shape, 135 ts.DType.INT8 if is_quant_node else weight.dtype, 136 ) 137 build_reshape( 138 tosa_graph, weight.name, weight_post_shape, weight_reshaped.name 139 ) 140 tosa_op = TosaOp.Op().DEPTHWISE_CONV2D 141 weight_name = weight_reshaped.name 142 else: 143 """Regular convolution case""" 144 tosa_op = TosaOp.Op().CONV2D 145 weight_name = weight.name 146 147 tosa_graph.addOperator( 148 tosa_op, 149 [ 150 input.name, 151 weight_name, 152 bias.name, 153 ], 154 [conv2d_output_name], 155 attr, 156 ) 157 158 # For quantized convolution, rescale the output value back to the same 159 # integer value domain of the next op. Otherwise return float32 output. 160 if is_quant_node: 161 # Get scale_factor from input, weight, and output. 162 input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale 163 weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale 164 output_qargs = get_quant_arg_downstream(list(node.users)[0]) 165 166 build_rescale_conv_output( 167 tosa_graph, 168 # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. 169 conv2d_res, 170 output.name, 171 actual_out_type, 172 input_scale, 173 weight_scale, 174 output_qargs.scale, 175 output_qargs.zp, 176 ) 177