xref: /aosp_15_r20/external/pigweed/pw_log_rpc/py/rpc_log_stream_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2023 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14
15"""RPC log stream handler tests."""
16
17from dataclasses import dataclass
18from typing import Any, Callable
19from unittest import TestCase, main, mock
20
21from pw_log.log_decoder import Log, LogStreamDecoder
22from pw_log.proto import log_pb2
23from pw_log_rpc.rpc_log_stream import LogStreamHandler
24from pw_rpc import callback_client, client
25from pw_rpc.descriptors import RpcIds
26from pw_rpc import packets
27from pw_status import Status
28
29
30class _CallableWithCounter:
31    """Wraps a function and counts how many time it was called."""
32
33    @dataclass
34    class CallParams:
35        args: Any
36        kwargs: Any
37
38    def __init__(self, func: Callable[[Any], Any]):
39        self._func = func
40        self.calls: list[_CallableWithCounter.CallParams] = []
41
42    def call_count(self) -> int:
43        return len(self.calls)
44
45    def __call__(self, *args, **kwargs) -> None:
46        self.calls.append(_CallableWithCounter.CallParams(args, kwargs))
47        self._func(*args, **kwargs)
48
49
50class TestRpcLogStreamHandler(TestCase):
51    """Tests for TestRpcLogStreamHandler."""
52
53    def setUp(self) -> None:
54        """Set up logs decoder."""
55        self._channel_id = 1
56        self.client = client.Client.from_modules(
57            callback_client.Impl(),
58            [client.Channel(self._channel_id, lambda _: None)],
59            [log_pb2],
60        )
61
62        self.captured_logs: list[Log] = []
63
64        def decoded_log_handler(log: Log) -> None:
65            self.captured_logs.append(log)
66
67        log_decoder = LogStreamDecoder(
68            decoded_log_handler=decoded_log_handler,
69            source_name='source',
70        )
71        self.log_stream_handler = LogStreamHandler(
72            self.client.channel(self._channel_id).rpcs, log_decoder
73        )
74
75    def _get_rpc_ids(self) -> RpcIds:
76        service = next(iter(self.client.services))
77        method = next(iter(service.methods))
78
79        # To handle unrequested log streams, packets' call Ids are set to
80        # kOpenCallId.
81        return RpcIds(
82            self._channel_id, service.id, method.id, client.OPEN_CALL_ID
83        )
84
85    def test_start_logging_subsequent_calls(self):
86        """Test a stream of RPC Logs."""
87        self.log_stream_handler.handle_log_stream_error = mock.Mock()
88        self.log_stream_handler.handle_log_stream_completed = mock.Mock()
89        self.log_stream_handler.start_logging()
90
91        self.assertIs(
92            self.client.process_packet(
93                packets.encode_server_stream(
94                    self._get_rpc_ids(),
95                    log_pb2.LogEntries(
96                        first_entry_sequence_id=0,
97                        entries=[
98                            log_pb2.LogEntry(message=b'message0'),
99                            log_pb2.LogEntry(message=b'message1'),
100                        ],
101                    ),
102                )
103            ),
104            Status.OK,
105        )
106        self.assertFalse(self.log_stream_handler.handle_log_stream_error.called)
107        self.assertFalse(
108            self.log_stream_handler.handle_log_stream_completed.called
109        )
110        self.assertEqual(len(self.captured_logs), 2)
111
112        # A subsequent RPC packet should be handled successfully.
113        self.assertIs(
114            self.client.process_packet(
115                packets.encode_server_stream(
116                    self._get_rpc_ids(),
117                    log_pb2.LogEntries(
118                        first_entry_sequence_id=2,
119                        entries=[
120                            log_pb2.LogEntry(message=b'message2'),
121                            log_pb2.LogEntry(message=b'message3'),
122                        ],
123                    ),
124                )
125            ),
126            Status.OK,
127        )
128        self.assertFalse(self.log_stream_handler.handle_log_stream_error.called)
129        self.assertFalse(
130            self.log_stream_handler.handle_log_stream_completed.called
131        )
132        self.assertEqual(len(self.captured_logs), 4)
133
134    def test_log_stream_cancelled(self):
135        """Tests that a cancelled log stream is not restarted."""
136        self.log_stream_handler.handle_log_stream_error = mock.Mock()
137        self.log_stream_handler.handle_log_stream_completed = mock.Mock()
138
139        start_function = _CallableWithCounter(
140            self.log_stream_handler.start_logging
141        )
142        self.log_stream_handler.start_logging = start_function
143        self.log_stream_handler.start_logging()
144
145        # Send logs prior to cancellation.
146        self.assertIs(
147            self.client.process_packet(
148                packets.encode_server_stream(
149                    self._get_rpc_ids(),
150                    log_pb2.LogEntries(
151                        first_entry_sequence_id=0,
152                        entries=[
153                            log_pb2.LogEntry(message=b'message0'),
154                            log_pb2.LogEntry(message=b'message1'),
155                        ],
156                    ),
157                )
158            ),
159            Status.OK,
160        )
161        self.assertIs(
162            self.client.process_packet(
163                packets.encode_server_error(
164                    self._get_rpc_ids(), Status.CANCELLED
165                )
166            ),
167            Status.OK,
168        )
169        self.log_stream_handler.handle_log_stream_error.assert_called_once_with(
170            Status.CANCELLED
171        )
172        self.assertFalse(
173            self.log_stream_handler.handle_log_stream_completed.called
174        )
175        self.assertEqual(len(self.captured_logs), 2)
176        self.assertEqual(start_function.call_count(), 1)
177
178    def test_log_stream_error_stream_restarted(self):
179        """Tests that an error on the log stream restarts the stream."""
180        self.log_stream_handler.handle_log_stream_completed = mock.Mock()
181
182        error_handler = _CallableWithCounter(
183            self.log_stream_handler.handle_log_stream_error
184        )
185        self.log_stream_handler.handle_log_stream_error = error_handler
186
187        start_function = _CallableWithCounter(
188            self.log_stream_handler.start_logging
189        )
190        self.log_stream_handler.start_logging = start_function
191        self.log_stream_handler.start_logging()
192
193        # Send logs prior to cancellation.
194        self.assertIs(
195            self.client.process_packet(
196                packets.encode_server_stream(
197                    self._get_rpc_ids(),
198                    log_pb2.LogEntries(
199                        first_entry_sequence_id=0,
200                        entries=[
201                            log_pb2.LogEntry(message=b'message0'),
202                            log_pb2.LogEntry(message=b'message1'),
203                        ],
204                    ),
205                )
206            ),
207            Status.OK,
208        )
209        self.assertIs(
210            self.client.process_packet(
211                packets.encode_server_error(self._get_rpc_ids(), Status.UNKNOWN)
212            ),
213            Status.OK,
214        )
215
216        self.assertFalse(
217            self.log_stream_handler.handle_log_stream_completed.called
218        )
219        self.assertEqual(len(self.captured_logs), 2)
220        self.assertEqual(start_function.call_count(), 2)
221        self.assertEqual(error_handler.call_count(), 1)
222        self.assertEqual(error_handler.calls[0].args, (Status.UNKNOWN,))
223
224    def test_log_stream_completed_ok_stream_restarted(self):
225        """Tests that when the log stream completes the stream is restarted."""
226        self.log_stream_handler.handle_log_stream_error = mock.Mock()
227
228        completion_handler = _CallableWithCounter(
229            self.log_stream_handler.handle_log_stream_completed
230        )
231        self.log_stream_handler.handle_log_stream_completed = completion_handler
232
233        start_function = _CallableWithCounter(
234            self.log_stream_handler.start_logging
235        )
236        self.log_stream_handler.start_logging = start_function
237        self.log_stream_handler.start_logging()
238
239        # Send logs prior to cancellation.
240        self.assertIs(
241            self.client.process_packet(
242                packets.encode_server_stream(
243                    self._get_rpc_ids(),
244                    log_pb2.LogEntries(
245                        first_entry_sequence_id=0,
246                        entries=[
247                            log_pb2.LogEntry(message=b'message0'),
248                            log_pb2.LogEntry(message=b'message1'),
249                        ],
250                    ),
251                )
252            ),
253            Status.OK,
254        )
255        self.assertIs(
256            self.client.process_packet(
257                packets.encode_response(self._get_rpc_ids(), status=Status.OK)
258            ),
259            Status.OK,
260        )
261
262        self.assertFalse(self.log_stream_handler.handle_log_stream_error.called)
263        self.assertEqual(len(self.captured_logs), 2)
264        self.assertEqual(start_function.call_count(), 2)
265        self.assertEqual(completion_handler.call_count(), 1)
266        self.assertEqual(completion_handler.calls[0].args, (Status.OK,))
267
268    def test_log_stream_completed_with_error_stream_restarted(self):
269        """Tests that when the log stream completes the stream is restarted."""
270        self.log_stream_handler.handle_log_stream_error = mock.Mock()
271
272        completion_handler = _CallableWithCounter(
273            self.log_stream_handler.handle_log_stream_completed
274        )
275        self.log_stream_handler.handle_log_stream_completed = completion_handler
276
277        start_function = _CallableWithCounter(
278            self.log_stream_handler.start_logging
279        )
280        self.log_stream_handler.start_logging = start_function
281        self.log_stream_handler.start_logging()
282
283        # Send logs prior to cancellation.
284        self.assertIs(
285            self.client.process_packet(
286                packets.encode_server_stream(
287                    self._get_rpc_ids(),
288                    log_pb2.LogEntries(
289                        first_entry_sequence_id=0,
290                        entries=[
291                            log_pb2.LogEntry(message=b'message0'),
292                            log_pb2.LogEntry(message=b'message1'),
293                        ],
294                    ),
295                )
296            ),
297            Status.OK,
298        )
299        self.assertIs(
300            self.client.process_packet(
301                packets.encode_response(
302                    self._get_rpc_ids(), status=Status.UNKNOWN
303                )
304            ),
305            Status.OK,
306        )
307
308        self.assertFalse(self.log_stream_handler.handle_log_stream_error.called)
309        self.assertEqual(len(self.captured_logs), 2)
310        self.assertEqual(start_function.call_count(), 2)
311        self.assertEqual(completion_handler.call_count(), 1)
312        self.assertEqual(completion_handler.calls[0].args, (Status.UNKNOWN,))
313
314
315if __name__ == '__main__':
316    main()
317