# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- import asyncio import logging import os import time from typing import Optional from bumble.colors import color from bumble.hci import ( HCI_READ_LOOPBACK_MODE_COMMAND, HCI_Read_Loopback_Mode_Command, HCI_WRITE_LOOPBACK_MODE_COMMAND, HCI_Write_Loopback_Mode_Command, LoopbackMode, ) from bumble.host import Host from bumble.transport import open_transport_or_link import click class Loopback: """Send and receive ACL data packets in local loopback mode""" def __init__(self, packet_size: int, packet_count: int, transport: str): self.transport = transport self.packet_size = packet_size self.packet_count = packet_count self.connection_handle: Optional[int] = None self.connection_event = asyncio.Event() self.done = asyncio.Event() self.expected_cid = 0 self.bytes_received = 0 self.start_timestamp = 0.0 self.last_timestamp = 0.0 def on_connection(self, connection_handle: int, *args): """Retrieve connection handle from new connection event""" if not self.connection_event.is_set(): # save first connection handle for ACL # subsequent connections are SCO self.connection_handle = connection_handle self.connection_event.set() def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes): """Calculate packet receive speed""" now = time.time() print(f'<<< Received packet {cid}: {len(pdu)} bytes') assert connection_handle == self.connection_handle assert cid == self.expected_cid self.expected_cid += 1 if cid == 0: self.start_timestamp = now else: elapsed_since_start = now - self.start_timestamp elapsed_since_last = now - self.last_timestamp self.bytes_received += len(pdu) instant_rx_speed = len(pdu) / elapsed_since_last average_rx_speed = self.bytes_received / elapsed_since_start print( color( f'@@@ RX speed: instant={instant_rx_speed:.4f},' f' average={average_rx_speed:.4f}', 'cyan', ) ) self.last_timestamp = now if self.expected_cid == self.packet_count: print(color('@@@ Received last packet', 'green')) self.done.set() async def run(self): """Run a loopback throughput test""" print(color('>>> Connecting to HCI...', 'green')) async with await open_transport_or_link(self.transport) as ( hci_source, hci_sink, ): print(color('>>> Connected', 'green')) host = Host(hci_source, hci_sink) await host.reset() # make sure data can fit in one l2cap pdu l2cap_header_size = 4 max_packet_size = ( host.acl_packet_queue if host.acl_packet_queue else host.le_acl_packet_queue ).max_packet_size - l2cap_header_size if self.packet_size > max_packet_size: print( color( f'!!! Packet size ({self.packet_size}) larger than max supported' f' size ({max_packet_size})', 'red', ) ) return if not host.supports_command( HCI_WRITE_LOOPBACK_MODE_COMMAND ) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND): print(color('!!! Loopback mode not supported', 'red')) return # set event callbacks host.on('connection', self.on_connection) host.on('l2cap_pdu', self.on_l2cap_pdu) loopback_mode = LoopbackMode.LOCAL print(color('### Setting loopback mode', 'blue')) await host.send_command( HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL), check_result=True, ) print(color('### Checking loopback mode', 'blue')) response = await host.send_command( HCI_Read_Loopback_Mode_Command(), check_result=True ) if response.return_parameters.loopback_mode != loopback_mode: print(color('!!! Loopback mode mismatch', 'red')) return await self.connection_event.wait() print(color('### Connected', 'cyan')) print(color('=== Start sending', 'magenta')) start_time = time.time() bytes_sent = 0 for cid in range(0, self.packet_count): # using the cid as an incremental index host.send_l2cap_pdu( self.connection_handle, cid, bytes(self.packet_size) ) print( color( f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow' ) ) bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes await asyncio.sleep(0) # yield to allow packet receive await self.done.wait() print(color('=== Done!', 'magenta')) elapsed = time.time() - start_time average_tx_speed = bytes_sent / elapsed print( color( f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes' f' in {elapsed:.2f} seconds)', 'green', ) ) # ----------------------------------------------------------------------------- @click.command() @click.option( '--packet-size', '-s', metavar='SIZE', type=click.IntRange(8, 4096), default=500, help='Packet size', ) @click.option( '--packet-count', '-c', metavar='COUNT', type=click.IntRange(1, 65535), default=10, help='Packet count', ) @click.argument('transport') def main(packet_size, packet_count, transport): logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) loopback = Loopback(packet_size, packet_count, transport) asyncio.run(loopback.run()) # ----------------------------------------------------------------------------- if __name__ == '__main__': main()