1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Utilities for running unit tests over :ref:`module-pw_rpc`.""" 15 16import enum 17import abc 18from dataclasses import dataclass 19import logging 20from typing import Iterable 21 22from pw_rpc.client import Services 23from pw_rpc.callback_client import OptionalTimeout, UseDefault 24from pw_unit_test_proto import unit_test_pb2 25 26_LOG = logging.getLogger(__package__) 27 28 29@dataclass(frozen=True) 30class TestCase: 31 suite_name: str 32 test_name: str 33 file_name: str 34 35 def __str__(self) -> str: 36 return f'{self.suite_name}.{self.test_name}' 37 38 def __repr__(self) -> str: 39 return f'TestCase({str(self)})' 40 41 42def _test_case(raw_test_case: unit_test_pb2.TestCaseDescriptor) -> TestCase: 43 return TestCase( 44 raw_test_case.suite_name, 45 raw_test_case.test_name, 46 raw_test_case.file_name, 47 ) 48 49 50@dataclass(frozen=True) 51class TestExpectation: 52 expression: str 53 evaluated_expression: str 54 line_number: int 55 success: bool 56 57 def __str__(self) -> str: 58 return self.expression 59 60 def __repr__(self) -> str: 61 return f'TestExpectation({str(self)})' 62 63 64class TestCaseResult(enum.IntEnum): 65 SUCCESS = unit_test_pb2.TestCaseResult.SUCCESS 66 FAILURE = unit_test_pb2.TestCaseResult.FAILURE 67 SKIPPED = unit_test_pb2.TestCaseResult.SKIPPED 68 69 70class EventHandler(abc.ABC): 71 @abc.abstractmethod 72 def run_all_tests_start(self) -> None: 73 """Called before all tests are run.""" 74 75 @abc.abstractmethod 76 def run_all_tests_end(self, passed_tests: int, failed_tests: int) -> None: 77 """Called after the test run is complete.""" 78 79 @abc.abstractmethod 80 def test_case_start(self, test_case: TestCase) -> None: 81 """Called when a new test case is started.""" 82 83 @abc.abstractmethod 84 def test_case_end( 85 self, test_case: TestCase, result: TestCaseResult 86 ) -> None: 87 """Called when a test case completes with its overall result.""" 88 89 @abc.abstractmethod 90 def test_case_disabled(self, test_case: TestCase) -> None: 91 """Called when a disabled test case is encountered.""" 92 93 @abc.abstractmethod 94 def test_case_expect( 95 self, test_case: TestCase, expectation: TestExpectation 96 ) -> None: 97 """Called after each expect or assert statement within a test case.""" 98 99 100class LoggingEventHandler(EventHandler): 101 """Event handler that logs test events using Google Test format.""" 102 103 def run_all_tests_start(self) -> None: 104 _LOG.info('[==========] Running all tests.') 105 106 def run_all_tests_end(self, passed_tests: int, failed_tests: int) -> None: 107 _LOG.info('[==========] Done running all tests.') 108 _LOG.info('[ PASSED ] %d test(s).', passed_tests) 109 if failed_tests: 110 _LOG.info('[ FAILED ] %d test(s).', failed_tests) 111 112 def test_case_start(self, test_case: TestCase) -> None: 113 _LOG.info('[ RUN ] %s', test_case) 114 115 def test_case_end( 116 self, test_case: TestCase, result: TestCaseResult 117 ) -> None: 118 if result == TestCaseResult.SUCCESS: 119 _LOG.info('[ OK ] %s', test_case) 120 else: 121 _LOG.info('[ FAILED ] %s', test_case) 122 123 def test_case_disabled(self, test_case: TestCase) -> None: 124 _LOG.info('Skipping disabled test %s', test_case) 125 126 def test_case_expect( 127 self, test_case: TestCase, expectation: TestExpectation 128 ) -> None: 129 result = 'Success' if expectation.success else 'Failure' 130 log = _LOG.info if expectation.success else _LOG.error 131 log('%s:%d: %s', test_case.file_name, expectation.line_number, result) 132 log(' Expected: %s', expectation.expression) 133 log(' Actual: %s', expectation.evaluated_expression) 134 135 136@dataclass(frozen=True) 137class TestRecord: 138 """Records test results.""" 139 140 passing_tests: tuple[TestCase, ...] 141 failing_tests: tuple[TestCase, ...] 142 disabled_tests: tuple[TestCase, ...] 143 144 def all_tests_passed(self) -> bool: 145 return not self.failing_tests 146 147 def __bool__(self) -> bool: 148 return self.all_tests_passed() 149 150 151def run_tests( 152 rpcs: Services, 153 report_passed_expectations: bool = False, 154 test_suites: Iterable[str] = (), 155 event_handlers: Iterable[EventHandler] = (LoggingEventHandler(),), 156 timeout_s: OptionalTimeout = UseDefault.VALUE, 157) -> TestRecord: 158 """Runs unit tests on a device over :ref:`module-pw_rpc`. 159 160 Calls each of the provided event handlers as test events occur, and returns 161 ``True`` if all tests pass. 162 """ 163 unit_test_service = rpcs.pw.unit_test.UnitTest # type: ignore[attr-defined] 164 request = unit_test_service.Run.request( 165 report_passed_expectations=report_passed_expectations, 166 test_suite=test_suites, 167 ) 168 call = unit_test_service.Run.invoke(request, timeout_s=timeout_s) 169 test_responses = iter(call) 170 171 # Read the first response, which must be a test_run_start message. 172 try: 173 first_response = next(test_responses) 174 except StopIteration: 175 _LOG.error( 176 'The "test_run_start" message was dropped! UnitTest.Run ' 177 'concluded with %s.', 178 call.status, 179 ) 180 raise 181 182 if not first_response.HasField('test_run_start'): 183 raise ValueError( 184 'Expected a "test_run_start" response from pw.unit_test.Run, ' 185 'but received a different message type. A response may have been ' 186 'dropped.' 187 ) 188 189 for event_handler in event_handlers: 190 event_handler.run_all_tests_start() 191 192 passing_tests: list[TestCase] = [] 193 failing_tests: list[TestCase] = [] 194 disabled_tests: list[TestCase] = [] 195 196 for response in test_responses: 197 if response.HasField('test_run_start'): 198 for event_handler in event_handlers: 199 event_handler.run_all_tests_start() 200 elif response.HasField('test_run_end'): 201 for event_handler in event_handlers: 202 event_handler.run_all_tests_end( 203 response.test_run_end.passed, response.test_run_end.failed 204 ) 205 assert len(passing_tests) == response.test_run_end.passed 206 assert len(failing_tests) == response.test_run_end.failed 207 test_record = TestRecord( 208 passing_tests=tuple(passing_tests), 209 failing_tests=tuple(failing_tests), 210 disabled_tests=tuple(disabled_tests), 211 ) 212 elif response.HasField('test_case_start'): 213 raw_test_case = response.test_case_start 214 current_test_case = _test_case(raw_test_case) 215 for event_handler in event_handlers: 216 event_handler.test_case_start(current_test_case) 217 elif response.HasField('test_case_end'): 218 result = TestCaseResult(response.test_case_end) 219 for event_handler in event_handlers: 220 event_handler.test_case_end(current_test_case, result) 221 if result == TestCaseResult.SUCCESS: 222 passing_tests.append(current_test_case) 223 else: 224 failing_tests.append(current_test_case) 225 elif response.HasField('test_case_disabled'): 226 raw_test_case = response.test_case_disabled 227 current_test_case = _test_case(raw_test_case) 228 for event_handler in event_handlers: 229 event_handler.test_case_disabled(current_test_case) 230 disabled_tests.append(current_test_case) 231 elif response.HasField('test_case_expectation'): 232 raw_expectation = response.test_case_expectation 233 expectation = TestExpectation( 234 raw_expectation.expression, 235 raw_expectation.evaluated_expression, 236 raw_expectation.line_number, 237 raw_expectation.success, 238 ) 239 for event_handler in event_handlers: 240 event_handler.test_case_expect(current_test_case, expectation) 241 return test_record 242