xref: /aosp_15_r20/external/pigweed/pw_console/py/pw_console/socket_client.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2023 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Wrapers for socket clients to log read and write data."""
15from __future__ import annotations
16
17from typing import Callable, TYPE_CHECKING
18
19import errno
20import re
21import socket
22
23from pw_console.plugins.bandwidth_toolbar import SerialBandwidthTracker
24
25if TYPE_CHECKING:
26    from _typeshed import ReadableBuffer
27
28
29class SocketClient:
30    """Socket transport implementation."""
31
32    FILE_SOCKET_SERVER = 'file'
33    DEFAULT_SOCKET_SERVER = 'localhost'
34    DEFAULT_SOCKET_PORT = 33000
35    PW_RPC_MAX_PACKET_SIZE = 256
36    DEFAULT_TIMEOUT = 0.5
37
38    _InitArgsType = tuple[
39        socket.AddressFamily, int  # pylint: disable=no-member
40    ]
41    # Can be a string, (address, port) for AF_INET or (address, port, flowinfo,
42    # scope_id) AF_INET6.
43    _AddressType = str | tuple[str, int] | tuple[str, int, int, int]
44
45    def __init__(
46        self,
47        config: str,
48        on_disconnect: Callable[[SocketClient], None] | None = None,
49        timeout: float | None = None,
50    ):
51        """Creates a socket connection.
52
53        Args:
54          config: The socket configuration. Accepted values and formats are:
55            'default' - uses the default configuration (localhost:33000)
56            'address:port' - An IPv4 address and port.
57            'address' - An IPv4 address. Uses default port 33000.
58            '[address]:port' - An IPv6 address and port.
59            '[address]' - An IPv6 address. Uses default port 33000.
60            'file:path_to_file' - A Unix socket at ``path_to_file``.
61            In the formats above,``address`` can be an actual address or a name
62            that resolves to an address through name-resolution.
63          on_disconnect: An optional callback called when the socket
64            disconnects.
65
66        Raises:
67          TypeError: The type of socket is not supported.
68          ValueError: The socket configuration is invalid.
69        """
70        self.socket: socket.socket
71        (
72            self._socket_init_args,
73            self._address,
74        ) = SocketClient._parse_socket_config(config)
75        self._on_disconnect = on_disconnect
76        self._timeout = SocketClient.DEFAULT_TIMEOUT
77        if timeout:
78            self._timeout = timeout
79        self._connected = False
80        self.connect()
81
82    @staticmethod
83    def _parse_socket_config(
84        config: str,
85    ) -> tuple[SocketClient._InitArgsType, SocketClient._AddressType]:
86        """Sets the variables used to create a socket given a config string.
87
88        Raises:
89          TypeError: The type of socket is not supported.
90          ValueError: The socket configuration is invalid.
91        """
92        init_args: SocketClient._InitArgsType
93        address: SocketClient._AddressType
94
95        # Check if this is using the default settings.
96        if config == 'default':
97            init_args = socket.AF_INET6, socket.SOCK_STREAM
98            address = (
99                SocketClient.DEFAULT_SOCKET_SERVER,
100                SocketClient.DEFAULT_SOCKET_PORT,
101            )
102            return init_args, address
103
104        # Check if this is a UNIX socket.
105        unix_socket_file_setting = f'{SocketClient.FILE_SOCKET_SERVER}:'
106        if config.startswith(unix_socket_file_setting):
107            # Unix socket support is available on Windows 10 since April
108            # 2018. However, there is no Python support on Windows yet.
109            # See https://bugs.python.org/issue33408 for more information.
110            if not hasattr(socket, 'AF_UNIX'):
111                raise TypeError(
112                    'Unix sockets are not supported in this environment.'
113                )
114            init_args = (
115                socket.AF_UNIX,  # pylint: disable=no-member
116                socket.SOCK_STREAM,
117            )
118            address = config[len(unix_socket_file_setting) :]
119            return init_args, address
120
121        # Search for IPv4 or IPv6 address or name and port.
122        # First, try to capture an IPv6 address as anything inside []. If there
123        # are no [] capture the IPv4 address. Lastly, capture the port as the
124        # numbers after :, if any.
125        match = re.match(
126            r'(\[(?P<ipv6_addr>.+)\]:?|(?P<ipv4_addr>[a-zA-Z0-9\._\/]+):?)'
127            r'(?P<port>[0-9]+)?',
128            config,
129        )
130        invalid_config_message = (
131            f'Invalid socket configuration "{config}"'
132            'Accepted values are "default", "file:<file_path>", '
133            '"<name_or_ipv4_address>" with optional ":<port>", and '
134            '"[<name_or_ipv6_address>]" with optional ":<port>".'
135        )
136        if match is None:
137            raise ValueError(invalid_config_message)
138
139        info = match.groupdict()
140        if info['port']:
141            port = int(info['port'])
142        else:
143            port = SocketClient.DEFAULT_SOCKET_PORT
144
145        if info['ipv4_addr']:
146            ip_addr = info['ipv4_addr']
147        elif info['ipv6_addr']:
148            ip_addr = info['ipv6_addr']
149        else:
150            raise ValueError(invalid_config_message)
151
152        sock_family, sock_type, _, _, address = socket.getaddrinfo(
153            ip_addr, port, type=socket.SOCK_STREAM
154        )[0]
155        init_args = sock_family, sock_type
156        return init_args, address
157
158    def __del__(self):
159        if self._connected:
160            self.socket.close()
161
162    def write(self, data: ReadableBuffer) -> None:
163        """Writes data and detects disconnects."""
164        if not self._connected:
165            raise Exception('Socket is not connected.')
166        try:
167            self.socket.sendall(data)
168        except socket.error as exc:
169            if isinstance(exc.args, tuple) and exc.args[0] == errno.EPIPE:
170                self._handle_disconnect()
171            else:
172                raise exc
173
174    def read(self, num_bytes: int = PW_RPC_MAX_PACKET_SIZE) -> bytes:
175        """Blocks until data is ready and reads up to num_bytes."""
176        if not self._connected:
177            raise Exception('Socket is not connected.')
178        data = self.socket.recv(num_bytes)
179        # Since this is a blocking read, no data returned means the socket is
180        # closed.
181        if not data:
182            self._handle_disconnect()
183        return data
184
185    def connect(self) -> None:
186        """Connects to socket."""
187        self.socket = socket.socket(*self._socket_init_args)
188
189        # Enable reusing address and port for reconnections.
190        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
191        if hasattr(socket, 'SO_REUSEPORT'):
192            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
193        self.socket.settimeout(self._timeout)
194        self.socket.connect(self._address)
195        self._connected = True
196
197    def _handle_disconnect(self):
198        """Escalates a socket disconnect to the user."""
199        self.socket.close()
200        self._connected = False
201        if self._on_disconnect:
202            self._on_disconnect(self)
203
204    def fileno(self) -> int:
205        return self.socket.fileno()
206
207
208class SocketClientWithLogging(SocketClient):
209    """Socket with read and write wrappers for logging."""
210
211    def __init__(self, *args, **kwargs):
212        super().__init__(*args, **kwargs)
213        self._bandwidth_tracker = SerialBandwidthTracker()
214
215    def read(
216        self, num_bytes: int = SocketClient.PW_RPC_MAX_PACKET_SIZE
217    ) -> bytes:
218        data = super().read(num_bytes)
219        self._bandwidth_tracker.track_read_data(data)
220        return data
221
222    def write(self, data: ReadableBuffer) -> None:
223        self._bandwidth_tracker.track_write_data(data)
224        super().write(data)
225