1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7from typing import List 8 9import numpy as np 10import serializer.tosa_serializer as ts 11import torch 12from executorch.backends.arm.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.arm.tosa_mapping import TosaArg 17from executorch.backends.arm.tosa_quant_utils import ( 18 dequantize_value, 19 get_quant_arg_downstream, 20 get_quant_arg_upstream, 21 QuantArgs, 22 quantize_value, 23) 24from serializer.tosa_serializer import TosaOp 25 26 27@register_node_visitor 28class RsqrtVisitor(NodeVisitor): 29 target = "aten.rsqrt.default" 30 31 def define_node( 32 self, 33 node: torch.fx.Node, 34 tosa_graph: ts.TosaSerializer, 35 inputs: List[TosaArg], 36 output: TosaArg, 37 is_quant_node: bool, 38 ) -> None: 39 if is_quant_node: 40 # Assume quantized input is 8 bit. 41 # Create attribute for 8 bit table lookup. 42 input_node = node.all_input_nodes[0] 43 in_quantargs = get_quant_arg_upstream(input_node) 44 output_node = list(node.users)[0] 45 out_quantargs = get_quant_arg_downstream(output_node) 46 table = rsqrt_table_8bit(in_quantargs, out_quantargs) 47 table_attr = ts.TosaSerializerAttribute() 48 table_attr.TableAttribute(table) 49 tosa_graph.addOperator( 50 TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr 51 ) 52 else: 53 tosa_graph.addOperator(TosaOp.Op().RSQRT, [inputs[0].name], [output.name]) 54 55 56def rsqrt_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): 57 """ 58 Returns a table mapping 256 entries to rqsrt([qmin,qmax]) 59 Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_rsqrt 60 """ 61 62 def rqsrt(x): 63 # Convert quantized input to floating point rqsrt input space. 64 v = dequantize_value(x, in_quantargs) 65 # Compute rqsrt. 66 v = 1 / np.sqrt(v) 67 # Convert rqsrt output back to quantized space. 68 return quantize_value(v, out_quantargs) 69 70 return [ 71 rqsrt(x) 72 for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) 73 ] 74