xref: /aosp_15_r20/external/pytorch/torch/nn/utils/prune.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Pruning methods."""
3import numbers
4from abc import ABC, abstractmethod
5from collections.abc import Iterable
6from typing import Tuple
7
8import torch
9
10
11class BasePruningMethod(ABC):
12    r"""Abstract base class for creation of new pruning techniques.
13
14    Provides a skeleton for customization requiring the overriding of methods
15    such as :meth:`compute_mask` and :meth:`apply`.
16    """
17
18    _tensor_name: str
19
20    def __call__(self, module, inputs):
21        r"""Multiply the mask into original tensor and store the result.
22
23        Multiplies the mask (stored in ``module[name + '_mask']``)
24        into the original tensor (stored in ``module[name + '_orig']``)
25        and stores the result into ``module[name]`` by using :meth:`apply_mask`.
26
27        Args:
28            module (nn.Module): module containing the tensor to prune
29            inputs: not used.
30        """
31        setattr(module, self._tensor_name, self.apply_mask(module))
32
33    @abstractmethod
34    def compute_mask(self, t, default_mask):
35        r"""Compute and returns a mask for the input tensor ``t``.
36
37        Starting from a base ``default_mask`` (which should be a mask of ones
38        if the tensor has not been pruned yet), generate a random mask to
39        apply on top of the ``default_mask`` according to the specific pruning
40        method recipe.
41
42        Args:
43            t (torch.Tensor): tensor representing the importance scores of the
44            parameter to prune.
45            default_mask (torch.Tensor): Base mask from previous pruning
46            iterations, that need to be respected after the new mask is
47            applied. Same dims as ``t``.
48
49        Returns:
50            mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
51        """
52
53    def apply_mask(self, module):
54        r"""Simply handles the multiplication between the parameter being pruned and the generated mask.
55
56        Fetches the mask and the original tensor from the module
57        and returns the pruned version of the tensor.
58
59        Args:
60            module (nn.Module): module containing the tensor to prune
61
62        Returns:
63            pruned_tensor (torch.Tensor): pruned version of the input tensor
64        """
65        # to carry out the multiplication, the mask needs to have been computed,
66        # so the pruning method must know what tensor it's operating on
67        assert (
68            self._tensor_name is not None
69        ), f"Module {module} has to be pruned"  # this gets set in apply()
70        mask = getattr(module, self._tensor_name + "_mask")
71        orig = getattr(module, self._tensor_name + "_orig")
72        pruned_tensor = mask.to(dtype=orig.dtype) * orig
73        return pruned_tensor
74
75    @classmethod
76    def apply(cls, module, name, *args, importance_scores=None, **kwargs):
77        r"""Add pruning on the fly and reparametrization of a tensor.
78
79        Adds the forward pre-hook that enables pruning on the fly and
80        the reparametrization of a tensor in terms of the original tensor
81        and the pruning mask.
82
83        Args:
84            module (nn.Module): module containing the tensor to prune
85            name (str): parameter name within ``module`` on which pruning
86                will act.
87            args: arguments passed on to a subclass of
88                :class:`BasePruningMethod`
89            importance_scores (torch.Tensor): tensor of importance scores (of
90                same shape as module parameter) used to compute mask for pruning.
91                The values in this tensor indicate the importance of the
92                corresponding elements in the parameter being pruned.
93                If unspecified or None, the parameter will be used in its place.
94            kwargs: keyword arguments passed on to a subclass of a
95                :class:`BasePruningMethod`
96        """
97
98        def _get_composite_method(cls, module, name, *args, **kwargs):
99            # Check if a pruning method has already been applied to
100            # `module[name]`. If so, store that in `old_method`.
101            old_method = None
102            found = 0
103            # there should technically be only 1 hook with hook.name == name
104            # assert this using `found`
105            hooks_to_remove = []
106            for k, hook in module._forward_pre_hooks.items():
107                # if it exists, take existing thing, remove hook, then
108                # go through normal thing
109                if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
110                    old_method = hook
111                    hooks_to_remove.append(k)
112                    found += 1
113            assert (
114                found <= 1
115            ), f"Avoid adding multiple pruning hooks to the\
116                same tensor {name} of module {module}. Use a PruningContainer."
117
118            for k in hooks_to_remove:
119                del module._forward_pre_hooks[k]
120
121            # Apply the new pruning method, either from scratch or on top of
122            # the previous one.
123            method = cls(*args, **kwargs)  # new pruning
124            # Have the pruning method remember what tensor it's been applied to
125            method._tensor_name = name
126
127            # combine `methods` with `old_method`, if `old_method` exists
128            if old_method is not None:  # meaning that there was a hook
129                # if the hook is already a pruning container, just add the
130                # new pruning method to the container
131                if isinstance(old_method, PruningContainer):
132                    old_method.add_pruning_method(method)
133                    method = old_method  # rename old_method --> method
134
135                # if the hook is simply a single pruning method, create a
136                # container, add the old pruning method and the new one
137                elif isinstance(old_method, BasePruningMethod):
138                    container = PruningContainer(old_method)
139                    # Have the pruning method remember the name of its tensor
140                    # setattr(container, '_tensor_name', name)
141                    container.add_pruning_method(method)
142                    method = container  # rename container --> method
143            return method
144
145        method = _get_composite_method(cls, module, name, *args, **kwargs)
146        # at this point we have no forward_pre_hooks but we could have an
147        # active reparametrization of the tensor if another pruning method
148        # had been applied (in which case `method` would be a PruningContainer
149        # and not a simple pruning method).
150
151        # Pruning is to be applied to the module's tensor named `name`,
152        # starting from the state it is found in prior to this iteration of
153        # pruning. The pruning mask is calculated based on importances scores.
154
155        orig = getattr(module, name)
156        if importance_scores is not None:
157            assert (
158                importance_scores.shape == orig.shape
159            ), f"importance_scores should have the same shape as parameter                 {name} of {module}"
160        else:
161            importance_scores = orig
162
163        # If this is the first time pruning is applied, take care of moving
164        # the original tensor to a new parameter called name + '_orig' and
165        # and deleting the original parameter
166        if not isinstance(method, PruningContainer):
167            # copy `module[name]` to `module[name + '_orig']`
168            module.register_parameter(name + "_orig", orig)
169            # temporarily delete `module[name]`
170            del module._parameters[name]
171            default_mask = torch.ones_like(orig)  # temp
172        # If this is not the first time pruning is applied, all of the above
173        # has been done before in a previous pruning iteration, so we're good
174        # to go
175        else:
176            default_mask = (
177                getattr(module, name + "_mask")
178                .detach()
179                .clone(memory_format=torch.contiguous_format)
180            )
181
182        # Use try/except because if anything goes wrong with the mask
183        # computation etc., you'd want to roll back.
184        try:
185            # get the final mask, computed according to the specific method
186            mask = method.compute_mask(importance_scores, default_mask=default_mask)
187            # reparameterize by saving mask to `module[name + '_mask']`...
188            module.register_buffer(name + "_mask", mask)
189            # ... and the new pruned tensor to `module[name]`
190            setattr(module, name, method.apply_mask(module))
191            # associate the pruning method to the module via a hook to
192            # compute the function before every forward() (compile by run)
193            module.register_forward_pre_hook(method)
194
195        except Exception as e:
196            if not isinstance(method, PruningContainer):
197                orig = getattr(module, name + "_orig")
198                module.register_parameter(name, orig)
199                del module._parameters[name + "_orig"]
200            raise e
201
202        return method
203
204    def prune(self, t, default_mask=None, importance_scores=None):
205        r"""Compute and returns a pruned version of input tensor ``t``.
206
207        According to the pruning rule specified in :meth:`compute_mask`.
208
209        Args:
210            t (torch.Tensor): tensor to prune (of same dimensions as
211                ``default_mask``).
212            importance_scores (torch.Tensor): tensor of importance scores (of
213                same shape as ``t``) used to compute mask for pruning ``t``.
214                The values in this tensor indicate the importance of the
215                corresponding elements in the ``t`` that is being pruned.
216                If unspecified or None, the tensor ``t`` will be used in its place.
217            default_mask (torch.Tensor, optional): mask from previous pruning
218                iteration, if any. To be considered when determining what
219                portion of the tensor that pruning should act on. If None,
220                default to a mask of ones.
221
222        Returns:
223            pruned version of tensor ``t``.
224        """
225        if importance_scores is not None:
226            assert (
227                importance_scores.shape == t.shape
228            ), "importance_scores should have the same shape as tensor t"
229        else:
230            importance_scores = t
231        default_mask = default_mask if default_mask is not None else torch.ones_like(t)
232        return t * self.compute_mask(importance_scores, default_mask=default_mask)
233
234    def remove(self, module):
235        r"""Remove the pruning reparameterization from a module.
236
237        The pruned parameter named ``name`` remains permanently pruned,
238        and the parameter named ``name+'_orig'`` is removed from the parameter list.
239        Similarly, the buffer named ``name+'_mask'`` is removed from the buffers.
240
241        Note:
242            Pruning itself is NOT undone or reversed!
243        """
244        # before removing pruning from a tensor, it has to have been applied
245        assert (
246            self._tensor_name is not None
247        ), f"Module {module} has to be pruned            before pruning can be removed"  # this gets set in apply()
248
249        # to update module[name] to latest trained weights
250        weight = self.apply_mask(module)  # masked weights
251
252        # delete and reset
253        if hasattr(module, self._tensor_name):
254            delattr(module, self._tensor_name)
255        orig = module._parameters[self._tensor_name + "_orig"]
256        orig.data = weight.data
257        del module._parameters[self._tensor_name + "_orig"]
258        del module._buffers[self._tensor_name + "_mask"]
259        setattr(module, self._tensor_name, orig)
260
261
262class PruningContainer(BasePruningMethod):
263    """Container holding a sequence of pruning methods for iterative pruning.
264
265    Keeps track of the order in which pruning methods are applied and handles
266    combining successive pruning calls.
267
268    Accepts as argument an instance of a BasePruningMethod or an iterable of
269    them.
270    """
271
272    def __init__(self, *args):
273        self._pruning_methods: Tuple[BasePruningMethod, ...] = ()
274        if not isinstance(args, Iterable):  # only 1 item
275            self._tensor_name = args._tensor_name
276            self.add_pruning_method(args)
277        elif len(args) == 1:  # only 1 item in a tuple
278            self._tensor_name = args[0]._tensor_name
279            self.add_pruning_method(args[0])
280        else:  # manual construction from list or other iterable (or no args)
281            for method in args:
282                self.add_pruning_method(method)
283
284    def add_pruning_method(self, method):
285        r"""Add a child pruning ``method`` to the container.
286
287        Args:
288            method (subclass of BasePruningMethod): child pruning method
289                to be added to the container.
290        """
291        # check that we're adding a pruning method to the container
292        if not isinstance(method, BasePruningMethod) and method is not None:
293            raise TypeError(f"{type(method)} is not a BasePruningMethod subclass")
294        elif method is not None and self._tensor_name != method._tensor_name:
295            raise ValueError(
296                "Can only add pruning methods acting on "
297                f"the parameter named '{self._tensor_name}' to PruningContainer {self}."
298                + f" Found '{method._tensor_name}'"
299            )
300        # if all checks passed, add to _pruning_methods tuple
301        self._pruning_methods += (method,)  # type: ignore[operator]
302
303    def __len__(self):
304        return len(self._pruning_methods)
305
306    def __iter__(self):
307        return iter(self._pruning_methods)
308
309    def __getitem__(self, idx):
310        return self._pruning_methods[idx]
311
312    def compute_mask(self, t, default_mask):
313        r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``.
314
315        The new partial mask should be computed on the entries or channels
316        that were not zeroed out by the ``default_mask``.
317        Which portions of the tensor ``t`` the new mask will be calculated from
318        depends on the ``PRUNING_TYPE`` (handled by the type handler):
319
320        * for 'unstructured', the mask will be computed from the raveled
321          list of nonmasked entries;
322
323        * for 'structured', the mask will be computed from the nonmasked
324          channels in the tensor;
325
326        * for 'global', the mask will be computed across all entries.
327
328        Args:
329            t (torch.Tensor): tensor representing the parameter to prune
330                (of same dimensions as ``default_mask``).
331            default_mask (torch.Tensor): mask from previous pruning iteration.
332
333        Returns:
334            mask (torch.Tensor): new mask that combines the effects
335            of the ``default_mask`` and the new mask from the current
336            pruning ``method`` (of same dimensions as ``default_mask`` and
337            ``t``).
338        """
339
340        def _combine_masks(method, t, mask):
341            r"""Combine the masks from all pruning methods and returns a new mask.
342
343            Args:
344                method (a BasePruningMethod subclass): pruning method
345                    currently being applied.
346                t (torch.Tensor): tensor representing the parameter to prune
347                    (of same dimensions as mask).
348                mask (torch.Tensor): mask from previous pruning iteration
349
350            Returns:
351                new_mask (torch.Tensor): new mask that combines the effects
352                    of the old mask and the new mask from the current
353                    pruning method (of same dimensions as mask and t).
354            """
355            new_mask = mask  # start off from existing mask
356            new_mask = new_mask.to(dtype=t.dtype)
357
358            # compute a slice of t onto which the new pruning method will operate
359            if method.PRUNING_TYPE == "unstructured":
360                # prune entries of t where the mask is 1
361                slc = mask == 1
362
363            # for struct pruning, exclude channels that have already been
364            # entirely pruned
365            elif method.PRUNING_TYPE == "structured":
366                if not hasattr(method, "dim"):
367                    raise AttributeError(
368                        "Pruning methods of PRUNING_TYPE "
369                        '"structured" need to have the attribute `dim` defined.'
370                    )
371
372                # find the channels to keep by removing the ones that have been
373                # zeroed out already (i.e. where sum(entries) == 0)
374                n_dims = t.dim()  # "is this a 2D tensor? 3D? ..."
375                dim = method.dim
376                # convert negative indexing
377                if dim < 0:
378                    dim = n_dims + dim
379                # if dim is still negative after subtracting it from n_dims
380                if dim < 0:
381                    raise IndexError(
382                        f"Index is out of bounds for tensor with dimensions {n_dims}"
383                    )
384                # find channels along dim = dim that aren't already tots 0ed out
385                keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0
386                # create slice to identify what to prune
387                slc = [slice(None)] * n_dims
388                slc[dim] = keep_channel
389
390            elif method.PRUNING_TYPE == "global":
391                n_dims = len(t.shape)  # "is this a 2D tensor? 3D? ..."
392                slc = [slice(None)] * n_dims
393
394            else:
395                raise ValueError(f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}")
396
397            # compute the new mask on the unpruned slice of the tensor t
398            partial_mask = method.compute_mask(t[slc], default_mask=mask[slc])
399            new_mask[slc] = partial_mask.to(dtype=new_mask.dtype)
400
401            return new_mask
402
403        method = self._pruning_methods[-1]
404        mask = _combine_masks(method, t, default_mask)
405        return mask
406
407
408class Identity(BasePruningMethod):
409    r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones."""
410
411    PRUNING_TYPE = "unstructured"
412
413    def compute_mask(self, t, default_mask):
414        mask = default_mask
415        return mask
416
417    @classmethod
418    def apply(cls, module, name):
419        r"""Add pruning on the fly and reparametrization of a tensor.
420
421        Adds the forward pre-hook that enables pruning on the fly and
422        the reparametrization of a tensor in terms of the original tensor
423        and the pruning mask.
424
425        Args:
426            module (nn.Module): module containing the tensor to prune
427            name (str): parameter name within ``module`` on which pruning
428                will act.
429        """
430        return super().apply(module, name)
431
432
433class RandomUnstructured(BasePruningMethod):
434    r"""Prune (currently unpruned) units in a tensor at random.
435
436    Args:
437        name (str): parameter name within ``module`` on which pruning
438            will act.
439        amount (int or float): quantity of parameters to prune.
440            If ``float``, should be between 0.0 and 1.0 and represent the
441            fraction of parameters to prune. If ``int``, it represents the
442            absolute number of parameters to prune.
443    """
444
445    PRUNING_TYPE = "unstructured"
446
447    def __init__(self, amount):
448        # Check range of validity of pruning amount
449        _validate_pruning_amount_init(amount)
450        self.amount = amount
451
452    def compute_mask(self, t, default_mask):
453        # Check that the amount of units to prune is not > than the number of
454        # parameters in t
455        tensor_size = t.nelement()
456        # Compute number of units to prune: amount if int,
457        # else amount * tensor_size
458        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
459        # This should raise an error if the number of units to prune is larger
460        # than the number of units in the tensor
461        _validate_pruning_amount(nparams_toprune, tensor_size)
462
463        mask = default_mask.clone(memory_format=torch.contiguous_format)
464
465        if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
466            prob = torch.rand_like(t)
467            topk = torch.topk(prob.view(-1), k=nparams_toprune)
468            mask.view(-1)[topk.indices] = 0
469
470        return mask
471
472    @classmethod
473    def apply(cls, module, name, amount):
474        r"""Add pruning on the fly and reparametrization of a tensor.
475
476        Adds the forward pre-hook that enables pruning on the fly and
477        the reparametrization of a tensor in terms of the original tensor
478        and the pruning mask.
479
480        Args:
481            module (nn.Module): module containing the tensor to prune
482            name (str): parameter name within ``module`` on which pruning
483                will act.
484            amount (int or float): quantity of parameters to prune.
485                If ``float``, should be between 0.0 and 1.0 and represent the
486                fraction of parameters to prune. If ``int``, it represents the
487                absolute number of parameters to prune.
488        """
489        return super().apply(module, name, amount=amount)
490
491
492class L1Unstructured(BasePruningMethod):
493    r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm.
494
495    Args:
496        amount (int or float): quantity of parameters to prune.
497            If ``float``, should be between 0.0 and 1.0 and represent the
498            fraction of parameters to prune. If ``int``, it represents the
499            absolute number of parameters to prune.
500    """
501
502    PRUNING_TYPE = "unstructured"
503
504    def __init__(self, amount):
505        # Check range of validity of pruning amount
506        _validate_pruning_amount_init(amount)
507        self.amount = amount
508
509    def compute_mask(self, t, default_mask):
510        # Check that the amount of units to prune is not > than the number of
511        # parameters in t
512        tensor_size = t.nelement()
513        # Compute number of units to prune: amount if int,
514        # else amount * tensor_size
515        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
516        # This should raise an error if the number of units to prune is larger
517        # than the number of units in the tensor
518        _validate_pruning_amount(nparams_toprune, tensor_size)
519
520        mask = default_mask.clone(memory_format=torch.contiguous_format)
521
522        if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
523            # largest=True --> top k; largest=False --> bottom k
524            # Prune the smallest k
525            topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False)
526            # topk will have .indices and .values
527            mask.view(-1)[topk.indices] = 0
528
529        return mask
530
531    @classmethod
532    def apply(cls, module, name, amount, importance_scores=None):
533        r"""Add pruning on the fly and reparametrization of a tensor.
534
535        Adds the forward pre-hook that enables pruning on the fly and
536        the reparametrization of a tensor in terms of the original tensor
537        and the pruning mask.
538
539        Args:
540            module (nn.Module): module containing the tensor to prune
541            name (str): parameter name within ``module`` on which pruning
542                will act.
543            amount (int or float): quantity of parameters to prune.
544                If ``float``, should be between 0.0 and 1.0 and represent the
545                fraction of parameters to prune. If ``int``, it represents the
546                absolute number of parameters to prune.
547            importance_scores (torch.Tensor): tensor of importance scores (of same
548                shape as module parameter) used to compute mask for pruning.
549                The values in this tensor indicate the importance of the corresponding
550                elements in the parameter being pruned.
551                If unspecified or None, the module parameter will be used in its place.
552        """
553        return super().apply(
554            module, name, amount=amount, importance_scores=importance_scores
555        )
556
557
558class RandomStructured(BasePruningMethod):
559    r"""Prune entire (currently unpruned) channels in a tensor at random.
560
561    Args:
562        amount (int or float): quantity of parameters to prune.
563            If ``float``, should be between 0.0 and 1.0 and represent the
564            fraction of parameters to prune. If ``int``, it represents the
565            absolute number of parameters to prune.
566        dim (int, optional): index of the dim along which we define
567            channels to prune. Default: -1.
568    """
569
570    PRUNING_TYPE = "structured"
571
572    def __init__(self, amount, dim=-1):
573        # Check range of validity of amount
574        _validate_pruning_amount_init(amount)
575        self.amount = amount
576        self.dim = dim
577
578    def compute_mask(self, t, default_mask):
579        r"""Compute and returns a mask for the input tensor ``t``.
580
581        Starting from a base ``default_mask`` (which should be a mask of ones
582        if the tensor has not been pruned yet), generate a random mask to
583        apply on top of the ``default_mask`` by randomly zeroing out channels
584        along the specified dim of the tensor.
585
586        Args:
587            t (torch.Tensor): tensor representing the parameter to prune
588            default_mask (torch.Tensor): Base mask from previous pruning
589                iterations, that need to be respected after the new mask is
590                applied. Same dims as ``t``.
591
592        Returns:
593            mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
594
595        Raises:
596            IndexError: if ``self.dim >= len(t.shape)``
597        """
598        # Check that tensor has structure (i.e. more than 1 dimension) such
599        # that the concept of "channels" makes sense
600        _validate_structured_pruning(t)
601
602        # Check that self.dim is a valid dim to index t, else raise IndexError
603        _validate_pruning_dim(t, self.dim)
604
605        # Check that the amount of channels to prune is not > than the number of
606        # channels in t along the dim to prune
607        tensor_size = t.shape[self.dim]
608        # Compute number of units to prune: amount if int,
609        # else amount * tensor_size
610        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
611        # This should raise an error if the number of units to prune is larger
612        # than the number of units in the tensor
613        _validate_pruning_amount(nparams_toprune, tensor_size)
614
615        # Compute binary mask by initializing it to all 0s and then filling in
616        # 1s wherever topk.indices indicates, along self.dim.
617        # mask has the same shape as tensor t
618        def make_mask(t, dim, nchannels, nchannels_toprune):
619            # generate a random number in [0, 1] to associate to each channel
620            prob = torch.rand(nchannels)
621            # generate mask for each channel by 0ing out the channels that
622            # got assigned the k = nchannels_toprune lowest values in prob
623            threshold = torch.kthvalue(prob, k=nchannels_toprune).values
624            channel_mask = prob > threshold
625
626            mask = torch.zeros_like(t)
627            slc = [slice(None)] * len(t.shape)
628            slc[dim] = channel_mask
629            mask[slc] = 1
630            return mask
631
632        if nparams_toprune == 0:  # k=0 not supported by torch.kthvalue
633            mask = default_mask
634        else:
635            # apply the new structured mask on top of prior (potentially
636            # unstructured) mask
637            mask = make_mask(t, self.dim, tensor_size, nparams_toprune)
638            mask *= default_mask.to(dtype=mask.dtype)
639        return mask
640
641    @classmethod
642    def apply(cls, module, name, amount, dim=-1):
643        r"""Add pruning on the fly and reparametrization of a tensor.
644
645        Adds the forward pre-hook that enables pruning on the fly and
646        the reparametrization of a tensor in terms of the original tensor
647        and the pruning mask.
648
649        Args:
650            module (nn.Module): module containing the tensor to prune
651            name (str): parameter name within ``module`` on which pruning
652                will act.
653            amount (int or float): quantity of parameters to prune.
654                If ``float``, should be between 0.0 and 1.0 and represent the
655                fraction of parameters to prune. If ``int``, it represents the
656                absolute number of parameters to prune.
657            dim (int, optional): index of the dim along which we define
658                channels to prune. Default: -1.
659        """
660        return super().apply(module, name, amount=amount, dim=dim)
661
662
663class LnStructured(BasePruningMethod):
664    r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm.
665
666    Args:
667        amount (int or float): quantity of channels to prune.
668            If ``float``, should be between 0.0 and 1.0 and represent the
669            fraction of parameters to prune. If ``int``, it represents the
670            absolute number of parameters to prune.
671        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
672            entries for argument ``p`` in :func:`torch.norm`.
673        dim (int, optional): index of the dim along which we define
674            channels to prune. Default: -1.
675    """
676
677    PRUNING_TYPE = "structured"
678
679    def __init__(self, amount, n, dim=-1):
680        # Check range of validity of amount
681        _validate_pruning_amount_init(amount)
682        self.amount = amount
683        self.n = n
684        self.dim = dim
685
686    def compute_mask(self, t, default_mask):
687        r"""Compute and returns a mask for the input tensor ``t``.
688
689        Starting from a base ``default_mask`` (which should be a mask of ones
690        if the tensor has not been pruned yet), generate a mask to apply on
691        top of the ``default_mask`` by zeroing out the channels along the
692        specified dim with the lowest L\ ``n``-norm.
693
694        Args:
695            t (torch.Tensor): tensor representing the parameter to prune
696            default_mask (torch.Tensor): Base mask from previous pruning
697                iterations, that need to be respected after the new mask is
698                applied.  Same dims as ``t``.
699
700        Returns:
701            mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``
702
703        Raises:
704            IndexError: if ``self.dim >= len(t.shape)``
705        """
706        # Check that tensor has structure (i.e. more than 1 dimension) such
707        # that the concept of "channels" makes sense
708        _validate_structured_pruning(t)
709        # Check that self.dim is a valid dim to index t, else raise IndexError
710        _validate_pruning_dim(t, self.dim)
711
712        # Check that the amount of channels to prune is not > than the number of
713        # channels in t along the dim to prune
714        tensor_size = t.shape[self.dim]
715        # Compute number of units to prune: amount if int,
716        # else amount * tensor_size
717        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
718        nparams_tokeep = tensor_size - nparams_toprune
719        # This should raise an error if the number of units to prune is larger
720        # than the number of units in the tensor
721        _validate_pruning_amount(nparams_toprune, tensor_size)
722
723        # Structured pruning prunes entire channels so we need to know the
724        # L_n norm along each channel to then find the topk based on this
725        # metric
726        norm = _compute_norm(t, self.n, self.dim)
727        # largest=True --> top k; largest=False --> bottom k
728        # Keep the largest k channels along dim=self.dim
729        topk = torch.topk(norm, k=nparams_tokeep, largest=True)
730        # topk will have .indices and .values
731
732        # Compute binary mask by initializing it to all 0s and then filling in
733        # 1s wherever topk.indices indicates, along self.dim.
734        # mask has the same shape as tensor t
735        def make_mask(t, dim, indices):
736            # init mask to 0
737            mask = torch.zeros_like(t)
738            # e.g.: slc = [None, None, None], if len(t.shape) = 3
739            slc = [slice(None)] * len(t.shape)
740            # replace a None at position=dim with indices
741            # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3]
742            slc[dim] = indices
743            # use slc to slice mask and replace all its entries with 1s
744            # e.g.: mask[:, :, [0, 2, 3]] = 1
745            mask[slc] = 1
746            return mask
747
748        if nparams_toprune == 0:  # k=0 not supported by torch.kthvalue
749            mask = default_mask
750        else:
751            mask = make_mask(t, self.dim, topk.indices)
752            mask *= default_mask.to(dtype=mask.dtype)
753
754        return mask
755
756    @classmethod
757    def apply(cls, module, name, amount, n, dim, importance_scores=None):
758        r"""Add pruning on the fly and reparametrization of a tensor.
759
760        Adds the forward pre-hook that enables pruning on the fly and
761        the reparametrization of a tensor in terms of the original tensor
762        and the pruning mask.
763
764        Args:
765            module (nn.Module): module containing the tensor to prune
766            name (str): parameter name within ``module`` on which pruning
767                will act.
768            amount (int or float): quantity of parameters to prune.
769                If ``float``, should be between 0.0 and 1.0 and represent the
770                fraction of parameters to prune. If ``int``, it represents the
771                absolute number of parameters to prune.
772            n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
773                entries for argument ``p`` in :func:`torch.norm`.
774            dim (int): index of the dim along which we define channels to
775                prune.
776            importance_scores (torch.Tensor): tensor of importance scores (of same
777                shape as module parameter) used to compute mask for pruning.
778                The values in this tensor indicate the importance of the corresponding
779                elements in the parameter being pruned.
780                If unspecified or None, the module parameter will be used in its place.
781        """
782        return super().apply(
783            module,
784            name,
785            amount=amount,
786            n=n,
787            dim=dim,
788            importance_scores=importance_scores,
789        )
790
791
792class CustomFromMask(BasePruningMethod):
793    PRUNING_TYPE = "global"
794
795    def __init__(self, mask):
796        self.mask = mask
797
798    def compute_mask(self, t, default_mask):
799        assert default_mask.shape == self.mask.shape
800        mask = default_mask * self.mask.to(dtype=default_mask.dtype)
801        return mask
802
803    @classmethod
804    def apply(cls, module, name, mask):
805        r"""Add pruning on the fly and reparametrization of a tensor.
806
807        Adds the forward pre-hook that enables pruning on the fly and
808        the reparametrization of a tensor in terms of the original tensor
809        and the pruning mask.
810
811        Args:
812            module (nn.Module): module containing the tensor to prune
813            name (str): parameter name within ``module`` on which pruning
814                will act.
815        """
816        return super().apply(module, name, mask=mask)
817
818
819def identity(module, name):
820    r"""Apply pruning reparametrization without pruning any units.
821
822    Applies pruning reparametrization to the tensor corresponding to the
823    parameter called ``name`` in ``module`` without actually pruning any
824    units. Modifies module in place (and also return the modified module)
825    by:
826
827    1) adding a named buffer called ``name+'_mask'`` corresponding to the
828       binary mask applied to the parameter ``name`` by the pruning method.
829    2) replacing the parameter ``name`` by its pruned version, while the
830       original (unpruned) parameter is stored in a new parameter named
831       ``name+'_orig'``.
832
833    Note:
834        The mask is a tensor of ones.
835
836    Args:
837        module (nn.Module): module containing the tensor to prune.
838        name (str): parameter name within ``module`` on which pruning
839                will act.
840
841    Returns:
842        module (nn.Module): modified (i.e. pruned) version of the input module
843
844    Examples:
845        >>> # xdoctest: +SKIP
846        >>> m = prune.identity(nn.Linear(2, 3), 'bias')
847        >>> print(m.bias_mask)
848        tensor([1., 1., 1.])
849    """
850    Identity.apply(module, name)
851    return module
852
853
854def random_unstructured(module, name, amount):
855    r"""Prune tensor by removing random (currently unpruned) units.
856
857    Prunes tensor corresponding to parameter called ``name`` in ``module``
858    by removing the specified ``amount`` of (currently unpruned) units
859    selected at random.
860    Modifies module in place (and also return the modified module) by:
861
862    1) adding a named buffer called ``name+'_mask'`` corresponding to the
863       binary mask applied to the parameter ``name`` by the pruning method.
864    2) replacing the parameter ``name`` by its pruned version, while the
865       original (unpruned) parameter is stored in a new parameter named
866       ``name+'_orig'``.
867
868    Args:
869        module (nn.Module): module containing the tensor to prune
870        name (str): parameter name within ``module`` on which pruning
871                will act.
872        amount (int or float): quantity of parameters to prune.
873            If ``float``, should be between 0.0 and 1.0 and represent the
874            fraction of parameters to prune. If ``int``, it represents the
875            absolute number of parameters to prune.
876
877    Returns:
878        module (nn.Module): modified (i.e. pruned) version of the input module
879
880    Examples:
881        >>> # xdoctest: +SKIP
882        >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
883        >>> torch.sum(m.weight_mask == 0)
884        tensor(1)
885
886    """
887    RandomUnstructured.apply(module, name, amount)
888    return module
889
890
891def l1_unstructured(module, name, amount, importance_scores=None):
892    r"""Prune tensor by removing units with the lowest L1-norm.
893
894    Prunes tensor corresponding to parameter called ``name`` in ``module``
895    by removing the specified `amount` of (currently unpruned) units with the
896    lowest L1-norm.
897    Modifies module in place (and also return the modified module)
898    by:
899
900    1) adding a named buffer called ``name+'_mask'`` corresponding to the
901       binary mask applied to the parameter ``name`` by the pruning method.
902    2) replacing the parameter ``name`` by its pruned version, while the
903       original (unpruned) parameter is stored in a new parameter named
904       ``name+'_orig'``.
905
906    Args:
907        module (nn.Module): module containing the tensor to prune
908        name (str): parameter name within ``module`` on which pruning
909                will act.
910        amount (int or float): quantity of parameters to prune.
911            If ``float``, should be between 0.0 and 1.0 and represent the
912            fraction of parameters to prune. If ``int``, it represents the
913            absolute number of parameters to prune.
914        importance_scores (torch.Tensor): tensor of importance scores (of same
915            shape as module parameter) used to compute mask for pruning.
916            The values in this tensor indicate the importance of the corresponding
917            elements in the parameter being pruned.
918            If unspecified or None, the module parameter will be used in its place.
919
920    Returns:
921        module (nn.Module): modified (i.e. pruned) version of the input module
922
923    Examples:
924        >>> # xdoctest: +SKIP
925        >>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2)
926        >>> m.state_dict().keys()
927        odict_keys(['bias', 'weight_orig', 'weight_mask'])
928    """
929    L1Unstructured.apply(
930        module, name, amount=amount, importance_scores=importance_scores
931    )
932    return module
933
934
935def random_structured(module, name, amount, dim):
936    r"""Prune tensor by removing random channels along the specified dimension.
937
938    Prunes tensor corresponding to parameter called ``name`` in ``module``
939    by removing the specified ``amount`` of (currently unpruned) channels
940    along the specified ``dim`` selected at random.
941    Modifies module in place (and also return the modified module)
942    by:
943
944    1) adding a named buffer called ``name+'_mask'`` corresponding to the
945       binary mask applied to the parameter ``name`` by the pruning method.
946    2) replacing the parameter ``name`` by its pruned version, while the
947       original (unpruned) parameter is stored in a new parameter named
948       ``name+'_orig'``.
949
950    Args:
951        module (nn.Module): module containing the tensor to prune
952        name (str): parameter name within ``module`` on which pruning
953                will act.
954        amount (int or float): quantity of parameters to prune.
955            If ``float``, should be between 0.0 and 1.0 and represent the
956            fraction of parameters to prune. If ``int``, it represents the
957            absolute number of parameters to prune.
958        dim (int): index of the dim along which we define channels to prune.
959
960    Returns:
961        module (nn.Module): modified (i.e. pruned) version of the input module
962
963    Examples:
964        >>> # xdoctest: +SKIP
965        >>> m = prune.random_structured(
966        ...     nn.Linear(5, 3), 'weight', amount=3, dim=1
967        ... )
968        >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0))
969        >>> print(columns_pruned)
970        3
971    """
972    RandomStructured.apply(module, name, amount, dim)
973    return module
974
975
976def ln_structured(module, name, amount, n, dim, importance_scores=None):
977    r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension.
978
979    Prunes tensor corresponding to parameter called ``name`` in ``module``
980    by removing the specified ``amount`` of (currently unpruned) channels
981    along the specified ``dim`` with the lowest L\ ``n``-norm.
982    Modifies module in place (and also return the modified module)
983    by:
984
985    1) adding a named buffer called ``name+'_mask'`` corresponding to the
986       binary mask applied to the parameter ``name`` by the pruning method.
987    2) replacing the parameter ``name`` by its pruned version, while the
988       original (unpruned) parameter is stored in a new parameter named
989       ``name+'_orig'``.
990
991    Args:
992        module (nn.Module): module containing the tensor to prune
993        name (str): parameter name within ``module`` on which pruning
994                will act.
995        amount (int or float): quantity of parameters to prune.
996            If ``float``, should be between 0.0 and 1.0 and represent the
997            fraction of parameters to prune. If ``int``, it represents the
998            absolute number of parameters to prune.
999        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
1000            entries for argument ``p`` in :func:`torch.norm`.
1001        dim (int): index of the dim along which we define channels to prune.
1002        importance_scores (torch.Tensor): tensor of importance scores (of same
1003            shape as module parameter) used to compute mask for pruning.
1004            The values in this tensor indicate the importance of the corresponding
1005            elements in the parameter being pruned.
1006            If unspecified or None, the module parameter will be used in its place.
1007
1008    Returns:
1009        module (nn.Module): modified (i.e. pruned) version of the input module
1010
1011    Examples:
1012        >>> from torch.nn.utils import prune
1013        >>> m = prune.ln_structured(
1014        ...     nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf')
1015        ... )
1016    """
1017    LnStructured.apply(
1018        module, name, amount, n, dim, importance_scores=importance_scores
1019    )
1020    return module
1021
1022
1023def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):
1024    r"""
1025    Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``.
1026
1027    Modifies modules in place by:
1028
1029    1) adding a named buffer called ``name+'_mask'`` corresponding to the
1030       binary mask applied to the parameter ``name`` by the pruning method.
1031    2) replacing the parameter ``name`` by its pruned version, while the
1032       original (unpruned) parameter is stored in a new parameter named
1033       ``name+'_orig'``.
1034
1035    Args:
1036        parameters (Iterable of (module, name) tuples): parameters of
1037            the model to prune in a global fashion, i.e. by aggregating all
1038            weights prior to deciding which ones to prune. module must be of
1039            type :class:`nn.Module`, and name must be a string.
1040        pruning_method (function): a valid pruning function from this module,
1041            or a custom one implemented by the user that satisfies the
1042            implementation guidelines and has ``PRUNING_TYPE='unstructured'``.
1043        importance_scores (dict): a dictionary mapping (module, name) tuples to
1044            the corresponding parameter's importance scores tensor. The tensor
1045            should be the same shape as the parameter, and is used for computing
1046            mask for pruning.
1047            If unspecified or None, the parameter will be used in place of its
1048            importance scores.
1049        kwargs: other keyword arguments such as:
1050            amount (int or float): quantity of parameters to prune across the
1051            specified parameters.
1052            If ``float``, should be between 0.0 and 1.0 and represent the
1053            fraction of parameters to prune. If ``int``, it represents the
1054            absolute number of parameters to prune.
1055
1056    Raises:
1057        TypeError: if ``PRUNING_TYPE != 'unstructured'``
1058
1059    Note:
1060        Since global structured pruning doesn't make much sense unless the
1061        norm is normalized by the size of the parameter, we now limit the
1062        scope of global pruning to unstructured methods.
1063
1064    Examples:
1065        >>> from torch.nn.utils import prune
1066        >>> from collections import OrderedDict
1067        >>> net = nn.Sequential(OrderedDict([
1068        ...     ('first', nn.Linear(10, 4)),
1069        ...     ('second', nn.Linear(4, 1)),
1070        ... ]))
1071        >>> parameters_to_prune = (
1072        ...     (net.first, 'weight'),
1073        ...     (net.second, 'weight'),
1074        ... )
1075        >>> prune.global_unstructured(
1076        ...     parameters_to_prune,
1077        ...     pruning_method=prune.L1Unstructured,
1078        ...     amount=10,
1079        ... )
1080        >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))
1081        tensor(10)
1082
1083    """
1084    # ensure parameters is a list or generator of tuples
1085    if not isinstance(parameters, Iterable):
1086        raise TypeError("global_unstructured(): parameters is not an Iterable")
1087
1088    importance_scores = importance_scores if importance_scores is not None else {}
1089    if not isinstance(importance_scores, dict):
1090        raise TypeError("global_unstructured(): importance_scores must be of type dict")
1091
1092    # flatten importance scores to consider them all at once in global pruning
1093    relevant_importance_scores = torch.nn.utils.parameters_to_vector(
1094        [
1095            importance_scores.get((module, name), getattr(module, name))
1096            for (module, name) in parameters
1097        ]
1098    )
1099    # similarly, flatten the masks (if they exist), or use a flattened vector
1100    # of 1s of the same dimensions as t
1101    default_mask = torch.nn.utils.parameters_to_vector(
1102        [
1103            getattr(module, name + "_mask", torch.ones_like(getattr(module, name)))
1104            for (module, name) in parameters
1105        ]
1106    )
1107
1108    # use the canonical pruning methods to compute the new mask, even if the
1109    # parameter is now a flattened out version of `parameters`
1110    container = PruningContainer()
1111    container._tensor_name = "temp"  # to make it match that of `method`
1112    method = pruning_method(**kwargs)
1113    method._tensor_name = "temp"  # to make it match that of `container`
1114    if method.PRUNING_TYPE != "unstructured":
1115        raise TypeError(
1116            'Only "unstructured" PRUNING_TYPE supported for '
1117            f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}"
1118        )
1119
1120    container.add_pruning_method(method)
1121
1122    # use the `compute_mask` method from `PruningContainer` to combine the
1123    # mask computed by the new method with the pre-existing mask
1124    final_mask = container.compute_mask(relevant_importance_scores, default_mask)
1125
1126    # Pointer for slicing the mask to match the shape of each parameter
1127    pointer = 0
1128    for module, name in parameters:
1129        param = getattr(module, name)
1130        # The length of the parameter
1131        num_param = param.numel()
1132        # Slice the mask, reshape it
1133        param_mask = final_mask[pointer : pointer + num_param].view_as(param)
1134        # Assign the correct pre-computed mask to each parameter and add it
1135        # to the forward_pre_hooks like any other pruning method
1136        custom_from_mask(module, name, mask=param_mask)
1137
1138        # Increment the pointer to continue slicing the final_mask
1139        pointer += num_param
1140
1141
1142def custom_from_mask(module, name, mask):
1143    r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``.
1144
1145    Modifies module in place (and also return the modified module) by:
1146
1147    1) adding a named buffer called ``name+'_mask'`` corresponding to the
1148       binary mask applied to the parameter ``name`` by the pruning method.
1149    2) replacing the parameter ``name`` by its pruned version, while the
1150       original (unpruned) parameter is stored in a new parameter named
1151       ``name+'_orig'``.
1152
1153    Args:
1154        module (nn.Module): module containing the tensor to prune
1155        name (str): parameter name within ``module`` on which pruning
1156            will act.
1157        mask (Tensor): binary mask to be applied to the parameter.
1158
1159    Returns:
1160        module (nn.Module): modified (i.e. pruned) version of the input module
1161
1162    Examples:
1163        >>> from torch.nn.utils import prune
1164        >>> m = prune.custom_from_mask(
1165        ...     nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
1166        ... )
1167        >>> print(m.bias_mask)
1168        tensor([0., 1., 0.])
1169
1170    """
1171    CustomFromMask.apply(module, name, mask)
1172    return module
1173
1174
1175def remove(module, name):
1176    r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook.
1177
1178    The pruned parameter named ``name`` remains permanently pruned, and the parameter
1179    named ``name+'_orig'`` is removed from the parameter list. Similarly,
1180    the buffer named ``name+'_mask'`` is removed from the buffers.
1181
1182    Note:
1183        Pruning itself is NOT undone or reversed!
1184
1185    Args:
1186        module (nn.Module): module containing the tensor to prune
1187        name (str): parameter name within ``module`` on which pruning
1188            will act.
1189
1190    Examples:
1191        >>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2)
1192        >>> m = remove(m, name='weight')
1193    """
1194    for k, hook in module._forward_pre_hooks.items():
1195        if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
1196            hook.remove(module)
1197            del module._forward_pre_hooks[k]
1198            return module
1199
1200    raise ValueError(
1201        f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed"
1202    )
1203
1204
1205def is_pruned(module):
1206    r"""Check if a module is pruned by looking for pruning pre-hooks.
1207
1208    Check whether ``module`` is pruned by looking for
1209    ``forward_pre_hooks`` in its modules that inherit from the
1210    :class:`BasePruningMethod`.
1211
1212    Args:
1213        module (nn.Module): object that is either pruned or unpruned
1214
1215    Returns:
1216        binary answer to whether ``module`` is pruned.
1217
1218    Examples:
1219        >>> from torch.nn.utils import prune
1220        >>> m = nn.Linear(5, 7)
1221        >>> print(prune.is_pruned(m))
1222        False
1223        >>> prune.random_unstructured(m, name='weight', amount=0.2)
1224        >>> print(prune.is_pruned(m))
1225        True
1226    """
1227    for _, submodule in module.named_modules():
1228        for hook in submodule._forward_pre_hooks.values():
1229            if isinstance(hook, BasePruningMethod):
1230                return True
1231    return False
1232
1233
1234def _validate_pruning_amount_init(amount):
1235    r"""Validate helper to check the range of amount at init.
1236
1237    Args:
1238        amount (int or float): quantity of parameters to prune.
1239            If float, should be between 0.0 and 1.0 and represent the
1240            fraction of parameters to prune. If int, it represents the
1241            absolute number of parameters to prune.
1242
1243    Raises:
1244        ValueError: if amount is a float not in [0, 1], or if it's a negative
1245            integer.
1246        TypeError: if amount is neither a float nor an integer.
1247
1248    Note:
1249        This does not take into account the number of parameters in the
1250        tensor to be pruned, which is known only at prune.
1251    """
1252    if not isinstance(amount, numbers.Real):
1253        raise TypeError(f"Invalid type for amount: {amount}. Must be int or float.")
1254
1255    if (isinstance(amount, numbers.Integral) and amount < 0) or (
1256        not isinstance(amount, numbers.Integral)  # so it's a float
1257        and (float(amount) > 1.0 or float(amount) < 0.0)
1258    ):
1259        raise ValueError(
1260            f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer"
1261        )
1262
1263
1264def _validate_pruning_amount(amount, tensor_size):
1265    r"""Validate that the pruning amount is meaningful wrt to the size of the data.
1266
1267    Validation helper to check that the amount of parameters to prune
1268    is meaningful wrt to the size of the data (`tensor_size`).
1269
1270    Args:
1271        amount (int or float): quantity of parameters to prune.
1272            If float, should be between 0.0 and 1.0 and represent the
1273            fraction of parameters to prune. If int, it represents the
1274            absolute number of parameters to prune.
1275        tensor_size (int): absolute number of parameters in the tensor
1276            to prune.
1277    """
1278    # TODO: consider removing this check and allowing users to specify
1279    # a number of units to prune that is greater than the number of units
1280    # left to prune. In this case, the tensor will just be fully pruned.
1281
1282    if isinstance(amount, numbers.Integral) and amount > tensor_size:
1283        raise ValueError(
1284            f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}"
1285        )
1286
1287
1288def _validate_structured_pruning(t):
1289    r"""Validate that the tensor to be pruned is at least 2-Dimensional.
1290
1291    Validation helper to check that the tensor to be pruned is multi-
1292    dimensional, such that the concept of "channels" is well-defined.
1293
1294    Args:
1295        t (torch.Tensor): tensor representing the parameter to prune
1296
1297    Raises:
1298        ValueError: if the tensor `t` is not at least 2D.
1299    """
1300    shape = t.shape
1301    if len(shape) <= 1:
1302        raise ValueError(
1303            "Structured pruning can only be applied to "
1304            "multidimensional tensors. Found tensor of shape "
1305            f"{shape} with {len(shape)} dims"
1306        )
1307
1308
1309def _compute_nparams_toprune(amount, tensor_size):
1310    r"""Convert the pruning amount from a percentage to absolute value.
1311
1312    Since amount can be expressed either in absolute value or as a
1313    percentage of the number of units/channels in a tensor, this utility
1314    function converts the percentage to absolute value to standardize
1315    the handling of pruning.
1316
1317    Args:
1318        amount (int or float): quantity of parameters to prune.
1319            If float, should be between 0.0 and 1.0 and represent the
1320            fraction of parameters to prune. If int, it represents the
1321            absolute number of parameters to prune.
1322        tensor_size (int): absolute number of parameters in the tensor
1323            to prune.
1324
1325    Returns:
1326        int: the number of units to prune in the tensor
1327    """
1328    # incorrect type already checked in _validate_pruning_amount_init
1329    if isinstance(amount, numbers.Integral):
1330        return amount
1331    else:
1332        return round(amount * tensor_size)
1333
1334
1335def _validate_pruning_dim(t, dim):
1336    r"""Validate that the pruning dimension is within the bounds of the tensor dimension.
1337
1338    Args:
1339        t (torch.Tensor): tensor representing the parameter to prune
1340        dim (int): index of the dim along which we define channels to prune
1341    """
1342    if dim >= t.dim():
1343        raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}")
1344
1345
1346def _compute_norm(t, n, dim):
1347    r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension.
1348
1349    The L_n-norm will be computed across all entries in tensor `t` along all dimension
1350    except for the one identified by dim.
1351    Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim),
1352    then norm will have Size [4], and each entry will represent the
1353    `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels.
1354
1355    Args:
1356        t (torch.Tensor): tensor representing the parameter to prune
1357        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid
1358            entries for argument p in torch.norm
1359        dim (int): dim identifying the channels to prune
1360
1361    Returns:
1362        norm (torch.Tensor): L_n norm computed across all dimensions except
1363            for `dim`. By construction, `norm.shape = t.shape[-1]`.
1364    """
1365    # dims = all axes, except for the one identified by `dim`
1366    dims = list(range(t.dim()))
1367    # convert negative indexing
1368    if dim < 0:
1369        dim = dims[dim]
1370    dims.remove(dim)
1371
1372    norm = torch.norm(t, p=n, dim=dims)
1373    return norm
1374