1# mypy: allow-untyped-defs 2import itertools 3from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union 4 5import torch.ao.nn.intrinsic as nni 6import torch.nn as nn 7from torch.ao.quantization.utils import get_combined_dict, MatchAllNode, Pattern 8 9 10__all__ = [ 11 "fuse_conv_bn", 12 "fuse_conv_bn_relu", 13 "fuse_linear_bn", 14 "fuse_convtranspose_bn", 15 "get_fuser_method", 16 "get_fuser_method_new", 17] 18 19 20def fuse_conv_bn(is_qat, conv, bn): 21 r"""Return the fused the conv and bn modules. 22 Given the conv and bn modules, fuses them and returns the fused module 23 24 Args: 25 is_qat: a flag for whether we are using quantization aware training fusion 26 or post training quantization fusion 27 conv: Module instance of type conv2d/conv3d 28 bn: Spatial BN instance that needs to be fused with the conv 29 30 Examples:: 31 32 >>> m1 = nn.Conv2d(10, 20, 3) 33 >>> b1 = nn.BatchNorm2d(20) 34 >>> # xdoctest: +SKIP 35 >>> m2 = fuse_conv_bn(m1, b1) 36 """ 37 assert ( 38 conv.training == bn.training 39 ), "Conv and BN both must be in the same mode (train or eval)." 40 41 fused_module_class_map = { 42 nn.Conv1d: nni.ConvBn1d, 43 nn.Conv2d: nni.ConvBn2d, 44 nn.Conv3d: nni.ConvBn3d, 45 } 46 47 if is_qat: 48 assert ( 49 bn.num_features == conv.out_channels 50 ), "Output channel of Conv2d must match num_features of BatchNorm2d" 51 assert bn.affine, "Only support fusing BatchNorm2d with affine set to True" 52 assert ( 53 bn.track_running_stats 54 ), "Only support fusing BatchNorm2d with tracking_running_stats set to True" 55 fused_module_class = fused_module_class_map.get((type(conv)), None) 56 if fused_module_class is not None: 57 return fused_module_class(conv, bn) 58 else: 59 raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn)}") 60 else: 61 return nn.utils.fuse_conv_bn_eval(conv, bn) 62 63 64def fuse_conv_bn_relu(is_qat, conv, bn, relu): 65 r"""Return the fused conv and bv modules. 66 67 Given the conv and bn modules, fuses them and returns the fused module 68 69 Args: 70 is_qat: a flag for whether we are using quantization aware training fusion 71 or post training quantization fusion 72 conv: Module instance of type conv2d/conv3d 73 bn: Spatial BN instance that needs to be fused with the conv 74 75 Examples:: 76 77 >>> m1 = nn.Conv2d(10, 20, 3) 78 >>> b1 = nn.BatchNorm2d(20) 79 >>> r1 = nn.ReLU(inplace=False) 80 >>> # xdoctest: +SKIP 81 >>> m2 = fuse_conv_bn_relu(m1, b1, r1) 82 """ 83 assert ( 84 conv.training == bn.training == relu.training 85 ), "Conv and BN both must be in the same mode (train or eval)." 86 fused_module: Optional[Type[nn.Sequential]] = None 87 if is_qat: 88 map_to_fused_module_train = { 89 nn.Conv1d: nni.ConvBnReLU1d, 90 nn.Conv2d: nni.ConvBnReLU2d, 91 nn.Conv3d: nni.ConvBnReLU3d, 92 } 93 assert ( 94 bn.num_features == conv.out_channels 95 ), "Output channel of Conv must match num_features of BatchNorm" 96 assert bn.affine, "Only support fusing BatchNorm with affine set to True" 97 assert ( 98 bn.track_running_stats 99 ), "Only support fusing BatchNorm with tracking_running_stats set to True" 100 fused_module = map_to_fused_module_train.get(type(conv), None) 101 if fused_module is not None: 102 return fused_module(conv, bn, relu) 103 else: 104 raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, relu)}") 105 else: 106 map_to_fused_module_eval = { 107 nn.Conv1d: nni.ConvReLU1d, 108 nn.Conv2d: nni.ConvReLU2d, 109 nn.Conv3d: nni.ConvReLU3d, 110 } 111 fused_module = map_to_fused_module_eval.get(type(conv), None) 112 if fused_module is not None: 113 fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) 114 return fused_module(fused_conv, relu) 115 else: 116 raise NotImplementedError(f"Cannot fuse eval modules: {(conv, bn, relu)}") 117 118 119def fuse_linear_bn(is_qat, linear, bn): 120 r"""Return the fused linear and bn modules. 121 Given the linear and bn modules, fuses them and returns the fused module 122 123 Args: 124 is_qat: a flag for whether we are using quantization aware training fusion 125 or post training quantization fusion 126 linear: Module instance of type Linear 127 bn: BatchNorm1d instance that needs to be fused with the linear layer 128 129 Examples:: 130 131 >>> m1 = nn.Linear(20, 10) 132 >>> b1 = nn.BatchNorm1d(10) 133 >>> # xdoctest: +SKIP 134 >>> m2 = fuse_linear_bn(m1, b1) 135 """ 136 assert ( 137 linear.training == bn.training 138 ), "Linear and BN both must be in the same mode (train or eval)." 139 140 if is_qat: 141 assert ( 142 bn.num_features == linear.out_features 143 ), "Output features of Linear must match num_features of BatchNorm1d" 144 assert bn.affine, "Only support fusing BatchNorm1d with affine set to True" 145 assert ( 146 bn.track_running_stats 147 ), "Only support fusing BatchNorm1d with tracking_running_stats set to True" 148 return nni.LinearBn1d(linear, bn) 149 else: 150 return nn.utils.fusion.fuse_linear_bn_eval(linear, bn) 151 152 153def fuse_convtranspose_bn(is_qat, convt, bn): 154 r"""Return the fused ConvTranspose and bn modules. 155 Given ConvTranspose and bn modules, fuses them and returns the fused module 156 157 Args: 158 convt: Module instance of type ConvTransposeNd 159 bn: BatchNormNd instance that needs to be fused with the linear layer. 160 batch norm N should match the ConvTranspose N 161 162 Examples:: 163 164 >>> m1 = nn.ConvTranspose2d(10, 20, 3) 165 >>> b1 = nn.BatchNorm2d(20) 166 >>> # xdoctest: +SKIP 167 >>> m2 = fuse_convtranspose_bn(m1, b1) 168 """ 169 assert ( 170 convt.training == bn.training 171 ), "ConvTranspose and BN both must be in the same mode (train or eval)." 172 173 if is_qat: 174 raise Exception( # noqa: TRY002 175 "Fusing ConvTranspose+BatchNorm not yet supported in QAT." 176 ) 177 else: 178 return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True) 179 180 181def _sequential_wrapper2(sequential): 182 """Return a sequential wrapped that for is_qat and two modules. 183 Given a sequential class for two modules, return a function that takes 184 is_qat, and then two modules as argument, that ignores the is_qat flag 185 and always returns the sequential that combines the two input modules 186 """ 187 188 def fuser_method(is_qat, m1, m2): 189 return sequential(m1, m2) 190 191 return fuser_method 192 193 194_DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = { 195 (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, 196 (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, 197 (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, 198 (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu, 199 (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn, 200 (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu, 201 (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d), 202 (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d), 203 (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d), 204 (nn.Linear, nn.BatchNorm1d): fuse_linear_bn, 205 (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU), 206 (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d), 207 (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d), 208 (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn, 209 (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn, 210 (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn, 211} 212 213 214def get_fuser_method(op_list, additional_fuser_method_mapping=None): 215 """Get fuser method for the given list of module types. 216 217 Get fuser method for the given list of module types, 218 return None if fuser method does not exist 219 """ 220 if additional_fuser_method_mapping is None: 221 additional_fuser_method_mapping = {} 222 all_mappings = get_combined_dict( 223 _DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping 224 ) 225 fuser_method = all_mappings.get(op_list, None) 226 assert fuser_method is not None, f"did not find fuser method for: {op_list} " 227 return fuser_method 228 229 230def _reverse2(f): 231 def reversed(is_qat, x, y): 232 return f(is_qat, y, x) 233 234 return reversed 235 236 237def _reverse3(f): 238 def reversed(is_qat, x, w): 239 y, z = w 240 return f(is_qat, z, y, x) 241 242 return reversed 243 244 245def _get_valid_patterns(op_pattern): 246 """Return a list of valid patterns generated from the op_pattern. 247 248 Returns a list of valid patterns generated from the op_pattern, 249 since MatchAllNode can match all types of nodes, 250 e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like 251 (MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode) 252 253 Example Input: 254 (torch.add, (torch.nn.ReLU, torch.nn.Conv2d)) 255 256 Example Output: 257 [(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)), 258 (torch.add, (torch.nn.ReLU, MatchAllNode)), 259 (torch.add, (MatchAllNode, torch.nn.Conv2d)), 260 (torch.add, (MatchAllNode, MatchAllNode)), 261 (MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)), 262 (MatchAllNode, (torch.nn.ReLU, MatchAllNode)), 263 (MatchAllNode, (MatchAllNode, torch.nn.Conv2d)), 264 (MatchAllNode, (MatchAllNode, MatchAllNode)), 265 ] 266 """ 267 result: List[Any] 268 if isinstance(op_pattern, (tuple, list)): 269 sub_combs = [] 270 for sub_pattern in op_pattern: 271 sub_combs.append(_get_valid_patterns(sub_pattern)) 272 result = list(itertools.product(*sub_combs)) 273 else: 274 result = [op_pattern, MatchAllNode] 275 return result 276 277 278def get_fuser_method_new( 279 op_pattern: Pattern, 280 fuser_method_mapping: Dict[Pattern, Union[nn.Sequential, Callable]], 281): 282 """Get fuser method. 283 284 This will be made default after we deprecate the get_fuser_method 285 Would like to implement this first and have a separate PR for deprecation 286 """ 287 op_patterns = _get_valid_patterns(op_pattern) 288 fuser_method = None 289 for op_pattern in op_patterns: 290 fuser_method = fuser_method_mapping.get(op_pattern, None) 291 if fuser_method is not None: 292 break 293 assert fuser_method is not None, f"did not find fuser method for: {op_pattern} " 294 return fuser_method 295