xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3from functools import reduce
4from typing import Any, List, Optional, Tuple
5
6import torch
7from torch.nn import functional as F
8
9from .base_data_sparsifier import BaseDataSparsifier
10
11
12__all__ = ["DataNormSparsifier"]
13
14
15class DataNormSparsifier(BaseDataSparsifier):
16    r"""L1-Norm Sparsifier
17    This sparsifier computes the *L1-norm* of every sparse block and "zeroes-out" the
18    ones with the lowest norm. The level of sparsity defines how many of the
19    blocks is removed.
20    This sparsifier is controlled by three variables:
21    1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out
22    2. `sparse_block_shape` defines the shape of the sparse blocks. Note that
23        the sparse blocks originate at the zero-index of the tensor.
24    3. `zeros_per_block` is the number of zeros that we are expecting in each
25        sparse block. By default we assume that all elements within a block are
26        zeroed-out. However, setting this variable sets the target number of
27        zeros per block. The zeros within each block are chosen as the *smallest
28        absolute values*.
29    Args:
30        sparsity_level: The target level of sparsity
31        sparse_block_shape: The shape of a sparse block
32        zeros_per_block: Number of zeros in a sparse block
33    Note::
34        All arguments to the DataNormSparsifier constructor are "default"
35        arguments and could be overriden by the configuration provided in the
36        `add_data` step.
37    """
38
39    def __init__(
40        self,
41        data_list: Optional[List[Tuple[str, Any]]] = None,
42        sparsity_level: float = 0.5,
43        sparse_block_shape: Tuple[int, int] = (1, 4),
44        zeros_per_block: Optional[int] = None,
45        norm: str = "L1",
46    ):
47        if zeros_per_block is None:
48            zeros_per_block = reduce(operator.mul, sparse_block_shape)
49
50        assert norm in ["L1", "L2"], "only L1 and L2 norm supported at the moment"
51
52        defaults = {
53            "sparsity_level": sparsity_level,
54            "sparse_block_shape": sparse_block_shape,
55            "zeros_per_block": zeros_per_block,
56        }
57        self.norm = norm
58        super().__init__(data_list=data_list, **defaults)
59
60    def __get_scatter_folded_mask(
61        self, data, dim, indices, output_size, sparse_block_shape
62    ):
63        mask = torch.ones_like(data)
64        mask.scatter_(dim=dim, index=indices, value=0)  # zeroing out
65        mask = F.fold(
66            mask,
67            output_size=output_size,
68            kernel_size=sparse_block_shape,
69            stride=sparse_block_shape,
70        )
71        mask = mask.to(torch.int8)
72        return mask
73
74    def __get_block_level_mask(self, data, sparse_block_shape, zeros_per_block):
75        # Assume data is a squeezed tensor
76        height, width = data.shape[-2], data.shape[-1]
77        block_height, block_width = sparse_block_shape
78        values_per_block = block_height * block_width
79
80        # just return zeros if zeroing all elements in block
81        if values_per_block == zeros_per_block:
82            return torch.zeros_like(data, dtype=torch.int8)
83
84        # creating additional height and width to support padding
85        dh = (block_height - height % block_height) % block_height
86        dw = (block_width - width % block_width) % block_width
87
88        # create a new padded tensor like data (to match the block_shape)
89        padded_data = torch.ones(
90            height + dh, width + dw, dtype=data.dtype, device=data.device
91        )
92        padded_data = (
93            padded_data * torch.nan
94        )  # can also be replaced with 0 to stop the removal of edge data
95        padded_data[0:height, 0:width] = data
96        unfolded_data = F.unfold(
97            padded_data[None, None, :],
98            kernel_size=sparse_block_shape,
99            stride=sparse_block_shape,
100        )
101
102        _, sorted_idx = torch.sort(unfolded_data, dim=1)
103        sorted_idx = sorted_idx[
104            :, :zeros_per_block, :
105        ]  # zero out zeros_per_block number of elements
106
107        mask = self.__get_scatter_folded_mask(
108            data=unfolded_data,
109            dim=1,
110            indices=sorted_idx,
111            output_size=padded_data.shape,
112            sparse_block_shape=sparse_block_shape,
113        )
114
115        mask = (
116            mask.squeeze(0).squeeze(0)[:height, :width].contiguous()
117        )  # remove padding and make contiguous
118        return mask
119
120    def __get_data_level_mask(self, data, sparsity_level, sparse_block_shape):
121        height, width = data.shape[-2], data.shape[-1]
122        block_height, block_width = sparse_block_shape
123        dh = (block_height - height % block_height) % block_height
124        dw = (block_width - width % block_width) % block_width
125
126        data_norm = F.avg_pool2d(
127            data[None, None, :],
128            kernel_size=sparse_block_shape,
129            stride=sparse_block_shape,
130            ceil_mode=True,
131        )
132
133        values_per_block = reduce(operator.mul, sparse_block_shape)
134
135        data_norm = data_norm.flatten()
136        num_blocks = len(data_norm)
137
138        data_norm = data_norm.repeat(
139            1, values_per_block, 1
140        )  # get similar shape after unfold
141        _, sorted_idx = torch.sort(data_norm, dim=2)
142
143        threshold_idx = round(sparsity_level * num_blocks)  # number of blocks to remove
144        sorted_idx = sorted_idx[:, :, :threshold_idx]
145
146        mask = self.__get_scatter_folded_mask(
147            data=data_norm,
148            dim=2,
149            indices=sorted_idx,
150            output_size=(height + dh, width + dw),
151            sparse_block_shape=sparse_block_shape,
152        )
153
154        mask = mask.squeeze(0).squeeze(0)[
155            :height, :width
156        ]  # squeeze only the first 2 dimension
157        return mask
158
159    def update_mask(
160        self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs
161    ):
162        values_per_block = reduce(operator.mul, sparse_block_shape)
163        if zeros_per_block > values_per_block:
164            raise ValueError(
165                "Number of zeros per block cannot be more than "
166                "the total number of elements in that block."
167            )
168        if zeros_per_block < 0:
169            raise ValueError("Number of zeros per block should be positive.")
170
171        if self.norm == "L1":
172            data_norm = torch.abs(data).squeeze()  # absolute value based (L1)
173        else:
174            data_norm = (data * data).squeeze()  # square every element for L2
175
176        if len(data_norm.shape) > 2:  # only supports 2 dimensional data at the moment
177            raise ValueError("only supports 2-D at the moment")
178
179        elif len(data_norm.shape) == 1:  # in case the data is bias (or 1D)
180            data_norm = data_norm[None, :]
181
182        mask = self.get_mask(name)
183        if sparsity_level <= 0 or zeros_per_block == 0:
184            mask.data = torch.ones_like(mask)
185        elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block):
186            mask.data = torch.zeros_like(mask)
187
188        # Fetch the high level mask that zeros out entire blocks
189        data_lvl_mask = self.__get_data_level_mask(
190            data=data_norm,
191            sparsity_level=sparsity_level,
192            sparse_block_shape=sparse_block_shape,
193        )
194
195        # Fetch block level mask that zeros out 'zeros_per_block' number of elements in every block
196        block_lvl_mask = self.__get_block_level_mask(
197            data=data_norm,
198            sparse_block_shape=sparse_block_shape,
199            zeros_per_block=zeros_per_block,
200        )
201
202        # zero out the entries inside those blocks whose block is sparsified
203        mask.data = torch.where(data_lvl_mask == 1, data_lvl_mask, block_lvl_mask)
204