xref: /aosp_15_r20/external/pytorch/scripts/jit/log_extract.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import functools
3import traceback
4from typing import Callable, List, Optional, Tuple
5
6from torch.utils.jit.log_extract import (
7    extract_ir,
8    load_graph_and_inputs,
9    run_baseline_no_fusion,
10    run_nnc,
11    run_nvfuser,
12)
13
14
15"""
16Usage:
171. Run your script and pipe into a log file
18  PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt
192. Run log_extract:
20  log_extract.py log.txt --nvfuser --nnc-dynamic --nnc-static
21
22You can also extract the list of extracted IR:
23  log_extract.py log.txt --output
24
25Passing in --graphs 0 2 will only run graphs 0 and 2
26"""
27
28
29def test_runners(
30    graphs: List[str],
31    runners: List[Tuple[str, Callable]],
32    graph_set: Optional[List[int]],
33):
34    for i, ir in enumerate(graphs):
35        _, inputs = load_graph_and_inputs(ir)
36        if graph_set and i not in graph_set:
37            continue
38
39        print(f"Running Graph {i}")
40        prev_result = None
41        prev_runner_name = None
42        for runner in runners:
43            runner_name, runner_fn = runner
44            try:
45                result = runner_fn(ir, inputs)
46                if prev_result:
47                    improvement = (prev_result / result - 1) * 100
48                    print(
49                        f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%"
50                    )
51                else:
52                    print(f"{runner_name} : {result:.6f} ms")
53                prev_result = result
54                prev_runner_name = runner_name
55            except RuntimeError:
56                print(f"  Graph {i} failed for {runner_name} :", traceback.format_exc())
57
58
59def run():
60    parser = argparse.ArgumentParser(
61        description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
62    )
63    parser.add_argument("filename", help="Filename of log file")
64    parser.add_argument(
65        "--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser"
66    )
67    parser.add_argument(
68        "--no-nvfuser",
69        dest="nvfuser",
70        action="store_false",
71        help="DON'T benchmark nvfuser",
72    )
73    parser.set_defaults(nvfuser=False)
74    parser.add_argument(
75        "--nnc-static",
76        dest="nnc_static",
77        action="store_true",
78        help="benchmark nnc static",
79    )
80    parser.add_argument(
81        "--no-nnc-static",
82        dest="nnc_static",
83        action="store_false",
84        help="DON'T benchmark nnc static",
85    )
86    parser.set_defaults(nnc_static=False)
87
88    parser.add_argument(
89        "--nnc-dynamic",
90        dest="nnc_dynamic",
91        action="store_true",
92        help="nnc with dynamic shapes",
93    )
94    parser.add_argument(
95        "--no-nnc-dynamic",
96        dest="nnc_dynamic",
97        action="store_false",
98        help="DONT't benchmark nnc with dynamic shapes",
99    )
100    parser.set_defaults(nnc_dynamic=False)
101
102    parser.add_argument(
103        "--baseline", dest="baseline", action="store_true", help="benchmark baseline"
104    )
105    parser.add_argument(
106        "--no-baseline",
107        dest="baseline",
108        action="store_false",
109        help="DON'T benchmark baseline",
110    )
111    parser.set_defaults(baseline=False)
112
113    parser.add_argument(
114        "--output", dest="output", action="store_true", help="Output graph IR"
115    )
116    parser.add_argument(
117        "--no-output", dest="output", action="store_false", help="DON'T output graph IR"
118    )
119    parser.set_defaults(output=False)
120
121    parser.add_argument(
122        "--graphs", nargs="+", type=int, help="Run only specified graph indices"
123    )
124
125    args = parser.parse_args()
126    graphs = extract_ir(args.filename)
127
128    graph_set = args.graphs
129    graph_set = graph_set if graph_set else None
130
131    options = []
132    if args.baseline:
133        options.append(("Baseline no fusion", run_baseline_no_fusion))
134    if args.nnc_dynamic:
135        options.append(("NNC Dynamic", functools.partial(run_nnc, dynamic=True)))
136    if args.nnc_static:
137        options.append(("NNC Static", functools.partial(run_nnc, dynamic=False)))
138    if args.nvfuser:
139        options.append(("NVFuser", run_nvfuser))
140
141    test_runners(graphs, options, graph_set)
142
143    if args.output:
144        quoted = []
145        for i, ir in enumerate(graphs):
146            if graph_set and i not in graph_set:
147                continue
148            quoted.append('"""' + ir + '"""')
149        print("[" + ", ".join(quoted) + "]")
150
151
152if __name__ == "__main__":
153    run()
154