xref: /aosp_15_r20/external/pigweed/pw_system/py/pw_system/device.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2021 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Device classes to interact with targets via RPC."""
15
16
17import logging
18import os
19from pathlib import Path
20import tempfile
21from types import ModuleType
22from collections.abc import Iterable
23from typing import Any, Callable, Sequence
24
25from pw_file import file_pb2
26from pw_hdlc import rpc
27from pw_hdlc.decode import Frame
28from pw_log import log_decoder
29from pw_log_rpc import rpc_log_stream
30from pw_metric import metric_parser
31import pw_rpc
32from pw_rpc import callback_client, console_tools, client_utils
33import pw_transfer
34from pw_transfer import transfer_pb2
35from pw_stream import stream_readers
36from pw_system import snapshot
37from pw_thread import thread_analyzer
38from pw_thread_protos import thread_pb2
39from pw_tokenizer import detokenize
40from pw_tokenizer.proto import decode_optionally_tokenized
41from pw_unit_test.rpc import run_tests as pw_unit_test_run_tests, TestRecord
42
43
44# Internal log for troubleshooting this tool (the console).
45_LOG = logging.getLogger(__package__)
46
47DEFAULT_DEVICE_LOGGER = logging.getLogger('rpc_device')
48
49
50class Device:
51    """Represents an RPC Client for a device running a Pigweed target.
52
53    The target must have RPC support for the following services:
54     - logging
55     - file
56     - transfer
57
58    Note: use this class as a base for specialized device representations.
59    """
60
61    def __init__(
62        # pylint: disable=too-many-arguments
63        self,
64        channel_id: int,
65        reader: stream_readers.CancellableReader,
66        write: Callable[[bytes], Any],
67        proto_library: Iterable[ModuleType | Path],
68        detokenizer: detokenize.Detokenizer | None = None,
69        timestamp_decoder: Callable[[int], str] | None = None,
70        rpc_timeout_s: float = 5,
71        use_rpc_logging: bool = True,
72        use_hdlc_encoding: bool = True,
73        logger: logging.Logger | logging.LoggerAdapter = DEFAULT_DEVICE_LOGGER,
74        extra_frame_handlers: dict[int, Callable[[bytes, Any], Any]]
75        | None = None,
76    ):
77        self.channel_id = channel_id
78        self.protos = list(proto_library)
79        self.detokenizer = detokenizer
80        self.rpc_timeout_s = rpc_timeout_s
81
82        self.logger = logger
83        self.logger.setLevel(logging.DEBUG)  # Allow all device logs through.
84
85        callback_client_impl = callback_client.Impl(
86            default_unary_timeout_s=self.rpc_timeout_s,
87            default_stream_timeout_s=None,
88        )
89
90        def detokenize_and_log_output(data: bytes, _detokenizer=None):
91            log_messages = data.decode(
92                encoding='utf-8', errors='surrogateescape'
93            )
94
95            if self.detokenizer:
96                log_messages = decode_optionally_tokenized(
97                    self.detokenizer, data
98                )
99
100            for line in log_messages.splitlines():
101                self.logger.info(line)
102
103        # Device has a hard dependency on transfer_pb2, so ensure it's
104        # always been added to the list of compiled protos, rather than
105        # requiring all clients to include it.
106        if transfer_pb2 not in self.protos:
107            self.protos.append(transfer_pb2)
108
109        self.client: client_utils.RpcClient
110        if use_hdlc_encoding:
111            channels = [
112                pw_rpc.Channel(self.channel_id, rpc.channel_output(write))
113            ]
114
115            def create_frame_handler_wrapper(
116                handler: Callable[[bytes, Any], Any]
117            ) -> Callable[[Frame], Any]:
118                def handler_wrapper(frame: Frame):
119                    handler(frame.data, self)
120
121                return handler_wrapper
122
123            extra_frame_handlers_wrapper: rpc.FrameHandlers = {}
124            if extra_frame_handlers is not None:
125                for address, handler in extra_frame_handlers.items():
126                    extra_frame_handlers_wrapper[
127                        address
128                    ] = create_frame_handler_wrapper(handler)
129
130            self.client = rpc.HdlcRpcClient(
131                reader,
132                self.protos,
133                channels,
134                detokenize_and_log_output,
135                client_impl=callback_client_impl,
136                extra_frame_handlers=extra_frame_handlers_wrapper,
137            )
138        else:
139            channel = pw_rpc.Channel(self.channel_id, write)
140            self.client = client_utils.NoEncodingSingleChannelRpcClient(
141                reader,
142                self.protos,
143                channel,
144                client_impl=callback_client_impl,
145            )
146
147        if use_rpc_logging:
148            # Create the log decoder used by the LogStreamHandler.
149
150            def decoded_log_handler(log: log_decoder.Log) -> None:
151                log_decoder.log_decoded_log(log, self.logger)
152
153            self._log_decoder = log_decoder.LogStreamDecoder(
154                decoded_log_handler=decoded_log_handler,
155                detokenizer=self.detokenizer,
156                source_name='RpcDevice',
157                timestamp_parser=(
158                    timestamp_decoder
159                    if timestamp_decoder
160                    else log_decoder.timestamp_parser_ns_since_boot
161                ),
162            )
163
164            # Start listening to logs as soon as possible.
165            self.log_stream_handler = rpc_log_stream.LogStreamHandler(
166                self.rpcs, self._log_decoder
167            )
168            self.log_stream_handler.start_logging()
169
170        # Create the transfer manager
171        self.transfer_service = self.rpcs.pw.transfer.Transfer
172        self.transfer_manager = pw_transfer.Manager(
173            self.transfer_service,
174            default_response_timeout_s=self.rpc_timeout_s,
175            initial_response_timeout_s=self.rpc_timeout_s,
176            default_protocol_version=pw_transfer.ProtocolVersion.LATEST,
177        )
178
179    def __enter__(self):
180        return self
181
182    def __exit__(self, *exc_info):
183        self.close()
184
185    def close(self) -> None:
186        self.client.close()
187
188    def info(self) -> console_tools.ClientInfo:
189        return console_tools.ClientInfo('device', self.rpcs, self.client.client)
190
191    @property
192    def rpcs(self) -> Any:
193        """Returns an object for accessing services on the specified channel."""
194        return next(iter(self.client.client.channels())).rpcs
195
196    def run_tests(self, timeout_s: float | None = 5) -> TestRecord:
197        """Runs the unit tests on this device."""
198        return pw_unit_test_run_tests(self.rpcs, timeout_s=timeout_s)
199
200    def echo(self, msg: str) -> str:
201        """Sends a string to the device and back, returning the result."""
202        return self.rpcs.pw.rpc.EchoService.Echo(msg=msg).unwrap_or_raise().msg
203
204    def reboot(self):
205        """Triggers a reboot to run asynchronously on the device.
206
207        This function *does not* wait for the reboot to complete."""
208        # `invoke` rather than call in order to ignore the result. No result
209        # will be sent when the device reboots.
210        self.rpcs.pw.system.proto.DeviceService.Reboot.invoke()
211
212    def crash(self):
213        """Triggers a crash to run asynchronously on the device.
214
215        This function *does not* wait for the crash to complete."""
216        # `invoke` rather than call in order to ignore the result. No result
217        # will be sent when the device crashes.
218        self.rpcs.pw.system.proto.DeviceService.Crash.invoke()
219
220    def get_and_log_metrics(self) -> dict:
221        """Retrieves the parsed metrics and logs them to the console."""
222        metrics = metric_parser.parse_metrics(
223            self.rpcs, self.detokenizer, self.rpc_timeout_s
224        )
225
226        def print_metrics(metrics, path):
227            """Traverses dictionaries, until a non-dict value is reached."""
228            for path_name, metric in metrics.items():
229                if isinstance(metric, dict):
230                    print_metrics(metric, path + '/' + path_name)
231                else:
232                    _LOG.info('%s/%s: %s', path, path_name, str(metric))
233
234        print_metrics(metrics, '')
235        return metrics
236
237    def snapshot_peak_stack_usage(self, thread_name: str | None = None):
238        snapshot_service = self.rpcs.pw.thread.proto.ThreadSnapshotService
239        _, rsp = snapshot_service.GetPeakStackUsage(name=thread_name)
240
241        thread_info = thread_pb2.SnapshotThreadInfo()
242        for thread_info_block in rsp:
243            for thread in thread_info_block.threads:
244                thread_info.threads.append(thread)
245        for line in str(
246            thread_analyzer.ThreadSnapshotAnalyzer(thread_info)
247        ).splitlines():
248            _LOG.info('%s', line)
249
250    def list_files(self) -> Sequence[file_pb2.ListResponse]:
251        """Lists all files on this device.
252        Returns:
253            A sequence of responses from the List() RPC.
254        """
255        fs_service = self.rpcs.pw.file.FileSystem
256        stream_response = fs_service.List()
257        if not stream_response.status.ok():
258            _LOG.error('Failed to list files %s', stream_response.status)
259            return []
260
261        return stream_response.responses
262
263    def delete_file(self, path: str) -> bool:
264        """Delete a file on this device.
265        Args:
266            path: The path of the file to delete.
267        Returns:
268            True on successful deletion, False on failure.
269        """
270
271        fs_service = self.rpcs.pw.file.FileSystem
272        req = file_pb2.DeleteRequest(path=path)
273        stream_response = fs_service.Delete(req)
274        if not stream_response.status.ok():
275            _LOG.error(
276                'Failed to delete file %s file: %s',
277                path,
278                stream_response.status,
279            )
280            return False
281
282        return True
283
284    def transfer_file(self, file_id: int, dest_path: str) -> bool:
285        """Transfer a file on this device to the host.
286        Args:
287            file_id: The file_id of the file to transfer from device.
288            dest_path: The destination path to save the file to on the host.
289        Returns:
290            True on successful transfer, False on failure.
291        Raises:
292            pw_transfer.Error the transfer failed.
293        """
294        try:
295            data = self.transfer_manager.read(file_id)
296            with open(dest_path, "wb") as bin_file:
297                bin_file.write(data)
298            _LOG.info(
299                'Successfully wrote file to %s', os.path.abspath(dest_path)
300            )
301        except pw_transfer.Error:
302            _LOG.exception('Failed to transfer file_id %i', file_id)
303            return False
304
305        return True
306
307    def get_crash_snapshots(self, crash_log_path: str | None = None) -> bool:
308        r"""Transfer any crash snapshots on this device to the host.
309        Args:
310            crash_log_path: The host path to store the crash files.
311              If not specified, defaults to `/tmp` or `C:\TEMP` on Windows.
312        Returns:
313            True on successful download of snapshot, or no snapshots
314            on device.  False on failure to download snapshot.
315        """
316        if crash_log_path is None:
317            crash_log_path = tempfile.gettempdir()
318
319        snapshot_paths: list[file_pb2.Path] = []
320        for response in self.list_files():
321            for snapshot_path in response.paths:
322                if snapshot_path.path.startswith('/snapshots/crash_'):
323                    snapshot_paths.append(snapshot_path)
324
325        if len(snapshot_paths) == 0:
326            _LOG.info('No crash snapshot on the device.')
327            return True
328
329        for snapshot_path in snapshot_paths:
330            dest_snapshot = os.path.join(
331                crash_log_path, os.path.basename(snapshot_path.path)
332            )
333            if not self.transfer_file(snapshot_path.file_id, dest_snapshot):
334                return False
335
336            decoded_snapshot: str
337            with open(dest_snapshot, 'rb') as f:
338                decoded_snapshot = snapshot.decode_snapshot(
339                    self.detokenizer, f.read()
340                )
341
342            dest_text_snapshot = dest_snapshot.replace(".snapshot", ".txt")
343            with open(dest_text_snapshot, 'w') as f:
344                f.write(decoded_snapshot)
345            _LOG.info('Wrote crash snapshot to: %s', dest_text_snapshot)
346
347            if not self.delete_file(snapshot_path.path):
348                return False
349
350        return True
351