xref: /aosp_15_r20/external/pytorch/benchmarks/sparse/spmm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import sys
3
4from utils import Event, gen_sparse_coo, gen_sparse_coo_and_csr, gen_sparse_csr
5
6import torch
7
8
9def test_sparse_csr(m, n, k, nnz, test_count):
10    start_timer = Event(enable_timing=True)
11    stop_timer = Event(enable_timing=True)
12
13    csr = gen_sparse_csr((m, k), nnz)
14    mat = torch.randn(k, n, dtype=torch.double)
15
16    times = []
17    for _ in range(test_count):
18        start_timer.record()
19        csr.matmul(mat)
20        stop_timer.record()
21        times.append(start_timer.elapsed_time(stop_timer))
22
23    return sum(times) / len(times)
24
25
26def test_sparse_coo(m, n, k, nnz, test_count):
27    start_timer = Event(enable_timing=True)
28    stop_timer = Event(enable_timing=True)
29
30    coo = gen_sparse_coo((m, k), nnz)
31    mat = torch.randn(k, n, dtype=torch.double)
32
33    times = []
34    for _ in range(test_count):
35        start_timer.record()
36        coo.matmul(mat)
37        stop_timer.record()
38        times.append(start_timer.elapsed_time(stop_timer))
39
40    return sum(times) / len(times)
41
42
43def test_sparse_coo_and_csr(m, n, k, nnz, test_count):
44    start = Event(enable_timing=True)
45    stop = Event(enable_timing=True)
46
47    coo, csr = gen_sparse_coo_and_csr((m, k), nnz)
48    mat = torch.randn((k, n), dtype=torch.double)
49
50    times = []
51    for _ in range(test_count):
52        start.record()
53        coo.matmul(mat)
54        stop.record()
55
56        times.append(start.elapsed_time(stop))
57
58        coo_mean_time = sum(times) / len(times)
59
60        times = []
61        for _ in range(test_count):
62            start.record()
63            csr.matmul(mat)
64            stop.record()
65            times.append(start.elapsed_time(stop))
66
67            csr_mean_time = sum(times) / len(times)
68
69    return coo_mean_time, csr_mean_time
70
71
72if __name__ == "__main__":
73    parser = argparse.ArgumentParser(description="SpMM")
74
75    parser.add_argument("--format", default="csr", type=str)
76    parser.add_argument("--m", default="1000", type=int)
77    parser.add_argument("--n", default="1000", type=int)
78    parser.add_argument("--k", default="1000", type=int)
79    parser.add_argument("--nnz-ratio", "--nnz_ratio", default="0.1", type=float)
80    parser.add_argument("--outfile", default="stdout", type=str)
81    parser.add_argument("--test-count", "--test_count", default="10", type=int)
82
83    args = parser.parse_args()
84
85    if args.outfile == "stdout":
86        outfile = sys.stdout
87    elif args.outfile == "stderr":
88        outfile = sys.stderr
89    else:
90        outfile = open(args.outfile, "a")
91
92    test_count = args.test_count
93    m = args.m
94    n = args.n
95    k = args.k
96    nnz_ratio = args.nnz_ratio
97
98    nnz = int(nnz_ratio * m * k)
99    if args.format == "csr":
100        time = test_sparse_csr(m, n, k, nnz, test_count)
101    elif args.format == "coo":
102        time = test_sparse_coo(m, n, k, nnz, test_count)
103    elif args.format == "both":
104        time_coo, time_csr = test_sparse_coo_and_csr(m, nnz, test_count)
105
106    if args.format == "both":
107        print(
108            "format=coo",
109            " nnz_ratio=",
110            nnz_ratio,
111            " m=",
112            m,
113            " n=",
114            n,
115            " k=",
116            k,
117            " time=",
118            time_coo,
119            file=outfile,
120        )
121        print(
122            "format=csr",
123            " nnz_ratio=",
124            nnz_ratio,
125            " m=",
126            m,
127            " n=",
128            n,
129            " k=",
130            k,
131            " time=",
132            time_csr,
133            file=outfile,
134        )
135    else:
136        print(
137            "format=",
138            args.format,
139            " nnz_ratio=",
140            nnz_ratio,
141            " m=",
142            m,
143            " n=",
144            n,
145            " k=",
146            k,
147            " time=",
148            time,
149            file=outfile,
150        )
151