xref: /aosp_15_r20/external/pytorch/torch/sparse/_semi_structured_conversions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4
5def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
6    """
7    This is PyTorch implementation of main part of reorder_meta()
8    function, from tools/util/include/cutlass/util/host_reorder.h file
9    of CUTLASS source tree.  Furthermore, CUTLASS template for sparse
10    GEMM decides upon layout of this matrix, and at the moment for the
11    sparse GEMM executed on tensor cores, this is layout described by
12    ColumnMajorInterleaved<2> data structure, in
13    include/cutlass/layout/matrix.h of CUTLASS source tree.  The
14    reordering of meta matrix into meta_reordered matrix calculated
15    according to these segments of CUTLASS code is re-implemented here.
16    Note that this calculation produces offsets for scattering metadata
17    matrix elements into reordered metadata matrix elements (or,
18    equivalently, for gathering reordered metadata matrix element back
19    into metadata matrix elements).
20    """
21    dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
22    dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
23
24    # Reorder the rows, then swizzle the 2x2 blocks.
25    group = 32 if meta_dtype.itemsize == 2 else 16
26    interweave = 4 if meta_dtype.itemsize == 2 else 2
27    dst_rows = (
28        dst_rows // group * group
29        + (dst_rows % 8) * interweave
30        + (dst_rows % group) // 8
31    )
32
33    topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
34    bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
35    dst_rows += topright - bottomleft
36    dst_cols -= topright - bottomleft
37
38    # Assumed that meta tensor is to be stored in CUTLASS
39    # InterleavedColumnMajor layout, and reverse engineered
40    # corresponding code to store values into this tensor.
41    interleave = 2
42    cols_maj = dst_cols // interleave
43    cols_min = dst_cols % interleave
44    return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
45
46
47def sparse_semi_structured_from_dense_cutlass(dense):
48    """
49    This function converts dense matrix into sparse semi-structured
50    representation, producing "compressed" matrix, in the layout used by
51    CUTLASS backend, and corresponding metadata matrix.
52    """
53    if dense.dim() != 2:
54        raise RuntimeError(
55            f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"
56        )
57
58    m, k = dense.shape
59    device = dense.device
60
61    meta_dtype = torch.int8
62    if dense.dtype == torch.int8:
63        meta_dtype = torch.int32
64    elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
65        meta_dtype = torch.int16
66    else:
67        raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
68    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
69    if quadbits_per_meta_elem not in (4, 8):
70        raise RuntimeError("Invalid number of elements per meta element calculated")
71
72    if meta_dtype == torch.int32:
73        if m % 16 != 0:
74            raise RuntimeError(
75                f"Number of rows of dense matrix {m} must be divisible by 16"
76            )
77    else:
78        if m % 32 != 0:
79            raise RuntimeError(
80                f"Number of rows of dense matrix {m} must be divisible by 32"
81            )
82    if k % (4 * quadbits_per_meta_elem) != 0:
83        raise RuntimeError(
84            f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
85        )
86
87    if dense.dtype != torch.float:
88        ksparse = 4
89        dense_4 = dense.view(-1, k // ksparse, ksparse)
90        m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
91    else:
92        ksparse = 2
93        dense_2 = dense.view(-1, k // ksparse, ksparse)
94        m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
95    meta_ncols = k // (ksparse * quadbits_per_meta_elem)
96
97    # Encoding quadruples of True/False values as follows:
98    #     [True,  True,  False, False] -> 0b0100
99    #     [True,  False, True,  False] -> 0b1000
100    #     [False, True,  True,  False] -> 0b1001
101    #     [True,  False, False, True ] -> 0b1100
102    #     [False, True,  False, True ] -> 0b1101
103    #     [False, False, True,  True ] -> 0b1110
104    # Thus, lower two bits in the encoding are index of the True value
105    # at the lowest index in the quadruple, and the higher two bits in
106    # the encoding are index of the other True value in the quadruple.
107    # In case there are less than two True values, than False value or
108    # values at some index or indices are considered True for the
109    # encoding.  In case there are more than two True values, then the
110    # excess True value(s) at some indices are considered False for
111    # the encoding.  The exact encodings used for these cases are as
112    # follows:
113    #     [False, False, False, False] -> 0b1110
114    #     [False, False, False, True ] -> 0b1110
115    #     [False, False, True,  False] -> 0b1110
116    #     [False, True,  False, False] -> 0b1001
117    #     [False, True,  True,  True ] -> 0b1101
118    #     [True,  False, False, False] -> 0b1000
119    #     [True,  False, True,  True ] -> 0b1100
120    #     [True,  True,  False, True ] -> 0b0100
121    #     [True,  True,  True,  False] -> 0b0100
122    #     [True,  True,  True,  True ] -> 0b0100
123    # These particular encodings are chosen, with the help of Espresso
124    # logic minimizer software, for the purpose of minimization of
125    # corresponding Boolean functions, that translate non-zero flags
126    # into encoding bits.  Note also possible choices for the first
127    # and last of these encodings were limited only to (0b0100,
128    # 0b1110), in order to produce valid encodings for 1:2 sparsity
129    # case.
130
131    expr0 = m0 & m1
132    expr1 = ~m0 & m1
133    expr2 = ~m0 & ~m1
134    bit0 = expr1
135    bit1 = expr2
136    bit2 = expr0 | expr2 | m3
137    bit3 = expr1 | ~m1
138    idxs0 = bit0 | (bit1.to(torch.int64) << 1)
139    idxs1 = bit2 | (bit3.to(torch.int64) << 1)
140
141    if dense.dtype != torch.float:
142        sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1))  # type: ignore[possibly-undefined]
143        sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
144        sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
145    else:
146        sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2)  # type: ignore[possibly-undefined]
147
148    meta_4 = idxs0 | (idxs1 << 2)
149    meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
150
151    if quadbits_per_meta_elem == 4:
152        meta = (
153            meta_n[:, :, 0]
154            | (meta_n[:, :, 1] << 4)
155            | (meta_n[:, :, 2] << 8)
156            | (meta_n[:, :, 3] << 12)
157        )
158    elif quadbits_per_meta_elem == 8:
159        meta = (
160            meta_n[:, :, 0]
161            | (meta_n[:, :, 1] << 4)
162            | (meta_n[:, :, 2] << 8)
163            | (meta_n[:, :, 3] << 12)
164            | (meta_n[:, :, 4] << 16)
165            | (meta_n[:, :, 5] << 20)
166            | (meta_n[:, :, 6] << 24)
167            | (meta_n[:, :, 7] << 28)
168        )
169
170    # Reorder meta tensor elements.
171    meta_reordered = meta.new_empty((m * meta_ncols,))  # type: ignore[possibly-undefined]
172    meta_offsets = _calculate_meta_reordering_scatter_offsets(
173        m, meta_ncols, meta_dtype, device
174    )
175    meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
176
177    return (sparse, meta_reordered.view(m, meta_ncols))
178
179
180def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
181    """
182    This function performs reverse of the function above - it
183    reconstructs dense matrix from a pair of "compressed" matrix, given
184    in the layout used by CUTLASS backend, and accompanying metadata
185    matrix.
186    """
187    if sparse.dim() != 2:
188        raise RuntimeError(
189            f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor"
190        )
191
192    m, k = sparse.shape
193    device = sparse.device
194
195    if meta_reordered.dim() != 2:
196        raise RuntimeError(
197            f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor"
198        )
199    if meta_reordered.device != device:
200        raise RuntimeError(
201            f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device"
202        )
203
204    meta_dtype = meta_reordered.dtype
205    if meta_dtype not in (torch.int16, torch.int32):
206        raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
207    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
208
209    if sparse.dtype != torch.float:
210        ksparse = 4
211    else:
212        ksparse = 2
213
214    meta_nrows, meta_ncols = meta_reordered.shape
215    if meta_nrows != m:
216        raise RuntimeError(
217            f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}"
218        )
219    if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
220        raise RuntimeError(
221            f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, "
222            "expected according to the number of columns of meta matrix"
223        )
224
225    # Undo meta tensor elements reordering.
226    meta_offsets = _calculate_meta_reordering_scatter_offsets(
227        m, meta_ncols, meta_dtype, device
228    )
229    meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
230
231    # Unpack sparse tensor back to original dense tensor, using
232    # information provided by meta tensor.  Note that torch.float
233    # datatype is handled pretty much the same as
234    # torch.half/torch.bfloat16, as metadata for a pair of torch.float
235    # value is encoded as if underlying 8 bytes contain four
236    # torch.half/torch.bfloat16 values, where either first two or last
237    # two are zeros.
238    meta_2 = torch.empty(
239        (m, meta_ncols, 2 * quadbits_per_meta_elem),
240        dtype=meta_dtype,
241        device=device,
242    )
243    if quadbits_per_meta_elem == 4:
244        meta_2[:, :, 0] = meta & 0b11
245        meta_2[:, :, 1] = (meta >> 2) & 0b11
246        meta_2[:, :, 2] = (meta >> 4) & 0b11
247        meta_2[:, :, 3] = (meta >> 6) & 0b11
248        meta_2[:, :, 4] = (meta >> 8) & 0b11
249        meta_2[:, :, 5] = (meta >> 10) & 0b11
250        meta_2[:, :, 6] = (meta >> 12) & 0b11
251        meta_2[:, :, 7] = (meta >> 14) & 0b11
252    elif quadbits_per_meta_elem == 8:
253        meta_2[:, :, 0] = meta & 0b11
254        meta_2[:, :, 1] = (meta >> 2) & 0b11
255        meta_2[:, :, 2] = (meta >> 4) & 0b11
256        meta_2[:, :, 3] = (meta >> 6) & 0b11
257        meta_2[:, :, 4] = (meta >> 8) & 0b11
258        meta_2[:, :, 5] = (meta >> 10) & 0b11
259        meta_2[:, :, 6] = (meta >> 12) & 0b11
260        meta_2[:, :, 7] = (meta >> 14) & 0b11
261        meta_2[:, :, 8] = (meta >> 16) & 0b11
262        meta_2[:, :, 9] = (meta >> 18) & 0b11
263        meta_2[:, :, 10] = (meta >> 20) & 0b11
264        meta_2[:, :, 11] = (meta >> 22) & 0b11
265        meta_2[:, :, 12] = (meta >> 24) & 0b11
266        meta_2[:, :, 13] = (meta >> 26) & 0b11
267        meta_2[:, :, 14] = (meta >> 28) & 0b11
268        meta_2[:, :, 15] = (meta >> 30) & 0b11
269
270    dense_offsets = meta_2.view(-1) + (
271        torch.arange(0, 2 * m * k // ksparse, device=device) * 4
272    ).view(-1, 1).repeat(1, 2).view(-1)
273
274    dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
275    if sparse.dtype != torch.float:
276        dense.scatter_(0, dense_offsets, sparse.view(-1))
277    else:
278        dense.view(torch.half).scatter_(
279            0, dense_offsets, sparse.view(torch.half).view(-1)
280        )
281
282    return dense.view(m, 2 * k)
283
284
285def _sparse_semi_structured_tile(dense):
286    """
287    This function computes a 2:4 sparse tile by greedily taking the largest values.
288
289    Since we take the largest values greedily, how the sorting algorithm handles duplicates affects
290    the ultimate sparsity pattern.
291
292    Note that this function does not have the same sorting semantics as our CUDA backend,
293    which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern.
294    """
295
296    def greedy_prune_tile(tile):
297        num_kept_row = [0, 0, 0, 0]
298        num_kept_col = [0, 0, 0, 0]
299
300        for x in tile.flatten().sort(descending=True, stable=True).indices:
301            r, c = x // 4, x % 4
302            if num_kept_row[r] < 2 and num_kept_col[c] < 2:
303                num_kept_row[r] += 1
304                num_kept_col[c] += 1
305            else:
306                tile[r, c] = 0
307
308    for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4):
309        for tile in batch:
310            greedy_prune_tile(tile)
311
312    return dense
313
314
315def _compute_compressed_swizzled_bitmask(dense):
316    """
317    Calculates the compressed swizzled bitmask from a dense tensor
318    """
319
320    # first we need to convert the dense tensor to a bitmask
321    int_bitmask = dense.bool().to(torch.uint8)
322
323    # Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
324    # A, B, C and D, as displayed in the following schema:
325    # +---+---+
326    # | A | B |
327    # +---+---+
328    # | C | D |
329    # +---+---+
330
331    # we first need to split into the 8x8 tiles
332    bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8)
333
334    # then we unfold again to get our indivdual 4x4 tiles
335    bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4)
336
337    # Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern
338    # of that tile. Note that the least siginificant bit is stored first.
339    # [1 1 0 0]
340    # [1 1 0 0]  ->  0011 0011 ->   51
341    # [0 0 1 1]      1100 1100      204
342    # [0 0 1 1]
343
344    # reshape tensor to expand tiles into 8-bit vectors
345    bitmask_binary_representation = bitmask_4x4_chunks.reshape(
346        *bitmask_4x4_chunks.shape[:2], 4, 2, 8
347    )
348
349    # to convert from binary representaiton, we can do a matmul with powers of two
350    powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda")
351    # To run on GPU: cast to float to do matmul and then cast back
352    compressed_swizzled_bitmask = (
353        bitmask_binary_representation.to(torch.float) @ powers_of_two
354    ).to(torch.uint8)
355
356    return compressed_swizzled_bitmask
357