#!/usr/bin/env python3 # Copyright 2022 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. """Proxy for transfer integration testing. This module contains a proxy for transfer intergation testing. It is capable of introducing various link failures into the connection between the client and server. """ import abc import argparse import asyncio from enum import Enum import logging import random import socket import sys import time from typing import Awaitable, Callable, Iterable, NamedTuple from google.protobuf import text_format from pw_rpc.internal import packet_pb2 from pw_transfer import transfer_pb2 from pw_transfer.integration_test import config_pb2 from pw_hdlc import decode from pw_transfer.chunk import Chunk _LOG = logging.getLogger('pw_transfer_intergration_test_proxy') # This is the maximum size of the socket receive buffers. Ideally, this is set # to the lowest allowed value to minimize buffering between the proxy and # clients so rate limiting causes the client to block and wait for the # integration test proxy to drain rather than allowing OS buffers to backlog # large quantities of data. # # Note that the OS may chose to not strictly follow this requested buffer size. # Still, setting this value to be relatively small does reduce bufer sizes # significantly enough to better reflect typical inter-device communication. # # For this to be effective, clients should also configure their sockets to a # smaller send buffer size. _RECEIVE_BUFFER_SIZE = 2048 class EventType(Enum): TRANSFER_START = 1 PARAMETERS_RETRANSMIT = 2 PARAMETERS_CONTINUE = 3 START_ACK_CONFIRMATION = 4 class Event(NamedTuple): type: EventType chunk: Chunk class Filter(abc.ABC): """An abstract interface for manipulating a stream of data. ``Filter``s are used to implement various transforms to simulate real world link properties. Some examples include: data corruption, packet loss, packet reordering, rate limiting, latency modeling. A ``Filter`` implementation should implement the ``process`` method and call ``self.send_data()`` when it has data to send. """ def __init__(self, send_data: Callable[[bytes], Awaitable[None]]): self.send_data = send_data @abc.abstractmethod async def process(self, data: bytes) -> None: """Processes incoming data. Implementations of this method may send arbitrary data, or none, using the ``self.send_data()`` handler. """ async def __call__(self, data: bytes) -> None: await self.process(data) class HdlcPacketizer(Filter): """A filter which aggregates data into complete HDLC packets. Since the proxy transport (SOCK_STREAM) has no framing and we want some filters to operates on whole frames, this filter can be used so that downstream filters see whole frames. """ def __init__(self, send_data: Callable[[bytes], Awaitable[None]]): super().__init__(send_data) self.decoder = decode.FrameDecoder() async def process(self, data: bytes) -> None: for frame in self.decoder.process(data): await self.send_data(frame.raw_encoded) class DataDropper(Filter): """A filter which drops some data. DataDropper will drop data passed through ``process()`` at the specified ``rate``. """ def __init__( self, send_data: Callable[[bytes], Awaitable[None]], name: str, rate: float, seed: int | None = None, ): super().__init__(send_data) self._rate = rate self._name = name if seed == None: seed = time.time_ns() self._rng = random.Random(seed) _LOG.info(f'{name} DataDropper initialized with seed {seed}') async def process(self, data: bytes) -> None: if self._rng.uniform(0.0, 1.0) < self._rate: _LOG.info(f'{self._name} dropped {len(data)} bytes of data') else: await self.send_data(data) class KeepDropQueue(Filter): """A filter which alternates between sending packets and dropping packets. A KeepDropQueue filter will alternate between keeping packets and dropping chunks of data based on a keep/drop queue provided during its creation. The queue is looped over unless a negative element is found. A negative number is effectively the same as a value of infinity. This filter is typically most practical when used with a packetizer so data can be dropped as distinct packets. Examples: keep_drop_queue = [3, 2]: Keeps 3 packets, Drops 2 packets, Keeps 3 packets, Drops 2 packets, ... [loops indefinitely] keep_drop_queue = [5, 99, 1, -1]: Keeps 5 packets, Drops 99 packets, Keeps 1 packet, Drops all further packets. """ def __init__( self, send_data: Callable[[bytes], Awaitable[None]], name: str, keep_drop_queue: Iterable[int], only_consider_transfer_chunks: bool = False, ): super().__init__(send_data) self._keep_drop_queue = list(keep_drop_queue) self._loop_idx = 0 self._current_count = self._keep_drop_queue[0] self._keep = True self._name = name self._only_consider_transfer_chunks = only_consider_transfer_chunks async def process(self, data: bytes) -> None: if self._only_consider_transfer_chunks: try: _extract_transfer_chunk(data) except Exception: await self.send_data(data) return # Move forward through the queue if needed. while self._current_count == 0: self._loop_idx += 1 self._current_count = self._keep_drop_queue[ self._loop_idx % len(self._keep_drop_queue) ] self._keep = not self._keep if self._current_count > 0: self._current_count -= 1 if self._keep: await self.send_data(data) _LOG.info(f'{self._name} forwarded {len(data)} bytes of data') else: _LOG.info(f'{self._name} dropped {len(data)} bytes of data') class RateLimiter(Filter): """A filter which limits transmission rate. This filter delays transmission of data by len(data)/rate. """ def __init__( self, send_data: Callable[[bytes], Awaitable[None]], rate: float ): super().__init__(send_data) self._rate = rate async def process(self, data: bytes) -> None: delay = len(data) / self._rate await asyncio.sleep(delay) await self.send_data(data) class DataTransposer(Filter): """A filter which occasionally transposes two chunks of data. This filter transposes data at the specified rate. It does this by holding a chunk to transpose until another chunk arrives. The filter will not hold a chunk longer than ``timeout`` seconds. """ def __init__( self, send_data: Callable[[bytes], Awaitable[None]], name: str, rate: float, timeout: float, seed: int, ): super().__init__(send_data) self._name = name self._rate = rate self._timeout = timeout self._data_queue = asyncio.Queue() self._rng = random.Random(seed) self._transpose_task = asyncio.create_task(self._transpose_handler()) _LOG.info(f'{name} DataTranspose initialized with seed {seed}') def __del__(self): _LOG.info(f'{self._name} cleaning up transpose task.') self._transpose_task.cancel() async def _transpose_handler(self): """Async task that handles the packet transposition and timeouts""" held_data: bytes | None = None while True: # Only use timeout if we have data held for transposition timeout = None if held_data is None else self._timeout try: data = await asyncio.wait_for( self._data_queue.get(), timeout=timeout ) if held_data is not None: # If we have held data, send it out of order. await self.send_data(data) await self.send_data(held_data) held_data = None else: # Otherwise decide if we should transpose the current data. if self._rng.uniform(0.0, 1.0) < self._rate: _LOG.info( f'{self._name} transposing {len(data)} bytes of data' ) held_data = data else: await self.send_data(data) except asyncio.TimeoutError: _LOG.info(f'{self._name} sending data in order due to timeout') await self.send_data(held_data) held_data = None async def process(self, data: bytes) -> None: # Queue data for processing by the transpose task. await self._data_queue.put(data) class ServerFailure(Filter): """A filter to simulate the server stopping sending packets. ServerFailure takes a list of numbers of packets to send before dropping all subsequent packets until a TRANSFER_START packet is seen. This process is repeated for each element in packets_before_failure. After that list is exhausted, ServerFailure will send all packets. This filter should be instantiated in the same filter stack as an HdlcPacketizer so that EventFilter can decode complete packets. """ def __init__( self, send_data: Callable[[bytes], Awaitable[None]], name: str, packets_before_failure_list: list[int], start_immediately: bool = False, only_consider_transfer_chunks: bool = False, ): super().__init__(send_data) self._name = name self._relay_packets = True self._packets_before_failure_list = packets_before_failure_list self._packets_before_failure = None self._only_consider_transfer_chunks = only_consider_transfer_chunks if start_immediately: self.advance_packets_before_failure() def advance_packets_before_failure(self): if len(self._packets_before_failure_list) > 0: self._packets_before_failure = ( self._packets_before_failure_list.pop(0) ) else: self._packets_before_failure = None async def process(self, data: bytes) -> None: if self._only_consider_transfer_chunks: try: _extract_transfer_chunk(data) except Exception: await self.send_data(data) return if self._packets_before_failure is None: await self.send_data(data) elif self._packets_before_failure > 0: self._packets_before_failure -= 1 await self.send_data(data) def handle_event(self, event: Event) -> None: if event.type is EventType.TRANSFER_START: self.advance_packets_before_failure() class WindowPacketDropper(Filter): """A filter to allow the same packet in each window to be dropped. WindowPacketDropper with drop the nth packet in each window as specified by window_packet_to_drop. This process will happen indefinitely for each window. This filter should be instantiated in the same filter stack as an HdlcPacketizer so that EventFilter can decode complete packets. """ def __init__( self, send_data: Callable[[bytes], Awaitable[None]], name: str, window_packet_to_drop: int, ): super().__init__(send_data) self._name = name self._relay_packets = True self._window_packet_to_drop = window_packet_to_drop self._next_window_start_offset: int | None = 0 self._window_packet = 0 async def process(self, data: bytes) -> None: data_chunk = None try: chunk = _extract_transfer_chunk(data) if chunk.type is Chunk.Type.DATA: data_chunk = chunk except Exception: # Invalid / non-chunk data (e.g. text logs); ignore. pass # Only count transfer data chunks as part of a window. if data_chunk is not None: if data_chunk.offset == self._next_window_start_offset: # If a new window has been requested, wait until the first # chunk matching its requested offset to begin counting window # chunks. Any in-flight chunks from the previous window are # allowed through. self._window_packet = 0 self._next_window_start_offset = None if self._window_packet != self._window_packet_to_drop: await self.send_data(data) self._window_packet += 1 else: await self.send_data(data) def handle_event(self, event: Event) -> None: if event.type in ( EventType.PARAMETERS_RETRANSMIT, EventType.PARAMETERS_CONTINUE, EventType.START_ACK_CONFIRMATION, ): # A new transmission window has been requested, starting at the # offset specified in the chunk. The receiver may already have data # from the previous window in-flight, so don't immediately reset # the window packet counter. self._next_window_start_offset = event.chunk.offset class EventFilter(Filter): """A filter that inspects packets and send events to other filters. This filter should be instantiated in the same filter stack as an HdlcPacketizer so that it can decode complete packets. """ def __init__( self, send_data: Callable[[bytes], Awaitable[None]], name: str, event_queue: asyncio.Queue, ): super().__init__(send_data) self._name = name self._queue = event_queue async def process(self, data: bytes) -> None: try: chunk = _extract_transfer_chunk(data) if chunk.type is Chunk.Type.START: await self._queue.put(Event(EventType.TRANSFER_START, chunk)) if chunk.type is Chunk.Type.START_ACK_CONFIRMATION: await self._queue.put( Event(EventType.START_ACK_CONFIRMATION, chunk) ) elif chunk.type is Chunk.Type.PARAMETERS_RETRANSMIT: await self._queue.put( Event(EventType.PARAMETERS_RETRANSMIT, chunk) ) elif chunk.type is Chunk.Type.PARAMETERS_CONTINUE: await self._queue.put( Event(EventType.PARAMETERS_CONTINUE, chunk) ) except: # Silently ignore invalid packets pass await self.send_data(data) def _extract_transfer_chunk(data: bytes) -> Chunk: """Gets a transfer Chunk from an HDLC frame containing an RPC packet. Raises an exception if a valid chunk does not exist. """ decoder = decode.FrameDecoder() for frame in decoder.process(data): packet = packet_pb2.RpcPacket() packet.ParseFromString(frame.data) if packet.payload: raw_chunk = transfer_pb2.Chunk() raw_chunk.ParseFromString(packet.payload) return Chunk.from_message(raw_chunk) # The incoming data is expected to be HDLC-packetized, so only one # frame should exist. break raise ValueError("Invalid transfer chunk frame") async def _handle_simplex_events( event_queue: asyncio.Queue, handlers: list[Callable[[Event], None]] ): while True: event = await event_queue.get() for handler in handlers: handler(event) async def _handle_simplex_connection( name: str, filter_stack_config: list[config_pb2.FilterConfig], reader: asyncio.StreamReader, writer: asyncio.StreamWriter, inbound_event_queue: asyncio.Queue, outbound_event_queue: asyncio.Queue, ) -> None: """Handle a single direction of a bidirectional connection between server and client.""" async def send(data: bytes): writer.write(data) await writer.drain() filter_stack = EventFilter(send, name, outbound_event_queue) event_handlers: list[Callable[[Event], None]] = [] # Build the filter stack from the bottom up for config in reversed(filter_stack_config): filter_name = config.WhichOneof("filter") if filter_name == "hdlc_packetizer": filter_stack = HdlcPacketizer(filter_stack) elif filter_name == "data_dropper": data_dropper = config.data_dropper filter_stack = DataDropper( filter_stack, name, data_dropper.rate, data_dropper.seed ) elif filter_name == "rate_limiter": filter_stack = RateLimiter(filter_stack, config.rate_limiter.rate) elif filter_name == "data_transposer": transposer = config.data_transposer filter_stack = DataTransposer( filter_stack, name, transposer.rate, transposer.timeout, transposer.seed, ) elif filter_name == "server_failure": server_failure = config.server_failure filter_stack = ServerFailure( filter_stack, name, server_failure.packets_before_failure, server_failure.start_immediately, server_failure.only_consider_transfer_chunks, ) event_handlers.append(filter_stack.handle_event) elif filter_name == "keep_drop_queue": keep_drop_queue = config.keep_drop_queue filter_stack = KeepDropQueue( filter_stack, name, keep_drop_queue.keep_drop_queue, keep_drop_queue.only_consider_transfer_chunks, ) elif filter_name == "window_packet_dropper": window_packet_dropper = config.window_packet_dropper filter_stack = WindowPacketDropper( filter_stack, name, window_packet_dropper.window_packet_to_drop ) event_handlers.append(filter_stack.handle_event) else: sys.exit(f'Unknown filter {filter_name}') event_task = asyncio.create_task( _handle_simplex_events(inbound_event_queue, event_handlers) ) while True: # Arbitrarily chosen "page sized" read. data = await reader.read(4096) # An empty data indicates that the connection is closed. if not data: _LOG.info(f'{name} connection closed.') return await filter_stack.process(data) async def _handle_connection( server_port: int, config: config_pb2.ProxyConfig, client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, ) -> None: """Handle a connection between server and client.""" client_addr = client_writer.get_extra_info('peername') _LOG.info(f'New client connection from {client_addr}') # Open a new connection to the server for each client connection. # # TODO(konkers): catch exception and close client writer server_reader, server_writer = await asyncio.open_connection( 'localhost', server_port ) _LOG.info('New connection opened to server') # Queues for the simplex connections to pass events to each other. server_event_queue = asyncio.Queue() client_event_queue = asyncio.Queue() # Instantiate two simplex handler one for each direction of the connection. _, pending = await asyncio.wait( [ asyncio.create_task( _handle_simplex_connection( "client", config.client_filter_stack, client_reader, server_writer, server_event_queue, client_event_queue, ) ), asyncio.create_task( _handle_simplex_connection( "server", config.server_filter_stack, server_reader, client_writer, client_event_queue, server_event_queue, ) ), ], return_when=asyncio.FIRST_COMPLETED, ) # When one side terminates the connection, also terminate the other side for task in pending: task.cancel() for stream in [client_writer, server_writer]: stream.close() def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( '--server-port', type=int, required=True, help='Port of the integration test server. The proxy will forward connections to this port', ) parser.add_argument( '--client-port', type=int, required=True, help='Port on which to listen for connections from integration test client.', ) return parser.parse_args() def _init_logging(level: int) -> None: _LOG.setLevel(logging.DEBUG) log_to_stderr = logging.StreamHandler() log_to_stderr.setLevel(level) log_to_stderr.setFormatter( logging.Formatter( fmt='%(asctime)s.%(msecs)03d-%(levelname)s: %(message)s', datefmt='%H:%M:%S', ) ) _LOG.addHandler(log_to_stderr) async def _main(server_port: int, client_port: int) -> None: _init_logging(logging.DEBUG) # Load config from stdin using synchronous IO text_config = sys.stdin.buffer.read() config = text_format.Parse(text_config, config_pb2.ProxyConfig()) # Instantiate the TCP server. server_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) server_socket.setsockopt( socket.SOL_SOCKET, socket.SO_RCVBUF, _RECEIVE_BUFFER_SIZE ) server_socket.bind(('', client_port)) server = await asyncio.start_server( lambda reader, writer: _handle_connection( server_port, config, reader, writer ), limit=_RECEIVE_BUFFER_SIZE, sock=server_socket, ) addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets) _LOG.info(f'Listening for client connection on {addrs}') # Run the TCP server. async with server: await server.serve_forever() if __name__ == '__main__': asyncio.run(_main(**vars(_parse_args())))