1# mypy: allow-untyped-defs 2import operator 3from functools import reduce 4from typing import Callable, Optional, Tuple, Union 5 6import torch 7import torch.nn.functional as F 8 9from .base_sparsifier import BaseSparsifier 10 11 12__all__ = ["WeightNormSparsifier"] 13 14 15def _flat_idx_to_2d(idx, shape): 16 rows = idx // shape[1] 17 cols = idx % shape[1] 18 return rows, cols 19 20 21class WeightNormSparsifier(BaseSparsifier): 22 r"""Weight-Norm Sparsifier 23 24 This sparsifier computes the norm of every sparse block and "zeroes-out" the 25 ones with the lowest norm. The level of sparsity defines how many of the 26 blocks is removed. 27 28 This sparsifier is controlled by three variables: 29 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out 30 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that 31 the sparse blocks originate at the zero-index of the tensor. 32 3. `zeros_per_block` is the number of zeros that we are expecting in each 33 sparse block. By default we assume that all elements within a block are 34 zeroed-out. However, setting this variable sets the target number of 35 zeros per block. The zeros within each block are chosen as the *smallest 36 absolute values*. 37 38 Args: 39 40 sparsity_level: The target level of sparsity 41 sparse_block_shape: The shape of a sparse block (see note below) 42 zeros_per_block: Number of zeros in a sparse block 43 norm: Norm to use. Could be either `int` or a callable. 44 If `int`, only L1 and L2 are implemented. 45 46 Note:: 47 The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS), 48 irrespective of what the rows / cols mean in the data tensor. That means, 49 if you were to sparsify a weight tensor in the nn.Linear, which has a 50 weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output 51 channels, while the `block_COLS` would refer to the input channels. 52 53 Note:: 54 All arguments to the WeightNormSparsifier constructor are "default" 55 arguments and could be overriden by the configuration provided in the 56 `prepare` step. 57 """ 58 59 def __init__( 60 self, 61 sparsity_level: float = 0.5, 62 sparse_block_shape: Tuple[int, int] = (1, 4), 63 zeros_per_block: Optional[int] = None, 64 norm: Optional[Union[Callable, int]] = None, 65 ): 66 if zeros_per_block is None: 67 zeros_per_block = reduce(operator.mul, sparse_block_shape) 68 defaults = { 69 "sparsity_level": sparsity_level, 70 "sparse_block_shape": sparse_block_shape, 71 "zeros_per_block": zeros_per_block, 72 } 73 if norm is None: 74 norm = 2 75 if callable(norm): 76 self.norm_fn = norm 77 elif norm == 1: 78 self.norm_fn = lambda T: T.abs() 79 elif norm == 2: 80 self.norm_fn = lambda T: T * T 81 else: 82 raise NotImplementedError(f"L-{norm} is not yet implemented.") 83 super().__init__(defaults=defaults) 84 85 def _scatter_fold_block_mask( 86 self, 87 output_shape, 88 dim, 89 indices, 90 block_shape, 91 mask=None, 92 input_shape=None, 93 device=None, 94 ): 95 r"""Creates patches of size `block_shape` after scattering the indices.""" 96 if mask is None: 97 assert input_shape is not None 98 mask = torch.ones(input_shape, device=device) 99 mask.scatter_(dim=dim, index=indices, value=0) 100 mask.data = F.fold( 101 mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape 102 ) 103 return mask 104 105 def _make_tensor_mask( 106 self, data, input_shape, sparsity_level, sparse_block_shape, mask=None 107 ): 108 r"""Creates a tensor-level mask. 109 110 Tensor-level mask is described as a mask, where the granularity of sparsification of the 111 smallest patch is the sparse_block_shape. That means, that for a given mask and a 112 sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape. 113 114 In this context, `sparsity_level` describes the fraction of sparse patches. 115 """ 116 h, w = data.shape[-2:] 117 block_h, block_w = sparse_block_shape 118 dh = (block_h - h % block_h) % block_h 119 dw = (block_w - w % block_w) % block_w 120 121 if mask is None: 122 mask = torch.ones(h + dh, w + dw, device=data.device) 123 124 if sparsity_level >= 1.0: 125 mask.data = torch.zeros_like(mask) 126 return mask 127 elif sparsity_level <= 0.0: 128 mask.data = torch.ones_like(mask) 129 return mask 130 131 values_per_block = reduce(operator.mul, sparse_block_shape) 132 if values_per_block > 1: 133 # Reduce the data 134 data = F.avg_pool2d( 135 data[None, None, :], 136 kernel_size=sparse_block_shape, 137 stride=sparse_block_shape, 138 ceil_mode=True, 139 ) 140 data = data.flatten() 141 num_blocks = len(data) 142 143 data = data.repeat(1, values_per_block, 1) 144 145 threshold_idx = int(round(sparsity_level * num_blocks)) 146 threshold_idx = max(0, min(num_blocks - 1, threshold_idx)) # Sanity check 147 _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False) 148 149 # Temp reshape for mask 150 mask_reshape = mask.reshape(data.shape) # data might be reshaped 151 self._scatter_fold_block_mask( 152 dim=2, 153 output_shape=(h + dh, w + dw), 154 indices=sorted_idx, 155 block_shape=sparse_block_shape, 156 mask=mask_reshape, 157 ) 158 mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous() 159 return mask 160 161 def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None): 162 r"""Creates a block-level mask. 163 164 Block-level mask is described as a mask, where the granularity of sparsification of the 165 largest patch is the sparse_block_shape. That means that for a given mask and a 166 sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape. 167 168 In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch. 169 """ 170 h, w = data.shape[-2:] 171 block_h, block_w = sparse_block_shape 172 dh = (block_h - h % block_h) % block_h 173 dw = (block_w - w % block_w) % block_w 174 values_per_block = reduce(operator.mul, sparse_block_shape) 175 176 if mask is None: 177 mask = torch.ones((h + dh, w + dw), device=data.device) 178 179 if values_per_block == zeros_per_block: 180 # Everything should be sparsified 181 mask.data = torch.zeros_like(mask) 182 return mask 183 184 # create a new padded tensor like data (to match the block_shape) 185 padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device) 186 padded_data.fill_(torch.nan) 187 padded_data[:h, :w] = data 188 unfolded_data = F.unfold( 189 padded_data[None, None, :], 190 kernel_size=sparse_block_shape, 191 stride=sparse_block_shape, 192 ) 193 194 # Temp reshape for mask 195 mask_reshape = mask.reshape(unfolded_data.shape) 196 _, sorted_idx = torch.topk( 197 unfolded_data, k=zeros_per_block, dim=1, largest=False 198 ) 199 200 self._scatter_fold_block_mask( 201 dim=1, 202 indices=sorted_idx, 203 output_shape=padded_data.shape, 204 block_shape=sparse_block_shape, 205 mask=mask_reshape, 206 ) 207 208 mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() 209 return mask 210 211 def update_mask( 212 self, 213 module, 214 tensor_name, 215 sparsity_level, 216 sparse_block_shape, 217 zeros_per_block, 218 **kwargs, 219 ): 220 values_per_block = reduce(operator.mul, sparse_block_shape) 221 if zeros_per_block > values_per_block: 222 raise ValueError( 223 "Number of zeros per block cannot be more than the total number of elements in that block." 224 ) 225 if zeros_per_block < 0: 226 raise ValueError("Number of zeros per block should be positive.") 227 228 mask = getattr(module.parametrizations, tensor_name)[0].mask 229 if sparsity_level <= 0 or zeros_per_block == 0: 230 mask.data = torch.ones_like(mask) 231 elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): 232 mask.data = torch.zeros_like(mask) 233 else: 234 ww = self.norm_fn(getattr(module, tensor_name)) 235 tensor_mask = self._make_tensor_mask( 236 data=ww, 237 input_shape=ww.shape, 238 sparsity_level=sparsity_level, 239 sparse_block_shape=sparse_block_shape, 240 ) 241 if values_per_block != zeros_per_block: 242 block_mask = self._make_block_mask( 243 data=ww, 244 sparse_block_shape=sparse_block_shape, 245 zeros_per_block=zeros_per_block, 246 ) 247 tensor_mask = torch.logical_or(tensor_mask, block_mask) 248 mask.data = tensor_mask 249