1# Copyright 2021-2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import enum
20import logging
21import os
22import struct
23import time
24
25import click
26
27from bumble import l2cap
28from bumble.core import (
29    BT_BR_EDR_TRANSPORT,
30    BT_LE_TRANSPORT,
31    BT_L2CAP_PROTOCOL_ID,
32    BT_RFCOMM_PROTOCOL_ID,
33    UUID,
34    CommandTimeoutError,
35)
36from bumble.colors import color
37from bumble.device import Connection, ConnectionParametersPreferences, Device, Peer
38from bumble.gatt import Characteristic, CharacteristicValue, Service
39from bumble.hci import (
40    HCI_LE_1M_PHY,
41    HCI_LE_2M_PHY,
42    HCI_LE_CODED_PHY,
43    HCI_CENTRAL_ROLE,
44    HCI_PERIPHERAL_ROLE,
45    HCI_Constant,
46    HCI_Error,
47    HCI_StatusError,
48)
49from bumble.sdp import (
50    SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
51    SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
52    SDP_PUBLIC_BROWSE_ROOT,
53    SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
54    SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
55    DataElement,
56    ServiceAttribute,
57)
58from bumble.transport import open_transport_or_link
59import bumble.rfcomm
60import bumble.core
61from bumble.utils import AsyncRunner
62from bumble.pairing import PairingConfig
63
64
65# -----------------------------------------------------------------------------
66# Logging
67# -----------------------------------------------------------------------------
68logger = logging.getLogger(__name__)
69
70
71# -----------------------------------------------------------------------------
72# Constants
73# -----------------------------------------------------------------------------
74DEFAULT_CENTRAL_ADDRESS = 'F0:F0:F0:F0:F0:F0'
75DEFAULT_CENTRAL_NAME = 'Speed Central'
76DEFAULT_PERIPHERAL_ADDRESS = 'F1:F1:F1:F1:F1:F1'
77DEFAULT_PERIPHERAL_NAME = 'Speed Peripheral'
78
79SPEED_SERVICE_UUID = '50DB505C-8AC4-4738-8448-3B1D9CC09CC5'
80SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53'
81SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'
82
83DEFAULT_RFCOMM_UUID = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'
84DEFAULT_L2CAP_PSM = 128
85DEFAULT_L2CAP_MAX_CREDITS = 128
86DEFAULT_L2CAP_MTU = 1024
87DEFAULT_L2CAP_MPS = 1024
88
89DEFAULT_LINGER_TIME = 1.0
90DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0
91
92DEFAULT_RFCOMM_CHANNEL = 8
93DEFAULT_RFCOMM_MTU = 2048
94
95
96# -----------------------------------------------------------------------------
97# Utils
98# -----------------------------------------------------------------------------
99def parse_packet(packet):
100    if len(packet) < 1:
101        logging.info(
102            color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
103        )
104        raise ValueError('packet too short')
105
106    try:
107        packet_type = PacketType(packet[0])
108    except ValueError:
109        logging.info(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
110        raise
111
112    return (packet_type, packet[1:])
113
114
115def parse_packet_sequence(packet_data):
116    if len(packet_data) < 5:
117        logging.info(
118            color(
119                f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
120                'red',
121            )
122        )
123        raise ValueError('packet too short')
124    return struct.unpack_from('>bI', packet_data, 0)
125
126
127def le_phy_name(phy_id):
128    return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
129        phy_id, HCI_Constant.le_phy_name(phy_id)
130    )
131
132
133def print_connection(connection):
134    params = []
135    if connection.transport == BT_LE_TRANSPORT:
136        params.append(
137            'PHY='
138            f'TX:{le_phy_name(connection.phy.tx_phy)}/'
139            f'RX:{le_phy_name(connection.phy.rx_phy)}'
140        )
141
142        params.append(
143            'DL=('
144            f'TX:{connection.data_length[0]}/{connection.data_length[1]},'
145            f'RX:{connection.data_length[2]}/{connection.data_length[3]}'
146            ')'
147        )
148
149        params.append(
150            'Parameters='
151            f'{connection.parameters.connection_interval * 1.25:.2f}/'
152            f'{connection.parameters.peripheral_latency}/'
153            f'{connection.parameters.supervision_timeout * 10} '
154        )
155
156        params.append(f'MTU={connection.att_mtu}')
157
158    else:
159        params.append(f'Role={HCI_Constant.role_name(connection.role)}')
160
161    logging.info(color('@@@ Connection: ', 'yellow') + ' '.join(params))
162
163
164def make_sdp_records(channel):
165    return {
166        0x00010001: [
167            ServiceAttribute(
168                SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
169                DataElement.unsigned_integer_32(0x00010001),
170            ),
171            ServiceAttribute(
172                SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
173                DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
174            ),
175            ServiceAttribute(
176                SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
177                DataElement.sequence([DataElement.uuid(UUID(DEFAULT_RFCOMM_UUID))]),
178            ),
179            ServiceAttribute(
180                SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
181                DataElement.sequence(
182                    [
183                        DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
184                        DataElement.sequence(
185                            [
186                                DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
187                                DataElement.unsigned_integer_8(channel),
188                            ]
189                        ),
190                    ]
191                ),
192            ),
193        ]
194    }
195
196
197def log_stats(title, stats):
198    stats_min = min(stats)
199    stats_max = max(stats)
200    stats_avg = sum(stats) / len(stats)
201    logging.info(
202        color(
203            (
204                f'### {title} stats: '
205                f'min={stats_min:.2f}, '
206                f'max={stats_max:.2f}, '
207                f'average={stats_avg:.2f}'
208            ),
209            'cyan',
210        )
211    )
212
213
214async def switch_roles(connection, role):
215    target_role = HCI_CENTRAL_ROLE if role == "central" else HCI_PERIPHERAL_ROLE
216    if connection.role != target_role:
217        logging.info(f'{color("### Switching roles to:", "cyan")} {role}')
218        try:
219            await connection.switch_role(target_role)
220            logging.info(color('### Role switch complete', 'cyan'))
221        except HCI_Error as error:
222            logging.info(f'{color("### Role switch failed:", "red")} {error}')
223
224
225class PacketType(enum.IntEnum):
226    RESET = 0
227    SEQUENCE = 1
228    ACK = 2
229
230
231PACKET_FLAG_LAST = 1
232
233
234# -----------------------------------------------------------------------------
235# Sender
236# -----------------------------------------------------------------------------
237class Sender:
238    def __init__(
239        self,
240        packet_io,
241        start_delay,
242        repeat,
243        repeat_delay,
244        pace,
245        packet_size,
246        packet_count,
247    ):
248        self.tx_start_delay = start_delay
249        self.tx_packet_size = packet_size
250        self.tx_packet_count = packet_count
251        self.packet_io = packet_io
252        self.packet_io.packet_listener = self
253        self.repeat = repeat
254        self.repeat_delay = repeat_delay
255        self.pace = pace
256        self.start_time = 0
257        self.bytes_sent = 0
258        self.stats = []
259        self.done = asyncio.Event()
260
261    def reset(self):
262        pass
263
264    async def run(self):
265        logging.info(color('--- Waiting for I/O to be ready...', 'blue'))
266        await self.packet_io.ready.wait()
267        logging.info(color('--- Go!', 'blue'))
268
269        for run in range(self.repeat + 1):
270            self.done.clear()
271
272            if run > 0 and self.repeat and self.repeat_delay:
273                logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
274                await asyncio.sleep(self.repeat_delay)
275
276            if self.tx_start_delay:
277                logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
278                await asyncio.sleep(self.tx_start_delay)
279
280            logging.info(color('=== Sending RESET', 'magenta'))
281            await self.packet_io.send_packet(bytes([PacketType.RESET]))
282            self.start_time = time.time()
283            self.bytes_sent = 0
284            for tx_i in range(self.tx_packet_count):
285                packet_flags = (
286                    PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0
287                )
288                packet = struct.pack(
289                    '>bbI',
290                    PacketType.SEQUENCE,
291                    packet_flags,
292                    tx_i,
293                ) + bytes(self.tx_packet_size - 6 - self.packet_io.overhead_size)
294                logging.info(
295                    color(
296                        f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow'
297                    )
298                )
299                self.bytes_sent += len(packet)
300                await self.packet_io.send_packet(packet)
301
302                if self.pace is None:
303                    continue
304
305                if self.pace > 0:
306                    await asyncio.sleep(self.pace / 1000)
307                else:
308                    await self.packet_io.drain()
309
310            await self.done.wait()
311
312            run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
313            logging.info(color(f'=== {run_counter} Done!', 'magenta'))
314
315            if self.repeat:
316                log_stats('Run', self.stats)
317
318        if self.repeat:
319            logging.info(color('--- End of runs', 'blue'))
320
321    def on_packet_received(self, packet):
322        try:
323            packet_type, _ = parse_packet(packet)
324        except ValueError:
325            return
326
327        if packet_type == PacketType.ACK:
328            elapsed = time.time() - self.start_time
329            average_tx_speed = self.bytes_sent / elapsed
330            self.stats.append(average_tx_speed)
331            logging.info(
332                color(
333                    f'@@@ Received ACK. Speed: average={average_tx_speed:.4f}'
334                    f' ({self.bytes_sent} bytes in {elapsed:.2f} seconds)',
335                    'green',
336                )
337            )
338            self.done.set()
339
340
341# -----------------------------------------------------------------------------
342# Receiver
343# -----------------------------------------------------------------------------
344class Receiver:
345    expected_packet_index: int
346    start_timestamp: float
347    last_timestamp: float
348
349    def __init__(self, packet_io, linger):
350        self.reset()
351        self.packet_io = packet_io
352        self.packet_io.packet_listener = self
353        self.linger = linger
354        self.done = asyncio.Event()
355
356    def reset(self):
357        self.expected_packet_index = 0
358        self.measurements = [(time.time(), 0)]
359        self.total_bytes_received = 0
360
361    def on_packet_received(self, packet):
362        try:
363            packet_type, packet_data = parse_packet(packet)
364        except ValueError:
365            return
366
367        if packet_type == PacketType.RESET:
368            logging.info(color('=== Received RESET', 'magenta'))
369            self.reset()
370            return
371
372        try:
373            packet_flags, packet_index = parse_packet_sequence(packet_data)
374        except ValueError:
375            return
376        logging.info(
377            f'<<< Received packet {packet_index}: '
378            f'flags=0x{packet_flags:02X}, '
379            f'{len(packet) + self.packet_io.overhead_size} bytes'
380        )
381
382        if packet_index != self.expected_packet_index:
383            logging.info(
384                color(
385                    f'!!! Unexpected packet, expected {self.expected_packet_index} '
386                    f'but received {packet_index}'
387                )
388            )
389
390        now = time.time()
391        elapsed_since_start = now - self.measurements[0][0]
392        elapsed_since_last = now - self.measurements[-1][0]
393        self.measurements.append((now, len(packet)))
394        self.total_bytes_received += len(packet)
395        instant_rx_speed = len(packet) / elapsed_since_last
396        average_rx_speed = self.total_bytes_received / elapsed_since_start
397        window = self.measurements[-64:]
398        windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / (
399            window[-1][0] - window[0][0]
400        )
401        logging.info(
402            color(
403                'Speed: '
404                f'instant={instant_rx_speed:.4f}, '
405                f'windowed={windowed_rx_speed:.4f}, '
406                f'average={average_rx_speed:.4f}',
407                'yellow',
408            )
409        )
410
411        self.expected_packet_index = packet_index + 1
412
413        if packet_flags & PACKET_FLAG_LAST:
414            AsyncRunner.spawn(
415                self.packet_io.send_packet(
416                    struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
417                )
418            )
419            logging.info(color('@@@ Received last packet', 'green'))
420            if not self.linger:
421                self.done.set()
422
423    async def run(self):
424        await self.done.wait()
425        logging.info(color('=== Done!', 'magenta'))
426
427
428# -----------------------------------------------------------------------------
429# Ping
430# -----------------------------------------------------------------------------
431class Ping:
432    def __init__(
433        self,
434        packet_io,
435        start_delay,
436        repeat,
437        repeat_delay,
438        pace,
439        packet_size,
440        packet_count,
441    ):
442        self.tx_start_delay = start_delay
443        self.tx_packet_size = packet_size
444        self.tx_packet_count = packet_count
445        self.packet_io = packet_io
446        self.packet_io.packet_listener = self
447        self.repeat = repeat
448        self.repeat_delay = repeat_delay
449        self.pace = pace
450        self.done = asyncio.Event()
451        self.current_packet_index = 0
452        self.ping_sent_time = 0.0
453        self.latencies = []
454        self.min_stats = []
455        self.max_stats = []
456        self.avg_stats = []
457
458    def reset(self):
459        pass
460
461    async def run(self):
462        logging.info(color('--- Waiting for I/O to be ready...', 'blue'))
463        await self.packet_io.ready.wait()
464        logging.info(color('--- Go!', 'blue'))
465
466        for run in range(self.repeat + 1):
467            self.done.clear()
468
469            if run > 0 and self.repeat and self.repeat_delay:
470                logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
471                await asyncio.sleep(self.repeat_delay)
472
473            if self.tx_start_delay:
474                logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
475                await asyncio.sleep(self.tx_start_delay)
476
477            logging.info(color('=== Sending RESET', 'magenta'))
478            await self.packet_io.send_packet(bytes([PacketType.RESET]))
479
480            self.current_packet_index = 0
481            self.latencies = []
482            await self.send_next_ping()
483
484            await self.done.wait()
485
486            min_latency = min(self.latencies)
487            max_latency = max(self.latencies)
488            avg_latency = sum(self.latencies) / len(self.latencies)
489            logging.info(
490                color(
491                    '@@@ Latencies: '
492                    f'min={min_latency:.2f}, '
493                    f'max={max_latency:.2f}, '
494                    f'average={avg_latency:.2f}'
495                )
496            )
497
498            self.min_stats.append(min_latency)
499            self.max_stats.append(max_latency)
500            self.avg_stats.append(avg_latency)
501
502            run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
503            logging.info(color(f'=== {run_counter} Done!', 'magenta'))
504
505            if self.repeat:
506                log_stats('Min Latency', self.min_stats)
507                log_stats('Max Latency', self.max_stats)
508                log_stats('Average Latency', self.avg_stats)
509
510        if self.repeat:
511            logging.info(color('--- End of runs', 'blue'))
512
513    async def send_next_ping(self):
514        if self.pace:
515            await asyncio.sleep(self.pace / 1000)
516
517        packet = struct.pack(
518            '>bbI',
519            PacketType.SEQUENCE,
520            (
521                PACKET_FLAG_LAST
522                if self.current_packet_index == self.tx_packet_count - 1
523                else 0
524            ),
525            self.current_packet_index,
526        ) + bytes(self.tx_packet_size - 6)
527        logging.info(color(f'Sending packet {self.current_packet_index}', 'yellow'))
528        self.ping_sent_time = time.time()
529        await self.packet_io.send_packet(packet)
530
531    def on_packet_received(self, packet):
532        elapsed = time.time() - self.ping_sent_time
533
534        try:
535            packet_type, packet_data = parse_packet(packet)
536        except ValueError:
537            return
538
539        try:
540            packet_flags, packet_index = parse_packet_sequence(packet_data)
541        except ValueError:
542            return
543
544        if packet_type == PacketType.ACK:
545            latency = elapsed * 1000
546            self.latencies.append(latency)
547            logging.info(
548                color(
549                    f'<<< Received ACK [{packet_index}], latency={latency:.2f}ms',
550                    'green',
551                )
552            )
553
554            if packet_index == self.current_packet_index:
555                self.current_packet_index += 1
556            else:
557                logging.info(
558                    color(
559                        f'!!! Unexpected packet, expected {self.current_packet_index} '
560                        f'but received {packet_index}'
561                    )
562                )
563
564        if packet_flags & PACKET_FLAG_LAST:
565            self.done.set()
566            return
567
568        AsyncRunner.spawn(self.send_next_ping())
569
570
571# -----------------------------------------------------------------------------
572# Pong
573# -----------------------------------------------------------------------------
574class Pong:
575    expected_packet_index: int
576
577    def __init__(self, packet_io, linger):
578        self.reset()
579        self.packet_io = packet_io
580        self.packet_io.packet_listener = self
581        self.linger = linger
582        self.done = asyncio.Event()
583
584    def reset(self):
585        self.expected_packet_index = 0
586
587    def on_packet_received(self, packet):
588        try:
589            packet_type, packet_data = parse_packet(packet)
590        except ValueError:
591            return
592
593        if packet_type == PacketType.RESET:
594            logging.info(color('=== Received RESET', 'magenta'))
595            self.reset()
596            return
597
598        try:
599            packet_flags, packet_index = parse_packet_sequence(packet_data)
600        except ValueError:
601            return
602        logging.info(
603            color(
604                f'<<< Received packet {packet_index}: '
605                f'flags=0x{packet_flags:02X}, {len(packet)} bytes',
606                'green',
607            )
608        )
609
610        if packet_index != self.expected_packet_index:
611            logging.info(
612                color(
613                    f'!!! Unexpected packet, expected {self.expected_packet_index} '
614                    f'but received {packet_index}'
615                )
616            )
617
618        self.expected_packet_index = packet_index + 1
619
620        AsyncRunner.spawn(
621            self.packet_io.send_packet(
622                struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
623            )
624        )
625
626        if packet_flags & PACKET_FLAG_LAST and not self.linger:
627            self.done.set()
628
629    async def run(self):
630        await self.done.wait()
631        logging.info(color('=== Done!', 'magenta'))
632
633
634# -----------------------------------------------------------------------------
635# GattClient
636# -----------------------------------------------------------------------------
637class GattClient:
638    def __init__(self, _device, att_mtu=None):
639        self.att_mtu = att_mtu
640        self.speed_rx = None
641        self.speed_tx = None
642        self.packet_listener = None
643        self.ready = asyncio.Event()
644        self.overhead_size = 0
645
646    async def on_connection(self, connection):
647        peer = Peer(connection)
648
649        if self.att_mtu:
650            logging.info(color(f'*** Requesting MTU update: {self.att_mtu}', 'blue'))
651            await peer.request_mtu(self.att_mtu)
652
653        logging.info(color('*** Discovering services...', 'blue'))
654        await peer.discover_services()
655
656        speed_services = peer.get_services_by_uuid(SPEED_SERVICE_UUID)
657        if not speed_services:
658            logging.info(color('!!! Speed Service not found', 'red'))
659            return
660        speed_service = speed_services[0]
661        logging.info(color('*** Discovering characteristics...', 'blue'))
662        await speed_service.discover_characteristics()
663
664        speed_txs = speed_service.get_characteristics_by_uuid(SPEED_TX_UUID)
665        if not speed_txs:
666            logging.info(color('!!! Speed TX not found', 'red'))
667            return
668        self.speed_tx = speed_txs[0]
669
670        speed_rxs = speed_service.get_characteristics_by_uuid(SPEED_RX_UUID)
671        if not speed_rxs:
672            logging.info(color('!!! Speed RX not found', 'red'))
673            return
674        self.speed_rx = speed_rxs[0]
675
676        logging.info(color('*** Subscribing to RX', 'blue'))
677        await self.speed_rx.subscribe(self.on_packet_received)
678
679        logging.info(color('*** Discovery complete', 'blue'))
680
681        connection.on('disconnection', self.on_disconnection)
682        self.ready.set()
683
684    def on_disconnection(self, _):
685        self.ready.clear()
686
687    def on_packet_received(self, packet):
688        if self.packet_listener:
689            self.packet_listener.on_packet_received(packet)
690
691    async def send_packet(self, packet):
692        await self.speed_tx.write_value(packet)
693
694    async def drain(self):
695        pass
696
697
698# -----------------------------------------------------------------------------
699# GattServer
700# -----------------------------------------------------------------------------
701class GattServer:
702    def __init__(self, device):
703        self.device = device
704        self.packet_listener = None
705        self.ready = asyncio.Event()
706        self.overhead_size = 0
707
708        # Setup the GATT service
709        self.speed_tx = Characteristic(
710            SPEED_TX_UUID,
711            Characteristic.Properties.WRITE,
712            Characteristic.WRITEABLE,
713            CharacteristicValue(write=self.on_tx_write),
714        )
715        self.speed_rx = Characteristic(
716            SPEED_RX_UUID, Characteristic.Properties.NOTIFY, 0
717        )
718
719        speed_service = Service(
720            SPEED_SERVICE_UUID,
721            [self.speed_tx, self.speed_rx],
722        )
723        device.add_services([speed_service])
724
725        self.speed_rx.on('subscription', self.on_rx_subscription)
726
727    async def on_connection(self, connection):
728        connection.on('disconnection', self.on_disconnection)
729
730    def on_disconnection(self, _):
731        self.ready.clear()
732
733    def on_rx_subscription(self, _connection, notify_enabled, _indicate_enabled):
734        if notify_enabled:
735            logging.info(color('*** RX subscription', 'blue'))
736            self.ready.set()
737        else:
738            logging.info(color('*** RX un-subscription', 'blue'))
739            self.ready.clear()
740
741    def on_tx_write(self, _, value):
742        if self.packet_listener:
743            self.packet_listener.on_packet_received(value)
744
745    async def send_packet(self, packet):
746        await self.device.notify_subscribers(self.speed_rx, packet)
747
748    async def drain(self):
749        pass
750
751
752# -----------------------------------------------------------------------------
753# StreamedPacketIO
754# -----------------------------------------------------------------------------
755class StreamedPacketIO:
756    def __init__(self):
757        self.packet_listener = None
758        self.io_sink = None
759        self.rx_packet = b''
760        self.rx_packet_header = b''
761        self.rx_packet_need = 0
762        self.overhead_size = 2
763
764    def on_packet(self, packet):
765        while packet:
766            if self.rx_packet_need:
767                chunk = packet[: self.rx_packet_need]
768                self.rx_packet += chunk
769                packet = packet[len(chunk) :]
770                self.rx_packet_need -= len(chunk)
771                if not self.rx_packet_need:
772                    # Packet completed
773                    if self.packet_listener:
774                        self.packet_listener.on_packet_received(self.rx_packet)
775
776                    self.rx_packet = b''
777                    self.rx_packet_header = b''
778            else:
779                # Expect the next packet
780                header_bytes_needed = 2 - len(self.rx_packet_header)
781                header_bytes = packet[:header_bytes_needed]
782                self.rx_packet_header += header_bytes
783                if len(self.rx_packet_header) != 2:
784                    return
785                packet = packet[len(header_bytes) :]
786                self.rx_packet_need = struct.unpack('>H', self.rx_packet_header)[0]
787
788    async def send_packet(self, packet):
789        if not self.io_sink:
790            logging.info(color('!!! No sink, dropping packet', 'red'))
791            return
792
793        # pylint: disable-next=not-callable
794        self.io_sink(struct.pack('>H', len(packet)) + packet)
795
796
797# -----------------------------------------------------------------------------
798# L2capClient
799# -----------------------------------------------------------------------------
800class L2capClient(StreamedPacketIO):
801    def __init__(
802        self,
803        _device,
804        psm=DEFAULT_L2CAP_PSM,
805        max_credits=DEFAULT_L2CAP_MAX_CREDITS,
806        mtu=DEFAULT_L2CAP_MTU,
807        mps=DEFAULT_L2CAP_MPS,
808    ):
809        super().__init__()
810        self.psm = psm
811        self.max_credits = max_credits
812        self.mtu = mtu
813        self.mps = mps
814        self.l2cap_channel = None
815        self.ready = asyncio.Event()
816
817    async def on_connection(self, connection: Connection) -> None:
818        connection.on('disconnection', self.on_disconnection)
819
820        # Connect a new L2CAP channel
821        logging.info(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
822        try:
823            l2cap_channel = await connection.create_l2cap_channel(
824                spec=l2cap.LeCreditBasedChannelSpec(
825                    psm=self.psm,
826                    max_credits=self.max_credits,
827                    mtu=self.mtu,
828                    mps=self.mps,
829                )
830            )
831            logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))
832        except Exception as error:
833            logging.info(color(f'!!! Connection failed: {error}', 'red'))
834            return
835
836        self.io_sink = l2cap_channel.write
837        self.l2cap_channel = l2cap_channel
838        l2cap_channel.on('close', self.on_l2cap_close)
839        l2cap_channel.sink = self.on_packet
840
841        self.ready.set()
842
843    def on_disconnection(self, _):
844        pass
845
846    def on_l2cap_close(self):
847        logging.info(color('*** L2CAP channel closed', 'red'))
848
849    async def drain(self):
850        assert self.l2cap_channel
851        await self.l2cap_channel.drain()
852
853
854# -----------------------------------------------------------------------------
855# L2capServer
856# -----------------------------------------------------------------------------
857class L2capServer(StreamedPacketIO):
858    def __init__(
859        self,
860        device: Device,
861        psm=DEFAULT_L2CAP_PSM,
862        max_credits=DEFAULT_L2CAP_MAX_CREDITS,
863        mtu=DEFAULT_L2CAP_MTU,
864        mps=DEFAULT_L2CAP_MPS,
865    ):
866        super().__init__()
867        self.l2cap_channel = None
868        self.ready = asyncio.Event()
869
870        # Listen for incoming L2CAP connections
871        device.create_l2cap_server(
872            spec=l2cap.LeCreditBasedChannelSpec(
873                psm=psm, mtu=mtu, mps=mps, max_credits=max_credits
874            ),
875            handler=self.on_l2cap_channel,
876        )
877        logging.info(
878            color(f'### Listening for L2CAP connection on PSM {psm}', 'yellow')
879        )
880
881    async def on_connection(self, connection):
882        connection.on('disconnection', self.on_disconnection)
883
884    def on_disconnection(self, _):
885        pass
886
887    def on_l2cap_channel(self, l2cap_channel):
888        logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))
889
890        self.io_sink = l2cap_channel.write
891        self.l2cap_channel = l2cap_channel
892        l2cap_channel.on('close', self.on_l2cap_close)
893        l2cap_channel.sink = self.on_packet
894
895        self.ready.set()
896
897    def on_l2cap_close(self):
898        logging.info(color('*** L2CAP channel closed', 'red'))
899        self.l2cap_channel = None
900
901    async def drain(self):
902        assert self.l2cap_channel
903        await self.l2cap_channel.drain()
904
905
906# -----------------------------------------------------------------------------
907# RfcommClient
908# -----------------------------------------------------------------------------
909class RfcommClient(StreamedPacketIO):
910    def __init__(
911        self,
912        device,
913        channel,
914        uuid,
915        l2cap_mtu,
916        max_frame_size,
917        initial_credits,
918        max_credits,
919        credits_threshold,
920    ):
921        super().__init__()
922        self.device = device
923        self.channel = channel
924        self.uuid = uuid
925        self.l2cap_mtu = l2cap_mtu
926        self.max_frame_size = max_frame_size
927        self.initial_credits = initial_credits
928        self.max_credits = max_credits
929        self.credits_threshold = credits_threshold
930        self.rfcomm_session = None
931        self.ready = asyncio.Event()
932
933    async def on_connection(self, connection):
934        connection.on('disconnection', self.on_disconnection)
935
936        # Find the channel number if not specified
937        channel = self.channel
938        if channel == 0:
939            logging.info(
940                color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan')
941            )
942            channel = await bumble.rfcomm.find_rfcomm_channel_with_uuid(
943                connection, self.uuid
944            )
945            logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
946            if channel == 0:
947                logging.info(color('!!! No RFComm service with this UUID found', 'red'))
948                await connection.disconnect()
949                return
950
951        # Create a client and start it
952        logging.info(color('*** Starting RFCOMM client...', 'blue'))
953        rfcomm_options = {}
954        if self.l2cap_mtu:
955            rfcomm_options['l2cap_mtu'] = self.l2cap_mtu
956        rfcomm_client = bumble.rfcomm.Client(connection, **rfcomm_options)
957        rfcomm_mux = await rfcomm_client.start()
958        logging.info(color('*** Started', 'blue'))
959
960        logging.info(color(f'### Opening session for channel {channel}...', 'yellow'))
961        try:
962            dlc_options = {}
963            if self.max_frame_size is not None:
964                dlc_options['max_frame_size'] = self.max_frame_size
965            if self.initial_credits is not None:
966                dlc_options['initial_credits'] = self.initial_credits
967            rfcomm_session = await rfcomm_mux.open_dlc(channel, **dlc_options)
968            logging.info(color(f'### Session open: {rfcomm_session}', 'yellow'))
969            if self.max_credits is not None:
970                rfcomm_session.rx_max_credits = self.max_credits
971            if self.credits_threshold is not None:
972                rfcomm_session.rx_credits_threshold = self.credits_threshold
973
974        except bumble.core.ConnectionError as error:
975            logging.info(color(f'!!! Session open failed: {error}', 'red'))
976            await rfcomm_mux.disconnect()
977            return
978
979        rfcomm_session.sink = self.on_packet
980        self.io_sink = rfcomm_session.write
981        self.rfcomm_session = rfcomm_session
982
983        self.ready.set()
984
985    def on_disconnection(self, _):
986        pass
987
988    async def drain(self):
989        assert self.rfcomm_session
990        await self.rfcomm_session.drain()
991
992
993# -----------------------------------------------------------------------------
994# RfcommServer
995# -----------------------------------------------------------------------------
996class RfcommServer(StreamedPacketIO):
997    def __init__(
998        self,
999        device,
1000        channel,
1001        l2cap_mtu,
1002        max_frame_size,
1003        initial_credits,
1004        max_credits,
1005        credits_threshold,
1006    ):
1007        super().__init__()
1008        self.max_credits = max_credits
1009        self.credits_threshold = credits_threshold
1010        self.dlc = None
1011        self.ready = asyncio.Event()
1012
1013        # Create and register a server
1014        server_options = {}
1015        if l2cap_mtu:
1016            server_options['l2cap_mtu'] = l2cap_mtu
1017        rfcomm_server = bumble.rfcomm.Server(device, **server_options)
1018
1019        # Listen for incoming DLC connections
1020        dlc_options = {}
1021        if max_frame_size is not None:
1022            dlc_options['max_frame_size'] = max_frame_size
1023        if initial_credits is not None:
1024            dlc_options['initial_credits'] = initial_credits
1025        channel_number = rfcomm_server.listen(self.on_dlc, channel, **dlc_options)
1026
1027        # Setup the SDP to advertise this channel
1028        device.sdp_service_records = make_sdp_records(channel_number)
1029
1030        logging.info(
1031            color(
1032                f'### Listening for RFComm connection on channel {channel_number}',
1033                'yellow',
1034            )
1035        )
1036
1037    async def on_connection(self, connection):
1038        connection.on('disconnection', self.on_disconnection)
1039
1040    def on_disconnection(self, _):
1041        pass
1042
1043    def on_dlc(self, dlc):
1044        logging.info(color(f'*** DLC connected: {dlc}', 'blue'))
1045        if self.credits_threshold is not None:
1046            dlc.rx_threshold = self.credits_threshold
1047        if self.max_credits is not None:
1048            dlc.rx_max_credits = self.max_credits
1049        dlc.sink = self.on_packet
1050        self.io_sink = dlc.write
1051        self.dlc = dlc
1052        if self.max_credits is not None:
1053            dlc.rx_max_credits = self.max_credits
1054        if self.credits_threshold is not None:
1055            dlc.rx_credits_threshold = self.credits_threshold
1056
1057    async def drain(self):
1058        assert self.dlc
1059        await self.dlc.drain()
1060
1061
1062# -----------------------------------------------------------------------------
1063# Central
1064# -----------------------------------------------------------------------------
1065class Central(Connection.Listener):
1066    def __init__(
1067        self,
1068        transport,
1069        peripheral_address,
1070        classic,
1071        role_factory,
1072        mode_factory,
1073        connection_interval,
1074        phy,
1075        authenticate,
1076        encrypt,
1077        extended_data_length,
1078        role_switch,
1079    ):
1080        super().__init__()
1081        self.transport = transport
1082        self.peripheral_address = peripheral_address
1083        self.classic = classic
1084        self.role_factory = role_factory
1085        self.mode_factory = mode_factory
1086        self.authenticate = authenticate
1087        self.encrypt = encrypt or authenticate
1088        self.extended_data_length = extended_data_length
1089        self.role_switch = role_switch
1090        self.device = None
1091        self.connection = None
1092
1093        if phy:
1094            self.phy = {
1095                '1m': HCI_LE_1M_PHY,
1096                '2m': HCI_LE_2M_PHY,
1097                'coded': HCI_LE_CODED_PHY,
1098            }[phy]
1099        else:
1100            self.phy = None
1101
1102        if connection_interval:
1103            connection_parameter_preferences = ConnectionParametersPreferences()
1104            connection_parameter_preferences.connection_interval_min = (
1105                connection_interval
1106            )
1107            connection_parameter_preferences.connection_interval_max = (
1108                connection_interval
1109            )
1110
1111            # Preferences for the 1M PHY are always set.
1112            self.connection_parameter_preferences = {
1113                HCI_LE_1M_PHY: connection_parameter_preferences,
1114            }
1115
1116            if self.phy not in (None, HCI_LE_1M_PHY):
1117                # Add an connections parameters entry for this PHY.
1118                self.connection_parameter_preferences[self.phy] = (
1119                    connection_parameter_preferences
1120                )
1121        else:
1122            self.connection_parameter_preferences = None
1123
1124    async def run(self):
1125        logging.info(color('>>> Connecting to HCI...', 'green'))
1126        async with await open_transport_or_link(self.transport) as (
1127            hci_source,
1128            hci_sink,
1129        ):
1130            logging.info(color('>>> Connected', 'green'))
1131
1132            central_address = DEFAULT_CENTRAL_ADDRESS
1133            self.device = Device.with_hci(
1134                DEFAULT_CENTRAL_NAME, central_address, hci_source, hci_sink
1135            )
1136            mode = self.mode_factory(self.device)
1137            role = self.role_factory(mode)
1138            self.device.classic_enabled = self.classic
1139
1140            # Set up a pairing config factory with minimal requirements.
1141            self.device.pairing_config_factory = lambda _: PairingConfig(
1142                sc=False, mitm=False, bonding=False
1143            )
1144
1145            await self.device.power_on()
1146
1147            if self.classic:
1148                await self.device.set_discoverable(False)
1149                await self.device.set_connectable(False)
1150
1151            logging.info(
1152                color(f'### Connecting to {self.peripheral_address}...', 'cyan')
1153            )
1154            try:
1155                self.connection = await self.device.connect(
1156                    self.peripheral_address,
1157                    connection_parameters_preferences=self.connection_parameter_preferences,
1158                    transport=BT_BR_EDR_TRANSPORT if self.classic else BT_LE_TRANSPORT,
1159                )
1160            except CommandTimeoutError:
1161                logging.info(color('!!! Connection timed out', 'red'))
1162                return
1163            except bumble.core.ConnectionError as error:
1164                logging.info(color(f'!!! Connection error: {error}', 'red'))
1165                return
1166            except HCI_StatusError as error:
1167                logging.info(color(f'!!! Connection failed: {error.error_name}'))
1168                return
1169            logging.info(color('### Connected', 'cyan'))
1170            self.connection.listener = self
1171            print_connection(self.connection)
1172
1173            # Switch roles if needed.
1174            if self.role_switch:
1175                await switch_roles(self.connection, self.role_switch)
1176
1177            # Wait a bit after the connection, some controllers aren't very good when
1178            # we start sending data right away while some connection parameters are
1179            # updated post connection
1180            await asyncio.sleep(DEFAULT_POST_CONNECTION_WAIT_TIME)
1181
1182            # Request a new data length if requested
1183            if self.extended_data_length:
1184                logging.info(color('+++ Requesting extended data length', 'cyan'))
1185                await self.connection.set_data_length(
1186                    self.extended_data_length[0], self.extended_data_length[1]
1187                )
1188
1189            # Authenticate if requested
1190            if self.authenticate:
1191                # Request authentication
1192                logging.info(color('*** Authenticating...', 'cyan'))
1193                await self.connection.authenticate()
1194                logging.info(color('*** Authenticated', 'cyan'))
1195
1196            # Encrypt if requested
1197            if self.encrypt:
1198                # Enable encryption
1199                logging.info(color('*** Enabling encryption...', 'cyan'))
1200                await self.connection.encrypt()
1201                logging.info(color('*** Encryption on', 'cyan'))
1202
1203            # Set the PHY if requested
1204            if self.phy is not None:
1205                try:
1206                    await self.connection.set_phy(
1207                        tx_phys=[self.phy], rx_phys=[self.phy]
1208                    )
1209                except HCI_Error as error:
1210                    logging.info(
1211                        color(
1212                            f'!!! Unable to set the PHY: {error.error_name}', 'yellow'
1213                        )
1214                    )
1215
1216            await mode.on_connection(self.connection)
1217
1218            await role.run()
1219            await asyncio.sleep(DEFAULT_LINGER_TIME)
1220            await self.connection.disconnect()
1221
1222    def on_disconnection(self, reason):
1223        logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
1224        self.connection = None
1225
1226    def on_connection_parameters_update(self):
1227        print_connection(self.connection)
1228
1229    def on_connection_phy_update(self):
1230        print_connection(self.connection)
1231
1232    def on_connection_att_mtu_update(self):
1233        print_connection(self.connection)
1234
1235    def on_connection_data_length_change(self):
1236        print_connection(self.connection)
1237
1238    def on_role_change(self):
1239        print_connection(self.connection)
1240
1241
1242# -----------------------------------------------------------------------------
1243# Peripheral
1244# -----------------------------------------------------------------------------
1245class Peripheral(Device.Listener, Connection.Listener):
1246    def __init__(
1247        self,
1248        transport,
1249        role_factory,
1250        mode_factory,
1251        classic,
1252        extended_data_length,
1253        role_switch,
1254    ):
1255        self.transport = transport
1256        self.classic = classic
1257        self.role_factory = role_factory
1258        self.mode_factory = mode_factory
1259        self.extended_data_length = extended_data_length
1260        self.role_switch = role_switch
1261        self.role = None
1262        self.mode = None
1263        self.device = None
1264        self.connection = None
1265        self.connected = asyncio.Event()
1266
1267    async def run(self):
1268        logging.info(color('>>> Connecting to HCI...', 'green'))
1269        async with await open_transport_or_link(self.transport) as (
1270            hci_source,
1271            hci_sink,
1272        ):
1273            logging.info(color('>>> Connected', 'green'))
1274
1275            peripheral_address = DEFAULT_PERIPHERAL_ADDRESS
1276            self.device = Device.with_hci(
1277                DEFAULT_PERIPHERAL_NAME, peripheral_address, hci_source, hci_sink
1278            )
1279            self.device.listener = self
1280            self.mode = self.mode_factory(self.device)
1281            self.role = self.role_factory(self.mode)
1282            self.device.classic_enabled = self.classic
1283
1284            # Set up a pairing config factory with minimal requirements.
1285            self.device.pairing_config_factory = lambda _: PairingConfig(
1286                sc=False, mitm=False, bonding=False
1287            )
1288
1289            await self.device.power_on()
1290
1291            if self.classic:
1292                await self.device.set_discoverable(True)
1293                await self.device.set_connectable(True)
1294            else:
1295                await self.device.start_advertising(auto_restart=True)
1296
1297            if self.classic:
1298                logging.info(
1299                    color(
1300                        '### Waiting for connection on'
1301                        f' {self.device.public_address}...',
1302                        'cyan',
1303                    )
1304                )
1305            else:
1306                logging.info(
1307                    color(
1308                        f'### Waiting for connection on {peripheral_address}...',
1309                        'cyan',
1310                    )
1311                )
1312
1313            await self.connected.wait()
1314            logging.info(color('### Connected', 'cyan'))
1315            print_connection(self.connection)
1316
1317            await self.mode.on_connection(self.connection)
1318            await self.role.run()
1319            await asyncio.sleep(DEFAULT_LINGER_TIME)
1320
1321    def on_connection(self, connection):
1322        connection.listener = self
1323        self.connection = connection
1324        self.connected.set()
1325
1326        # Stop being discoverable and connectable
1327        if self.classic:
1328            AsyncRunner.spawn(self.device.set_discoverable(False))
1329            AsyncRunner.spawn(self.device.set_connectable(False))
1330
1331        # Request a new data length if needed
1332        if not self.classic and self.extended_data_length:
1333            logging.info("+++ Requesting extended data length")
1334            AsyncRunner.spawn(
1335                connection.set_data_length(
1336                    self.extended_data_length[0], self.extended_data_length[1]
1337                )
1338            )
1339
1340        # Switch roles if needed.
1341        if self.role_switch:
1342            AsyncRunner.spawn(switch_roles(connection, self.role_switch))
1343
1344    def on_disconnection(self, reason):
1345        logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
1346        self.connection = None
1347        self.role.reset()
1348
1349        if self.classic:
1350            AsyncRunner.spawn(self.device.set_discoverable(True))
1351            AsyncRunner.spawn(self.device.set_connectable(True))
1352
1353    def on_connection_parameters_update(self):
1354        print_connection(self.connection)
1355
1356    def on_connection_phy_update(self):
1357        print_connection(self.connection)
1358
1359    def on_connection_att_mtu_update(self):
1360        print_connection(self.connection)
1361
1362    def on_connection_data_length_change(self):
1363        print_connection(self.connection)
1364
1365    def on_role_change(self):
1366        print_connection(self.connection)
1367
1368
1369# -----------------------------------------------------------------------------
1370def create_mode_factory(ctx, default_mode):
1371    mode = ctx.obj['mode']
1372    if mode is None:
1373        mode = default_mode
1374
1375    def create_mode(device):
1376        if mode == 'gatt-client':
1377            return GattClient(device, att_mtu=ctx.obj['att_mtu'])
1378
1379        if mode == 'gatt-server':
1380            return GattServer(device)
1381
1382        if mode == 'l2cap-client':
1383            return L2capClient(
1384                device,
1385                psm=ctx.obj['l2cap_psm'],
1386                mtu=ctx.obj['l2cap_mtu'],
1387                mps=ctx.obj['l2cap_mps'],
1388                max_credits=ctx.obj['l2cap_max_credits'],
1389            )
1390
1391        if mode == 'l2cap-server':
1392            return L2capServer(
1393                device,
1394                psm=ctx.obj['l2cap_psm'],
1395                mtu=ctx.obj['l2cap_mtu'],
1396                mps=ctx.obj['l2cap_mps'],
1397                max_credits=ctx.obj['l2cap_max_credits'],
1398            )
1399
1400        if mode == 'rfcomm-client':
1401            return RfcommClient(
1402                device,
1403                channel=ctx.obj['rfcomm_channel'],
1404                uuid=ctx.obj['rfcomm_uuid'],
1405                l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
1406                max_frame_size=ctx.obj['rfcomm_max_frame_size'],
1407                initial_credits=ctx.obj['rfcomm_initial_credits'],
1408                max_credits=ctx.obj['rfcomm_max_credits'],
1409                credits_threshold=ctx.obj['rfcomm_credits_threshold'],
1410            )
1411
1412        if mode == 'rfcomm-server':
1413            return RfcommServer(
1414                device,
1415                channel=ctx.obj['rfcomm_channel'],
1416                l2cap_mtu=ctx.obj['rfcomm_l2cap_mtu'],
1417                max_frame_size=ctx.obj['rfcomm_max_frame_size'],
1418                initial_credits=ctx.obj['rfcomm_initial_credits'],
1419                max_credits=ctx.obj['rfcomm_max_credits'],
1420                credits_threshold=ctx.obj['rfcomm_credits_threshold'],
1421            )
1422
1423        raise ValueError('invalid mode')
1424
1425    return create_mode
1426
1427
1428# -----------------------------------------------------------------------------
1429def create_role_factory(ctx, default_role):
1430    role = ctx.obj['role']
1431    if role is None:
1432        role = default_role
1433
1434    def create_role(packet_io):
1435        if role == 'sender':
1436            return Sender(
1437                packet_io,
1438                start_delay=ctx.obj['start_delay'],
1439                repeat=ctx.obj['repeat'],
1440                repeat_delay=ctx.obj['repeat_delay'],
1441                pace=ctx.obj['pace'],
1442                packet_size=ctx.obj['packet_size'],
1443                packet_count=ctx.obj['packet_count'],
1444            )
1445
1446        if role == 'receiver':
1447            return Receiver(packet_io, ctx.obj['linger'])
1448
1449        if role == 'ping':
1450            return Ping(
1451                packet_io,
1452                start_delay=ctx.obj['start_delay'],
1453                repeat=ctx.obj['repeat'],
1454                repeat_delay=ctx.obj['repeat_delay'],
1455                pace=ctx.obj['pace'],
1456                packet_size=ctx.obj['packet_size'],
1457                packet_count=ctx.obj['packet_count'],
1458            )
1459
1460        if role == 'pong':
1461            return Pong(packet_io, ctx.obj['linger'])
1462
1463        raise ValueError('invalid role')
1464
1465    return create_role
1466
1467
1468# -----------------------------------------------------------------------------
1469# Main
1470# -----------------------------------------------------------------------------
1471@click.group()
1472@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
1473@click.option('--role', type=click.Choice(['sender', 'receiver', 'ping', 'pong']))
1474@click.option(
1475    '--mode',
1476    type=click.Choice(
1477        [
1478            'gatt-client',
1479            'gatt-server',
1480            'l2cap-client',
1481            'l2cap-server',
1482            'rfcomm-client',
1483            'rfcomm-server',
1484        ]
1485    ),
1486)
1487@click.option(
1488    '--att-mtu',
1489    metavar='MTU',
1490    type=click.IntRange(23, 517),
1491    help='GATT MTU (gatt-client mode)',
1492)
1493@click.option(
1494    '--extended-data-length',
1495    help='Request a data length upon connection, specified as tx_octets/tx_time',
1496)
1497@click.option(
1498    '--role-switch',
1499    type=click.Choice(['central', 'peripheral']),
1500    help='Request role switch upon connection (central or peripheral)',
1501)
1502@click.option(
1503    '--rfcomm-channel',
1504    type=int,
1505    default=DEFAULT_RFCOMM_CHANNEL,
1506    help='RFComm channel to use',
1507)
1508@click.option(
1509    '--rfcomm-uuid',
1510    default=DEFAULT_RFCOMM_UUID,
1511    help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)',
1512)
1513@click.option(
1514    '--rfcomm-l2cap-mtu',
1515    type=int,
1516    help='RFComm L2CAP MTU',
1517)
1518@click.option(
1519    '--rfcomm-max-frame-size',
1520    type=int,
1521    help='RFComm maximum frame size',
1522)
1523@click.option(
1524    '--rfcomm-initial-credits',
1525    type=int,
1526    help='RFComm initial credits',
1527)
1528@click.option(
1529    '--rfcomm-max-credits',
1530    type=int,
1531    help='RFComm max credits',
1532)
1533@click.option(
1534    '--rfcomm-credits-threshold',
1535    type=int,
1536    help='RFComm credits threshold',
1537)
1538@click.option(
1539    '--l2cap-psm',
1540    type=int,
1541    default=DEFAULT_L2CAP_PSM,
1542    help='L2CAP PSM to use',
1543)
1544@click.option(
1545    '--l2cap-mtu',
1546    type=int,
1547    default=DEFAULT_L2CAP_MTU,
1548    help='L2CAP MTU to use',
1549)
1550@click.option(
1551    '--l2cap-mps',
1552    type=int,
1553    default=DEFAULT_L2CAP_MPS,
1554    help='L2CAP MPS to use',
1555)
1556@click.option(
1557    '--l2cap-max-credits',
1558    type=int,
1559    default=DEFAULT_L2CAP_MAX_CREDITS,
1560    help='L2CAP maximum number of credits allowed for the peer',
1561)
1562@click.option(
1563    '--packet-size',
1564    '-s',
1565    metavar='SIZE',
1566    type=click.IntRange(8, 8192),
1567    default=500,
1568    help='Packet size (client or ping role)',
1569)
1570@click.option(
1571    '--packet-count',
1572    '-c',
1573    metavar='COUNT',
1574    type=int,
1575    default=10,
1576    help='Packet count (client or ping role)',
1577)
1578@click.option(
1579    '--start-delay',
1580    '-sd',
1581    metavar='SECONDS',
1582    type=int,
1583    default=1,
1584    help='Start delay (client or ping role)',
1585)
1586@click.option(
1587    '--repeat',
1588    metavar='N',
1589    type=int,
1590    default=0,
1591    help=(
1592        'Repeat the run N times (client and ping roles)'
1593        '(0, which is the fault, to run just once) '
1594    ),
1595)
1596@click.option(
1597    '--repeat-delay',
1598    metavar='SECONDS',
1599    type=int,
1600    default=1,
1601    help=('Delay, in seconds, between repeats'),
1602)
1603@click.option(
1604    '--pace',
1605    metavar='MILLISECONDS',
1606    type=int,
1607    default=0,
1608    help=(
1609        'Wait N milliseconds between packets '
1610        '(0, which is the fault, to send as fast as possible) '
1611    ),
1612)
1613@click.option(
1614    '--linger',
1615    is_flag=True,
1616    help="Don't exit at the end of a run (server and pong roles)",
1617)
1618@click.pass_context
1619def bench(
1620    ctx,
1621    device_config,
1622    role,
1623    mode,
1624    att_mtu,
1625    extended_data_length,
1626    role_switch,
1627    packet_size,
1628    packet_count,
1629    start_delay,
1630    repeat,
1631    repeat_delay,
1632    pace,
1633    linger,
1634    rfcomm_channel,
1635    rfcomm_uuid,
1636    rfcomm_l2cap_mtu,
1637    rfcomm_max_frame_size,
1638    rfcomm_initial_credits,
1639    rfcomm_max_credits,
1640    rfcomm_credits_threshold,
1641    l2cap_psm,
1642    l2cap_mtu,
1643    l2cap_mps,
1644    l2cap_max_credits,
1645):
1646    ctx.ensure_object(dict)
1647    ctx.obj['device_config'] = device_config
1648    ctx.obj['role'] = role
1649    ctx.obj['mode'] = mode
1650    ctx.obj['att_mtu'] = att_mtu
1651    ctx.obj['rfcomm_channel'] = rfcomm_channel
1652    ctx.obj['rfcomm_uuid'] = rfcomm_uuid
1653    ctx.obj['rfcomm_l2cap_mtu'] = rfcomm_l2cap_mtu
1654    ctx.obj['rfcomm_max_frame_size'] = rfcomm_max_frame_size
1655    ctx.obj['rfcomm_initial_credits'] = rfcomm_initial_credits
1656    ctx.obj['rfcomm_max_credits'] = rfcomm_max_credits
1657    ctx.obj['rfcomm_credits_threshold'] = rfcomm_credits_threshold
1658    ctx.obj['l2cap_psm'] = l2cap_psm
1659    ctx.obj['l2cap_mtu'] = l2cap_mtu
1660    ctx.obj['l2cap_mps'] = l2cap_mps
1661    ctx.obj['l2cap_max_credits'] = l2cap_max_credits
1662    ctx.obj['packet_size'] = packet_size
1663    ctx.obj['packet_count'] = packet_count
1664    ctx.obj['start_delay'] = start_delay
1665    ctx.obj['repeat'] = repeat
1666    ctx.obj['repeat_delay'] = repeat_delay
1667    ctx.obj['pace'] = pace
1668    ctx.obj['linger'] = linger
1669    ctx.obj['extended_data_length'] = (
1670        [int(x) for x in extended_data_length.split('/')]
1671        if extended_data_length
1672        else None
1673    )
1674    ctx.obj['role_switch'] = role_switch
1675    ctx.obj['classic'] = mode in ('rfcomm-client', 'rfcomm-server')
1676
1677
1678@bench.command()
1679@click.argument('transport')
1680@click.option(
1681    '--peripheral',
1682    'peripheral_address',
1683    metavar='ADDRESS_OR_NAME',
1684    default=DEFAULT_PERIPHERAL_ADDRESS,
1685    help='Address or name to connect to',
1686)
1687@click.option(
1688    '--connection-interval',
1689    '--ci',
1690    metavar='CONNECTION_INTERVAL',
1691    type=int,
1692    help='Connection interval (in ms)',
1693)
1694@click.option('--phy', type=click.Choice(['1m', '2m', 'coded']), help='PHY to use')
1695@click.option('--authenticate', is_flag=True, help='Authenticate (RFComm only)')
1696@click.option('--encrypt', is_flag=True, help='Encrypt the connection (RFComm only)')
1697@click.pass_context
1698def central(
1699    ctx, transport, peripheral_address, connection_interval, phy, authenticate, encrypt
1700):
1701    """Run as a central (initiates the connection)"""
1702    role_factory = create_role_factory(ctx, 'sender')
1703    mode_factory = create_mode_factory(ctx, 'gatt-client')
1704    classic = ctx.obj['classic']
1705
1706    async def run_central():
1707        await Central(
1708            transport,
1709            peripheral_address,
1710            classic,
1711            role_factory,
1712            mode_factory,
1713            connection_interval,
1714            phy,
1715            authenticate,
1716            encrypt or authenticate,
1717            ctx.obj['extended_data_length'],
1718            ctx.obj['role_switch'],
1719        ).run()
1720
1721    asyncio.run(run_central())
1722
1723
1724@bench.command()
1725@click.argument('transport')
1726@click.pass_context
1727def peripheral(ctx, transport):
1728    """Run as a peripheral (waits for a connection)"""
1729    role_factory = create_role_factory(ctx, 'receiver')
1730    mode_factory = create_mode_factory(ctx, 'gatt-server')
1731
1732    async def run_peripheral():
1733        await Peripheral(
1734            transport,
1735            role_factory,
1736            mode_factory,
1737            ctx.obj['classic'],
1738            ctx.obj['extended_data_length'],
1739            ctx.obj['role_switch'],
1740        ).run()
1741
1742    asyncio.run(run_peripheral())
1743
1744
1745def main():
1746    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
1747    bench()
1748
1749
1750# -----------------------------------------------------------------------------
1751if __name__ == "__main__":
1752    main()  # pylint: disable=no-value-for-parameter
1753