1# mypy: allow-untyped-defs 2r"""Pruning methods.""" 3import numbers 4from abc import ABC, abstractmethod 5from collections.abc import Iterable 6from typing import Tuple 7 8import torch 9 10 11class BasePruningMethod(ABC): 12 r"""Abstract base class for creation of new pruning techniques. 13 14 Provides a skeleton for customization requiring the overriding of methods 15 such as :meth:`compute_mask` and :meth:`apply`. 16 """ 17 18 _tensor_name: str 19 20 def __call__(self, module, inputs): 21 r"""Multiply the mask into original tensor and store the result. 22 23 Multiplies the mask (stored in ``module[name + '_mask']``) 24 into the original tensor (stored in ``module[name + '_orig']``) 25 and stores the result into ``module[name]`` by using :meth:`apply_mask`. 26 27 Args: 28 module (nn.Module): module containing the tensor to prune 29 inputs: not used. 30 """ 31 setattr(module, self._tensor_name, self.apply_mask(module)) 32 33 @abstractmethod 34 def compute_mask(self, t, default_mask): 35 r"""Compute and returns a mask for the input tensor ``t``. 36 37 Starting from a base ``default_mask`` (which should be a mask of ones 38 if the tensor has not been pruned yet), generate a random mask to 39 apply on top of the ``default_mask`` according to the specific pruning 40 method recipe. 41 42 Args: 43 t (torch.Tensor): tensor representing the importance scores of the 44 parameter to prune. 45 default_mask (torch.Tensor): Base mask from previous pruning 46 iterations, that need to be respected after the new mask is 47 applied. Same dims as ``t``. 48 49 Returns: 50 mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` 51 """ 52 53 def apply_mask(self, module): 54 r"""Simply handles the multiplication between the parameter being pruned and the generated mask. 55 56 Fetches the mask and the original tensor from the module 57 and returns the pruned version of the tensor. 58 59 Args: 60 module (nn.Module): module containing the tensor to prune 61 62 Returns: 63 pruned_tensor (torch.Tensor): pruned version of the input tensor 64 """ 65 # to carry out the multiplication, the mask needs to have been computed, 66 # so the pruning method must know what tensor it's operating on 67 assert ( 68 self._tensor_name is not None 69 ), f"Module {module} has to be pruned" # this gets set in apply() 70 mask = getattr(module, self._tensor_name + "_mask") 71 orig = getattr(module, self._tensor_name + "_orig") 72 pruned_tensor = mask.to(dtype=orig.dtype) * orig 73 return pruned_tensor 74 75 @classmethod 76 def apply(cls, module, name, *args, importance_scores=None, **kwargs): 77 r"""Add pruning on the fly and reparametrization of a tensor. 78 79 Adds the forward pre-hook that enables pruning on the fly and 80 the reparametrization of a tensor in terms of the original tensor 81 and the pruning mask. 82 83 Args: 84 module (nn.Module): module containing the tensor to prune 85 name (str): parameter name within ``module`` on which pruning 86 will act. 87 args: arguments passed on to a subclass of 88 :class:`BasePruningMethod` 89 importance_scores (torch.Tensor): tensor of importance scores (of 90 same shape as module parameter) used to compute mask for pruning. 91 The values in this tensor indicate the importance of the 92 corresponding elements in the parameter being pruned. 93 If unspecified or None, the parameter will be used in its place. 94 kwargs: keyword arguments passed on to a subclass of a 95 :class:`BasePruningMethod` 96 """ 97 98 def _get_composite_method(cls, module, name, *args, **kwargs): 99 # Check if a pruning method has already been applied to 100 # `module[name]`. If so, store that in `old_method`. 101 old_method = None 102 found = 0 103 # there should technically be only 1 hook with hook.name == name 104 # assert this using `found` 105 hooks_to_remove = [] 106 for k, hook in module._forward_pre_hooks.items(): 107 # if it exists, take existing thing, remove hook, then 108 # go through normal thing 109 if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: 110 old_method = hook 111 hooks_to_remove.append(k) 112 found += 1 113 assert ( 114 found <= 1 115 ), f"Avoid adding multiple pruning hooks to the\ 116 same tensor {name} of module {module}. Use a PruningContainer." 117 118 for k in hooks_to_remove: 119 del module._forward_pre_hooks[k] 120 121 # Apply the new pruning method, either from scratch or on top of 122 # the previous one. 123 method = cls(*args, **kwargs) # new pruning 124 # Have the pruning method remember what tensor it's been applied to 125 method._tensor_name = name 126 127 # combine `methods` with `old_method`, if `old_method` exists 128 if old_method is not None: # meaning that there was a hook 129 # if the hook is already a pruning container, just add the 130 # new pruning method to the container 131 if isinstance(old_method, PruningContainer): 132 old_method.add_pruning_method(method) 133 method = old_method # rename old_method --> method 134 135 # if the hook is simply a single pruning method, create a 136 # container, add the old pruning method and the new one 137 elif isinstance(old_method, BasePruningMethod): 138 container = PruningContainer(old_method) 139 # Have the pruning method remember the name of its tensor 140 # setattr(container, '_tensor_name', name) 141 container.add_pruning_method(method) 142 method = container # rename container --> method 143 return method 144 145 method = _get_composite_method(cls, module, name, *args, **kwargs) 146 # at this point we have no forward_pre_hooks but we could have an 147 # active reparametrization of the tensor if another pruning method 148 # had been applied (in which case `method` would be a PruningContainer 149 # and not a simple pruning method). 150 151 # Pruning is to be applied to the module's tensor named `name`, 152 # starting from the state it is found in prior to this iteration of 153 # pruning. The pruning mask is calculated based on importances scores. 154 155 orig = getattr(module, name) 156 if importance_scores is not None: 157 assert ( 158 importance_scores.shape == orig.shape 159 ), f"importance_scores should have the same shape as parameter {name} of {module}" 160 else: 161 importance_scores = orig 162 163 # If this is the first time pruning is applied, take care of moving 164 # the original tensor to a new parameter called name + '_orig' and 165 # and deleting the original parameter 166 if not isinstance(method, PruningContainer): 167 # copy `module[name]` to `module[name + '_orig']` 168 module.register_parameter(name + "_orig", orig) 169 # temporarily delete `module[name]` 170 del module._parameters[name] 171 default_mask = torch.ones_like(orig) # temp 172 # If this is not the first time pruning is applied, all of the above 173 # has been done before in a previous pruning iteration, so we're good 174 # to go 175 else: 176 default_mask = ( 177 getattr(module, name + "_mask") 178 .detach() 179 .clone(memory_format=torch.contiguous_format) 180 ) 181 182 # Use try/except because if anything goes wrong with the mask 183 # computation etc., you'd want to roll back. 184 try: 185 # get the final mask, computed according to the specific method 186 mask = method.compute_mask(importance_scores, default_mask=default_mask) 187 # reparameterize by saving mask to `module[name + '_mask']`... 188 module.register_buffer(name + "_mask", mask) 189 # ... and the new pruned tensor to `module[name]` 190 setattr(module, name, method.apply_mask(module)) 191 # associate the pruning method to the module via a hook to 192 # compute the function before every forward() (compile by run) 193 module.register_forward_pre_hook(method) 194 195 except Exception as e: 196 if not isinstance(method, PruningContainer): 197 orig = getattr(module, name + "_orig") 198 module.register_parameter(name, orig) 199 del module._parameters[name + "_orig"] 200 raise e 201 202 return method 203 204 def prune(self, t, default_mask=None, importance_scores=None): 205 r"""Compute and returns a pruned version of input tensor ``t``. 206 207 According to the pruning rule specified in :meth:`compute_mask`. 208 209 Args: 210 t (torch.Tensor): tensor to prune (of same dimensions as 211 ``default_mask``). 212 importance_scores (torch.Tensor): tensor of importance scores (of 213 same shape as ``t``) used to compute mask for pruning ``t``. 214 The values in this tensor indicate the importance of the 215 corresponding elements in the ``t`` that is being pruned. 216 If unspecified or None, the tensor ``t`` will be used in its place. 217 default_mask (torch.Tensor, optional): mask from previous pruning 218 iteration, if any. To be considered when determining what 219 portion of the tensor that pruning should act on. If None, 220 default to a mask of ones. 221 222 Returns: 223 pruned version of tensor ``t``. 224 """ 225 if importance_scores is not None: 226 assert ( 227 importance_scores.shape == t.shape 228 ), "importance_scores should have the same shape as tensor t" 229 else: 230 importance_scores = t 231 default_mask = default_mask if default_mask is not None else torch.ones_like(t) 232 return t * self.compute_mask(importance_scores, default_mask=default_mask) 233 234 def remove(self, module): 235 r"""Remove the pruning reparameterization from a module. 236 237 The pruned parameter named ``name`` remains permanently pruned, 238 and the parameter named ``name+'_orig'`` is removed from the parameter list. 239 Similarly, the buffer named ``name+'_mask'`` is removed from the buffers. 240 241 Note: 242 Pruning itself is NOT undone or reversed! 243 """ 244 # before removing pruning from a tensor, it has to have been applied 245 assert ( 246 self._tensor_name is not None 247 ), f"Module {module} has to be pruned before pruning can be removed" # this gets set in apply() 248 249 # to update module[name] to latest trained weights 250 weight = self.apply_mask(module) # masked weights 251 252 # delete and reset 253 if hasattr(module, self._tensor_name): 254 delattr(module, self._tensor_name) 255 orig = module._parameters[self._tensor_name + "_orig"] 256 orig.data = weight.data 257 del module._parameters[self._tensor_name + "_orig"] 258 del module._buffers[self._tensor_name + "_mask"] 259 setattr(module, self._tensor_name, orig) 260 261 262class PruningContainer(BasePruningMethod): 263 """Container holding a sequence of pruning methods for iterative pruning. 264 265 Keeps track of the order in which pruning methods are applied and handles 266 combining successive pruning calls. 267 268 Accepts as argument an instance of a BasePruningMethod or an iterable of 269 them. 270 """ 271 272 def __init__(self, *args): 273 self._pruning_methods: Tuple[BasePruningMethod, ...] = () 274 if not isinstance(args, Iterable): # only 1 item 275 self._tensor_name = args._tensor_name 276 self.add_pruning_method(args) 277 elif len(args) == 1: # only 1 item in a tuple 278 self._tensor_name = args[0]._tensor_name 279 self.add_pruning_method(args[0]) 280 else: # manual construction from list or other iterable (or no args) 281 for method in args: 282 self.add_pruning_method(method) 283 284 def add_pruning_method(self, method): 285 r"""Add a child pruning ``method`` to the container. 286 287 Args: 288 method (subclass of BasePruningMethod): child pruning method 289 to be added to the container. 290 """ 291 # check that we're adding a pruning method to the container 292 if not isinstance(method, BasePruningMethod) and method is not None: 293 raise TypeError(f"{type(method)} is not a BasePruningMethod subclass") 294 elif method is not None and self._tensor_name != method._tensor_name: 295 raise ValueError( 296 "Can only add pruning methods acting on " 297 f"the parameter named '{self._tensor_name}' to PruningContainer {self}." 298 + f" Found '{method._tensor_name}'" 299 ) 300 # if all checks passed, add to _pruning_methods tuple 301 self._pruning_methods += (method,) # type: ignore[operator] 302 303 def __len__(self): 304 return len(self._pruning_methods) 305 306 def __iter__(self): 307 return iter(self._pruning_methods) 308 309 def __getitem__(self, idx): 310 return self._pruning_methods[idx] 311 312 def compute_mask(self, t, default_mask): 313 r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``. 314 315 The new partial mask should be computed on the entries or channels 316 that were not zeroed out by the ``default_mask``. 317 Which portions of the tensor ``t`` the new mask will be calculated from 318 depends on the ``PRUNING_TYPE`` (handled by the type handler): 319 320 * for 'unstructured', the mask will be computed from the raveled 321 list of nonmasked entries; 322 323 * for 'structured', the mask will be computed from the nonmasked 324 channels in the tensor; 325 326 * for 'global', the mask will be computed across all entries. 327 328 Args: 329 t (torch.Tensor): tensor representing the parameter to prune 330 (of same dimensions as ``default_mask``). 331 default_mask (torch.Tensor): mask from previous pruning iteration. 332 333 Returns: 334 mask (torch.Tensor): new mask that combines the effects 335 of the ``default_mask`` and the new mask from the current 336 pruning ``method`` (of same dimensions as ``default_mask`` and 337 ``t``). 338 """ 339 340 def _combine_masks(method, t, mask): 341 r"""Combine the masks from all pruning methods and returns a new mask. 342 343 Args: 344 method (a BasePruningMethod subclass): pruning method 345 currently being applied. 346 t (torch.Tensor): tensor representing the parameter to prune 347 (of same dimensions as mask). 348 mask (torch.Tensor): mask from previous pruning iteration 349 350 Returns: 351 new_mask (torch.Tensor): new mask that combines the effects 352 of the old mask and the new mask from the current 353 pruning method (of same dimensions as mask and t). 354 """ 355 new_mask = mask # start off from existing mask 356 new_mask = new_mask.to(dtype=t.dtype) 357 358 # compute a slice of t onto which the new pruning method will operate 359 if method.PRUNING_TYPE == "unstructured": 360 # prune entries of t where the mask is 1 361 slc = mask == 1 362 363 # for struct pruning, exclude channels that have already been 364 # entirely pruned 365 elif method.PRUNING_TYPE == "structured": 366 if not hasattr(method, "dim"): 367 raise AttributeError( 368 "Pruning methods of PRUNING_TYPE " 369 '"structured" need to have the attribute `dim` defined.' 370 ) 371 372 # find the channels to keep by removing the ones that have been 373 # zeroed out already (i.e. where sum(entries) == 0) 374 n_dims = t.dim() # "is this a 2D tensor? 3D? ..." 375 dim = method.dim 376 # convert negative indexing 377 if dim < 0: 378 dim = n_dims + dim 379 # if dim is still negative after subtracting it from n_dims 380 if dim < 0: 381 raise IndexError( 382 f"Index is out of bounds for tensor with dimensions {n_dims}" 383 ) 384 # find channels along dim = dim that aren't already tots 0ed out 385 keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 386 # create slice to identify what to prune 387 slc = [slice(None)] * n_dims 388 slc[dim] = keep_channel 389 390 elif method.PRUNING_TYPE == "global": 391 n_dims = len(t.shape) # "is this a 2D tensor? 3D? ..." 392 slc = [slice(None)] * n_dims 393 394 else: 395 raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}") 396 397 # compute the new mask on the unpruned slice of the tensor t 398 partial_mask = method.compute_mask(t[slc], default_mask=mask[slc]) 399 new_mask[slc] = partial_mask.to(dtype=new_mask.dtype) 400 401 return new_mask 402 403 method = self._pruning_methods[-1] 404 mask = _combine_masks(method, t, default_mask) 405 return mask 406 407 408class Identity(BasePruningMethod): 409 r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones.""" 410 411 PRUNING_TYPE = "unstructured" 412 413 def compute_mask(self, t, default_mask): 414 mask = default_mask 415 return mask 416 417 @classmethod 418 def apply(cls, module, name): 419 r"""Add pruning on the fly and reparametrization of a tensor. 420 421 Adds the forward pre-hook that enables pruning on the fly and 422 the reparametrization of a tensor in terms of the original tensor 423 and the pruning mask. 424 425 Args: 426 module (nn.Module): module containing the tensor to prune 427 name (str): parameter name within ``module`` on which pruning 428 will act. 429 """ 430 return super().apply(module, name) 431 432 433class RandomUnstructured(BasePruningMethod): 434 r"""Prune (currently unpruned) units in a tensor at random. 435 436 Args: 437 name (str): parameter name within ``module`` on which pruning 438 will act. 439 amount (int or float): quantity of parameters to prune. 440 If ``float``, should be between 0.0 and 1.0 and represent the 441 fraction of parameters to prune. If ``int``, it represents the 442 absolute number of parameters to prune. 443 """ 444 445 PRUNING_TYPE = "unstructured" 446 447 def __init__(self, amount): 448 # Check range of validity of pruning amount 449 _validate_pruning_amount_init(amount) 450 self.amount = amount 451 452 def compute_mask(self, t, default_mask): 453 # Check that the amount of units to prune is not > than the number of 454 # parameters in t 455 tensor_size = t.nelement() 456 # Compute number of units to prune: amount if int, 457 # else amount * tensor_size 458 nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) 459 # This should raise an error if the number of units to prune is larger 460 # than the number of units in the tensor 461 _validate_pruning_amount(nparams_toprune, tensor_size) 462 463 mask = default_mask.clone(memory_format=torch.contiguous_format) 464 465 if nparams_toprune != 0: # k=0 not supported by torch.kthvalue 466 prob = torch.rand_like(t) 467 topk = torch.topk(prob.view(-1), k=nparams_toprune) 468 mask.view(-1)[topk.indices] = 0 469 470 return mask 471 472 @classmethod 473 def apply(cls, module, name, amount): 474 r"""Add pruning on the fly and reparametrization of a tensor. 475 476 Adds the forward pre-hook that enables pruning on the fly and 477 the reparametrization of a tensor in terms of the original tensor 478 and the pruning mask. 479 480 Args: 481 module (nn.Module): module containing the tensor to prune 482 name (str): parameter name within ``module`` on which pruning 483 will act. 484 amount (int or float): quantity of parameters to prune. 485 If ``float``, should be between 0.0 and 1.0 and represent the 486 fraction of parameters to prune. If ``int``, it represents the 487 absolute number of parameters to prune. 488 """ 489 return super().apply(module, name, amount=amount) 490 491 492class L1Unstructured(BasePruningMethod): 493 r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm. 494 495 Args: 496 amount (int or float): quantity of parameters to prune. 497 If ``float``, should be between 0.0 and 1.0 and represent the 498 fraction of parameters to prune. If ``int``, it represents the 499 absolute number of parameters to prune. 500 """ 501 502 PRUNING_TYPE = "unstructured" 503 504 def __init__(self, amount): 505 # Check range of validity of pruning amount 506 _validate_pruning_amount_init(amount) 507 self.amount = amount 508 509 def compute_mask(self, t, default_mask): 510 # Check that the amount of units to prune is not > than the number of 511 # parameters in t 512 tensor_size = t.nelement() 513 # Compute number of units to prune: amount if int, 514 # else amount * tensor_size 515 nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) 516 # This should raise an error if the number of units to prune is larger 517 # than the number of units in the tensor 518 _validate_pruning_amount(nparams_toprune, tensor_size) 519 520 mask = default_mask.clone(memory_format=torch.contiguous_format) 521 522 if nparams_toprune != 0: # k=0 not supported by torch.kthvalue 523 # largest=True --> top k; largest=False --> bottom k 524 # Prune the smallest k 525 topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False) 526 # topk will have .indices and .values 527 mask.view(-1)[topk.indices] = 0 528 529 return mask 530 531 @classmethod 532 def apply(cls, module, name, amount, importance_scores=None): 533 r"""Add pruning on the fly and reparametrization of a tensor. 534 535 Adds the forward pre-hook that enables pruning on the fly and 536 the reparametrization of a tensor in terms of the original tensor 537 and the pruning mask. 538 539 Args: 540 module (nn.Module): module containing the tensor to prune 541 name (str): parameter name within ``module`` on which pruning 542 will act. 543 amount (int or float): quantity of parameters to prune. 544 If ``float``, should be between 0.0 and 1.0 and represent the 545 fraction of parameters to prune. If ``int``, it represents the 546 absolute number of parameters to prune. 547 importance_scores (torch.Tensor): tensor of importance scores (of same 548 shape as module parameter) used to compute mask for pruning. 549 The values in this tensor indicate the importance of the corresponding 550 elements in the parameter being pruned. 551 If unspecified or None, the module parameter will be used in its place. 552 """ 553 return super().apply( 554 module, name, amount=amount, importance_scores=importance_scores 555 ) 556 557 558class RandomStructured(BasePruningMethod): 559 r"""Prune entire (currently unpruned) channels in a tensor at random. 560 561 Args: 562 amount (int or float): quantity of parameters to prune. 563 If ``float``, should be between 0.0 and 1.0 and represent the 564 fraction of parameters to prune. If ``int``, it represents the 565 absolute number of parameters to prune. 566 dim (int, optional): index of the dim along which we define 567 channels to prune. Default: -1. 568 """ 569 570 PRUNING_TYPE = "structured" 571 572 def __init__(self, amount, dim=-1): 573 # Check range of validity of amount 574 _validate_pruning_amount_init(amount) 575 self.amount = amount 576 self.dim = dim 577 578 def compute_mask(self, t, default_mask): 579 r"""Compute and returns a mask for the input tensor ``t``. 580 581 Starting from a base ``default_mask`` (which should be a mask of ones 582 if the tensor has not been pruned yet), generate a random mask to 583 apply on top of the ``default_mask`` by randomly zeroing out channels 584 along the specified dim of the tensor. 585 586 Args: 587 t (torch.Tensor): tensor representing the parameter to prune 588 default_mask (torch.Tensor): Base mask from previous pruning 589 iterations, that need to be respected after the new mask is 590 applied. Same dims as ``t``. 591 592 Returns: 593 mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` 594 595 Raises: 596 IndexError: if ``self.dim >= len(t.shape)`` 597 """ 598 # Check that tensor has structure (i.e. more than 1 dimension) such 599 # that the concept of "channels" makes sense 600 _validate_structured_pruning(t) 601 602 # Check that self.dim is a valid dim to index t, else raise IndexError 603 _validate_pruning_dim(t, self.dim) 604 605 # Check that the amount of channels to prune is not > than the number of 606 # channels in t along the dim to prune 607 tensor_size = t.shape[self.dim] 608 # Compute number of units to prune: amount if int, 609 # else amount * tensor_size 610 nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) 611 # This should raise an error if the number of units to prune is larger 612 # than the number of units in the tensor 613 _validate_pruning_amount(nparams_toprune, tensor_size) 614 615 # Compute binary mask by initializing it to all 0s and then filling in 616 # 1s wherever topk.indices indicates, along self.dim. 617 # mask has the same shape as tensor t 618 def make_mask(t, dim, nchannels, nchannels_toprune): 619 # generate a random number in [0, 1] to associate to each channel 620 prob = torch.rand(nchannels) 621 # generate mask for each channel by 0ing out the channels that 622 # got assigned the k = nchannels_toprune lowest values in prob 623 threshold = torch.kthvalue(prob, k=nchannels_toprune).values 624 channel_mask = prob > threshold 625 626 mask = torch.zeros_like(t) 627 slc = [slice(None)] * len(t.shape) 628 slc[dim] = channel_mask 629 mask[slc] = 1 630 return mask 631 632 if nparams_toprune == 0: # k=0 not supported by torch.kthvalue 633 mask = default_mask 634 else: 635 # apply the new structured mask on top of prior (potentially 636 # unstructured) mask 637 mask = make_mask(t, self.dim, tensor_size, nparams_toprune) 638 mask *= default_mask.to(dtype=mask.dtype) 639 return mask 640 641 @classmethod 642 def apply(cls, module, name, amount, dim=-1): 643 r"""Add pruning on the fly and reparametrization of a tensor. 644 645 Adds the forward pre-hook that enables pruning on the fly and 646 the reparametrization of a tensor in terms of the original tensor 647 and the pruning mask. 648 649 Args: 650 module (nn.Module): module containing the tensor to prune 651 name (str): parameter name within ``module`` on which pruning 652 will act. 653 amount (int or float): quantity of parameters to prune. 654 If ``float``, should be between 0.0 and 1.0 and represent the 655 fraction of parameters to prune. If ``int``, it represents the 656 absolute number of parameters to prune. 657 dim (int, optional): index of the dim along which we define 658 channels to prune. Default: -1. 659 """ 660 return super().apply(module, name, amount=amount, dim=dim) 661 662 663class LnStructured(BasePruningMethod): 664 r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm. 665 666 Args: 667 amount (int or float): quantity of channels to prune. 668 If ``float``, should be between 0.0 and 1.0 and represent the 669 fraction of parameters to prune. If ``int``, it represents the 670 absolute number of parameters to prune. 671 n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid 672 entries for argument ``p`` in :func:`torch.norm`. 673 dim (int, optional): index of the dim along which we define 674 channels to prune. Default: -1. 675 """ 676 677 PRUNING_TYPE = "structured" 678 679 def __init__(self, amount, n, dim=-1): 680 # Check range of validity of amount 681 _validate_pruning_amount_init(amount) 682 self.amount = amount 683 self.n = n 684 self.dim = dim 685 686 def compute_mask(self, t, default_mask): 687 r"""Compute and returns a mask for the input tensor ``t``. 688 689 Starting from a base ``default_mask`` (which should be a mask of ones 690 if the tensor has not been pruned yet), generate a mask to apply on 691 top of the ``default_mask`` by zeroing out the channels along the 692 specified dim with the lowest L\ ``n``-norm. 693 694 Args: 695 t (torch.Tensor): tensor representing the parameter to prune 696 default_mask (torch.Tensor): Base mask from previous pruning 697 iterations, that need to be respected after the new mask is 698 applied. Same dims as ``t``. 699 700 Returns: 701 mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` 702 703 Raises: 704 IndexError: if ``self.dim >= len(t.shape)`` 705 """ 706 # Check that tensor has structure (i.e. more than 1 dimension) such 707 # that the concept of "channels" makes sense 708 _validate_structured_pruning(t) 709 # Check that self.dim is a valid dim to index t, else raise IndexError 710 _validate_pruning_dim(t, self.dim) 711 712 # Check that the amount of channels to prune is not > than the number of 713 # channels in t along the dim to prune 714 tensor_size = t.shape[self.dim] 715 # Compute number of units to prune: amount if int, 716 # else amount * tensor_size 717 nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) 718 nparams_tokeep = tensor_size - nparams_toprune 719 # This should raise an error if the number of units to prune is larger 720 # than the number of units in the tensor 721 _validate_pruning_amount(nparams_toprune, tensor_size) 722 723 # Structured pruning prunes entire channels so we need to know the 724 # L_n norm along each channel to then find the topk based on this 725 # metric 726 norm = _compute_norm(t, self.n, self.dim) 727 # largest=True --> top k; largest=False --> bottom k 728 # Keep the largest k channels along dim=self.dim 729 topk = torch.topk(norm, k=nparams_tokeep, largest=True) 730 # topk will have .indices and .values 731 732 # Compute binary mask by initializing it to all 0s and then filling in 733 # 1s wherever topk.indices indicates, along self.dim. 734 # mask has the same shape as tensor t 735 def make_mask(t, dim, indices): 736 # init mask to 0 737 mask = torch.zeros_like(t) 738 # e.g.: slc = [None, None, None], if len(t.shape) = 3 739 slc = [slice(None)] * len(t.shape) 740 # replace a None at position=dim with indices 741 # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3] 742 slc[dim] = indices 743 # use slc to slice mask and replace all its entries with 1s 744 # e.g.: mask[:, :, [0, 2, 3]] = 1 745 mask[slc] = 1 746 return mask 747 748 if nparams_toprune == 0: # k=0 not supported by torch.kthvalue 749 mask = default_mask 750 else: 751 mask = make_mask(t, self.dim, topk.indices) 752 mask *= default_mask.to(dtype=mask.dtype) 753 754 return mask 755 756 @classmethod 757 def apply(cls, module, name, amount, n, dim, importance_scores=None): 758 r"""Add pruning on the fly and reparametrization of a tensor. 759 760 Adds the forward pre-hook that enables pruning on the fly and 761 the reparametrization of a tensor in terms of the original tensor 762 and the pruning mask. 763 764 Args: 765 module (nn.Module): module containing the tensor to prune 766 name (str): parameter name within ``module`` on which pruning 767 will act. 768 amount (int or float): quantity of parameters to prune. 769 If ``float``, should be between 0.0 and 1.0 and represent the 770 fraction of parameters to prune. If ``int``, it represents the 771 absolute number of parameters to prune. 772 n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid 773 entries for argument ``p`` in :func:`torch.norm`. 774 dim (int): index of the dim along which we define channels to 775 prune. 776 importance_scores (torch.Tensor): tensor of importance scores (of same 777 shape as module parameter) used to compute mask for pruning. 778 The values in this tensor indicate the importance of the corresponding 779 elements in the parameter being pruned. 780 If unspecified or None, the module parameter will be used in its place. 781 """ 782 return super().apply( 783 module, 784 name, 785 amount=amount, 786 n=n, 787 dim=dim, 788 importance_scores=importance_scores, 789 ) 790 791 792class CustomFromMask(BasePruningMethod): 793 PRUNING_TYPE = "global" 794 795 def __init__(self, mask): 796 self.mask = mask 797 798 def compute_mask(self, t, default_mask): 799 assert default_mask.shape == self.mask.shape 800 mask = default_mask * self.mask.to(dtype=default_mask.dtype) 801 return mask 802 803 @classmethod 804 def apply(cls, module, name, mask): 805 r"""Add pruning on the fly and reparametrization of a tensor. 806 807 Adds the forward pre-hook that enables pruning on the fly and 808 the reparametrization of a tensor in terms of the original tensor 809 and the pruning mask. 810 811 Args: 812 module (nn.Module): module containing the tensor to prune 813 name (str): parameter name within ``module`` on which pruning 814 will act. 815 """ 816 return super().apply(module, name, mask=mask) 817 818 819def identity(module, name): 820 r"""Apply pruning reparametrization without pruning any units. 821 822 Applies pruning reparametrization to the tensor corresponding to the 823 parameter called ``name`` in ``module`` without actually pruning any 824 units. Modifies module in place (and also return the modified module) 825 by: 826 827 1) adding a named buffer called ``name+'_mask'`` corresponding to the 828 binary mask applied to the parameter ``name`` by the pruning method. 829 2) replacing the parameter ``name`` by its pruned version, while the 830 original (unpruned) parameter is stored in a new parameter named 831 ``name+'_orig'``. 832 833 Note: 834 The mask is a tensor of ones. 835 836 Args: 837 module (nn.Module): module containing the tensor to prune. 838 name (str): parameter name within ``module`` on which pruning 839 will act. 840 841 Returns: 842 module (nn.Module): modified (i.e. pruned) version of the input module 843 844 Examples: 845 >>> # xdoctest: +SKIP 846 >>> m = prune.identity(nn.Linear(2, 3), 'bias') 847 >>> print(m.bias_mask) 848 tensor([1., 1., 1.]) 849 """ 850 Identity.apply(module, name) 851 return module 852 853 854def random_unstructured(module, name, amount): 855 r"""Prune tensor by removing random (currently unpruned) units. 856 857 Prunes tensor corresponding to parameter called ``name`` in ``module`` 858 by removing the specified ``amount`` of (currently unpruned) units 859 selected at random. 860 Modifies module in place (and also return the modified module) by: 861 862 1) adding a named buffer called ``name+'_mask'`` corresponding to the 863 binary mask applied to the parameter ``name`` by the pruning method. 864 2) replacing the parameter ``name`` by its pruned version, while the 865 original (unpruned) parameter is stored in a new parameter named 866 ``name+'_orig'``. 867 868 Args: 869 module (nn.Module): module containing the tensor to prune 870 name (str): parameter name within ``module`` on which pruning 871 will act. 872 amount (int or float): quantity of parameters to prune. 873 If ``float``, should be between 0.0 and 1.0 and represent the 874 fraction of parameters to prune. If ``int``, it represents the 875 absolute number of parameters to prune. 876 877 Returns: 878 module (nn.Module): modified (i.e. pruned) version of the input module 879 880 Examples: 881 >>> # xdoctest: +SKIP 882 >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1) 883 >>> torch.sum(m.weight_mask == 0) 884 tensor(1) 885 886 """ 887 RandomUnstructured.apply(module, name, amount) 888 return module 889 890 891def l1_unstructured(module, name, amount, importance_scores=None): 892 r"""Prune tensor by removing units with the lowest L1-norm. 893 894 Prunes tensor corresponding to parameter called ``name`` in ``module`` 895 by removing the specified `amount` of (currently unpruned) units with the 896 lowest L1-norm. 897 Modifies module in place (and also return the modified module) 898 by: 899 900 1) adding a named buffer called ``name+'_mask'`` corresponding to the 901 binary mask applied to the parameter ``name`` by the pruning method. 902 2) replacing the parameter ``name`` by its pruned version, while the 903 original (unpruned) parameter is stored in a new parameter named 904 ``name+'_orig'``. 905 906 Args: 907 module (nn.Module): module containing the tensor to prune 908 name (str): parameter name within ``module`` on which pruning 909 will act. 910 amount (int or float): quantity of parameters to prune. 911 If ``float``, should be between 0.0 and 1.0 and represent the 912 fraction of parameters to prune. If ``int``, it represents the 913 absolute number of parameters to prune. 914 importance_scores (torch.Tensor): tensor of importance scores (of same 915 shape as module parameter) used to compute mask for pruning. 916 The values in this tensor indicate the importance of the corresponding 917 elements in the parameter being pruned. 918 If unspecified or None, the module parameter will be used in its place. 919 920 Returns: 921 module (nn.Module): modified (i.e. pruned) version of the input module 922 923 Examples: 924 >>> # xdoctest: +SKIP 925 >>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2) 926 >>> m.state_dict().keys() 927 odict_keys(['bias', 'weight_orig', 'weight_mask']) 928 """ 929 L1Unstructured.apply( 930 module, name, amount=amount, importance_scores=importance_scores 931 ) 932 return module 933 934 935def random_structured(module, name, amount, dim): 936 r"""Prune tensor by removing random channels along the specified dimension. 937 938 Prunes tensor corresponding to parameter called ``name`` in ``module`` 939 by removing the specified ``amount`` of (currently unpruned) channels 940 along the specified ``dim`` selected at random. 941 Modifies module in place (and also return the modified module) 942 by: 943 944 1) adding a named buffer called ``name+'_mask'`` corresponding to the 945 binary mask applied to the parameter ``name`` by the pruning method. 946 2) replacing the parameter ``name`` by its pruned version, while the 947 original (unpruned) parameter is stored in a new parameter named 948 ``name+'_orig'``. 949 950 Args: 951 module (nn.Module): module containing the tensor to prune 952 name (str): parameter name within ``module`` on which pruning 953 will act. 954 amount (int or float): quantity of parameters to prune. 955 If ``float``, should be between 0.0 and 1.0 and represent the 956 fraction of parameters to prune. If ``int``, it represents the 957 absolute number of parameters to prune. 958 dim (int): index of the dim along which we define channels to prune. 959 960 Returns: 961 module (nn.Module): modified (i.e. pruned) version of the input module 962 963 Examples: 964 >>> # xdoctest: +SKIP 965 >>> m = prune.random_structured( 966 ... nn.Linear(5, 3), 'weight', amount=3, dim=1 967 ... ) 968 >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0)) 969 >>> print(columns_pruned) 970 3 971 """ 972 RandomStructured.apply(module, name, amount, dim) 973 return module 974 975 976def ln_structured(module, name, amount, n, dim, importance_scores=None): 977 r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension. 978 979 Prunes tensor corresponding to parameter called ``name`` in ``module`` 980 by removing the specified ``amount`` of (currently unpruned) channels 981 along the specified ``dim`` with the lowest L\ ``n``-norm. 982 Modifies module in place (and also return the modified module) 983 by: 984 985 1) adding a named buffer called ``name+'_mask'`` corresponding to the 986 binary mask applied to the parameter ``name`` by the pruning method. 987 2) replacing the parameter ``name`` by its pruned version, while the 988 original (unpruned) parameter is stored in a new parameter named 989 ``name+'_orig'``. 990 991 Args: 992 module (nn.Module): module containing the tensor to prune 993 name (str): parameter name within ``module`` on which pruning 994 will act. 995 amount (int or float): quantity of parameters to prune. 996 If ``float``, should be between 0.0 and 1.0 and represent the 997 fraction of parameters to prune. If ``int``, it represents the 998 absolute number of parameters to prune. 999 n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid 1000 entries for argument ``p`` in :func:`torch.norm`. 1001 dim (int): index of the dim along which we define channels to prune. 1002 importance_scores (torch.Tensor): tensor of importance scores (of same 1003 shape as module parameter) used to compute mask for pruning. 1004 The values in this tensor indicate the importance of the corresponding 1005 elements in the parameter being pruned. 1006 If unspecified or None, the module parameter will be used in its place. 1007 1008 Returns: 1009 module (nn.Module): modified (i.e. pruned) version of the input module 1010 1011 Examples: 1012 >>> from torch.nn.utils import prune 1013 >>> m = prune.ln_structured( 1014 ... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') 1015 ... ) 1016 """ 1017 LnStructured.apply( 1018 module, name, amount, n, dim, importance_scores=importance_scores 1019 ) 1020 return module 1021 1022 1023def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): 1024 r""" 1025 Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. 1026 1027 Modifies modules in place by: 1028 1029 1) adding a named buffer called ``name+'_mask'`` corresponding to the 1030 binary mask applied to the parameter ``name`` by the pruning method. 1031 2) replacing the parameter ``name`` by its pruned version, while the 1032 original (unpruned) parameter is stored in a new parameter named 1033 ``name+'_orig'``. 1034 1035 Args: 1036 parameters (Iterable of (module, name) tuples): parameters of 1037 the model to prune in a global fashion, i.e. by aggregating all 1038 weights prior to deciding which ones to prune. module must be of 1039 type :class:`nn.Module`, and name must be a string. 1040 pruning_method (function): a valid pruning function from this module, 1041 or a custom one implemented by the user that satisfies the 1042 implementation guidelines and has ``PRUNING_TYPE='unstructured'``. 1043 importance_scores (dict): a dictionary mapping (module, name) tuples to 1044 the corresponding parameter's importance scores tensor. The tensor 1045 should be the same shape as the parameter, and is used for computing 1046 mask for pruning. 1047 If unspecified or None, the parameter will be used in place of its 1048 importance scores. 1049 kwargs: other keyword arguments such as: 1050 amount (int or float): quantity of parameters to prune across the 1051 specified parameters. 1052 If ``float``, should be between 0.0 and 1.0 and represent the 1053 fraction of parameters to prune. If ``int``, it represents the 1054 absolute number of parameters to prune. 1055 1056 Raises: 1057 TypeError: if ``PRUNING_TYPE != 'unstructured'`` 1058 1059 Note: 1060 Since global structured pruning doesn't make much sense unless the 1061 norm is normalized by the size of the parameter, we now limit the 1062 scope of global pruning to unstructured methods. 1063 1064 Examples: 1065 >>> from torch.nn.utils import prune 1066 >>> from collections import OrderedDict 1067 >>> net = nn.Sequential(OrderedDict([ 1068 ... ('first', nn.Linear(10, 4)), 1069 ... ('second', nn.Linear(4, 1)), 1070 ... ])) 1071 >>> parameters_to_prune = ( 1072 ... (net.first, 'weight'), 1073 ... (net.second, 'weight'), 1074 ... ) 1075 >>> prune.global_unstructured( 1076 ... parameters_to_prune, 1077 ... pruning_method=prune.L1Unstructured, 1078 ... amount=10, 1079 ... ) 1080 >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0)) 1081 tensor(10) 1082 1083 """ 1084 # ensure parameters is a list or generator of tuples 1085 if not isinstance(parameters, Iterable): 1086 raise TypeError("global_unstructured(): parameters is not an Iterable") 1087 1088 importance_scores = importance_scores if importance_scores is not None else {} 1089 if not isinstance(importance_scores, dict): 1090 raise TypeError("global_unstructured(): importance_scores must be of type dict") 1091 1092 # flatten importance scores to consider them all at once in global pruning 1093 relevant_importance_scores = torch.nn.utils.parameters_to_vector( 1094 [ 1095 importance_scores.get((module, name), getattr(module, name)) 1096 for (module, name) in parameters 1097 ] 1098 ) 1099 # similarly, flatten the masks (if they exist), or use a flattened vector 1100 # of 1s of the same dimensions as t 1101 default_mask = torch.nn.utils.parameters_to_vector( 1102 [ 1103 getattr(module, name + "_mask", torch.ones_like(getattr(module, name))) 1104 for (module, name) in parameters 1105 ] 1106 ) 1107 1108 # use the canonical pruning methods to compute the new mask, even if the 1109 # parameter is now a flattened out version of `parameters` 1110 container = PruningContainer() 1111 container._tensor_name = "temp" # to make it match that of `method` 1112 method = pruning_method(**kwargs) 1113 method._tensor_name = "temp" # to make it match that of `container` 1114 if method.PRUNING_TYPE != "unstructured": 1115 raise TypeError( 1116 'Only "unstructured" PRUNING_TYPE supported for ' 1117 f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}" 1118 ) 1119 1120 container.add_pruning_method(method) 1121 1122 # use the `compute_mask` method from `PruningContainer` to combine the 1123 # mask computed by the new method with the pre-existing mask 1124 final_mask = container.compute_mask(relevant_importance_scores, default_mask) 1125 1126 # Pointer for slicing the mask to match the shape of each parameter 1127 pointer = 0 1128 for module, name in parameters: 1129 param = getattr(module, name) 1130 # The length of the parameter 1131 num_param = param.numel() 1132 # Slice the mask, reshape it 1133 param_mask = final_mask[pointer : pointer + num_param].view_as(param) 1134 # Assign the correct pre-computed mask to each parameter and add it 1135 # to the forward_pre_hooks like any other pruning method 1136 custom_from_mask(module, name, mask=param_mask) 1137 1138 # Increment the pointer to continue slicing the final_mask 1139 pointer += num_param 1140 1141 1142def custom_from_mask(module, name, mask): 1143 r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``. 1144 1145 Modifies module in place (and also return the modified module) by: 1146 1147 1) adding a named buffer called ``name+'_mask'`` corresponding to the 1148 binary mask applied to the parameter ``name`` by the pruning method. 1149 2) replacing the parameter ``name`` by its pruned version, while the 1150 original (unpruned) parameter is stored in a new parameter named 1151 ``name+'_orig'``. 1152 1153 Args: 1154 module (nn.Module): module containing the tensor to prune 1155 name (str): parameter name within ``module`` on which pruning 1156 will act. 1157 mask (Tensor): binary mask to be applied to the parameter. 1158 1159 Returns: 1160 module (nn.Module): modified (i.e. pruned) version of the input module 1161 1162 Examples: 1163 >>> from torch.nn.utils import prune 1164 >>> m = prune.custom_from_mask( 1165 ... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0]) 1166 ... ) 1167 >>> print(m.bias_mask) 1168 tensor([0., 1., 0.]) 1169 1170 """ 1171 CustomFromMask.apply(module, name, mask) 1172 return module 1173 1174 1175def remove(module, name): 1176 r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook. 1177 1178 The pruned parameter named ``name`` remains permanently pruned, and the parameter 1179 named ``name+'_orig'`` is removed from the parameter list. Similarly, 1180 the buffer named ``name+'_mask'`` is removed from the buffers. 1181 1182 Note: 1183 Pruning itself is NOT undone or reversed! 1184 1185 Args: 1186 module (nn.Module): module containing the tensor to prune 1187 name (str): parameter name within ``module`` on which pruning 1188 will act. 1189 1190 Examples: 1191 >>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2) 1192 >>> m = remove(m, name='weight') 1193 """ 1194 for k, hook in module._forward_pre_hooks.items(): 1195 if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: 1196 hook.remove(module) 1197 del module._forward_pre_hooks[k] 1198 return module 1199 1200 raise ValueError( 1201 f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed" 1202 ) 1203 1204 1205def is_pruned(module): 1206 r"""Check if a module is pruned by looking for pruning pre-hooks. 1207 1208 Check whether ``module`` is pruned by looking for 1209 ``forward_pre_hooks`` in its modules that inherit from the 1210 :class:`BasePruningMethod`. 1211 1212 Args: 1213 module (nn.Module): object that is either pruned or unpruned 1214 1215 Returns: 1216 binary answer to whether ``module`` is pruned. 1217 1218 Examples: 1219 >>> from torch.nn.utils import prune 1220 >>> m = nn.Linear(5, 7) 1221 >>> print(prune.is_pruned(m)) 1222 False 1223 >>> prune.random_unstructured(m, name='weight', amount=0.2) 1224 >>> print(prune.is_pruned(m)) 1225 True 1226 """ 1227 for _, submodule in module.named_modules(): 1228 for hook in submodule._forward_pre_hooks.values(): 1229 if isinstance(hook, BasePruningMethod): 1230 return True 1231 return False 1232 1233 1234def _validate_pruning_amount_init(amount): 1235 r"""Validate helper to check the range of amount at init. 1236 1237 Args: 1238 amount (int or float): quantity of parameters to prune. 1239 If float, should be between 0.0 and 1.0 and represent the 1240 fraction of parameters to prune. If int, it represents the 1241 absolute number of parameters to prune. 1242 1243 Raises: 1244 ValueError: if amount is a float not in [0, 1], or if it's a negative 1245 integer. 1246 TypeError: if amount is neither a float nor an integer. 1247 1248 Note: 1249 This does not take into account the number of parameters in the 1250 tensor to be pruned, which is known only at prune. 1251 """ 1252 if not isinstance(amount, numbers.Real): 1253 raise TypeError(f"Invalid type for amount: {amount}. Must be int or float.") 1254 1255 if (isinstance(amount, numbers.Integral) and amount < 0) or ( 1256 not isinstance(amount, numbers.Integral) # so it's a float 1257 and (float(amount) > 1.0 or float(amount) < 0.0) 1258 ): 1259 raise ValueError( 1260 f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer" 1261 ) 1262 1263 1264def _validate_pruning_amount(amount, tensor_size): 1265 r"""Validate that the pruning amount is meaningful wrt to the size of the data. 1266 1267 Validation helper to check that the amount of parameters to prune 1268 is meaningful wrt to the size of the data (`tensor_size`). 1269 1270 Args: 1271 amount (int or float): quantity of parameters to prune. 1272 If float, should be between 0.0 and 1.0 and represent the 1273 fraction of parameters to prune. If int, it represents the 1274 absolute number of parameters to prune. 1275 tensor_size (int): absolute number of parameters in the tensor 1276 to prune. 1277 """ 1278 # TODO: consider removing this check and allowing users to specify 1279 # a number of units to prune that is greater than the number of units 1280 # left to prune. In this case, the tensor will just be fully pruned. 1281 1282 if isinstance(amount, numbers.Integral) and amount > tensor_size: 1283 raise ValueError( 1284 f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}" 1285 ) 1286 1287 1288def _validate_structured_pruning(t): 1289 r"""Validate that the tensor to be pruned is at least 2-Dimensional. 1290 1291 Validation helper to check that the tensor to be pruned is multi- 1292 dimensional, such that the concept of "channels" is well-defined. 1293 1294 Args: 1295 t (torch.Tensor): tensor representing the parameter to prune 1296 1297 Raises: 1298 ValueError: if the tensor `t` is not at least 2D. 1299 """ 1300 shape = t.shape 1301 if len(shape) <= 1: 1302 raise ValueError( 1303 "Structured pruning can only be applied to " 1304 "multidimensional tensors. Found tensor of shape " 1305 f"{shape} with {len(shape)} dims" 1306 ) 1307 1308 1309def _compute_nparams_toprune(amount, tensor_size): 1310 r"""Convert the pruning amount from a percentage to absolute value. 1311 1312 Since amount can be expressed either in absolute value or as a 1313 percentage of the number of units/channels in a tensor, this utility 1314 function converts the percentage to absolute value to standardize 1315 the handling of pruning. 1316 1317 Args: 1318 amount (int or float): quantity of parameters to prune. 1319 If float, should be between 0.0 and 1.0 and represent the 1320 fraction of parameters to prune. If int, it represents the 1321 absolute number of parameters to prune. 1322 tensor_size (int): absolute number of parameters in the tensor 1323 to prune. 1324 1325 Returns: 1326 int: the number of units to prune in the tensor 1327 """ 1328 # incorrect type already checked in _validate_pruning_amount_init 1329 if isinstance(amount, numbers.Integral): 1330 return amount 1331 else: 1332 return round(amount * tensor_size) 1333 1334 1335def _validate_pruning_dim(t, dim): 1336 r"""Validate that the pruning dimension is within the bounds of the tensor dimension. 1337 1338 Args: 1339 t (torch.Tensor): tensor representing the parameter to prune 1340 dim (int): index of the dim along which we define channels to prune 1341 """ 1342 if dim >= t.dim(): 1343 raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}") 1344 1345 1346def _compute_norm(t, n, dim): 1347 r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension. 1348 1349 The L_n-norm will be computed across all entries in tensor `t` along all dimension 1350 except for the one identified by dim. 1351 Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim), 1352 then norm will have Size [4], and each entry will represent the 1353 `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels. 1354 1355 Args: 1356 t (torch.Tensor): tensor representing the parameter to prune 1357 n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid 1358 entries for argument p in torch.norm 1359 dim (int): dim identifying the channels to prune 1360 1361 Returns: 1362 norm (torch.Tensor): L_n norm computed across all dimensions except 1363 for `dim`. By construction, `norm.shape = t.shape[-1]`. 1364 """ 1365 # dims = all axes, except for the one identified by `dim` 1366 dims = list(range(t.dim())) 1367 # convert negative indexing 1368 if dim < 0: 1369 dim = dims[dim] 1370 dims.remove(dim) 1371 1372 norm = torch.norm(t, p=n, dim=dims) 1373 return norm 1374