xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/op_dequantize.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from typing import Dict
7
8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9
10import torch
11
12from .node_visitor import NodeVisitor, register_node_visitor
13from .qnn_constants import OpDequantize, QNN_OP_PACKAGE_NAME_QTI_AISW
14
15
16class DequantizeOpBase(NodeVisitor):
17    def __init__(self, *args) -> None:
18        super().__init__(*args)
19
20    def define_node(
21        self,
22        node: torch.fx.Node,
23        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
24    ) -> PyQnnWrapper.PyQnnOpWrapper:
25        dequant_input_tensors = []
26        input_node = node.args[0]
27        input_tensor = self.get_tensor(input_node, node)
28        inp_tensor_wrapper = self.define_tensor(
29            input_node,
30            input_tensor,
31            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
32            nodes_to_wrappers,
33            is_input_tensor=True,
34        )
35        dequant_input_tensors.append(inp_tensor_wrapper)
36
37        output_tensor = self.get_tensor(node, node)
38        output_tensor_wrapper = self.define_tensor(
39            node,
40            output_tensor,
41            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
42            nodes_to_wrappers,
43            is_input_tensor=False,
44        )
45        dequant_output_tensors = [output_tensor_wrapper]
46
47        dequant_op = PyQnnWrapper.PyQnnOpWrapper(
48            node.target.__name__,
49            QNN_OP_PACKAGE_NAME_QTI_AISW,
50            OpDequantize.op_name,
51        )
52        dequant_op.AddInputTensors(dequant_input_tensors)
53        dequant_op.AddOutputTensors(dequant_output_tensors)
54
55        return dequant_op
56
57
58@register_node_visitor
59class PerTensorDequantize(DequantizeOpBase):
60    target = [
61        "quantized_decomposed.dequantize_per_tensor.default",
62        "quantized_decomposed.dequantize_per_tensor.tensor",
63    ]
64
65
66@register_node_visitor
67class PerChannelDequantize(DequantizeOpBase):
68    target = [
69        "quantized_decomposed.dequantize_per_channel.default",
70        "quantized_decomposed.dequantize_per_channel.tensor",
71    ]
72