xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/microbenchmarks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport operator
3*da0073e9SAndroid Build Coastguard Workerimport time
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport matplotlib.pyplot as plt
6*da0073e9SAndroid Build Coastguard Workerimport numpy as np
7*da0073e9SAndroid Build Coastguard Workerimport pandas as pd
8*da0073e9SAndroid Build Coastguard Workerimport seaborn as sns
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport torch
11*da0073e9SAndroid Build Coastguard Workerimport torch._C._te as te
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerclass kernel_arena_scope:
15*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
16*da0073e9SAndroid Build Coastguard Worker        self.scope = te.KernelScope()
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, typ, val, traceback):
19*da0073e9SAndroid Build Coastguard Worker        self.scope = None
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerunary_ops = [
23*da0073e9SAndroid Build Coastguard Worker    ("sin", torch.sin),
24*da0073e9SAndroid Build Coastguard Worker    ("cos", torch.cos),
25*da0073e9SAndroid Build Coastguard Worker    ("tan", torch.tan),
26*da0073e9SAndroid Build Coastguard Worker    ("asin", torch.asin),
27*da0073e9SAndroid Build Coastguard Worker    ("acos", torch.acos),
28*da0073e9SAndroid Build Coastguard Worker    ("atan", torch.atan),
29*da0073e9SAndroid Build Coastguard Worker    ("sinh", torch.sinh),
30*da0073e9SAndroid Build Coastguard Worker    ("cosh", torch.cosh),
31*da0073e9SAndroid Build Coastguard Worker    ("tanh", torch.tanh),
32*da0073e9SAndroid Build Coastguard Worker    ("sigmoid", torch.sigmoid),
33*da0073e9SAndroid Build Coastguard Worker    ("exp", torch.exp),
34*da0073e9SAndroid Build Coastguard Worker    ("expm1", torch.expm1),
35*da0073e9SAndroid Build Coastguard Worker    ("expm1", torch.expm1),
36*da0073e9SAndroid Build Coastguard Worker    ("abs", torch.abs),
37*da0073e9SAndroid Build Coastguard Worker    ("log", torch.log),
38*da0073e9SAndroid Build Coastguard Worker    ("fast_log", torch.log),
39*da0073e9SAndroid Build Coastguard Worker    ("log2", torch.log2),
40*da0073e9SAndroid Build Coastguard Worker    ("log10", torch.log10),
41*da0073e9SAndroid Build Coastguard Worker    ("log1p", torch.log1p),
42*da0073e9SAndroid Build Coastguard Worker    ("erf", torch.erf),
43*da0073e9SAndroid Build Coastguard Worker    ("erfc", torch.erfc),
44*da0073e9SAndroid Build Coastguard Worker    ("sqrt", torch.sqrt),
45*da0073e9SAndroid Build Coastguard Worker    ("rsqrt", torch.rsqrt),
46*da0073e9SAndroid Build Coastguard Worker    ("ceil", torch.ceil),
47*da0073e9SAndroid Build Coastguard Worker    ("floor", torch.floor),
48*da0073e9SAndroid Build Coastguard Worker    ("round", torch.round),
49*da0073e9SAndroid Build Coastguard Worker    ("trunc", torch.trunc),
50*da0073e9SAndroid Build Coastguard Worker    ("lgamma", torch.lgamma),
51*da0073e9SAndroid Build Coastguard Worker    # ("frac", torch.frac), # seems unimplemented
52*da0073e9SAndroid Build Coastguard Worker    # ("isnan", torch.isnan), # no out variant
53*da0073e9SAndroid Build Coastguard Worker]
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerdef gen_unary_nnc_fun(nnc_name):
57*da0073e9SAndroid Build Coastguard Worker    def nnc_fun(A, B):
58*da0073e9SAndroid Build Coastguard Worker        def compute(i, j):
59*da0073e9SAndroid Build Coastguard Worker            return getattr(A.load([i, j]), nnc_name)()
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        return compute
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    return nnc_fun
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Workerdef gen_unary_torch_fun(torch_op):
67*da0073e9SAndroid Build Coastguard Worker    def torch_fun(a, b, out):
68*da0073e9SAndroid Build Coastguard Worker        def fun():
69*da0073e9SAndroid Build Coastguard Worker            return torch_op(a, out=out)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker        return fun
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    return torch_fun
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Workerdef gen_binary_nnc_fun(fn):
77*da0073e9SAndroid Build Coastguard Worker    def nnc_fun(A, B):
78*da0073e9SAndroid Build Coastguard Worker        def compute(i, j):
79*da0073e9SAndroid Build Coastguard Worker            return fn(A.load([i, j]), B.load([i, j]))
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        return compute
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    return nnc_fun
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Workerdef gen_binary_torch_fun(fn):
87*da0073e9SAndroid Build Coastguard Worker    def pt_fun(a, b, out):
88*da0073e9SAndroid Build Coastguard Worker        def fun():
89*da0073e9SAndroid Build Coastguard Worker            return fn(a, b, out=out)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        return fun
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    return pt_fun
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Workerdef gen_int_comparison_tensors(N, M):
97*da0073e9SAndroid Build Coastguard Worker    return (
98*da0073e9SAndroid Build Coastguard Worker        torch.randint(0, 3, (N, M)),
99*da0073e9SAndroid Build Coastguard Worker        torch.randint(0, 3, (N, M)),
100*da0073e9SAndroid Build Coastguard Worker        torch.empty((N, M), dtype=torch.bool),
101*da0073e9SAndroid Build Coastguard Worker    )
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerdef gen_float_comparison_tensors(N, M):
105*da0073e9SAndroid Build Coastguard Worker    return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool))
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Workerte_bool = te.Dtype.Bool
109*da0073e9SAndroid Build Coastguard Workerbinary_ops = [
110*da0073e9SAndroid Build Coastguard Worker    ("add", operator.add, torch.add),
111*da0073e9SAndroid Build Coastguard Worker    ("mul", operator.mul, torch.mul),
112*da0073e9SAndroid Build Coastguard Worker    ("sub", operator.sub, torch.sub),
113*da0073e9SAndroid Build Coastguard Worker    ("div", operator.truediv, torch.div),
114*da0073e9SAndroid Build Coastguard Worker    (
115*da0073e9SAndroid Build Coastguard Worker        "eq",
116*da0073e9SAndroid Build Coastguard Worker        (lambda a, b: te.Cast.make(te_bool, a == b)),
117*da0073e9SAndroid Build Coastguard Worker        torch.eq,
118*da0073e9SAndroid Build Coastguard Worker        gen_int_comparison_tensors,
119*da0073e9SAndroid Build Coastguard Worker    ),
120*da0073e9SAndroid Build Coastguard Worker    (
121*da0073e9SAndroid Build Coastguard Worker        "gt",
122*da0073e9SAndroid Build Coastguard Worker        (lambda a, b: te.Cast.make(te_bool, a > b)),
123*da0073e9SAndroid Build Coastguard Worker        torch.gt,
124*da0073e9SAndroid Build Coastguard Worker        gen_float_comparison_tensors,
125*da0073e9SAndroid Build Coastguard Worker    ),
126*da0073e9SAndroid Build Coastguard Worker    (
127*da0073e9SAndroid Build Coastguard Worker        "lt",
128*da0073e9SAndroid Build Coastguard Worker        (lambda a, b: te.Cast.make(te_bool, a < b)),
129*da0073e9SAndroid Build Coastguard Worker        torch.lt,
130*da0073e9SAndroid Build Coastguard Worker        gen_float_comparison_tensors,
131*da0073e9SAndroid Build Coastguard Worker    ),
132*da0073e9SAndroid Build Coastguard Worker    (
133*da0073e9SAndroid Build Coastguard Worker        "gte",
134*da0073e9SAndroid Build Coastguard Worker        (lambda a, b: te.Cast.make(te_bool, a >= b)),
135*da0073e9SAndroid Build Coastguard Worker        torch.greater_equal,
136*da0073e9SAndroid Build Coastguard Worker        gen_float_comparison_tensors,
137*da0073e9SAndroid Build Coastguard Worker    ),
138*da0073e9SAndroid Build Coastguard Worker    (
139*da0073e9SAndroid Build Coastguard Worker        "lte",
140*da0073e9SAndroid Build Coastguard Worker        (lambda a, b: te.Cast.make(te_bool, a <= b)),
141*da0073e9SAndroid Build Coastguard Worker        torch.less_equal,
142*da0073e9SAndroid Build Coastguard Worker        gen_float_comparison_tensors,
143*da0073e9SAndroid Build Coastguard Worker    ),
144*da0073e9SAndroid Build Coastguard Worker    # ('neq', (lambda a, b: a != b), None)), # no one-op equivalent
145*da0073e9SAndroid Build Coastguard Worker    # ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test
146*da0073e9SAndroid Build Coastguard Worker]
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Workerdef nnc_relu(A, B):
150*da0073e9SAndroid Build Coastguard Worker    def f(i, j):
151*da0073e9SAndroid Build Coastguard Worker        return torch._C._te.ifThenElse(
152*da0073e9SAndroid Build Coastguard Worker            A.load([i, j]) < torch._C._te.ExprHandle.float(0),
153*da0073e9SAndroid Build Coastguard Worker            torch._C._te.ExprHandle.float(0),
154*da0073e9SAndroid Build Coastguard Worker            A.load([i, j]),
155*da0073e9SAndroid Build Coastguard Worker        )
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    return f
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Workerdef pt_relu(a, b, c):
161*da0073e9SAndroid Build Coastguard Worker    return torch.relu(a)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Workercustom_ops = [
165*da0073e9SAndroid Build Coastguard Worker    ("relu", nnc_relu, pt_relu),
166*da0073e9SAndroid Build Coastguard Worker    # ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu)
167*da0073e9SAndroid Build Coastguard Worker    # ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c))
168*da0073e9SAndroid Build Coastguard Worker]
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Workerdef gen_custom_torch_fun(fn):
172*da0073e9SAndroid Build Coastguard Worker    def pt_fun(a, b, out):
173*da0073e9SAndroid Build Coastguard Worker        def fun():
174*da0073e9SAndroid Build Coastguard Worker            return fn(a, b, out)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        return fun
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    return pt_fun
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Workerdef normalize_benchmarks(ops):
182*da0073e9SAndroid Build Coastguard Worker    return [i + (None,) if len(i) == 3 else i for i in ops]
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Workernames = []
186*da0073e9SAndroid Build Coastguard Workernnc_fns = []
187*da0073e9SAndroid Build Coastguard Workerpt_fns = []
188*da0073e9SAndroid Build Coastguard Workershape_fns = []
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Workerfor nnc_name, pt_op in unary_ops:
191*da0073e9SAndroid Build Coastguard Worker    names.append(nnc_name)
192*da0073e9SAndroid Build Coastguard Worker    nnc_fns.append(gen_unary_nnc_fun(nnc_name))
193*da0073e9SAndroid Build Coastguard Worker    pt_fns.append(gen_unary_torch_fun(pt_op))
194*da0073e9SAndroid Build Coastguard Worker    shape_fns.append(None)
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Workerfor name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops):
197*da0073e9SAndroid Build Coastguard Worker    names.append(name)
198*da0073e9SAndroid Build Coastguard Worker    nnc_fns.append(gen_binary_nnc_fun(lmbda))
199*da0073e9SAndroid Build Coastguard Worker    pt_fns.append(gen_binary_torch_fun(pt_fn))
200*da0073e9SAndroid Build Coastguard Worker    shape_fns.append(shape_fn)
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Workerfor name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops):
203*da0073e9SAndroid Build Coastguard Worker    names.append(name)
204*da0073e9SAndroid Build Coastguard Worker    nnc_fns.append(lmbda)
205*da0073e9SAndroid Build Coastguard Worker    pt_fns.append(gen_custom_torch_fun(pt_fn))
206*da0073e9SAndroid Build Coastguard Worker    shape_fns.append(shape_fn)
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Workerbenchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns))
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Workerdef run_benchmarks(benchmarks, sizes):
212*da0073e9SAndroid Build Coastguard Worker    df = pd.DataFrame(columns=["name", "N", "M", "nnc_time", "torch_time", "ratio"])
213*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
214*da0073e9SAndroid Build Coastguard Worker        for name, nnc_fun, torch_fun, shape_fn in benchmarks:
215*da0073e9SAndroid Build Coastguard Worker            for N, M in sizes:
216*da0073e9SAndroid Build Coastguard Worker                iters = int(1e6 / (N + M))
217*da0073e9SAndroid Build Coastguard Worker                with kernel_arena_scope():
218*da0073e9SAndroid Build Coastguard Worker                    if shape_fn is None:
219*da0073e9SAndroid Build Coastguard Worker                        tA = torch.rand(M, N).clamp(0.01, 0.99)
220*da0073e9SAndroid Build Coastguard Worker                        tB = torch.rand(M, N).clamp(0.01, 0.99)
221*da0073e9SAndroid Build Coastguard Worker                        tX = torch.empty(M, N)
222*da0073e9SAndroid Build Coastguard Worker                        tR = torch.empty(M, N)
223*da0073e9SAndroid Build Coastguard Worker                    else:
224*da0073e9SAndroid Build Coastguard Worker                        tA, tB, tX = shape_fn(M, N)
225*da0073e9SAndroid Build Coastguard Worker                        tR = tX.clone()
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker                    def get_nnc_type(dtype):
228*da0073e9SAndroid Build Coastguard Worker                        if dtype == torch.float:
229*da0073e9SAndroid Build Coastguard Worker                            return torch._C._te.Dtype.Float
230*da0073e9SAndroid Build Coastguard Worker                        elif dtype == torch.long:
231*da0073e9SAndroid Build Coastguard Worker                            return torch._C._te.Dtype.Long
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker                    dtype = get_nnc_type(tA.dtype)
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker                    dM = torch._C._te.ExprHandle.int(M)
236*da0073e9SAndroid Build Coastguard Worker                    dN = torch._C._te.ExprHandle.int(N)
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker                    A = torch._C._te.Placeholder("A", dtype, [dM, dN])
239*da0073e9SAndroid Build Coastguard Worker                    B = torch._C._te.Placeholder("B", dtype, [dM, dN])
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker                    dim_args = [
242*da0073e9SAndroid Build Coastguard Worker                        torch._C._te.DimArg(*args) for args in [(dM, "m"), (dN, "n")]
243*da0073e9SAndroid Build Coastguard Worker                    ]
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker                    compute = nnc_fun(A, B)
246*da0073e9SAndroid Build Coastguard Worker                    X = torch._C._te.Compute("X", dim_args, compute)
247*da0073e9SAndroid Build Coastguard Worker                    loopnest = torch._C._te.LoopNest([X])
248*da0073e9SAndroid Build Coastguard Worker                    loopnest.prepare_for_codegen()
249*da0073e9SAndroid Build Coastguard Worker                    stmt = torch._C._te.simplify(loopnest.root_stmt())
250*da0073e9SAndroid Build Coastguard Worker                    cg = torch._C._te.construct_codegen(
251*da0073e9SAndroid Build Coastguard Worker                        "llvm", stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]]
252*da0073e9SAndroid Build Coastguard Worker                    )
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker                    # warmup
255*da0073e9SAndroid Build Coastguard Worker                    for _ in range(10):
256*da0073e9SAndroid Build Coastguard Worker                        cg.call([tA, tB, tX])
257*da0073e9SAndroid Build Coastguard Worker                    start = time.time()
258*da0073e9SAndroid Build Coastguard Worker                    for it in range(iters):
259*da0073e9SAndroid Build Coastguard Worker                        cg.call([tA, tB, tX])
260*da0073e9SAndroid Build Coastguard Worker                    time1 = time.time() - start
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker                    fn = torch_fun(tA, tB, tR)
263*da0073e9SAndroid Build Coastguard Worker                    # warmup
264*da0073e9SAndroid Build Coastguard Worker                    for _ in range(10):
265*da0073e9SAndroid Build Coastguard Worker                        tR = fn()
266*da0073e9SAndroid Build Coastguard Worker                    start = time.time()
267*da0073e9SAndroid Build Coastguard Worker                    for it in range(iters):
268*da0073e9SAndroid Build Coastguard Worker                        tR = fn()
269*da0073e9SAndroid Build Coastguard Worker                    time2 = time.time() - start
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker                    df = df.append(
272*da0073e9SAndroid Build Coastguard Worker                        {
273*da0073e9SAndroid Build Coastguard Worker                            "name": name,
274*da0073e9SAndroid Build Coastguard Worker                            "N": N,
275*da0073e9SAndroid Build Coastguard Worker                            "M": M,
276*da0073e9SAndroid Build Coastguard Worker                            "nnc_time": time1,
277*da0073e9SAndroid Build Coastguard Worker                            "torch_time": time2,
278*da0073e9SAndroid Build Coastguard Worker                            "ratio": time2 / time1,
279*da0073e9SAndroid Build Coastguard Worker                        },
280*da0073e9SAndroid Build Coastguard Worker                        ignore_index=True,
281*da0073e9SAndroid Build Coastguard Worker                    )
282*da0073e9SAndroid Build Coastguard Worker                    print(name, N, M)
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker                    print(time2 / time1, time1, time2)
285*da0073e9SAndroid Build Coastguard Worker                    print()
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker                    def check_correctness(a, b):
288*da0073e9SAndroid Build Coastguard Worker                        if not np.allclose(a, b):
289*da0073e9SAndroid Build Coastguard Worker                            print(name)
290*da0073e9SAndroid Build Coastguard Worker                            assert np.allclose(a, b)
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker                    check_correctness(tX, tR)
293*da0073e9SAndroid Build Coastguard Worker    return df
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Workerdef dump_plot(df, sizes):
297*da0073e9SAndroid Build Coastguard Worker    keys = []
298*da0073e9SAndroid Build Coastguard Worker    vals = []
299*da0073e9SAndroid Build Coastguard Worker    indexed = df[df["N"] == df["M"]]
300*da0073e9SAndroid Build Coastguard Worker    for index, row in indexed.iterrows():
301*da0073e9SAndroid Build Coastguard Worker        keys.append(row["name"])
302*da0073e9SAndroid Build Coastguard Worker        vals.append(row["ratio"])
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker    keys = keys[:: len(sizes)]
305*da0073e9SAndroid Build Coastguard Worker    sns.set(rc={"figure.figsize": (5.0, len(keys) * 0.5)})
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker    cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True)
308*da0073e9SAndroid Build Coastguard Worker    np_vals = np.array([vals]).reshape(-1, len(sizes))
309*da0073e9SAndroid Build Coastguard Worker    g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True)
310*da0073e9SAndroid Build Coastguard Worker    plt.yticks(rotation=0)
311*da0073e9SAndroid Build Coastguard Worker    plt.title("PyTorch performance divided by NNC performance (single core)")
312*da0073e9SAndroid Build Coastguard Worker    plt.xlabel("Size of NxN matrix")
313*da0073e9SAndroid Build Coastguard Worker    plt.ylabel("Operation")
314*da0073e9SAndroid Build Coastguard Worker    g.set_yticklabels(keys)
315*da0073e9SAndroid Build Coastguard Worker    g.set_xticklabels(sizes)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    plt.savefig("nnc.png")
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
321*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="Runs NNC microbenchmarks")
322*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
323*da0073e9SAndroid Build Coastguard Worker        "--multi-threaded",
324*da0073e9SAndroid Build Coastguard Worker        "--multi_threaded",
325*da0073e9SAndroid Build Coastguard Worker        action="store_true",
326*da0073e9SAndroid Build Coastguard Worker        help="Run with more than one thread",
327*da0073e9SAndroid Build Coastguard Worker    )
328*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
329*da0073e9SAndroid Build Coastguard Worker    if not args.multi_threaded:
330*da0073e9SAndroid Build Coastguard Worker        torch.set_num_threads(1)
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker    sizes = [1, 4, 16, 64, 256, 1024]
333*da0073e9SAndroid Build Coastguard Worker    df = run_benchmarks(benchmarks, [(i, i) for i in sizes])
334*da0073e9SAndroid Build Coastguard Worker    dump_plot(df, sizes)
335