xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/parse_logs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import csv
2import os
3import re
4import sys
5
6
7# This script takes the logs produced by the benchmark scripts (e.g.,
8# torchbench.py) and parses it into a CSV file that summarizes what
9# is failing and why.  It is kept separate from the benchmark script
10# emitting a more structured output as it is often more convenient
11# to iterate quickly on log files offline instead of having to make
12# a change to the benchmark script and then do a full sweep to see
13# the updates.
14#
15# This script is not very well written, feel free to rewrite it as necessary
16
17assert len(sys.argv) == 2
18
19full_log = open(sys.argv[1]).read()
20
21# If the log contains a gist URL, extract it so we can include it in the CSV
22gist_url = ""
23m = re.search(r"https://gist.github.com/[a-f0-9]+", full_log)
24if m is not None:
25    gist_url = m.group(0)
26
27# Split the log into an entry per benchmark
28entries = re.split(
29    r"(?:cuda (?:train|eval) +([^ ]+)|WARNING:root:([^ ]+) failed to load)", full_log
30)[1:]
31# Entries schema example:
32# `['hf_Bert', None, '
33#  PASS\nTIMING: entire_frame_compile:1.80925 backend_compile:6e-05\nDynamo produced 1 graph(s) covering 367 ops\n']`
34
35
36def chunker(seq, size):
37    return (seq[pos : pos + size] for pos in range(0, len(seq), size))
38
39
40c = 0
41i = 0
42
43out = csv.DictWriter(
44    sys.stdout,
45    [
46        "bench",
47        "name",
48        "result",
49        "component",
50        "context",
51        "explain",
52        "frame_time",
53        "backend_time",
54        "graph_count",
55        "op_count",
56        "graph_breaks",
57        "unique_graph_breaks",
58    ],
59    dialect="excel",
60)
61out.writeheader()
62out.writerow({"explain": gist_url})
63
64
65# Sometimes backtraces will be in third party code, which results
66# in very long file names.  Delete the absolute path in this case.
67def normalize_file(f):
68    if "site-packages/" in f:
69        return f.split("site-packages/", 2)[1]
70    else:
71        return os.path.relpath(f)
72
73
74# Assume we run torchbench, huggingface, timm_models in that order
75# (as output doesn't say which suite the benchmark is part of)
76# TODO: make this more robust
77
78bench = "torchbench"
79
80# 3 = 1 + number of matches in the entries split regex
81for name, name2, log in chunker(entries, 3):
82    if name is None:
83        name = name2
84    if name.startswith("Albert"):
85        bench = "huggingface"
86    elif name.startswith("adv_inc"):
87        bench = "timm_models"
88
89    # Payload that will go into the csv
90    r = "UNKNOWN"
91    explain = ""
92    component = ""
93    context = ""
94
95    if "PASS" in log:
96        r = "PASS"
97    if "TIMEOUT" in log:
98        r = "FAIL TIMEOUT"
99    if "Accuracy failed" in log:
100        r = "FAIL ACCURACY"
101
102    # Attempt to extract out useful information from the traceback
103
104    log = log.split(
105        "The above exception was the direct cause of the following exception"
106    )[0]
107    split = log.split("Traceback (most recent call last)", maxsplit=1)
108    if len(split) == 2:
109        log = split[1]
110    log = log.split("Original traceback:")[0]
111    m = re.search(
112        r'File "([^"]+)", line ([0-9]+), in .+\n +(.+)\n([A-Za-z]+(?:Error|Exception|NotImplementedError): ?.*)',
113        log,
114    )
115
116    if m is not None:
117        r = "FAIL"
118        component = f"{normalize_file(m.group(1))}:{m.group(2)}"
119        context = m.group(3)
120        explain = f"{m.group(4)}"
121    else:
122        m = re.search(
123            r'File "([^"]+)", line ([0-9]+), in .+\n +(.+)\nAssertionError', log
124        )
125        if m is not None:
126            r = "FAIL"
127            component = f"{normalize_file(m.group(1))}:{m.group(2)}"
128            context = m.group(3)
129            explain = "AssertionError"
130
131    # Sometimes, the benchmark will say FAIL without any useful info
132    # See https://github.com/pytorch/torchdynamo/issues/1910
133    if "FAIL" in log:
134        r = "FAIL"
135
136    if r == "UNKNOWN":
137        c += 1
138
139    backend_time = None
140    frame_time = None
141    if "TIMING:" in log:
142        result = re.search("TIMING:(.*)\n", log).group(1)
143        split_str = result.split("backend_compile:")
144        if len(split_str) == 2:
145            backend_time = float(split_str[1])
146            frame_time = float(split_str[0].split("entire_frame_compile:")[1])
147
148    if "STATS:" in log:
149        result = re.search("STATS:(.*)\n", log).group(1)
150        # call_* op count: 970 | FakeTensor.__torch_dispatch__:35285 | ProxyTorchDispatchMode.__torch_dispatch__:13339
151        split_all = result.split("|")
152        # TODO: rewrite this to work with arbitrarily many stats
153
154    graph_count = None
155    op_count = None
156    graph_breaks = None
157    unique_graph_breaks = None
158    if m := re.search(
159        r"Dynamo produced (\d+) graphs covering (\d+) ops with (\d+) graph breaks \((\d+) unique\)",
160        log,
161    ):
162        graph_count = m.group(1)
163        op_count = m.group(2)
164        graph_breaks = m.group(3)
165        unique_graph_breaks = m.group(4)
166
167    # If the context string is too long, don't put it in the CSV.
168    # This is a hack to try to make it more likely that Google Sheets will
169    # offer to split columns
170    if len(context) > 78:
171        context = ""
172
173    # Temporary file names are meaningless, report it's generated code in this
174    # case
175    if "/tmp/" in component:
176        component = "generated code"
177        context = ""
178
179    out.writerow(
180        {
181            "bench": bench,
182            "name": name,
183            "result": r,
184            "component": component,
185            "context": context,
186            "explain": explain,
187            "frame_time": frame_time,
188            "backend_time": backend_time,
189            "graph_count": graph_count,
190            "op_count": op_count,
191            "graph_breaks": graph_breaks,
192            "unique_graph_breaks": unique_graph_breaks,
193        }
194    )
195    i += 1
196
197if c:
198    print(f"failed to classify {c} entries", file=sys.stderr)
199