xref: /aosp_15_r20/external/pytorch/benchmarks/sparse/spmv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport sys
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom .utils import Event, gen_sparse_coo, gen_sparse_coo_and_csr, gen_sparse_csr
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerdef test_sparse_csr(m, nnz, test_count):
10*da0073e9SAndroid Build Coastguard Worker    start_timer = Event(enable_timing=True)
11*da0073e9SAndroid Build Coastguard Worker    stop_timer = Event(enable_timing=True)
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker    csr = gen_sparse_csr((m, m), nnz)
14*da0073e9SAndroid Build Coastguard Worker    vector = torch.randn(m, dtype=torch.double)
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker    times = []
17*da0073e9SAndroid Build Coastguard Worker    for _ in range(test_count):
18*da0073e9SAndroid Build Coastguard Worker        start_timer.record()
19*da0073e9SAndroid Build Coastguard Worker        csr.matmul(vector)
20*da0073e9SAndroid Build Coastguard Worker        stop_timer.record()
21*da0073e9SAndroid Build Coastguard Worker        times.append(start_timer.elapsed_time(stop_timer))
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    return sum(times) / len(times)
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerdef test_sparse_coo(m, nnz, test_count):
27*da0073e9SAndroid Build Coastguard Worker    start_timer = Event(enable_timing=True)
28*da0073e9SAndroid Build Coastguard Worker    stop_timer = Event(enable_timing=True)
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    coo = gen_sparse_coo((m, m), nnz)
31*da0073e9SAndroid Build Coastguard Worker    vector = torch.randn(m, dtype=torch.double)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    times = []
34*da0073e9SAndroid Build Coastguard Worker    for _ in range(test_count):
35*da0073e9SAndroid Build Coastguard Worker        start_timer.record()
36*da0073e9SAndroid Build Coastguard Worker        coo.matmul(vector)
37*da0073e9SAndroid Build Coastguard Worker        stop_timer.record()
38*da0073e9SAndroid Build Coastguard Worker        times.append(start_timer.elapsed_time(stop_timer))
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    return sum(times) / len(times)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerdef test_sparse_coo_and_csr(m, nnz, test_count):
44*da0073e9SAndroid Build Coastguard Worker    start = Event(enable_timing=True)
45*da0073e9SAndroid Build Coastguard Worker    stop = Event(enable_timing=True)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    coo, csr = gen_sparse_coo_and_csr((m, m), nnz)
48*da0073e9SAndroid Build Coastguard Worker    vector = torch.randn(m, dtype=torch.double)
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    times = []
51*da0073e9SAndroid Build Coastguard Worker    for _ in range(test_count):
52*da0073e9SAndroid Build Coastguard Worker        start.record()
53*da0073e9SAndroid Build Coastguard Worker        coo.matmul(vector)
54*da0073e9SAndroid Build Coastguard Worker        stop.record()
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        times.append(start.elapsed_time(stop))
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    coo_mean_time = sum(times) / len(times)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    times = []
61*da0073e9SAndroid Build Coastguard Worker    for _ in range(test_count):
62*da0073e9SAndroid Build Coastguard Worker        start.record()
63*da0073e9SAndroid Build Coastguard Worker        csr.matmul(vector)
64*da0073e9SAndroid Build Coastguard Worker        stop.record()
65*da0073e9SAndroid Build Coastguard Worker        times.append(start.elapsed_time(stop))
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    csr_mean_time = sum(times) / len(times)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    return coo_mean_time, csr_mean_time
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
73*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="SpMV")
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--format", default="csr", type=str)
76*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--m", default="1000", type=int)
77*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--nnz-ratio", "--nnz_ratio", default="0.1", type=float)
78*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--outfile", default="stdout", type=str)
79*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--test-count", "--test_count", default="10", type=int)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    if args.outfile == "stdout":
84*da0073e9SAndroid Build Coastguard Worker        outfile = sys.stdout
85*da0073e9SAndroid Build Coastguard Worker    elif args.outfile == "stderr":
86*da0073e9SAndroid Build Coastguard Worker        outfile = sys.stderr
87*da0073e9SAndroid Build Coastguard Worker    else:
88*da0073e9SAndroid Build Coastguard Worker        outfile = open(args.outfile, "a")
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    test_count = args.test_count
91*da0073e9SAndroid Build Coastguard Worker    m = args.m
92*da0073e9SAndroid Build Coastguard Worker    nnz_ratio = args.nnz_ratio
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    nnz = int(nnz_ratio * m * m)
95*da0073e9SAndroid Build Coastguard Worker    if args.format == "csr":
96*da0073e9SAndroid Build Coastguard Worker        time = test_sparse_csr(m, nnz, test_count)
97*da0073e9SAndroid Build Coastguard Worker    elif args.format == "coo":
98*da0073e9SAndroid Build Coastguard Worker        time = test_sparse_coo(m, nnz, test_count)
99*da0073e9SAndroid Build Coastguard Worker    elif args.format == "both":
100*da0073e9SAndroid Build Coastguard Worker        time_coo, time_csr = test_sparse_coo_and_csr(m, nnz, test_count)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    if args.format != "both":
103*da0073e9SAndroid Build Coastguard Worker        print(
104*da0073e9SAndroid Build Coastguard Worker            "format=",
105*da0073e9SAndroid Build Coastguard Worker            args.format,
106*da0073e9SAndroid Build Coastguard Worker            " nnz_ratio=",
107*da0073e9SAndroid Build Coastguard Worker            nnz_ratio,
108*da0073e9SAndroid Build Coastguard Worker            " m=",
109*da0073e9SAndroid Build Coastguard Worker            m,
110*da0073e9SAndroid Build Coastguard Worker            " time=",
111*da0073e9SAndroid Build Coastguard Worker            time,
112*da0073e9SAndroid Build Coastguard Worker            file=outfile,
113*da0073e9SAndroid Build Coastguard Worker        )
114*da0073e9SAndroid Build Coastguard Worker    else:
115*da0073e9SAndroid Build Coastguard Worker        print(
116*da0073e9SAndroid Build Coastguard Worker            "format=coo",
117*da0073e9SAndroid Build Coastguard Worker            " nnz_ratio=",
118*da0073e9SAndroid Build Coastguard Worker            nnz_ratio,
119*da0073e9SAndroid Build Coastguard Worker            " m=",
120*da0073e9SAndroid Build Coastguard Worker            m,
121*da0073e9SAndroid Build Coastguard Worker            " time=",
122*da0073e9SAndroid Build Coastguard Worker            time_coo,
123*da0073e9SAndroid Build Coastguard Worker            file=outfile,
124*da0073e9SAndroid Build Coastguard Worker        )
125*da0073e9SAndroid Build Coastguard Worker        print(
126*da0073e9SAndroid Build Coastguard Worker            "format=csr",
127*da0073e9SAndroid Build Coastguard Worker            " nnz_ratio=",
128*da0073e9SAndroid Build Coastguard Worker            nnz_ratio,
129*da0073e9SAndroid Build Coastguard Worker            " m=",
130*da0073e9SAndroid Build Coastguard Worker            m,
131*da0073e9SAndroid Build Coastguard Worker            " time=",
132*da0073e9SAndroid Build Coastguard Worker            time_csr,
133*da0073e9SAndroid Build Coastguard Worker            file=outfile,
134*da0073e9SAndroid Build Coastguard Worker        )
135