import torch from torch._inductor.runtime.benchmarking import benchmarker def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): assert ( sparsity <= 1.0 and sparsity >= 0.0 ), "sparsity should be a value between 0 and 1" assert M % blocksize[0] == 0 assert N % blocksize[1] == 0 shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :] A = torch.bernoulli(torch.full(shape, 1 - sparsity, dtype=dtype, device=device)) expected_nnz = int((1 - sparsity) * M * N / (blocksize[0] * blocksize[1])) nonzero_indices = A.flatten().nonzero() actual_nnz = nonzero_indices.shape[0] if actual_nnz > expected_nnz: selected_nonzeros = torch.randperm(actual_nnz)[: actual_nnz - expected_nnz] A.flatten()[nonzero_indices[selected_nonzeros]] = 0 elif actual_nnz < expected_nnz: zero_indices = (A == 0).flatten().nonzero() selected_zeros = torch.randperm(zero_indices.shape[0])[ : expected_nnz - actual_nnz ] A.flatten()[zero_indices[selected_zeros]] = 1 A = torch.repeat_interleave(A, blocksize[0], dim=-2) A = torch.repeat_interleave(A, blocksize[1], dim=-1) return A def _test_worker(test_func): ms, ms_min, ms_max = benchmarker.benchmark_gpu( test_func, warmup=500, rep=100, fast_flush=False ) tflops = 2 * m * k * n * 1e-12 / (ms * 1e-3) return ms, tflops def test_dense_dense_mm(x, y, **meta): def test_func(x=x.to_dense(), y=y): return torch.matmul(x, y) return _test_worker(test_func) def test_torch_matmul(x, y, **meta): def test_func(x=x, y=y): return torch.matmul(x, y) return _test_worker(test_func) def test_bsr_dense_mm(x, y, **meta): from torch.sparse._triton_ops import bsr_dense_mm def test_func(x=x, y=y): return bsr_dense_mm( x, y, meta=dict(GROUP_SIZE_ROW=4, num_stages=1, num_warps=4) ) return _test_worker(test_func) def test_bsr_dense_mm_with_meta(x, y, **meta): from torch.sparse._triton_ops import bsr_dense_mm def test_func(x=x, y=y, meta=meta): return bsr_dense_mm(x, y, meta=meta) return _test_worker(test_func) def test_bsr_scatter_mm2(x, y, **meta): from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data indices_data = bsr_scatter_mm_indices_data( x, y, indices_format="scatter_mm", **meta ) def test_func(x=x, y=y): return bsr_scatter_mm(x, y, indices_data=indices_data) return _test_worker(test_func) def test_bsr_scatter_mm6(x, y, **meta): from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data indices_data = bsr_scatter_mm_indices_data( x, y, indices_format="bsr_strided_mm_compressed", **meta ) def test_func(x=x, y=y): return bsr_scatter_mm(x, y, indices_data=indices_data) return _test_worker(test_func) def test_bsr_scatter_mm(x, y, **meta): from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data def test_func(x=x, y=y): indices_data = bsr_scatter_mm_indices_data( x, y, indices_format="bsr_strided_mm_compressed", **meta ) return bsr_scatter_mm(x, y, indices_data=indices_data) return _test_worker(test_func) def test_linear(x, y, **meta): import torch.nn.functional as F def test_func(x=x, y=y.transpose(-2, -1)): return F.linear(y, x) return _test_worker(test_func) if __name__ == "__main__": import argparse import atexit import itertools import sys import triton from torch.testing import make_tensor torch.manual_seed(0) def integer_list(a): return list(map(int, a.split(","))) def float_list(a): return list(map(float, a.split(","))) def integer_or_float_list(a): lst = [] for n in a.split(","): if n.count(":") == 1: start, end = map(int, n.split(":")) lst.extend(range(start, end)) elif n.count(":") == 2: start, end, step = map(int, n.split(":")) lst.extend(range(start, end, step)) elif "." in n: lst.append(float(n)) else: lst.append(int(n)) return lst parser = argparse.ArgumentParser(description="SpTritonOps") parser.add_argument( "--ops", default="dense_dense_mm,bsr_dense_mm,bsr_scatter_mm6", type=str, ) parser.add_argument("--b", default="0", type=int) parser.add_argument("--m", default="1024", type=integer_list) parser.add_argument("--k", default=None, type=integer_list) parser.add_argument("--n", default=None, type=integer_list) parser.add_argument("--bm", default="16", type=integer_list) parser.add_argument("--bk", default=None, type=integer_list) parser.add_argument("--tile_m", default=None, type=integer_list) parser.add_argument("--tile_n", default=None, type=integer_list) parser.add_argument("--split_n", default=None, type=integer_list) parser.add_argument("--group_size", default=None, type=integer_list) parser.add_argument("--num_warps", default=None, type=integer_list) parser.add_argument("--num_stages", default=None, type=integer_list) parser.add_argument("--sparsity", default="0.5", type=integer_or_float_list) parser.add_argument("--dtype", default="float16", type=str) parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--repeat", default="1", type=int) parser.add_argument("--outfile", default="stdout", type=str) parser.add_argument("--star", default=False, action="store_true") args = parser.parse_args() if args.outfile == "stdout": outfile = sys.stdout elif args.outfile == "stderr": outfile = sys.stderr else: outfile = open(args.outfile, "a") ops = args.ops.split(",") b = args.b m_list = args.m or [1024] n_list = args.n or [None] k_list = args.k or [None] bm_list = args.bm or [16] bk_list = args.bk or [None] split_n_list = args.split_n or [None] tile_m_list = args.tile_m or [None] tile_n_list = args.tile_n or [None] group_size_list = args.group_size or [None] num_warps_list = args.num_warps or [None] num_stages_list = args.num_stages or [None] sparsity_list = args.sparsity or [0.5] dtype = getattr(torch, args.dtype) if args.star > 0: import torch.sparse._triton_ops assert {len(m_list), len(n_list), len(k_list), len(bm_list), len(bk_list)} == { 1 } m = m_list[0] n = n_list[0] or m k = k_list[0] or m bm = bm_list[0] bk = bk_list[0] or bm if "bsr_scatter_mm6" in ops: meta = torch.sparse._triton_ops.scatter_mm_meta(m, k, n, bm, bk) elif "bsr_dense_mm_with_meta" in ops: meta = torch.sparse._triton_ops.bsr_dense_mm_meta(m, k, n, bm, bk) else: raise NotImplementedError(f"--star not implemented for operations in {ops}") if "bsr_scatter_mm6" in ops: if split_n_list[0] is None: split_n_list = [ meta["SPLIT_N"] // 2, meta["SPLIT_N"], meta["SPLIT_N"] * 2, ][int(meta["SPLIT_N"] == 1) :] elif split_n_list[0] == 0: split_n_list = [meta["SPLIT_N"]] if tile_m_list[0] is None: tile_m_list = [meta["TILE_M"] // 2, meta["TILE_M"], meta["TILE_M"] * 2][ int(meta["TILE_M"] == 16) : ] elif tile_m_list[0] == 0: tile_m_list = [meta["TILE_M"]] if tile_n_list[0] is None: tile_n_list = [meta["TILE_N"] // 2, meta["TILE_N"], meta["TILE_N"] * 2][ int(meta["TILE_N"] == 16) : ] elif tile_n_list[0] == 0: tile_n_list = [meta["TILE_N"]] if group_size_list[0] is None: group_size_list = [ meta["GROUP_SIZE"] - 1, meta["GROUP_SIZE"], meta["GROUP_SIZE"] + 1, ][int(meta["GROUP_SIZE"] == 1) :] elif group_size_list[0] == 0: group_size_list = [meta["GROUP_SIZE"]] if "bsr_dense_mm_with_meta" in ops: if group_size_list[0] is None: group_size_list = [ meta["GROUP_SIZE_ROW"] - 1, meta["GROUP_SIZE_ROW"], meta["GROUP_SIZE_ROW"] + 1, ][int(meta["GROUP_SIZE_ROW"] == 1) :] elif group_size_list[0] == 0: group_size_list = [meta["GROUP_SIZE_ROW"]] if num_warps_list[0] is None: num_warps_list = [ meta["num_warps"] // 2, meta["num_warps"], meta["num_warps"] * 2, ][int(meta["num_warps"] == 1) :] elif num_warps_list[0] == 0: num_warps_list = [meta["num_warps"]] if num_stages_list[0] is None: num_stages_list = [ meta["num_stages"] - 1, meta["num_stages"], meta["num_stages"] + 1, ][int(meta["num_stages"] == 1) :] elif num_stages_list[0] == 0: num_stages_list = [meta["num_stages"]] device = args.device dense_dense_mm_sizes = set() target_performance = None performance_rtol = 1e-2 best_messages = [] @atexit.register def show_best_messages(best_messages=best_messages): print("TOP 10:") for m in best_messages[-10:]: print(m) sys.stdout.flush() for m, k, n, bm, bk, sparsity in itertools.product( m_list, k_list, n_list, bm_list, bk_list, sparsity_list ): k = k or m n = n or m bk = bk or bm if bm > m or bk > k: # Skip invalid parameter combinations continue blocksize = (bm, bk) if isinstance(sparsity, int): # integer sparsity value corresponds to desired nnz value sparsity = 1 - bk * bm * sparsity / (m * k) if sparsity > 1 or sparsity < 0: continue x = create_blocked_tensor( b, m, k, blocksize, sparsity, dtype, device ).to_sparse_bsr(blocksize) # recompute sparsity sparsity = 1 - bk * bm * x._nnz() / (m * k) y = make_tensor(k, n, dtype=dtype, device=device) bsr_size = f"{b}x{m}x{k}" if b > 0 else f"{k}x{n}" for op in ops: if op == "dense_dense_mm": if (m, k, n) in dense_dense_mm_sizes: # Skip already benchmarked cases continue dense_dense_mm_sizes.add((m, k, n)) best_tflops = 0 for ( split_n, num_warps, num_stages, tile_m, tile_n, group_size, ) in itertools.product( split_n_list, num_warps_list, num_stages_list, tile_m_list, tile_n_list, group_size_list, ): if ( (tile_m or 0) > bm or (tile_n or 0) > n // (split_n or 1) or n % (split_n or 1) != 0 or (split_n or 0) > n ): # Skip invalid parameter combinations continue test_func = globals()["test_" + op] meta = dict( bsr_scatter_mm6=dict( SPLIT_N=split_n, TILE_M=tile_m, TILE_N=tile_n, GROUP_SIZE=group_size, num_stages=num_stages, num_warps=num_warps, ), bsr_dense_mm_with_meta=dict( GROUP_SIZE_ROW=group_size, num_stages=num_stages, num_warps=num_warps, ), ).get(op, {}) meta_str = ";".join( f"{k}={v}" for k, v in meta.items() if v is not None ) time_ms_lst = [] performance_tflops_lst = [] for r in range(args.repeat): try: time_ms, performance_tflops = test_func(x, y, **meta) except triton.compiler.OutOfResources as msg: print( f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})" f" blocksize={bm}x{bk} OutOfResources", file=outfile, ) continue except AssertionError: raise except Exception as msg: msg = str(msg).split("\n", 1)[0] print( f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})" f" blocksize={bm}x{bk} {msg}", file=outfile, ) continue time_ms_lst.append(time_ms) performance_tflops_lst.append(performance_tflops) mark = "" if op == "dense_dense_mm": if target_performance is None: target_performance = performance_tflops elif target_performance is not None: if ( abs(1 - performance_tflops / target_performance) < performance_rtol ): mark += " @@@" if best_tflops < performance_tflops: best_tflops = performance_tflops best_message = ( f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})" f" blocksize={bm}x{bk} time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS" ) if best_message not in best_messages: best_messages.append(best_message) mark += " !!!" print( f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})" f" blocksize={bm}x{bk}" f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS{mark}", file=outfile, ) outfile.flush() if args.repeat > 1: avg_time_ms = sum(time_ms_lst) / len(time_ms_lst) avg_performance_tflops = sum(performance_tflops_lst) / len( performance_tflops_lst ) print( f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})" f" blocksize={bm}x{bk}" f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS [AVERAGE]", file=outfile, ) outfile.flush() if op not in {"bsr_scatter_mm6", "bsr_dense_mm_with_meta"}: # Break on operations that do not consume parameters break