xref: /aosp_15_r20/external/executorch/backends/xnnpack/_passes/prelu_reshape_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
9from executorch.backends.xnnpack.utils.utils import (
10    check_or_raise,
11    get_param_tensor,
12    is_param_node,
13)
14from executorch.exir.dialects._ops import ops as exir_ops
15from executorch.exir.pass_base import PassResult
16
17
18class PReLUReshapePass(XNNPACKPass):
19    """
20    This pass is used to modify the args of a PReLU node to make it compatible
21    with running via XNNPACK delegate. If there is only one parameter in the
22    weight tensor, repeat it to make the tensor to length num_channels.
23    This is because pytorch supports having either per-tensor or per-channel
24    weight parameters for PReLU, whereas XNNPACK supports only per-channel
25    """
26
27    def call(self, graph_module: torch.fx.GraphModule):
28        graph = graph_module.graph
29        node_list = list(graph.nodes)
30        for node in node_list:
31            if node.op == "call_function":
32                if node.target == exir_ops.edge.aten._prelu_kernel.default:
33                    weight_node = node.args[1]
34
35                    check_or_raise(
36                        is_param_node(self.exported_program, weight_node),
37                        "Only constant weight PReLU is supported by XNNPACK",
38                    )
39
40                    weight_data = get_param_tensor(self.exported_program, weight_node)
41                    if weight_data is None:
42                        raise AssertionError("Expected weight tensor to be not None")
43
44                    weight_data = weight_data.data.contiguous()
45
46                    check_or_raise(
47                        weight_data.dim() == 4,
48                        f"4D weight required for XNNPACK PReLU, got: {weight_data.dim()}D",
49                    )
50
51                    if weight_data.numel() == 1:
52                        input_shape = node.args[0].meta["val"].shape
53
54                        check_or_raise(
55                            len(input_shape) == 4,
56                            f"4D input required for XNNPACK PReLU, got: {len(input_shape)}D",
57                        )
58
59                        num_channels = input_shape[1]
60
61                        weight_data_per_channel = weight_data.repeat(
62                            1, num_channels, 1, 1
63                        )
64
65                        setattr(
66                            weight_node.graph.owning_module,
67                            weight_node.target,
68                            torch.nn.Parameter(data=weight_data_per_channel),
69                        )
70
71        # Since we are overriding "call", we need to call the parent's "call"
72        # to retrace the graph and regenerate metadata
73        graph_module = super().call(graph_module).graph_module
74
75        return PassResult(graph_module, True)
76