1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7from typing import List 8 9import serializer.tosa_serializer as ts 10import torch 11from executorch.backends.arm.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.arm.tosa_mapping import TosaArg 16from executorch.backends.arm.tosa_utils import tosa_shape 17from serializer.tosa_serializer import TosaOp 18 19 20@register_node_visitor 21class DivVisitor(NodeVisitor): 22 target = "aten.div.Tensor" 23 24 def __init__(self, *args): 25 super().__init__(*args) 26 27 def define_node( 28 self, 29 node: torch.fx.Node, 30 tosa_graph: ts.TosaSerializer, 31 inputs: List[TosaArg], 32 output: TosaArg, 33 is_quant_node: bool, 34 ) -> None: 35 # FP32 Div is implemented as output=x/y -> output=x*1/y e.g. MUL(x,RECIPROCAL(y)) 36 recip = tosa_graph.addIntermediate( 37 tosa_shape(inputs[1].shape, inputs[1].dim_order), inputs[1].dtype 38 ) 39 tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[1].name], [recip.name]) 40 41 attr = ts.TosaSerializerAttribute() 42 attr.MulAttribute(0) 43 tosa_graph.addOperator( 44 TosaOp.Op().MUL, [inputs[0].name, recip.name], [output.name], attr 45 ) 46