1# mypy: allow-untyped-defs 2from contextlib import contextmanager 3from typing import Any, List, Tuple, cast 4import random 5import torch 6import time 7from torch.utils.benchmark import Timer 8 9def extract_ir(filename: str) -> List[str]: 10 BEGIN = "<GRAPH_EXPORT>" 11 END = "</GRAPH_EXPORT>" 12 pfx = None 13 current = "" 14 graphs = [] 15 with open(filename) as f: 16 split_strs = f.read().split(BEGIN) 17 for i, split_str in enumerate(split_strs): 18 if i == 0: 19 continue 20 end_loc = split_str.find(END) 21 if end_loc == -1: 22 continue 23 s = split_str[:end_loc] 24 pfx = split_strs[i - 1].splitlines()[-1] 25 lines = [x[len(pfx):] for x in s.splitlines(keepends=True)] 26 graphs.append(''.join(lines)) 27 28 return graphs 29 30 31def make_tensor_from_type(inp_type: torch._C.TensorType): 32 size = inp_type.sizes() 33 stride = inp_type.strides() 34 device = inp_type.device() 35 dtype = inp_type.dtype() 36 assert size is not None 37 assert stride is not None 38 assert device is not None 39 assert dtype is not None 40 return torch.empty_strided(size=size, stride=stride, device=device, dtype=dtype) 41 42def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]: 43 graph = torch._C.parse_ir(ir, parse_tensor_constants=True) 44 graph.makeMultiOutputIntoTuple() 45 inputs = [] 46 for inp in graph.inputs(): 47 if isinstance(inp.type(), torch._C.FloatType): 48 inputs.append(random.uniform(.1, 100)) 49 elif isinstance(inp.type(), torch._C.IntType): 50 inputs.append(random.randint(1, 100)) 51 elif isinstance(inp.type(), torch._C.TensorType): 52 tensorType = cast(torch._C.TensorType, inp.type()) 53 inputs.append(make_tensor_from_type(tensorType)) 54 elif isinstance(inp.type(), torch._C.BoolType): 55 inputs.append(random.randint(0, 1) == 1) 56 else: 57 raise NotImplementedError(f"A default value is not implemented for type {inp.type()}") 58 59 func = torch._C._create_function_from_graph("forward", graph) 60 torch._C._jit_pass_erase_shape_information(func.graph) 61 return (func, inputs) 62 63def time_cuda(fn, inputs, test_runs): 64 t = Timer(stmt="fn(*inputs)", globals={"fn": fn, "inputs" : inputs}) 65 times = t.blocked_autorange() 66 return times.median * 1000 # time in ms 67 68def time_cpu(fn, inputs, test_runs): 69 s = time.perf_counter() 70 for _ in range(test_runs): 71 fn(*inputs) 72 e = time.perf_counter() 73 return (e - s) / test_runs * 1000 # time in ms 74 75def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float: 76 graph, _ = load_graph_and_inputs(ir) 77 for _ in range(warmup_runs): 78 graph(*inputs) 79 80 is_cpu = None 81 for input in inputs: 82 if isinstance(input, torch.Tensor): 83 is_cpu = input.device.type == "cpu" 84 break 85 assert is_cpu is not None 86 87 out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs) 88 return out 89 90@contextmanager 91def no_fuser(*args, **kwargs): 92 old_optimize = torch._C._get_graph_executor_optimize(False) 93 try: 94 yield 95 finally: 96 torch._C._get_graph_executor_optimize(old_optimize) 97 98def run_baseline_no_fusion(ir, inputs) -> float: 99 with no_fuser(): 100 return run_test(ir, inputs) 101 102 103def run_nnc(ir, inputs, dynamic) -> float: 104 try: 105 strat = [("DYNAMIC", 10)] if dynamic else [("STATIC", 10)] 106 old_strat = torch.jit.set_fusion_strategy(strat) 107 with torch.jit.fuser("fuser1"): 108 return run_test(ir, inputs) 109 finally: 110 torch.jit.set_fusion_strategy(old_strat) 111 112def run_nvfuser(ir, inputs) -> float: 113 with torch.jit.fuser("fuser2"): 114 return run_test(ir, inputs) 115