xref: /aosp_15_r20/external/pytorch/tools/stats/upload_test_stat_aggregates.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import ast
5import datetime
6import json
7import os
8import re
9from typing import Any
10
11import rockset  # type: ignore[import]
12
13from tools.stats.upload_stats_lib import upload_to_s3
14
15
16def get_oncall_from_testfile(testfile: str) -> list[str] | None:
17    path = f"test/{testfile}"
18    if not path.endswith(".py"):
19        path += ".py"
20    # get oncall on test file
21    try:
22        with open(path) as f:
23            for line in f:
24                if line.startswith("# Owner(s): "):
25                    possible_lists = re.findall(r"\[.*\]", line)
26                    if len(possible_lists) > 1:
27                        raise Exception("More than one list found")  # noqa: TRY002
28                    elif len(possible_lists) == 0:
29                        raise Exception(  # noqa: TRY002
30                            "No oncalls found or file is badly formatted"
31                        )  # noqa: TRY002
32                    oncalls = ast.literal_eval(possible_lists[0])
33                    return list(oncalls)
34    except Exception as e:
35        if "." in testfile:
36            return [f"module: {testfile.split('.')[0]}"]
37        else:
38            return ["module: unmarked"]
39    return None
40
41
42def get_test_stat_aggregates(date: datetime.date) -> Any:
43    # Initialize the Rockset client with your API key
44    rockset_api_key = os.environ["ROCKSET_API_KEY"]
45    rockset_api_server = "api.rs2.usw2.rockset.com"
46    iso_date = date.isoformat()
47    rs = rockset.RocksetClient(host="api.usw2a1.rockset.com", api_key=rockset_api_key)
48
49    # Define the name of the Rockset collection and lambda function
50    collection_name = "commons"
51    lambda_function_name = "test_insights_per_daily_upload"
52    query_parameters = [
53        rockset.models.QueryParameter(name="startTime", type="string", value=iso_date)
54    ]
55    api_response = rs.QueryLambdas.execute_query_lambda(
56        query_lambda=lambda_function_name,
57        version="692684fa5b37177f",
58        parameters=query_parameters,
59    )
60    for i in range(len(api_response["results"])):
61        oncalls = get_oncall_from_testfile(api_response["results"][i]["test_file"])
62        api_response["results"][i]["oncalls"] = oncalls
63    return json.loads(
64        json.dumps(api_response["results"], indent=4, sort_keys=True, default=str)
65    )
66
67
68if __name__ == "__main__":
69    parser = argparse.ArgumentParser(
70        description="Upload test stat aggregates to Rockset."
71    )
72    parser.add_argument(
73        "--date",
74        type=datetime.date.fromisoformat,
75        help="Date to upload test stat aggregates for (YYYY-MM-DD). Must be in the last 30 days",
76        required=True,
77    )
78    args = parser.parse_args()
79    if args.date < datetime.datetime.now().date() - datetime.timedelta(days=30):
80        raise ValueError("date must be in the last 30 days")
81    data = get_test_stat_aggregates(date=args.date)
82    upload_to_s3(
83        bucket_name="torchci-aggregated-stats",
84        key=f"test_data_aggregates/{str(args.date)}",
85        docs=data,
86    )
87