1import argparse 2import sys 3 4from utils import Event, gen_sparse_coo, gen_sparse_coo_and_csr, gen_sparse_csr 5 6import torch 7 8 9def test_sparse_csr(m, n, k, nnz, test_count): 10 start_timer = Event(enable_timing=True) 11 stop_timer = Event(enable_timing=True) 12 13 csr = gen_sparse_csr((m, k), nnz) 14 mat = torch.randn(k, n, dtype=torch.double) 15 16 times = [] 17 for _ in range(test_count): 18 start_timer.record() 19 csr.matmul(mat) 20 stop_timer.record() 21 times.append(start_timer.elapsed_time(stop_timer)) 22 23 return sum(times) / len(times) 24 25 26def test_sparse_coo(m, n, k, nnz, test_count): 27 start_timer = Event(enable_timing=True) 28 stop_timer = Event(enable_timing=True) 29 30 coo = gen_sparse_coo((m, k), nnz) 31 mat = torch.randn(k, n, dtype=torch.double) 32 33 times = [] 34 for _ in range(test_count): 35 start_timer.record() 36 coo.matmul(mat) 37 stop_timer.record() 38 times.append(start_timer.elapsed_time(stop_timer)) 39 40 return sum(times) / len(times) 41 42 43def test_sparse_coo_and_csr(m, n, k, nnz, test_count): 44 start = Event(enable_timing=True) 45 stop = Event(enable_timing=True) 46 47 coo, csr = gen_sparse_coo_and_csr((m, k), nnz) 48 mat = torch.randn((k, n), dtype=torch.double) 49 50 times = [] 51 for _ in range(test_count): 52 start.record() 53 coo.matmul(mat) 54 stop.record() 55 56 times.append(start.elapsed_time(stop)) 57 58 coo_mean_time = sum(times) / len(times) 59 60 times = [] 61 for _ in range(test_count): 62 start.record() 63 csr.matmul(mat) 64 stop.record() 65 times.append(start.elapsed_time(stop)) 66 67 csr_mean_time = sum(times) / len(times) 68 69 return coo_mean_time, csr_mean_time 70 71 72if __name__ == "__main__": 73 parser = argparse.ArgumentParser(description="SpMM") 74 75 parser.add_argument("--format", default="csr", type=str) 76 parser.add_argument("--m", default="1000", type=int) 77 parser.add_argument("--n", default="1000", type=int) 78 parser.add_argument("--k", default="1000", type=int) 79 parser.add_argument("--nnz-ratio", "--nnz_ratio", default="0.1", type=float) 80 parser.add_argument("--outfile", default="stdout", type=str) 81 parser.add_argument("--test-count", "--test_count", default="10", type=int) 82 83 args = parser.parse_args() 84 85 if args.outfile == "stdout": 86 outfile = sys.stdout 87 elif args.outfile == "stderr": 88 outfile = sys.stderr 89 else: 90 outfile = open(args.outfile, "a") 91 92 test_count = args.test_count 93 m = args.m 94 n = args.n 95 k = args.k 96 nnz_ratio = args.nnz_ratio 97 98 nnz = int(nnz_ratio * m * k) 99 if args.format == "csr": 100 time = test_sparse_csr(m, n, k, nnz, test_count) 101 elif args.format == "coo": 102 time = test_sparse_coo(m, n, k, nnz, test_count) 103 elif args.format == "both": 104 time_coo, time_csr = test_sparse_coo_and_csr(m, nnz, test_count) 105 106 if args.format == "both": 107 print( 108 "format=coo", 109 " nnz_ratio=", 110 nnz_ratio, 111 " m=", 112 m, 113 " n=", 114 n, 115 " k=", 116 k, 117 " time=", 118 time_coo, 119 file=outfile, 120 ) 121 print( 122 "format=csr", 123 " nnz_ratio=", 124 nnz_ratio, 125 " m=", 126 m, 127 " n=", 128 n, 129 " k=", 130 k, 131 " time=", 132 time_csr, 133 file=outfile, 134 ) 135 else: 136 print( 137 "format=", 138 args.format, 139 " nnz_ratio=", 140 nnz_ratio, 141 " m=", 142 m, 143 " n=", 144 n, 145 " k=", 146 k, 147 " time=", 148 time, 149 file=outfile, 150 ) 151