1# mypy: allow-untyped-defs 2import torch 3 4 5def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): 6 """ 7 This is PyTorch implementation of main part of reorder_meta() 8 function, from tools/util/include/cutlass/util/host_reorder.h file 9 of CUTLASS source tree. Furthermore, CUTLASS template for sparse 10 GEMM decides upon layout of this matrix, and at the moment for the 11 sparse GEMM executed on tensor cores, this is layout described by 12 ColumnMajorInterleaved<2> data structure, in 13 include/cutlass/layout/matrix.h of CUTLASS source tree. The 14 reordering of meta matrix into meta_reordered matrix calculated 15 according to these segments of CUTLASS code is re-implemented here. 16 Note that this calculation produces offsets for scattering metadata 17 matrix elements into reordered metadata matrix elements (or, 18 equivalently, for gathering reordered metadata matrix element back 19 into metadata matrix elements). 20 """ 21 dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) 22 dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) 23 24 # Reorder the rows, then swizzle the 2x2 blocks. 25 group = 32 if meta_dtype.itemsize == 2 else 16 26 interweave = 4 if meta_dtype.itemsize == 2 else 2 27 dst_rows = ( 28 dst_rows // group * group 29 + (dst_rows % 8) * interweave 30 + (dst_rows % group) // 8 31 ) 32 33 topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) 34 bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) 35 dst_rows += topright - bottomleft 36 dst_cols -= topright - bottomleft 37 38 # Assumed that meta tensor is to be stored in CUTLASS 39 # InterleavedColumnMajor layout, and reverse engineered 40 # corresponding code to store values into this tensor. 41 interleave = 2 42 cols_maj = dst_cols // interleave 43 cols_min = dst_cols % interleave 44 return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) 45 46 47def sparse_semi_structured_from_dense_cutlass(dense): 48 """ 49 This function converts dense matrix into sparse semi-structured 50 representation, producing "compressed" matrix, in the layout used by 51 CUTLASS backend, and corresponding metadata matrix. 52 """ 53 if dense.dim() != 2: 54 raise RuntimeError( 55 f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" 56 ) 57 58 m, k = dense.shape 59 device = dense.device 60 61 meta_dtype = torch.int8 62 if dense.dtype == torch.int8: 63 meta_dtype = torch.int32 64 elif dense.dtype in [torch.half, torch.bfloat16, torch.float]: 65 meta_dtype = torch.int16 66 else: 67 raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") 68 quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 69 if quadbits_per_meta_elem not in (4, 8): 70 raise RuntimeError("Invalid number of elements per meta element calculated") 71 72 if meta_dtype == torch.int32: 73 if m % 16 != 0: 74 raise RuntimeError( 75 f"Number of rows of dense matrix {m} must be divisible by 16" 76 ) 77 else: 78 if m % 32 != 0: 79 raise RuntimeError( 80 f"Number of rows of dense matrix {m} must be divisible by 32" 81 ) 82 if k % (4 * quadbits_per_meta_elem) != 0: 83 raise RuntimeError( 84 f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" 85 ) 86 87 if dense.dtype != torch.float: 88 ksparse = 4 89 dense_4 = dense.view(-1, k // ksparse, ksparse) 90 m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) 91 else: 92 ksparse = 2 93 dense_2 = dense.view(-1, k // ksparse, ksparse) 94 m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) 95 meta_ncols = k // (ksparse * quadbits_per_meta_elem) 96 97 # Encoding quadruples of True/False values as follows: 98 # [True, True, False, False] -> 0b0100 99 # [True, False, True, False] -> 0b1000 100 # [False, True, True, False] -> 0b1001 101 # [True, False, False, True ] -> 0b1100 102 # [False, True, False, True ] -> 0b1101 103 # [False, False, True, True ] -> 0b1110 104 # Thus, lower two bits in the encoding are index of the True value 105 # at the lowest index in the quadruple, and the higher two bits in 106 # the encoding are index of the other True value in the quadruple. 107 # In case there are less than two True values, than False value or 108 # values at some index or indices are considered True for the 109 # encoding. In case there are more than two True values, then the 110 # excess True value(s) at some indices are considered False for 111 # the encoding. The exact encodings used for these cases are as 112 # follows: 113 # [False, False, False, False] -> 0b1110 114 # [False, False, False, True ] -> 0b1110 115 # [False, False, True, False] -> 0b1110 116 # [False, True, False, False] -> 0b1001 117 # [False, True, True, True ] -> 0b1101 118 # [True, False, False, False] -> 0b1000 119 # [True, False, True, True ] -> 0b1100 120 # [True, True, False, True ] -> 0b0100 121 # [True, True, True, False] -> 0b0100 122 # [True, True, True, True ] -> 0b0100 123 # These particular encodings are chosen, with the help of Espresso 124 # logic minimizer software, for the purpose of minimization of 125 # corresponding Boolean functions, that translate non-zero flags 126 # into encoding bits. Note also possible choices for the first 127 # and last of these encodings were limited only to (0b0100, 128 # 0b1110), in order to produce valid encodings for 1:2 sparsity 129 # case. 130 131 expr0 = m0 & m1 132 expr1 = ~m0 & m1 133 expr2 = ~m0 & ~m1 134 bit0 = expr1 135 bit1 = expr2 136 bit2 = expr0 | expr2 | m3 137 bit3 = expr1 | ~m1 138 idxs0 = bit0 | (bit1.to(torch.int64) << 1) 139 idxs1 = bit2 | (bit3.to(torch.int64) << 1) 140 141 if dense.dtype != torch.float: 142 sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] 143 sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) 144 sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) 145 else: 146 sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] 147 148 meta_4 = idxs0 | (idxs1 << 2) 149 meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) 150 151 if quadbits_per_meta_elem == 4: 152 meta = ( 153 meta_n[:, :, 0] 154 | (meta_n[:, :, 1] << 4) 155 | (meta_n[:, :, 2] << 8) 156 | (meta_n[:, :, 3] << 12) 157 ) 158 elif quadbits_per_meta_elem == 8: 159 meta = ( 160 meta_n[:, :, 0] 161 | (meta_n[:, :, 1] << 4) 162 | (meta_n[:, :, 2] << 8) 163 | (meta_n[:, :, 3] << 12) 164 | (meta_n[:, :, 4] << 16) 165 | (meta_n[:, :, 5] << 20) 166 | (meta_n[:, :, 6] << 24) 167 | (meta_n[:, :, 7] << 28) 168 ) 169 170 # Reorder meta tensor elements. 171 meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined] 172 meta_offsets = _calculate_meta_reordering_scatter_offsets( 173 m, meta_ncols, meta_dtype, device 174 ) 175 meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) 176 177 return (sparse, meta_reordered.view(m, meta_ncols)) 178 179 180def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): 181 """ 182 This function performs reverse of the function above - it 183 reconstructs dense matrix from a pair of "compressed" matrix, given 184 in the layout used by CUTLASS backend, and accompanying metadata 185 matrix. 186 """ 187 if sparse.dim() != 2: 188 raise RuntimeError( 189 f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" 190 ) 191 192 m, k = sparse.shape 193 device = sparse.device 194 195 if meta_reordered.dim() != 2: 196 raise RuntimeError( 197 f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" 198 ) 199 if meta_reordered.device != device: 200 raise RuntimeError( 201 f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" 202 ) 203 204 meta_dtype = meta_reordered.dtype 205 if meta_dtype not in (torch.int16, torch.int32): 206 raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") 207 quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 208 209 if sparse.dtype != torch.float: 210 ksparse = 4 211 else: 212 ksparse = 2 213 214 meta_nrows, meta_ncols = meta_reordered.shape 215 if meta_nrows != m: 216 raise RuntimeError( 217 f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" 218 ) 219 if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: 220 raise RuntimeError( 221 f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " 222 "expected according to the number of columns of meta matrix" 223 ) 224 225 # Undo meta tensor elements reordering. 226 meta_offsets = _calculate_meta_reordering_scatter_offsets( 227 m, meta_ncols, meta_dtype, device 228 ) 229 meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) 230 231 # Unpack sparse tensor back to original dense tensor, using 232 # information provided by meta tensor. Note that torch.float 233 # datatype is handled pretty much the same as 234 # torch.half/torch.bfloat16, as metadata for a pair of torch.float 235 # value is encoded as if underlying 8 bytes contain four 236 # torch.half/torch.bfloat16 values, where either first two or last 237 # two are zeros. 238 meta_2 = torch.empty( 239 (m, meta_ncols, 2 * quadbits_per_meta_elem), 240 dtype=meta_dtype, 241 device=device, 242 ) 243 if quadbits_per_meta_elem == 4: 244 meta_2[:, :, 0] = meta & 0b11 245 meta_2[:, :, 1] = (meta >> 2) & 0b11 246 meta_2[:, :, 2] = (meta >> 4) & 0b11 247 meta_2[:, :, 3] = (meta >> 6) & 0b11 248 meta_2[:, :, 4] = (meta >> 8) & 0b11 249 meta_2[:, :, 5] = (meta >> 10) & 0b11 250 meta_2[:, :, 6] = (meta >> 12) & 0b11 251 meta_2[:, :, 7] = (meta >> 14) & 0b11 252 elif quadbits_per_meta_elem == 8: 253 meta_2[:, :, 0] = meta & 0b11 254 meta_2[:, :, 1] = (meta >> 2) & 0b11 255 meta_2[:, :, 2] = (meta >> 4) & 0b11 256 meta_2[:, :, 3] = (meta >> 6) & 0b11 257 meta_2[:, :, 4] = (meta >> 8) & 0b11 258 meta_2[:, :, 5] = (meta >> 10) & 0b11 259 meta_2[:, :, 6] = (meta >> 12) & 0b11 260 meta_2[:, :, 7] = (meta >> 14) & 0b11 261 meta_2[:, :, 8] = (meta >> 16) & 0b11 262 meta_2[:, :, 9] = (meta >> 18) & 0b11 263 meta_2[:, :, 10] = (meta >> 20) & 0b11 264 meta_2[:, :, 11] = (meta >> 22) & 0b11 265 meta_2[:, :, 12] = (meta >> 24) & 0b11 266 meta_2[:, :, 13] = (meta >> 26) & 0b11 267 meta_2[:, :, 14] = (meta >> 28) & 0b11 268 meta_2[:, :, 15] = (meta >> 30) & 0b11 269 270 dense_offsets = meta_2.view(-1) + ( 271 torch.arange(0, 2 * m * k // ksparse, device=device) * 4 272 ).view(-1, 1).repeat(1, 2).view(-1) 273 274 dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) 275 if sparse.dtype != torch.float: 276 dense.scatter_(0, dense_offsets, sparse.view(-1)) 277 else: 278 dense.view(torch.half).scatter_( 279 0, dense_offsets, sparse.view(torch.half).view(-1) 280 ) 281 282 return dense.view(m, 2 * k) 283 284 285def _sparse_semi_structured_tile(dense): 286 """ 287 This function computes a 2:4 sparse tile by greedily taking the largest values. 288 289 Since we take the largest values greedily, how the sorting algorithm handles duplicates affects 290 the ultimate sparsity pattern. 291 292 Note that this function does not have the same sorting semantics as our CUDA backend, 293 which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern. 294 """ 295 296 def greedy_prune_tile(tile): 297 num_kept_row = [0, 0, 0, 0] 298 num_kept_col = [0, 0, 0, 0] 299 300 for x in tile.flatten().sort(descending=True, stable=True).indices: 301 r, c = x // 4, x % 4 302 if num_kept_row[r] < 2 and num_kept_col[c] < 2: 303 num_kept_row[r] += 1 304 num_kept_col[c] += 1 305 else: 306 tile[r, c] = 0 307 308 for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4): 309 for tile in batch: 310 greedy_prune_tile(tile) 311 312 return dense 313 314 315def _compute_compressed_swizzled_bitmask(dense): 316 """ 317 Calculates the compressed swizzled bitmask from a dense tensor 318 """ 319 320 # first we need to convert the dense tensor to a bitmask 321 int_bitmask = dense.bool().to(torch.uint8) 322 323 # Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles: 324 # A, B, C and D, as displayed in the following schema: 325 # +---+---+ 326 # | A | B | 327 # +---+---+ 328 # | C | D | 329 # +---+---+ 330 331 # we first need to split into the 8x8 tiles 332 bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8) 333 334 # then we unfold again to get our indivdual 4x4 tiles 335 bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4) 336 337 # Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern 338 # of that tile. Note that the least siginificant bit is stored first. 339 # [1 1 0 0] 340 # [1 1 0 0] -> 0011 0011 -> 51 341 # [0 0 1 1] 1100 1100 204 342 # [0 0 1 1] 343 344 # reshape tensor to expand tiles into 8-bit vectors 345 bitmask_binary_representation = bitmask_4x4_chunks.reshape( 346 *bitmask_4x4_chunks.shape[:2], 4, 2, 8 347 ) 348 349 # to convert from binary representaiton, we can do a matmul with powers of two 350 powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda") 351 # To run on GPU: cast to float to do matmul and then cast back 352 compressed_swizzled_bitmask = ( 353 bitmask_binary_representation.to(torch.float) @ powers_of_two 354 ).to(torch.uint8) 355 356 return compressed_swizzled_bitmask 357