1import inspect 2import itertools 3import sys 4import time 5 6import click 7 8import torch 9 10 11torch.set_num_threads(1) 12torch._C._debug_set_fusion_group_inlining(False) 13 14 15def rand(*shape): 16 return torch.rand(*shape).mul(16).add(1) 17 18 19# ------------------------------------------------------------------------------ 20# Shape test cases 21# ------------------------------------------------------------------------------ 22def scalar(): 23 return (rand(1), rand(1)) 24 25 26def small(): 27 return (rand(32), rand(32)) 28 29 30def small_2d(): 31 return (rand(1, 32), rand(1, 32)) 32 33 34def small_broadcast(): 35 return (rand(4, 32), rand(32)) 36 37 38def medium(): 39 return (rand(32, 12, 64, 64), rand(32, 12, 64, 64)) 40 41 42def medium_sliced(): 43 return (rand(32, 12, 64, 64)[..., ::2], rand(32, 12, 64, 64)[..., ::2]) 44 45 46def medium_transpose(): 47 return ( 48 rand(32, 12, 64, 64).transpose(-1, -2), 49 rand(32, 12, 64, 64).transpose(-1, -2), 50 ) 51 52 53def medium2(): 54 return (rand(32, 3, 224, 224), rand(32, 3, 224, 224)) 55 56 57def medium3d(): 58 return (rand(16, 32, 64), rand(16, 32, 64)) 59 60 61def medium_channels_last(): 62 return ( 63 rand(32, 3, 224, 224).to(memory_format=torch.channels_last), 64 rand(32, 3, 224, 224).to(memory_format=torch.channels_last), 65 ) 66 67 68def medium_broadcast(): 69 return (rand(32, 12, 64, 64), rand(64)) 70 71 72def medium_broadcast_channels_last(): 73 return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), rand(3, 1, 1)) 74 75 76def large(): 77 return (rand(8192, 8192), rand(8192, 8192)) 78 79 80def large_transpose(): 81 return (rand(8192, 8192).transpose(0, 1), rand(8192, 8192).transpose(0, 1)) 82 83 84def large_channels_last(): 85 return ( 86 rand(32, 32, 256, 256).to(memory_format=torch.channels_last), 87 rand(32, 32, 256, 256).to(memory_format=torch.channels_last), 88 ) 89 90 91def broadcast_narrow_57611(): 92 return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2)) 93 94 95def large_broadcast_66816(): 96 return (rand(64, 8, 256, 162), rand(256, 162)) 97 98 99# ------------------------------------------------------------------------------ 100# Operator test cases 101# ------------------------------------------------------------------------------ 102def add(a, b): 103 return 3 * a + b 104 105 106def sub(a, b): 107 return 3 * a - b 108 109 110def mul(a, b): 111 return 3 * a * b 112 113 114def div(a, b): 115 return 3 * a / b 116 117 118def relu(a): 119 return (3 * a).relu() 120 121 122def sigmoid(a): 123 return (3 * a).sigmoid() 124 125 126def tanh(a): 127 return (3 * a).tanh() 128 129 130def log(a): 131 return (3 * a).log() 132 133 134def exp(a): 135 return (3 * a).exp() 136 137 138def square(a): 139 return (3 * a) ** 2 140 141 142def fma(a, b): 143 return a * b + b 144 145 146def mul_mul_add_66816(a, b, c): 147 return (a * b) + (a * c) 148 149 150def hardswish_int(a): 151 return a * (a + 3).clamp(0, 6) / 6 152 153 154def hardswish(a): 155 return a * (a + 3).clamp(0.0, 6.0) / 6 156 157 158def native_hardswish(a): 159 return torch._C._nn.hardswish(a * 3) 160 161 162def softplus(a): 163 return (a * 1.0).exp().log1p() / 1.0 164 165 166def mish(a): 167 return a * ((a * 1.0).exp().log1p() / 1.0).tanh() 168 169 170SHAPES = [ 171 scalar, 172 small, 173 small_2d, 174 small_broadcast, 175 medium, 176 medium2, 177 medium3d, 178 medium_sliced, 179 medium_transpose, 180 medium_channels_last, 181 medium_broadcast, 182 medium_broadcast_channels_last, 183 large, 184 large_transpose, 185 large_channels_last, 186 broadcast_narrow_57611, 187 large_broadcast_66816, 188] 189 190OPERATORS = [ 191 add, 192 sub, 193 mul, 194 div, 195 relu, 196 sigmoid, 197 tanh, 198 log, 199 exp, 200 square, 201 fma, 202 mul_mul_add_66816, 203 hardswish_int, 204 hardswish, 205 native_hardswish, 206 softplus, 207 mish, 208] 209 210 211def time_cpu(fn, args, iters): 212 s = time.perf_counter() 213 for _ in range(iters): 214 fn(*args) 215 e = time.perf_counter() 216 return e - s 217 218 219def time_cuda(fn, args, iters): 220 start = torch.cuda.Event(enable_timing=True) 221 end = torch.cuda.Event(enable_timing=True) 222 start.record() 223 for _ in range(iters): 224 fn(*args) 225 end.record() 226 torch.cuda.synchronize() 227 return start.elapsed_time(end) / 1e3 228 229 230def benchmark_with_timer(fn, args, timer): 231 timer(fn, args, 3) 232 calibration = timer(fn, args, 1) 233 iters = int(1.0 / calibration) 234 return timer(fn, args, iters) / iters 235 236 237def benchmark(fn, args): 238 timer = time_cpu if args[0].device.type == "cpu" else time_cuda 239 return benchmark_with_timer(fn, args, timer) 240 241 242def micros(s): 243 return f"{s * 1e6:.1f}" 244 245 246def with_nvfuser(): 247 torch._C._jit_override_can_fuse_on_cpu(False) 248 torch._C._jit_override_can_fuse_on_gpu(False) 249 torch._C._jit_set_texpr_fuser_enabled(False) 250 torch._C._jit_set_nvfuser_enabled(True) 251 torch._C._jit_set_profiling_executor(True) 252 torch._C._jit_set_profiling_mode(True) 253 254 255def with_nnc(): 256 torch._C._jit_override_can_fuse_on_cpu(True) 257 torch._C._jit_override_can_fuse_on_gpu(True) 258 torch._C._jit_set_texpr_fuser_enabled(True) 259 torch._C._jit_set_nvfuser_enabled(False) 260 torch._C._jit_set_profiling_executor(True) 261 torch._C._jit_set_profiling_mode(True) 262 263 264def with_legacy(): 265 torch._C._jit_override_can_fuse_on_cpu(True) 266 torch._C._jit_override_can_fuse_on_gpu(True) 267 torch._C._jit_set_texpr_fuser_enabled(False) 268 torch._C._jit_set_nvfuser_enabled(False) 269 torch._C._jit_set_profiling_executor(False) 270 torch._C._jit_set_profiling_mode(False) 271 272 273@click.command() 274@click.option("--operators", default=None) 275@click.option("--shapes", default=None) 276def run_benchmarks(operators, shapes): 277 if operators is None: 278 operators = OPERATORS 279 else: 280 operators = [globals()[k] for k in operators.split(",")] 281 if shapes is None: 282 shapes = SHAPES 283 else: 284 shapes = [globals()[k] for k in shapes.split(",")] 285 286 print("fuser,device,operator,shape,time") 287 results = [] 288 for shape, operator in itertools.product(shapes, operators): 289 nargs = len(inspect.signature(operator).parameters) 290 args = shape() 291 if nargs > len(args): 292 args = list(args) 293 args += [args[-1]] * (nargs - len(args)) 294 args = args[:nargs] 295 args = [arg.to("cuda") for arg in args] 296 297 result = benchmark(operator, args) 298 print( 299 ",".join( 300 [ 301 "eager", 302 args[0].device.type, 303 operator.__name__, 304 shape.__name__, 305 micros(result), 306 ] 307 ) 308 ) 309 310 def bench(name): 311 nnc_op = torch.jit.trace(operator, args) 312 result = benchmark(nnc_op, args) 313 print( 314 ",".join( 315 [ 316 name, 317 args[0].device.type, 318 operator.__name__, 319 shape.__name__, 320 micros(result), 321 ] 322 ) 323 ) 324 sys.stdout.flush() 325 326 with_nnc() 327 bench("nnc") 328 with_nvfuser() 329 bench("nvfuser") 330 with_legacy() 331 bench("legacy") 332 333 334if __name__ == "__main__": 335 run_benchmarks() 336