1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass 9from executorch.backends.xnnpack.utils.utils import ( 10 check_or_raise, 11 get_param_tensor, 12 is_param_node, 13) 14from executorch.exir.dialects._ops import ops as exir_ops 15from executorch.exir.pass_base import PassResult 16 17 18class PReLUReshapePass(XNNPACKPass): 19 """ 20 This pass is used to modify the args of a PReLU node to make it compatible 21 with running via XNNPACK delegate. If there is only one parameter in the 22 weight tensor, repeat it to make the tensor to length num_channels. 23 This is because pytorch supports having either per-tensor or per-channel 24 weight parameters for PReLU, whereas XNNPACK supports only per-channel 25 """ 26 27 def call(self, graph_module: torch.fx.GraphModule): 28 graph = graph_module.graph 29 node_list = list(graph.nodes) 30 for node in node_list: 31 if node.op == "call_function": 32 if node.target == exir_ops.edge.aten._prelu_kernel.default: 33 weight_node = node.args[1] 34 35 check_or_raise( 36 is_param_node(self.exported_program, weight_node), 37 "Only constant weight PReLU is supported by XNNPACK", 38 ) 39 40 weight_data = get_param_tensor(self.exported_program, weight_node) 41 if weight_data is None: 42 raise AssertionError("Expected weight tensor to be not None") 43 44 weight_data = weight_data.data.contiguous() 45 46 check_or_raise( 47 weight_data.dim() == 4, 48 f"4D weight required for XNNPACK PReLU, got: {weight_data.dim()}D", 49 ) 50 51 if weight_data.numel() == 1: 52 input_shape = node.args[0].meta["val"].shape 53 54 check_or_raise( 55 len(input_shape) == 4, 56 f"4D input required for XNNPACK PReLU, got: {len(input_shape)}D", 57 ) 58 59 num_channels = input_shape[1] 60 61 weight_data_per_channel = weight_data.repeat( 62 1, num_channels, 1, 1 63 ) 64 65 setattr( 66 weight_node.graph.owning_module, 67 weight_node.target, 68 torch.nn.Parameter(data=weight_data_per_channel), 69 ) 70 71 # Since we are overriding "call", we need to call the parent's "call" 72 # to retrace the graph and regenerate metadata 73 graph_module = super().call(graph_module).graph_module 74 75 return PassResult(graph_module, True) 76