xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_rsqrt.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7from typing import List
8
9import numpy as np
10import serializer.tosa_serializer as ts
11import torch
12from executorch.backends.arm.operators.node_visitor import (
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.arm.tosa_mapping import TosaArg
17from executorch.backends.arm.tosa_quant_utils import (
18    dequantize_value,
19    get_quant_arg_downstream,
20    get_quant_arg_upstream,
21    QuantArgs,
22    quantize_value,
23)
24from serializer.tosa_serializer import TosaOp
25
26
27@register_node_visitor
28class RsqrtVisitor(NodeVisitor):
29    target = "aten.rsqrt.default"
30
31    def define_node(
32        self,
33        node: torch.fx.Node,
34        tosa_graph: ts.TosaSerializer,
35        inputs: List[TosaArg],
36        output: TosaArg,
37        is_quant_node: bool,
38    ) -> None:
39        if is_quant_node:
40            # Assume quantized input is 8 bit.
41            # Create attribute for 8 bit table lookup.
42            input_node = node.all_input_nodes[0]
43            in_quantargs = get_quant_arg_upstream(input_node)
44            output_node = list(node.users)[0]
45            out_quantargs = get_quant_arg_downstream(output_node)
46            table = rsqrt_table_8bit(in_quantargs, out_quantargs)
47            table_attr = ts.TosaSerializerAttribute()
48            table_attr.TableAttribute(table)
49            tosa_graph.addOperator(
50                TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
51            )
52        else:
53            tosa_graph.addOperator(TosaOp.Op().RSQRT, [inputs[0].name], [output.name])
54
55
56def rsqrt_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
57    """
58    Returns a table mapping 256 entries to rqsrt([qmin,qmax])
59    Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_rsqrt
60    """
61
62    def rqsrt(x):
63        # Convert quantized input to floating point rqsrt input space.
64        v = dequantize_value(x, in_quantargs)
65        # Compute rqsrt.
66        v = 1 / np.sqrt(v)
67        # Convert rqsrt output back to quantized space.
68        return quantize_value(v, out_quantargs)
69
70    return [
71        rqsrt(x)
72        for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
73    ]
74