1#!/usr/bin/env python3 2""" 3This lint verifies that every Python test file (file that matches test_*.py or 4*_test.py in the test folder) has a main block which raises an exception or 5calls run_tests to ensure that the test will be run in OSS CI. 6 7Takes ~2 minuters to run without the multiprocessing, probably overkill. 8""" 9 10from __future__ import annotations 11 12import argparse 13import json 14import multiprocessing as mp 15from enum import Enum 16from typing import NamedTuple 17 18import libcst as cst 19import libcst.matchers as m 20 21 22LINTER_CODE = "TEST_HAS_MAIN" 23 24 25class HasMainVisiter(cst.CSTVisitor): 26 def __init__(self) -> None: 27 super().__init__() 28 self.found = False 29 30 def visit_Module(self, node: cst.Module) -> bool: 31 name = m.Name("__name__") 32 main = m.SimpleString('"__main__"') | m.SimpleString("'__main__'") 33 run_test_call = m.Call( 34 func=m.Name("run_tests") | m.Attribute(attr=m.Name("run_tests")) 35 ) 36 # Distributed tests (i.e. MultiProcContinuousTest) calls `run_rank` 37 # instead of `run_tests` in main 38 run_rank_call = m.Call( 39 func=m.Name("run_rank") | m.Attribute(attr=m.Name("run_rank")) 40 ) 41 raise_block = m.Raise() 42 43 # name == main or main == name 44 if_main1 = m.Comparison( 45 name, 46 [m.ComparisonTarget(m.Equal(), main)], 47 ) 48 if_main2 = m.Comparison( 49 main, 50 [m.ComparisonTarget(m.Equal(), name)], 51 ) 52 for child in node.children: 53 if m.matches(child, m.If(test=if_main1 | if_main2)): 54 if m.findall(child, raise_block | run_test_call | run_rank_call): 55 self.found = True 56 break 57 58 return False 59 60 61class LintSeverity(str, Enum): 62 ERROR = "error" 63 WARNING = "warning" 64 ADVICE = "advice" 65 DISABLED = "disabled" 66 67 68class LintMessage(NamedTuple): 69 path: str | None 70 line: int | None 71 char: int | None 72 code: str 73 severity: LintSeverity 74 name: str 75 original: str | None 76 replacement: str | None 77 description: str | None 78 79 80def check_file(filename: str) -> list[LintMessage]: 81 lint_messages = [] 82 83 with open(filename) as f: 84 file = f.read() 85 v = HasMainVisiter() 86 cst.parse_module(file).visit(v) 87 if not v.found: 88 message = ( 89 "Test files need to have a main block which either calls run_tests " 90 + "(to ensure that the tests are run during OSS CI) or raises an exception " 91 + "and added to the blocklist in test/run_test.py" 92 ) 93 lint_messages.append( 94 LintMessage( 95 path=filename, 96 line=None, 97 char=None, 98 code=LINTER_CODE, 99 severity=LintSeverity.ERROR, 100 name="[no-main]", 101 original=None, 102 replacement=None, 103 description=message, 104 ) 105 ) 106 return lint_messages 107 108 109def main() -> None: 110 parser = argparse.ArgumentParser( 111 description="test files should have main block linter", 112 fromfile_prefix_chars="@", 113 ) 114 parser.add_argument( 115 "filenames", 116 nargs="+", 117 help="paths to lint", 118 ) 119 120 args = parser.parse_args() 121 122 pool = mp.Pool(8) 123 lint_messages = pool.map(check_file, args.filenames) 124 pool.close() 125 pool.join() 126 127 flat_lint_messages = [] 128 for sublist in lint_messages: 129 flat_lint_messages.extend(sublist) 130 131 for lint_message in flat_lint_messages: 132 print(json.dumps(lint_message._asdict()), flush=True) 133 134 135if __name__ == "__main__": 136 main() 137