xref: /aosp_15_r20/external/pytorch/benchmarks/sparse/dlmc/matmul_bench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Sparse benchmarks
2
3# This benchmark is for  sparse matmul performance test.
4# They exist for comparing the performance of sparse matrix routines
5# `sparse @ vector`, `sparse @ sparse` and `sparse @ dense` with different backends (CPU/CUDA)
6# and with other frameworks such as scipy.
7
8import argparse
9import os
10import sys
11
12from scipy.sparse import isspmatrix
13
14import torch
15import torch.utils.benchmark as benchmark_utils
16
17from .utils import load_dlmc_dataset
18
19
20def scipy_matmul(mat1, mat2):
21    if isspmatrix(mat1) and isspmatrix(mat2):
22        return mat1.dot(mat2).tocoo()
23    return mat1.dot(mat2)
24
25
26def matmul_backward(a_dense, b_dense, grad_output):
27    r1 = a_dense.matmul(b_dense)
28    r1.backward(grad_output)
29
30
31def sparse_matmul_backward(a, b, grad_output):
32    c = torch.sparse.mm(a, b)
33    c.backward(grad_output)
34
35
36OPS_MAP = {
37    "sparse@sparse": "torch.sparse.mm",
38    "sparse@dense": "torch.matmul",
39    "sparse@vector": "torch.matmul",
40}
41
42
43# also get the arguments as input from the user using `argparse`
44def parse_args():
45    parser = argparse.ArgumentParser(description="matmul benchmark")
46    parser.add_argument("--path", type=str, help="DLMC dataset path")
47    parser.add_argument("--dataset", type=str, default="magnitude_pruning")
48    parser.add_argument("--hidden-size", "--hidden_size", default=2048, type=int)
49    parser.add_argument("--backward-test", "--backward_test", action="store_true")
50    parser.add_argument(
51        "--operation",
52        type=str,
53        help="|".join(OPS_MAP.keys()),
54        default=next(iter(OPS_MAP)),
55    )
56    parser.add_argument("--with-cuda", "--with_cuda", action="store_true")
57    parser.add_argument(
58        "--timer-min-run-time", "--timer_min_run_time", default=1, type=float
59    )
60    return parser
61
62
63def get_tasks(op, backward_test, device):
64    def filter_ops(operation):
65        if backward_test:
66            test_name = device + ":matmul-backward"
67            return [
68                (
69                    test_name,
70                    device,
71                    "torch:" + operation.replace("sparse", "dense"),
72                    "matmul_backward(dx, dy, grad_output)",
73                ),
74                (
75                    test_name,
76                    device,
77                    "torch:" + operation,
78                    "sparse_matmul_backward(x, y, sparse_grad_output)",
79                ),
80            ]
81        else:
82            test_name = device + ":matmul-forward"
83            return list(
84                filter(
85                    None,
86                    [
87                        (
88                            test_name,
89                            device,
90                            "torch:" + operation.replace("sparse", "dense"),
91                            f"{OPS_MAP[operation]}(dx, dy)",
92                        ),
93                        (
94                            test_name,
95                            device,
96                            "torch:" + operation,
97                            f"{OPS_MAP[operation]}(x, y)",
98                        ),
99                        (
100                            test_name,
101                            device,
102                            "scipy:" + operation,
103                            "scipy_matmul(sx, sy)",
104                        )
105                        if device == "cpu"
106                        else None,
107                    ],
108                )
109            )
110
111    all_operations = {
112        "sparse@sparse": filter_ops("sparse@sparse"),
113        "sparse@dense": filter_ops("sparse@dense"),
114        "sparse@vector": filter_ops("sparse@vector"),
115    }
116    return all_operations[op]
117
118
119if __name__ == "__main__":
120    parser = parse_args()
121    args = parser.parse_args()
122
123    if args.with_cuda and not torch.cuda.is_available():
124        raise RuntimeError("No CUDA available")
125
126    dataset_path = args.path
127    dataset_name = args.dataset
128    dataset_path = os.path.join(dataset_path, dataset_name)
129    device = "cuda" if args.with_cuda else "cpu"
130
131    tasks = get_tasks(args.operation, args.backward_test, device)
132    repeats = 3
133    timers = [
134        benchmark_utils.Timer(
135            stmt=stmt,
136            globals={
137                "scipy_matmul": scipy_matmul,
138                "matmul_backward": matmul_backward,
139                "sparse_matmul_backward": sparse_matmul_backward,
140                **variables,
141            },
142            label=label,
143            sub_label=sub_label,
144            description=f"{sparsity}",
145            env=device,
146        )
147        for sparsity in [0.5, 0.7, 0.8, 0.9, 0.95, 0.98]
148        for label, device, sub_label, stmt in tasks
149        for variables in load_dlmc_dataset(
150            dataset_path,
151            args.operation,
152            args.hidden_size,
153            sparsity,
154            device,
155            args.backward_test,
156        )
157    ]
158    measurements = []
159
160    for i, timer in enumerate(timers * repeats):
161        m = timer.blocked_autorange(min_run_time=args.timer_min_run_time)
162        m.metadata = {"device": "cuda" if m.task_spec.env.find("cuda") >= 0 else "cpu"}
163        measurements.append(m)
164        print(f"\r{i + 1} / {len(timers) * repeats}", end="")
165        sys.stdout.flush()
166    print()
167
168    comparison = benchmark_utils.Compare(measurements)
169
170    print("== Results " + "=" * 80 + "\n" + "/" * 95 + "\n")
171    comparison.print()
172