1""" 2Update commited CSV files used as reference points by dynamo/inductor CI. 3 4Currently only cares about graph breaks, so only saves those columns. 5 6Hardcodes a list of job names and artifacts per job, but builds the lookup 7by querying github sha and finding associated github actions workflow ID and CI jobs, 8downloading artifact zips, extracting CSVs and filtering them. 9 10Usage: 11 12python benchmarks/dynamo/ci_expected_accuracy.py <sha of pytorch commit that has completed inductor benchmark jobs> 13 14Known limitations: 15- doesn't handle 'retry' jobs in CI, if the same hash has more than one set of artifacts, gets the first one 16""" 17 18import argparse 19import json 20import os 21import subprocess 22import sys 23import urllib 24from io import BytesIO 25from itertools import product 26from pathlib import Path 27from urllib.request import urlopen 28from zipfile import ZipFile 29 30import pandas as pd 31import requests 32 33 34# Note: the public query url targets this rockset lambda: 35# https://console.rockset.com/lambdas/details/commons.artifacts 36ARTIFACTS_QUERY_URL = "https://api.usw2a1.rockset.com/v1/public/shared_lambdas/4ca0033e-0117-41f5-b043-59cde19eff35" 37CSV_LINTER = str( 38 Path(__file__).absolute().parent.parent.parent.parent 39 / "tools/linter/adapters/no_merge_conflict_csv_linter.py" 40) 41 42 43def query_job_sha(repo, sha): 44 params = { 45 "parameters": [ 46 {"name": "sha", "type": "string", "value": sha}, 47 {"name": "repo", "type": "string", "value": repo}, 48 ] 49 } 50 51 r = requests.post(url=ARTIFACTS_QUERY_URL, json=params) 52 data = r.json() 53 return data["results"] 54 55 56def parse_job_name(job_str): 57 return (part.strip() for part in job_str.split("/")) 58 59 60def parse_test_str(test_str): 61 return (part.strip() for part in test_str[6:].strip(")").split(",")) 62 63 64S3_BASE_URL = "https://gha-artifacts.s3.amazonaws.com" 65 66 67def get_artifacts_urls(results, suites): 68 urls = {} 69 for r in results: 70 if ( 71 r["workflowName"] in ("inductor", "inductor-periodic") 72 and "test" in r["jobName"] 73 ): 74 config_str, test_str = parse_job_name(r["jobName"]) 75 suite, shard_id, num_shards, machine, *_ = parse_test_str(test_str) 76 workflowId = r["workflowId"] 77 id = r["id"] 78 runAttempt = r["runAttempt"] 79 80 if suite in suites: 81 artifact_filename = f"test-reports-test-{suite}-{shard_id}-{num_shards}-{machine}_{id}.zip" 82 s3_url = f"{S3_BASE_URL}/{repo}/{workflowId}/{runAttempt}/artifact/{artifact_filename}" 83 urls[(suite, int(shard_id))] = s3_url 84 print(f"{suite} {shard_id}, {num_shards}: {s3_url}") 85 return urls 86 87 88def normalize_suite_filename(suite_name): 89 strs = suite_name.split("_") 90 subsuite = strs[-1] 91 if "timm" in subsuite: 92 subsuite = subsuite.replace("timm", "timm_models") 93 94 return subsuite 95 96 97def download_artifacts_and_extract_csvs(urls): 98 dataframes = {} 99 for (suite, shard), url in urls.items(): 100 try: 101 resp = urlopen(url) 102 subsuite = normalize_suite_filename(suite) 103 artifact = ZipFile(BytesIO(resp.read())) 104 for phase in ("training", "inference"): 105 name = f"test/test-reports/{phase}_{subsuite}.csv" 106 try: 107 df = pd.read_csv(artifact.open(name)) 108 df["graph_breaks"] = df["graph_breaks"].fillna(0).astype(int) 109 prev_df = dataframes.get((suite, phase), None) 110 dataframes[(suite, phase)] = ( 111 pd.concat([prev_df, df]) if prev_df is not None else df 112 ) 113 except KeyError: 114 print( 115 f"Warning: Unable to find {name} in artifacts file from {url}, continuing" 116 ) 117 except urllib.error.HTTPError: 118 print(f"Unable to download {url}, perhaps the CI job isn't finished?") 119 120 return dataframes 121 122 123def write_filtered_csvs(root_path, dataframes): 124 for (suite, phase), df in dataframes.items(): 125 out_fn = os.path.join(root_path, f"{suite}_{phase}.csv") 126 df.to_csv(out_fn, index=False, columns=["name", "accuracy", "graph_breaks"]) 127 apply_lints(out_fn) 128 129 130def apply_lints(filename): 131 patch = json.loads(subprocess.check_output([sys.executable, CSV_LINTER, filename])) 132 if patch.get("replacement"): 133 with open(filename) as fd: 134 data = fd.read().replace(patch["original"], patch["replacement"]) 135 with open(filename, "w") as fd: 136 fd.write(data) 137 138 139if __name__ == "__main__": 140 parser = argparse.ArgumentParser( 141 description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter 142 ) 143 144 parser.add_argument("sha") 145 args = parser.parse_args() 146 147 repo = "pytorch/pytorch" 148 149 suites = { 150 f"{a}_{b}" 151 for a, b in product( 152 [ 153 "aot_eager", 154 "aot_inductor", 155 "cpu_aot_inductor", 156 "cpu_aot_inductor_amp_freezing", 157 "cpu_aot_inductor_freezing", 158 "cpu_inductor", 159 "cpu_inductor_amp_freezing", 160 "cpu_inductor_freezing", 161 "dynamic_aot_eager", 162 "dynamic_cpu_aot_inductor", 163 "dynamic_cpu_aot_inductor_amp_freezing", 164 "dynamic_cpu_aot_inductor_freezing", 165 "dynamic_cpu_inductor", 166 "dynamic_inductor", 167 "dynamo_eager", 168 "inductor", 169 ], 170 ["huggingface", "timm", "torchbench"], 171 ) 172 } 173 174 root_path = "benchmarks/dynamo/ci_expected_accuracy/" 175 assert os.path.exists(root_path), f"cd <pytorch root> and ensure {root_path} exists" 176 177 results = query_job_sha(repo, args.sha) 178 urls = get_artifacts_urls(results, suites) 179 dataframes = download_artifacts_and_extract_csvs(urls) 180 write_filtered_csvs(root_path, dataframes) 181 print("Success. Now, confirm the changes to .csvs and `git add` them if satisfied.") 182