1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8import torch 9 10from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass 11from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import OutputMinMax 12 13from executorch.backends.xnnpack.utils.utils import check_or_raise 14from executorch.exir.dialects._ops import ops as exir_ops 15from executorch.exir.pass_base import PassResult 16 17 18class FuseActivationPass(XNNPACKPass): 19 """ 20 Some activations like ReLU and hardtanh can be fused with certain operators preceding it. 21 In the case of fusion, we can instead delete the relu node and embed the activation constraints in the metadata 22 of the preceding node. 23 """ 24 25 FUSED_ACTIVATION_TAG = "XNN_FUSED_ACTIVATION" 26 27 FUSEABLE_OPS = [ 28 exir_ops.edge.aten.convolution.default, 29 exir_ops.edge.aten.add.Tensor, 30 exir_ops.edge.aten.sub.Tensor, 31 exir_ops.edge.aten.mul.Tensor, 32 exir_ops.edge.aten.linear.default, 33 ] 34 FUSEABLE_ACTIVATIONS = [ 35 exir_ops.edge.aten.relu.default, 36 exir_ops.edge.aten.hardtanh.default, 37 ] 38 39 @staticmethod 40 def get_fused_activation(node): 41 if node.meta.get(FuseActivationPass.FUSED_ACTIVATION_TAG, None) is not None: 42 return node.meta[FuseActivationPass.FUSED_ACTIVATION_TAG] 43 return None 44 45 def get_output_min_max_from_activation(self, activation_node): 46 check_or_raise( 47 activation_node.target in self.FUSEABLE_ACTIVATIONS, 48 f"Attempted to fuse activation: {activation_node.target}, but it is not a fuseable activation", 49 ) 50 if activation_node.target == exir_ops.edge.aten.relu.default: 51 output_min = 0 52 output_max = "+inf" 53 elif activation_node.target == exir_ops.edge.aten.hardtanh.default: 54 output_min = -1 55 output_max = 1 56 if len(activation_node.args) > 1: 57 output_min = activation_node.args[1] 58 output_max = activation_node.args[2] 59 60 return OutputMinMax(output_min, output_max) 61 62 def call(self, graph_module: torch.fx.GraphModule): 63 for activation_node in graph_module.graph.nodes: 64 if activation_node.op == "call_function": 65 if activation_node.target in self.FUSEABLE_ACTIVATIONS: 66 preceding_op = activation_node.args[0] 67 if ( 68 preceding_op.op == "call_function" 69 and preceding_op.target in self.FUSEABLE_OPS 70 ): 71 # Delete activation, and embed metadata into preceding op 72 output_min_max = self.get_output_min_max_from_activation( 73 activation_node 74 ) 75 preceding_op.meta[self.FUSED_ACTIVATION_TAG] = output_min_max 76 activation_node.replace_all_uses_with(preceding_op) 77 graph_module.graph.erase_node(activation_node) 78 79 graph_module.recompile() 80 # To Regenerate meta data and shape information, retrace module 81 graph_module = super().call(graph_module).graph_module 82 83 return PassResult(graph_module, True) 84