1# Copyright 2023 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Wrapers for socket clients to log read and write data.""" 15from __future__ import annotations 16 17from typing import Callable, TYPE_CHECKING 18 19import errno 20import re 21import socket 22 23from pw_console.plugins.bandwidth_toolbar import SerialBandwidthTracker 24 25if TYPE_CHECKING: 26 from _typeshed import ReadableBuffer 27 28 29class SocketClient: 30 """Socket transport implementation.""" 31 32 FILE_SOCKET_SERVER = 'file' 33 DEFAULT_SOCKET_SERVER = 'localhost' 34 DEFAULT_SOCKET_PORT = 33000 35 PW_RPC_MAX_PACKET_SIZE = 256 36 DEFAULT_TIMEOUT = 0.5 37 38 _InitArgsType = tuple[ 39 socket.AddressFamily, int # pylint: disable=no-member 40 ] 41 # Can be a string, (address, port) for AF_INET or (address, port, flowinfo, 42 # scope_id) AF_INET6. 43 _AddressType = str | tuple[str, int] | tuple[str, int, int, int] 44 45 def __init__( 46 self, 47 config: str, 48 on_disconnect: Callable[[SocketClient], None] | None = None, 49 timeout: float | None = None, 50 ): 51 """Creates a socket connection. 52 53 Args: 54 config: The socket configuration. Accepted values and formats are: 55 'default' - uses the default configuration (localhost:33000) 56 'address:port' - An IPv4 address and port. 57 'address' - An IPv4 address. Uses default port 33000. 58 '[address]:port' - An IPv6 address and port. 59 '[address]' - An IPv6 address. Uses default port 33000. 60 'file:path_to_file' - A Unix socket at ``path_to_file``. 61 In the formats above,``address`` can be an actual address or a name 62 that resolves to an address through name-resolution. 63 on_disconnect: An optional callback called when the socket 64 disconnects. 65 66 Raises: 67 TypeError: The type of socket is not supported. 68 ValueError: The socket configuration is invalid. 69 """ 70 self.socket: socket.socket 71 ( 72 self._socket_init_args, 73 self._address, 74 ) = SocketClient._parse_socket_config(config) 75 self._on_disconnect = on_disconnect 76 self._timeout = SocketClient.DEFAULT_TIMEOUT 77 if timeout: 78 self._timeout = timeout 79 self._connected = False 80 self.connect() 81 82 @staticmethod 83 def _parse_socket_config( 84 config: str, 85 ) -> tuple[SocketClient._InitArgsType, SocketClient._AddressType]: 86 """Sets the variables used to create a socket given a config string. 87 88 Raises: 89 TypeError: The type of socket is not supported. 90 ValueError: The socket configuration is invalid. 91 """ 92 init_args: SocketClient._InitArgsType 93 address: SocketClient._AddressType 94 95 # Check if this is using the default settings. 96 if config == 'default': 97 init_args = socket.AF_INET6, socket.SOCK_STREAM 98 address = ( 99 SocketClient.DEFAULT_SOCKET_SERVER, 100 SocketClient.DEFAULT_SOCKET_PORT, 101 ) 102 return init_args, address 103 104 # Check if this is a UNIX socket. 105 unix_socket_file_setting = f'{SocketClient.FILE_SOCKET_SERVER}:' 106 if config.startswith(unix_socket_file_setting): 107 # Unix socket support is available on Windows 10 since April 108 # 2018. However, there is no Python support on Windows yet. 109 # See https://bugs.python.org/issue33408 for more information. 110 if not hasattr(socket, 'AF_UNIX'): 111 raise TypeError( 112 'Unix sockets are not supported in this environment.' 113 ) 114 init_args = ( 115 socket.AF_UNIX, # pylint: disable=no-member 116 socket.SOCK_STREAM, 117 ) 118 address = config[len(unix_socket_file_setting) :] 119 return init_args, address 120 121 # Search for IPv4 or IPv6 address or name and port. 122 # First, try to capture an IPv6 address as anything inside []. If there 123 # are no [] capture the IPv4 address. Lastly, capture the port as the 124 # numbers after :, if any. 125 match = re.match( 126 r'(\[(?P<ipv6_addr>.+)\]:?|(?P<ipv4_addr>[a-zA-Z0-9\._\/]+):?)' 127 r'(?P<port>[0-9]+)?', 128 config, 129 ) 130 invalid_config_message = ( 131 f'Invalid socket configuration "{config}"' 132 'Accepted values are "default", "file:<file_path>", ' 133 '"<name_or_ipv4_address>" with optional ":<port>", and ' 134 '"[<name_or_ipv6_address>]" with optional ":<port>".' 135 ) 136 if match is None: 137 raise ValueError(invalid_config_message) 138 139 info = match.groupdict() 140 if info['port']: 141 port = int(info['port']) 142 else: 143 port = SocketClient.DEFAULT_SOCKET_PORT 144 145 if info['ipv4_addr']: 146 ip_addr = info['ipv4_addr'] 147 elif info['ipv6_addr']: 148 ip_addr = info['ipv6_addr'] 149 else: 150 raise ValueError(invalid_config_message) 151 152 sock_family, sock_type, _, _, address = socket.getaddrinfo( 153 ip_addr, port, type=socket.SOCK_STREAM 154 )[0] 155 init_args = sock_family, sock_type 156 return init_args, address 157 158 def __del__(self): 159 if self._connected: 160 self.socket.close() 161 162 def write(self, data: ReadableBuffer) -> None: 163 """Writes data and detects disconnects.""" 164 if not self._connected: 165 raise Exception('Socket is not connected.') 166 try: 167 self.socket.sendall(data) 168 except socket.error as exc: 169 if isinstance(exc.args, tuple) and exc.args[0] == errno.EPIPE: 170 self._handle_disconnect() 171 else: 172 raise exc 173 174 def read(self, num_bytes: int = PW_RPC_MAX_PACKET_SIZE) -> bytes: 175 """Blocks until data is ready and reads up to num_bytes.""" 176 if not self._connected: 177 raise Exception('Socket is not connected.') 178 data = self.socket.recv(num_bytes) 179 # Since this is a blocking read, no data returned means the socket is 180 # closed. 181 if not data: 182 self._handle_disconnect() 183 return data 184 185 def connect(self) -> None: 186 """Connects to socket.""" 187 self.socket = socket.socket(*self._socket_init_args) 188 189 # Enable reusing address and port for reconnections. 190 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 191 if hasattr(socket, 'SO_REUSEPORT'): 192 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) 193 self.socket.settimeout(self._timeout) 194 self.socket.connect(self._address) 195 self._connected = True 196 197 def _handle_disconnect(self): 198 """Escalates a socket disconnect to the user.""" 199 self.socket.close() 200 self._connected = False 201 if self._on_disconnect: 202 self._on_disconnect(self) 203 204 def fileno(self) -> int: 205 return self.socket.fileno() 206 207 208class SocketClientWithLogging(SocketClient): 209 """Socket with read and write wrappers for logging.""" 210 211 def __init__(self, *args, **kwargs): 212 super().__init__(*args, **kwargs) 213 self._bandwidth_tracker = SerialBandwidthTracker() 214 215 def read( 216 self, num_bytes: int = SocketClient.PW_RPC_MAX_PACKET_SIZE 217 ) -> bytes: 218 data = super().read(num_bytes) 219 self._bandwidth_tracker.track_read_data(data) 220 return data 221 222 def write(self, data: ReadableBuffer) -> None: 223 self._bandwidth_tracker.track_write_data(data) 224 super().write(data) 225