1# mypy: allow-untyped-defs 2import copy 3from typing import Any, Dict 4 5import torch 6 7 8__all__ = [ 9 "set_module_weight", 10 "set_module_bias", 11 "has_bias", 12 "get_module_weight", 13 "get_module_bias", 14 "max_over_ndim", 15 "min_over_ndim", 16 "channel_range", 17 "get_name_by_module", 18 "cross_layer_equalization", 19 "process_paired_modules_list_to_name", 20 "expand_groups_in_paired_modules_list", 21 "equalize", 22 "converged", 23] 24 25_supported_types = {torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d} 26_supported_intrinsic_types = { 27 torch.ao.nn.intrinsic.ConvReLU2d, 28 torch.ao.nn.intrinsic.LinearReLU, 29 torch.ao.nn.intrinsic.ConvReLU1d, 30} 31_all_supported_types = _supported_types.union(_supported_intrinsic_types) 32 33 34def set_module_weight(module, weight) -> None: 35 if type(module) in _supported_types: 36 module.weight = torch.nn.Parameter(weight) 37 else: 38 module[0].weight = torch.nn.Parameter(weight) 39 40 41def set_module_bias(module, bias) -> None: 42 if type(module) in _supported_types: 43 module.bias = torch.nn.Parameter(bias) 44 else: 45 module[0].bias = torch.nn.Parameter(bias) 46 47 48def has_bias(module) -> bool: 49 if type(module) in _supported_types: 50 return module.bias is not None 51 else: 52 return module[0].bias is not None 53 54 55def get_module_weight(module): 56 if type(module) in _supported_types: 57 return module.weight 58 else: 59 return module[0].weight 60 61 62def get_module_bias(module): 63 if type(module) in _supported_types: 64 return module.bias 65 else: 66 return module[0].bias 67 68 69def max_over_ndim(input, axis_list, keepdim=False): 70 """Apply 'torch.max' over the given axes.""" 71 axis_list.sort(reverse=True) 72 for axis in axis_list: 73 input, _ = input.max(axis, keepdim) 74 return input 75 76 77def min_over_ndim(input, axis_list, keepdim=False): 78 """Apply 'torch.min' over the given axes.""" 79 axis_list.sort(reverse=True) 80 for axis in axis_list: 81 input, _ = input.min(axis, keepdim) 82 return input 83 84 85def channel_range(input, axis=0): 86 """Find the range of weights associated with a specific channel.""" 87 size_of_tensor_dim = input.ndim 88 axis_list = list(range(size_of_tensor_dim)) 89 axis_list.remove(axis) 90 91 mins = min_over_ndim(input, axis_list) 92 maxs = max_over_ndim(input, axis_list) 93 94 assert mins.size(0) == input.size( 95 axis 96 ), "Dimensions of resultant channel range does not match size of requested axis" 97 return maxs - mins 98 99 100def get_name_by_module(model, module): 101 """Get the name of a module within a model. 102 103 Args: 104 model: a model (nn.module) that equalization is to be applied on 105 module: a module within the model 106 107 Returns: 108 name: the name of the module within the model 109 """ 110 for name, m in model.named_modules(): 111 if m is module: 112 return name 113 raise ValueError("module is not in the model") 114 115 116def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): 117 """Scale the range of Tensor1.output to equal Tensor2.input. 118 119 Given two adjacent tensors', the weights are scaled such that 120 the ranges of the first tensors' output channel are equal to the 121 ranges of the second tensors' input channel 122 """ 123 if ( 124 type(module1) not in _all_supported_types 125 or type(module2) not in _all_supported_types 126 ): 127 raise ValueError( 128 "module type not supported:", type(module1), " ", type(module2) 129 ) 130 131 conv1_has_bias = has_bias(module1) 132 bias = None 133 134 weight1 = get_module_weight(module1) 135 weight2 = get_module_weight(module2) 136 137 if weight1.size(output_axis) != weight2.size(input_axis): 138 raise TypeError( 139 "Number of output channels of first arg do not match \ 140 number input channels of second arg" 141 ) 142 143 if conv1_has_bias: 144 bias = get_module_bias(module1) 145 146 weight1_range = channel_range(weight1, output_axis) 147 weight2_range = channel_range(weight2, input_axis) 148 149 # producing scaling factors to applied 150 weight2_range += 1e-9 151 scaling_factors = torch.sqrt(weight1_range / weight2_range) 152 inverse_scaling_factors = torch.reciprocal(scaling_factors) 153 154 if conv1_has_bias: 155 bias = bias * inverse_scaling_factors 156 157 # formatting the scaling (1D) tensors to be applied on the given argument tensors 158 # pads axis to (1D) tensors to then be broadcasted 159 size1 = [1] * weight1.ndim 160 size1[output_axis] = weight1.size(output_axis) 161 size2 = [1] * weight2.ndim 162 size2[input_axis] = weight2.size(input_axis) 163 164 scaling_factors = torch.reshape(scaling_factors, size2) 165 inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1) 166 167 weight1 = weight1 * inverse_scaling_factors 168 weight2 = weight2 * scaling_factors 169 170 set_module_weight(module1, weight1) 171 if conv1_has_bias: 172 set_module_bias(module1, bias) 173 set_module_weight(module2, weight2) 174 175 176def process_paired_modules_list_to_name(model, paired_modules_list): 177 """Processes a list of paired modules to a list of names of paired modules.""" 178 179 for group in paired_modules_list: 180 for i, item in enumerate(group): 181 if isinstance(item, torch.nn.Module): 182 group[i] = get_name_by_module(model, item) 183 elif not isinstance(item, str): 184 raise TypeError("item must be a nn.Module or a string") 185 return paired_modules_list 186 187 188def expand_groups_in_paired_modules_list(paired_modules_list): 189 """Expands module pair groups larger than two into groups of two modules.""" 190 new_list = [] 191 192 for group in paired_modules_list: 193 if len(group) == 1: 194 raise ValueError("Group must have at least two modules") 195 elif len(group) == 2: 196 new_list.append(group) 197 elif len(group) > 2: 198 for i in range(len(group) - 1): 199 new_list.append([group[i], group[i + 1]]) 200 201 return new_list 202 203 204def equalize(model, paired_modules_list, threshold=1e-4, inplace=True): 205 """Equalize modules until convergence is achieved. 206 207 Given a list of adjacent modules within a model, equalization will 208 be applied between each pair, this will repeated until convergence is achieved 209 210 Keeps a copy of the changing modules from the previous iteration, if the copies 211 are not that different than the current modules (determined by converged_test), 212 then the modules have converged enough that further equalizing is not necessary 213 214 Reference is section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf 215 216 Args: 217 model: a model (nn.Module) that equalization is to be applied on 218 paired_modules_list (List(List[nn.module || str])): a list of lists 219 where each sublist is a pair of two submodules found in the model, 220 for each pair the two modules have to be adjacent in the model, 221 with only piece-wise-linear functions like a (P)ReLU or LeakyReLU in between 222 to get expected results. 223 The list can contain either modules, or names of modules in the model. 224 If you pass multiple modules in the same list, they will all be equalized together. 225 threshold (float): a number used by the converged function to determine what degree 226 of similarity between models is necessary for them to be called equivalent 227 inplace (bool): determines if function is inplace or not 228 """ 229 230 paired_modules_list = process_paired_modules_list_to_name( 231 model, paired_modules_list 232 ) 233 234 if not inplace: 235 model = copy.deepcopy(model) 236 237 paired_modules_list = expand_groups_in_paired_modules_list(paired_modules_list) 238 239 name_to_module: Dict[str, torch.nn.Module] = {} 240 previous_name_to_module: Dict[str, Any] = {} 241 name_set = {name for pair in paired_modules_list for name in pair} 242 243 for name, module in model.named_modules(): 244 if name in name_set: 245 name_to_module[name] = module 246 previous_name_to_module[name] = None 247 while not converged(name_to_module, previous_name_to_module, threshold): 248 for pair in paired_modules_list: 249 previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]]) 250 previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]]) 251 252 cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]]) 253 254 return model 255 256 257def converged(curr_modules, prev_modules, threshold=1e-4): 258 """Test whether modules are converged to a specified threshold. 259 260 Tests for the summed norm of the differences between each set of modules 261 being less than the given threshold 262 263 Takes two dictionaries mapping names to modules, the set of names for each dictionary 264 should be the same, looping over the set of names, for each name take the difference 265 between the associated modules in each dictionary 266 267 """ 268 if curr_modules.keys() != prev_modules.keys(): 269 raise ValueError( 270 "The keys to the given mappings must have the same set of names of modules" 271 ) 272 273 summed_norms = torch.tensor(0.0) 274 if None in prev_modules.values(): 275 return False 276 for name in curr_modules.keys(): 277 curr_weight = get_module_weight(curr_modules[name]) 278 prev_weight = get_module_weight(prev_modules[name]) 279 280 difference = curr_weight.sub(prev_weight) 281 summed_norms += torch.norm(difference) 282 return bool(summed_norms < threshold) 283