1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import operator 9from typing import List, Optional 10 11import torch 12from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( 13 FuseBatchNormWithConvPass, 14) 15from executorch.backends.xnnpack.partition.config.xnnpack_config import ( 16 ConfigPrecisionType, 17 XNNPartitionerConfig, 18) 19from executorch.backends.xnnpack.utils.utils import is_param_node 20from executorch.exir.backend.canonical_partitioners.config_partitioner import ( 21 format_target_name, 22) 23from executorch.exir.backend.utils import WhyNoPartition 24from torch.export import ExportedProgram 25 26logger = logging.getLogger(__name__) 27why = WhyNoPartition(logger=logger) 28 29 30class BatchNormConfig(XNNPartitionerConfig): 31 target_name = "_native_batch_norm_legit_no_training.default" 32 33 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 34 if not self.check_common_constraints(node, ep): 35 return False 36 37 bn = node 38 conv = node.all_input_nodes[0] 39 40 if conv.op != "call_function": 41 return False 42 43 conv_name = format_target_name(conv.target.__name__) # pyre-ignore 44 45 if conv_name not in ["convolution.default"]: 46 why(node, f"Invalid conv target {conv_name}") 47 return False 48 49 can_fuse = FuseBatchNormWithConvPass.can_fuse(conv, bn, ep) 50 if not can_fuse: 51 why(node, "BatchNorm cannot be fused with Convolution") 52 return False 53 54 return True 55 56 def get_node_and_deps( 57 self, node: torch.fx.Node, ep: ExportedProgram 58 ) -> List[torch.fx.Node]: 59 deps = [node] 60 61 # weight, bias, running_mean, running_var 62 deps.extend(node.all_input_nodes[1:5]) 63 64 # All the users of batchnorm node must be getitem ops. batchnorm 65 # returns a 3-element tuple. Each user must only access the first 66 # element of the tuple. 67 if [ 68 (user.target == operator.getitem and user.args[1] == 0) 69 for user in node.users 70 ].count(False): 71 return [] 72 73 deps.extend(list(node.users.keys())) 74 return deps 75 76 def supported_precision_types(self) -> List[ConfigPrecisionType]: 77 return [ConfigPrecisionType.FP32] 78 79 80class MaxDimConfig(XNNPartitionerConfig): 81 target_name = "max.dim" 82 83 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 84 # We support max_dim as long as we don't return indices 85 supported_dtypes = {torch.float32, torch.float16, torch.int8, torch.qint8} 86 node_val = node.meta.get("val") 87 output_0 = node_val[0] 88 89 input_node = node.all_input_nodes[0] 90 if len(input_node.meta.get("val").shape) != 4: 91 why(node, f"Unsupported input rank {input_node.meta.get('val').shape}") 92 return False 93 # Don't check indicies dtype 94 if output_0.dtype not in supported_dtypes: 95 why(node, f"Unsupported output dtype {output_0.dtype}") 96 return False 97 98 max_input = node.all_input_nodes[0] 99 if max_input.meta.get("val").dtype not in supported_dtypes: 100 why(node, f"Unsupported input dtype {max_input.meta.get('val').dtype}") 101 return False 102 103 # Make sure that all users are getitems of the first output 104 for user in node.users: 105 if not (user.target == operator.getitem and user.args[1] == 0): 106 why(node, "Unsupported user of max.dim") 107 return False 108 109 return True 110 111 def get_node_and_deps( 112 self, node: torch.fx.Node, ep: ExportedProgram 113 ) -> List[torch.fx.Node]: 114 getitems = list(node.users) 115 116 return [node] + getitems 117 118 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 119 return None 120 121 def supported_precision_types(self) -> List[ConfigPrecisionType]: 122 return [ConfigPrecisionType.FP32] 123 124 125class PreluConfig(XNNPartitionerConfig): 126 target_name = "prelu.default" 127 128 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 129 if not self.check_common_constraints(node, ep): 130 return False 131 132 weight = node.all_input_nodes[1] 133 is_param = is_param_node(ep, weight) 134 if not is_param: 135 why(node, "Prelu weight must be a parameter") 136 return False 137 return True 138 139 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 140 return torch.ops.aten.prelu.default 141 142 def get_node_and_deps( 143 self, node: torch.fx.Node, ep: ExportedProgram 144 ) -> List[torch.fx.Node]: 145 weight = node.all_input_nodes[1] 146 147 return [node, weight] 148 149 def supported_precision_types(self) -> List[ConfigPrecisionType]: 150 return [ConfigPrecisionType.FP32] 151