xref: /aosp_15_r20/external/pytorch/tools/test/test_cmake.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport typing
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Workerimport unittest.mock
8*da0073e9SAndroid Build Coastguard Workerfrom typing import Iterator, Sequence
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport tools.setup_helpers.cmake
11*da0073e9SAndroid Build Coastguard Workerimport tools.setup_helpers.env  # noqa: F401 unused but resolves circular import
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard WorkerT = typing.TypeVar("T")
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerclass TestCMake(unittest.TestCase):
18*da0073e9SAndroid Build Coastguard Worker    @unittest.mock.patch("multiprocessing.cpu_count")
19*da0073e9SAndroid Build Coastguard Worker    def test_build_jobs(self, mock_cpu_count: unittest.mock.MagicMock) -> None:
20*da0073e9SAndroid Build Coastguard Worker        """Tests that the number of build jobs comes out correctly."""
21*da0073e9SAndroid Build Coastguard Worker        mock_cpu_count.return_value = 13
22*da0073e9SAndroid Build Coastguard Worker        cases = [
23*da0073e9SAndroid Build Coastguard Worker            # MAX_JOBS, USE_NINJA, IS_WINDOWS,         want
24*da0073e9SAndroid Build Coastguard Worker            (("8", True, False), ["-j", "8"]),  # noqa: E201,E241
25*da0073e9SAndroid Build Coastguard Worker            ((None, True, False), None),  # noqa: E201,E241
26*da0073e9SAndroid Build Coastguard Worker            (("7", False, False), ["-j", "7"]),  # noqa: E201,E241
27*da0073e9SAndroid Build Coastguard Worker            ((None, False, False), ["-j", "13"]),  # noqa: E201,E241
28*da0073e9SAndroid Build Coastguard Worker            (("6", True, True), ["-j", "6"]),  # noqa: E201,E241
29*da0073e9SAndroid Build Coastguard Worker            ((None, True, True), None),  # noqa: E201,E241
30*da0073e9SAndroid Build Coastguard Worker            (("11", False, True), ["/p:CL_MPCount=11"]),  # noqa: E201,E241
31*da0073e9SAndroid Build Coastguard Worker            ((None, False, True), ["/p:CL_MPCount=13"]),  # noqa: E201,E241
32*da0073e9SAndroid Build Coastguard Worker        ]
33*da0073e9SAndroid Build Coastguard Worker        for (max_jobs, use_ninja, is_windows), want in cases:
34*da0073e9SAndroid Build Coastguard Worker            with self.subTest(
35*da0073e9SAndroid Build Coastguard Worker                MAX_JOBS=max_jobs, USE_NINJA=use_ninja, IS_WINDOWS=is_windows
36*da0073e9SAndroid Build Coastguard Worker            ):
37*da0073e9SAndroid Build Coastguard Worker                with contextlib.ExitStack() as stack:
38*da0073e9SAndroid Build Coastguard Worker                    stack.enter_context(env_var("MAX_JOBS", max_jobs))
39*da0073e9SAndroid Build Coastguard Worker                    stack.enter_context(
40*da0073e9SAndroid Build Coastguard Worker                        unittest.mock.patch.object(
41*da0073e9SAndroid Build Coastguard Worker                            tools.setup_helpers.cmake, "USE_NINJA", use_ninja
42*da0073e9SAndroid Build Coastguard Worker                        )
43*da0073e9SAndroid Build Coastguard Worker                    )
44*da0073e9SAndroid Build Coastguard Worker                    stack.enter_context(
45*da0073e9SAndroid Build Coastguard Worker                        unittest.mock.patch.object(
46*da0073e9SAndroid Build Coastguard Worker                            tools.setup_helpers.cmake, "IS_WINDOWS", is_windows
47*da0073e9SAndroid Build Coastguard Worker                        )
48*da0073e9SAndroid Build Coastguard Worker                    )
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker                    cmake = tools.setup_helpers.cmake.CMake()
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker                    with unittest.mock.patch.object(cmake, "run") as cmake_run:
53*da0073e9SAndroid Build Coastguard Worker                        cmake.build({})
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker                    cmake_run.assert_called_once()
56*da0073e9SAndroid Build Coastguard Worker                    (call,) = cmake_run.mock_calls
57*da0073e9SAndroid Build Coastguard Worker                    build_args, _ = call.args
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker                if want is None:
60*da0073e9SAndroid Build Coastguard Worker                    self.assertNotIn("-j", build_args)
61*da0073e9SAndroid Build Coastguard Worker                else:
62*da0073e9SAndroid Build Coastguard Worker                    self.assert_contains_sequence(build_args, want)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    @staticmethod
65*da0073e9SAndroid Build Coastguard Worker    def assert_contains_sequence(
66*da0073e9SAndroid Build Coastguard Worker        sequence: Sequence[T], subsequence: Sequence[T]
67*da0073e9SAndroid Build Coastguard Worker    ) -> None:
68*da0073e9SAndroid Build Coastguard Worker        """Raises an assertion if the subsequence is not contained in the sequence."""
69*da0073e9SAndroid Build Coastguard Worker        if len(subsequence) == 0:
70*da0073e9SAndroid Build Coastguard Worker            return  # all sequences contain the empty subsequence
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker        # Iterate over all windows of len(subsequence). Stop if the
73*da0073e9SAndroid Build Coastguard Worker        # window matches.
74*da0073e9SAndroid Build Coastguard Worker        for i in range(len(sequence) - len(subsequence) + 1):
75*da0073e9SAndroid Build Coastguard Worker            candidate = sequence[i : i + len(subsequence)]
76*da0073e9SAndroid Build Coastguard Worker            assert len(candidate) == len(subsequence)  # sanity check
77*da0073e9SAndroid Build Coastguard Worker            if candidate == subsequence:
78*da0073e9SAndroid Build Coastguard Worker                return  # found it
79*da0073e9SAndroid Build Coastguard Worker        raise AssertionError(f"{subsequence} not found in {sequence}")
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
83*da0073e9SAndroid Build Coastguard Workerdef env_var(key: str, value: str | None) -> Iterator[None]:
84*da0073e9SAndroid Build Coastguard Worker    """Sets/clears an environment variable within a Python context."""
85*da0073e9SAndroid Build Coastguard Worker    # Get the previous value and then override it.
86*da0073e9SAndroid Build Coastguard Worker    previous_value = os.environ.get(key)
87*da0073e9SAndroid Build Coastguard Worker    set_env_var(key, value)
88*da0073e9SAndroid Build Coastguard Worker    try:
89*da0073e9SAndroid Build Coastguard Worker        yield
90*da0073e9SAndroid Build Coastguard Worker    finally:
91*da0073e9SAndroid Build Coastguard Worker        # Restore to previous value.
92*da0073e9SAndroid Build Coastguard Worker        set_env_var(key, previous_value)
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Workerdef set_env_var(key: str, value: str | None) -> None:
96*da0073e9SAndroid Build Coastguard Worker    """Sets/clears an environment variable."""
97*da0073e9SAndroid Build Coastguard Worker    if value is None:
98*da0073e9SAndroid Build Coastguard Worker        os.environ.pop(key, None)
99*da0073e9SAndroid Build Coastguard Worker    else:
100*da0073e9SAndroid Build Coastguard Worker        os.environ[key] = value
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
104*da0073e9SAndroid Build Coastguard Worker    unittest.main()
105