1# Copyright 2024 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import logging 20import os 21import time 22from typing import Optional 23from bumble.colors import color 24from bumble.hci import ( 25 HCI_READ_LOOPBACK_MODE_COMMAND, 26 HCI_Read_Loopback_Mode_Command, 27 HCI_WRITE_LOOPBACK_MODE_COMMAND, 28 HCI_Write_Loopback_Mode_Command, 29 LoopbackMode, 30) 31from bumble.host import Host 32from bumble.transport import open_transport_or_link 33import click 34 35 36class Loopback: 37 """Send and receive ACL data packets in local loopback mode""" 38 39 def __init__(self, packet_size: int, packet_count: int, transport: str): 40 self.transport = transport 41 self.packet_size = packet_size 42 self.packet_count = packet_count 43 self.connection_handle: Optional[int] = None 44 self.connection_event = asyncio.Event() 45 self.done = asyncio.Event() 46 self.expected_cid = 0 47 self.bytes_received = 0 48 self.start_timestamp = 0.0 49 self.last_timestamp = 0.0 50 51 def on_connection(self, connection_handle: int, *args): 52 """Retrieve connection handle from new connection event""" 53 if not self.connection_event.is_set(): 54 # save first connection handle for ACL 55 # subsequent connections are SCO 56 self.connection_handle = connection_handle 57 self.connection_event.set() 58 59 def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes): 60 """Calculate packet receive speed""" 61 now = time.time() 62 print(f'<<< Received packet {cid}: {len(pdu)} bytes') 63 assert connection_handle == self.connection_handle 64 assert cid == self.expected_cid 65 self.expected_cid += 1 66 if cid == 0: 67 self.start_timestamp = now 68 else: 69 elapsed_since_start = now - self.start_timestamp 70 elapsed_since_last = now - self.last_timestamp 71 self.bytes_received += len(pdu) 72 instant_rx_speed = len(pdu) / elapsed_since_last 73 average_rx_speed = self.bytes_received / elapsed_since_start 74 print( 75 color( 76 f'@@@ RX speed: instant={instant_rx_speed:.4f},' 77 f' average={average_rx_speed:.4f}', 78 'cyan', 79 ) 80 ) 81 82 self.last_timestamp = now 83 84 if self.expected_cid == self.packet_count: 85 print(color('@@@ Received last packet', 'green')) 86 self.done.set() 87 88 async def run(self): 89 """Run a loopback throughput test""" 90 print(color('>>> Connecting to HCI...', 'green')) 91 async with await open_transport_or_link(self.transport) as ( 92 hci_source, 93 hci_sink, 94 ): 95 print(color('>>> Connected', 'green')) 96 97 host = Host(hci_source, hci_sink) 98 await host.reset() 99 100 # make sure data can fit in one l2cap pdu 101 l2cap_header_size = 4 102 103 max_packet_size = ( 104 host.acl_packet_queue 105 if host.acl_packet_queue 106 else host.le_acl_packet_queue 107 ).max_packet_size - l2cap_header_size 108 if self.packet_size > max_packet_size: 109 print( 110 color( 111 f'!!! Packet size ({self.packet_size}) larger than max supported' 112 f' size ({max_packet_size})', 113 'red', 114 ) 115 ) 116 return 117 118 if not host.supports_command( 119 HCI_WRITE_LOOPBACK_MODE_COMMAND 120 ) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND): 121 print(color('!!! Loopback mode not supported', 'red')) 122 return 123 124 # set event callbacks 125 host.on('connection', self.on_connection) 126 host.on('l2cap_pdu', self.on_l2cap_pdu) 127 128 loopback_mode = LoopbackMode.LOCAL 129 130 print(color('### Setting loopback mode', 'blue')) 131 await host.send_command( 132 HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL), 133 check_result=True, 134 ) 135 136 print(color('### Checking loopback mode', 'blue')) 137 response = await host.send_command( 138 HCI_Read_Loopback_Mode_Command(), check_result=True 139 ) 140 if response.return_parameters.loopback_mode != loopback_mode: 141 print(color('!!! Loopback mode mismatch', 'red')) 142 return 143 144 await self.connection_event.wait() 145 print(color('### Connected', 'cyan')) 146 147 print(color('=== Start sending', 'magenta')) 148 start_time = time.time() 149 bytes_sent = 0 150 for cid in range(0, self.packet_count): 151 # using the cid as an incremental index 152 host.send_l2cap_pdu( 153 self.connection_handle, cid, bytes(self.packet_size) 154 ) 155 print( 156 color( 157 f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow' 158 ) 159 ) 160 bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes 161 await asyncio.sleep(0) # yield to allow packet receive 162 163 await self.done.wait() 164 print(color('=== Done!', 'magenta')) 165 166 elapsed = time.time() - start_time 167 average_tx_speed = bytes_sent / elapsed 168 print( 169 color( 170 f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes' 171 f' in {elapsed:.2f} seconds)', 172 'green', 173 ) 174 ) 175 176 177# ----------------------------------------------------------------------------- 178@click.command() 179@click.option( 180 '--packet-size', 181 '-s', 182 metavar='SIZE', 183 type=click.IntRange(8, 4096), 184 default=500, 185 help='Packet size', 186) 187@click.option( 188 '--packet-count', 189 '-c', 190 metavar='COUNT', 191 type=click.IntRange(1, 65535), 192 default=10, 193 help='Packet count', 194) 195@click.argument('transport') 196def main(packet_size, packet_count, transport): 197 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) 198 199 loopback = Loopback(packet_size, packet_count, transport) 200 asyncio.run(loopback.run()) 201 202 203# ----------------------------------------------------------------------------- 204if __name__ == '__main__': 205 main() 206