1from __future__ import annotations 2 3from typing import Any 4from unittest import main, TestCase 5 6from tools.alerts.create_alerts import filter_job_names, JobStatus 7 8 9JOB_NAME = "periodic / linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck / test (default, 2, 2, linux.4xlarge.nvidia.gpu)" 10MOCK_TEST_DATA = [ 11 { 12 "sha": "f02f3046571d21b48af3067e308a1e0f29b43af9", 13 "id": 7819529276, 14 "conclusion": "failure", 15 "htmlUrl": "https://github.com/pytorch/pytorch/runs/7819529276?check_suite_focus=true", 16 "logUrl": "https://ossci-raw-job-status.s3.amazonaws.com/log/7819529276", 17 "durationS": 14876, 18 "failureLine": "##[error]The action has timed out.", 19 "failureContext": "", 20 "failureCaptures": ["##[error]The action has timed out."], 21 "failureLineNumber": 83818, 22 "repo": "pytorch/pytorch", 23 }, 24 { 25 "sha": "d0d6b1f2222bf90f478796d84a525869898f55b6", 26 "id": 7818399623, 27 "conclusion": "failure", 28 "htmlUrl": "https://github.com/pytorch/pytorch/runs/7818399623?check_suite_focus=true", 29 "logUrl": "https://ossci-raw-job-status.s3.amazonaws.com/log/7818399623", 30 "durationS": 14882, 31 "failureLine": "##[error]The action has timed out.", 32 "failureContext": "", 33 "failureCaptures": ["##[error]The action has timed out."], 34 "failureLineNumber": 72821, 35 "repo": "pytorch/pytorch", 36 }, 37] 38 39 40class TestGitHubPR(TestCase): 41 # Should fail when jobs are ? ? Fail Fail 42 def test_alert(self) -> None: 43 modified_data: list[Any] = [{}] 44 modified_data.append({}) 45 modified_data.extend(MOCK_TEST_DATA) 46 status = JobStatus(JOB_NAME, modified_data) 47 self.assertTrue(status.should_alert()) 48 49 # test filter job names 50 def test_job_filter(self) -> None: 51 job_names = [ 52 "pytorch_linux_xenial_py3_6_gcc5_4_test", 53 "pytorch_linux_xenial_py3_6_gcc5_4_test2", 54 ] 55 self.assertListEqual( 56 filter_job_names(job_names, ""), 57 job_names, 58 "empty regex should match all jobs", 59 ) 60 self.assertListEqual(filter_job_names(job_names, ".*"), job_names) 61 self.assertListEqual(filter_job_names(job_names, ".*xenial.*"), job_names) 62 self.assertListEqual( 63 filter_job_names(job_names, ".*xenial.*test2"), 64 ["pytorch_linux_xenial_py3_6_gcc5_4_test2"], 65 ) 66 self.assertListEqual(filter_job_names(job_names, ".*xenial.*test3"), []) 67 self.assertRaises( 68 Exception, 69 lambda: filter_job_names(job_names, "["), 70 msg="malformed regex should throw exception", 71 ) 72 73 74if __name__ == "__main__": 75 main() 76