1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict, List 8 9import torch 10from executorch.backends.transforms import get_shape 11from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass 12from executorch.backends.xnnpack.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.xnnpack.operators.quant_params import QuantParams 17from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 18 XNNConv2d, 19 XNNDepthwiseConv2d, 20 XNNGraph, 21 XNode, 22) 23from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node 24 25from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID 26 27 28@register_node_visitor 29class Conv2d(NodeVisitor): 30 target = "aten.convolution.default" 31 32 def __init__(self, *args) -> None: 33 super().__init__(*args) 34 35 def define_node( 36 self, 37 node: torch.fx.Node, 38 xnn_graph: XNNGraph, 39 vals_to_ids: Dict[torch.fx.Node, int], 40 debug_handle: int, 41 ) -> None: 42 kwargs = {} 43 # input 44 input_node = get_input_node(node, 0) 45 input_quant_params = QuantParams.from_inputs(input_node, self._exported_program) 46 self.define_tensor( 47 input_node, 48 xnn_graph, 49 vals_to_ids, 50 convert_to_nhwc=True, 51 quant_params=input_quant_params, 52 ) # NHWC input 53 kwargs["input1_id"] = vals_to_ids[get_input_node(node, 0)] 54 55 # filter shape for pytorch convolution is (oc, inc/groups, height, width) 56 # shape for xnnpack convolution is (oc, height, width, inc/groups), to convert 57 # to the proper shape, this is essentially a NCHW to NHWC conversion 58 kernel_node = get_input_node(node, 1) 59 kernel_shape = get_shape(kernel_node) 60 groups = cast(int, node.args[8]) 61 group_input_channels = kernel_shape[1] 62 group_output_channels = int(kernel_shape[0] / groups) 63 64 # XNNPACK expects the kernel's N and C dimensions to be swapped for 65 # Depthwise Convolution, which occurs under the following conditions: 66 # 1) groups = input_channels (i.e. group_input_channels = 1) 67 # 2) output_channels is a positive integer multiple of input channels 68 is_depthwise_conv = (group_input_channels == 1) and ( 69 group_output_channels % group_input_channels == 0 70 ) 71 weight_quant_params = QuantParams.from_weights( 72 kernel_node, self._exported_program 73 ) 74 fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16 75 76 self.define_tensor( 77 kernel_node, 78 xnn_graph, 79 vals_to_ids, 80 convert_to_nhwc=True, 81 swap_nc_for_depthwise_weights=is_depthwise_conv, 82 quant_params=weight_quant_params, 83 fp32_static_weights=fp32_static_weights, 84 ) 85 kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)] 86 87 # output 88 output_min_max = FuseActivationPass.get_fused_activation(node) 89 output_quant_params = QuantParams.from_outputs(node) 90 self.define_tensor( 91 node, 92 xnn_graph, 93 vals_to_ids, 94 convert_to_nhwc=True, 95 quant_params=output_quant_params, 96 ) # NHWC output 97 kwargs["output_id"] = vals_to_ids[node] 98 99 # bias 100 kwargs["bias_id"] = XNN_INVALID_VALUE_ID 101 if node.args[2] is not None: 102 # If there is a bias 103 bias_node = get_input_node(node, 2) 104 bias_quant_params = QuantParams.from_bias( 105 bias_node, weight_quant_params, input_quant_params 106 ) 107 self.define_tensor( 108 get_input_node(node, 2), 109 xnn_graph, 110 vals_to_ids, 111 convert_to_nhwc=False, 112 quant_params=bias_quant_params, 113 fp32_static_weights=fp32_static_weights, 114 ) 115 kwargs["bias_id"] = vals_to_ids[get_input_node(node, 2)] 116 117 stride = cast(List[int], node.args[3]) 118 padding = cast(List[int], node.args[4]) 119 dilation = cast(List[int], node.args[5]) 120 if len(padding) == 1: 121 padding = padding + padding 122 123 # args[6] = transposed 124 check_or_raise( 125 not cast(bool, node.args[6]), "No support for transposed convolution" 126 ) 127 # args[7] = output padding 128 check_or_raise( 129 all(out_pad == 0 for out_pad in cast(List[int], node.args[7])), 130 "XNNPACK does not support output padding", 131 ) 132 133 check_or_raise( 134 len(stride) == 2, "XNNPACK currently only supports 2D convolution" 135 ) 136 kwargs["padding_top"] = padding[0] 137 kwargs["padding_right"] = padding[1] 138 kwargs["padding_bottom"] = padding[0] 139 kwargs["padding_left"] = padding[1] 140 kwargs["kernel_height"] = kernel_shape[2] 141 kwargs["kernel_width"] = kernel_shape[3] 142 kwargs["subsampling_height"] = stride[0] 143 kwargs["subsampling_width"] = stride[1] 144 kwargs["dilation_height"] = dilation[0] 145 kwargs["dilation_width"] = dilation[1] 146 kwargs["group_input_channels"] = group_input_channels 147 kwargs["group_output_channels"] = group_output_channels 148 kwargs["groups"] = groups 149 kwargs["adjustment_height"] = 0 150 kwargs["adjustment_width"] = 0 151 kwargs["flags"] = 0 152 153 if is_depthwise_conv: 154 conv_node_type = XNNDepthwiseConv2d 155 else: 156 conv_node_type = XNNConv2d 157 158 ser_node = XNode( 159 xnode_union=conv_node_type( 160 **kwargs, 161 ), 162 debug_handle=debug_handle, 163 output_min_max=output_min_max, 164 ) 165 xnn_graph.xnodes.append(ser_node) 166