xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_full.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7from typing import List
8
9import numpy as np
10
11import serializer.tosa_serializer as ts
12from executorch.backends.arm.operators.node_visitor import (
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.arm.tosa_mapping import TosaArg
17from executorch.backends.arm.tosa_quant_utils import (
18    get_quant_arg_downstream,
19    quantize_value,
20)
21from executorch.backends.arm.tosa_utils import tosa_shape
22from torch.fx import Node
23
24
25@register_node_visitor
26class FullVisitor(NodeVisitor):
27    target = "aten.full.default"
28
29    def __init__(self, *args):
30        super().__init__(*args)
31
32    def define_node(
33        self,
34        node: Node,
35        tosa_graph: ts.TosaSerializer,
36        inputs: List[TosaArg],
37        output: TosaArg,
38        is_quant_node: bool,
39    ) -> None:
40
41        shape = tosa_shape(inputs[0].special, output.dim_order)
42
43        value = inputs[1].number
44        if is_quant_node:
45            qargs = get_quant_arg_downstream(list(node.users)[0])
46            qvalue = quantize_value(value, qargs)
47            dtype = ts.DType.INT8
48            data = np.full(shape, qvalue, dtype=np.int8)
49        else:
50            assert (
51                output.dtype == ts.DType.FP32
52            ), "'Full' currently only supports FP32 for unquantized models."
53            dtype = ts.DType.FP32
54            data = np.full(shape, value, dtype=np.float32)
55
56        tosa_graph.addConst(shape, dtype, data, node.name + "full-const")
57        tosa_graph.addOperator(
58            ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name]
59        )
60