xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_div.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7from typing import List
8
9import serializer.tosa_serializer as ts
10import torch
11from executorch.backends.arm.operators.node_visitor import (
12    NodeVisitor,
13    register_node_visitor,
14)
15from executorch.backends.arm.tosa_mapping import TosaArg
16from executorch.backends.arm.tosa_utils import tosa_shape
17from serializer.tosa_serializer import TosaOp
18
19
20@register_node_visitor
21class DivVisitor(NodeVisitor):
22    target = "aten.div.Tensor"
23
24    def __init__(self, *args):
25        super().__init__(*args)
26
27    def define_node(
28        self,
29        node: torch.fx.Node,
30        tosa_graph: ts.TosaSerializer,
31        inputs: List[TosaArg],
32        output: TosaArg,
33        is_quant_node: bool,
34    ) -> None:
35        # FP32 Div is implemented as output=x/y -> output=x*1/y e.g. MUL(x,RECIPROCAL(y))
36        recip = tosa_graph.addIntermediate(
37            tosa_shape(inputs[1].shape, inputs[1].dim_order), inputs[1].dtype
38        )
39        tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[1].name], [recip.name])
40
41        attr = ts.TosaSerializerAttribute()
42        attr.MulAttribute(0)
43        tosa_graph.addOperator(
44            TosaOp.Op().MUL, [inputs[0].name, recip.name], [output.name], attr
45        )
46