xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/op_cat.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import warnings
7from typing import cast, Dict, List
8
9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10
11import numpy as np
12import torch
13from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
14
15from .node_visitor import NodeVisitor, register_node_visitor
16from .qnn_constants import OpConcat, QNN_OP_PACKAGE_NAME_QTI_AISW
17
18
19@register_node_visitor
20class Cat(NodeVisitor):
21    target = ["aten.cat.default"]
22
23    def __init__(self, *args) -> None:
24        super().__init__(*args)
25
26    def define_node(
27        self,
28        node: torch.fx.Node,
29        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30    ) -> PyQnnWrapper.PyQnnOpWrapper:
31        list_of_tensors = cast(List[torch.fx.Node], node.args[0])
32        list_of_tensor_wrappers = []
33
34        for tensor_input in list_of_tensors:
35            input_tensor = self.get_tensor(tensor_input, node)
36            list_of_tensor_wrappers.append(
37                self.define_tensor(
38                    tensor_input,
39                    input_tensor,
40                    PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
41                    nodes_to_wrappers,
42                    is_input_tensor=True,
43                )
44            )
45
46        if len(list_of_tensors) != len(list_of_tensor_wrappers):
47            warnings.warn(
48                "[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.",
49                stacklevel=1,
50            )
51            return
52
53        output_tensor = self.get_tensor(node, node)
54        output_tensor_wrapper = self.define_tensor(
55            node,
56            output_tensor,
57            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
58            nodes_to_wrappers,
59            is_input_tensor=False,
60        )
61
62        # node args[1] might not exist
63        axis = 0
64        if len(node.args) == 2:
65            axis = cast(int, node.args[1])
66
67        if axis < 0:
68            axis += node.meta["val"].dim()
69
70        if QCOM_AXIS_ORDER in node.meta:
71            axis = node.meta[QCOM_AXIS_ORDER].index(axis)
72
73        concat_op = PyQnnWrapper.PyQnnOpWrapper(
74            node.name,
75            QNN_OP_PACKAGE_NAME_QTI_AISW,
76            OpConcat.op_name,
77        )
78        concat_op.AddInputTensors(list_of_tensor_wrappers)
79        concat_op.AddOutputTensors([output_tensor_wrapper])
80
81        concat_op.AddScalarParam(
82            OpConcat.param_axis,
83            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
84            {QCOM_DATA: np.uint32(axis)},
85        )
86
87        return concat_op
88