1# mypy: allow-untyped-defs 2import torch 3import torch.nn as nn 4from torch._dynamo.utils import counters 5from torch._inductor import config as inductor_config 6from torch.func import functional_call 7 8from ..pattern_matcher import ( 9 CallFunctionVarArgs, 10 CallModuleVarArgs, 11 Match, 12 register_graph_pattern, 13) 14from .pre_grad import efficient_conv_bn_eval_pass 15 16 17def efficient_conv_bn_eval( 18 bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor 19): 20 """ 21 Implementation based on https://arxiv.org/abs/2305.11624 22 "Efficient ConvBN Blocks for Transfer Learning and Beyond" 23 It leverages the associative law between convolution and affine transform, 24 i.e., normalize (weight conv feature) = (normalize weight) conv feature. 25 It works for Eval mode of ConvBN blocks during validation, and can be used 26 for **training** as well, but only if one sets `bn.training=False`. It 27 reduces memory footprint and computation cost, at the cost of slightly 28 reduced numerical stability. 29 Args: 30 bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module. 31 conv (nn.modules.conv._ConvNd): a conv module 32 x (torch.Tensor): Input feature map. 33 """ 34 35 assert bn.running_var is not None 36 37 # These lines of code are designed to deal with various cases 38 # like bn without affine transform, and conv without bias 39 weight_on_the_fly = conv.weight 40 if conv.bias is not None: 41 bias_on_the_fly = conv.bias 42 else: 43 bias_on_the_fly = torch.zeros_like(bn.running_var) 44 45 if bn.weight is not None: 46 bn_weight = bn.weight 47 else: 48 bn_weight = torch.ones_like(bn.running_var) 49 50 if bn.bias is not None: 51 bn_bias = bn.bias 52 else: 53 bn_bias = torch.zeros_like(bn.running_var) 54 55 # shape of [C_out, 1, 1, 1] in Conv2d 56 target_shape = [-1] + [1] * (conv.weight.ndim - 1) 57 if isinstance(conv, nn.modules.conv._ConvTransposeNd): 58 # for transposed conv, the C_out dimension should at index 1. 59 target_shape[:2] = [target_shape[1], target_shape[0]] 60 weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape) 61 # shape of [C_out, 1, 1, 1] in Conv2d 62 coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff 63 64 # shape of [C_out, C_in, k, k] in Conv2d 65 weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly 66 # shape of [C_out] in Conv2d 67 bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( 68 bias_on_the_fly - bn.running_mean 69 ) 70 71 input = x 72 params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly} 73 output = functional_call(conv, params, input) 74 return output 75 76 77def efficient_conv_bn_eval_decomposed( 78 bn_weight, 79 bn_bias, 80 bn_running_mean, 81 bn_running_var, 82 bn_eps, 83 conv: torch._ops.OpOverload, 84 conv_weight, 85 conv_bias, 86 x, 87 conv_remainging_args, 88): 89 """ 90 Implementation based on https://arxiv.org/abs/2305.11624 91 "Efficient ConvBN Blocks for Transfer Learning and Beyond" 92 It leverages the associative law between convolution and affine transform, 93 i.e., normalize (weight conv feature) = (normalize weight) conv feature. 94 It works for Eval mode of ConvBN blocks during validation, and can be used 95 for **training** as well, but only if one sets `bn.training=False`. It 96 reduces memory footprint and computation cost, at the cost of slightly 97 reduced numerical stability. 98 Args: 99 """ 100 assert bn_running_var is not None 101 102 # These lines of code are designed to deal with various cases 103 # like bn without affine transform, and conv without bias 104 weight_on_the_fly = conv_weight 105 if conv_bias is not None: 106 bias_on_the_fly = conv_bias 107 else: 108 bias_on_the_fly = torch.zeros_like(bn_running_var) 109 110 if bn_weight is not None: 111 bn_weight = bn_weight 112 else: 113 bn_weight = torch.ones_like(bn_running_var) 114 115 if bn_bias is not None: 116 bn_bias = bn_bias 117 else: 118 bn_bias = torch.zeros_like(bn_running_var) 119 120 # shape of [C_out, 1, 1, 1] in Conv2d 121 target_shape = [-1] + [1] * (conv_weight.ndim - 1) 122 if "conv_transpose" in conv.__str__(): 123 # for transposed conv, the C_out dimension should at index 1. 124 target_shape[:2] = [target_shape[1], target_shape[0]] 125 weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape) 126 # shape of [C_out, 1, 1, 1] in Conv2d 127 coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff 128 129 # shape of [C_out, C_in, k, k] in Conv2d 130 weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly 131 # shape of [C_out] in Conv2d 132 bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( 133 bias_on_the_fly - bn_running_mean 134 ) 135 136 input = x 137 return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args)) 138 139 140@register_graph_pattern( 141 CallFunctionVarArgs( 142 [ 143 torch.nn.functional.batch_norm, 144 ] 145 ), 146 pass_dict=efficient_conv_bn_eval_pass, 147 extra_check=lambda match: not inductor_config.freezing 148 and inductor_config.efficient_conv_bn_eval_fx_passes, 149) 150def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs): 151 bn_node = match.nodes[0] 152 graph = match.graph 153 assert len(bn_node.args) == 8 154 155 # We can only use efficient conv-bn for eval mode with track_running_stats 156 # bn_node.args is `training` 157 if bn_node.args[-3]: 158 return 159 160 # Check if the input is Conv 161 input_node = bn_node.args[0] 162 163 if input_node.op != "call_function": # type: ignore[union-attr] 164 return 165 166 input_fn = input_node.target # type: ignore[arg-type, union-attr] 167 supported_convs = [ 168 torch._C._nn.linear, 169 torch.conv1d, 170 torch.conv2d, 171 torch.conv3d, 172 torch.conv_transpose1d, 173 torch.conv_transpose2d, 174 torch.conv_transpose3d, 175 ] 176 177 if not any(input_fn is cls for cls in supported_convs): 178 return 179 180 conv_node = input_node 181 # Output of conv is used by other nodes, cannot optimize 182 if len(conv_node.users) > 1: # type: ignore[union-attr] 183 return 184 185 counters["inductor"]["efficient_conv_bn_eval"] += 1 186 187 with graph.inserting_before(bn_node): 188 # prepare args for the fused function 189 bn_running_mean = bn_node.args[1] 190 bn_running_var = bn_node.args[2] 191 bn_weight = bn_node.args[3] 192 bn_bias = bn_node.args[4] 193 bn_eps = bn_node.args[7] 194 assert len(conv_node.args) >= 2 # type: ignore[union-attr] 195 conv_input = conv_node.args[0] # type: ignore[union-attr] 196 conv_weight = conv_node.args[1] # type: ignore[union-attr] 197 conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] 198 conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] 199 args = ( 200 bn_weight, 201 bn_bias, 202 bn_running_mean, 203 bn_running_var, 204 bn_eps, 205 conv_node.target, # type: ignore[union-attr] 206 conv_weight, 207 conv_bias, 208 conv_input, 209 conv_remainging_args, 210 ) 211 212 # create a new node 213 new_node = graph.create_node( 214 op="call_function", 215 target=efficient_conv_bn_eval_decomposed, 216 args=args, # type: ignore[arg-type] 217 name="efficient_conv_bn_eval", 218 ) 219 220 # this node replaces the original conv + bn, and therefore 221 # should replace the uses of bn_node 222 bn_node.replace_all_uses_with(new_node) 223 # take care of the deletion order: 224 # delete bn_node first, and then conv_node 225 graph.erase_node(bn_node) 226 graph.erase_node(conv_node) # type: ignore[arg-type] 227 228 return 229 230 231@register_graph_pattern( 232 CallFunctionVarArgs( 233 [ 234 torch.ops.aten.batch_norm.default, 235 ] 236 ), 237 pass_dict=efficient_conv_bn_eval_pass, 238 extra_check=lambda match: not inductor_config.freezing 239 and inductor_config.efficient_conv_bn_eval_fx_passes, 240) 241def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs): 242 bn_node = match.nodes[0] 243 graph = match.graph 244 assert len(bn_node.args) == 9 245 246 # We can only use efficient conv-bn for eval mode with track_running_stats 247 # bn_node.args is `training` 248 if bn_node.args[-4]: 249 return 250 251 # Check if the input is Conv 252 input_node = bn_node.args[0] 253 254 if input_node.op != "call_function": # type: ignore[union-attr] 255 return 256 257 input_fn = input_node.target # type: ignore[arg-type, union-attr] 258 supported_convs = [ 259 torch.ops.aten.linear.default, 260 torch.ops.aten.conv1d.default, 261 torch.ops.aten.conv2d.default, 262 torch.ops.aten.conv3d.default, 263 torch.ops.aten.conv_transpose1d.default, 264 torch.ops.aten.conv_transpose2d.input, 265 torch.ops.aten.conv_transpose3d.input, 266 ] 267 268 if not any(input_fn is cls for cls in supported_convs): 269 return 270 271 conv_node = input_node 272 # Output of conv is used by other nodes, cannot optimize 273 if len(conv_node.users) > 1: # type: ignore[union-attr] 274 return 275 276 counters["inductor"]["efficient_conv_bn_eval"] += 1 277 278 with graph.inserting_before(bn_node): 279 # prepare args for the fused function 280 bn_weight = bn_node.args[1] 281 bn_bias = bn_node.args[2] 282 bn_running_mean = bn_node.args[3] 283 bn_running_var = bn_node.args[4] 284 bn_eps = bn_node.args[7] 285 assert len(conv_node.args) >= 2 # type: ignore[union-attr] 286 conv_input = conv_node.args[0] # type: ignore[union-attr] 287 conv_weight = conv_node.args[1] # type: ignore[union-attr] 288 conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] 289 conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] 290 args = ( 291 bn_weight, 292 bn_bias, 293 bn_running_mean, 294 bn_running_var, 295 bn_eps, 296 conv_node.target, # type: ignore[union-attr] 297 conv_weight, 298 conv_bias, 299 conv_input, 300 conv_remainging_args, 301 ) 302 303 # create a new node 304 new_node = graph.create_node( 305 op="call_function", 306 target=efficient_conv_bn_eval_decomposed, 307 args=args, # type: ignore[arg-type] 308 name="efficient_conv_bn_eval", 309 ) 310 311 # this node replaces the original conv + bn, and therefore 312 # should replace the uses of bn_node 313 bn_node.replace_all_uses_with(new_node) 314 # take care of the deletion order: 315 # delete bn_node first, and then conv_node 316 graph.erase_node(bn_node) 317 graph.erase_node(conv_node) # type: ignore[arg-type] 318 319 return 320 321 322@register_graph_pattern( 323 CallModuleVarArgs( 324 [ 325 nn.modules.batchnorm._BatchNorm, 326 nn.BatchNorm1d, 327 nn.BatchNorm2d, 328 nn.BatchNorm3d, 329 nn.SyncBatchNorm, 330 ], 331 ), 332 pass_dict=efficient_conv_bn_eval_pass, 333 extra_check=lambda match: not inductor_config.freezing 334 and inductor_config.efficient_conv_bn_eval_fx_passes, 335) 336def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs): 337 # We matched a BN node 338 bn_node = match.nodes[0] 339 graph = match.graph 340 gm = graph.owning_module 341 bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type] 342 343 # We can only use efficient conv-bn for eval mode with track_running_stats 344 if not bn_mod.track_running_stats or bn_mod.training: 345 return 346 347 # Check if the input is Conv 348 if bn_node.args: 349 input_node = bn_node.args[0] 350 else: 351 input_node = bn_node.kwargs["input"] 352 if input_node.op != "call_module": # type: ignore[union-attr] 353 return 354 if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr] 355 return 356 input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr] 357 supported_convs = [ 358 nn.Linear, 359 nn.Conv1d, 360 nn.Conv2d, 361 nn.Conv3d, 362 nn.ConvTranspose1d, 363 nn.ConvTranspose2d, 364 nn.ConvTranspose3d, 365 ] 366 if not any(isinstance(input_mod, cls) for cls in supported_convs): 367 return 368 conv_node = input_node 369 # Output of conv is used by other nodes, cannot optimize 370 if len(conv_node.users) > 1: # type: ignore[union-attr] 371 return 372 373 # Find a pair of conv and bn computation nodes to optimize. 374 counters["inductor"]["efficient_conv_bn_eval"] += 1 375 376 with graph.inserting_before(conv_node): # type: ignore[arg-type] 377 # create `get_attr` node to access modules 378 # note that we directly call `create_node` to fill the `name` 379 # argument. `graph.get_attr` and 380 # `graph.call_function` does not allow the `name` argument. 381 conv_get_node = graph.create_node( 382 op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr] 383 ) 384 bn_get_node = graph.create_node( 385 op="get_attr", target=bn_node.target, name="get_bn" 386 ) 387 if conv_node.args: # type: ignore[union-attr] 388 conv_input = conv_node.args[0] # type: ignore[union-attr] 389 else: 390 conv_input = conv_node.kwargs["input"] # type: ignore[union-attr] 391 # prepare args for the fused function 392 args = (bn_get_node, conv_get_node, conv_input) 393 # create a new node 394 new_node = graph.create_node( 395 op="call_function", 396 target=efficient_conv_bn_eval, 397 args=args, 398 name="efficient_conv_bn_eval", 399 ) 400 # this node replaces the original conv + bn, and therefore 401 # should replace the uses of bn_node 402 bn_node.replace_all_uses_with(new_node) 403 # take care of the deletion order: 404 # delete bn_node first, and then conv_node 405 graph.erase_node(bn_node) 406 graph.erase_node(conv_node) # type: ignore[arg-type] 407