xref: /aosp_15_r20/external/pytorch/tools/testing/update_slow_tests.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import json
2import os
3import subprocess
4import time
5from pathlib import Path
6from typing import Any, cast, Dict, Optional, Tuple
7
8import requests
9import rockset  # type: ignore[import]
10
11
12REPO_ROOT = Path(__file__).resolve().parent.parent.parent
13QUERY = """
14WITH most_recent_strict_commits AS (
15    SELECT
16        push.head_commit.id as sha,
17    FROM
18        commons.push
19    WHERE
20        push.ref = 'refs/heads/viable/strict'
21        AND push.repository.full_name = 'pytorch/pytorch'
22    ORDER BY
23        push._event_time DESC
24    LIMIT
25        3
26), workflows AS (
27    SELECT
28        id
29    FROM
30        commons.workflow_run w
31        INNER JOIN most_recent_strict_commits c on w.head_sha = c.sha
32    WHERE
33        w.name != 'periodic'
34),
35job AS (
36    SELECT
37        j.id
38    FROM
39        commons.workflow_job j
40        INNER JOIN workflows w on w.id = j.run_id
41    WHERE
42        j.name NOT LIKE '%asan%'
43),
44duration_per_job AS (
45    SELECT
46        test_run.classname,
47        test_run.name,
48        job.id,
49        SUM(time) as time
50    FROM
51        commons.test_run_s3 test_run
52        /* `test_run` is ginormous and `job` is small, so lookup join is essential */
53        INNER JOIN job ON test_run.job_id = job.id HINT(join_strategy = lookup)
54    WHERE
55        /* cpp tests do not populate `file` for some reason. */
56        /* Exclude them as we don't include them in our slow test infra */
57        test_run.file IS NOT NULL
58        /* do some more filtering to cut down on the test_run size */
59        AND test_run.skipped IS NULL
60        AND test_run.failure IS NULL
61        AND test_run.error IS NULL
62    GROUP BY
63        test_run.classname,
64        test_run.name,
65        job.id
66)
67SELECT
68    CONCAT(
69        name,
70        ' (__main__.',
71        classname,
72        ')'
73    ) as test_name,
74    AVG(time) as avg_duration_sec
75FROM
76    duration_per_job
77GROUP BY
78    CONCAT(
79        name,
80        ' (__main__.',
81        classname,
82        ')'
83    )
84HAVING
85    AVG(time) > 60.0
86ORDER BY
87    test_name
88"""
89
90
91UPDATEBOT_TOKEN = os.environ["UPDATEBOT_TOKEN"]
92PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"]
93
94
95def git_api(
96    url: str, params: Dict[str, str], type: str = "get", token: str = UPDATEBOT_TOKEN
97) -> Any:
98    headers = {
99        "Accept": "application/vnd.github.v3+json",
100        "Authorization": f"token {token}",
101    }
102    if type == "post":
103        return requests.post(
104            f"https://api.github.com{url}",
105            data=json.dumps(params),
106            headers=headers,
107        ).json()
108    elif type == "patch":
109        return requests.patch(
110            f"https://api.github.com{url}",
111            data=json.dumps(params),
112            headers=headers,
113        ).json()
114    else:
115        return requests.get(
116            f"https://api.github.com{url}",
117            params=params,
118            headers=headers,
119        ).json()
120
121
122def make_pr(source_repo: str, params: Dict[str, Any]) -> int:
123    response = git_api(f"/repos/{source_repo}/pulls", params, type="post")
124    print(f"made pr {response['html_url']}")
125    return cast(int, response["number"])
126
127
128def approve_pr(source_repo: str, pr_number: int) -> None:
129    params = {"event": "APPROVE"}
130    # use pytorchbot to approve the pr
131    git_api(
132        f"/repos/{source_repo}/pulls/{pr_number}/reviews",
133        params,
134        type="post",
135        token=PYTORCHBOT_TOKEN,
136    )
137
138
139def make_comment(source_repo: str, pr_number: int, msg: str) -> None:
140    params = {"body": msg}
141    # comment with pytorchbot because pytorchmergebot gets ignored
142    git_api(
143        f"/repos/{source_repo}/issues/{pr_number}/comments",
144        params,
145        type="post",
146        token=PYTORCHBOT_TOKEN,
147    )
148
149
150def search_for_open_pr(
151    source_repo: str, search_string: str
152) -> Optional[Tuple[int, str]]:
153    params = {
154        "q": f"is:pr is:open in:title author:pytorchupdatebot repo:{source_repo} {search_string}",
155        "sort": "created",
156    }
157    response = git_api("/search/issues", params)
158    if response["total_count"] != 0:
159        # pr does exist
160        pr_num = response["items"][0]["number"]
161        link = response["items"][0]["html_url"]
162        response = git_api(f"/repos/{source_repo}/pulls/{pr_num}", {})
163        branch_name = response["head"]["ref"]
164        print(
165            f"pr does exist, number is {pr_num}, branch name is {branch_name}, link is {link}"
166        )
167        return pr_num, branch_name
168    return None
169
170
171if __name__ == "__main__":
172    rs_client = rockset.RocksetClient(
173        host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
174    )
175
176    results = rs_client.sql(QUERY).results
177    slow_tests = {row["test_name"]: row["avg_duration_sec"] for row in results}
178
179    with open(REPO_ROOT / "test" / "slow_tests.json", "w") as f:
180        json.dump(slow_tests, f, indent=2)
181
182    branch_name = f"update_slow_tests_{int(time.time())}"
183    pr_num = None
184
185    open_pr = search_for_open_pr("pytorch/pytorch", "Update slow tests")
186    if open_pr is not None:
187        pr_num, branch_name = open_pr
188
189    subprocess.run(["git", "checkout", "-b", branch_name], cwd=REPO_ROOT)
190    subprocess.run(["git", "add", "test/slow_tests.json"], cwd=REPO_ROOT)
191    subprocess.run(["git", "commit", "-m", "Update slow tests"], cwd=REPO_ROOT)
192    subprocess.run(
193        f"git push --set-upstream origin {branch_name} -f".split(), cwd=REPO_ROOT
194    )
195
196    params = {
197        "title": "Update slow tests",
198        "head": branch_name,
199        "base": "main",
200        "body": "This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/"
201        + ".github/workflows/weekly.yml).\nUpdate the list of slow tests.",
202    }
203    if pr_num is None:
204        # no existing pr, so make a new one and approve it
205        pr_num = make_pr("pytorch/pytorch", params)
206        time.sleep(5)
207        approve_pr("pytorch/pytorch", pr_num)
208    make_comment("pytorch/pytorch", pr_num, "@pytorchbot merge")
209