1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7from typing import List 8 9import serializer.tosa_serializer as ts 10import torch 11import tosa.Op as TosaOp 12 13from executorch.backends.arm.operators.node_visitor import ( 14 NodeVisitor, 15 register_node_visitor, 16) 17from executorch.backends.arm.tosa_mapping import TosaArg 18from executorch.backends.arm.tosa_utils import tosa_shape 19 20 21@register_node_visitor 22class ViewVisitor(NodeVisitor): 23 target = "aten.view_copy.default" 24 25 def __init__(self, *args): 26 super().__init__(*args) 27 28 def define_node( 29 self, 30 node: torch.fx.Node, 31 tosa_graph: ts.TosaSerializer, 32 inputs: List[TosaArg], 33 output: TosaArg, 34 is_quant_node: bool, 35 ) -> None: 36 attr = ts.TosaSerializerAttribute() 37 new_shape = tosa_shape(inputs[1].special, output.dim_order) 38 attr.ReshapeAttribute(new_shape) 39 tosa_graph.addOperator( 40 TosaOp.Op().RESHAPE, [inputs[0].name], [output.name], attr 41 ) 42