1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7from typing import List 8 9import numpy as np 10 11import serializer.tosa_serializer as ts 12from executorch.backends.arm.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.arm.tosa_mapping import TosaArg 17 18from executorch.backends.arm.tosa_quant_utils import ( 19 dequantize_value, 20 get_quant_arg_downstream, 21 get_quant_arg_upstream, 22 QuantArgs, 23 quantize_value, 24) 25from serializer.tosa_serializer import TosaOp 26from torch.fx import Node 27 28 29@register_node_visitor 30class SigmoidVisitor(NodeVisitor): 31 target = "aten.sigmoid.default" 32 33 def __init__(self, *args): 34 super().__init__(*args) 35 36 def define_node( 37 self, 38 node: Node, 39 tosa_graph: ts.TosaSerializer, 40 inputs: List[TosaArg], 41 output: TosaArg, 42 is_quant_node: bool, 43 ) -> None: 44 45 assert len(node.all_input_nodes) == 1 46 assert len(node.users) == 1 47 48 if is_quant_node: 49 # Assume quantized input is 8 bit. 50 51 # Create attribute for 8 bit table lookup. 52 input_node = node.all_input_nodes[0] 53 in_quantargs = get_quant_arg_upstream(input_node) 54 output_node = list(node.users)[0] 55 out_quantargs = get_quant_arg_downstream(output_node) 56 57 table = sigmoid_table_8bit(in_quantargs, out_quantargs) 58 table_attr = ts.TosaSerializerAttribute() 59 table_attr.TableAttribute(table) 60 61 tosa_graph.addOperator( 62 TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr 63 ) 64 else: 65 tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name]) 66 67 68def sigmoid_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): 69 """ 70 Returns a table mapping 256 entries to sigmoid([qmin,qmax]) 71 Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_sigmoid 72 """ 73 74 def sigmoid(x): 75 # Convert quantized input to floating point sigmoid input space. 76 v = dequantize_value(x, in_quantargs) 77 # Compute sigmoid. 78 v = 1.0 / (1.0 + np.exp(-v)) 79 # Convert sigmoid output back to quantized space. 80 return quantize_value(v, out_quantargs) 81 82 return [ 83 sigmoid(x) 84 for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) 85 ] 86