1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport operator 3*da0073e9SAndroid Build Coastguard Workerimport time 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport matplotlib.pyplot as plt 6*da0073e9SAndroid Build Coastguard Workerimport numpy as np 7*da0073e9SAndroid Build Coastguard Workerimport pandas as pd 8*da0073e9SAndroid Build Coastguard Workerimport seaborn as sns 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport torch 11*da0073e9SAndroid Build Coastguard Workerimport torch._C._te as te 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerclass kernel_arena_scope: 15*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 16*da0073e9SAndroid Build Coastguard Worker self.scope = te.KernelScope() 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker def __exit__(self, typ, val, traceback): 19*da0073e9SAndroid Build Coastguard Worker self.scope = None 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerunary_ops = [ 23*da0073e9SAndroid Build Coastguard Worker ("sin", torch.sin), 24*da0073e9SAndroid Build Coastguard Worker ("cos", torch.cos), 25*da0073e9SAndroid Build Coastguard Worker ("tan", torch.tan), 26*da0073e9SAndroid Build Coastguard Worker ("asin", torch.asin), 27*da0073e9SAndroid Build Coastguard Worker ("acos", torch.acos), 28*da0073e9SAndroid Build Coastguard Worker ("atan", torch.atan), 29*da0073e9SAndroid Build Coastguard Worker ("sinh", torch.sinh), 30*da0073e9SAndroid Build Coastguard Worker ("cosh", torch.cosh), 31*da0073e9SAndroid Build Coastguard Worker ("tanh", torch.tanh), 32*da0073e9SAndroid Build Coastguard Worker ("sigmoid", torch.sigmoid), 33*da0073e9SAndroid Build Coastguard Worker ("exp", torch.exp), 34*da0073e9SAndroid Build Coastguard Worker ("expm1", torch.expm1), 35*da0073e9SAndroid Build Coastguard Worker ("expm1", torch.expm1), 36*da0073e9SAndroid Build Coastguard Worker ("abs", torch.abs), 37*da0073e9SAndroid Build Coastguard Worker ("log", torch.log), 38*da0073e9SAndroid Build Coastguard Worker ("fast_log", torch.log), 39*da0073e9SAndroid Build Coastguard Worker ("log2", torch.log2), 40*da0073e9SAndroid Build Coastguard Worker ("log10", torch.log10), 41*da0073e9SAndroid Build Coastguard Worker ("log1p", torch.log1p), 42*da0073e9SAndroid Build Coastguard Worker ("erf", torch.erf), 43*da0073e9SAndroid Build Coastguard Worker ("erfc", torch.erfc), 44*da0073e9SAndroid Build Coastguard Worker ("sqrt", torch.sqrt), 45*da0073e9SAndroid Build Coastguard Worker ("rsqrt", torch.rsqrt), 46*da0073e9SAndroid Build Coastguard Worker ("ceil", torch.ceil), 47*da0073e9SAndroid Build Coastguard Worker ("floor", torch.floor), 48*da0073e9SAndroid Build Coastguard Worker ("round", torch.round), 49*da0073e9SAndroid Build Coastguard Worker ("trunc", torch.trunc), 50*da0073e9SAndroid Build Coastguard Worker ("lgamma", torch.lgamma), 51*da0073e9SAndroid Build Coastguard Worker # ("frac", torch.frac), # seems unimplemented 52*da0073e9SAndroid Build Coastguard Worker # ("isnan", torch.isnan), # no out variant 53*da0073e9SAndroid Build Coastguard Worker] 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerdef gen_unary_nnc_fun(nnc_name): 57*da0073e9SAndroid Build Coastguard Worker def nnc_fun(A, B): 58*da0073e9SAndroid Build Coastguard Worker def compute(i, j): 59*da0073e9SAndroid Build Coastguard Worker return getattr(A.load([i, j]), nnc_name)() 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker return compute 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker return nnc_fun 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Workerdef gen_unary_torch_fun(torch_op): 67*da0073e9SAndroid Build Coastguard Worker def torch_fun(a, b, out): 68*da0073e9SAndroid Build Coastguard Worker def fun(): 69*da0073e9SAndroid Build Coastguard Worker return torch_op(a, out=out) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker return fun 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker return torch_fun 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Workerdef gen_binary_nnc_fun(fn): 77*da0073e9SAndroid Build Coastguard Worker def nnc_fun(A, B): 78*da0073e9SAndroid Build Coastguard Worker def compute(i, j): 79*da0073e9SAndroid Build Coastguard Worker return fn(A.load([i, j]), B.load([i, j])) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker return compute 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker return nnc_fun 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Workerdef gen_binary_torch_fun(fn): 87*da0073e9SAndroid Build Coastguard Worker def pt_fun(a, b, out): 88*da0073e9SAndroid Build Coastguard Worker def fun(): 89*da0073e9SAndroid Build Coastguard Worker return fn(a, b, out=out) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker return fun 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker return pt_fun 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Workerdef gen_int_comparison_tensors(N, M): 97*da0073e9SAndroid Build Coastguard Worker return ( 98*da0073e9SAndroid Build Coastguard Worker torch.randint(0, 3, (N, M)), 99*da0073e9SAndroid Build Coastguard Worker torch.randint(0, 3, (N, M)), 100*da0073e9SAndroid Build Coastguard Worker torch.empty((N, M), dtype=torch.bool), 101*da0073e9SAndroid Build Coastguard Worker ) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Workerdef gen_float_comparison_tensors(N, M): 105*da0073e9SAndroid Build Coastguard Worker return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool)) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Workerte_bool = te.Dtype.Bool 109*da0073e9SAndroid Build Coastguard Workerbinary_ops = [ 110*da0073e9SAndroid Build Coastguard Worker ("add", operator.add, torch.add), 111*da0073e9SAndroid Build Coastguard Worker ("mul", operator.mul, torch.mul), 112*da0073e9SAndroid Build Coastguard Worker ("sub", operator.sub, torch.sub), 113*da0073e9SAndroid Build Coastguard Worker ("div", operator.truediv, torch.div), 114*da0073e9SAndroid Build Coastguard Worker ( 115*da0073e9SAndroid Build Coastguard Worker "eq", 116*da0073e9SAndroid Build Coastguard Worker (lambda a, b: te.Cast.make(te_bool, a == b)), 117*da0073e9SAndroid Build Coastguard Worker torch.eq, 118*da0073e9SAndroid Build Coastguard Worker gen_int_comparison_tensors, 119*da0073e9SAndroid Build Coastguard Worker ), 120*da0073e9SAndroid Build Coastguard Worker ( 121*da0073e9SAndroid Build Coastguard Worker "gt", 122*da0073e9SAndroid Build Coastguard Worker (lambda a, b: te.Cast.make(te_bool, a > b)), 123*da0073e9SAndroid Build Coastguard Worker torch.gt, 124*da0073e9SAndroid Build Coastguard Worker gen_float_comparison_tensors, 125*da0073e9SAndroid Build Coastguard Worker ), 126*da0073e9SAndroid Build Coastguard Worker ( 127*da0073e9SAndroid Build Coastguard Worker "lt", 128*da0073e9SAndroid Build Coastguard Worker (lambda a, b: te.Cast.make(te_bool, a < b)), 129*da0073e9SAndroid Build Coastguard Worker torch.lt, 130*da0073e9SAndroid Build Coastguard Worker gen_float_comparison_tensors, 131*da0073e9SAndroid Build Coastguard Worker ), 132*da0073e9SAndroid Build Coastguard Worker ( 133*da0073e9SAndroid Build Coastguard Worker "gte", 134*da0073e9SAndroid Build Coastguard Worker (lambda a, b: te.Cast.make(te_bool, a >= b)), 135*da0073e9SAndroid Build Coastguard Worker torch.greater_equal, 136*da0073e9SAndroid Build Coastguard Worker gen_float_comparison_tensors, 137*da0073e9SAndroid Build Coastguard Worker ), 138*da0073e9SAndroid Build Coastguard Worker ( 139*da0073e9SAndroid Build Coastguard Worker "lte", 140*da0073e9SAndroid Build Coastguard Worker (lambda a, b: te.Cast.make(te_bool, a <= b)), 141*da0073e9SAndroid Build Coastguard Worker torch.less_equal, 142*da0073e9SAndroid Build Coastguard Worker gen_float_comparison_tensors, 143*da0073e9SAndroid Build Coastguard Worker ), 144*da0073e9SAndroid Build Coastguard Worker # ('neq', (lambda a, b: a != b), None)), # no one-op equivalent 145*da0073e9SAndroid Build Coastguard Worker # ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test 146*da0073e9SAndroid Build Coastguard Worker] 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Workerdef nnc_relu(A, B): 150*da0073e9SAndroid Build Coastguard Worker def f(i, j): 151*da0073e9SAndroid Build Coastguard Worker return torch._C._te.ifThenElse( 152*da0073e9SAndroid Build Coastguard Worker A.load([i, j]) < torch._C._te.ExprHandle.float(0), 153*da0073e9SAndroid Build Coastguard Worker torch._C._te.ExprHandle.float(0), 154*da0073e9SAndroid Build Coastguard Worker A.load([i, j]), 155*da0073e9SAndroid Build Coastguard Worker ) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker return f 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Workerdef pt_relu(a, b, c): 161*da0073e9SAndroid Build Coastguard Worker return torch.relu(a) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Workercustom_ops = [ 165*da0073e9SAndroid Build Coastguard Worker ("relu", nnc_relu, pt_relu), 166*da0073e9SAndroid Build Coastguard Worker # ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu) 167*da0073e9SAndroid Build Coastguard Worker # ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c)) 168*da0073e9SAndroid Build Coastguard Worker] 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Workerdef gen_custom_torch_fun(fn): 172*da0073e9SAndroid Build Coastguard Worker def pt_fun(a, b, out): 173*da0073e9SAndroid Build Coastguard Worker def fun(): 174*da0073e9SAndroid Build Coastguard Worker return fn(a, b, out) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker return fun 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker return pt_fun 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Workerdef normalize_benchmarks(ops): 182*da0073e9SAndroid Build Coastguard Worker return [i + (None,) if len(i) == 3 else i for i in ops] 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Workernames = [] 186*da0073e9SAndroid Build Coastguard Workernnc_fns = [] 187*da0073e9SAndroid Build Coastguard Workerpt_fns = [] 188*da0073e9SAndroid Build Coastguard Workershape_fns = [] 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Workerfor nnc_name, pt_op in unary_ops: 191*da0073e9SAndroid Build Coastguard Worker names.append(nnc_name) 192*da0073e9SAndroid Build Coastguard Worker nnc_fns.append(gen_unary_nnc_fun(nnc_name)) 193*da0073e9SAndroid Build Coastguard Worker pt_fns.append(gen_unary_torch_fun(pt_op)) 194*da0073e9SAndroid Build Coastguard Worker shape_fns.append(None) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Workerfor name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops): 197*da0073e9SAndroid Build Coastguard Worker names.append(name) 198*da0073e9SAndroid Build Coastguard Worker nnc_fns.append(gen_binary_nnc_fun(lmbda)) 199*da0073e9SAndroid Build Coastguard Worker pt_fns.append(gen_binary_torch_fun(pt_fn)) 200*da0073e9SAndroid Build Coastguard Worker shape_fns.append(shape_fn) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Workerfor name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops): 203*da0073e9SAndroid Build Coastguard Worker names.append(name) 204*da0073e9SAndroid Build Coastguard Worker nnc_fns.append(lmbda) 205*da0073e9SAndroid Build Coastguard Worker pt_fns.append(gen_custom_torch_fun(pt_fn)) 206*da0073e9SAndroid Build Coastguard Worker shape_fns.append(shape_fn) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Workerbenchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns)) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Workerdef run_benchmarks(benchmarks, sizes): 212*da0073e9SAndroid Build Coastguard Worker df = pd.DataFrame(columns=["name", "N", "M", "nnc_time", "torch_time", "ratio"]) 213*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 214*da0073e9SAndroid Build Coastguard Worker for name, nnc_fun, torch_fun, shape_fn in benchmarks: 215*da0073e9SAndroid Build Coastguard Worker for N, M in sizes: 216*da0073e9SAndroid Build Coastguard Worker iters = int(1e6 / (N + M)) 217*da0073e9SAndroid Build Coastguard Worker with kernel_arena_scope(): 218*da0073e9SAndroid Build Coastguard Worker if shape_fn is None: 219*da0073e9SAndroid Build Coastguard Worker tA = torch.rand(M, N).clamp(0.01, 0.99) 220*da0073e9SAndroid Build Coastguard Worker tB = torch.rand(M, N).clamp(0.01, 0.99) 221*da0073e9SAndroid Build Coastguard Worker tX = torch.empty(M, N) 222*da0073e9SAndroid Build Coastguard Worker tR = torch.empty(M, N) 223*da0073e9SAndroid Build Coastguard Worker else: 224*da0073e9SAndroid Build Coastguard Worker tA, tB, tX = shape_fn(M, N) 225*da0073e9SAndroid Build Coastguard Worker tR = tX.clone() 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker def get_nnc_type(dtype): 228*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float: 229*da0073e9SAndroid Build Coastguard Worker return torch._C._te.Dtype.Float 230*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.long: 231*da0073e9SAndroid Build Coastguard Worker return torch._C._te.Dtype.Long 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker dtype = get_nnc_type(tA.dtype) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker dM = torch._C._te.ExprHandle.int(M) 236*da0073e9SAndroid Build Coastguard Worker dN = torch._C._te.ExprHandle.int(N) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker A = torch._C._te.Placeholder("A", dtype, [dM, dN]) 239*da0073e9SAndroid Build Coastguard Worker B = torch._C._te.Placeholder("B", dtype, [dM, dN]) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker dim_args = [ 242*da0073e9SAndroid Build Coastguard Worker torch._C._te.DimArg(*args) for args in [(dM, "m"), (dN, "n")] 243*da0073e9SAndroid Build Coastguard Worker ] 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker compute = nnc_fun(A, B) 246*da0073e9SAndroid Build Coastguard Worker X = torch._C._te.Compute("X", dim_args, compute) 247*da0073e9SAndroid Build Coastguard Worker loopnest = torch._C._te.LoopNest([X]) 248*da0073e9SAndroid Build Coastguard Worker loopnest.prepare_for_codegen() 249*da0073e9SAndroid Build Coastguard Worker stmt = torch._C._te.simplify(loopnest.root_stmt()) 250*da0073e9SAndroid Build Coastguard Worker cg = torch._C._te.construct_codegen( 251*da0073e9SAndroid Build Coastguard Worker "llvm", stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]] 252*da0073e9SAndroid Build Coastguard Worker ) 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker # warmup 255*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 256*da0073e9SAndroid Build Coastguard Worker cg.call([tA, tB, tX]) 257*da0073e9SAndroid Build Coastguard Worker start = time.time() 258*da0073e9SAndroid Build Coastguard Worker for it in range(iters): 259*da0073e9SAndroid Build Coastguard Worker cg.call([tA, tB, tX]) 260*da0073e9SAndroid Build Coastguard Worker time1 = time.time() - start 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker fn = torch_fun(tA, tB, tR) 263*da0073e9SAndroid Build Coastguard Worker # warmup 264*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 265*da0073e9SAndroid Build Coastguard Worker tR = fn() 266*da0073e9SAndroid Build Coastguard Worker start = time.time() 267*da0073e9SAndroid Build Coastguard Worker for it in range(iters): 268*da0073e9SAndroid Build Coastguard Worker tR = fn() 269*da0073e9SAndroid Build Coastguard Worker time2 = time.time() - start 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker df = df.append( 272*da0073e9SAndroid Build Coastguard Worker { 273*da0073e9SAndroid Build Coastguard Worker "name": name, 274*da0073e9SAndroid Build Coastguard Worker "N": N, 275*da0073e9SAndroid Build Coastguard Worker "M": M, 276*da0073e9SAndroid Build Coastguard Worker "nnc_time": time1, 277*da0073e9SAndroid Build Coastguard Worker "torch_time": time2, 278*da0073e9SAndroid Build Coastguard Worker "ratio": time2 / time1, 279*da0073e9SAndroid Build Coastguard Worker }, 280*da0073e9SAndroid Build Coastguard Worker ignore_index=True, 281*da0073e9SAndroid Build Coastguard Worker ) 282*da0073e9SAndroid Build Coastguard Worker print(name, N, M) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker print(time2 / time1, time1, time2) 285*da0073e9SAndroid Build Coastguard Worker print() 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker def check_correctness(a, b): 288*da0073e9SAndroid Build Coastguard Worker if not np.allclose(a, b): 289*da0073e9SAndroid Build Coastguard Worker print(name) 290*da0073e9SAndroid Build Coastguard Worker assert np.allclose(a, b) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker check_correctness(tX, tR) 293*da0073e9SAndroid Build Coastguard Worker return df 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Workerdef dump_plot(df, sizes): 297*da0073e9SAndroid Build Coastguard Worker keys = [] 298*da0073e9SAndroid Build Coastguard Worker vals = [] 299*da0073e9SAndroid Build Coastguard Worker indexed = df[df["N"] == df["M"]] 300*da0073e9SAndroid Build Coastguard Worker for index, row in indexed.iterrows(): 301*da0073e9SAndroid Build Coastguard Worker keys.append(row["name"]) 302*da0073e9SAndroid Build Coastguard Worker vals.append(row["ratio"]) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker keys = keys[:: len(sizes)] 305*da0073e9SAndroid Build Coastguard Worker sns.set(rc={"figure.figsize": (5.0, len(keys) * 0.5)}) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True) 308*da0073e9SAndroid Build Coastguard Worker np_vals = np.array([vals]).reshape(-1, len(sizes)) 309*da0073e9SAndroid Build Coastguard Worker g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True) 310*da0073e9SAndroid Build Coastguard Worker plt.yticks(rotation=0) 311*da0073e9SAndroid Build Coastguard Worker plt.title("PyTorch performance divided by NNC performance (single core)") 312*da0073e9SAndroid Build Coastguard Worker plt.xlabel("Size of NxN matrix") 313*da0073e9SAndroid Build Coastguard Worker plt.ylabel("Operation") 314*da0073e9SAndroid Build Coastguard Worker g.set_yticklabels(keys) 315*da0073e9SAndroid Build Coastguard Worker g.set_xticklabels(sizes) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker plt.savefig("nnc.png") 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 321*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="Runs NNC microbenchmarks") 322*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 323*da0073e9SAndroid Build Coastguard Worker "--multi-threaded", 324*da0073e9SAndroid Build Coastguard Worker "--multi_threaded", 325*da0073e9SAndroid Build Coastguard Worker action="store_true", 326*da0073e9SAndroid Build Coastguard Worker help="Run with more than one thread", 327*da0073e9SAndroid Build Coastguard Worker ) 328*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 329*da0073e9SAndroid Build Coastguard Worker if not args.multi_threaded: 330*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(1) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker sizes = [1, 4, 16, 64, 256, 1024] 333*da0073e9SAndroid Build Coastguard Worker df = run_benchmarks(benchmarks, [(i, i) for i in sizes]) 334*da0073e9SAndroid Build Coastguard Worker dump_plot(df, sizes) 335