xref: /aosp_15_r20/external/pytorch/torch/utils/benchmark/examples/sparse/compare.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Example of Timer and Compare APIs:
3
4$ python -m examples.sparse.compare
5"""
6
7import pickle
8import sys
9import time
10
11import torch
12import torch.utils.benchmark as benchmark_utils
13
14
15class FauxTorch:
16    """Emulate different versions of pytorch.
17
18    In normal circumstances this would be done with multiple processes
19    writing serialized measurements, but this simplifies that model to
20    make the example clearer.
21    """
22    def __init__(self, real_torch, extra_ns_per_element):
23        self._real_torch = real_torch
24        self._extra_ns_per_element = extra_ns_per_element
25
26    @property
27    def sparse(self):
28        return self.Sparse(self._real_torch, self._extra_ns_per_element)
29
30    class Sparse:
31        def __init__(self, real_torch, extra_ns_per_element):
32            self._real_torch = real_torch
33            self._extra_ns_per_element = extra_ns_per_element
34
35        def extra_overhead(self, result):
36            # time.sleep has a ~65 us overhead, so only fake a
37            # per-element overhead if numel is large enough.
38            size = sum(result.size())
39            if size > 5000:
40                time.sleep(size * self._extra_ns_per_element * 1e-9)
41            return result
42
43        def mm(self, *args, **kwargs):
44            return self.extra_overhead(self._real_torch.sparse.mm(*args, **kwargs))
45
46def generate_coo_data(size, sparse_dim, nnz, dtype, device):
47    """
48    Parameters
49    ----------
50    size : tuple
51    sparse_dim : int
52    nnz : int
53    dtype : torch.dtype
54    device : str
55    Returns
56    -------
57    indices : torch.tensor
58    values : torch.tensor
59    """
60    if dtype is None:
61        dtype = 'float32'
62
63    indices = torch.rand(sparse_dim, nnz, device=device)
64    indices.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(indices))
65    indices = indices.to(torch.long)
66    values = torch.rand([nnz, ], dtype=dtype, device=device)
67    return indices, values
68
69def gen_sparse(size, density, dtype, device='cpu'):
70    sparse_dim = len(size)
71    nnz = int(size[0] * size[1] * density)
72    indices, values = generate_coo_data(size, sparse_dim, nnz, dtype, device)
73    return torch.sparse_coo_tensor(indices, values, size, dtype=dtype, device=device)
74
75def main():
76    tasks = [
77        ("matmul", "x @ y", "torch.sparse.mm(x, y)"),
78        ("matmul", "x @ y + 0", "torch.sparse.mm(x, y) + zero"),
79    ]
80
81    serialized_results = []
82    repeats = 2
83    timers = [
84        benchmark_utils.Timer(
85            stmt=stmt,
86            globals={
87                "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns),
88                "x": gen_sparse(size=size, density=density, dtype=torch.float32),
89                "y": torch.rand(size, dtype=torch.float32),
90                "zero": torch.zeros(()),
91            },
92            label=label,
93            sub_label=sub_label,
94            description=f"size: {size}",
95            env=branch,
96            num_threads=num_threads,
97        )
98        for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 10)]
99        for label, sub_label, stmt in tasks
100        for density in [0.05, 0.1]
101        for size in [(8, 8), (32, 32), (64, 64), (128, 128)]
102        for num_threads in [1, 4]
103    ]
104
105    for i, timer in enumerate(timers * repeats):
106        serialized_results.append(pickle.dumps(
107            timer.blocked_autorange(min_run_time=0.05)
108        ))
109        print(f"\r{i + 1} / {len(timers) * repeats}", end="")
110        sys.stdout.flush()
111    print()
112
113    comparison = benchmark_utils.Compare([
114        pickle.loads(i) for i in serialized_results
115    ])
116
117    print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
118    comparison.print()
119
120    print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
121    comparison.trim_significant_figures()
122    comparison.colorize()
123    comparison.print()
124
125
126if __name__ == "__main__":
127    main()
128