xref: /aosp_15_r20/external/pigweed/pw_unit_test/py/pw_unit_test/rpc.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Utilities for running unit tests over :ref:`module-pw_rpc`."""
15
16import enum
17import abc
18from dataclasses import dataclass
19import logging
20from typing import Iterable
21
22from pw_rpc.client import Services
23from pw_rpc.callback_client import OptionalTimeout, UseDefault
24from pw_unit_test_proto import unit_test_pb2
25
26_LOG = logging.getLogger(__package__)
27
28
29@dataclass(frozen=True)
30class TestCase:
31    suite_name: str
32    test_name: str
33    file_name: str
34
35    def __str__(self) -> str:
36        return f'{self.suite_name}.{self.test_name}'
37
38    def __repr__(self) -> str:
39        return f'TestCase({str(self)})'
40
41
42def _test_case(raw_test_case: unit_test_pb2.TestCaseDescriptor) -> TestCase:
43    return TestCase(
44        raw_test_case.suite_name,
45        raw_test_case.test_name,
46        raw_test_case.file_name,
47    )
48
49
50@dataclass(frozen=True)
51class TestExpectation:
52    expression: str
53    evaluated_expression: str
54    line_number: int
55    success: bool
56
57    def __str__(self) -> str:
58        return self.expression
59
60    def __repr__(self) -> str:
61        return f'TestExpectation({str(self)})'
62
63
64class TestCaseResult(enum.IntEnum):
65    SUCCESS = unit_test_pb2.TestCaseResult.SUCCESS
66    FAILURE = unit_test_pb2.TestCaseResult.FAILURE
67    SKIPPED = unit_test_pb2.TestCaseResult.SKIPPED
68
69
70class EventHandler(abc.ABC):
71    @abc.abstractmethod
72    def run_all_tests_start(self) -> None:
73        """Called before all tests are run."""
74
75    @abc.abstractmethod
76    def run_all_tests_end(self, passed_tests: int, failed_tests: int) -> None:
77        """Called after the test run is complete."""
78
79    @abc.abstractmethod
80    def test_case_start(self, test_case: TestCase) -> None:
81        """Called when a new test case is started."""
82
83    @abc.abstractmethod
84    def test_case_end(
85        self, test_case: TestCase, result: TestCaseResult
86    ) -> None:
87        """Called when a test case completes with its overall result."""
88
89    @abc.abstractmethod
90    def test_case_disabled(self, test_case: TestCase) -> None:
91        """Called when a disabled test case is encountered."""
92
93    @abc.abstractmethod
94    def test_case_expect(
95        self, test_case: TestCase, expectation: TestExpectation
96    ) -> None:
97        """Called after each expect or assert statement within a test case."""
98
99
100class LoggingEventHandler(EventHandler):
101    """Event handler that logs test events using Google Test format."""
102
103    def run_all_tests_start(self) -> None:
104        _LOG.info('[==========] Running all tests.')
105
106    def run_all_tests_end(self, passed_tests: int, failed_tests: int) -> None:
107        _LOG.info('[==========] Done running all tests.')
108        _LOG.info('[  PASSED  ] %d test(s).', passed_tests)
109        if failed_tests:
110            _LOG.info('[  FAILED  ] %d test(s).', failed_tests)
111
112    def test_case_start(self, test_case: TestCase) -> None:
113        _LOG.info('[ RUN      ] %s', test_case)
114
115    def test_case_end(
116        self, test_case: TestCase, result: TestCaseResult
117    ) -> None:
118        if result == TestCaseResult.SUCCESS:
119            _LOG.info('[       OK ] %s', test_case)
120        else:
121            _LOG.info('[  FAILED  ] %s', test_case)
122
123    def test_case_disabled(self, test_case: TestCase) -> None:
124        _LOG.info('Skipping disabled test %s', test_case)
125
126    def test_case_expect(
127        self, test_case: TestCase, expectation: TestExpectation
128    ) -> None:
129        result = 'Success' if expectation.success else 'Failure'
130        log = _LOG.info if expectation.success else _LOG.error
131        log('%s:%d: %s', test_case.file_name, expectation.line_number, result)
132        log('      Expected: %s', expectation.expression)
133        log('        Actual: %s', expectation.evaluated_expression)
134
135
136@dataclass(frozen=True)
137class TestRecord:
138    """Records test results."""
139
140    passing_tests: tuple[TestCase, ...]
141    failing_tests: tuple[TestCase, ...]
142    disabled_tests: tuple[TestCase, ...]
143
144    def all_tests_passed(self) -> bool:
145        return not self.failing_tests
146
147    def __bool__(self) -> bool:
148        return self.all_tests_passed()
149
150
151def run_tests(
152    rpcs: Services,
153    report_passed_expectations: bool = False,
154    test_suites: Iterable[str] = (),
155    event_handlers: Iterable[EventHandler] = (LoggingEventHandler(),),
156    timeout_s: OptionalTimeout = UseDefault.VALUE,
157) -> TestRecord:
158    """Runs unit tests on a device over :ref:`module-pw_rpc`.
159
160    Calls each of the provided event handlers as test events occur, and returns
161    ``True`` if all tests pass.
162    """
163    unit_test_service = rpcs.pw.unit_test.UnitTest  # type: ignore[attr-defined]
164    request = unit_test_service.Run.request(
165        report_passed_expectations=report_passed_expectations,
166        test_suite=test_suites,
167    )
168    call = unit_test_service.Run.invoke(request, timeout_s=timeout_s)
169    test_responses = iter(call)
170
171    # Read the first response, which must be a test_run_start message.
172    try:
173        first_response = next(test_responses)
174    except StopIteration:
175        _LOG.error(
176            'The "test_run_start" message was dropped! UnitTest.Run '
177            'concluded with %s.',
178            call.status,
179        )
180        raise
181
182    if not first_response.HasField('test_run_start'):
183        raise ValueError(
184            'Expected a "test_run_start" response from pw.unit_test.Run, '
185            'but received a different message type. A response may have been '
186            'dropped.'
187        )
188
189    for event_handler in event_handlers:
190        event_handler.run_all_tests_start()
191
192    passing_tests: list[TestCase] = []
193    failing_tests: list[TestCase] = []
194    disabled_tests: list[TestCase] = []
195
196    for response in test_responses:
197        if response.HasField('test_run_start'):
198            for event_handler in event_handlers:
199                event_handler.run_all_tests_start()
200        elif response.HasField('test_run_end'):
201            for event_handler in event_handlers:
202                event_handler.run_all_tests_end(
203                    response.test_run_end.passed, response.test_run_end.failed
204                )
205            assert len(passing_tests) == response.test_run_end.passed
206            assert len(failing_tests) == response.test_run_end.failed
207            test_record = TestRecord(
208                passing_tests=tuple(passing_tests),
209                failing_tests=tuple(failing_tests),
210                disabled_tests=tuple(disabled_tests),
211            )
212        elif response.HasField('test_case_start'):
213            raw_test_case = response.test_case_start
214            current_test_case = _test_case(raw_test_case)
215            for event_handler in event_handlers:
216                event_handler.test_case_start(current_test_case)
217        elif response.HasField('test_case_end'):
218            result = TestCaseResult(response.test_case_end)
219            for event_handler in event_handlers:
220                event_handler.test_case_end(current_test_case, result)
221            if result == TestCaseResult.SUCCESS:
222                passing_tests.append(current_test_case)
223            else:
224                failing_tests.append(current_test_case)
225        elif response.HasField('test_case_disabled'):
226            raw_test_case = response.test_case_disabled
227            current_test_case = _test_case(raw_test_case)
228            for event_handler in event_handlers:
229                event_handler.test_case_disabled(current_test_case)
230            disabled_tests.append(current_test_case)
231        elif response.HasField('test_case_expectation'):
232            raw_expectation = response.test_case_expectation
233            expectation = TestExpectation(
234                raw_expectation.expression,
235                raw_expectation.evaluated_expression,
236                raw_expectation.line_number,
237                raw_expectation.success,
238            )
239            for event_handler in event_handlers:
240                event_handler.test_case_expect(current_test_case, expectation)
241    return test_record
242