xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/op_prelu.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from typing import Dict
7
8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9
10import torch
11from executorch.backends.qualcomm.utils.constants import (
12    QCOM_AXIS_ORDER,
13    QCOM_QUANT_ATTRS,
14    QCOM_QUANT_MAX,
15    QCOM_QUANT_MIN,
16    QCOM_SCALE,
17    QCOM_ZERO_POINT,
18)
19from executorch.exir.dialects._ops import ops as exir_ops
20
21from .node_visitor import get_parameter, NodeVisitor, register_node_visitor
22from .qnn_constants import OpPRelu, QNN_OP_PACKAGE_NAME_QTI_AISW
23
24
25@register_node_visitor
26class PReLU(NodeVisitor):
27    target = ["aten.leaky_relu.default", "aten.prelu.default"]
28
29    def __init__(self, *args) -> None:
30        super().__init__(*args)
31
32    def define_node(
33        self,
34        node: torch.fx.Node,
35        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
36    ) -> PyQnnWrapper.PyQnnOpWrapper:
37        input_node = node.args[0]
38        input_tensor = self.get_tensor(input_node, node)
39        prelu_inp_tensor_wrapper = self.define_tensor(
40            input_node,
41            input_tensor,
42            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
43            nodes_to_wrappers,
44            is_input_tensor=True,
45        )
46
47        if node.target.__name__ == "aten.leaky_relu.default":
48            coeff = 1e-2 if len(node.args) < 2 else node.args[1]
49            coeff_tensor = torch.full(input_tensor.shape, coeff).to(torch.float32)
50        else:
51            coeff_node = node.args[1]
52            coeff_tensor = torch.zeros(input_node.meta["val"].shape)
53            coeff = get_parameter(coeff_node, self.edge_program)
54            # param nodes will be FakeTensor when doing partition
55            # fill in random numeric for validation
56            if isinstance(coeff, torch._subclasses.fake_tensor.FakeTensor):
57                coeff = torch.ones(coeff.shape)
58            # per-channel activation
59            if coeff_node.meta["val"].shape[0] > 1:
60                for i in range(input_node.meta["val"].shape[1]):
61                    coeff_tensor = coeff_tensor.index_fill(
62                        1, torch.tensor([i]), coeff[i]
63                    )
64                if QCOM_AXIS_ORDER in input_node.meta:
65                    axis_order = input_node.meta[QCOM_AXIS_ORDER]
66                    coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous()
67                # simple min-max quantization
68                coeff = torch.max(coeff).item()
69            else:
70                coeff = coeff.item()
71                coeff_tensor = torch.full(input_tensor.shape, coeff).to(torch.float32)
72
73        # 'graph', 'name', 'op', 'target', 'args', and 'kwargs'
74        scalar_node = torch.fx.Node(
75            node.graph,
76            node.name + "_runtime_scalar",
77            "call_function",
78            exir_ops.edge.aten.full.default,
79            (),  # args
80            {},  # kwargs
81        )
82        if pow_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
83            quant_attrs = pow_quant_attrs.copy()
84            quant_range = quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN]
85            # coeff is guaranteed to be positive
86            quant_attrs[QCOM_ZERO_POINT] = 0
87            quant_attrs[QCOM_SCALE] = coeff / quant_range
88            scalar_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
89
90        scalar_tensor_wrapper = self.define_tensor(
91            scalar_node,
92            coeff_tensor,
93            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
94            nodes_to_wrappers,
95            is_input_tensor=True,
96        )
97        prelu_input_tensors = [prelu_inp_tensor_wrapper, scalar_tensor_wrapper]
98
99        output_tensor = self.get_tensor(node, node)
100        output_tensor_wrapper = self.define_tensor(
101            node,
102            output_tensor,
103            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
104            nodes_to_wrappers,
105            is_input_tensor=False,
106        )
107        prelu_output_tensors = [output_tensor_wrapper]
108
109        prelu_op = PyQnnWrapper.PyQnnOpWrapper(
110            node.name,
111            QNN_OP_PACKAGE_NAME_QTI_AISW,
112            OpPRelu.op_name,
113        )
114        prelu_op.AddInputTensors(prelu_input_tensors)
115        prelu_op.AddOutputTensors(prelu_output_tensors)
116
117        return prelu_op
118