xref: /aosp_15_r20/external/executorch/backends/xnnpack/_passes/fuse_activation_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import torch
9
10from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import OutputMinMax
12
13from executorch.backends.xnnpack.utils.utils import check_or_raise
14from executorch.exir.dialects._ops import ops as exir_ops
15from executorch.exir.pass_base import PassResult
16
17
18class FuseActivationPass(XNNPACKPass):
19    """
20    Some activations like ReLU and hardtanh can be fused with certain operators preceding it.
21    In the case of fusion, we can instead delete the relu node and embed the activation constraints in the metadata
22    of the preceding node.
23    """
24
25    FUSED_ACTIVATION_TAG = "XNN_FUSED_ACTIVATION"
26
27    FUSEABLE_OPS = [
28        exir_ops.edge.aten.convolution.default,
29        exir_ops.edge.aten.add.Tensor,
30        exir_ops.edge.aten.sub.Tensor,
31        exir_ops.edge.aten.mul.Tensor,
32        exir_ops.edge.aten.linear.default,
33    ]
34    FUSEABLE_ACTIVATIONS = [
35        exir_ops.edge.aten.relu.default,
36        exir_ops.edge.aten.hardtanh.default,
37    ]
38
39    @staticmethod
40    def get_fused_activation(node):
41        if node.meta.get(FuseActivationPass.FUSED_ACTIVATION_TAG, None) is not None:
42            return node.meta[FuseActivationPass.FUSED_ACTIVATION_TAG]
43        return None
44
45    def get_output_min_max_from_activation(self, activation_node):
46        check_or_raise(
47            activation_node.target in self.FUSEABLE_ACTIVATIONS,
48            f"Attempted to fuse activation: {activation_node.target}, but it is not a fuseable activation",
49        )
50        if activation_node.target == exir_ops.edge.aten.relu.default:
51            output_min = 0
52            output_max = "+inf"
53        elif activation_node.target == exir_ops.edge.aten.hardtanh.default:
54            output_min = -1
55            output_max = 1
56            if len(activation_node.args) > 1:
57                output_min = activation_node.args[1]
58                output_max = activation_node.args[2]
59
60        return OutputMinMax(output_min, output_max)
61
62    def call(self, graph_module: torch.fx.GraphModule):
63        for activation_node in graph_module.graph.nodes:
64            if activation_node.op == "call_function":
65                if activation_node.target in self.FUSEABLE_ACTIVATIONS:
66                    preceding_op = activation_node.args[0]
67                    if (
68                        preceding_op.op == "call_function"
69                        and preceding_op.target in self.FUSEABLE_OPS
70                    ):
71                        # Delete activation, and embed metadata into preceding op
72                        output_min_max = self.get_output_min_max_from_activation(
73                            activation_node
74                        )
75                        preceding_op.meta[self.FUSED_ACTIVATION_TAG] = output_min_max
76                        activation_node.replace_all_uses_with(preceding_op)
77                        graph_module.graph.erase_node(activation_node)
78
79        graph_module.recompile()
80        # To Regenerate meta data and shape information, retrace module
81        graph_module = super().call(graph_module).graph_module
82
83        return PassResult(graph_module, True)
84