1import argparse 2import functools 3import traceback 4from typing import Callable, List, Optional, Tuple 5 6from torch.utils.jit.log_extract import ( 7 extract_ir, 8 load_graph_and_inputs, 9 run_baseline_no_fusion, 10 run_nnc, 11 run_nvfuser, 12) 13 14 15""" 16Usage: 171. Run your script and pipe into a log file 18 PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt 192. Run log_extract: 20 log_extract.py log.txt --nvfuser --nnc-dynamic --nnc-static 21 22You can also extract the list of extracted IR: 23 log_extract.py log.txt --output 24 25Passing in --graphs 0 2 will only run graphs 0 and 2 26""" 27 28 29def test_runners( 30 graphs: List[str], 31 runners: List[Tuple[str, Callable]], 32 graph_set: Optional[List[int]], 33): 34 for i, ir in enumerate(graphs): 35 _, inputs = load_graph_and_inputs(ir) 36 if graph_set and i not in graph_set: 37 continue 38 39 print(f"Running Graph {i}") 40 prev_result = None 41 prev_runner_name = None 42 for runner in runners: 43 runner_name, runner_fn = runner 44 try: 45 result = runner_fn(ir, inputs) 46 if prev_result: 47 improvement = (prev_result / result - 1) * 100 48 print( 49 f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%" 50 ) 51 else: 52 print(f"{runner_name} : {result:.6f} ms") 53 prev_result = result 54 prev_runner_name = runner_name 55 except RuntimeError: 56 print(f" Graph {i} failed for {runner_name} :", traceback.format_exc()) 57 58 59def run(): 60 parser = argparse.ArgumentParser( 61 description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR" 62 ) 63 parser.add_argument("filename", help="Filename of log file") 64 parser.add_argument( 65 "--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser" 66 ) 67 parser.add_argument( 68 "--no-nvfuser", 69 dest="nvfuser", 70 action="store_false", 71 help="DON'T benchmark nvfuser", 72 ) 73 parser.set_defaults(nvfuser=False) 74 parser.add_argument( 75 "--nnc-static", 76 dest="nnc_static", 77 action="store_true", 78 help="benchmark nnc static", 79 ) 80 parser.add_argument( 81 "--no-nnc-static", 82 dest="nnc_static", 83 action="store_false", 84 help="DON'T benchmark nnc static", 85 ) 86 parser.set_defaults(nnc_static=False) 87 88 parser.add_argument( 89 "--nnc-dynamic", 90 dest="nnc_dynamic", 91 action="store_true", 92 help="nnc with dynamic shapes", 93 ) 94 parser.add_argument( 95 "--no-nnc-dynamic", 96 dest="nnc_dynamic", 97 action="store_false", 98 help="DONT't benchmark nnc with dynamic shapes", 99 ) 100 parser.set_defaults(nnc_dynamic=False) 101 102 parser.add_argument( 103 "--baseline", dest="baseline", action="store_true", help="benchmark baseline" 104 ) 105 parser.add_argument( 106 "--no-baseline", 107 dest="baseline", 108 action="store_false", 109 help="DON'T benchmark baseline", 110 ) 111 parser.set_defaults(baseline=False) 112 113 parser.add_argument( 114 "--output", dest="output", action="store_true", help="Output graph IR" 115 ) 116 parser.add_argument( 117 "--no-output", dest="output", action="store_false", help="DON'T output graph IR" 118 ) 119 parser.set_defaults(output=False) 120 121 parser.add_argument( 122 "--graphs", nargs="+", type=int, help="Run only specified graph indices" 123 ) 124 125 args = parser.parse_args() 126 graphs = extract_ir(args.filename) 127 128 graph_set = args.graphs 129 graph_set = graph_set if graph_set else None 130 131 options = [] 132 if args.baseline: 133 options.append(("Baseline no fusion", run_baseline_no_fusion)) 134 if args.nnc_dynamic: 135 options.append(("NNC Dynamic", functools.partial(run_nnc, dynamic=True))) 136 if args.nnc_static: 137 options.append(("NNC Static", functools.partial(run_nnc, dynamic=False))) 138 if args.nvfuser: 139 options.append(("NVFuser", run_nvfuser)) 140 141 test_runners(graphs, options, graph_set) 142 143 if args.output: 144 quoted = [] 145 for i, ir in enumerate(graphs): 146 if graph_set and i not in graph_set: 147 continue 148 quoted.append('"""' + ir + '"""') 149 print("[" + ", ".join(quoted) + "]") 150 151 152if __name__ == "__main__": 153 run() 154