xref: /aosp_15_r20/external/pytorch/tools/testing/test_run.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom copy import copy
4*da0073e9SAndroid Build Coastguard Workerfrom functools import total_ordering
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Iterable
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerclass TestRun:
9*da0073e9SAndroid Build Coastguard Worker    """
10*da0073e9SAndroid Build Coastguard Worker    TestRun defines the set of tests that should be run together in a single pytest invocation.
11*da0073e9SAndroid Build Coastguard Worker    It'll either be a whole test file or a subset of a test file.
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker    This class assumes that we won't always know the full set of TestClasses in a test file.
14*da0073e9SAndroid Build Coastguard Worker    So it's designed to include or exclude explicitly requested TestClasses, while having accepting
15*da0073e9SAndroid Build Coastguard Worker    that there will be an ambiguous set of "unknown" test classes that are not expliclty called out.
16*da0073e9SAndroid Build Coastguard Worker    Those manifest as tests that haven't been explicitly excluded.
17*da0073e9SAndroid Build Coastguard Worker    """
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    test_file: str
20*da0073e9SAndroid Build Coastguard Worker    _excluded: frozenset[str]  # Tests that should be excluded from this test run
21*da0073e9SAndroid Build Coastguard Worker    _included: frozenset[
22*da0073e9SAndroid Build Coastguard Worker        str
23*da0073e9SAndroid Build Coastguard Worker    ]  # If non-empy, only these tests should be run in this test run
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker    def __init__(
26*da0073e9SAndroid Build Coastguard Worker        self,
27*da0073e9SAndroid Build Coastguard Worker        name: str,
28*da0073e9SAndroid Build Coastguard Worker        excluded: Iterable[str] | None = None,
29*da0073e9SAndroid Build Coastguard Worker        included: Iterable[str] | None = None,
30*da0073e9SAndroid Build Coastguard Worker    ) -> None:
31*da0073e9SAndroid Build Coastguard Worker        if excluded and included:
32*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Can't specify both included and excluded")
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker        ins = set(included or [])
35*da0073e9SAndroid Build Coastguard Worker        exs = set(excluded or [])
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker        if "::" in name:
38*da0073e9SAndroid Build Coastguard Worker            assert (
39*da0073e9SAndroid Build Coastguard Worker                not included and not excluded
40*da0073e9SAndroid Build Coastguard Worker            ), "Can't specify included or excluded tests when specifying a test class in the file name"
41*da0073e9SAndroid Build Coastguard Worker            self.test_file, test_class = name.split("::")
42*da0073e9SAndroid Build Coastguard Worker            ins.add(test_class)
43*da0073e9SAndroid Build Coastguard Worker        else:
44*da0073e9SAndroid Build Coastguard Worker            self.test_file = name
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker        self._excluded = frozenset(exs)
47*da0073e9SAndroid Build Coastguard Worker        self._included = frozenset(ins)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    @staticmethod
50*da0073e9SAndroid Build Coastguard Worker    def empty() -> TestRun:
51*da0073e9SAndroid Build Coastguard Worker        return TestRun("")
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    def is_empty(self) -> bool:
54*da0073e9SAndroid Build Coastguard Worker        # Lack of a test_file means that this is an empty run,
55*da0073e9SAndroid Build Coastguard Worker        # which means there is nothing to run. It's the zero.
56*da0073e9SAndroid Build Coastguard Worker        return not self.test_file
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    def is_full_file(self) -> bool:
59*da0073e9SAndroid Build Coastguard Worker        return not self._included and not self._excluded
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    def included(self) -> frozenset[str]:
62*da0073e9SAndroid Build Coastguard Worker        return self._included
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def excluded(self) -> frozenset[str]:
65*da0073e9SAndroid Build Coastguard Worker        return self._excluded
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    def get_pytest_filter(self) -> str:
68*da0073e9SAndroid Build Coastguard Worker        if self._included:
69*da0073e9SAndroid Build Coastguard Worker            return " or ".join(sorted(self._included))
70*da0073e9SAndroid Build Coastguard Worker        elif self._excluded:
71*da0073e9SAndroid Build Coastguard Worker            return f"not ({' or '.join(sorted(self._excluded))})"
72*da0073e9SAndroid Build Coastguard Worker        else:
73*da0073e9SAndroid Build Coastguard Worker            return ""
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def contains(self, test: TestRun) -> bool:
76*da0073e9SAndroid Build Coastguard Worker        if self.test_file != test.test_file:
77*da0073e9SAndroid Build Coastguard Worker            return False
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        if self.is_full_file():
80*da0073e9SAndroid Build Coastguard Worker            return True  # self contains all tests
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker        if test.is_full_file():
83*da0073e9SAndroid Build Coastguard Worker            return False  # test contains all tests, but self doesn't
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        # Does self exclude a subset of what test excludes?
86*da0073e9SAndroid Build Coastguard Worker        if test._excluded:
87*da0073e9SAndroid Build Coastguard Worker            return test._excluded.issubset(self._excluded)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        # Does self include everything test includes?
90*da0073e9SAndroid Build Coastguard Worker        if self._included:
91*da0073e9SAndroid Build Coastguard Worker            return test._included.issubset(self._included)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        # Getting to here means that test includes and self excludes
94*da0073e9SAndroid Build Coastguard Worker        # Does self exclude anything test includes? If not, we're good
95*da0073e9SAndroid Build Coastguard Worker        return not self._excluded.intersection(test._included)
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    def __copy__(self) -> TestRun:
98*da0073e9SAndroid Build Coastguard Worker        return TestRun(self.test_file, excluded=self._excluded, included=self._included)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    def __bool__(self) -> bool:
101*da0073e9SAndroid Build Coastguard Worker        return not self.is_empty()
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    def __repr__(self) -> str:
104*da0073e9SAndroid Build Coastguard Worker        r: str = f"RunTest({self.test_file}"
105*da0073e9SAndroid Build Coastguard Worker        r += f", included: {self._included}" if self._included else ""
106*da0073e9SAndroid Build Coastguard Worker        r += f", excluded: {self._excluded}" if self._excluded else ""
107*da0073e9SAndroid Build Coastguard Worker        r += ")"
108*da0073e9SAndroid Build Coastguard Worker        return r
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
111*da0073e9SAndroid Build Coastguard Worker        if self.is_empty():
112*da0073e9SAndroid Build Coastguard Worker            return "Empty"
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        pytest_filter = self.get_pytest_filter()
115*da0073e9SAndroid Build Coastguard Worker        if pytest_filter:
116*da0073e9SAndroid Build Coastguard Worker            return self.test_file + ", " + pytest_filter
117*da0073e9SAndroid Build Coastguard Worker        return self.test_file
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, other: object) -> bool:
120*da0073e9SAndroid Build Coastguard Worker        if not isinstance(other, TestRun):
121*da0073e9SAndroid Build Coastguard Worker            return False
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        ret = self.test_file == other.test_file
124*da0073e9SAndroid Build Coastguard Worker        ret = ret and self._included == other._included
125*da0073e9SAndroid Build Coastguard Worker        ret = ret and self._excluded == other._excluded
126*da0073e9SAndroid Build Coastguard Worker        return ret
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    def __hash__(self) -> int:
129*da0073e9SAndroid Build Coastguard Worker        return hash((self.test_file, self._included, self._excluded))
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    def __or__(self, other: TestRun) -> TestRun:
132*da0073e9SAndroid Build Coastguard Worker        """
133*da0073e9SAndroid Build Coastguard Worker        To OR/Union test runs means to run all the tests that either of the two runs specify.
134*da0073e9SAndroid Build Coastguard Worker        """
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        # Is any file empty?
137*da0073e9SAndroid Build Coastguard Worker        if self.is_empty():
138*da0073e9SAndroid Build Coastguard Worker            return other
139*da0073e9SAndroid Build Coastguard Worker        if other.is_empty():
140*da0073e9SAndroid Build Coastguard Worker            return copy(self)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        # If not, ensure we have the same file
143*da0073e9SAndroid Build Coastguard Worker        assert (
144*da0073e9SAndroid Build Coastguard Worker            self.test_file == other.test_file
145*da0073e9SAndroid Build Coastguard Worker        ), f"Can't exclude {other} from {self} because they're not the same test file"
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        # 4 possible cases:
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        # 1. Either file is the full file, so union is everything
150*da0073e9SAndroid Build Coastguard Worker        if self.is_full_file() or other.is_full_file():
151*da0073e9SAndroid Build Coastguard Worker            # The union is the whole file
152*da0073e9SAndroid Build Coastguard Worker            return TestRun(self.test_file)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        # 2. Both files only run what's in _included, so union is the union of the two sets
155*da0073e9SAndroid Build Coastguard Worker        if self._included and other._included:
156*da0073e9SAndroid Build Coastguard Worker            return TestRun(
157*da0073e9SAndroid Build Coastguard Worker                self.test_file, included=self._included.union(other._included)
158*da0073e9SAndroid Build Coastguard Worker            )
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        # 3. Both files only exclude what's in _excluded, so union is the intersection of the two sets
161*da0073e9SAndroid Build Coastguard Worker        if self._excluded and other._excluded:
162*da0073e9SAndroid Build Coastguard Worker            return TestRun(
163*da0073e9SAndroid Build Coastguard Worker                self.test_file, excluded=self._excluded.intersection(other._excluded)
164*da0073e9SAndroid Build Coastguard Worker            )
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        # 4. One file includes and the other excludes, so we then continue excluding the _excluded set minus
167*da0073e9SAndroid Build Coastguard Worker        #    whatever is in the _included set
168*da0073e9SAndroid Build Coastguard Worker        included = self._included | other._included
169*da0073e9SAndroid Build Coastguard Worker        excluded = self._excluded | other._excluded
170*da0073e9SAndroid Build Coastguard Worker        return TestRun(self.test_file, excluded=excluded - included)
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    def __sub__(self, other: TestRun) -> TestRun:
173*da0073e9SAndroid Build Coastguard Worker        """
174*da0073e9SAndroid Build Coastguard Worker        To subtract test runs means to run all the tests in the first run except for what the second run specifies.
175*da0073e9SAndroid Build Coastguard Worker        """
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        # Is any file empty?
178*da0073e9SAndroid Build Coastguard Worker        if self.is_empty():
179*da0073e9SAndroid Build Coastguard Worker            return TestRun.empty()
180*da0073e9SAndroid Build Coastguard Worker        if other.is_empty():
181*da0073e9SAndroid Build Coastguard Worker            return copy(self)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        # Are you trying to subtract tests that don't even exist in this test run?
184*da0073e9SAndroid Build Coastguard Worker        if self.test_file != other.test_file:
185*da0073e9SAndroid Build Coastguard Worker            return copy(self)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        # You're subtracting everything?
188*da0073e9SAndroid Build Coastguard Worker        if other.is_full_file():
189*da0073e9SAndroid Build Coastguard Worker            return TestRun.empty()
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        def return_inclusions_or_empty(inclusions: frozenset[str]) -> TestRun:
192*da0073e9SAndroid Build Coastguard Worker            if inclusions:
193*da0073e9SAndroid Build Coastguard Worker                return TestRun(self.test_file, included=inclusions)
194*da0073e9SAndroid Build Coastguard Worker            return TestRun.empty()
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        if other._included:
197*da0073e9SAndroid Build Coastguard Worker            if self._included:
198*da0073e9SAndroid Build Coastguard Worker                return return_inclusions_or_empty(self._included - other._included)
199*da0073e9SAndroid Build Coastguard Worker            else:
200*da0073e9SAndroid Build Coastguard Worker                return TestRun(
201*da0073e9SAndroid Build Coastguard Worker                    self.test_file, excluded=self._excluded | other._included
202*da0073e9SAndroid Build Coastguard Worker                )
203*da0073e9SAndroid Build Coastguard Worker        else:
204*da0073e9SAndroid Build Coastguard Worker            if self._included:
205*da0073e9SAndroid Build Coastguard Worker                return return_inclusions_or_empty(self._included & other._excluded)
206*da0073e9SAndroid Build Coastguard Worker            else:
207*da0073e9SAndroid Build Coastguard Worker                return return_inclusions_or_empty(other._excluded - self._excluded)
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker    def __and__(self, other: TestRun) -> TestRun:
210*da0073e9SAndroid Build Coastguard Worker        if self.test_file != other.test_file:
211*da0073e9SAndroid Build Coastguard Worker            return TestRun.empty()
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker        return (self | other) - (self - other) - (other - self)
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    def to_json(self) -> dict[str, Any]:
216*da0073e9SAndroid Build Coastguard Worker        r: dict[str, Any] = {
217*da0073e9SAndroid Build Coastguard Worker            "test_file": self.test_file,
218*da0073e9SAndroid Build Coastguard Worker        }
219*da0073e9SAndroid Build Coastguard Worker        if self._included:
220*da0073e9SAndroid Build Coastguard Worker            r["included"] = list(self._included)
221*da0073e9SAndroid Build Coastguard Worker        if self._excluded:
222*da0073e9SAndroid Build Coastguard Worker            r["excluded"] = list(self._excluded)
223*da0073e9SAndroid Build Coastguard Worker        return r
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    @staticmethod
226*da0073e9SAndroid Build Coastguard Worker    def from_json(json: dict[str, Any]) -> TestRun:
227*da0073e9SAndroid Build Coastguard Worker        return TestRun(
228*da0073e9SAndroid Build Coastguard Worker            json["test_file"],
229*da0073e9SAndroid Build Coastguard Worker            included=json.get("included", []),
230*da0073e9SAndroid Build Coastguard Worker            excluded=json.get("excluded", []),
231*da0073e9SAndroid Build Coastguard Worker        )
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker@total_ordering
235*da0073e9SAndroid Build Coastguard Workerclass ShardedTest:
236*da0073e9SAndroid Build Coastguard Worker    test: TestRun
237*da0073e9SAndroid Build Coastguard Worker    shard: int
238*da0073e9SAndroid Build Coastguard Worker    num_shards: int
239*da0073e9SAndroid Build Coastguard Worker    time: float | None  # In seconds
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker    def __init__(
242*da0073e9SAndroid Build Coastguard Worker        self,
243*da0073e9SAndroid Build Coastguard Worker        test: TestRun | str,
244*da0073e9SAndroid Build Coastguard Worker        shard: int,
245*da0073e9SAndroid Build Coastguard Worker        num_shards: int,
246*da0073e9SAndroid Build Coastguard Worker        time: float | None = None,
247*da0073e9SAndroid Build Coastguard Worker    ) -> None:
248*da0073e9SAndroid Build Coastguard Worker        if isinstance(test, str):
249*da0073e9SAndroid Build Coastguard Worker            test = TestRun(test)
250*da0073e9SAndroid Build Coastguard Worker        self.test = test
251*da0073e9SAndroid Build Coastguard Worker        self.shard = shard
252*da0073e9SAndroid Build Coastguard Worker        self.num_shards = num_shards
253*da0073e9SAndroid Build Coastguard Worker        self.time = time
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker    @property
256*da0073e9SAndroid Build Coastguard Worker    def name(self) -> str:
257*da0073e9SAndroid Build Coastguard Worker        return self.test.test_file
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, other: object) -> bool:
260*da0073e9SAndroid Build Coastguard Worker        if not isinstance(other, ShardedTest):
261*da0073e9SAndroid Build Coastguard Worker            return False
262*da0073e9SAndroid Build Coastguard Worker        return (
263*da0073e9SAndroid Build Coastguard Worker            self.test == other.test
264*da0073e9SAndroid Build Coastguard Worker            and self.shard == other.shard
265*da0073e9SAndroid Build Coastguard Worker            and self.num_shards == other.num_shards
266*da0073e9SAndroid Build Coastguard Worker            and self.time == other.time
267*da0073e9SAndroid Build Coastguard Worker        )
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker    def __repr__(self) -> str:
270*da0073e9SAndroid Build Coastguard Worker        ret = f"{self.test} {self.shard}/{self.num_shards}"
271*da0073e9SAndroid Build Coastguard Worker        if self.time:
272*da0073e9SAndroid Build Coastguard Worker            ret += f" ({self.time}s)"
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        return ret
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker    def __lt__(self, other: object) -> bool:
277*da0073e9SAndroid Build Coastguard Worker        if not isinstance(other, ShardedTest):
278*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        # This is how the list was implicity sorted when it was a NamedTuple
281*da0073e9SAndroid Build Coastguard Worker        if self.name != other.name:
282*da0073e9SAndroid Build Coastguard Worker            return self.name < other.name
283*da0073e9SAndroid Build Coastguard Worker        if self.shard != other.shard:
284*da0073e9SAndroid Build Coastguard Worker            return self.shard < other.shard
285*da0073e9SAndroid Build Coastguard Worker        if self.num_shards != other.num_shards:
286*da0073e9SAndroid Build Coastguard Worker            return self.num_shards < other.num_shards
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        # None is the smallest value
289*da0073e9SAndroid Build Coastguard Worker        if self.time is None:
290*da0073e9SAndroid Build Coastguard Worker            return True
291*da0073e9SAndroid Build Coastguard Worker        if other.time is None:
292*da0073e9SAndroid Build Coastguard Worker            return False
293*da0073e9SAndroid Build Coastguard Worker        return self.time < other.time
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    def __str__(self) -> str:
296*da0073e9SAndroid Build Coastguard Worker        return f"{self.test} {self.shard}/{self.num_shards}"
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    def get_time(self, default: float = 0) -> float:
299*da0073e9SAndroid Build Coastguard Worker        return self.time if self.time is not None else default
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker    def get_pytest_args(self) -> list[str]:
302*da0073e9SAndroid Build Coastguard Worker        filter = self.test.get_pytest_filter()
303*da0073e9SAndroid Build Coastguard Worker        if filter:
304*da0073e9SAndroid Build Coastguard Worker            return ["-k", self.test.get_pytest_filter()]
305*da0073e9SAndroid Build Coastguard Worker        return []
306