1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Device instance creation.""" 15 16import argparse 17from dataclasses import dataclass 18import logging 19from pathlib import Path 20import time 21from types import ModuleType 22from typing import Any, Callable, Collection 23 24import serial 25 26from pw_console import pyserial_wrapper 27from pw_console import socket_client 28from pw_hdlc import rpc 29from pw_log.log_decoder import timestamp_parser_ms_since_boot 30from pw_tokenizer import detokenize 31 32from pw_system.find_serial_port import interactive_serial_port_select 33from pw_system import device as pw_device 34from pw_system import device_tracing as pw_device_tracing 35 36# Default proto imports: 37# pylint: disable=ungrouped-imports,wrong-import-order 38from pw_file import file_pb2 39from pw_log.proto import log_pb2 40from pw_metric_proto import metric_service_pb2 41from pw_rpc import echo_pb2 42from pw_stream import stream_readers 43from pw_thread_protos import thread_snapshot_service_pb2 44from pw_trace_protos import trace_service_pb2 45 46from pw_system_protos import device_service_pb2 47from pw_unit_test_proto import unit_test_pb2 48 49 50# pylint: enable=ungrouped-imports 51 52# Internal log for troubleshooting this tool (the console). 53_LOG = logging.getLogger(__package__) 54_ROOT_LOG = logging.getLogger() 55 56DEFAULT_DEVICE_LOGGER = logging.getLogger('rpc_device') 57 58 59@dataclass 60class DeviceConnection: 61 """Stores a Device client along with the reader and writer.""" 62 63 client: pw_device_tracing.DeviceWithTracing | pw_device.Device 64 reader: stream_readers.SelectableReader | stream_readers.SerialReader 65 writer: Callable[[bytes], int | None] 66 67 def __enter__(self): 68 """Enter the reader followed by the client context managers. 69 70 Returns the device client for RPC interaction. 71 """ 72 self.reader.__enter__() 73 self.client.__enter__() 74 return self.client 75 76 def __exit__(self, *exc_info): 77 """Close the device connection followed by the reader.""" 78 self.client.__exit__() 79 self.reader.__exit__() 80 81 82def create_device_serial_or_socket_connection( 83 # pylint: disable=too-many-arguments,too-many-locals 84 device: str, 85 baudrate: int, 86 token_databases: Collection[Path], 87 socket_addr: str | None = None, 88 ticks_per_second: int | None = None, 89 serial_debug: bool = False, 90 compiled_protos: list[ModuleType] | None = None, 91 rpc_logging: bool = True, 92 channel_id: int = rpc.DEFAULT_CHANNEL_ID, 93 hdlc_encoding: bool = True, 94 device_tracing: bool = True, 95 device_class: type[pw_device.Device] | None = pw_device.Device, 96 device_tracing_class: type[pw_device_tracing.DeviceWithTracing] 97 | None = (pw_device_tracing.DeviceWithTracing), 98 timestamp_decoder: Callable[[int], str] | None = None, 99 extra_frame_handlers: dict[int, Callable[[bytes, Any], Any]] | None = None, 100) -> DeviceConnection: 101 """Setup a pw_system.Device client connection. 102 103 Full example of opening a device connection and running an RPC: 104 105 .. code-block:: python 106 107 from pw_system.device_connection import ( 108 add_device_args, 109 create_device_serial_or_socket_connection, 110 ) 111 112 from pw_protobuf_protos import common_pb2 113 from pw_rpc import echo_pb2 114 115 parser = argparse.ArgumentParser( 116 prog='MyProductScript', 117 ) 118 parser = add_device_args(parser) 119 args = parser.parse_args() 120 121 compiled_protos = [ 122 common_pb2, 123 echo_pb2, 124 ] 125 126 device_connection = create_device_serial_or_socket_connection( 127 device=args.device, 128 baudrate=args.baudrate, 129 token_databases=args.token_databases, 130 compiled_protos=compiled_protos, 131 socket_addr=args.socket_addr, 132 ticks_per_second=args.ticks_per_second, 133 serial_debug=args.serial_debug, 134 rpc_logging=args.rpc_logging, 135 hdlc_encoding=args.hdlc_encoding, 136 channel_id=args.channel_id, 137 device_tracing=args.device_tracing, 138 device_class=Device, 139 device_tracing_class=DeviceWithTracing, 140 timestamp_decoder=timestamp_parser_ms_since_boot, 141 ) 142 143 144 # Open the device connction and interact with the Device client. 145 with device_connection as device: 146 # Make a shortcut to the EchoService. 147 echo_service = device.rpcs.pw.rpc.EchoService 148 149 # Call some RPCs and check the results. 150 result = echo_service.Echo(msg='Hello') 151 152 if result.status.ok(): 153 print('The status was', result.status) 154 print('The message was', result.response.msg) 155 else: 156 print('Uh oh, this RPC returned', result.status) 157 158 # The result object can be split into status and payload 159 # when assigned. 160 status, payload = echo_service.Echo(msg='Goodbye!') 161 162 print(f'{status}: {payload}') 163 """ 164 165 detokenizer = None 166 if token_databases: 167 token_databases_with_domains = [] * len(token_databases) 168 for token_database in token_databases: 169 # Load all domains from token database. 170 token_databases_with_domains.append(str(token_database) + "#.*") 171 172 detokenizer = detokenize.AutoUpdatingDetokenizer( 173 *token_databases_with_domains 174 ) 175 detokenizer.show_errors = True 176 177 protos: list[ModuleType | Path] = [] 178 179 if compiled_protos is None: 180 compiled_protos = [] 181 182 # Append compiled log.proto library to avoid include errors when 183 # manually provided, and shadowing errors due to ordering when the 184 # default global search path is used. 185 if rpc_logging: 186 compiled_protos.append(log_pb2) 187 compiled_protos.append(unit_test_pb2) 188 protos.extend(compiled_protos) 189 protos.append(metric_service_pb2) 190 protos.append(thread_snapshot_service_pb2) 191 protos.append(file_pb2) 192 protos.append(echo_pb2) 193 protos.append(trace_service_pb2) 194 protos.append(device_service_pb2) 195 196 reader: stream_readers.SelectableReader | stream_readers.SerialReader 197 198 if socket_addr is None: 199 serial_impl = ( 200 pyserial_wrapper.SerialWithLogging 201 if serial_debug 202 else serial.Serial 203 ) 204 205 if not device: 206 device = interactive_serial_port_select() 207 _ROOT_LOG.info('Using serial port: %s', device) 208 serial_device = serial_impl( 209 device, 210 baudrate, 211 # Timeout in seconds. This should be a very small value. Setting to 212 # zero makes pyserial read() non-blocking which will cause the host 213 # machine to busy loop and 100% CPU usage. 214 # https://pythonhosted.org/pyserial/pyserial_api.html#serial.Serial 215 timeout=0.1, 216 ) 217 reader = stream_readers.SerialReader(serial_device, 8192) 218 write = serial_device.write 219 220 # Overwrite decoder for serial device. 221 if timestamp_decoder is None: 222 timestamp_decoder = timestamp_parser_ms_since_boot 223 else: 224 socket_impl = ( 225 socket_client.SocketClientWithLogging 226 if serial_debug 227 else socket_client.SocketClient 228 ) 229 230 def disconnect_handler( 231 socket_device: socket_client.SocketClient, 232 ) -> None: 233 """Attempts to reconnect on disconnected socket.""" 234 _LOG.error('Socket disconnected. Will retry to connect.') 235 while True: 236 try: 237 socket_device.connect() 238 break 239 except: # pylint: disable=bare-except 240 # Ignore errors and retry to reconnect. 241 time.sleep(1) 242 _LOG.info('Successfully reconnected') 243 244 try: 245 socket_device = socket_impl( 246 socket_addr, on_disconnect=disconnect_handler 247 ) 248 reader = stream_readers.SelectableReader(socket_device) 249 write = socket_device.write 250 except ValueError as error: 251 raise ValueError( 252 f'Failed to initialize socket at {socket_addr}' 253 ) from error 254 255 device_args: list[Any] = [channel_id, reader, write] 256 device_kwds: dict[str, Any] = { 257 'proto_library': protos, 258 'detokenizer': detokenizer, 259 'timestamp_decoder': timestamp_decoder, 260 'rpc_timeout_s': 5, 261 'use_rpc_logging': rpc_logging, 262 'use_hdlc_encoding': hdlc_encoding, 263 'extra_frame_handlers': extra_frame_handlers, 264 } 265 266 device_client: pw_device_tracing.DeviceWithTracing | pw_device.Device 267 if device_tracing: 268 device_kwds['ticks_per_second'] = ticks_per_second 269 if device_tracing_class is None: 270 device_tracing_class = pw_device_tracing.DeviceWithTracing 271 device_client = device_tracing_class(*device_args, **device_kwds) 272 else: 273 if device_class is None: 274 device_class = pw_device.Device 275 device_client = device_class(*device_args, **device_kwds) 276 277 return DeviceConnection(device_client, reader, write) 278 279 280def add_device_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 281 """Add device specific args required by the pw_system.Device class.""" 282 283 group = parser.add_mutually_exclusive_group(required=False) 284 285 group.add_argument( 286 '-d', 287 '--device', 288 help='the serial port to use', 289 ) 290 group.add_argument( 291 '-s', 292 '--socket-addr', 293 type=str, 294 help=( 295 'Socket address used to connect to server. Type "default" to use ' 296 'localhost:33000, pass the server address and port as ' 297 'address:port, or prefix the path to a forwarded socket with ' 298 f'"{socket_client.SocketClient.FILE_SOCKET_SERVER}:" as ' 299 f'{socket_client.SocketClient.FILE_SOCKET_SERVER}:path_to_file.' 300 ), 301 ) 302 303 parser.add_argument( 304 '-b', 305 '--baudrate', 306 type=int, 307 default=115200, 308 help='the baud rate to use', 309 ) 310 parser.add_argument( 311 '--serial-debug', 312 action='store_true', 313 help=( 314 'Enable debug log tracing of all data passed through' 315 'pyserial read and write.' 316 ), 317 ) 318 parser.add_argument( 319 "--token-databases", 320 metavar='elf_or_token_database', 321 nargs="+", 322 type=Path, 323 help="Path to tokenizer database csv file(s).", 324 ) 325 parser.add_argument( 326 '-f', 327 '--ticks_per_second', 328 type=int, 329 dest='ticks_per_second', 330 help=('The clock rate of the trace events.'), 331 ) 332 parser.add_argument( 333 '--rpc-logging', 334 action=argparse.BooleanOptionalAction, 335 default=True, 336 help='Use pw_rpc based logging.', 337 ) 338 parser.add_argument( 339 '--hdlc-encoding', 340 action=argparse.BooleanOptionalAction, 341 default=True, 342 help='Use HDLC encoding on transfer interfaces.', 343 ) 344 parser.add_argument( 345 '--channel-id', 346 type=int, 347 default=rpc.DEFAULT_CHANNEL_ID, 348 help="Channel ID used in RPC communications.", 349 ) 350 parser.add_argument( 351 '--device-tracing', 352 action=argparse.BooleanOptionalAction, 353 default=True, 354 help='Use device tracing.', 355 ) 356 357 return parser 358