xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_exp.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7from typing import List
8
9import numpy as np
10
11import serializer.tosa_serializer as ts
12from executorch.backends.arm.operators.node_visitor import (
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.arm.tosa_mapping import TosaArg
17
18from executorch.backends.arm.tosa_quant_utils import (
19    dequantize_value,
20    get_quant_arg_downstream,
21    get_quant_arg_upstream,
22    QuantArgs,
23    quantize_value,
24)
25from serializer.tosa_serializer import TosaOp
26from torch.fx import Node
27
28
29@register_node_visitor
30class ExpVisitor(NodeVisitor):
31    target = "aten.exp.default"
32
33    def __init__(self, *args):
34        super().__init__(*args)
35
36    def define_node(
37        self,
38        node: Node,
39        tosa_graph: ts.TosaSerializer,
40        inputs: List[TosaArg],
41        output: TosaArg,
42        is_quant_node: bool,
43    ) -> None:
44
45        assert len(node.all_input_nodes) == 1
46
47        if is_quant_node:
48            # Assume quantized input is 8 bit.
49
50            # Create attribute for 8 bit table lookup.
51            input_node = node.all_input_nodes[0]
52            in_quantargs = get_quant_arg_upstream(input_node)
53            output_node = list(node.users)[0]
54            out_quantargs = get_quant_arg_downstream(output_node)
55
56            table = exp_table_8bit(in_quantargs, out_quantargs)
57            table_attr = ts.TosaSerializerAttribute()
58            table_attr.TableAttribute(table)
59
60            tosa_graph.addOperator(
61                TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
62            )
63        else:
64            tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name])
65
66
67def exp_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
68    """
69    Returns a table mapping 256 entries to exp([qmin,qmax])
70    """
71
72    def exp(x):
73        # Convert quantized input to floating point exp input space.
74        v = dequantize_value(x, in_quantargs)
75        # Compute exp.
76        v = np.exp(v)
77        # Convert exp output back to quantized space.
78        return quantize_value(v, out_quantargs)
79
80    return [
81        exp(x)
82        for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
83    ]
84