1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import functools
20import logging
21import os
22from types import LambdaType
23import pytest
24from unittest import mock
25
26from bumble.core import (
27    BT_BR_EDR_TRANSPORT,
28    BT_LE_TRANSPORT,
29    BT_PERIPHERAL_ROLE,
30    ConnectionParameters,
31)
32from bumble.device import AdvertisingParameters, Connection, Device
33from bumble.host import AclPacketQueue, Host
34from bumble.hci import (
35    HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
36    HCI_COMMAND_STATUS_PENDING,
37    HCI_CREATE_CONNECTION_COMMAND,
38    HCI_SUCCESS,
39    HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR,
40    Address,
41    OwnAddressType,
42    HCI_Command_Complete_Event,
43    HCI_Command_Status_Event,
44    HCI_Connection_Complete_Event,
45    HCI_Connection_Request_Event,
46    HCI_Error,
47    HCI_Packet,
48)
49from bumble.gatt import (
50    GATT_GENERIC_ACCESS_SERVICE,
51    GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
52    GATT_DEVICE_NAME_CHARACTERISTIC,
53    GATT_APPEARANCE_CHARACTERISTIC,
54)
55
56from .test_utils import TwoDevices, async_barrier
57
58# -----------------------------------------------------------------------------
59# Constants
60# -----------------------------------------------------------------------------
61_TIMEOUT = 0.1
62
63# -----------------------------------------------------------------------------
64# Logging
65# -----------------------------------------------------------------------------
66logger = logging.getLogger(__name__)
67
68
69# -----------------------------------------------------------------------------
70class Sink:
71    def __init__(self, flow):
72        self.flow = flow
73        next(self.flow)
74
75    def on_packet(self, packet):
76        self.flow.send(packet)
77
78
79# -----------------------------------------------------------------------------
80@pytest.mark.asyncio
81async def test_device_connect_parallel():
82    d0 = Device(host=Host(None, None))
83    d1 = Device(host=Host(None, None))
84    d2 = Device(host=Host(None, None))
85
86    def _send(packet):
87        pass
88
89    d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
90    d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
91    d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
92
93    # enable classic
94    d0.classic_enabled = True
95    d1.classic_enabled = True
96    d2.classic_enabled = True
97
98    # set public addresses
99    d0.public_address = Address(
100        'F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS
101    )
102    d1.public_address = Address(
103        'F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS
104    )
105    d2.public_address = Address(
106        'F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS
107    )
108
109    def d0_flow():
110        packet = HCI_Packet.from_bytes((yield))
111        assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
112        assert packet.bd_addr == d1.public_address
113
114        d0.host.on_hci_packet(
115            HCI_Command_Status_Event(
116                status=HCI_COMMAND_STATUS_PENDING,
117                num_hci_command_packets=1,
118                command_opcode=HCI_CREATE_CONNECTION_COMMAND,
119            )
120        )
121
122        d1.host.on_hci_packet(
123            HCI_Connection_Request_Event(
124                bd_addr=d0.public_address,
125                class_of_device=0,
126                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
127            )
128        )
129
130        packet = HCI_Packet.from_bytes((yield))
131        assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
132        assert packet.bd_addr == d2.public_address
133
134        d0.host.on_hci_packet(
135            HCI_Command_Status_Event(
136                status=HCI_COMMAND_STATUS_PENDING,
137                num_hci_command_packets=1,
138                command_opcode=HCI_CREATE_CONNECTION_COMMAND,
139            )
140        )
141
142        d2.host.on_hci_packet(
143            HCI_Connection_Request_Event(
144                bd_addr=d0.public_address,
145                class_of_device=0,
146                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
147            )
148        )
149
150        assert (yield) == None
151
152    def d1_flow():
153        packet = HCI_Packet.from_bytes((yield))
154        assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
155
156        d1.host.on_hci_packet(
157            HCI_Command_Complete_Event(
158                num_hci_command_packets=1,
159                command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
160                return_parameters=b"\x00",
161            )
162        )
163
164        d1.host.on_hci_packet(
165            HCI_Connection_Complete_Event(
166                status=HCI_SUCCESS,
167                connection_handle=0x100,
168                bd_addr=d0.public_address,
169                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
170                encryption_enabled=True,
171            )
172        )
173
174        d0.host.on_hci_packet(
175            HCI_Connection_Complete_Event(
176                status=HCI_SUCCESS,
177                connection_handle=0x100,
178                bd_addr=d1.public_address,
179                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
180                encryption_enabled=True,
181            )
182        )
183
184        assert (yield) == None
185
186    def d2_flow():
187        packet = HCI_Packet.from_bytes((yield))
188        assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
189
190        d2.host.on_hci_packet(
191            HCI_Command_Complete_Event(
192                num_hci_command_packets=1,
193                command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
194                return_parameters=b"\x00",
195            )
196        )
197
198        d2.host.on_hci_packet(
199            HCI_Connection_Complete_Event(
200                status=HCI_SUCCESS,
201                connection_handle=0x101,
202                bd_addr=d0.public_address,
203                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
204                encryption_enabled=True,
205            )
206        )
207
208        d0.host.on_hci_packet(
209            HCI_Connection_Complete_Event(
210                status=HCI_SUCCESS,
211                connection_handle=0x101,
212                bd_addr=d2.public_address,
213                link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
214                encryption_enabled=True,
215            )
216        )
217
218        assert (yield) == None
219
220    d0.host.set_packet_sink(Sink(d0_flow()))
221    d1.host.set_packet_sink(Sink(d1_flow()))
222    d2.host.set_packet_sink(Sink(d2_flow()))
223
224    d1_accept_task = asyncio.create_task(d1.accept(peer_address=d0.public_address))
225    d2_accept_task = asyncio.create_task(d2.accept())
226
227    # Ensure that the accept tasks have started.
228    await async_barrier()
229
230    [c01, c02, a10, a20] = await asyncio.gather(
231        *[
232            asyncio.create_task(
233                d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)
234            ),
235            asyncio.create_task(
236                d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)
237            ),
238            d1_accept_task,
239            d2_accept_task,
240        ]
241    )
242
243    assert type(c01) == Connection
244    assert type(c02) == Connection
245    assert type(a10) == Connection
246    assert type(a20) == Connection
247
248    assert c01.handle == a10.handle and c01.handle == 0x100
249    assert c02.handle == a20.handle and c02.handle == 0x101
250
251
252# -----------------------------------------------------------------------------
253@pytest.mark.asyncio
254async def test_flush():
255    d0 = Device(host=Host(None, None))
256    task = d0.abort_on('flush', asyncio.sleep(10000))
257    await d0.host.flush()
258    try:
259        await task
260        assert False
261    except asyncio.CancelledError:
262        pass
263
264
265# -----------------------------------------------------------------------------
266@pytest.mark.asyncio
267async def test_legacy_advertising():
268    device = Device(host=mock.AsyncMock(Host))
269
270    # Start advertising
271    await device.start_advertising()
272    assert device.is_advertising
273
274    # Stop advertising
275    await device.stop_advertising()
276    assert not device.is_advertising
277
278
279# -----------------------------------------------------------------------------
280@pytest.mark.parametrize(
281    'auto_restart,',
282    (True, False),
283)
284@pytest.mark.asyncio
285async def test_legacy_advertising_disconnection(auto_restart):
286    device = Device(host=mock.AsyncMock(spec=Host))
287    peer_address = Address('F0:F1:F2:F3:F4:F5')
288    await device.start_advertising(auto_restart=auto_restart)
289    device.on_connection(
290        0x0001,
291        BT_LE_TRANSPORT,
292        peer_address,
293        None,
294        None,
295        BT_PERIPHERAL_ROLE,
296        ConnectionParameters(0, 0, 0),
297    )
298
299    device.on_advertising_set_termination(
300        HCI_SUCCESS, device.legacy_advertising_set.advertising_handle, 0x0001, 0
301    )
302
303    device.on_disconnection(0x0001, 0)
304    await async_barrier()
305    await async_barrier()
306
307    if auto_restart:
308        assert device.is_advertising
309    else:
310        assert not device.is_advertising
311
312
313# -----------------------------------------------------------------------------
314@pytest.mark.asyncio
315async def test_extended_advertising():
316    device = Device(host=mock.AsyncMock(Host))
317
318    # Start advertising
319    advertising_set = await device.create_advertising_set()
320    assert device.extended_advertising_sets
321    assert advertising_set.enabled
322
323    # Stop advertising
324    await advertising_set.stop()
325    assert not advertising_set.enabled
326
327
328# -----------------------------------------------------------------------------
329@pytest.mark.parametrize(
330    'own_address_type,',
331    (OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
332)
333@pytest.mark.asyncio
334async def test_extended_advertising_connection(own_address_type):
335    device = Device(host=mock.AsyncMock(spec=Host))
336    peer_address = Address('F0:F1:F2:F3:F4:F5')
337    advertising_set = await device.create_advertising_set(
338        advertising_parameters=AdvertisingParameters(own_address_type=own_address_type)
339    )
340    device.on_connection(
341        0x0001,
342        BT_LE_TRANSPORT,
343        peer_address,
344        None,
345        None,
346        BT_PERIPHERAL_ROLE,
347        ConnectionParameters(0, 0, 0),
348    )
349    device.on_advertising_set_termination(
350        HCI_SUCCESS,
351        advertising_set.advertising_handle,
352        0x0001,
353        0,
354    )
355
356    if own_address_type == OwnAddressType.PUBLIC:
357        assert device.lookup_connection(0x0001).self_address == device.public_address
358    else:
359        assert device.lookup_connection(0x0001).self_address == device.random_address
360
361    await async_barrier()
362
363
364# -----------------------------------------------------------------------------
365@pytest.mark.parametrize(
366    'own_address_type,',
367    (OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
368)
369@pytest.mark.asyncio
370async def test_extended_advertising_connection_out_of_order(own_address_type):
371    device = Device(host=mock.AsyncMock(spec=Host))
372    peer_address = Address('F0:F1:F2:F3:F4:F5')
373    advertising_set = await device.create_advertising_set(
374        advertising_parameters=AdvertisingParameters(own_address_type=own_address_type)
375    )
376    device.on_advertising_set_termination(
377        HCI_SUCCESS,
378        advertising_set.advertising_handle,
379        0x0001,
380        0,
381    )
382    device.on_connection(
383        0x0001,
384        BT_LE_TRANSPORT,
385        peer_address,
386        None,
387        None,
388        BT_PERIPHERAL_ROLE,
389        ConnectionParameters(0, 0, 0),
390    )
391
392    if own_address_type == OwnAddressType.PUBLIC:
393        assert device.lookup_connection(0x0001).self_address == device.public_address
394    else:
395        assert device.lookup_connection(0x0001).self_address == device.random_address
396
397    await async_barrier()
398
399
400# -----------------------------------------------------------------------------
401@pytest.mark.asyncio
402async def test_get_remote_le_features():
403    devices = TwoDevices()
404    await devices.setup_connection()
405
406    assert (await devices.connections[0].get_remote_le_features()) is not None
407
408
409# -----------------------------------------------------------------------------
410@pytest.mark.asyncio
411async def test_get_remote_le_features_failed():
412    devices = TwoDevices()
413    await devices.setup_connection()
414
415    def on_hci_le_read_remote_features_complete_event(event):
416        devices[0].host.emit(
417            'le_remote_features_failure',
418            event.connection_handle,
419            HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR,
420        )
421
422    devices[0].host.on_hci_le_read_remote_features_complete_event = (
423        on_hci_le_read_remote_features_complete_event
424    )
425
426    with pytest.raises(HCI_Error):
427        await asyncio.wait_for(
428            devices.connections[0].get_remote_le_features(), _TIMEOUT
429        )
430
431
432# -----------------------------------------------------------------------------
433@pytest.mark.asyncio
434async def test_cis():
435    devices = TwoDevices()
436    await devices.setup_connection()
437
438    peripheral_cis_futures = {}
439
440    def on_cis_request(
441        acl_connection: Connection,
442        cis_handle: int,
443        _cig_id: int,
444        _cis_id: int,
445    ):
446        acl_connection.abort_on(
447            'disconnection', devices[1].accept_cis_request(cis_handle)
448        )
449        peripheral_cis_futures[cis_handle] = asyncio.get_running_loop().create_future()
450
451    devices[1].on('cis_request', on_cis_request)
452    devices[1].on(
453        'cis_establishment',
454        lambda cis_link: peripheral_cis_futures[cis_link.handle].set_result(None),
455    )
456
457    cis_handles = await devices[0].setup_cig(
458        cig_id=1,
459        cis_id=[2, 3],
460        sdu_interval=(0, 0),
461        framing=0,
462        max_sdu=(0, 0),
463        retransmission_number=0,
464        max_transport_latency=(0, 0),
465    )
466    assert len(cis_handles) == 2
467    cis_links = await devices[0].create_cis(
468        [
469            (cis_handles[0], devices.connections[0].handle),
470            (cis_handles[1], devices.connections[0].handle),
471        ]
472    )
473    await asyncio.gather(*peripheral_cis_futures.values())
474    assert len(cis_links) == 2
475
476    await cis_links[0].disconnect()
477    await cis_links[1].disconnect()
478
479
480# -----------------------------------------------------------------------------
481@pytest.mark.asyncio
482async def test_cis_setup_failure():
483    devices = TwoDevices()
484    await devices.setup_connection()
485
486    cis_requests = asyncio.Queue()
487
488    def on_cis_request(
489        acl_connection: Connection,
490        cis_handle: int,
491        cig_id: int,
492        cis_id: int,
493    ):
494        del acl_connection, cig_id, cis_id
495        cis_requests.put_nowait(cis_handle)
496
497    devices[1].on('cis_request', on_cis_request)
498
499    cis_handles = await devices[0].setup_cig(
500        cig_id=1,
501        cis_id=[2],
502        sdu_interval=(0, 0),
503        framing=0,
504        max_sdu=(0, 0),
505        retransmission_number=0,
506        max_transport_latency=(0, 0),
507    )
508    assert len(cis_handles) == 1
509
510    cis_create_task = asyncio.create_task(
511        devices[0].create_cis(
512            [
513                (cis_handles[0], devices.connections[0].handle),
514            ]
515        )
516    )
517
518    def on_hci_le_cis_established_event(host, event):
519        host.emit(
520            'cis_establishment_failure',
521            event.connection_handle,
522            HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR,
523        )
524
525    for device in devices:
526        device.host.on_hci_le_cis_established_event = functools.partial(
527            on_hci_le_cis_established_event, device.host
528        )
529
530    cis_request = await asyncio.wait_for(cis_requests.get(), _TIMEOUT)
531
532    with pytest.raises(HCI_Error):
533        await asyncio.wait_for(devices[1].accept_cis_request(cis_request), _TIMEOUT)
534
535    with pytest.raises(HCI_Error):
536        await asyncio.wait_for(cis_create_task, _TIMEOUT)
537
538
539# -----------------------------------------------------------------------------
540@pytest.mark.asyncio
541async def test_power_on_default_static_address_should_not_be_any():
542    devices = TwoDevices()
543    devices[0].static_address = devices[0].random_address = Address.ANY_RANDOM
544    await devices[0].power_on()
545
546    assert devices[0].static_address != Address.ANY_RANDOM
547
548
549# -----------------------------------------------------------------------------
550def test_gatt_services_with_gas():
551    device = Device(host=Host(None, None))
552
553    # there should be one service and two chars, therefore 5 attributes
554    assert len(device.gatt_server.attributes) == 5
555    assert device.gatt_server.attributes[0].uuid == GATT_GENERIC_ACCESS_SERVICE
556    assert device.gatt_server.attributes[1].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
557    assert device.gatt_server.attributes[2].uuid == GATT_DEVICE_NAME_CHARACTERISTIC
558    assert device.gatt_server.attributes[3].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
559    assert device.gatt_server.attributes[4].uuid == GATT_APPEARANCE_CHARACTERISTIC
560
561
562# -----------------------------------------------------------------------------
563def test_gatt_services_without_gas():
564    device = Device(host=Host(None, None), generic_access_service=False)
565
566    # there should be no services
567    assert len(device.gatt_server.attributes) == 0
568
569
570# -----------------------------------------------------------------------------
571async def run_test_device():
572    await test_device_connect_parallel()
573    await test_flush()
574    await test_gatt_services_with_gas()
575    await test_gatt_services_without_gas()
576
577
578# -----------------------------------------------------------------------------
579if __name__ == '__main__':
580    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
581    asyncio.run(run_test_device())
582