1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import logging
20import os
21import time
22from typing import Optional
23from bumble.colors import color
24from bumble.hci import (
25    HCI_READ_LOOPBACK_MODE_COMMAND,
26    HCI_Read_Loopback_Mode_Command,
27    HCI_WRITE_LOOPBACK_MODE_COMMAND,
28    HCI_Write_Loopback_Mode_Command,
29    LoopbackMode,
30)
31from bumble.host import Host
32from bumble.transport import open_transport_or_link
33import click
34
35
36class Loopback:
37    """Send and receive ACL data packets in local loopback mode"""
38
39    def __init__(self, packet_size: int, packet_count: int, transport: str):
40        self.transport = transport
41        self.packet_size = packet_size
42        self.packet_count = packet_count
43        self.connection_handle: Optional[int] = None
44        self.connection_event = asyncio.Event()
45        self.done = asyncio.Event()
46        self.expected_cid = 0
47        self.bytes_received = 0
48        self.start_timestamp = 0.0
49        self.last_timestamp = 0.0
50
51    def on_connection(self, connection_handle: int, *args):
52        """Retrieve connection handle from new connection event"""
53        if not self.connection_event.is_set():
54            # save first connection handle for ACL
55            # subsequent connections are SCO
56            self.connection_handle = connection_handle
57            self.connection_event.set()
58
59    def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
60        """Calculate packet receive speed"""
61        now = time.time()
62        print(f'<<< Received packet {cid}: {len(pdu)} bytes')
63        assert connection_handle == self.connection_handle
64        assert cid == self.expected_cid
65        self.expected_cid += 1
66        if cid == 0:
67            self.start_timestamp = now
68        else:
69            elapsed_since_start = now - self.start_timestamp
70            elapsed_since_last = now - self.last_timestamp
71            self.bytes_received += len(pdu)
72            instant_rx_speed = len(pdu) / elapsed_since_last
73            average_rx_speed = self.bytes_received / elapsed_since_start
74            print(
75                color(
76                    f'@@@ RX speed: instant={instant_rx_speed:.4f},'
77                    f' average={average_rx_speed:.4f}',
78                    'cyan',
79                )
80            )
81
82        self.last_timestamp = now
83
84        if self.expected_cid == self.packet_count:
85            print(color('@@@ Received last packet', 'green'))
86            self.done.set()
87
88    async def run(self):
89        """Run a loopback throughput test"""
90        print(color('>>> Connecting to HCI...', 'green'))
91        async with await open_transport_or_link(self.transport) as (
92            hci_source,
93            hci_sink,
94        ):
95            print(color('>>> Connected', 'green'))
96
97            host = Host(hci_source, hci_sink)
98            await host.reset()
99
100            # make sure data can fit in one l2cap pdu
101            l2cap_header_size = 4
102
103            max_packet_size = (
104                host.acl_packet_queue
105                if host.acl_packet_queue
106                else host.le_acl_packet_queue
107            ).max_packet_size - l2cap_header_size
108            if self.packet_size > max_packet_size:
109                print(
110                    color(
111                        f'!!! Packet size ({self.packet_size}) larger than max supported'
112                        f' size ({max_packet_size})',
113                        'red',
114                    )
115                )
116                return
117
118            if not host.supports_command(
119                HCI_WRITE_LOOPBACK_MODE_COMMAND
120            ) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
121                print(color('!!! Loopback mode not supported', 'red'))
122                return
123
124            # set event callbacks
125            host.on('connection', self.on_connection)
126            host.on('l2cap_pdu', self.on_l2cap_pdu)
127
128            loopback_mode = LoopbackMode.LOCAL
129
130            print(color('### Setting loopback mode', 'blue'))
131            await host.send_command(
132                HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
133                check_result=True,
134            )
135
136            print(color('### Checking loopback mode', 'blue'))
137            response = await host.send_command(
138                HCI_Read_Loopback_Mode_Command(), check_result=True
139            )
140            if response.return_parameters.loopback_mode != loopback_mode:
141                print(color('!!! Loopback mode mismatch', 'red'))
142                return
143
144            await self.connection_event.wait()
145            print(color('### Connected', 'cyan'))
146
147            print(color('=== Start sending', 'magenta'))
148            start_time = time.time()
149            bytes_sent = 0
150            for cid in range(0, self.packet_count):
151                # using the cid as an incremental index
152                host.send_l2cap_pdu(
153                    self.connection_handle, cid, bytes(self.packet_size)
154                )
155                print(
156                    color(
157                        f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
158                    )
159                )
160                bytes_sent += self.packet_size  # don't count L2CAP or HCI header sizes
161                await asyncio.sleep(0)  # yield to allow packet receive
162
163            await self.done.wait()
164            print(color('=== Done!', 'magenta'))
165
166            elapsed = time.time() - start_time
167            average_tx_speed = bytes_sent / elapsed
168            print(
169                color(
170                    f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
171                    f' in {elapsed:.2f} seconds)',
172                    'green',
173                )
174            )
175
176
177# -----------------------------------------------------------------------------
178@click.command()
179@click.option(
180    '--packet-size',
181    '-s',
182    metavar='SIZE',
183    type=click.IntRange(8, 4096),
184    default=500,
185    help='Packet size',
186)
187@click.option(
188    '--packet-count',
189    '-c',
190    metavar='COUNT',
191    type=click.IntRange(1, 65535),
192    default=10,
193    help='Packet count',
194)
195@click.argument('transport')
196def main(packet_size, packet_count, transport):
197    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
198
199    loopback = Loopback(packet_size, packet_count, transport)
200    asyncio.run(loopback.run())
201
202
203# -----------------------------------------------------------------------------
204if __name__ == '__main__':
205    main()
206