1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import asyncio
20import logging
21import socket
22
23from .common import Transport, StreamPacketSource
24
25# -----------------------------------------------------------------------------
26# Logging
27# -----------------------------------------------------------------------------
28logger = logging.getLogger(__name__)
29
30
31# -----------------------------------------------------------------------------
32
33
34# A pass-through function to ease mock testing.
35async def _create_server(*args, **kw_args):
36    await asyncio.get_running_loop().create_server(*args, **kw_args)
37
38
39async def open_tcp_server_transport(spec: str) -> Transport:
40    '''
41    Open a TCP server transport.
42    The parameter string has this syntax:
43    <local-host>:<local-port>
44    Where <local-host> may be the address of a local network interface, or '_'
45    to accept connections on all local network interfaces.
46
47    Example: _:9001
48    '''
49    local_host, local_port = spec.split(':')
50    return await _open_tcp_server_transport_impl(
51        host=local_host if local_host != '_' else None, port=int(local_port)
52    )
53
54
55async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport:
56    '''
57    Open a TCP server transport with an existing socket.
58
59    One reason to use this variant is to let python pick an unused port.
60    '''
61    return await _open_tcp_server_transport_impl(sock=sock)
62
63
64async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
65    class TcpServerTransport(Transport):
66        async def close(self):
67            await super().close()
68
69    class TcpServerProtocol(asyncio.BaseProtocol):
70        def __init__(self, packet_source, packet_sink):
71            self.packet_source = packet_source
72            self.packet_sink = packet_sink
73
74        # Called when a new connection is established
75        def connection_made(self, transport):
76            peer_name = transport.get_extra_info('peer_name')
77            logger.debug(f'connection from {peer_name}')
78            self.packet_sink.transport = transport
79
80        # Called when the client is disconnected
81        def connection_lost(self, error):
82            logger.debug(f'connection lost: {error}')
83            self.packet_sink.transport = None
84
85        def eof_received(self):
86            logger.debug('connection end')
87            self.packet_sink.transport = None
88
89        # Called when data is received on the socket
90        def data_received(self, data):
91            self.packet_source.data_received(data)
92
93    class TcpServerPacketSink:
94        def __init__(self):
95            self.transport = None
96
97        def on_packet(self, packet):
98            if self.transport:
99                self.transport.write(packet)
100            else:
101                logger.debug('no client, dropping packet')
102
103    packet_source = StreamPacketSource()
104    packet_sink = TcpServerPacketSink()
105    await _create_server(
106        lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs
107    )
108
109    return TcpServerTransport(packet_source, packet_sink)
110