1# Copyright 2023 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14 15"""RPC log stream handler tests.""" 16 17from dataclasses import dataclass 18from typing import Any, Callable 19from unittest import TestCase, main, mock 20 21from pw_log.log_decoder import Log, LogStreamDecoder 22from pw_log.proto import log_pb2 23from pw_log_rpc.rpc_log_stream import LogStreamHandler 24from pw_rpc import callback_client, client 25from pw_rpc.descriptors import RpcIds 26from pw_rpc import packets 27from pw_status import Status 28 29 30class _CallableWithCounter: 31 """Wraps a function and counts how many time it was called.""" 32 33 @dataclass 34 class CallParams: 35 args: Any 36 kwargs: Any 37 38 def __init__(self, func: Callable[[Any], Any]): 39 self._func = func 40 self.calls: list[_CallableWithCounter.CallParams] = [] 41 42 def call_count(self) -> int: 43 return len(self.calls) 44 45 def __call__(self, *args, **kwargs) -> None: 46 self.calls.append(_CallableWithCounter.CallParams(args, kwargs)) 47 self._func(*args, **kwargs) 48 49 50class TestRpcLogStreamHandler(TestCase): 51 """Tests for TestRpcLogStreamHandler.""" 52 53 def setUp(self) -> None: 54 """Set up logs decoder.""" 55 self._channel_id = 1 56 self.client = client.Client.from_modules( 57 callback_client.Impl(), 58 [client.Channel(self._channel_id, lambda _: None)], 59 [log_pb2], 60 ) 61 62 self.captured_logs: list[Log] = [] 63 64 def decoded_log_handler(log: Log) -> None: 65 self.captured_logs.append(log) 66 67 log_decoder = LogStreamDecoder( 68 decoded_log_handler=decoded_log_handler, 69 source_name='source', 70 ) 71 self.log_stream_handler = LogStreamHandler( 72 self.client.channel(self._channel_id).rpcs, log_decoder 73 ) 74 75 def _get_rpc_ids(self) -> RpcIds: 76 service = next(iter(self.client.services)) 77 method = next(iter(service.methods)) 78 79 # To handle unrequested log streams, packets' call Ids are set to 80 # kOpenCallId. 81 return RpcIds( 82 self._channel_id, service.id, method.id, client.OPEN_CALL_ID 83 ) 84 85 def test_start_logging_subsequent_calls(self): 86 """Test a stream of RPC Logs.""" 87 self.log_stream_handler.handle_log_stream_error = mock.Mock() 88 self.log_stream_handler.handle_log_stream_completed = mock.Mock() 89 self.log_stream_handler.start_logging() 90 91 self.assertIs( 92 self.client.process_packet( 93 packets.encode_server_stream( 94 self._get_rpc_ids(), 95 log_pb2.LogEntries( 96 first_entry_sequence_id=0, 97 entries=[ 98 log_pb2.LogEntry(message=b'message0'), 99 log_pb2.LogEntry(message=b'message1'), 100 ], 101 ), 102 ) 103 ), 104 Status.OK, 105 ) 106 self.assertFalse(self.log_stream_handler.handle_log_stream_error.called) 107 self.assertFalse( 108 self.log_stream_handler.handle_log_stream_completed.called 109 ) 110 self.assertEqual(len(self.captured_logs), 2) 111 112 # A subsequent RPC packet should be handled successfully. 113 self.assertIs( 114 self.client.process_packet( 115 packets.encode_server_stream( 116 self._get_rpc_ids(), 117 log_pb2.LogEntries( 118 first_entry_sequence_id=2, 119 entries=[ 120 log_pb2.LogEntry(message=b'message2'), 121 log_pb2.LogEntry(message=b'message3'), 122 ], 123 ), 124 ) 125 ), 126 Status.OK, 127 ) 128 self.assertFalse(self.log_stream_handler.handle_log_stream_error.called) 129 self.assertFalse( 130 self.log_stream_handler.handle_log_stream_completed.called 131 ) 132 self.assertEqual(len(self.captured_logs), 4) 133 134 def test_log_stream_cancelled(self): 135 """Tests that a cancelled log stream is not restarted.""" 136 self.log_stream_handler.handle_log_stream_error = mock.Mock() 137 self.log_stream_handler.handle_log_stream_completed = mock.Mock() 138 139 start_function = _CallableWithCounter( 140 self.log_stream_handler.start_logging 141 ) 142 self.log_stream_handler.start_logging = start_function 143 self.log_stream_handler.start_logging() 144 145 # Send logs prior to cancellation. 146 self.assertIs( 147 self.client.process_packet( 148 packets.encode_server_stream( 149 self._get_rpc_ids(), 150 log_pb2.LogEntries( 151 first_entry_sequence_id=0, 152 entries=[ 153 log_pb2.LogEntry(message=b'message0'), 154 log_pb2.LogEntry(message=b'message1'), 155 ], 156 ), 157 ) 158 ), 159 Status.OK, 160 ) 161 self.assertIs( 162 self.client.process_packet( 163 packets.encode_server_error( 164 self._get_rpc_ids(), Status.CANCELLED 165 ) 166 ), 167 Status.OK, 168 ) 169 self.log_stream_handler.handle_log_stream_error.assert_called_once_with( 170 Status.CANCELLED 171 ) 172 self.assertFalse( 173 self.log_stream_handler.handle_log_stream_completed.called 174 ) 175 self.assertEqual(len(self.captured_logs), 2) 176 self.assertEqual(start_function.call_count(), 1) 177 178 def test_log_stream_error_stream_restarted(self): 179 """Tests that an error on the log stream restarts the stream.""" 180 self.log_stream_handler.handle_log_stream_completed = mock.Mock() 181 182 error_handler = _CallableWithCounter( 183 self.log_stream_handler.handle_log_stream_error 184 ) 185 self.log_stream_handler.handle_log_stream_error = error_handler 186 187 start_function = _CallableWithCounter( 188 self.log_stream_handler.start_logging 189 ) 190 self.log_stream_handler.start_logging = start_function 191 self.log_stream_handler.start_logging() 192 193 # Send logs prior to cancellation. 194 self.assertIs( 195 self.client.process_packet( 196 packets.encode_server_stream( 197 self._get_rpc_ids(), 198 log_pb2.LogEntries( 199 first_entry_sequence_id=0, 200 entries=[ 201 log_pb2.LogEntry(message=b'message0'), 202 log_pb2.LogEntry(message=b'message1'), 203 ], 204 ), 205 ) 206 ), 207 Status.OK, 208 ) 209 self.assertIs( 210 self.client.process_packet( 211 packets.encode_server_error(self._get_rpc_ids(), Status.UNKNOWN) 212 ), 213 Status.OK, 214 ) 215 216 self.assertFalse( 217 self.log_stream_handler.handle_log_stream_completed.called 218 ) 219 self.assertEqual(len(self.captured_logs), 2) 220 self.assertEqual(start_function.call_count(), 2) 221 self.assertEqual(error_handler.call_count(), 1) 222 self.assertEqual(error_handler.calls[0].args, (Status.UNKNOWN,)) 223 224 def test_log_stream_completed_ok_stream_restarted(self): 225 """Tests that when the log stream completes the stream is restarted.""" 226 self.log_stream_handler.handle_log_stream_error = mock.Mock() 227 228 completion_handler = _CallableWithCounter( 229 self.log_stream_handler.handle_log_stream_completed 230 ) 231 self.log_stream_handler.handle_log_stream_completed = completion_handler 232 233 start_function = _CallableWithCounter( 234 self.log_stream_handler.start_logging 235 ) 236 self.log_stream_handler.start_logging = start_function 237 self.log_stream_handler.start_logging() 238 239 # Send logs prior to cancellation. 240 self.assertIs( 241 self.client.process_packet( 242 packets.encode_server_stream( 243 self._get_rpc_ids(), 244 log_pb2.LogEntries( 245 first_entry_sequence_id=0, 246 entries=[ 247 log_pb2.LogEntry(message=b'message0'), 248 log_pb2.LogEntry(message=b'message1'), 249 ], 250 ), 251 ) 252 ), 253 Status.OK, 254 ) 255 self.assertIs( 256 self.client.process_packet( 257 packets.encode_response(self._get_rpc_ids(), status=Status.OK) 258 ), 259 Status.OK, 260 ) 261 262 self.assertFalse(self.log_stream_handler.handle_log_stream_error.called) 263 self.assertEqual(len(self.captured_logs), 2) 264 self.assertEqual(start_function.call_count(), 2) 265 self.assertEqual(completion_handler.call_count(), 1) 266 self.assertEqual(completion_handler.calls[0].args, (Status.OK,)) 267 268 def test_log_stream_completed_with_error_stream_restarted(self): 269 """Tests that when the log stream completes the stream is restarted.""" 270 self.log_stream_handler.handle_log_stream_error = mock.Mock() 271 272 completion_handler = _CallableWithCounter( 273 self.log_stream_handler.handle_log_stream_completed 274 ) 275 self.log_stream_handler.handle_log_stream_completed = completion_handler 276 277 start_function = _CallableWithCounter( 278 self.log_stream_handler.start_logging 279 ) 280 self.log_stream_handler.start_logging = start_function 281 self.log_stream_handler.start_logging() 282 283 # Send logs prior to cancellation. 284 self.assertIs( 285 self.client.process_packet( 286 packets.encode_server_stream( 287 self._get_rpc_ids(), 288 log_pb2.LogEntries( 289 first_entry_sequence_id=0, 290 entries=[ 291 log_pb2.LogEntry(message=b'message0'), 292 log_pb2.LogEntry(message=b'message1'), 293 ], 294 ), 295 ) 296 ), 297 Status.OK, 298 ) 299 self.assertIs( 300 self.client.process_packet( 301 packets.encode_response( 302 self._get_rpc_ids(), status=Status.UNKNOWN 303 ) 304 ), 305 Status.OK, 306 ) 307 308 self.assertFalse(self.log_stream_handler.handle_log_stream_error.called) 309 self.assertEqual(len(self.captured_logs), 2) 310 self.assertEqual(start_function.call_count(), 2) 311 self.assertEqual(completion_handler.call_count(), 1) 312 self.assertEqual(completion_handler.calls[0].args, (Status.UNKNOWN,)) 313 314 315if __name__ == '__main__': 316 main() 317