#!/usr/bin/env python3 import argparse import inspect import sys import numpy as np import tabulate import torch import torch._inductor from torch._dynamo.backends.cudagraphs import cudagraphs_inner from torch._dynamo.testing import same from torch._inductor.compile_fx import compile_fx from torch._inductor.utils import timed aten = torch.ops.aten try: import test.test_torchinductor as tti except ImportError: tti = None def compute_speedups(args, models, example_inputs): expected = models[0](*example_inputs) for model in models[1:]: actual = model(*example_inputs) assert same(actual, expected), expected[0] - actual[0] timings = np.zeros((args.repeat, len(models)), np.float64) for rep in range(args.repeat): # interleave the runs to handle frequency scaling and load changes for m, model in enumerate(models): timings[rep, m] = timed(model, example_inputs) median = np.median(timings, axis=0) return (median[0] / median[1:]).tolist() def microbenchmark(args, model, example_inputs): compiled_fn = compile_fx(torch.fx.symbolic_trace(model), example_inputs) cudagraphs_eager = cudagraphs_inner(model, example_inputs, copy_outputs=False) cudagraphs_jit = cudagraphs_inner( torch.jit.trace(model, example_inputs), example_inputs, copy_outputs=False ) return compute_speedups( args, [cudagraphs_eager, cudagraphs_jit, compiled_fn], example_inputs, ) class MyModel1(torch.nn.Module): def __init__(self): super().__init__() self.model = torch.nn.Sequential( torch.nn.Linear(1024, 1024), torch.nn.ReLU(), ) def forward(self, input): # return (self.model(input) + 1,) return (self.model(input),) class MyModel2(torch.nn.Module): def forward(self, x, y): # return x / (torch.abs(x) + 1.0), return (x + y,) class MicroBenchmarks: @staticmethod def add(a, b): return (a + b,) @staticmethod def scale(x, m, d): return ((x - m) / torch.clip(d, 1e-4),) @staticmethod def abs_norm(x): return (x / (torch.abs(x) + 1),) @staticmethod def add_relu_softmax(x, a): return (torch.softmax(torch.relu(x + a), -1),) @staticmethod def sum(a, b): return ((a + b).sum(),) @staticmethod def view(x): return (aten.alias(x),) def main(): parser = argparse.ArgumentParser() parser.add_argument( "--filter", "-k", action="append", help="filter benchmarks with regexp" ) parser.add_argument( "--exclude", "-x", action="append", help="filter benchmarks with regexp" ) parser.add_argument("--devices", "-d", action="append", help="cpu or cuda") parser.add_argument("--size", "-s", action="append", help="cpu or cuda") parser.add_argument( "--repeat", "-n", type=int, default=30, help="number of timing runs" ) parser.add_argument( "--threads", "-t", type=int, help="number of threads to use for eager" ) parser.add_argument( "--verbose", "-v", action="store_true", help="enable verbose debug printouts" ) parser.add_argument( "--nvfuser", action="store_true", help="enable nvfuser globally" ) parser.add_argument("--transpose", action="store_true", help="transpose one input") parser.add_argument("--broadcast", action="store_true", help="broadcast one input") args = parser.parse_args() # defaults args.devices = args.devices or ["cpu", "cuda"] args.filter = args.filter or [r"."] args.exclude = args.exclude or [r"^$"] args.size = args.size or [64, 256, 1024, 4096, 8192] if args.nvfuser: torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(True) else: torch._C._jit_override_can_fuse_on_cpu(torch._C._llvm_enabled()) torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_set_texpr_fuser_enabled(True) if torch.cuda.is_available(): torch._C._jit_set_nvfuser_enabled(False) if args.threads: torch.set_num_threads(args.threads) torch._inductor.config.cpp.threads = args.threads if args.verbose: torch._inductor.config.debug = True torch._inductor.config.triton.autotune_pointwise = True rows = [] for model in (MicroBenchmarks.sum, MicroBenchmarks.view): nargs = len(inspect.signature(model).parameters) for device in args.devices: for n in args.size: n = int(n) sys.stdout.write(f"{model.__name__:10} {device:4} {n:5} ") sys.stdout.flush() inputs = [torch.rand((n, n), device=device) for _ in range(nargs)] if args.broadcast: inputs[-1] = torch.rand((1, n), device=device) if args.transpose: inputs[-1] = inputs[-1].transpose(0, 1) result = microbenchmark(args, model, inputs) rows.append([model.__name__, device, str(n)] + result) print(" ".join(f"{v:.2f}x" for v in result)) print( tabulate.tabulate( rows, headers=[ "model", "dev", "n", "ts", "inductor", ], ) ) if __name__ == "__main__": main()