1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport sys 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerfrom .utils import Event, gen_sparse_coo, gen_sparse_coo_and_csr, gen_sparse_csr 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerdef test_sparse_csr(m, nnz, test_count): 10*da0073e9SAndroid Build Coastguard Worker start_timer = Event(enable_timing=True) 11*da0073e9SAndroid Build Coastguard Worker stop_timer = Event(enable_timing=True) 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker csr = gen_sparse_csr((m, m), nnz) 14*da0073e9SAndroid Build Coastguard Worker vector = torch.randn(m, dtype=torch.double) 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker times = [] 17*da0073e9SAndroid Build Coastguard Worker for _ in range(test_count): 18*da0073e9SAndroid Build Coastguard Worker start_timer.record() 19*da0073e9SAndroid Build Coastguard Worker csr.matmul(vector) 20*da0073e9SAndroid Build Coastguard Worker stop_timer.record() 21*da0073e9SAndroid Build Coastguard Worker times.append(start_timer.elapsed_time(stop_timer)) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker return sum(times) / len(times) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workerdef test_sparse_coo(m, nnz, test_count): 27*da0073e9SAndroid Build Coastguard Worker start_timer = Event(enable_timing=True) 28*da0073e9SAndroid Build Coastguard Worker stop_timer = Event(enable_timing=True) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker coo = gen_sparse_coo((m, m), nnz) 31*da0073e9SAndroid Build Coastguard Worker vector = torch.randn(m, dtype=torch.double) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker times = [] 34*da0073e9SAndroid Build Coastguard Worker for _ in range(test_count): 35*da0073e9SAndroid Build Coastguard Worker start_timer.record() 36*da0073e9SAndroid Build Coastguard Worker coo.matmul(vector) 37*da0073e9SAndroid Build Coastguard Worker stop_timer.record() 38*da0073e9SAndroid Build Coastguard Worker times.append(start_timer.elapsed_time(stop_timer)) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker return sum(times) / len(times) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Workerdef test_sparse_coo_and_csr(m, nnz, test_count): 44*da0073e9SAndroid Build Coastguard Worker start = Event(enable_timing=True) 45*da0073e9SAndroid Build Coastguard Worker stop = Event(enable_timing=True) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker coo, csr = gen_sparse_coo_and_csr((m, m), nnz) 48*da0073e9SAndroid Build Coastguard Worker vector = torch.randn(m, dtype=torch.double) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker times = [] 51*da0073e9SAndroid Build Coastguard Worker for _ in range(test_count): 52*da0073e9SAndroid Build Coastguard Worker start.record() 53*da0073e9SAndroid Build Coastguard Worker coo.matmul(vector) 54*da0073e9SAndroid Build Coastguard Worker stop.record() 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker times.append(start.elapsed_time(stop)) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker coo_mean_time = sum(times) / len(times) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker times = [] 61*da0073e9SAndroid Build Coastguard Worker for _ in range(test_count): 62*da0073e9SAndroid Build Coastguard Worker start.record() 63*da0073e9SAndroid Build Coastguard Worker csr.matmul(vector) 64*da0073e9SAndroid Build Coastguard Worker stop.record() 65*da0073e9SAndroid Build Coastguard Worker times.append(start.elapsed_time(stop)) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker csr_mean_time = sum(times) / len(times) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker return coo_mean_time, csr_mean_time 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 73*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="SpMV") 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--format", default="csr", type=str) 76*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--m", default="1000", type=int) 77*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--nnz-ratio", "--nnz_ratio", default="0.1", type=float) 78*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--outfile", default="stdout", type=str) 79*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--test-count", "--test_count", default="10", type=int) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker if args.outfile == "stdout": 84*da0073e9SAndroid Build Coastguard Worker outfile = sys.stdout 85*da0073e9SAndroid Build Coastguard Worker elif args.outfile == "stderr": 86*da0073e9SAndroid Build Coastguard Worker outfile = sys.stderr 87*da0073e9SAndroid Build Coastguard Worker else: 88*da0073e9SAndroid Build Coastguard Worker outfile = open(args.outfile, "a") 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker test_count = args.test_count 91*da0073e9SAndroid Build Coastguard Worker m = args.m 92*da0073e9SAndroid Build Coastguard Worker nnz_ratio = args.nnz_ratio 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker nnz = int(nnz_ratio * m * m) 95*da0073e9SAndroid Build Coastguard Worker if args.format == "csr": 96*da0073e9SAndroid Build Coastguard Worker time = test_sparse_csr(m, nnz, test_count) 97*da0073e9SAndroid Build Coastguard Worker elif args.format == "coo": 98*da0073e9SAndroid Build Coastguard Worker time = test_sparse_coo(m, nnz, test_count) 99*da0073e9SAndroid Build Coastguard Worker elif args.format == "both": 100*da0073e9SAndroid Build Coastguard Worker time_coo, time_csr = test_sparse_coo_and_csr(m, nnz, test_count) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker if args.format != "both": 103*da0073e9SAndroid Build Coastguard Worker print( 104*da0073e9SAndroid Build Coastguard Worker "format=", 105*da0073e9SAndroid Build Coastguard Worker args.format, 106*da0073e9SAndroid Build Coastguard Worker " nnz_ratio=", 107*da0073e9SAndroid Build Coastguard Worker nnz_ratio, 108*da0073e9SAndroid Build Coastguard Worker " m=", 109*da0073e9SAndroid Build Coastguard Worker m, 110*da0073e9SAndroid Build Coastguard Worker " time=", 111*da0073e9SAndroid Build Coastguard Worker time, 112*da0073e9SAndroid Build Coastguard Worker file=outfile, 113*da0073e9SAndroid Build Coastguard Worker ) 114*da0073e9SAndroid Build Coastguard Worker else: 115*da0073e9SAndroid Build Coastguard Worker print( 116*da0073e9SAndroid Build Coastguard Worker "format=coo", 117*da0073e9SAndroid Build Coastguard Worker " nnz_ratio=", 118*da0073e9SAndroid Build Coastguard Worker nnz_ratio, 119*da0073e9SAndroid Build Coastguard Worker " m=", 120*da0073e9SAndroid Build Coastguard Worker m, 121*da0073e9SAndroid Build Coastguard Worker " time=", 122*da0073e9SAndroid Build Coastguard Worker time_coo, 123*da0073e9SAndroid Build Coastguard Worker file=outfile, 124*da0073e9SAndroid Build Coastguard Worker ) 125*da0073e9SAndroid Build Coastguard Worker print( 126*da0073e9SAndroid Build Coastguard Worker "format=csr", 127*da0073e9SAndroid Build Coastguard Worker " nnz_ratio=", 128*da0073e9SAndroid Build Coastguard Worker nnz_ratio, 129*da0073e9SAndroid Build Coastguard Worker " m=", 130*da0073e9SAndroid Build Coastguard Worker m, 131*da0073e9SAndroid Build Coastguard Worker " time=", 132*da0073e9SAndroid Build Coastguard Worker time_csr, 133*da0073e9SAndroid Build Coastguard Worker file=outfile, 134*da0073e9SAndroid Build Coastguard Worker ) 135