xref: /aosp_15_r20/external/pigweed/pw_system/py/pw_system/device_connection.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Device instance creation."""
15
16import argparse
17from dataclasses import dataclass
18import logging
19from pathlib import Path
20import time
21from types import ModuleType
22from typing import Any, Callable, Collection
23
24import serial
25
26from pw_console import pyserial_wrapper
27from pw_console import socket_client
28from pw_hdlc import rpc
29from pw_log.log_decoder import timestamp_parser_ms_since_boot
30from pw_tokenizer import detokenize
31
32from pw_system.find_serial_port import interactive_serial_port_select
33from pw_system import device as pw_device
34from pw_system import device_tracing as pw_device_tracing
35
36# Default proto imports:
37# pylint: disable=ungrouped-imports,wrong-import-order
38from pw_file import file_pb2
39from pw_log.proto import log_pb2
40from pw_metric_proto import metric_service_pb2
41from pw_rpc import echo_pb2
42from pw_stream import stream_readers
43from pw_thread_protos import thread_snapshot_service_pb2
44from pw_trace_protos import trace_service_pb2
45
46from pw_system_protos import device_service_pb2
47from pw_unit_test_proto import unit_test_pb2
48
49
50# pylint: enable=ungrouped-imports
51
52# Internal log for troubleshooting this tool (the console).
53_LOG = logging.getLogger(__package__)
54_ROOT_LOG = logging.getLogger()
55
56DEFAULT_DEVICE_LOGGER = logging.getLogger('rpc_device')
57
58
59@dataclass
60class DeviceConnection:
61    """Stores a Device client along with the reader and writer."""
62
63    client: pw_device_tracing.DeviceWithTracing | pw_device.Device
64    reader: stream_readers.SelectableReader | stream_readers.SerialReader
65    writer: Callable[[bytes], int | None]
66
67    def __enter__(self):
68        """Enter the reader followed by the client context managers.
69
70        Returns the device client for RPC interaction.
71        """
72        self.reader.__enter__()
73        self.client.__enter__()
74        return self.client
75
76    def __exit__(self, *exc_info):
77        """Close the device connection followed by the reader."""
78        self.client.__exit__()
79        self.reader.__exit__()
80
81
82def create_device_serial_or_socket_connection(
83    # pylint: disable=too-many-arguments,too-many-locals
84    device: str,
85    baudrate: int,
86    token_databases: Collection[Path],
87    socket_addr: str | None = None,
88    ticks_per_second: int | None = None,
89    serial_debug: bool = False,
90    compiled_protos: list[ModuleType] | None = None,
91    rpc_logging: bool = True,
92    channel_id: int = rpc.DEFAULT_CHANNEL_ID,
93    hdlc_encoding: bool = True,
94    device_tracing: bool = True,
95    device_class: type[pw_device.Device] | None = pw_device.Device,
96    device_tracing_class: type[pw_device_tracing.DeviceWithTracing]
97    | None = (pw_device_tracing.DeviceWithTracing),
98    timestamp_decoder: Callable[[int], str] | None = None,
99    extra_frame_handlers: dict[int, Callable[[bytes, Any], Any]] | None = None,
100) -> DeviceConnection:
101    """Setup a pw_system.Device client connection.
102
103    Full example of opening a device connection and running an RPC:
104
105    .. code-block:: python
106
107       from pw_system.device_connection import (
108           add_device_args,
109           create_device_serial_or_socket_connection,
110       )
111
112       from pw_protobuf_protos import common_pb2
113       from pw_rpc import echo_pb2
114
115       parser = argparse.ArgumentParser(
116           prog='MyProductScript',
117       )
118       parser = add_device_args(parser)
119       args = parser.parse_args()
120
121       compiled_protos = [
122           common_pb2,
123           echo_pb2,
124       ]
125
126       device_connection = create_device_serial_or_socket_connection(
127           device=args.device,
128           baudrate=args.baudrate,
129           token_databases=args.token_databases,
130           compiled_protos=compiled_protos,
131           socket_addr=args.socket_addr,
132           ticks_per_second=args.ticks_per_second,
133           serial_debug=args.serial_debug,
134           rpc_logging=args.rpc_logging,
135           hdlc_encoding=args.hdlc_encoding,
136           channel_id=args.channel_id,
137           device_tracing=args.device_tracing,
138           device_class=Device,
139           device_tracing_class=DeviceWithTracing,
140           timestamp_decoder=timestamp_parser_ms_since_boot,
141       )
142
143
144       # Open the device connction and interact with the Device client.
145       with device_connection as device:
146           # Make a shortcut to the EchoService.
147           echo_service = device.rpcs.pw.rpc.EchoService
148
149           # Call some RPCs and check the results.
150           result = echo_service.Echo(msg='Hello')
151
152           if result.status.ok():
153               print('The status was', result.status)
154               print('The message was', result.response.msg)
155           else:
156               print('Uh oh, this RPC returned', result.status)
157
158           # The result object can be split into status and payload
159           # when assigned.
160           status, payload = echo_service.Echo(msg='Goodbye!')
161
162           print(f'{status}: {payload}')
163    """
164
165    detokenizer = None
166    if token_databases:
167        token_databases_with_domains = [] * len(token_databases)
168        for token_database in token_databases:
169            # Load all domains from token database.
170            token_databases_with_domains.append(str(token_database) + "#.*")
171
172        detokenizer = detokenize.AutoUpdatingDetokenizer(
173            *token_databases_with_domains
174        )
175        detokenizer.show_errors = True
176
177    protos: list[ModuleType | Path] = []
178
179    if compiled_protos is None:
180        compiled_protos = []
181
182    # Append compiled log.proto library to avoid include errors when
183    # manually provided, and shadowing errors due to ordering when the
184    # default global search path is used.
185    if rpc_logging:
186        compiled_protos.append(log_pb2)
187    compiled_protos.append(unit_test_pb2)
188    protos.extend(compiled_protos)
189    protos.append(metric_service_pb2)
190    protos.append(thread_snapshot_service_pb2)
191    protos.append(file_pb2)
192    protos.append(echo_pb2)
193    protos.append(trace_service_pb2)
194    protos.append(device_service_pb2)
195
196    reader: stream_readers.SelectableReader | stream_readers.SerialReader
197
198    if socket_addr is None:
199        serial_impl = (
200            pyserial_wrapper.SerialWithLogging
201            if serial_debug
202            else serial.Serial
203        )
204
205        if not device:
206            device = interactive_serial_port_select()
207        _ROOT_LOG.info('Using serial port: %s', device)
208        serial_device = serial_impl(
209            device,
210            baudrate,
211            # Timeout in seconds. This should be a very small value. Setting to
212            # zero makes pyserial read() non-blocking which will cause the host
213            # machine to busy loop and 100% CPU usage.
214            # https://pythonhosted.org/pyserial/pyserial_api.html#serial.Serial
215            timeout=0.1,
216        )
217        reader = stream_readers.SerialReader(serial_device, 8192)
218        write = serial_device.write
219
220        # Overwrite decoder for serial device.
221        if timestamp_decoder is None:
222            timestamp_decoder = timestamp_parser_ms_since_boot
223    else:
224        socket_impl = (
225            socket_client.SocketClientWithLogging
226            if serial_debug
227            else socket_client.SocketClient
228        )
229
230        def disconnect_handler(
231            socket_device: socket_client.SocketClient,
232        ) -> None:
233            """Attempts to reconnect on disconnected socket."""
234            _LOG.error('Socket disconnected. Will retry to connect.')
235            while True:
236                try:
237                    socket_device.connect()
238                    break
239                except:  # pylint: disable=bare-except
240                    # Ignore errors and retry to reconnect.
241                    time.sleep(1)
242            _LOG.info('Successfully reconnected')
243
244        try:
245            socket_device = socket_impl(
246                socket_addr, on_disconnect=disconnect_handler
247            )
248            reader = stream_readers.SelectableReader(socket_device)
249            write = socket_device.write
250        except ValueError as error:
251            raise ValueError(
252                f'Failed to initialize socket at {socket_addr}'
253            ) from error
254
255    device_args: list[Any] = [channel_id, reader, write]
256    device_kwds: dict[str, Any] = {
257        'proto_library': protos,
258        'detokenizer': detokenizer,
259        'timestamp_decoder': timestamp_decoder,
260        'rpc_timeout_s': 5,
261        'use_rpc_logging': rpc_logging,
262        'use_hdlc_encoding': hdlc_encoding,
263        'extra_frame_handlers': extra_frame_handlers,
264    }
265
266    device_client: pw_device_tracing.DeviceWithTracing | pw_device.Device
267    if device_tracing:
268        device_kwds['ticks_per_second'] = ticks_per_second
269        if device_tracing_class is None:
270            device_tracing_class = pw_device_tracing.DeviceWithTracing
271        device_client = device_tracing_class(*device_args, **device_kwds)
272    else:
273        if device_class is None:
274            device_class = pw_device.Device
275        device_client = device_class(*device_args, **device_kwds)
276
277    return DeviceConnection(device_client, reader, write)
278
279
280def add_device_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
281    """Add device specific args required by the pw_system.Device class."""
282
283    group = parser.add_mutually_exclusive_group(required=False)
284
285    group.add_argument(
286        '-d',
287        '--device',
288        help='the serial port to use',
289    )
290    group.add_argument(
291        '-s',
292        '--socket-addr',
293        type=str,
294        help=(
295            'Socket address used to connect to server. Type "default" to use '
296            'localhost:33000, pass the server address and port as '
297            'address:port, or prefix the path to a forwarded socket with '
298            f'"{socket_client.SocketClient.FILE_SOCKET_SERVER}:" as '
299            f'{socket_client.SocketClient.FILE_SOCKET_SERVER}:path_to_file.'
300        ),
301    )
302
303    parser.add_argument(
304        '-b',
305        '--baudrate',
306        type=int,
307        default=115200,
308        help='the baud rate to use',
309    )
310    parser.add_argument(
311        '--serial-debug',
312        action='store_true',
313        help=(
314            'Enable debug log tracing of all data passed through'
315            'pyserial read and write.'
316        ),
317    )
318    parser.add_argument(
319        "--token-databases",
320        metavar='elf_or_token_database',
321        nargs="+",
322        type=Path,
323        help="Path to tokenizer database csv file(s).",
324    )
325    parser.add_argument(
326        '-f',
327        '--ticks_per_second',
328        type=int,
329        dest='ticks_per_second',
330        help=('The clock rate of the trace events.'),
331    )
332    parser.add_argument(
333        '--rpc-logging',
334        action=argparse.BooleanOptionalAction,
335        default=True,
336        help='Use pw_rpc based logging.',
337    )
338    parser.add_argument(
339        '--hdlc-encoding',
340        action=argparse.BooleanOptionalAction,
341        default=True,
342        help='Use HDLC encoding on transfer interfaces.',
343    )
344    parser.add_argument(
345        '--channel-id',
346        type=int,
347        default=rpc.DEFAULT_CHANNEL_ID,
348        help="Channel ID used in RPC communications.",
349    )
350    parser.add_argument(
351        '--device-tracing',
352        action=argparse.BooleanOptionalAction,
353        default=True,
354        help='Use device tracing.',
355    )
356
357    return parser
358