1import argparse 2import sys 3 4import torch 5 6from .utils import Event, gen_sparse_coo, gen_sparse_coo_and_csr, gen_sparse_csr 7 8 9def test_sparse_csr(m, nnz, test_count): 10 start_timer = Event(enable_timing=True) 11 stop_timer = Event(enable_timing=True) 12 13 csr = gen_sparse_csr((m, m), nnz) 14 vector = torch.randn(m, dtype=torch.double) 15 16 times = [] 17 for _ in range(test_count): 18 start_timer.record() 19 csr.matmul(vector) 20 stop_timer.record() 21 times.append(start_timer.elapsed_time(stop_timer)) 22 23 return sum(times) / len(times) 24 25 26def test_sparse_coo(m, nnz, test_count): 27 start_timer = Event(enable_timing=True) 28 stop_timer = Event(enable_timing=True) 29 30 coo = gen_sparse_coo((m, m), nnz) 31 vector = torch.randn(m, dtype=torch.double) 32 33 times = [] 34 for _ in range(test_count): 35 start_timer.record() 36 coo.matmul(vector) 37 stop_timer.record() 38 times.append(start_timer.elapsed_time(stop_timer)) 39 40 return sum(times) / len(times) 41 42 43def test_sparse_coo_and_csr(m, nnz, test_count): 44 start = Event(enable_timing=True) 45 stop = Event(enable_timing=True) 46 47 coo, csr = gen_sparse_coo_and_csr((m, m), nnz) 48 vector = torch.randn(m, dtype=torch.double) 49 50 times = [] 51 for _ in range(test_count): 52 start.record() 53 coo.matmul(vector) 54 stop.record() 55 56 times.append(start.elapsed_time(stop)) 57 58 coo_mean_time = sum(times) / len(times) 59 60 times = [] 61 for _ in range(test_count): 62 start.record() 63 csr.matmul(vector) 64 stop.record() 65 times.append(start.elapsed_time(stop)) 66 67 csr_mean_time = sum(times) / len(times) 68 69 return coo_mean_time, csr_mean_time 70 71 72if __name__ == "__main__": 73 parser = argparse.ArgumentParser(description="SpMV") 74 75 parser.add_argument("--format", default="csr", type=str) 76 parser.add_argument("--m", default="1000", type=int) 77 parser.add_argument("--nnz-ratio", "--nnz_ratio", default="0.1", type=float) 78 parser.add_argument("--outfile", default="stdout", type=str) 79 parser.add_argument("--test-count", "--test_count", default="10", type=int) 80 81 args = parser.parse_args() 82 83 if args.outfile == "stdout": 84 outfile = sys.stdout 85 elif args.outfile == "stderr": 86 outfile = sys.stderr 87 else: 88 outfile = open(args.outfile, "a") 89 90 test_count = args.test_count 91 m = args.m 92 nnz_ratio = args.nnz_ratio 93 94 nnz = int(nnz_ratio * m * m) 95 if args.format == "csr": 96 time = test_sparse_csr(m, nnz, test_count) 97 elif args.format == "coo": 98 time = test_sparse_coo(m, nnz, test_count) 99 elif args.format == "both": 100 time_coo, time_csr = test_sparse_coo_and_csr(m, nnz, test_count) 101 102 if args.format != "both": 103 print( 104 "format=", 105 args.format, 106 " nnz_ratio=", 107 nnz_ratio, 108 " m=", 109 m, 110 " time=", 111 time, 112 file=outfile, 113 ) 114 else: 115 print( 116 "format=coo", 117 " nnz_ratio=", 118 nnz_ratio, 119 " m=", 120 m, 121 " time=", 122 time_coo, 123 file=outfile, 124 ) 125 print( 126 "format=csr", 127 " nnz_ratio=", 128 nnz_ratio, 129 " m=", 130 m, 131 " time=", 132 time_csr, 133 file=outfile, 134 ) 135