xref: /aosp_15_r20/external/executorch/backends/xnnpack/operators/op_multiply.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict
8
9import torch
10from executorch.backends.xnnpack.operators.node_visitor import (
11    NodeVisitor,
12    register_node_visitor,
13)
14from executorch.backends.xnnpack.operators.quant_params import QuantParams
15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
16    OutputMinMax,
17    XNNGraph,
18    XNNMultiply,
19    XNode,
20)
21
22from executorch.backends.xnnpack.utils.utils import get_input_node, get_relu_fused_node
23from executorch.exir.dialects._ops import ops as exir_ops
24
25
26@register_node_visitor
27class MultiplyVisitor(NodeVisitor):
28    target = "aten.mul.Tensor"
29
30    def __init__(self, *args) -> None:
31        super().__init__(*args)
32
33    def define_node(
34        self,
35        node: torch.fx.Node,
36        xnn_graph: XNNGraph,
37        vals_to_ids: Dict[torch.fx.Node, int],
38        debug_handle: int,
39    ) -> None:
40        # input1
41        input1 = get_input_node(node, 0)
42        self.define_tensor(
43            input1,
44            xnn_graph,
45            vals_to_ids,
46            quant_params=QuantParams.from_inputs(input1, self._exported_program),
47        )
48        input1_id = vals_to_ids[input1]
49
50        # input2
51        input2 = get_input_node(node, 1)
52        self.define_tensor(
53            input2,
54            xnn_graph,
55            vals_to_ids,
56            quant_params=QuantParams.from_inputs(input2, self._exported_program),
57        )
58        input2_id = vals_to_ids[input2]
59
60        # output
61        output_node = get_relu_fused_node(node) or node
62        output_min_max = None
63        # if fused with relu
64        if output_node.target == exir_ops.edge.aten.relu.default:
65            output_node.meta["XNNPACK_FUSED"] = True
66            output_min_max = OutputMinMax(output_min=0, output_max="+inf")
67
68        self.define_tensor(
69            output_node,
70            xnn_graph,
71            vals_to_ids,
72            quant_params=QuantParams.from_outputs(output_node),
73        )
74
75        output_id = vals_to_ids[output_node]
76
77        ser_node = XNode(
78            xnode_union=XNNMultiply(
79                input1_id=input1_id, input2_id=input2_id, output_id=output_id, flags=0
80            ),
81            debug_handle=debug_handle,
82            output_min_max=output_min_max,
83        )
84        xnn_graph.xnodes.append(ser_node)
85