xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/ci_expected_accuracy/update_expected.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2Update commited CSV files used as reference points by dynamo/inductor CI.
3
4Currently only cares about graph breaks, so only saves those columns.
5
6Hardcodes a list of job names and artifacts per job, but builds the lookup
7by querying github sha and finding associated github actions workflow ID and CI jobs,
8downloading artifact zips, extracting CSVs and filtering them.
9
10Usage:
11
12python benchmarks/dynamo/ci_expected_accuracy.py <sha of pytorch commit that has completed inductor benchmark jobs>
13
14Known limitations:
15- doesn't handle 'retry' jobs in CI, if the same hash has more than one set of artifacts, gets the first one
16"""
17
18import argparse
19import json
20import os
21import subprocess
22import sys
23import urllib
24from io import BytesIO
25from itertools import product
26from pathlib import Path
27from urllib.request import urlopen
28from zipfile import ZipFile
29
30import pandas as pd
31import requests
32
33
34# Note: the public query url targets this rockset lambda:
35# https://console.rockset.com/lambdas/details/commons.artifacts
36ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/4ca0033e-0117-41f5-b043-59cde19eff35"
37CSV_LINTER = str(
38    Path(__file__).absolute().parent.parent.parent.parent
39    / "tools/linter/adapters/no_merge_conflict_csv_linter.py"
40)
41
42
43def query_job_sha(repo, sha):
44    params = {
45        "parameters": [
46            {"name": "sha", "type": "string", "value": sha},
47            {"name": "repo", "type": "string", "value": repo},
48        ]
49    }
50
51    r = requests.post(url=ARTIFACTS_QUERY_URL, json=params)
52    data = r.json()
53    return data["results"]
54
55
56def parse_job_name(job_str):
57    return (part.strip() for part in job_str.split("/"))
58
59
60def parse_test_str(test_str):
61    return (part.strip() for part in test_str[6:].strip(")").split(","))
62
63
64S3_BASE_URL = "https://gha-artifacts.s3.amazonaws.com"
65
66
67def get_artifacts_urls(results, suites):
68    urls = {}
69    for r in results:
70        if (
71            r["workflowName"] in ("inductor", "inductor-periodic")
72            and "test" in r["jobName"]
73        ):
74            config_str, test_str = parse_job_name(r["jobName"])
75            suite, shard_id, num_shards, machine, *_ = parse_test_str(test_str)
76            workflowId = r["workflowId"]
77            id = r["id"]
78            runAttempt = r["runAttempt"]
79
80            if suite in suites:
81                artifact_filename = f"test-reports-test-{suite}-{shard_id}-{num_shards}-{machine}_{id}.zip"
82                s3_url = f"{S3_BASE_URL}/{repo}/{workflowId}/{runAttempt}/artifact/{artifact_filename}"
83                urls[(suite, int(shard_id))] = s3_url
84                print(f"{suite} {shard_id}, {num_shards}: {s3_url}")
85    return urls
86
87
88def normalize_suite_filename(suite_name):
89    strs = suite_name.split("_")
90    subsuite = strs[-1]
91    if "timm" in subsuite:
92        subsuite = subsuite.replace("timm", "timm_models")
93
94    return subsuite
95
96
97def download_artifacts_and_extract_csvs(urls):
98    dataframes = {}
99    for (suite, shard), url in urls.items():
100        try:
101            resp = urlopen(url)
102            subsuite = normalize_suite_filename(suite)
103            artifact = ZipFile(BytesIO(resp.read()))
104            for phase in ("training", "inference"):
105                name = f"test/test-reports/{phase}_{subsuite}.csv"
106                try:
107                    df = pd.read_csv(artifact.open(name))
108                    df["graph_breaks"] = df["graph_breaks"].fillna(0).astype(int)
109                    prev_df = dataframes.get((suite, phase), None)
110                    dataframes[(suite, phase)] = (
111                        pd.concat([prev_df, df]) if prev_df is not None else df
112                    )
113                except KeyError:
114                    print(
115                        f"Warning: Unable to find {name} in artifacts file from {url}, continuing"
116                    )
117        except urllib.error.HTTPError:
118            print(f"Unable to download {url}, perhaps the CI job isn't finished?")
119
120    return dataframes
121
122
123def write_filtered_csvs(root_path, dataframes):
124    for (suite, phase), df in dataframes.items():
125        out_fn = os.path.join(root_path, f"{suite}_{phase}.csv")
126        df.to_csv(out_fn, index=False, columns=["name", "accuracy", "graph_breaks"])
127        apply_lints(out_fn)
128
129
130def apply_lints(filename):
131    patch = json.loads(subprocess.check_output([sys.executable, CSV_LINTER, filename]))
132    if patch.get("replacement"):
133        with open(filename) as fd:
134            data = fd.read().replace(patch["original"], patch["replacement"])
135        with open(filename, "w") as fd:
136            fd.write(data)
137
138
139if __name__ == "__main__":
140    parser = argparse.ArgumentParser(
141        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
142    )
143
144    parser.add_argument("sha")
145    args = parser.parse_args()
146
147    repo = "pytorch/pytorch"
148
149    suites = {
150        f"{a}_{b}"
151        for a, b in product(
152            [
153                "aot_eager",
154                "aot_inductor",
155                "cpu_aot_inductor",
156                "cpu_aot_inductor_amp_freezing",
157                "cpu_aot_inductor_freezing",
158                "cpu_inductor",
159                "cpu_inductor_amp_freezing",
160                "cpu_inductor_freezing",
161                "dynamic_aot_eager",
162                "dynamic_cpu_aot_inductor",
163                "dynamic_cpu_aot_inductor_amp_freezing",
164                "dynamic_cpu_aot_inductor_freezing",
165                "dynamic_cpu_inductor",
166                "dynamic_inductor",
167                "dynamo_eager",
168                "inductor",
169            ],
170            ["huggingface", "timm", "torchbench"],
171        )
172    }
173
174    root_path = "benchmarks/dynamo/ci_expected_accuracy/"
175    assert os.path.exists(root_path), f"cd <pytorch root> and ensure {root_path} exists"
176
177    results = query_job_sha(repo, args.sha)
178    urls = get_artifacts_urls(results, suites)
179    dataframes = download_artifacts_and_extract_csvs(urls)
180    write_filtered_csvs(root_path, dataframes)
181    print("Success. Now, confirm the changes to .csvs and `git add` them if satisfied.")
182