1# mypy: allow-untyped-defs 2import operator 3from functools import reduce 4from typing import Any, List, Optional, Tuple 5 6import torch 7from torch.nn import functional as F 8 9from .base_data_sparsifier import BaseDataSparsifier 10 11 12__all__ = ["DataNormSparsifier"] 13 14 15class DataNormSparsifier(BaseDataSparsifier): 16 r"""L1-Norm Sparsifier 17 This sparsifier computes the *L1-norm* of every sparse block and "zeroes-out" the 18 ones with the lowest norm. The level of sparsity defines how many of the 19 blocks is removed. 20 This sparsifier is controlled by three variables: 21 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out 22 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that 23 the sparse blocks originate at the zero-index of the tensor. 24 3. `zeros_per_block` is the number of zeros that we are expecting in each 25 sparse block. By default we assume that all elements within a block are 26 zeroed-out. However, setting this variable sets the target number of 27 zeros per block. The zeros within each block are chosen as the *smallest 28 absolute values*. 29 Args: 30 sparsity_level: The target level of sparsity 31 sparse_block_shape: The shape of a sparse block 32 zeros_per_block: Number of zeros in a sparse block 33 Note:: 34 All arguments to the DataNormSparsifier constructor are "default" 35 arguments and could be overriden by the configuration provided in the 36 `add_data` step. 37 """ 38 39 def __init__( 40 self, 41 data_list: Optional[List[Tuple[str, Any]]] = None, 42 sparsity_level: float = 0.5, 43 sparse_block_shape: Tuple[int, int] = (1, 4), 44 zeros_per_block: Optional[int] = None, 45 norm: str = "L1", 46 ): 47 if zeros_per_block is None: 48 zeros_per_block = reduce(operator.mul, sparse_block_shape) 49 50 assert norm in ["L1", "L2"], "only L1 and L2 norm supported at the moment" 51 52 defaults = { 53 "sparsity_level": sparsity_level, 54 "sparse_block_shape": sparse_block_shape, 55 "zeros_per_block": zeros_per_block, 56 } 57 self.norm = norm 58 super().__init__(data_list=data_list, **defaults) 59 60 def __get_scatter_folded_mask( 61 self, data, dim, indices, output_size, sparse_block_shape 62 ): 63 mask = torch.ones_like(data) 64 mask.scatter_(dim=dim, index=indices, value=0) # zeroing out 65 mask = F.fold( 66 mask, 67 output_size=output_size, 68 kernel_size=sparse_block_shape, 69 stride=sparse_block_shape, 70 ) 71 mask = mask.to(torch.int8) 72 return mask 73 74 def __get_block_level_mask(self, data, sparse_block_shape, zeros_per_block): 75 # Assume data is a squeezed tensor 76 height, width = data.shape[-2], data.shape[-1] 77 block_height, block_width = sparse_block_shape 78 values_per_block = block_height * block_width 79 80 # just return zeros if zeroing all elements in block 81 if values_per_block == zeros_per_block: 82 return torch.zeros_like(data, dtype=torch.int8) 83 84 # creating additional height and width to support padding 85 dh = (block_height - height % block_height) % block_height 86 dw = (block_width - width % block_width) % block_width 87 88 # create a new padded tensor like data (to match the block_shape) 89 padded_data = torch.ones( 90 height + dh, width + dw, dtype=data.dtype, device=data.device 91 ) 92 padded_data = ( 93 padded_data * torch.nan 94 ) # can also be replaced with 0 to stop the removal of edge data 95 padded_data[0:height, 0:width] = data 96 unfolded_data = F.unfold( 97 padded_data[None, None, :], 98 kernel_size=sparse_block_shape, 99 stride=sparse_block_shape, 100 ) 101 102 _, sorted_idx = torch.sort(unfolded_data, dim=1) 103 sorted_idx = sorted_idx[ 104 :, :zeros_per_block, : 105 ] # zero out zeros_per_block number of elements 106 107 mask = self.__get_scatter_folded_mask( 108 data=unfolded_data, 109 dim=1, 110 indices=sorted_idx, 111 output_size=padded_data.shape, 112 sparse_block_shape=sparse_block_shape, 113 ) 114 115 mask = ( 116 mask.squeeze(0).squeeze(0)[:height, :width].contiguous() 117 ) # remove padding and make contiguous 118 return mask 119 120 def __get_data_level_mask(self, data, sparsity_level, sparse_block_shape): 121 height, width = data.shape[-2], data.shape[-1] 122 block_height, block_width = sparse_block_shape 123 dh = (block_height - height % block_height) % block_height 124 dw = (block_width - width % block_width) % block_width 125 126 data_norm = F.avg_pool2d( 127 data[None, None, :], 128 kernel_size=sparse_block_shape, 129 stride=sparse_block_shape, 130 ceil_mode=True, 131 ) 132 133 values_per_block = reduce(operator.mul, sparse_block_shape) 134 135 data_norm = data_norm.flatten() 136 num_blocks = len(data_norm) 137 138 data_norm = data_norm.repeat( 139 1, values_per_block, 1 140 ) # get similar shape after unfold 141 _, sorted_idx = torch.sort(data_norm, dim=2) 142 143 threshold_idx = round(sparsity_level * num_blocks) # number of blocks to remove 144 sorted_idx = sorted_idx[:, :, :threshold_idx] 145 146 mask = self.__get_scatter_folded_mask( 147 data=data_norm, 148 dim=2, 149 indices=sorted_idx, 150 output_size=(height + dh, width + dw), 151 sparse_block_shape=sparse_block_shape, 152 ) 153 154 mask = mask.squeeze(0).squeeze(0)[ 155 :height, :width 156 ] # squeeze only the first 2 dimension 157 return mask 158 159 def update_mask( 160 self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs 161 ): 162 values_per_block = reduce(operator.mul, sparse_block_shape) 163 if zeros_per_block > values_per_block: 164 raise ValueError( 165 "Number of zeros per block cannot be more than " 166 "the total number of elements in that block." 167 ) 168 if zeros_per_block < 0: 169 raise ValueError("Number of zeros per block should be positive.") 170 171 if self.norm == "L1": 172 data_norm = torch.abs(data).squeeze() # absolute value based (L1) 173 else: 174 data_norm = (data * data).squeeze() # square every element for L2 175 176 if len(data_norm.shape) > 2: # only supports 2 dimensional data at the moment 177 raise ValueError("only supports 2-D at the moment") 178 179 elif len(data_norm.shape) == 1: # in case the data is bias (or 1D) 180 data_norm = data_norm[None, :] 181 182 mask = self.get_mask(name) 183 if sparsity_level <= 0 or zeros_per_block == 0: 184 mask.data = torch.ones_like(mask) 185 elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): 186 mask.data = torch.zeros_like(mask) 187 188 # Fetch the high level mask that zeros out entire blocks 189 data_lvl_mask = self.__get_data_level_mask( 190 data=data_norm, 191 sparsity_level=sparsity_level, 192 sparse_block_shape=sparse_block_shape, 193 ) 194 195 # Fetch block level mask that zeros out 'zeros_per_block' number of elements in every block 196 block_lvl_mask = self.__get_block_level_mask( 197 data=data_norm, 198 sparse_block_shape=sparse_block_shape, 199 zeros_per_block=zeros_per_block, 200 ) 201 202 # zero out the entries inside those blocks whose block is sparsified 203 mask.data = torch.where(data_lvl_mask == 1, data_lvl_mask, block_lvl_mask) 204