1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import functools 20import logging 21import os 22from types import LambdaType 23import pytest 24from unittest import mock 25 26from bumble.core import ( 27 BT_BR_EDR_TRANSPORT, 28 BT_LE_TRANSPORT, 29 BT_PERIPHERAL_ROLE, 30 ConnectionParameters, 31) 32from bumble.device import AdvertisingParameters, Connection, Device 33from bumble.host import AclPacketQueue, Host 34from bumble.hci import ( 35 HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, 36 HCI_COMMAND_STATUS_PENDING, 37 HCI_CREATE_CONNECTION_COMMAND, 38 HCI_SUCCESS, 39 HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR, 40 Address, 41 OwnAddressType, 42 HCI_Command_Complete_Event, 43 HCI_Command_Status_Event, 44 HCI_Connection_Complete_Event, 45 HCI_Connection_Request_Event, 46 HCI_Error, 47 HCI_Packet, 48) 49from bumble.gatt import ( 50 GATT_GENERIC_ACCESS_SERVICE, 51 GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, 52 GATT_DEVICE_NAME_CHARACTERISTIC, 53 GATT_APPEARANCE_CHARACTERISTIC, 54) 55 56from .test_utils import TwoDevices, async_barrier 57 58# ----------------------------------------------------------------------------- 59# Constants 60# ----------------------------------------------------------------------------- 61_TIMEOUT = 0.1 62 63# ----------------------------------------------------------------------------- 64# Logging 65# ----------------------------------------------------------------------------- 66logger = logging.getLogger(__name__) 67 68 69# ----------------------------------------------------------------------------- 70class Sink: 71 def __init__(self, flow): 72 self.flow = flow 73 next(self.flow) 74 75 def on_packet(self, packet): 76 self.flow.send(packet) 77 78 79# ----------------------------------------------------------------------------- 80@pytest.mark.asyncio 81async def test_device_connect_parallel(): 82 d0 = Device(host=Host(None, None)) 83 d1 = Device(host=Host(None, None)) 84 d2 = Device(host=Host(None, None)) 85 86 def _send(packet): 87 pass 88 89 d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send) 90 d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send) 91 d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send) 92 93 # enable classic 94 d0.classic_enabled = True 95 d1.classic_enabled = True 96 d2.classic_enabled = True 97 98 # set public addresses 99 d0.public_address = Address( 100 'F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS 101 ) 102 d1.public_address = Address( 103 'F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS 104 ) 105 d2.public_address = Address( 106 'F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS 107 ) 108 109 def d0_flow(): 110 packet = HCI_Packet.from_bytes((yield)) 111 assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND' 112 assert packet.bd_addr == d1.public_address 113 114 d0.host.on_hci_packet( 115 HCI_Command_Status_Event( 116 status=HCI_COMMAND_STATUS_PENDING, 117 num_hci_command_packets=1, 118 command_opcode=HCI_CREATE_CONNECTION_COMMAND, 119 ) 120 ) 121 122 d1.host.on_hci_packet( 123 HCI_Connection_Request_Event( 124 bd_addr=d0.public_address, 125 class_of_device=0, 126 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 127 ) 128 ) 129 130 packet = HCI_Packet.from_bytes((yield)) 131 assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND' 132 assert packet.bd_addr == d2.public_address 133 134 d0.host.on_hci_packet( 135 HCI_Command_Status_Event( 136 status=HCI_COMMAND_STATUS_PENDING, 137 num_hci_command_packets=1, 138 command_opcode=HCI_CREATE_CONNECTION_COMMAND, 139 ) 140 ) 141 142 d2.host.on_hci_packet( 143 HCI_Connection_Request_Event( 144 bd_addr=d0.public_address, 145 class_of_device=0, 146 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 147 ) 148 ) 149 150 assert (yield) == None 151 152 def d1_flow(): 153 packet = HCI_Packet.from_bytes((yield)) 154 assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' 155 156 d1.host.on_hci_packet( 157 HCI_Command_Complete_Event( 158 num_hci_command_packets=1, 159 command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, 160 return_parameters=b"\x00", 161 ) 162 ) 163 164 d1.host.on_hci_packet( 165 HCI_Connection_Complete_Event( 166 status=HCI_SUCCESS, 167 connection_handle=0x100, 168 bd_addr=d0.public_address, 169 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 170 encryption_enabled=True, 171 ) 172 ) 173 174 d0.host.on_hci_packet( 175 HCI_Connection_Complete_Event( 176 status=HCI_SUCCESS, 177 connection_handle=0x100, 178 bd_addr=d1.public_address, 179 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 180 encryption_enabled=True, 181 ) 182 ) 183 184 assert (yield) == None 185 186 def d2_flow(): 187 packet = HCI_Packet.from_bytes((yield)) 188 assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND' 189 190 d2.host.on_hci_packet( 191 HCI_Command_Complete_Event( 192 num_hci_command_packets=1, 193 command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, 194 return_parameters=b"\x00", 195 ) 196 ) 197 198 d2.host.on_hci_packet( 199 HCI_Connection_Complete_Event( 200 status=HCI_SUCCESS, 201 connection_handle=0x101, 202 bd_addr=d0.public_address, 203 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 204 encryption_enabled=True, 205 ) 206 ) 207 208 d0.host.on_hci_packet( 209 HCI_Connection_Complete_Event( 210 status=HCI_SUCCESS, 211 connection_handle=0x101, 212 bd_addr=d2.public_address, 213 link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE, 214 encryption_enabled=True, 215 ) 216 ) 217 218 assert (yield) == None 219 220 d0.host.set_packet_sink(Sink(d0_flow())) 221 d1.host.set_packet_sink(Sink(d1_flow())) 222 d2.host.set_packet_sink(Sink(d2_flow())) 223 224 d1_accept_task = asyncio.create_task(d1.accept(peer_address=d0.public_address)) 225 d2_accept_task = asyncio.create_task(d2.accept()) 226 227 # Ensure that the accept tasks have started. 228 await async_barrier() 229 230 [c01, c02, a10, a20] = await asyncio.gather( 231 *[ 232 asyncio.create_task( 233 d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT) 234 ), 235 asyncio.create_task( 236 d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT) 237 ), 238 d1_accept_task, 239 d2_accept_task, 240 ] 241 ) 242 243 assert type(c01) == Connection 244 assert type(c02) == Connection 245 assert type(a10) == Connection 246 assert type(a20) == Connection 247 248 assert c01.handle == a10.handle and c01.handle == 0x100 249 assert c02.handle == a20.handle and c02.handle == 0x101 250 251 252# ----------------------------------------------------------------------------- 253@pytest.mark.asyncio 254async def test_flush(): 255 d0 = Device(host=Host(None, None)) 256 task = d0.abort_on('flush', asyncio.sleep(10000)) 257 await d0.host.flush() 258 try: 259 await task 260 assert False 261 except asyncio.CancelledError: 262 pass 263 264 265# ----------------------------------------------------------------------------- 266@pytest.mark.asyncio 267async def test_legacy_advertising(): 268 device = Device(host=mock.AsyncMock(Host)) 269 270 # Start advertising 271 await device.start_advertising() 272 assert device.is_advertising 273 274 # Stop advertising 275 await device.stop_advertising() 276 assert not device.is_advertising 277 278 279# ----------------------------------------------------------------------------- 280@pytest.mark.parametrize( 281 'auto_restart,', 282 (True, False), 283) 284@pytest.mark.asyncio 285async def test_legacy_advertising_disconnection(auto_restart): 286 device = Device(host=mock.AsyncMock(spec=Host)) 287 peer_address = Address('F0:F1:F2:F3:F4:F5') 288 await device.start_advertising(auto_restart=auto_restart) 289 device.on_connection( 290 0x0001, 291 BT_LE_TRANSPORT, 292 peer_address, 293 None, 294 None, 295 BT_PERIPHERAL_ROLE, 296 ConnectionParameters(0, 0, 0), 297 ) 298 299 device.on_advertising_set_termination( 300 HCI_SUCCESS, device.legacy_advertising_set.advertising_handle, 0x0001, 0 301 ) 302 303 device.on_disconnection(0x0001, 0) 304 await async_barrier() 305 await async_barrier() 306 307 if auto_restart: 308 assert device.is_advertising 309 else: 310 assert not device.is_advertising 311 312 313# ----------------------------------------------------------------------------- 314@pytest.mark.asyncio 315async def test_extended_advertising(): 316 device = Device(host=mock.AsyncMock(Host)) 317 318 # Start advertising 319 advertising_set = await device.create_advertising_set() 320 assert device.extended_advertising_sets 321 assert advertising_set.enabled 322 323 # Stop advertising 324 await advertising_set.stop() 325 assert not advertising_set.enabled 326 327 328# ----------------------------------------------------------------------------- 329@pytest.mark.parametrize( 330 'own_address_type,', 331 (OwnAddressType.PUBLIC, OwnAddressType.RANDOM), 332) 333@pytest.mark.asyncio 334async def test_extended_advertising_connection(own_address_type): 335 device = Device(host=mock.AsyncMock(spec=Host)) 336 peer_address = Address('F0:F1:F2:F3:F4:F5') 337 advertising_set = await device.create_advertising_set( 338 advertising_parameters=AdvertisingParameters(own_address_type=own_address_type) 339 ) 340 device.on_connection( 341 0x0001, 342 BT_LE_TRANSPORT, 343 peer_address, 344 None, 345 None, 346 BT_PERIPHERAL_ROLE, 347 ConnectionParameters(0, 0, 0), 348 ) 349 device.on_advertising_set_termination( 350 HCI_SUCCESS, 351 advertising_set.advertising_handle, 352 0x0001, 353 0, 354 ) 355 356 if own_address_type == OwnAddressType.PUBLIC: 357 assert device.lookup_connection(0x0001).self_address == device.public_address 358 else: 359 assert device.lookup_connection(0x0001).self_address == device.random_address 360 361 await async_barrier() 362 363 364# ----------------------------------------------------------------------------- 365@pytest.mark.parametrize( 366 'own_address_type,', 367 (OwnAddressType.PUBLIC, OwnAddressType.RANDOM), 368) 369@pytest.mark.asyncio 370async def test_extended_advertising_connection_out_of_order(own_address_type): 371 device = Device(host=mock.AsyncMock(spec=Host)) 372 peer_address = Address('F0:F1:F2:F3:F4:F5') 373 advertising_set = await device.create_advertising_set( 374 advertising_parameters=AdvertisingParameters(own_address_type=own_address_type) 375 ) 376 device.on_advertising_set_termination( 377 HCI_SUCCESS, 378 advertising_set.advertising_handle, 379 0x0001, 380 0, 381 ) 382 device.on_connection( 383 0x0001, 384 BT_LE_TRANSPORT, 385 peer_address, 386 None, 387 None, 388 BT_PERIPHERAL_ROLE, 389 ConnectionParameters(0, 0, 0), 390 ) 391 392 if own_address_type == OwnAddressType.PUBLIC: 393 assert device.lookup_connection(0x0001).self_address == device.public_address 394 else: 395 assert device.lookup_connection(0x0001).self_address == device.random_address 396 397 await async_barrier() 398 399 400# ----------------------------------------------------------------------------- 401@pytest.mark.asyncio 402async def test_get_remote_le_features(): 403 devices = TwoDevices() 404 await devices.setup_connection() 405 406 assert (await devices.connections[0].get_remote_le_features()) is not None 407 408 409# ----------------------------------------------------------------------------- 410@pytest.mark.asyncio 411async def test_get_remote_le_features_failed(): 412 devices = TwoDevices() 413 await devices.setup_connection() 414 415 def on_hci_le_read_remote_features_complete_event(event): 416 devices[0].host.emit( 417 'le_remote_features_failure', 418 event.connection_handle, 419 HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR, 420 ) 421 422 devices[0].host.on_hci_le_read_remote_features_complete_event = ( 423 on_hci_le_read_remote_features_complete_event 424 ) 425 426 with pytest.raises(HCI_Error): 427 await asyncio.wait_for( 428 devices.connections[0].get_remote_le_features(), _TIMEOUT 429 ) 430 431 432# ----------------------------------------------------------------------------- 433@pytest.mark.asyncio 434async def test_cis(): 435 devices = TwoDevices() 436 await devices.setup_connection() 437 438 peripheral_cis_futures = {} 439 440 def on_cis_request( 441 acl_connection: Connection, 442 cis_handle: int, 443 _cig_id: int, 444 _cis_id: int, 445 ): 446 acl_connection.abort_on( 447 'disconnection', devices[1].accept_cis_request(cis_handle) 448 ) 449 peripheral_cis_futures[cis_handle] = asyncio.get_running_loop().create_future() 450 451 devices[1].on('cis_request', on_cis_request) 452 devices[1].on( 453 'cis_establishment', 454 lambda cis_link: peripheral_cis_futures[cis_link.handle].set_result(None), 455 ) 456 457 cis_handles = await devices[0].setup_cig( 458 cig_id=1, 459 cis_id=[2, 3], 460 sdu_interval=(0, 0), 461 framing=0, 462 max_sdu=(0, 0), 463 retransmission_number=0, 464 max_transport_latency=(0, 0), 465 ) 466 assert len(cis_handles) == 2 467 cis_links = await devices[0].create_cis( 468 [ 469 (cis_handles[0], devices.connections[0].handle), 470 (cis_handles[1], devices.connections[0].handle), 471 ] 472 ) 473 await asyncio.gather(*peripheral_cis_futures.values()) 474 assert len(cis_links) == 2 475 476 await cis_links[0].disconnect() 477 await cis_links[1].disconnect() 478 479 480# ----------------------------------------------------------------------------- 481@pytest.mark.asyncio 482async def test_cis_setup_failure(): 483 devices = TwoDevices() 484 await devices.setup_connection() 485 486 cis_requests = asyncio.Queue() 487 488 def on_cis_request( 489 acl_connection: Connection, 490 cis_handle: int, 491 cig_id: int, 492 cis_id: int, 493 ): 494 del acl_connection, cig_id, cis_id 495 cis_requests.put_nowait(cis_handle) 496 497 devices[1].on('cis_request', on_cis_request) 498 499 cis_handles = await devices[0].setup_cig( 500 cig_id=1, 501 cis_id=[2], 502 sdu_interval=(0, 0), 503 framing=0, 504 max_sdu=(0, 0), 505 retransmission_number=0, 506 max_transport_latency=(0, 0), 507 ) 508 assert len(cis_handles) == 1 509 510 cis_create_task = asyncio.create_task( 511 devices[0].create_cis( 512 [ 513 (cis_handles[0], devices.connections[0].handle), 514 ] 515 ) 516 ) 517 518 def on_hci_le_cis_established_event(host, event): 519 host.emit( 520 'cis_establishment_failure', 521 event.connection_handle, 522 HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR, 523 ) 524 525 for device in devices: 526 device.host.on_hci_le_cis_established_event = functools.partial( 527 on_hci_le_cis_established_event, device.host 528 ) 529 530 cis_request = await asyncio.wait_for(cis_requests.get(), _TIMEOUT) 531 532 with pytest.raises(HCI_Error): 533 await asyncio.wait_for(devices[1].accept_cis_request(cis_request), _TIMEOUT) 534 535 with pytest.raises(HCI_Error): 536 await asyncio.wait_for(cis_create_task, _TIMEOUT) 537 538 539# ----------------------------------------------------------------------------- 540@pytest.mark.asyncio 541async def test_power_on_default_static_address_should_not_be_any(): 542 devices = TwoDevices() 543 devices[0].static_address = devices[0].random_address = Address.ANY_RANDOM 544 await devices[0].power_on() 545 546 assert devices[0].static_address != Address.ANY_RANDOM 547 548 549# ----------------------------------------------------------------------------- 550def test_gatt_services_with_gas(): 551 device = Device(host=Host(None, None)) 552 553 # there should be one service and two chars, therefore 5 attributes 554 assert len(device.gatt_server.attributes) == 5 555 assert device.gatt_server.attributes[0].uuid == GATT_GENERIC_ACCESS_SERVICE 556 assert device.gatt_server.attributes[1].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE 557 assert device.gatt_server.attributes[2].uuid == GATT_DEVICE_NAME_CHARACTERISTIC 558 assert device.gatt_server.attributes[3].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE 559 assert device.gatt_server.attributes[4].uuid == GATT_APPEARANCE_CHARACTERISTIC 560 561 562# ----------------------------------------------------------------------------- 563def test_gatt_services_without_gas(): 564 device = Device(host=Host(None, None), generic_access_service=False) 565 566 # there should be no services 567 assert len(device.gatt_server.attributes) == 0 568 569 570# ----------------------------------------------------------------------------- 571async def run_test_device(): 572 await test_device_connect_parallel() 573 await test_flush() 574 await test_gatt_services_with_gas() 575 await test_gatt_services_without_gas() 576 577 578# ----------------------------------------------------------------------------- 579if __name__ == '__main__': 580 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) 581 asyncio.run(run_test_device()) 582