xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_transpose.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8from typing import List
9
10import serializer.tosa_serializer as ts
11import torch
12from executorch.backends.arm.operators.node_visitor import (
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.arm.tosa_mapping import TosaArg
17from serializer.tosa_serializer import TosaOp
18
19
20@register_node_visitor
21class TransposeVisitor(NodeVisitor):
22    """
23    This node visitor targets the _transpose op defined in the
24    passthrough_to_tosa library. Used when switching between tosa_dim_orders.
25    Inserts a TOSA TRANSPOSE.
26    """
27
28    target = "_transpose"
29
30    def define_node(
31        self,
32        node: torch.fx.Node,
33        tosa_graph: ts.TosaSerializer,
34        inputs: List[TosaArg],
35        output: TosaArg,
36        is_quant_node: bool,
37    ) -> None:
38        output_rank = len(output.shape)
39        perms = [dim % output_rank for dim in inputs[1].special]
40        attr = ts.TosaSerializerAttribute()
41        attr.TransposeAttribute(perms)
42        tosa_graph.addOperator(
43            TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr
44        )
45