1from __future__ import annotations 2 3import copy 4from typing import Optional, Tuple, TypeVar 5 6import torch 7 8 9__all__ = [ 10 "fuse_conv_bn_eval", 11 "fuse_conv_bn_weights", 12 "fuse_linear_bn_eval", 13 "fuse_linear_bn_weights", 14] 15 16ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd") 17LinearT = TypeVar("LinearT", bound="torch.nn.Linear") 18 19 20def fuse_conv_bn_eval( 21 conv: ConvT, 22 bn: torch.nn.modules.batchnorm._BatchNorm, 23 transpose: bool = False, 24) -> ConvT: 25 r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module. 26 27 Args: 28 conv (torch.nn.modules.conv._ConvNd): A convolutional module. 29 bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. 30 transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False. 31 32 Returns: 33 torch.nn.modules.conv._ConvNd: The fused convolutional module. 34 35 .. note:: 36 Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. 37 """ 38 assert not (conv.training or bn.training), "Fusion only for eval!" 39 fused_conv = copy.deepcopy(conv) 40 41 assert bn.running_mean is not None and bn.running_var is not None 42 fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( 43 fused_conv.weight, 44 fused_conv.bias, 45 bn.running_mean, 46 bn.running_var, 47 bn.eps, 48 bn.weight, 49 bn.bias, 50 transpose, 51 ) 52 53 return fused_conv 54 55 56def fuse_conv_bn_weights( 57 conv_w: torch.Tensor, 58 conv_b: Optional[torch.Tensor], 59 bn_rm: torch.Tensor, 60 bn_rv: torch.Tensor, 61 bn_eps: float, 62 bn_w: Optional[torch.Tensor], 63 bn_b: Optional[torch.Tensor], 64 transpose: bool = False, 65) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: 66 r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters. 67 68 Args: 69 conv_w (torch.Tensor): Convolutional weight. 70 conv_b (Optional[torch.Tensor]): Convolutional bias. 71 bn_rm (torch.Tensor): BatchNorm running mean. 72 bn_rv (torch.Tensor): BatchNorm running variance. 73 bn_eps (float): BatchNorm epsilon. 74 bn_w (Optional[torch.Tensor]): BatchNorm weight. 75 bn_b (Optional[torch.Tensor]): BatchNorm bias. 76 transpose (bool, optional): If True, transpose the conv weight. Defaults to False. 77 78 Returns: 79 Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias. 80 """ 81 conv_weight_dtype = conv_w.dtype 82 conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype 83 if conv_b is None: 84 conv_b = torch.zeros_like(bn_rm) 85 if bn_w is None: 86 bn_w = torch.ones_like(bn_rm) 87 if bn_b is None: 88 bn_b = torch.zeros_like(bn_rm) 89 bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) 90 91 if transpose: 92 shape = [1, -1] + [1] * (len(conv_w.shape) - 2) 93 else: 94 shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) 95 96 fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to( 97 dtype=conv_weight_dtype 98 ) 99 fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to( 100 dtype=conv_bias_dtype 101 ) 102 103 return ( 104 torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), 105 torch.nn.Parameter(fused_conv_b, conv_b.requires_grad), 106 ) 107 108 109def fuse_linear_bn_eval( 110 linear: LinearT, 111 bn: torch.nn.modules.batchnorm._BatchNorm, 112) -> LinearT: 113 r"""Fuse a linear module and a BatchNorm module into a single, new linear module. 114 115 Args: 116 linear (torch.nn.Linear): A Linear module. 117 bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module. 118 119 Returns: 120 torch.nn.Linear: The fused linear module. 121 122 .. note:: 123 Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed. 124 """ 125 assert not (linear.training or bn.training), "Fusion only for eval!" 126 fused_linear = copy.deepcopy(linear) 127 128 """ 129 Linear-BN needs to be fused while preserving the shapes of linear weight/bias. 130 To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear, 131 because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in). 132 To be broadcastable, the number of features in bn and 133 the number of output features from linear must satisfy the following condition: 134 1. they are equal, or 135 2. the number of features in bn is 1 136 Otherwise, skip the folding path 137 """ 138 assert ( 139 linear.out_features == bn.num_features or bn.num_features == 1 140 ), "To fuse, linear.out_features == bn.num_features or bn.num_features == 1" 141 142 assert bn.running_mean is not None and bn.running_var is not None 143 fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights( 144 fused_linear.weight, 145 fused_linear.bias, 146 bn.running_mean, 147 bn.running_var, 148 bn.eps, 149 bn.weight, 150 bn.bias, 151 ) 152 153 return fused_linear 154 155 156def fuse_linear_bn_weights( 157 linear_w: torch.Tensor, 158 linear_b: Optional[torch.Tensor], 159 bn_rm: torch.Tensor, 160 bn_rv: torch.Tensor, 161 bn_eps: float, 162 bn_w: torch.Tensor, 163 bn_b: torch.Tensor, 164) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: 165 r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters. 166 167 Args: 168 linear_w (torch.Tensor): Linear weight. 169 linear_b (Optional[torch.Tensor]): Linear bias. 170 bn_rm (torch.Tensor): BatchNorm running mean. 171 bn_rv (torch.Tensor): BatchNorm running variance. 172 bn_eps (float): BatchNorm epsilon. 173 bn_w (torch.Tensor): BatchNorm weight. 174 bn_b (torch.Tensor): BatchNorm bias. 175 176 Returns: 177 Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias. 178 """ 179 linear_weight_dtype = linear_w.dtype 180 linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype 181 if linear_b is None: 182 linear_b = torch.zeros_like(bn_rm) 183 bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps) 184 185 fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype) 186 fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype) 187 188 return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter( 189 fused_b, linear_b.requires_grad 190 ) 191