1import torch 2from torch._inductor.runtime.benchmarking import benchmarker 3 4 5def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): 6 assert ( 7 sparsity <= 1.0 and sparsity >= 0.0 8 ), "sparsity should be a value between 0 and 1" 9 assert M % blocksize[0] == 0 10 assert N % blocksize[1] == 0 11 shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :] 12 A = torch.bernoulli(torch.full(shape, 1 - sparsity, dtype=dtype, device=device)) 13 expected_nnz = int((1 - sparsity) * M * N / (blocksize[0] * blocksize[1])) 14 nonzero_indices = A.flatten().nonzero() 15 actual_nnz = nonzero_indices.shape[0] 16 if actual_nnz > expected_nnz: 17 selected_nonzeros = torch.randperm(actual_nnz)[: actual_nnz - expected_nnz] 18 A.flatten()[nonzero_indices[selected_nonzeros]] = 0 19 elif actual_nnz < expected_nnz: 20 zero_indices = (A == 0).flatten().nonzero() 21 selected_zeros = torch.randperm(zero_indices.shape[0])[ 22 : expected_nnz - actual_nnz 23 ] 24 A.flatten()[zero_indices[selected_zeros]] = 1 25 A = torch.repeat_interleave(A, blocksize[0], dim=-2) 26 A = torch.repeat_interleave(A, blocksize[1], dim=-1) 27 return A 28 29 30def _test_worker(test_func): 31 ms, ms_min, ms_max = benchmarker.benchmark_gpu( 32 test_func, warmup=500, rep=100, fast_flush=False 33 ) 34 35 tflops = 2 * m * k * n * 1e-12 / (ms * 1e-3) 36 return ms, tflops 37 38 39def test_dense_dense_mm(x, y, **meta): 40 def test_func(x=x.to_dense(), y=y): 41 return torch.matmul(x, y) 42 43 return _test_worker(test_func) 44 45 46def test_torch_matmul(x, y, **meta): 47 def test_func(x=x, y=y): 48 return torch.matmul(x, y) 49 50 return _test_worker(test_func) 51 52 53def test_bsr_dense_mm(x, y, **meta): 54 from torch.sparse._triton_ops import bsr_dense_mm 55 56 def test_func(x=x, y=y): 57 return bsr_dense_mm( 58 x, y, meta=dict(GROUP_SIZE_ROW=4, num_stages=1, num_warps=4) 59 ) 60 61 return _test_worker(test_func) 62 63 64def test_bsr_dense_mm_with_meta(x, y, **meta): 65 from torch.sparse._triton_ops import bsr_dense_mm 66 67 def test_func(x=x, y=y, meta=meta): 68 return bsr_dense_mm(x, y, meta=meta) 69 70 return _test_worker(test_func) 71 72 73def test_bsr_scatter_mm2(x, y, **meta): 74 from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data 75 76 indices_data = bsr_scatter_mm_indices_data( 77 x, y, indices_format="scatter_mm", **meta 78 ) 79 80 def test_func(x=x, y=y): 81 return bsr_scatter_mm(x, y, indices_data=indices_data) 82 83 return _test_worker(test_func) 84 85 86def test_bsr_scatter_mm6(x, y, **meta): 87 from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data 88 89 indices_data = bsr_scatter_mm_indices_data( 90 x, y, indices_format="bsr_strided_mm_compressed", **meta 91 ) 92 93 def test_func(x=x, y=y): 94 return bsr_scatter_mm(x, y, indices_data=indices_data) 95 96 return _test_worker(test_func) 97 98 99def test_bsr_scatter_mm(x, y, **meta): 100 from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data 101 102 def test_func(x=x, y=y): 103 indices_data = bsr_scatter_mm_indices_data( 104 x, y, indices_format="bsr_strided_mm_compressed", **meta 105 ) 106 return bsr_scatter_mm(x, y, indices_data=indices_data) 107 108 return _test_worker(test_func) 109 110 111def test_linear(x, y, **meta): 112 import torch.nn.functional as F 113 114 def test_func(x=x, y=y.transpose(-2, -1)): 115 return F.linear(y, x) 116 117 return _test_worker(test_func) 118 119 120if __name__ == "__main__": 121 import argparse 122 import atexit 123 import itertools 124 import sys 125 126 import triton 127 128 from torch.testing import make_tensor 129 130 torch.manual_seed(0) 131 132 def integer_list(a): 133 return list(map(int, a.split(","))) 134 135 def float_list(a): 136 return list(map(float, a.split(","))) 137 138 def integer_or_float_list(a): 139 lst = [] 140 for n in a.split(","): 141 if n.count(":") == 1: 142 start, end = map(int, n.split(":")) 143 lst.extend(range(start, end)) 144 elif n.count(":") == 2: 145 start, end, step = map(int, n.split(":")) 146 lst.extend(range(start, end, step)) 147 elif "." in n: 148 lst.append(float(n)) 149 else: 150 lst.append(int(n)) 151 return lst 152 153 parser = argparse.ArgumentParser(description="SpTritonOps") 154 155 parser.add_argument( 156 "--ops", 157 default="dense_dense_mm,bsr_dense_mm,bsr_scatter_mm6", 158 type=str, 159 ) 160 parser.add_argument("--b", default="0", type=int) 161 162 parser.add_argument("--m", default="1024", type=integer_list) 163 parser.add_argument("--k", default=None, type=integer_list) 164 parser.add_argument("--n", default=None, type=integer_list) 165 parser.add_argument("--bm", default="16", type=integer_list) 166 parser.add_argument("--bk", default=None, type=integer_list) 167 parser.add_argument("--tile_m", default=None, type=integer_list) 168 parser.add_argument("--tile_n", default=None, type=integer_list) 169 parser.add_argument("--split_n", default=None, type=integer_list) 170 parser.add_argument("--group_size", default=None, type=integer_list) 171 parser.add_argument("--num_warps", default=None, type=integer_list) 172 parser.add_argument("--num_stages", default=None, type=integer_list) 173 parser.add_argument("--sparsity", default="0.5", type=integer_or_float_list) 174 parser.add_argument("--dtype", default="float16", type=str) 175 parser.add_argument("--device", default="cuda", type=str) 176 parser.add_argument("--repeat", default="1", type=int) 177 parser.add_argument("--outfile", default="stdout", type=str) 178 parser.add_argument("--star", default=False, action="store_true") 179 180 args = parser.parse_args() 181 182 if args.outfile == "stdout": 183 outfile = sys.stdout 184 elif args.outfile == "stderr": 185 outfile = sys.stderr 186 else: 187 outfile = open(args.outfile, "a") 188 189 ops = args.ops.split(",") 190 191 b = args.b 192 193 m_list = args.m or [1024] 194 n_list = args.n or [None] 195 k_list = args.k or [None] 196 bm_list = args.bm or [16] 197 bk_list = args.bk or [None] 198 split_n_list = args.split_n or [None] 199 tile_m_list = args.tile_m or [None] 200 tile_n_list = args.tile_n or [None] 201 group_size_list = args.group_size or [None] 202 num_warps_list = args.num_warps or [None] 203 num_stages_list = args.num_stages or [None] 204 sparsity_list = args.sparsity or [0.5] 205 dtype = getattr(torch, args.dtype) 206 207 if args.star > 0: 208 import torch.sparse._triton_ops 209 210 assert {len(m_list), len(n_list), len(k_list), len(bm_list), len(bk_list)} == { 211 1 212 } 213 m = m_list[0] 214 n = n_list[0] or m 215 k = k_list[0] or m 216 bm = bm_list[0] 217 bk = bk_list[0] or bm 218 if "bsr_scatter_mm6" in ops: 219 meta = torch.sparse._triton_ops.scatter_mm_meta(m, k, n, bm, bk) 220 elif "bsr_dense_mm_with_meta" in ops: 221 meta = torch.sparse._triton_ops.bsr_dense_mm_meta(m, k, n, bm, bk) 222 else: 223 raise NotImplementedError(f"--star not implemented for operations in {ops}") 224 if "bsr_scatter_mm6" in ops: 225 if split_n_list[0] is None: 226 split_n_list = [ 227 meta["SPLIT_N"] // 2, 228 meta["SPLIT_N"], 229 meta["SPLIT_N"] * 2, 230 ][int(meta["SPLIT_N"] == 1) :] 231 elif split_n_list[0] == 0: 232 split_n_list = [meta["SPLIT_N"]] 233 if tile_m_list[0] is None: 234 tile_m_list = [meta["TILE_M"] // 2, meta["TILE_M"], meta["TILE_M"] * 2][ 235 int(meta["TILE_M"] == 16) : 236 ] 237 elif tile_m_list[0] == 0: 238 tile_m_list = [meta["TILE_M"]] 239 if tile_n_list[0] is None: 240 tile_n_list = [meta["TILE_N"] // 2, meta["TILE_N"], meta["TILE_N"] * 2][ 241 int(meta["TILE_N"] == 16) : 242 ] 243 elif tile_n_list[0] == 0: 244 tile_n_list = [meta["TILE_N"]] 245 if group_size_list[0] is None: 246 group_size_list = [ 247 meta["GROUP_SIZE"] - 1, 248 meta["GROUP_SIZE"], 249 meta["GROUP_SIZE"] + 1, 250 ][int(meta["GROUP_SIZE"] == 1) :] 251 elif group_size_list[0] == 0: 252 group_size_list = [meta["GROUP_SIZE"]] 253 if "bsr_dense_mm_with_meta" in ops: 254 if group_size_list[0] is None: 255 group_size_list = [ 256 meta["GROUP_SIZE_ROW"] - 1, 257 meta["GROUP_SIZE_ROW"], 258 meta["GROUP_SIZE_ROW"] + 1, 259 ][int(meta["GROUP_SIZE_ROW"] == 1) :] 260 elif group_size_list[0] == 0: 261 group_size_list = [meta["GROUP_SIZE_ROW"]] 262 if num_warps_list[0] is None: 263 num_warps_list = [ 264 meta["num_warps"] // 2, 265 meta["num_warps"], 266 meta["num_warps"] * 2, 267 ][int(meta["num_warps"] == 1) :] 268 elif num_warps_list[0] == 0: 269 num_warps_list = [meta["num_warps"]] 270 if num_stages_list[0] is None: 271 num_stages_list = [ 272 meta["num_stages"] - 1, 273 meta["num_stages"], 274 meta["num_stages"] + 1, 275 ][int(meta["num_stages"] == 1) :] 276 elif num_stages_list[0] == 0: 277 num_stages_list = [meta["num_stages"]] 278 279 device = args.device 280 dense_dense_mm_sizes = set() 281 target_performance = None 282 performance_rtol = 1e-2 283 284 best_messages = [] 285 286 @atexit.register 287 def show_best_messages(best_messages=best_messages): 288 print("TOP 10:") 289 for m in best_messages[-10:]: 290 print(m) 291 sys.stdout.flush() 292 293 for m, k, n, bm, bk, sparsity in itertools.product( 294 m_list, k_list, n_list, bm_list, bk_list, sparsity_list 295 ): 296 k = k or m 297 n = n or m 298 bk = bk or bm 299 300 if bm > m or bk > k: 301 # Skip invalid parameter combinations 302 continue 303 304 blocksize = (bm, bk) 305 306 if isinstance(sparsity, int): 307 # integer sparsity value corresponds to desired nnz value 308 sparsity = 1 - bk * bm * sparsity / (m * k) 309 310 if sparsity > 1 or sparsity < 0: 311 continue 312 313 x = create_blocked_tensor( 314 b, m, k, blocksize, sparsity, dtype, device 315 ).to_sparse_bsr(blocksize) 316 317 # recompute sparsity 318 sparsity = 1 - bk * bm * x._nnz() / (m * k) 319 320 y = make_tensor(k, n, dtype=dtype, device=device) 321 322 bsr_size = f"{b}x{m}x{k}" if b > 0 else f"{k}x{n}" 323 324 for op in ops: 325 if op == "dense_dense_mm": 326 if (m, k, n) in dense_dense_mm_sizes: 327 # Skip already benchmarked cases 328 continue 329 dense_dense_mm_sizes.add((m, k, n)) 330 best_tflops = 0 331 for ( 332 split_n, 333 num_warps, 334 num_stages, 335 tile_m, 336 tile_n, 337 group_size, 338 ) in itertools.product( 339 split_n_list, 340 num_warps_list, 341 num_stages_list, 342 tile_m_list, 343 tile_n_list, 344 group_size_list, 345 ): 346 if ( 347 (tile_m or 0) > bm 348 or (tile_n or 0) > n // (split_n or 1) 349 or n % (split_n or 1) != 0 350 or (split_n or 0) > n 351 ): 352 # Skip invalid parameter combinations 353 continue 354 test_func = globals()["test_" + op] 355 meta = dict( 356 bsr_scatter_mm6=dict( 357 SPLIT_N=split_n, 358 TILE_M=tile_m, 359 TILE_N=tile_n, 360 GROUP_SIZE=group_size, 361 num_stages=num_stages, 362 num_warps=num_warps, 363 ), 364 bsr_dense_mm_with_meta=dict( 365 GROUP_SIZE_ROW=group_size, 366 num_stages=num_stages, 367 num_warps=num_warps, 368 ), 369 ).get(op, {}) 370 371 meta_str = ";".join( 372 f"{k}={v}" for k, v in meta.items() if v is not None 373 ) 374 time_ms_lst = [] 375 performance_tflops_lst = [] 376 for r in range(args.repeat): 377 try: 378 time_ms, performance_tflops = test_func(x, y, **meta) 379 except triton.compiler.OutOfResources as msg: 380 print( 381 f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})" 382 f" blocksize={bm}x{bk} OutOfResources", 383 file=outfile, 384 ) 385 continue 386 except AssertionError: 387 raise 388 except Exception as msg: 389 msg = str(msg).split("\n", 1)[0] 390 print( 391 f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})" 392 f" blocksize={bm}x{bk} {msg}", 393 file=outfile, 394 ) 395 continue 396 time_ms_lst.append(time_ms) 397 performance_tflops_lst.append(performance_tflops) 398 mark = "" 399 if op == "dense_dense_mm": 400 if target_performance is None: 401 target_performance = performance_tflops 402 elif target_performance is not None: 403 if ( 404 abs(1 - performance_tflops / target_performance) 405 < performance_rtol 406 ): 407 mark += " @@@" 408 if best_tflops < performance_tflops: 409 best_tflops = performance_tflops 410 best_message = ( 411 f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})" 412 f" blocksize={bm}x{bk} time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS" 413 ) 414 if best_message not in best_messages: 415 best_messages.append(best_message) 416 mark += " !!!" 417 print( 418 f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})" 419 f" blocksize={bm}x{bk}" 420 f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS{mark}", 421 file=outfile, 422 ) 423 outfile.flush() 424 if args.repeat > 1: 425 avg_time_ms = sum(time_ms_lst) / len(time_ms_lst) 426 avg_performance_tflops = sum(performance_tflops_lst) / len( 427 performance_tflops_lst 428 ) 429 print( 430 f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})" 431 f" blocksize={bm}x{bk}" 432 f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS [AVERAGE]", 433 file=outfile, 434 ) 435 outfile.flush() 436 if op not in {"bsr_scatter_mm6", "bsr_dense_mm_with_meta"}: 437 # Break on operations that do not consume parameters 438 break 439