1# mypy: allow-untyped-defs 2import abc 3import copy 4from collections import defaultdict 5from typing import Any, Dict, List, Optional, Set, Tuple, Type 6 7import torch 8from torch import nn 9from torch.nn.utils import parametrize 10from torch.nn.utils.parametrize import type_before_parametrizations 11 12from .utils import ( 13 FakeSparsity, 14 get_arg_info_from_tensor_fqn, 15 module_contains_param, 16 module_to_fqn, 17 swap_module, 18) 19 20 21__all__ = ["BaseSparsifier"] 22 23SUPPORTED_MODULES = {nn.Linear} 24 25KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"] 26 27 28# TODO update desc with new config args 29class BaseSparsifier(abc.ABC): 30 r"""Base class for all sparsifiers. 31 32 Abstract methods that need to be implemented: 33 34 - update_mask: Function to compute a new mask for all keys in the 35 `groups`. 36 37 Args: 38 - model [nn.Module]: model to configure. The model itself is not saved 39 but used for the state_dict saving / loading. 40 - config [list]: configuration elements should be a dict map that includes 41 `tensor_fqn` of tensors to sparsify 42 - defaults [dict]: default configurations will be attached to the 43 configuration. Only the keys that don't exist in the `config` will 44 be updated. 45 46 Example:: 47 48 >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask") 49 >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}] 50 >>> defaults = {'sparsity_level': 0.7} 51 >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default) 52 >>> sparsifier = BaseSparsifier(config, defaults) 53 """ 54 55 def __init__(self, defaults: Optional[Dict[str, Any]] = None): 56 super().__init__() 57 self.defaults: Dict[str, Any] = defaults or {} 58 59 self.state: Dict[str, Dict] = defaultdict(dict) 60 self.groups: List[Dict[str, Any]] = [] 61 self.enable_mask_update = True 62 63 def __getstate__(self) -> Dict[str, Any]: 64 return { 65 "defaults": self.defaults, 66 "state": self.state, 67 "groups": self.groups, 68 } 69 70 def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: 71 self.__dict__.update(state) 72 73 def __repr__(self): 74 format_string = self.__class__.__name__ + " (" 75 for i, sparse_args in enumerate(self.groups): 76 module = sparse_args["module"] 77 format_string += "\n" 78 format_string += f"\tGroup {i}\n" 79 format_string += f"\t module: {module}\n" 80 for key in sorted(sparse_args.keys()): 81 if key == "module": 82 continue 83 format_string += f"\t {key}: {sparse_args[key]}\n" 84 format_string += ")" 85 return format_string 86 87 def state_dict(self) -> Dict[str, Any]: 88 r"""Returns the state of the optimizer as a :class:`dict`. 89 90 It contains: 91 * state - current state of the sparsification. 92 * groups - a list containing all sparsity configuration groups 93 with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model 94 95 TODO: Need a clean way of loading the state of the "prepared" module 96 """ 97 98 groups: List[Dict[str, Any]] = [ 99 dict( 100 filter( 101 lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT, 102 mg.items(), 103 ) 104 ) 105 for mg in self.groups 106 ] 107 108 return { 109 "state": self.state, 110 "groups": groups, 111 } 112 113 def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True): 114 groups = copy.deepcopy(state_dict["groups"]) 115 states = state_dict["state"] 116 for tensor_fqn, s in states.items(): 117 arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn) 118 module = arg_info["module"] 119 tensor_name = arg_info["tensor_name"] 120 if strict and module is None: 121 raise RuntimeError(f"Error loading {tensor_fqn} into the model") 122 123 found = False 124 for p in module.parametrizations[tensor_name]: 125 if isinstance(p, FakeSparsity): 126 found = True 127 break 128 if not found: 129 p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape)) 130 parametrize.register_parametrization(module, tensor_name, p) 131 if s.get("mask", None) is not None: 132 mask = s.pop("mask") 133 p.mask = mask 134 135 for mg in groups: 136 if mg["tensor_fqn"] == tensor_fqn: 137 mg.update(arg_info) 138 self.__setstate__({"state": states, "groups": groups}) 139 140 def make_config_from_model( 141 self, 142 model: nn.Module, 143 SUPPORTED_MODULES: Set[Type] = SUPPORTED_MODULES, 144 ) -> None: 145 self.config = [] 146 stack = [model] 147 while stack: 148 module = stack.pop() 149 for name, child in module.named_children(): 150 if type(child) in SUPPORTED_MODULES: 151 module_fqn = module_to_fqn(model, child) 152 assert isinstance(module_fqn, str) # for mypy 153 self.config.append({"tensor_fqn": module_fqn + ".weight"}) 154 else: 155 stack.append(child) 156 157 def prepare(self, model, config): 158 r"""Prepares a model, by adding the parametrizations. 159 160 Note:: 161 162 The model is modified inplace. If you need to preserve the original 163 model, use copy.deepcopy. 164 """ 165 self.model = model # TODO: Need to figure out how to load without this. 166 self.config = config 167 168 # If no config -- try getting all the supported layers 169 if self.config is None: 170 self.make_config_from_model(model) 171 172 # TODO: Remove the configuration by reference ('module') 173 for module_config in self.config: 174 assert isinstance(module_config, dict), ( 175 "config elements should be dicts not modules i.e.:" 176 "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]" 177 ) 178 179 assert isinstance(self.defaults, Dict) # for mypy 180 local_args = copy.deepcopy(self.defaults) 181 local_args.update(module_config) 182 183 tensor_fqn = local_args.get("tensor_fqn", None) 184 assert tensor_fqn is not None, ( 185 "tensor_fqn is a required argument in the sparsity config which" 186 "replaces previous `module` and [module]`fqn` arguments" 187 ) 188 189 # populate all information from tensor_fqn 190 info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) 191 192 # check that whatever was put into local_args agrees with what was obtained 193 # from tensor_fqn 194 for key in info_from_tensor_fqn.keys(): 195 if key in local_args: 196 assert ( 197 info_from_tensor_fqn[key] == local_args[key] 198 or ( 199 key == "tensor_fqn" 200 and "." + info_from_tensor_fqn[key] == local_args[key] 201 ) 202 # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that 203 ), f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" 204 local_args.update(info_from_tensor_fqn) 205 self.groups.append(local_args) 206 self._prepare() 207 208 def _prepare(self, *args, **kwargs): 209 r"""Adds mask parametrization to the layer weight""" 210 for config in self.groups: 211 module = config["module"] 212 tensor_name = config["tensor_name"] 213 parametrization = config.get("parametrization", FakeSparsity) 214 mask = config.get("mask", torch.ones_like(getattr(module, tensor_name))) 215 self.state[config["tensor_fqn"]]["mask"] = mask 216 parametrize.register_parametrization( 217 module, tensor_name, parametrization(mask) 218 ) 219 220 def squash_mask( 221 self, 222 params_to_keep: Optional[Tuple[str, ...]] = None, 223 params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None, 224 *args, 225 **kwargs, 226 ): 227 r"""Squashes the sparse masks into the appropriate tensors. 228 229 If either the `params_to_keep` or `params_to_keep_per_layer` is set, 230 the module will have a `sparse_params` dict attached to it. 231 232 Args: 233 params_to_keep: List of keys to save in the module or a dict 234 representing the modules and keys that will have 235 sparsity parameters saved 236 params_to_keep_per_layer: Dict to specify the params that should be 237 saved for specific layers. The keys in the dict 238 should be the module fqn, while the values should 239 be a list of strings with the names of the variables 240 to save in the `sparse_params` 241 242 Examples: 243 >>> # xdoctest: +SKIP("locals are undefined") 244 >>> # Don't save any sparse params 245 >>> sparsifier.squash_mask() 246 >>> hasattr(model.submodule1, 'sparse_params') 247 False 248 249 >>> # Keep sparse params per layer 250 >>> sparsifier.squash_mask( 251 ... params_to_keep_per_layer={ 252 ... 'submodule1.linear1': ('foo', 'bar'), 253 ... 'submodule2.linear42': ('baz',) 254 ... }) 255 >>> print(model.submodule1.linear1.sparse_params) 256 {'foo': 42, 'bar': 24} 257 >>> print(model.submodule2.linear42.sparse_params) 258 {'baz': 0.1} 259 260 >>> # Keep sparse params for all layers 261 >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar')) 262 >>> print(model.submodule1.linear1.sparse_params) 263 {'foo': 42, 'bar': 24} 264 >>> print(model.submodule2.linear42.sparse_params) 265 {'foo': 42, 'bar': 24} 266 267 >>> # Keep some sparse params for all layers, and specific ones for 268 >>> # some other layers 269 >>> sparsifier.squash_mask( 270 ... params_to_keep=('foo', 'bar'), 271 ... params_to_keep_per_layer={ 272 ... 'submodule2.linear42': ('baz',) 273 ... }) 274 >>> print(model.submodule1.linear1.sparse_params) 275 {'foo': 42, 'bar': 24} 276 >>> print(model.submodule2.linear42.sparse_params) 277 {'foo': 42, 'bar': 24, 'baz': 0.1} 278 """ 279 for config in self.groups: 280 module = config["module"] 281 tensor_name = config["tensor_name"] 282 parametrize.remove_parametrizations( 283 module, tensor_name, leave_parametrized=True 284 ) 285 sparse_params = {} 286 if params_to_keep is not None: 287 global_params = {k: config[k] for k in params_to_keep} 288 sparse_params.update(global_params) 289 if params_to_keep_per_layer is not None: 290 params = params_to_keep_per_layer.get(config["module_fqn"], None) 291 if params is not None: 292 per_layer_params = {k: config[k] for k in params} 293 sparse_params.update(per_layer_params) 294 if sparse_params: 295 # TODO handle multiple tensor being quantized on a single module, where to store sparse_params? 296 module.sparse_params = sparse_params 297 298 def convert( 299 self, 300 module: nn.Module, 301 mapping: Optional[Dict[Type[nn.Module], Type[nn.Module]]] = None, 302 inplace: bool = False, 303 parameterization: Type[nn.Module] = FakeSparsity, 304 ): 305 r"""Converts submodules in input module to a different module according to `mapping` 306 by calling `from_dense` method on the target module class 307 Args: 308 module: input module 309 mapping: a dictionary that maps from source module type to target 310 module type, can be overwritten to allow swapping user defined 311 Modules 312 inplace: carry out model transformations in-place, the original module 313 is mutated 314 """ 315 if mapping is None: 316 raise NotImplementedError("Need to auto generate mapping ") 317 if not inplace: 318 module = copy.deepcopy(module) 319 320 reassign = {} 321 for name, mod in module.named_children(): 322 # leaf node 323 if ( 324 module_contains_param(mod, parameterization) 325 and type_before_parametrizations(mod) in mapping 326 ): 327 reassign[name] = swap_module(mod, mapping) 328 else: 329 # recurse 330 reassign[name] = self.convert( 331 mod, 332 mapping=mapping, 333 inplace=True, 334 parameterization=parameterization, 335 ) 336 337 for key, value in reassign.items(): 338 module._modules[key] = value 339 340 return module 341 342 def step(self, use_path: bool = True) -> None: 343 if not self.enable_mask_update: 344 return 345 with torch.no_grad(): 346 for config in self.groups: 347 self.update_mask(**config) 348 349 @abc.abstractmethod 350 def update_mask(self, module: nn.Module, tensor_name: str, **kwargs): 351 pass 352