1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7from typing import List 8 9import numpy as np 10 11import serializer.tosa_serializer as ts 12from executorch.backends.arm.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.arm.tosa_mapping import TosaArg 17from executorch.backends.arm.tosa_quant_utils import ( 18 get_quant_arg_downstream, 19 quantize_value, 20) 21from executorch.backends.arm.tosa_utils import tosa_shape 22from torch.fx import Node 23 24 25@register_node_visitor 26class FullVisitor(NodeVisitor): 27 target = "aten.full.default" 28 29 def __init__(self, *args): 30 super().__init__(*args) 31 32 def define_node( 33 self, 34 node: Node, 35 tosa_graph: ts.TosaSerializer, 36 inputs: List[TosaArg], 37 output: TosaArg, 38 is_quant_node: bool, 39 ) -> None: 40 41 shape = tosa_shape(inputs[0].special, output.dim_order) 42 43 value = inputs[1].number 44 if is_quant_node: 45 qargs = get_quant_arg_downstream(list(node.users)[0]) 46 qvalue = quantize_value(value, qargs) 47 dtype = ts.DType.INT8 48 data = np.full(shape, qvalue, dtype=np.int8) 49 else: 50 assert ( 51 output.dtype == ts.DType.FP32 52 ), "'Full' currently only supports FP32 for unquantized models." 53 dtype = ts.DType.FP32 54 data = np.full(shape, value, dtype=np.float32) 55 56 tosa_graph.addConst(shape, dtype, data, node.name + "full-const") 57 tosa_graph.addOperator( 58 ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name] 59 ) 60