xref: /aosp_15_r20/external/pytorch/benchmarks/sparse/spmv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import sys
3
4import torch
5
6from .utils import Event, gen_sparse_coo, gen_sparse_coo_and_csr, gen_sparse_csr
7
8
9def test_sparse_csr(m, nnz, test_count):
10    start_timer = Event(enable_timing=True)
11    stop_timer = Event(enable_timing=True)
12
13    csr = gen_sparse_csr((m, m), nnz)
14    vector = torch.randn(m, dtype=torch.double)
15
16    times = []
17    for _ in range(test_count):
18        start_timer.record()
19        csr.matmul(vector)
20        stop_timer.record()
21        times.append(start_timer.elapsed_time(stop_timer))
22
23    return sum(times) / len(times)
24
25
26def test_sparse_coo(m, nnz, test_count):
27    start_timer = Event(enable_timing=True)
28    stop_timer = Event(enable_timing=True)
29
30    coo = gen_sparse_coo((m, m), nnz)
31    vector = torch.randn(m, dtype=torch.double)
32
33    times = []
34    for _ in range(test_count):
35        start_timer.record()
36        coo.matmul(vector)
37        stop_timer.record()
38        times.append(start_timer.elapsed_time(stop_timer))
39
40    return sum(times) / len(times)
41
42
43def test_sparse_coo_and_csr(m, nnz, test_count):
44    start = Event(enable_timing=True)
45    stop = Event(enable_timing=True)
46
47    coo, csr = gen_sparse_coo_and_csr((m, m), nnz)
48    vector = torch.randn(m, dtype=torch.double)
49
50    times = []
51    for _ in range(test_count):
52        start.record()
53        coo.matmul(vector)
54        stop.record()
55
56        times.append(start.elapsed_time(stop))
57
58    coo_mean_time = sum(times) / len(times)
59
60    times = []
61    for _ in range(test_count):
62        start.record()
63        csr.matmul(vector)
64        stop.record()
65        times.append(start.elapsed_time(stop))
66
67    csr_mean_time = sum(times) / len(times)
68
69    return coo_mean_time, csr_mean_time
70
71
72if __name__ == "__main__":
73    parser = argparse.ArgumentParser(description="SpMV")
74
75    parser.add_argument("--format", default="csr", type=str)
76    parser.add_argument("--m", default="1000", type=int)
77    parser.add_argument("--nnz-ratio", "--nnz_ratio", default="0.1", type=float)
78    parser.add_argument("--outfile", default="stdout", type=str)
79    parser.add_argument("--test-count", "--test_count", default="10", type=int)
80
81    args = parser.parse_args()
82
83    if args.outfile == "stdout":
84        outfile = sys.stdout
85    elif args.outfile == "stderr":
86        outfile = sys.stderr
87    else:
88        outfile = open(args.outfile, "a")
89
90    test_count = args.test_count
91    m = args.m
92    nnz_ratio = args.nnz_ratio
93
94    nnz = int(nnz_ratio * m * m)
95    if args.format == "csr":
96        time = test_sparse_csr(m, nnz, test_count)
97    elif args.format == "coo":
98        time = test_sparse_coo(m, nnz, test_count)
99    elif args.format == "both":
100        time_coo, time_csr = test_sparse_coo_and_csr(m, nnz, test_count)
101
102    if args.format != "both":
103        print(
104            "format=",
105            args.format,
106            " nnz_ratio=",
107            nnz_ratio,
108            " m=",
109            m,
110            " time=",
111            time,
112            file=outfile,
113        )
114    else:
115        print(
116            "format=coo",
117            " nnz_ratio=",
118            nnz_ratio,
119            " m=",
120            m,
121            " time=",
122            time_coo,
123            file=outfile,
124        )
125        print(
126            "format=csr",
127            " nnz_ratio=",
128            nnz_ratio,
129            " m=",
130            m,
131            " time=",
132            time_csr,
133            file=outfile,
134        )
135