xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3from functools import reduce
4from typing import Callable, Optional, Tuple, Union
5
6import torch
7import torch.nn.functional as F
8
9from .base_sparsifier import BaseSparsifier
10
11
12__all__ = ["WeightNormSparsifier"]
13
14
15def _flat_idx_to_2d(idx, shape):
16    rows = idx // shape[1]
17    cols = idx % shape[1]
18    return rows, cols
19
20
21class WeightNormSparsifier(BaseSparsifier):
22    r"""Weight-Norm Sparsifier
23
24    This sparsifier computes the norm of every sparse block and "zeroes-out" the
25    ones with the lowest norm. The level of sparsity defines how many of the
26    blocks is removed.
27
28    This sparsifier is controlled by three variables:
29    1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out
30    2. `sparse_block_shape` defines the shape of the sparse blocks. Note that
31        the sparse blocks originate at the zero-index of the tensor.
32    3. `zeros_per_block` is the number of zeros that we are expecting in each
33        sparse block. By default we assume that all elements within a block are
34        zeroed-out. However, setting this variable sets the target number of
35        zeros per block. The zeros within each block are chosen as the *smallest
36        absolute values*.
37
38    Args:
39
40        sparsity_level: The target level of sparsity
41        sparse_block_shape: The shape of a sparse block (see note below)
42        zeros_per_block: Number of zeros in a sparse block
43        norm: Norm to use. Could be either `int` or a callable.
44            If `int`, only L1 and L2 are implemented.
45
46    Note::
47        The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS),
48        irrespective of what the rows / cols mean in the data tensor. That means,
49        if you were to sparsify a weight tensor in the nn.Linear, which has a
50        weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output
51        channels, while the `block_COLS` would refer to the input channels.
52
53    Note::
54        All arguments to the WeightNormSparsifier constructor are "default"
55        arguments and could be overriden by the configuration provided in the
56        `prepare` step.
57    """
58
59    def __init__(
60        self,
61        sparsity_level: float = 0.5,
62        sparse_block_shape: Tuple[int, int] = (1, 4),
63        zeros_per_block: Optional[int] = None,
64        norm: Optional[Union[Callable, int]] = None,
65    ):
66        if zeros_per_block is None:
67            zeros_per_block = reduce(operator.mul, sparse_block_shape)
68        defaults = {
69            "sparsity_level": sparsity_level,
70            "sparse_block_shape": sparse_block_shape,
71            "zeros_per_block": zeros_per_block,
72        }
73        if norm is None:
74            norm = 2
75        if callable(norm):
76            self.norm_fn = norm
77        elif norm == 1:
78            self.norm_fn = lambda T: T.abs()
79        elif norm == 2:
80            self.norm_fn = lambda T: T * T
81        else:
82            raise NotImplementedError(f"L-{norm} is not yet implemented.")
83        super().__init__(defaults=defaults)
84
85    def _scatter_fold_block_mask(
86        self,
87        output_shape,
88        dim,
89        indices,
90        block_shape,
91        mask=None,
92        input_shape=None,
93        device=None,
94    ):
95        r"""Creates patches of size `block_shape` after scattering the indices."""
96        if mask is None:
97            assert input_shape is not None
98            mask = torch.ones(input_shape, device=device)
99        mask.scatter_(dim=dim, index=indices, value=0)
100        mask.data = F.fold(
101            mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape
102        )
103        return mask
104
105    def _make_tensor_mask(
106        self, data, input_shape, sparsity_level, sparse_block_shape, mask=None
107    ):
108        r"""Creates a tensor-level mask.
109
110        Tensor-level mask is described as a mask, where the granularity of sparsification of the
111        smallest patch is the sparse_block_shape. That means, that for a given mask and a
112        sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape.
113
114        In this context, `sparsity_level` describes the fraction of sparse patches.
115        """
116        h, w = data.shape[-2:]
117        block_h, block_w = sparse_block_shape
118        dh = (block_h - h % block_h) % block_h
119        dw = (block_w - w % block_w) % block_w
120
121        if mask is None:
122            mask = torch.ones(h + dh, w + dw, device=data.device)
123
124        if sparsity_level >= 1.0:
125            mask.data = torch.zeros_like(mask)
126            return mask
127        elif sparsity_level <= 0.0:
128            mask.data = torch.ones_like(mask)
129            return mask
130
131        values_per_block = reduce(operator.mul, sparse_block_shape)
132        if values_per_block > 1:
133            # Reduce the data
134            data = F.avg_pool2d(
135                data[None, None, :],
136                kernel_size=sparse_block_shape,
137                stride=sparse_block_shape,
138                ceil_mode=True,
139            )
140        data = data.flatten()
141        num_blocks = len(data)
142
143        data = data.repeat(1, values_per_block, 1)
144
145        threshold_idx = int(round(sparsity_level * num_blocks))
146        threshold_idx = max(0, min(num_blocks - 1, threshold_idx))  # Sanity check
147        _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False)
148
149        # Temp reshape for mask
150        mask_reshape = mask.reshape(data.shape)  # data might be reshaped
151        self._scatter_fold_block_mask(
152            dim=2,
153            output_shape=(h + dh, w + dw),
154            indices=sorted_idx,
155            block_shape=sparse_block_shape,
156            mask=mask_reshape,
157        )
158        mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous()
159        return mask
160
161    def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None):
162        r"""Creates a block-level mask.
163
164        Block-level mask is described as a mask, where the granularity of sparsification of the
165        largest patch is the sparse_block_shape. That means that for a given mask and a
166        sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape.
167
168        In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch.
169        """
170        h, w = data.shape[-2:]
171        block_h, block_w = sparse_block_shape
172        dh = (block_h - h % block_h) % block_h
173        dw = (block_w - w % block_w) % block_w
174        values_per_block = reduce(operator.mul, sparse_block_shape)
175
176        if mask is None:
177            mask = torch.ones((h + dh, w + dw), device=data.device)
178
179        if values_per_block == zeros_per_block:
180            # Everything should be sparsified
181            mask.data = torch.zeros_like(mask)
182            return mask
183
184        # create a new padded tensor like data (to match the block_shape)
185        padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device)
186        padded_data.fill_(torch.nan)
187        padded_data[:h, :w] = data
188        unfolded_data = F.unfold(
189            padded_data[None, None, :],
190            kernel_size=sparse_block_shape,
191            stride=sparse_block_shape,
192        )
193
194        # Temp reshape for mask
195        mask_reshape = mask.reshape(unfolded_data.shape)
196        _, sorted_idx = torch.topk(
197            unfolded_data, k=zeros_per_block, dim=1, largest=False
198        )
199
200        self._scatter_fold_block_mask(
201            dim=1,
202            indices=sorted_idx,
203            output_shape=padded_data.shape,
204            block_shape=sparse_block_shape,
205            mask=mask_reshape,
206        )
207
208        mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous()
209        return mask
210
211    def update_mask(
212        self,
213        module,
214        tensor_name,
215        sparsity_level,
216        sparse_block_shape,
217        zeros_per_block,
218        **kwargs,
219    ):
220        values_per_block = reduce(operator.mul, sparse_block_shape)
221        if zeros_per_block > values_per_block:
222            raise ValueError(
223                "Number of zeros per block cannot be more than the total number of elements in that block."
224            )
225        if zeros_per_block < 0:
226            raise ValueError("Number of zeros per block should be positive.")
227
228        mask = getattr(module.parametrizations, tensor_name)[0].mask
229        if sparsity_level <= 0 or zeros_per_block == 0:
230            mask.data = torch.ones_like(mask)
231        elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block):
232            mask.data = torch.zeros_like(mask)
233        else:
234            ww = self.norm_fn(getattr(module, tensor_name))
235            tensor_mask = self._make_tensor_mask(
236                data=ww,
237                input_shape=ww.shape,
238                sparsity_level=sparsity_level,
239                sparse_block_shape=sparse_block_shape,
240            )
241            if values_per_block != zeros_per_block:
242                block_mask = self._make_block_mask(
243                    data=ww,
244                    sparse_block_shape=sparse_block_shape,
245                    zeros_per_block=zeros_per_block,
246                )
247                tensor_mask = torch.logical_or(tensor_mask, block_mask)
248            mask.data = tensor_mask
249