xref: /aosp_15_r20/external/pytorch/tools/test/test_create_alerts.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from typing import Any
4from unittest import main, TestCase
5
6from tools.alerts.create_alerts import filter_job_names, JobStatus
7
8
9JOB_NAME = "periodic / linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck / test (default, 2, 2, linux.4xlarge.nvidia.gpu)"
10MOCK_TEST_DATA = [
11    {
12        "sha": "f02f3046571d21b48af3067e308a1e0f29b43af9",
13        "id": 7819529276,
14        "conclusion": "failure",
15        "htmlUrl": "https://github.com/pytorch/pytorch/runs/7819529276?check_suite_focus=true",
16        "logUrl": "https://ossci-raw-job-status.s3.amazonaws.com/log/7819529276",
17        "durationS": 14876,
18        "failureLine": "##[error]The action has timed out.",
19        "failureContext": "",
20        "failureCaptures": ["##[error]The action has timed out."],
21        "failureLineNumber": 83818,
22        "repo": "pytorch/pytorch",
23    },
24    {
25        "sha": "d0d6b1f2222bf90f478796d84a525869898f55b6",
26        "id": 7818399623,
27        "conclusion": "failure",
28        "htmlUrl": "https://github.com/pytorch/pytorch/runs/7818399623?check_suite_focus=true",
29        "logUrl": "https://ossci-raw-job-status.s3.amazonaws.com/log/7818399623",
30        "durationS": 14882,
31        "failureLine": "##[error]The action has timed out.",
32        "failureContext": "",
33        "failureCaptures": ["##[error]The action has timed out."],
34        "failureLineNumber": 72821,
35        "repo": "pytorch/pytorch",
36    },
37]
38
39
40class TestGitHubPR(TestCase):
41    # Should fail when jobs are ? ? Fail Fail
42    def test_alert(self) -> None:
43        modified_data: list[Any] = [{}]
44        modified_data.append({})
45        modified_data.extend(MOCK_TEST_DATA)
46        status = JobStatus(JOB_NAME, modified_data)
47        self.assertTrue(status.should_alert())
48
49    # test filter job names
50    def test_job_filter(self) -> None:
51        job_names = [
52            "pytorch_linux_xenial_py3_6_gcc5_4_test",
53            "pytorch_linux_xenial_py3_6_gcc5_4_test2",
54        ]
55        self.assertListEqual(
56            filter_job_names(job_names, ""),
57            job_names,
58            "empty regex should match all jobs",
59        )
60        self.assertListEqual(filter_job_names(job_names, ".*"), job_names)
61        self.assertListEqual(filter_job_names(job_names, ".*xenial.*"), job_names)
62        self.assertListEqual(
63            filter_job_names(job_names, ".*xenial.*test2"),
64            ["pytorch_linux_xenial_py3_6_gcc5_4_test2"],
65        )
66        self.assertListEqual(filter_job_names(job_names, ".*xenial.*test3"), [])
67        self.assertRaises(
68            Exception,
69            lambda: filter_job_names(job_names, "["),
70            msg="malformed regex should throw exception",
71        )
72
73
74if __name__ == "__main__":
75    main()
76