1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7from typing import List 8 9import numpy as np 10 11import serializer.tosa_serializer as ts 12from executorch.backends.arm.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.arm.tosa_mapping import TosaArg 17 18from executorch.backends.arm.tosa_quant_utils import ( 19 dequantize_value, 20 get_quant_arg_downstream, 21 get_quant_arg_upstream, 22 QuantArgs, 23 quantize_value, 24) 25from serializer.tosa_serializer import TosaOp 26from torch.fx import Node 27 28 29@register_node_visitor 30class ExpVisitor(NodeVisitor): 31 target = "aten.exp.default" 32 33 def __init__(self, *args): 34 super().__init__(*args) 35 36 def define_node( 37 self, 38 node: Node, 39 tosa_graph: ts.TosaSerializer, 40 inputs: List[TosaArg], 41 output: TosaArg, 42 is_quant_node: bool, 43 ) -> None: 44 45 assert len(node.all_input_nodes) == 1 46 47 if is_quant_node: 48 # Assume quantized input is 8 bit. 49 50 # Create attribute for 8 bit table lookup. 51 input_node = node.all_input_nodes[0] 52 in_quantargs = get_quant_arg_upstream(input_node) 53 output_node = list(node.users)[0] 54 out_quantargs = get_quant_arg_downstream(output_node) 55 56 table = exp_table_8bit(in_quantargs, out_quantargs) 57 table_attr = ts.TosaSerializerAttribute() 58 table_attr.TableAttribute(table) 59 60 tosa_graph.addOperator( 61 TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr 62 ) 63 else: 64 tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name]) 65 66 67def exp_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): 68 """ 69 Returns a table mapping 256 entries to exp([qmin,qmax]) 70 """ 71 72 def exp(x): 73 # Convert quantized input to floating point exp input space. 74 v = dequantize_value(x, in_quantargs) 75 # Compute exp. 76 v = np.exp(v) 77 # Convert exp output back to quantized space. 78 return quantize_value(v, out_quantargs) 79 80 return [ 81 exp(x) 82 for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) 83 ] 84