1import argparse 2import operator 3import time 4 5import matplotlib.pyplot as plt 6import numpy as np 7import pandas as pd 8import seaborn as sns 9 10import torch 11import torch._C._te as te 12 13 14class kernel_arena_scope: 15 def __enter__(self): 16 self.scope = te.KernelScope() 17 18 def __exit__(self, typ, val, traceback): 19 self.scope = None 20 21 22unary_ops = [ 23 ("sin", torch.sin), 24 ("cos", torch.cos), 25 ("tan", torch.tan), 26 ("asin", torch.asin), 27 ("acos", torch.acos), 28 ("atan", torch.atan), 29 ("sinh", torch.sinh), 30 ("cosh", torch.cosh), 31 ("tanh", torch.tanh), 32 ("sigmoid", torch.sigmoid), 33 ("exp", torch.exp), 34 ("expm1", torch.expm1), 35 ("expm1", torch.expm1), 36 ("abs", torch.abs), 37 ("log", torch.log), 38 ("fast_log", torch.log), 39 ("log2", torch.log2), 40 ("log10", torch.log10), 41 ("log1p", torch.log1p), 42 ("erf", torch.erf), 43 ("erfc", torch.erfc), 44 ("sqrt", torch.sqrt), 45 ("rsqrt", torch.rsqrt), 46 ("ceil", torch.ceil), 47 ("floor", torch.floor), 48 ("round", torch.round), 49 ("trunc", torch.trunc), 50 ("lgamma", torch.lgamma), 51 # ("frac", torch.frac), # seems unimplemented 52 # ("isnan", torch.isnan), # no out variant 53] 54 55 56def gen_unary_nnc_fun(nnc_name): 57 def nnc_fun(A, B): 58 def compute(i, j): 59 return getattr(A.load([i, j]), nnc_name)() 60 61 return compute 62 63 return nnc_fun 64 65 66def gen_unary_torch_fun(torch_op): 67 def torch_fun(a, b, out): 68 def fun(): 69 return torch_op(a, out=out) 70 71 return fun 72 73 return torch_fun 74 75 76def gen_binary_nnc_fun(fn): 77 def nnc_fun(A, B): 78 def compute(i, j): 79 return fn(A.load([i, j]), B.load([i, j])) 80 81 return compute 82 83 return nnc_fun 84 85 86def gen_binary_torch_fun(fn): 87 def pt_fun(a, b, out): 88 def fun(): 89 return fn(a, b, out=out) 90 91 return fun 92 93 return pt_fun 94 95 96def gen_int_comparison_tensors(N, M): 97 return ( 98 torch.randint(0, 3, (N, M)), 99 torch.randint(0, 3, (N, M)), 100 torch.empty((N, M), dtype=torch.bool), 101 ) 102 103 104def gen_float_comparison_tensors(N, M): 105 return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool)) 106 107 108te_bool = te.Dtype.Bool 109binary_ops = [ 110 ("add", operator.add, torch.add), 111 ("mul", operator.mul, torch.mul), 112 ("sub", operator.sub, torch.sub), 113 ("div", operator.truediv, torch.div), 114 ( 115 "eq", 116 (lambda a, b: te.Cast.make(te_bool, a == b)), 117 torch.eq, 118 gen_int_comparison_tensors, 119 ), 120 ( 121 "gt", 122 (lambda a, b: te.Cast.make(te_bool, a > b)), 123 torch.gt, 124 gen_float_comparison_tensors, 125 ), 126 ( 127 "lt", 128 (lambda a, b: te.Cast.make(te_bool, a < b)), 129 torch.lt, 130 gen_float_comparison_tensors, 131 ), 132 ( 133 "gte", 134 (lambda a, b: te.Cast.make(te_bool, a >= b)), 135 torch.greater_equal, 136 gen_float_comparison_tensors, 137 ), 138 ( 139 "lte", 140 (lambda a, b: te.Cast.make(te_bool, a <= b)), 141 torch.less_equal, 142 gen_float_comparison_tensors, 143 ), 144 # ('neq', (lambda a, b: a != b), None)), # no one-op equivalent 145 # ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test 146] 147 148 149def nnc_relu(A, B): 150 def f(i, j): 151 return torch._C._te.ifThenElse( 152 A.load([i, j]) < torch._C._te.ExprHandle.float(0), 153 torch._C._te.ExprHandle.float(0), 154 A.load([i, j]), 155 ) 156 157 return f 158 159 160def pt_relu(a, b, c): 161 return torch.relu(a) 162 163 164custom_ops = [ 165 ("relu", nnc_relu, pt_relu), 166 # ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu) 167 # ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c)) 168] 169 170 171def gen_custom_torch_fun(fn): 172 def pt_fun(a, b, out): 173 def fun(): 174 return fn(a, b, out) 175 176 return fun 177 178 return pt_fun 179 180 181def normalize_benchmarks(ops): 182 return [i + (None,) if len(i) == 3 else i for i in ops] 183 184 185names = [] 186nnc_fns = [] 187pt_fns = [] 188shape_fns = [] 189 190for nnc_name, pt_op in unary_ops: 191 names.append(nnc_name) 192 nnc_fns.append(gen_unary_nnc_fun(nnc_name)) 193 pt_fns.append(gen_unary_torch_fun(pt_op)) 194 shape_fns.append(None) 195 196for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops): 197 names.append(name) 198 nnc_fns.append(gen_binary_nnc_fun(lmbda)) 199 pt_fns.append(gen_binary_torch_fun(pt_fn)) 200 shape_fns.append(shape_fn) 201 202for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops): 203 names.append(name) 204 nnc_fns.append(lmbda) 205 pt_fns.append(gen_custom_torch_fun(pt_fn)) 206 shape_fns.append(shape_fn) 207 208benchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns)) 209 210 211def run_benchmarks(benchmarks, sizes): 212 df = pd.DataFrame(columns=["name", "N", "M", "nnc_time", "torch_time", "ratio"]) 213 with torch.no_grad(): 214 for name, nnc_fun, torch_fun, shape_fn in benchmarks: 215 for N, M in sizes: 216 iters = int(1e6 / (N + M)) 217 with kernel_arena_scope(): 218 if shape_fn is None: 219 tA = torch.rand(M, N).clamp(0.01, 0.99) 220 tB = torch.rand(M, N).clamp(0.01, 0.99) 221 tX = torch.empty(M, N) 222 tR = torch.empty(M, N) 223 else: 224 tA, tB, tX = shape_fn(M, N) 225 tR = tX.clone() 226 227 def get_nnc_type(dtype): 228 if dtype == torch.float: 229 return torch._C._te.Dtype.Float 230 elif dtype == torch.long: 231 return torch._C._te.Dtype.Long 232 233 dtype = get_nnc_type(tA.dtype) 234 235 dM = torch._C._te.ExprHandle.int(M) 236 dN = torch._C._te.ExprHandle.int(N) 237 238 A = torch._C._te.Placeholder("A", dtype, [dM, dN]) 239 B = torch._C._te.Placeholder("B", dtype, [dM, dN]) 240 241 dim_args = [ 242 torch._C._te.DimArg(*args) for args in [(dM, "m"), (dN, "n")] 243 ] 244 245 compute = nnc_fun(A, B) 246 X = torch._C._te.Compute("X", dim_args, compute) 247 loopnest = torch._C._te.LoopNest([X]) 248 loopnest.prepare_for_codegen() 249 stmt = torch._C._te.simplify(loopnest.root_stmt()) 250 cg = torch._C._te.construct_codegen( 251 "llvm", stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]] 252 ) 253 254 # warmup 255 for _ in range(10): 256 cg.call([tA, tB, tX]) 257 start = time.time() 258 for it in range(iters): 259 cg.call([tA, tB, tX]) 260 time1 = time.time() - start 261 262 fn = torch_fun(tA, tB, tR) 263 # warmup 264 for _ in range(10): 265 tR = fn() 266 start = time.time() 267 for it in range(iters): 268 tR = fn() 269 time2 = time.time() - start 270 271 df = df.append( 272 { 273 "name": name, 274 "N": N, 275 "M": M, 276 "nnc_time": time1, 277 "torch_time": time2, 278 "ratio": time2 / time1, 279 }, 280 ignore_index=True, 281 ) 282 print(name, N, M) 283 284 print(time2 / time1, time1, time2) 285 print() 286 287 def check_correctness(a, b): 288 if not np.allclose(a, b): 289 print(name) 290 assert np.allclose(a, b) 291 292 check_correctness(tX, tR) 293 return df 294 295 296def dump_plot(df, sizes): 297 keys = [] 298 vals = [] 299 indexed = df[df["N"] == df["M"]] 300 for index, row in indexed.iterrows(): 301 keys.append(row["name"]) 302 vals.append(row["ratio"]) 303 304 keys = keys[:: len(sizes)] 305 sns.set(rc={"figure.figsize": (5.0, len(keys) * 0.5)}) 306 307 cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True) 308 np_vals = np.array([vals]).reshape(-1, len(sizes)) 309 g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True) 310 plt.yticks(rotation=0) 311 plt.title("PyTorch performance divided by NNC performance (single core)") 312 plt.xlabel("Size of NxN matrix") 313 plt.ylabel("Operation") 314 g.set_yticklabels(keys) 315 g.set_xticklabels(sizes) 316 317 plt.savefig("nnc.png") 318 319 320if __name__ == "__main__": 321 parser = argparse.ArgumentParser(description="Runs NNC microbenchmarks") 322 parser.add_argument( 323 "--multi-threaded", 324 "--multi_threaded", 325 action="store_true", 326 help="Run with more than one thread", 327 ) 328 args = parser.parse_args() 329 if not args.multi_threaded: 330 torch.set_num_threads(1) 331 332 sizes = [1, 4, 16, 64, 256, 1024] 333 df = run_benchmarks(benchmarks, [(i, i) for i in sizes]) 334 dump_plot(df, sizes) 335