xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/op_to.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from typing import Dict
7
8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9
10import torch
11from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
12
13from .node_visitor import NodeVisitor, register_node_visitor
14from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW
15
16
17@register_node_visitor
18class To(NodeVisitor):
19    target = ["aten._to_copy.default"]
20    sufixed_8_offset_diff = 128
21    sufixed_16_offset_diff = 32768
22    epsilon = 1e-6
23    sufixed_8 = {
24        PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8,
25        PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
26    }
27    sufixed_16 = {
28        PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16,
29        PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
30    }
31
32    def __init__(self, *args) -> None:
33        super().__init__(*args)
34
35    def is_cast_node(self, node):
36        input_node = node.args[0]
37
38        # Not a case which has two quant node, no need to consider the convert op
39        if not all(
40            [
41                input_node.meta.get(QCOM_QUANT_ATTRS),
42                node.meta.get(QCOM_QUANT_ATTRS),
43            ]
44        ):
45            return True
46
47        input_tensor = self.get_tensor(input_node, node)
48        _, inp_qconfs = self.get_quant_encoding_conf(input_node, False)
49        inp_dtype = self.get_data_type(input_tensor, inp_qconfs)
50
51        output_tensor = self.get_tensor(node, node)
52        _, out_qconfs = self.get_quant_encoding_conf(node, False)
53        out_dtype = self.get_data_type(output_tensor, out_qconfs)
54        is_qparam_castable = (
55            lambda o1, o2, s1, s2, diff: abs(s1 - s2) < self.epsilon
56            and abs(o1 - o2) == diff
57        )
58
59        if {inp_dtype, out_dtype} == self.sufixed_8:
60            return is_qparam_castable(
61                inp_qconfs["offset"],
62                out_qconfs["offset"],
63                inp_qconfs["scale"],
64                out_qconfs["scale"],
65                self.sufixed_8_offset_diff,
66            )
67        elif {inp_dtype, out_dtype} == self.sufixed_16:
68            return is_qparam_castable(
69                inp_qconfs["offset"],
70                out_qconfs["offset"],
71                inp_qconfs["scale"],
72                out_qconfs["scale"],
73                self.sufixed_16_offset_diff,
74            )
75        return False
76
77    def define_node(
78        self,
79        node: torch.fx.Node,
80        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
81    ) -> PyQnnWrapper.PyQnnOpWrapper:
82        input_node = node.args[0]
83        input_tensor = self.get_tensor(input_node, node)
84
85        input_tensor_wrapper = self.define_tensor(
86            input_node,
87            input_tensor,
88            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
89            nodes_to_wrappers,
90            is_input_tensor=True,
91        )
92
93        output_tensor = self.get_tensor(node, node)
94
95        output_tensor_wrapper = self.define_tensor(
96            node,
97            output_tensor,
98            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
99            nodes_to_wrappers,
100            is_input_tensor=False,
101        )
102
103        qnn_op = OpCast if self.is_cast_node(node) else OpConvert
104        op = PyQnnWrapper.PyQnnOpWrapper(
105            node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name
106        )
107        op.AddInputTensors([input_tensor_wrapper])
108        op.AddOutputTensors([output_tensor_wrapper])
109
110        return op
111