xref: /aosp_15_r20/external/pytorch/scripts/release_notes/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import json
2import locale
3import os
4import re
5import subprocess
6from collections import namedtuple
7from dataclasses import dataclass
8from pathlib import Path
9
10import requests
11
12
13@dataclass
14class CategoryGroup:
15    name: str
16    categories: list
17
18
19frontend_categories = [
20    "meta",
21    "nn",
22    "linalg",
23    "cpp",
24    "python",
25    "complex",
26    "vmap",
27    "autograd",
28    "build",
29    "memory_format",
30    "foreach",
31    "dataloader",
32    "sparse",
33    "nested tensor",
34    "optimizer",
35]
36
37pytorch_2_categories = [
38    "dynamo",
39    "inductor",
40]
41
42# These will all get mapped to quantization
43quantization = CategoryGroup(
44    name="quantization",
45    categories=[
46        "quantization",
47        "AO frontend",
48        "AO Pruning",
49    ],
50)
51
52# Distributed has a number of release note labels we want to map to one
53distributed = CategoryGroup(
54    name="distributed",
55    categories=[
56        "distributed",
57        "distributed (c10d)",
58        "distributed (composable)",
59        "distributed (ddp)",
60        "distributed (fsdp)",
61        "distributed (rpc)",
62        "distributed (sharded)",
63    ],
64)
65
66categories = (
67    [
68        "Uncategorized",
69        "lazy",
70        "hub",
71        "mobile",
72        "jit",
73        "visualization",
74        "onnx",
75        "caffe2",
76        "amd",
77        "rocm",
78        "cuda",
79        "cpu",
80        "cudnn",
81        "xla",
82        "benchmark",
83        "profiler",
84        "performance_as_product",
85        "package",
86        "dispatcher",
87        "releng",
88        "fx",
89        "code_coverage",
90        "vulkan",
91        "skip",
92        "composability",
93        # 2.0 release
94        "mps",
95        "intel",
96        "functorch",
97        "gnn",
98        "distributions",
99        "serialization",
100    ]
101    + [f"{category}_frontend" for category in frontend_categories]
102    + pytorch_2_categories
103    + [quantization.name]
104    + [distributed.name]
105)
106
107
108topics = [
109    "bc breaking",
110    "deprecation",
111    "new features",
112    "improvements",
113    "bug fixes",
114    "performance",
115    "docs",
116    "devs",
117    "Untopiced",
118    "not user facing",
119    "security",
120]
121
122
123Features = namedtuple(
124    "Features",
125    ["title", "body", "pr_number", "files_changed", "labels", "author", "accepters"],
126)
127
128
129def dict_to_features(dct):
130    return Features(
131        title=dct["title"],
132        body=dct["body"],
133        pr_number=dct["pr_number"],
134        files_changed=dct["files_changed"],
135        labels=dct["labels"],
136        author=dct["author"],
137        accepters=tuple(dct["accepters"]),
138    )
139
140
141def features_to_dict(features):
142    return dict(features._asdict())
143
144
145def run(command):
146    """Returns (return-code, stdout, stderr)"""
147    p = subprocess.Popen(
148        command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True
149    )
150    output, err = p.communicate()
151    rc = p.returncode
152    enc = locale.getpreferredencoding()
153    output = output.decode(enc)
154    err = err.decode(enc)
155    return rc, output.strip(), err.strip()
156
157
158def commit_body(commit_hash):
159    cmd = f"git log -n 1 --pretty=format:%b {commit_hash}"
160    ret, out, err = run(cmd)
161    return out if ret == 0 else None
162
163
164def commit_title(commit_hash):
165    cmd = f"git log -n 1 --pretty=format:%s {commit_hash}"
166    ret, out, err = run(cmd)
167    return out if ret == 0 else None
168
169
170def commit_files_changed(commit_hash):
171    cmd = f"git diff-tree --no-commit-id --name-only -r {commit_hash}"
172    ret, out, err = run(cmd)
173    return out.split("\n") if ret == 0 else None
174
175
176def parse_pr_number(body, commit_hash, title):
177    regex = r"Pull Request resolved: https://github.com/pytorch/pytorch/pull/([0-9]+)"
178    matches = re.findall(regex, body)
179    if len(matches) == 0:
180        if "revert" not in title.lower() and "updating submodules" not in title.lower():
181            print(f"[{commit_hash}: {title}] Could not parse PR number, ignoring PR")
182        return None
183    if len(matches) > 1:
184        print(f"[{commit_hash}: {title}] Got two PR numbers, using the first one")
185        return matches[0]
186    return matches[0]
187
188
189def get_ghstack_token():
190    pattern = "github_oauth = (.*)"
191    with open(Path("~/.ghstackrc").expanduser(), "r+") as f:
192        config = f.read()
193    matches = re.findall(pattern, config)
194    if len(matches) == 0:
195        raise RuntimeError("Can't find a github oauth token")
196    return matches[0]
197
198
199def get_token():
200    env_token = os.environ.get("GITHUB_TOKEN")
201    if env_token is not None:
202        print("using GITHUB_TOKEN from environment variable")
203        return env_token
204    else:
205        return get_ghstack_token()
206
207
208token = get_token()
209
210headers = {"Authorization": f"token {token}"}
211
212
213def run_query(query):
214    request = requests.post(
215        "https://api.github.com/graphql", json={"query": query}, headers=headers
216    )
217    if request.status_code == 200:
218        return request.json()
219    else:
220        raise Exception(  # noqa: TRY002
221            f"Query failed to run by returning code of {request.status_code}. {request.json()}"
222        )
223
224
225_ERRORS = []
226_MAX_ERROR_LEN = 20
227
228
229def github_data(pr_number):
230    query = (
231        """
232    {
233      repository(owner: "pytorch", name: "pytorch") {
234        pullRequest(number: %s ) {
235          author {
236            login
237          }
238          reviews(last: 5, states: APPROVED) {
239            nodes {
240              author {
241                login
242              }
243            }
244          }
245          labels(first: 10) {
246            edges {
247              node {
248                name
249              }
250            }
251          }
252        }
253      }
254    }
255    """  # noqa: UP031
256        % pr_number
257    )
258    query = run_query(query)
259    if query.get("errors"):
260        global _ERRORS
261        _ERRORS.append(query.get("errors"))
262        if len(_ERRORS) < _MAX_ERROR_LEN:
263            return [], "None", ()
264        else:
265            raise Exception(  # noqa: TRY002
266                f"Got {_MAX_ERROR_LEN} errors: {_ERRORS}, please check if"
267                " there is something wrong"
268            )
269    edges = query["data"]["repository"]["pullRequest"]["labels"]["edges"]
270    labels = [edge["node"]["name"] for edge in edges]
271    author = query["data"]["repository"]["pullRequest"]["author"]["login"]
272    nodes = query["data"]["repository"]["pullRequest"]["reviews"]["nodes"]
273
274    # using set to dedup multiple accepts from same accepter
275    accepters = {node["author"]["login"] for node in nodes}
276    accepters = tuple(sorted(accepters))
277
278    return labels, author, accepters
279
280
281def get_features(commit_hash):
282    title, body, files_changed = (
283        commit_title(commit_hash),
284        commit_body(commit_hash),
285        commit_files_changed(commit_hash),
286    )
287    pr_number = parse_pr_number(body, commit_hash, title)
288    labels = []
289    author = ""
290    accepters = ()
291    if pr_number is not None:
292        labels, author, accepters = github_data(pr_number)
293    result = Features(title, body, pr_number, files_changed, labels, author, accepters)
294    return result
295
296
297_commit_data_cache = None
298
299
300def get_commit_data_cache(path="results/data.json"):
301    global _commit_data_cache
302    if _commit_data_cache is None:
303        _commit_data_cache = _CommitDataCache(path)
304    return _commit_data_cache
305
306
307class _CommitDataCache:
308    def __init__(self, path):
309        self.path = path
310        self.data = {}
311        if os.path.exists(path):
312            self.data = self.read_from_disk()
313        else:
314            os.makedirs(Path(path).parent, exist_ok=True)
315
316    def get(self, commit):
317        if commit not in self.data.keys():
318            # Fetch and cache the data
319            self.data[commit] = get_features(commit)
320            self.write_to_disk()
321        return self.data[commit]
322
323    def read_from_disk(self):
324        with open(self.path) as f:
325            data = json.load(f)
326            data = {commit: dict_to_features(dct) for commit, dct in data.items()}
327        return data
328
329    def write_to_disk(self):
330        data = {commit: features._asdict() for commit, features in self.data.items()}
331        with open(self.path, "w") as f:
332            json.dump(data, f)
333