xref: /aosp_15_r20/external/pytorch/tools/linter/adapters/test_has_main_linter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2"""
3This lint verifies that every Python test file (file that matches test_*.py or
4*_test.py in the test folder) has a main block which raises an exception or
5calls run_tests to ensure that the test will be run in OSS CI.
6
7Takes ~2 minuters to run without the multiprocessing, probably overkill.
8"""
9
10from __future__ import annotations
11
12import argparse
13import json
14import multiprocessing as mp
15from enum import Enum
16from typing import NamedTuple
17
18import libcst as cst
19import libcst.matchers as m
20
21
22LINTER_CODE = "TEST_HAS_MAIN"
23
24
25class HasMainVisiter(cst.CSTVisitor):
26    def __init__(self) -> None:
27        super().__init__()
28        self.found = False
29
30    def visit_Module(self, node: cst.Module) -> bool:
31        name = m.Name("__name__")
32        main = m.SimpleString('"__main__"') | m.SimpleString("'__main__'")
33        run_test_call = m.Call(
34            func=m.Name("run_tests") | m.Attribute(attr=m.Name("run_tests"))
35        )
36        # Distributed tests (i.e. MultiProcContinuousTest) calls `run_rank`
37        # instead of `run_tests` in main
38        run_rank_call = m.Call(
39            func=m.Name("run_rank") | m.Attribute(attr=m.Name("run_rank"))
40        )
41        raise_block = m.Raise()
42
43        # name == main or main == name
44        if_main1 = m.Comparison(
45            name,
46            [m.ComparisonTarget(m.Equal(), main)],
47        )
48        if_main2 = m.Comparison(
49            main,
50            [m.ComparisonTarget(m.Equal(), name)],
51        )
52        for child in node.children:
53            if m.matches(child, m.If(test=if_main1 | if_main2)):
54                if m.findall(child, raise_block | run_test_call | run_rank_call):
55                    self.found = True
56                    break
57
58        return False
59
60
61class LintSeverity(str, Enum):
62    ERROR = "error"
63    WARNING = "warning"
64    ADVICE = "advice"
65    DISABLED = "disabled"
66
67
68class LintMessage(NamedTuple):
69    path: str | None
70    line: int | None
71    char: int | None
72    code: str
73    severity: LintSeverity
74    name: str
75    original: str | None
76    replacement: str | None
77    description: str | None
78
79
80def check_file(filename: str) -> list[LintMessage]:
81    lint_messages = []
82
83    with open(filename) as f:
84        file = f.read()
85        v = HasMainVisiter()
86        cst.parse_module(file).visit(v)
87        if not v.found:
88            message = (
89                "Test files need to have a main block which either calls run_tests "
90                + "(to ensure that the tests are run during OSS CI) or raises an exception "
91                + "and added to the blocklist in test/run_test.py"
92            )
93            lint_messages.append(
94                LintMessage(
95                    path=filename,
96                    line=None,
97                    char=None,
98                    code=LINTER_CODE,
99                    severity=LintSeverity.ERROR,
100                    name="[no-main]",
101                    original=None,
102                    replacement=None,
103                    description=message,
104                )
105            )
106    return lint_messages
107
108
109def main() -> None:
110    parser = argparse.ArgumentParser(
111        description="test files should have main block linter",
112        fromfile_prefix_chars="@",
113    )
114    parser.add_argument(
115        "filenames",
116        nargs="+",
117        help="paths to lint",
118    )
119
120    args = parser.parse_args()
121
122    pool = mp.Pool(8)
123    lint_messages = pool.map(check_file, args.filenames)
124    pool.close()
125    pool.join()
126
127    flat_lint_messages = []
128    for sublist in lint_messages:
129        flat_lint_messages.extend(sublist)
130
131    for lint_message in flat_lint_messages:
132        print(json.dumps(lint_message._asdict()), flush=True)
133
134
135if __name__ == "__main__":
136    main()
137