1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 NodeVisitor, 12 register_node_visitor, 13) 14from executorch.backends.xnnpack.operators.quant_params import QuantParams 15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 16 OutputMinMax, 17 XNNGraph, 18 XNNMultiply, 19 XNode, 20) 21 22from executorch.backends.xnnpack.utils.utils import get_input_node, get_relu_fused_node 23from executorch.exir.dialects._ops import ops as exir_ops 24 25 26@register_node_visitor 27class MultiplyVisitor(NodeVisitor): 28 target = "aten.mul.Tensor" 29 30 def __init__(self, *args) -> None: 31 super().__init__(*args) 32 33 def define_node( 34 self, 35 node: torch.fx.Node, 36 xnn_graph: XNNGraph, 37 vals_to_ids: Dict[torch.fx.Node, int], 38 debug_handle: int, 39 ) -> None: 40 # input1 41 input1 = get_input_node(node, 0) 42 self.define_tensor( 43 input1, 44 xnn_graph, 45 vals_to_ids, 46 quant_params=QuantParams.from_inputs(input1, self._exported_program), 47 ) 48 input1_id = vals_to_ids[input1] 49 50 # input2 51 input2 = get_input_node(node, 1) 52 self.define_tensor( 53 input2, 54 xnn_graph, 55 vals_to_ids, 56 quant_params=QuantParams.from_inputs(input2, self._exported_program), 57 ) 58 input2_id = vals_to_ids[input2] 59 60 # output 61 output_node = get_relu_fused_node(node) or node 62 output_min_max = None 63 # if fused with relu 64 if output_node.target == exir_ops.edge.aten.relu.default: 65 output_node.meta["XNNPACK_FUSED"] = True 66 output_min_max = OutputMinMax(output_min=0, output_max="+inf") 67 68 self.define_tensor( 69 output_node, 70 xnn_graph, 71 vals_to_ids, 72 quant_params=QuantParams.from_outputs(output_node), 73 ) 74 75 output_id = vals_to_ids[output_node] 76 77 ser_node = XNode( 78 xnode_union=XNNMultiply( 79 input1_id=input1_id, input2_id=input2_id, output_id=output_id, flags=0 80 ), 81 debug_handle=debug_handle, 82 output_min_max=output_min_max, 83 ) 84 xnn_graph.xnodes.append(ser_node) 85