1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import logging 20import os 21import pytest 22 23from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID 24from bumble.sdp import ( 25 DataElement, 26 ServiceAttribute, 27 Client, 28 Server, 29 SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, 30 SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, 31 SDP_PUBLIC_BROWSE_ROOT, 32 SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, 33 SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, 34) 35from .test_utils import TwoDevices 36 37# ----------------------------------------------------------------------------- 38# pylint: disable=invalid-name 39# ----------------------------------------------------------------------------- 40 41 42# ----------------------------------------------------------------------------- 43def basic_check(x: DataElement) -> None: 44 serialized = bytes(x) 45 if len(serialized) < 500: 46 print('Original:', x) 47 print('Serialized:', serialized.hex()) 48 parsed = DataElement.from_bytes(serialized) 49 if len(serialized) < 500: 50 print('Parsed:', parsed) 51 parsed_bytes = bytes(parsed) 52 if len(serialized) < 500: 53 print('Parsed Bytes:', parsed_bytes.hex()) 54 assert parsed_bytes == serialized 55 x_str = str(x) 56 parsed_str = str(parsed) 57 assert x_str == parsed_str 58 59 60# ----------------------------------------------------------------------------- 61def test_data_elements() -> None: 62 e = DataElement(DataElement.NIL, None) 63 basic_check(e) 64 65 e = DataElement(DataElement.UNSIGNED_INTEGER, 12, 1) 66 basic_check(e) 67 68 e = DataElement(DataElement.UNSIGNED_INTEGER, 1234, 2) 69 basic_check(e) 70 71 e = DataElement(DataElement.UNSIGNED_INTEGER, 0x123456, 4) 72 basic_check(e) 73 74 e = DataElement(DataElement.UNSIGNED_INTEGER, 0x123456789, 8) 75 basic_check(e) 76 77 e = DataElement(DataElement.UNSIGNED_INTEGER, 0x0000FFFF, value_size=4) 78 basic_check(e) 79 80 e = DataElement(DataElement.SIGNED_INTEGER, -12, 1) 81 basic_check(e) 82 83 e = DataElement(DataElement.SIGNED_INTEGER, -1234, 2) 84 basic_check(e) 85 86 e = DataElement(DataElement.SIGNED_INTEGER, -0x123456, 4) 87 basic_check(e) 88 89 e = DataElement(DataElement.SIGNED_INTEGER, -0x123456789, 8) 90 basic_check(e) 91 92 e = DataElement(DataElement.SIGNED_INTEGER, 0x0000FFFF, value_size=4) 93 basic_check(e) 94 95 e = DataElement(DataElement.UUID, UUID.from_16_bits(1234)) 96 basic_check(e) 97 98 e = DataElement(DataElement.UUID, UUID.from_32_bits(123456789)) 99 basic_check(e) 100 101 e = DataElement(DataElement.UUID, UUID('61A3512C-09BE-4DDC-A6A6-0B03667AAFC6')) 102 basic_check(e) 103 104 e = DataElement(DataElement.TEXT_STRING, b'hello') 105 basic_check(e) 106 107 e = DataElement(DataElement.TEXT_STRING, b'hello' * 60) 108 basic_check(e) 109 110 e = DataElement(DataElement.TEXT_STRING, b'hello' * 20000) 111 basic_check(e) 112 113 e = DataElement(DataElement.BOOLEAN, True) 114 basic_check(e) 115 116 e = DataElement(DataElement.BOOLEAN, False) 117 basic_check(e) 118 119 e = DataElement(DataElement.SEQUENCE, [DataElement(DataElement.BOOLEAN, True)]) 120 basic_check(e) 121 122 e = DataElement( 123 DataElement.SEQUENCE, 124 [ 125 DataElement(DataElement.BOOLEAN, True), 126 DataElement(DataElement.TEXT_STRING, b'hello'), 127 ], 128 ) 129 basic_check(e) 130 131 e = DataElement(DataElement.ALTERNATIVE, [DataElement(DataElement.BOOLEAN, True)]) 132 basic_check(e) 133 134 e = DataElement( 135 DataElement.ALTERNATIVE, 136 [ 137 DataElement(DataElement.BOOLEAN, True), 138 DataElement(DataElement.TEXT_STRING, b'hello'), 139 ], 140 ) 141 basic_check(e) 142 143 e = DataElement(DataElement.URL, 'http://example.com') 144 145 e = DataElement.nil() 146 147 e = DataElement.unsigned_integer(1234, 2) 148 basic_check(e) 149 150 e = DataElement.signed_integer(-1234, 2) 151 basic_check(e) 152 153 e = DataElement.uuid(UUID.from_16_bits(1234)) 154 basic_check(e) 155 156 e = DataElement.text_string(b'hello') 157 basic_check(e) 158 159 e = DataElement.boolean(True) 160 basic_check(e) 161 162 e = DataElement.sequence( 163 [DataElement.signed_integer(0, 1), DataElement.text_string(b'hello')] 164 ) 165 basic_check(e) 166 167 e = DataElement.alternative( 168 [DataElement.signed_integer(0, 1), DataElement.text_string(b'hello')] 169 ) 170 basic_check(e) 171 172 e = DataElement.url('http://foobar.com') 173 basic_check(e) 174 175 176# ----------------------------------------------------------------------------- 177def sdp_records(): 178 return { 179 0x00010001: [ 180 ServiceAttribute( 181 SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, 182 DataElement.unsigned_integer_32(0x00010001), 183 ), 184 ServiceAttribute( 185 SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, 186 DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), 187 ), 188 ServiceAttribute( 189 SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, 190 DataElement.sequence( 191 [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))] 192 ), 193 ), 194 ServiceAttribute( 195 SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, 196 DataElement.sequence( 197 [ 198 DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]), 199 ] 200 ), 201 ), 202 ] 203 } 204 205 206# ----------------------------------------------------------------------------- 207@pytest.mark.asyncio 208async def test_service_search(): 209 # Setup connections 210 devices = TwoDevices() 211 await devices.setup_connection() 212 assert devices.connections[0] 213 assert devices.connections[1] 214 215 # Register SDP service 216 devices.devices[0].sdp_server.service_records.update(sdp_records()) 217 218 # Search for service 219 client = Client(devices.connections[1]) 220 await client.connect() 221 services = await client.search_services( 222 [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')] 223 ) 224 225 # Then 226 assert services[0] == 0x00010001 227 228 229# ----------------------------------------------------------------------------- 230@pytest.mark.asyncio 231async def test_service_attribute(): 232 # Setup connections 233 devices = TwoDevices() 234 await devices.setup_connection() 235 236 # Register SDP service 237 devices.devices[0].sdp_server.service_records.update(sdp_records()) 238 239 # Search for service 240 client = Client(devices.connections[1]) 241 await client.connect() 242 attributes = await client.get_attributes( 243 0x00010001, [SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID] 244 ) 245 246 # Then 247 assert attributes[0].value.value == sdp_records()[0x00010001][0].value.value 248 249 250# ----------------------------------------------------------------------------- 251@pytest.mark.asyncio 252async def test_service_search_attribute(): 253 # Setup connections 254 devices = TwoDevices() 255 await devices.setup_connection() 256 257 # Register SDP service 258 devices.devices[0].sdp_server.service_records.update(sdp_records()) 259 260 # Search for service 261 client = Client(devices.connections[1]) 262 await client.connect() 263 attributes = await client.search_attributes( 264 [UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0x0000FFFF, 8)] 265 ) 266 267 # Then 268 for expect, actual in zip(attributes, sdp_records().values()): 269 assert expect.id == actual.id 270 assert expect.value == actual.value 271 272 273# ----------------------------------------------------------------------------- 274@pytest.mark.asyncio 275async def test_client_async_context(): 276 devices = TwoDevices() 277 await devices.setup_connection() 278 279 client = Client(devices.connections[1]) 280 281 async with client: 282 assert client.channel is not None 283 284 assert client.channel is None 285 286 287# ----------------------------------------------------------------------------- 288async def run(): 289 test_data_elements() 290 await test_service_attribute() 291 await test_service_search() 292 await test_service_search_attribute() 293 294 295# ----------------------------------------------------------------------------- 296if __name__ == '__main__': 297 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) 298 asyncio.run(run()) 299