1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19import asyncio 20import logging 21import socket 22 23from .common import Transport, StreamPacketSource 24 25# ----------------------------------------------------------------------------- 26# Logging 27# ----------------------------------------------------------------------------- 28logger = logging.getLogger(__name__) 29 30 31# ----------------------------------------------------------------------------- 32 33 34# A pass-through function to ease mock testing. 35async def _create_server(*args, **kw_args): 36 await asyncio.get_running_loop().create_server(*args, **kw_args) 37 38 39async def open_tcp_server_transport(spec: str) -> Transport: 40 ''' 41 Open a TCP server transport. 42 The parameter string has this syntax: 43 <local-host>:<local-port> 44 Where <local-host> may be the address of a local network interface, or '_' 45 to accept connections on all local network interfaces. 46 47 Example: _:9001 48 ''' 49 local_host, local_port = spec.split(':') 50 return await _open_tcp_server_transport_impl( 51 host=local_host if local_host != '_' else None, port=int(local_port) 52 ) 53 54 55async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport: 56 ''' 57 Open a TCP server transport with an existing socket. 58 59 One reason to use this variant is to let python pick an unused port. 60 ''' 61 return await _open_tcp_server_transport_impl(sock=sock) 62 63 64async def _open_tcp_server_transport_impl(**kwargs) -> Transport: 65 class TcpServerTransport(Transport): 66 async def close(self): 67 await super().close() 68 69 class TcpServerProtocol(asyncio.BaseProtocol): 70 def __init__(self, packet_source, packet_sink): 71 self.packet_source = packet_source 72 self.packet_sink = packet_sink 73 74 # Called when a new connection is established 75 def connection_made(self, transport): 76 peer_name = transport.get_extra_info('peer_name') 77 logger.debug(f'connection from {peer_name}') 78 self.packet_sink.transport = transport 79 80 # Called when the client is disconnected 81 def connection_lost(self, error): 82 logger.debug(f'connection lost: {error}') 83 self.packet_sink.transport = None 84 85 def eof_received(self): 86 logger.debug('connection end') 87 self.packet_sink.transport = None 88 89 # Called when data is received on the socket 90 def data_received(self, data): 91 self.packet_source.data_received(data) 92 93 class TcpServerPacketSink: 94 def __init__(self): 95 self.transport = None 96 97 def on_packet(self, packet): 98 if self.transport: 99 self.transport.write(packet) 100 else: 101 logger.debug('no client, dropping packet') 102 103 packet_source = StreamPacketSource() 104 packet_sink = TcpServerPacketSink() 105 await _create_server( 106 lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs 107 ) 108 109 return TcpServerTransport(packet_source, packet_sink) 110