1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import numpy as np 10from executorch.backends.arm._passes.arm_pass_utils import ( 11 create_node, 12 get_first_fake_tensor, 13) 14from executorch.backends.arm.tosa_quant_utils import dq_op, q_op 15from executorch.exir.dialects._ops import ops as exir_ops 16from executorch.exir.pass_base import ExportPass, PassResult 17 18 19class DecomposeLinearPass(ExportPass): 20 """ 21 This pass decomposes linear into a Conv2D with the required view operations. 22 linear(x, weights, bias) becomes: 23 x_reshaped = view(x) 24 weights_reshaped = view(weights) 25 conv2d = conv2d(x_reshaped, weights_reshaped, bias) 26 output = view(conv2d) 27 It also inserts q/dq pairs if the linear node was quantized. 28 """ 29 30 def call(self, graph_module): 31 for node in graph_module.graph.nodes: 32 if node.op != "call_function": 33 continue 34 if node.target != exir_ops.edge.aten.linear.default: 35 continue 36 args = node.args 37 input = args[0] 38 weights = args[1] 39 bias = args[2] if len(args) > 2 else None 40 output_shape = get_first_fake_tensor(node).shape 41 input_shape = get_first_fake_tensor(input).shape 42 weights_shape = get_first_fake_tensor(weights).shape 43 batches = int(np.prod(input_shape[:-1])) if len(input_shape) > 1 else 1 44 # input has shape (..., Ci) 45 input_reshaped_shape = [batches, input_shape[-1], 1, 1] 46 # weights have shape (Co, Ci) 47 weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1] 48 49 with graph_module.graph.inserting_before(node): 50 quantize = input.op == "call_function" and input.target == dq_op 51 q_params = input.args[1:] if quantize else None 52 # Reshape input to 4D with shape (N, Ci, 1, 1) 53 input_reshaped = create_node( 54 graph=graph_module.graph, 55 op_target=exir_ops.edge.aten.view_copy.default, 56 args=(input, input_reshaped_shape), 57 kwargs={}, 58 quantize=quantize, 59 q_params=q_params, 60 ) 61 62 quantize = weights.op == "call_function" and weights.target == dq_op 63 q_params = weights.args[1:] if quantize else None 64 # Reshape weights to 4D with shape (Co, Ci, 1, 1) 65 weights_reshaped = create_node( 66 graph=graph_module.graph, 67 op_target=exir_ops.edge.aten.view_copy.default, 68 args=(weights, weights_reshaped_shape), 69 kwargs={}, 70 quantize=quantize, 71 q_params=q_params, 72 ) 73 74 consumer_node = list(node.users)[0] 75 quantize = ( 76 consumer_node.op == "call_function" and consumer_node.target == q_op 77 ) 78 q_params = consumer_node.args[1:] if quantize else None 79 conv = create_node( 80 graph=graph_module.graph, 81 op_target=exir_ops.edge.aten.convolution.default, 82 args=( 83 input_reshaped, 84 weights_reshaped, 85 bias, 86 [1, 1], # strides 87 [0, 0], # padding 88 [1, 1], # dilation 89 False, # transposed 90 [0, 0], # output padding 91 1, # groups 92 ), 93 kwargs={}, 94 quantize=quantize, 95 q_params=q_params, 96 ) 97 98 with graph_module.graph.inserting_after(conv): 99 # Reshape output to same rank as original input with shape (..., Co) 100 # No need to insert q/dq pair as Conv2D node above has inserted them if 101 # required. 102 output = create_node( 103 graph=graph_module.graph, 104 op_target=exir_ops.edge.aten.view_copy.default, 105 args=(conv, list(output_shape)), 106 kwargs={}, 107 ) 108 109 node.replace_all_uses_with(output) 110 graph_module.graph.erase_node(node) 111 graph_module.graph.eliminate_dead_code() 112 graph_module.recompile() 113 graph_module = super().call(graph_module).graph_module 114 return PassResult(graph_module, True) 115