1import json 2import os 3import subprocess 4import time 5from pathlib import Path 6from typing import Any, cast, Dict, Optional, Tuple 7 8import requests 9import rockset # type: ignore[import] 10 11 12REPO_ROOT = Path(__file__).resolve().parent.parent.parent 13QUERY = """ 14WITH most_recent_strict_commits AS ( 15 SELECT 16 push.head_commit.id as sha, 17 FROM 18 commons.push 19 WHERE 20 push.ref = 'refs/heads/viable/strict' 21 AND push.repository.full_name = 'pytorch/pytorch' 22 ORDER BY 23 push._event_time DESC 24 LIMIT 25 3 26), workflows AS ( 27 SELECT 28 id 29 FROM 30 commons.workflow_run w 31 INNER JOIN most_recent_strict_commits c on w.head_sha = c.sha 32 WHERE 33 w.name != 'periodic' 34), 35job AS ( 36 SELECT 37 j.id 38 FROM 39 commons.workflow_job j 40 INNER JOIN workflows w on w.id = j.run_id 41 WHERE 42 j.name NOT LIKE '%asan%' 43), 44duration_per_job AS ( 45 SELECT 46 test_run.classname, 47 test_run.name, 48 job.id, 49 SUM(time) as time 50 FROM 51 commons.test_run_s3 test_run 52 /* `test_run` is ginormous and `job` is small, so lookup join is essential */ 53 INNER JOIN job ON test_run.job_id = job.id HINT(join_strategy = lookup) 54 WHERE 55 /* cpp tests do not populate `file` for some reason. */ 56 /* Exclude them as we don't include them in our slow test infra */ 57 test_run.file IS NOT NULL 58 /* do some more filtering to cut down on the test_run size */ 59 AND test_run.skipped IS NULL 60 AND test_run.failure IS NULL 61 AND test_run.error IS NULL 62 GROUP BY 63 test_run.classname, 64 test_run.name, 65 job.id 66) 67SELECT 68 CONCAT( 69 name, 70 ' (__main__.', 71 classname, 72 ')' 73 ) as test_name, 74 AVG(time) as avg_duration_sec 75FROM 76 duration_per_job 77GROUP BY 78 CONCAT( 79 name, 80 ' (__main__.', 81 classname, 82 ')' 83 ) 84HAVING 85 AVG(time) > 60.0 86ORDER BY 87 test_name 88""" 89 90 91UPDATEBOT_TOKEN = os.environ["UPDATEBOT_TOKEN"] 92PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"] 93 94 95def git_api( 96 url: str, params: Dict[str, str], type: str = "get", token: str = UPDATEBOT_TOKEN 97) -> Any: 98 headers = { 99 "Accept": "application/vnd.github.v3+json", 100 "Authorization": f"token {token}", 101 } 102 if type == "post": 103 return requests.post( 104 f"https://api.github.com{url}", 105 data=json.dumps(params), 106 headers=headers, 107 ).json() 108 elif type == "patch": 109 return requests.patch( 110 f"https://api.github.com{url}", 111 data=json.dumps(params), 112 headers=headers, 113 ).json() 114 else: 115 return requests.get( 116 f"https://api.github.com{url}", 117 params=params, 118 headers=headers, 119 ).json() 120 121 122def make_pr(source_repo: str, params: Dict[str, Any]) -> int: 123 response = git_api(f"/repos/{source_repo}/pulls", params, type="post") 124 print(f"made pr {response['html_url']}") 125 return cast(int, response["number"]) 126 127 128def approve_pr(source_repo: str, pr_number: int) -> None: 129 params = {"event": "APPROVE"} 130 # use pytorchbot to approve the pr 131 git_api( 132 f"/repos/{source_repo}/pulls/{pr_number}/reviews", 133 params, 134 type="post", 135 token=PYTORCHBOT_TOKEN, 136 ) 137 138 139def make_comment(source_repo: str, pr_number: int, msg: str) -> None: 140 params = {"body": msg} 141 # comment with pytorchbot because pytorchmergebot gets ignored 142 git_api( 143 f"/repos/{source_repo}/issues/{pr_number}/comments", 144 params, 145 type="post", 146 token=PYTORCHBOT_TOKEN, 147 ) 148 149 150def search_for_open_pr( 151 source_repo: str, search_string: str 152) -> Optional[Tuple[int, str]]: 153 params = { 154 "q": f"is:pr is:open in:title author:pytorchupdatebot repo:{source_repo} {search_string}", 155 "sort": "created", 156 } 157 response = git_api("/search/issues", params) 158 if response["total_count"] != 0: 159 # pr does exist 160 pr_num = response["items"][0]["number"] 161 link = response["items"][0]["html_url"] 162 response = git_api(f"/repos/{source_repo}/pulls/{pr_num}", {}) 163 branch_name = response["head"]["ref"] 164 print( 165 f"pr does exist, number is {pr_num}, branch name is {branch_name}, link is {link}" 166 ) 167 return pr_num, branch_name 168 return None 169 170 171if __name__ == "__main__": 172 rs_client = rockset.RocksetClient( 173 host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"] 174 ) 175 176 results = rs_client.sql(QUERY).results 177 slow_tests = {row["test_name"]: row["avg_duration_sec"] for row in results} 178 179 with open(REPO_ROOT / "test" / "slow_tests.json", "w") as f: 180 json.dump(slow_tests, f, indent=2) 181 182 branch_name = f"update_slow_tests_{int(time.time())}" 183 pr_num = None 184 185 open_pr = search_for_open_pr("pytorch/pytorch", "Update slow tests") 186 if open_pr is not None: 187 pr_num, branch_name = open_pr 188 189 subprocess.run(["git", "checkout", "-b", branch_name], cwd=REPO_ROOT) 190 subprocess.run(["git", "add", "test/slow_tests.json"], cwd=REPO_ROOT) 191 subprocess.run(["git", "commit", "-m", "Update slow tests"], cwd=REPO_ROOT) 192 subprocess.run( 193 f"git push --set-upstream origin {branch_name} -f".split(), cwd=REPO_ROOT 194 ) 195 196 params = { 197 "title": "Update slow tests", 198 "head": branch_name, 199 "base": "main", 200 "body": "This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/" 201 + ".github/workflows/weekly.yml).\nUpdate the list of slow tests.", 202 } 203 if pr_num is None: 204 # no existing pr, so make a new one and approve it 205 pr_num = make_pr("pytorch/pytorch", params) 206 time.sleep(5) 207 approve_pr("pytorch/pytorch", pr_num) 208 make_comment("pytorch/pytorch", pr_num, "@pytorchbot merge") 209