xref: /aosp_15_r20/external/pytorch/benchmarks/sparse/triton_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2from torch._inductor.runtime.benchmarking import benchmarker
3
4
5def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
6    assert (
7        sparsity <= 1.0 and sparsity >= 0.0
8    ), "sparsity should be a value between 0 and 1"
9    assert M % blocksize[0] == 0
10    assert N % blocksize[1] == 0
11    shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]
12    A = torch.bernoulli(torch.full(shape, 1 - sparsity, dtype=dtype, device=device))
13    expected_nnz = int((1 - sparsity) * M * N / (blocksize[0] * blocksize[1]))
14    nonzero_indices = A.flatten().nonzero()
15    actual_nnz = nonzero_indices.shape[0]
16    if actual_nnz > expected_nnz:
17        selected_nonzeros = torch.randperm(actual_nnz)[: actual_nnz - expected_nnz]
18        A.flatten()[nonzero_indices[selected_nonzeros]] = 0
19    elif actual_nnz < expected_nnz:
20        zero_indices = (A == 0).flatten().nonzero()
21        selected_zeros = torch.randperm(zero_indices.shape[0])[
22            : expected_nnz - actual_nnz
23        ]
24        A.flatten()[zero_indices[selected_zeros]] = 1
25    A = torch.repeat_interleave(A, blocksize[0], dim=-2)
26    A = torch.repeat_interleave(A, blocksize[1], dim=-1)
27    return A
28
29
30def _test_worker(test_func):
31    ms, ms_min, ms_max = benchmarker.benchmark_gpu(
32        test_func, warmup=500, rep=100, fast_flush=False
33    )
34
35    tflops = 2 * m * k * n * 1e-12 / (ms * 1e-3)
36    return ms, tflops
37
38
39def test_dense_dense_mm(x, y, **meta):
40    def test_func(x=x.to_dense(), y=y):
41        return torch.matmul(x, y)
42
43    return _test_worker(test_func)
44
45
46def test_torch_matmul(x, y, **meta):
47    def test_func(x=x, y=y):
48        return torch.matmul(x, y)
49
50    return _test_worker(test_func)
51
52
53def test_bsr_dense_mm(x, y, **meta):
54    from torch.sparse._triton_ops import bsr_dense_mm
55
56    def test_func(x=x, y=y):
57        return bsr_dense_mm(
58            x, y, meta=dict(GROUP_SIZE_ROW=4, num_stages=1, num_warps=4)
59        )
60
61    return _test_worker(test_func)
62
63
64def test_bsr_dense_mm_with_meta(x, y, **meta):
65    from torch.sparse._triton_ops import bsr_dense_mm
66
67    def test_func(x=x, y=y, meta=meta):
68        return bsr_dense_mm(x, y, meta=meta)
69
70    return _test_worker(test_func)
71
72
73def test_bsr_scatter_mm2(x, y, **meta):
74    from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
75
76    indices_data = bsr_scatter_mm_indices_data(
77        x, y, indices_format="scatter_mm", **meta
78    )
79
80    def test_func(x=x, y=y):
81        return bsr_scatter_mm(x, y, indices_data=indices_data)
82
83    return _test_worker(test_func)
84
85
86def test_bsr_scatter_mm6(x, y, **meta):
87    from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
88
89    indices_data = bsr_scatter_mm_indices_data(
90        x, y, indices_format="bsr_strided_mm_compressed", **meta
91    )
92
93    def test_func(x=x, y=y):
94        return bsr_scatter_mm(x, y, indices_data=indices_data)
95
96    return _test_worker(test_func)
97
98
99def test_bsr_scatter_mm(x, y, **meta):
100    from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
101
102    def test_func(x=x, y=y):
103        indices_data = bsr_scatter_mm_indices_data(
104            x, y, indices_format="bsr_strided_mm_compressed", **meta
105        )
106        return bsr_scatter_mm(x, y, indices_data=indices_data)
107
108    return _test_worker(test_func)
109
110
111def test_linear(x, y, **meta):
112    import torch.nn.functional as F
113
114    def test_func(x=x, y=y.transpose(-2, -1)):
115        return F.linear(y, x)
116
117    return _test_worker(test_func)
118
119
120if __name__ == "__main__":
121    import argparse
122    import atexit
123    import itertools
124    import sys
125
126    import triton
127
128    from torch.testing import make_tensor
129
130    torch.manual_seed(0)
131
132    def integer_list(a):
133        return list(map(int, a.split(",")))
134
135    def float_list(a):
136        return list(map(float, a.split(",")))
137
138    def integer_or_float_list(a):
139        lst = []
140        for n in a.split(","):
141            if n.count(":") == 1:
142                start, end = map(int, n.split(":"))
143                lst.extend(range(start, end))
144            elif n.count(":") == 2:
145                start, end, step = map(int, n.split(":"))
146                lst.extend(range(start, end, step))
147            elif "." in n:
148                lst.append(float(n))
149            else:
150                lst.append(int(n))
151        return lst
152
153    parser = argparse.ArgumentParser(description="SpTritonOps")
154
155    parser.add_argument(
156        "--ops",
157        default="dense_dense_mm,bsr_dense_mm,bsr_scatter_mm6",
158        type=str,
159    )
160    parser.add_argument("--b", default="0", type=int)
161
162    parser.add_argument("--m", default="1024", type=integer_list)
163    parser.add_argument("--k", default=None, type=integer_list)
164    parser.add_argument("--n", default=None, type=integer_list)
165    parser.add_argument("--bm", default="16", type=integer_list)
166    parser.add_argument("--bk", default=None, type=integer_list)
167    parser.add_argument("--tile_m", default=None, type=integer_list)
168    parser.add_argument("--tile_n", default=None, type=integer_list)
169    parser.add_argument("--split_n", default=None, type=integer_list)
170    parser.add_argument("--group_size", default=None, type=integer_list)
171    parser.add_argument("--num_warps", default=None, type=integer_list)
172    parser.add_argument("--num_stages", default=None, type=integer_list)
173    parser.add_argument("--sparsity", default="0.5", type=integer_or_float_list)
174    parser.add_argument("--dtype", default="float16", type=str)
175    parser.add_argument("--device", default="cuda", type=str)
176    parser.add_argument("--repeat", default="1", type=int)
177    parser.add_argument("--outfile", default="stdout", type=str)
178    parser.add_argument("--star", default=False, action="store_true")
179
180    args = parser.parse_args()
181
182    if args.outfile == "stdout":
183        outfile = sys.stdout
184    elif args.outfile == "stderr":
185        outfile = sys.stderr
186    else:
187        outfile = open(args.outfile, "a")
188
189    ops = args.ops.split(",")
190
191    b = args.b
192
193    m_list = args.m or [1024]
194    n_list = args.n or [None]
195    k_list = args.k or [None]
196    bm_list = args.bm or [16]
197    bk_list = args.bk or [None]
198    split_n_list = args.split_n or [None]
199    tile_m_list = args.tile_m or [None]
200    tile_n_list = args.tile_n or [None]
201    group_size_list = args.group_size or [None]
202    num_warps_list = args.num_warps or [None]
203    num_stages_list = args.num_stages or [None]
204    sparsity_list = args.sparsity or [0.5]
205    dtype = getattr(torch, args.dtype)
206
207    if args.star > 0:
208        import torch.sparse._triton_ops
209
210        assert {len(m_list), len(n_list), len(k_list), len(bm_list), len(bk_list)} == {
211            1
212        }
213        m = m_list[0]
214        n = n_list[0] or m
215        k = k_list[0] or m
216        bm = bm_list[0]
217        bk = bk_list[0] or bm
218        if "bsr_scatter_mm6" in ops:
219            meta = torch.sparse._triton_ops.scatter_mm_meta(m, k, n, bm, bk)
220        elif "bsr_dense_mm_with_meta" in ops:
221            meta = torch.sparse._triton_ops.bsr_dense_mm_meta(m, k, n, bm, bk)
222        else:
223            raise NotImplementedError(f"--star not implemented for operations in {ops}")
224        if "bsr_scatter_mm6" in ops:
225            if split_n_list[0] is None:
226                split_n_list = [
227                    meta["SPLIT_N"] // 2,
228                    meta["SPLIT_N"],
229                    meta["SPLIT_N"] * 2,
230                ][int(meta["SPLIT_N"] == 1) :]
231            elif split_n_list[0] == 0:
232                split_n_list = [meta["SPLIT_N"]]
233            if tile_m_list[0] is None:
234                tile_m_list = [meta["TILE_M"] // 2, meta["TILE_M"], meta["TILE_M"] * 2][
235                    int(meta["TILE_M"] == 16) :
236                ]
237            elif tile_m_list[0] == 0:
238                tile_m_list = [meta["TILE_M"]]
239            if tile_n_list[0] is None:
240                tile_n_list = [meta["TILE_N"] // 2, meta["TILE_N"], meta["TILE_N"] * 2][
241                    int(meta["TILE_N"] == 16) :
242                ]
243            elif tile_n_list[0] == 0:
244                tile_n_list = [meta["TILE_N"]]
245            if group_size_list[0] is None:
246                group_size_list = [
247                    meta["GROUP_SIZE"] - 1,
248                    meta["GROUP_SIZE"],
249                    meta["GROUP_SIZE"] + 1,
250                ][int(meta["GROUP_SIZE"] == 1) :]
251            elif group_size_list[0] == 0:
252                group_size_list = [meta["GROUP_SIZE"]]
253        if "bsr_dense_mm_with_meta" in ops:
254            if group_size_list[0] is None:
255                group_size_list = [
256                    meta["GROUP_SIZE_ROW"] - 1,
257                    meta["GROUP_SIZE_ROW"],
258                    meta["GROUP_SIZE_ROW"] + 1,
259                ][int(meta["GROUP_SIZE_ROW"] == 1) :]
260            elif group_size_list[0] == 0:
261                group_size_list = [meta["GROUP_SIZE_ROW"]]
262        if num_warps_list[0] is None:
263            num_warps_list = [
264                meta["num_warps"] // 2,
265                meta["num_warps"],
266                meta["num_warps"] * 2,
267            ][int(meta["num_warps"] == 1) :]
268        elif num_warps_list[0] == 0:
269            num_warps_list = [meta["num_warps"]]
270        if num_stages_list[0] is None:
271            num_stages_list = [
272                meta["num_stages"] - 1,
273                meta["num_stages"],
274                meta["num_stages"] + 1,
275            ][int(meta["num_stages"] == 1) :]
276        elif num_stages_list[0] == 0:
277            num_stages_list = [meta["num_stages"]]
278
279    device = args.device
280    dense_dense_mm_sizes = set()
281    target_performance = None
282    performance_rtol = 1e-2
283
284    best_messages = []
285
286    @atexit.register
287    def show_best_messages(best_messages=best_messages):
288        print("TOP 10:")
289        for m in best_messages[-10:]:
290            print(m)
291        sys.stdout.flush()
292
293    for m, k, n, bm, bk, sparsity in itertools.product(
294        m_list, k_list, n_list, bm_list, bk_list, sparsity_list
295    ):
296        k = k or m
297        n = n or m
298        bk = bk or bm
299
300        if bm > m or bk > k:
301            # Skip invalid parameter combinations
302            continue
303
304        blocksize = (bm, bk)
305
306        if isinstance(sparsity, int):
307            # integer sparsity value corresponds to desired nnz value
308            sparsity = 1 - bk * bm * sparsity / (m * k)
309
310        if sparsity > 1 or sparsity < 0:
311            continue
312
313        x = create_blocked_tensor(
314            b, m, k, blocksize, sparsity, dtype, device
315        ).to_sparse_bsr(blocksize)
316
317        # recompute sparsity
318        sparsity = 1 - bk * bm * x._nnz() / (m * k)
319
320        y = make_tensor(k, n, dtype=dtype, device=device)
321
322        bsr_size = f"{b}x{m}x{k}" if b > 0 else f"{k}x{n}"
323
324        for op in ops:
325            if op == "dense_dense_mm":
326                if (m, k, n) in dense_dense_mm_sizes:
327                    # Skip already benchmarked cases
328                    continue
329                dense_dense_mm_sizes.add((m, k, n))
330            best_tflops = 0
331            for (
332                split_n,
333                num_warps,
334                num_stages,
335                tile_m,
336                tile_n,
337                group_size,
338            ) in itertools.product(
339                split_n_list,
340                num_warps_list,
341                num_stages_list,
342                tile_m_list,
343                tile_n_list,
344                group_size_list,
345            ):
346                if (
347                    (tile_m or 0) > bm
348                    or (tile_n or 0) > n // (split_n or 1)
349                    or n % (split_n or 1) != 0
350                    or (split_n or 0) > n
351                ):
352                    # Skip invalid parameter combinations
353                    continue
354                test_func = globals()["test_" + op]
355                meta = dict(
356                    bsr_scatter_mm6=dict(
357                        SPLIT_N=split_n,
358                        TILE_M=tile_m,
359                        TILE_N=tile_n,
360                        GROUP_SIZE=group_size,
361                        num_stages=num_stages,
362                        num_warps=num_warps,
363                    ),
364                    bsr_dense_mm_with_meta=dict(
365                        GROUP_SIZE_ROW=group_size,
366                        num_stages=num_stages,
367                        num_warps=num_warps,
368                    ),
369                ).get(op, {})
370
371                meta_str = ";".join(
372                    f"{k}={v}" for k, v in meta.items() if v is not None
373                )
374                time_ms_lst = []
375                performance_tflops_lst = []
376                for r in range(args.repeat):
377                    try:
378                        time_ms, performance_tflops = test_func(x, y, **meta)
379                    except triton.compiler.OutOfResources as msg:
380                        print(
381                            f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
382                            f" blocksize={bm}x{bk} OutOfResources",
383                            file=outfile,
384                        )
385                        continue
386                    except AssertionError:
387                        raise
388                    except Exception as msg:
389                        msg = str(msg).split("\n", 1)[0]
390                        print(
391                            f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
392                            f" blocksize={bm}x{bk} {msg}",
393                            file=outfile,
394                        )
395                        continue
396                    time_ms_lst.append(time_ms)
397                    performance_tflops_lst.append(performance_tflops)
398                    mark = ""
399                    if op == "dense_dense_mm":
400                        if target_performance is None:
401                            target_performance = performance_tflops
402                    elif target_performance is not None:
403                        if (
404                            abs(1 - performance_tflops / target_performance)
405                            < performance_rtol
406                        ):
407                            mark += " @@@"
408                    if best_tflops < performance_tflops:
409                        best_tflops = performance_tflops
410                        best_message = (
411                            f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})"
412                            f" blocksize={bm}x{bk} time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS"
413                        )
414                        if best_message not in best_messages:
415                            best_messages.append(best_message)
416                        mark += " !!!"
417                    print(
418                        f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})"
419                        f" blocksize={bm}x{bk}"
420                        f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS{mark}",
421                        file=outfile,
422                    )
423                    outfile.flush()
424                if args.repeat > 1:
425                    avg_time_ms = sum(time_ms_lst) / len(time_ms_lst)
426                    avg_performance_tflops = sum(performance_tflops_lst) / len(
427                        performance_tflops_lst
428                    )
429                    print(
430                        f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
431                        f" blocksize={bm}x{bk}"
432                        f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS [AVERAGE]",
433                        file=outfile,
434                    )
435                    outfile.flush()
436                if op not in {"bsr_scatter_mm6", "bsr_dense_mm_with_meta"}:
437                    # Break on operations that do not consume parameters
438                    break
439