xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/microbenchmarks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import operator
3import time
4
5import matplotlib.pyplot as plt
6import numpy as np
7import pandas as pd
8import seaborn as sns
9
10import torch
11import torch._C._te as te
12
13
14class kernel_arena_scope:
15    def __enter__(self):
16        self.scope = te.KernelScope()
17
18    def __exit__(self, typ, val, traceback):
19        self.scope = None
20
21
22unary_ops = [
23    ("sin", torch.sin),
24    ("cos", torch.cos),
25    ("tan", torch.tan),
26    ("asin", torch.asin),
27    ("acos", torch.acos),
28    ("atan", torch.atan),
29    ("sinh", torch.sinh),
30    ("cosh", torch.cosh),
31    ("tanh", torch.tanh),
32    ("sigmoid", torch.sigmoid),
33    ("exp", torch.exp),
34    ("expm1", torch.expm1),
35    ("expm1", torch.expm1),
36    ("abs", torch.abs),
37    ("log", torch.log),
38    ("fast_log", torch.log),
39    ("log2", torch.log2),
40    ("log10", torch.log10),
41    ("log1p", torch.log1p),
42    ("erf", torch.erf),
43    ("erfc", torch.erfc),
44    ("sqrt", torch.sqrt),
45    ("rsqrt", torch.rsqrt),
46    ("ceil", torch.ceil),
47    ("floor", torch.floor),
48    ("round", torch.round),
49    ("trunc", torch.trunc),
50    ("lgamma", torch.lgamma),
51    # ("frac", torch.frac), # seems unimplemented
52    # ("isnan", torch.isnan), # no out variant
53]
54
55
56def gen_unary_nnc_fun(nnc_name):
57    def nnc_fun(A, B):
58        def compute(i, j):
59            return getattr(A.load([i, j]), nnc_name)()
60
61        return compute
62
63    return nnc_fun
64
65
66def gen_unary_torch_fun(torch_op):
67    def torch_fun(a, b, out):
68        def fun():
69            return torch_op(a, out=out)
70
71        return fun
72
73    return torch_fun
74
75
76def gen_binary_nnc_fun(fn):
77    def nnc_fun(A, B):
78        def compute(i, j):
79            return fn(A.load([i, j]), B.load([i, j]))
80
81        return compute
82
83    return nnc_fun
84
85
86def gen_binary_torch_fun(fn):
87    def pt_fun(a, b, out):
88        def fun():
89            return fn(a, b, out=out)
90
91        return fun
92
93    return pt_fun
94
95
96def gen_int_comparison_tensors(N, M):
97    return (
98        torch.randint(0, 3, (N, M)),
99        torch.randint(0, 3, (N, M)),
100        torch.empty((N, M), dtype=torch.bool),
101    )
102
103
104def gen_float_comparison_tensors(N, M):
105    return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool))
106
107
108te_bool = te.Dtype.Bool
109binary_ops = [
110    ("add", operator.add, torch.add),
111    ("mul", operator.mul, torch.mul),
112    ("sub", operator.sub, torch.sub),
113    ("div", operator.truediv, torch.div),
114    (
115        "eq",
116        (lambda a, b: te.Cast.make(te_bool, a == b)),
117        torch.eq,
118        gen_int_comparison_tensors,
119    ),
120    (
121        "gt",
122        (lambda a, b: te.Cast.make(te_bool, a > b)),
123        torch.gt,
124        gen_float_comparison_tensors,
125    ),
126    (
127        "lt",
128        (lambda a, b: te.Cast.make(te_bool, a < b)),
129        torch.lt,
130        gen_float_comparison_tensors,
131    ),
132    (
133        "gte",
134        (lambda a, b: te.Cast.make(te_bool, a >= b)),
135        torch.greater_equal,
136        gen_float_comparison_tensors,
137    ),
138    (
139        "lte",
140        (lambda a, b: te.Cast.make(te_bool, a <= b)),
141        torch.less_equal,
142        gen_float_comparison_tensors,
143    ),
144    # ('neq', (lambda a, b: a != b), None)), # no one-op equivalent
145    # ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test
146]
147
148
149def nnc_relu(A, B):
150    def f(i, j):
151        return torch._C._te.ifThenElse(
152            A.load([i, j]) < torch._C._te.ExprHandle.float(0),
153            torch._C._te.ExprHandle.float(0),
154            A.load([i, j]),
155        )
156
157    return f
158
159
160def pt_relu(a, b, c):
161    return torch.relu(a)
162
163
164custom_ops = [
165    ("relu", nnc_relu, pt_relu),
166    # ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu)
167    # ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c))
168]
169
170
171def gen_custom_torch_fun(fn):
172    def pt_fun(a, b, out):
173        def fun():
174            return fn(a, b, out)
175
176        return fun
177
178    return pt_fun
179
180
181def normalize_benchmarks(ops):
182    return [i + (None,) if len(i) == 3 else i for i in ops]
183
184
185names = []
186nnc_fns = []
187pt_fns = []
188shape_fns = []
189
190for nnc_name, pt_op in unary_ops:
191    names.append(nnc_name)
192    nnc_fns.append(gen_unary_nnc_fun(nnc_name))
193    pt_fns.append(gen_unary_torch_fun(pt_op))
194    shape_fns.append(None)
195
196for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops):
197    names.append(name)
198    nnc_fns.append(gen_binary_nnc_fun(lmbda))
199    pt_fns.append(gen_binary_torch_fun(pt_fn))
200    shape_fns.append(shape_fn)
201
202for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops):
203    names.append(name)
204    nnc_fns.append(lmbda)
205    pt_fns.append(gen_custom_torch_fun(pt_fn))
206    shape_fns.append(shape_fn)
207
208benchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns))
209
210
211def run_benchmarks(benchmarks, sizes):
212    df = pd.DataFrame(columns=["name", "N", "M", "nnc_time", "torch_time", "ratio"])
213    with torch.no_grad():
214        for name, nnc_fun, torch_fun, shape_fn in benchmarks:
215            for N, M in sizes:
216                iters = int(1e6 / (N + M))
217                with kernel_arena_scope():
218                    if shape_fn is None:
219                        tA = torch.rand(M, N).clamp(0.01, 0.99)
220                        tB = torch.rand(M, N).clamp(0.01, 0.99)
221                        tX = torch.empty(M, N)
222                        tR = torch.empty(M, N)
223                    else:
224                        tA, tB, tX = shape_fn(M, N)
225                        tR = tX.clone()
226
227                    def get_nnc_type(dtype):
228                        if dtype == torch.float:
229                            return torch._C._te.Dtype.Float
230                        elif dtype == torch.long:
231                            return torch._C._te.Dtype.Long
232
233                    dtype = get_nnc_type(tA.dtype)
234
235                    dM = torch._C._te.ExprHandle.int(M)
236                    dN = torch._C._te.ExprHandle.int(N)
237
238                    A = torch._C._te.Placeholder("A", dtype, [dM, dN])
239                    B = torch._C._te.Placeholder("B", dtype, [dM, dN])
240
241                    dim_args = [
242                        torch._C._te.DimArg(*args) for args in [(dM, "m"), (dN, "n")]
243                    ]
244
245                    compute = nnc_fun(A, B)
246                    X = torch._C._te.Compute("X", dim_args, compute)
247                    loopnest = torch._C._te.LoopNest([X])
248                    loopnest.prepare_for_codegen()
249                    stmt = torch._C._te.simplify(loopnest.root_stmt())
250                    cg = torch._C._te.construct_codegen(
251                        "llvm", stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]]
252                    )
253
254                    # warmup
255                    for _ in range(10):
256                        cg.call([tA, tB, tX])
257                    start = time.time()
258                    for it in range(iters):
259                        cg.call([tA, tB, tX])
260                    time1 = time.time() - start
261
262                    fn = torch_fun(tA, tB, tR)
263                    # warmup
264                    for _ in range(10):
265                        tR = fn()
266                    start = time.time()
267                    for it in range(iters):
268                        tR = fn()
269                    time2 = time.time() - start
270
271                    df = df.append(
272                        {
273                            "name": name,
274                            "N": N,
275                            "M": M,
276                            "nnc_time": time1,
277                            "torch_time": time2,
278                            "ratio": time2 / time1,
279                        },
280                        ignore_index=True,
281                    )
282                    print(name, N, M)
283
284                    print(time2 / time1, time1, time2)
285                    print()
286
287                    def check_correctness(a, b):
288                        if not np.allclose(a, b):
289                            print(name)
290                            assert np.allclose(a, b)
291
292                    check_correctness(tX, tR)
293    return df
294
295
296def dump_plot(df, sizes):
297    keys = []
298    vals = []
299    indexed = df[df["N"] == df["M"]]
300    for index, row in indexed.iterrows():
301        keys.append(row["name"])
302        vals.append(row["ratio"])
303
304    keys = keys[:: len(sizes)]
305    sns.set(rc={"figure.figsize": (5.0, len(keys) * 0.5)})
306
307    cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True)
308    np_vals = np.array([vals]).reshape(-1, len(sizes))
309    g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True)
310    plt.yticks(rotation=0)
311    plt.title("PyTorch performance divided by NNC performance (single core)")
312    plt.xlabel("Size of NxN matrix")
313    plt.ylabel("Operation")
314    g.set_yticklabels(keys)
315    g.set_xticklabels(sizes)
316
317    plt.savefig("nnc.png")
318
319
320if __name__ == "__main__":
321    parser = argparse.ArgumentParser(description="Runs NNC microbenchmarks")
322    parser.add_argument(
323        "--multi-threaded",
324        "--multi_threaded",
325        action="store_true",
326        help="Run with more than one thread",
327    )
328    args = parser.parse_args()
329    if not args.multi_threaded:
330        torch.set_num_threads(1)
331
332    sizes = [1, 4, 16, 64, 256, 1024]
333    df = run_benchmarks(benchmarks, [(i, i) for i in sizes])
334    dump_plot(df, sizes)
335