1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import logging
20import struct
21from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
22from typing_extensions import Self
23
24from . import core, l2cap
25from .colors import color
26from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
27from .hci import HCI_Object, name_or_number, key_with_value
28
29if TYPE_CHECKING:
30    from .device import Device, Connection
31
32# -----------------------------------------------------------------------------
33# Logging
34# -----------------------------------------------------------------------------
35logger = logging.getLogger(__name__)
36
37
38# -----------------------------------------------------------------------------
39# Constants
40# -----------------------------------------------------------------------------
41# fmt: off
42# pylint: disable=line-too-long
43
44SDP_CONTINUATION_WATCHDOG = 64  # Maximum number of continuations we're willing to do
45
46SDP_PSM = 0x0001
47
48SDP_ERROR_RESPONSE                    = 0x01
49SDP_SERVICE_SEARCH_REQUEST            = 0x02
50SDP_SERVICE_SEARCH_RESPONSE           = 0x03
51SDP_SERVICE_ATTRIBUTE_REQUEST         = 0x04
52SDP_SERVICE_ATTRIBUTE_RESPONSE        = 0x05
53SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST  = 0x06
54SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE = 0x07
55
56SDP_PDU_NAMES = {
57    SDP_ERROR_RESPONSE:                    'SDP_ERROR_RESPONSE',
58    SDP_SERVICE_SEARCH_REQUEST:            'SDP_SERVICE_SEARCH_REQUEST',
59    SDP_SERVICE_SEARCH_RESPONSE:           'SDP_SERVICE_SEARCH_RESPONSE',
60    SDP_SERVICE_ATTRIBUTE_REQUEST:         'SDP_SERVICE_ATTRIBUTE_REQUEST',
61    SDP_SERVICE_ATTRIBUTE_RESPONSE:        'SDP_SERVICE_ATTRIBUTE_RESPONSE',
62    SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST:  'SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST',
63    SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE: 'SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE'
64}
65
66SDP_INVALID_SDP_VERSION_ERROR                       = 0x0001
67SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR             = 0x0002
68SDP_INVALID_REQUEST_SYNTAX_ERROR                    = 0x0003
69SDP_INVALID_PDU_SIZE_ERROR                          = 0x0004
70SDP_INVALID_CONTINUATION_STATE_ERROR                = 0x0005
71SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR = 0x0006
72
73SDP_ERROR_NAMES = {
74    SDP_INVALID_SDP_VERSION_ERROR:                       'SDP_INVALID_SDP_VERSION_ERROR',
75    SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR:             'SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR',
76    SDP_INVALID_REQUEST_SYNTAX_ERROR:                    'SDP_INVALID_REQUEST_SYNTAX_ERROR',
77    SDP_INVALID_PDU_SIZE_ERROR:                          'SDP_INVALID_PDU_SIZE_ERROR',
78    SDP_INVALID_CONTINUATION_STATE_ERROR:                'SDP_INVALID_CONTINUATION_STATE_ERROR',
79    SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR: 'SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR'
80}
81
82SDP_SERVICE_NAME_ATTRIBUTE_ID_OFFSET        = 0x0000
83SDP_SERVICE_DESCRIPTION_ATTRIBUTE_ID_OFFSET = 0x0001
84SDP_PROVIDER_NAME_ATTRIBUTE_ID_OFFSET       = 0x0002
85
86SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID               = 0X0000
87SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID               = 0X0001
88SDP_SERVICE_RECORD_STATE_ATTRIBUTE_ID                = 0X0002
89SDP_SERVICE_ID_ATTRIBUTE_ID                          = 0X0003
90SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID            = 0X0004
91SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID                   = 0X0005
92SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID     = 0X0006
93SDP_SERVICE_INFO_TIME_TO_LIVE_ATTRIBUTE_ID           = 0X0007
94SDP_SERVICE_AVAILABILITY_ATTRIBUTE_ID                = 0X0008
95SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID   = 0X0009
96SDP_DOCUMENTATION_URL_ATTRIBUTE_ID                   = 0X000A
97SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID               = 0X000B
98SDP_ICON_URL_ATTRIBUTE_ID                            = 0X000C
99SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
100
101
102# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery)
103# used by AVRCP, HFP and A2DP
104SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
105
106SDP_ATTRIBUTE_ID_NAMES = {
107    SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID:               'SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID',
108    SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:               'SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID',
109    SDP_SERVICE_RECORD_STATE_ATTRIBUTE_ID:                'SDP_SERVICE_RECORD_STATE_ATTRIBUTE_ID',
110    SDP_SERVICE_ID_ATTRIBUTE_ID:                          'SDP_SERVICE_ID_ATTRIBUTE_ID',
111    SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:            'SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID',
112    SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID:                   'SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID',
113    SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID:     'SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID',
114    SDP_SERVICE_INFO_TIME_TO_LIVE_ATTRIBUTE_ID:           'SDP_SERVICE_INFO_TIME_TO_LIVE_ATTRIBUTE_ID',
115    SDP_SERVICE_AVAILABILITY_ATTRIBUTE_ID:                'SDP_SERVICE_AVAILABILITY_ATTRIBUTE_ID',
116    SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID:   'SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID',
117    SDP_DOCUMENTATION_URL_ATTRIBUTE_ID:                   'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID',
118    SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID:               'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID',
119    SDP_ICON_URL_ATTRIBUTE_ID:                            'SDP_ICON_URL_ATTRIBUTE_ID',
120    SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID',
121    SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID:                  'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID',
122}
123
124SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
125
126# To be used in searches where an attribute ID list allows a range to be specified
127SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4)  # Express this as tuple so we can convey the desired encoding size
128
129# fmt: on
130# pylint: enable=line-too-long
131# pylint: disable=invalid-name
132
133
134# -----------------------------------------------------------------------------
135class DataElement:
136    NIL = 0
137    UNSIGNED_INTEGER = 1
138    SIGNED_INTEGER = 2
139    UUID = 3
140    TEXT_STRING = 4
141    BOOLEAN = 5
142    SEQUENCE = 6
143    ALTERNATIVE = 7
144    URL = 8
145
146    TYPE_NAMES = {
147        NIL: 'NIL',
148        UNSIGNED_INTEGER: 'UNSIGNED_INTEGER',
149        SIGNED_INTEGER: 'SIGNED_INTEGER',
150        UUID: 'UUID',
151        TEXT_STRING: 'TEXT_STRING',
152        BOOLEAN: 'BOOLEAN',
153        SEQUENCE: 'SEQUENCE',
154        ALTERNATIVE: 'ALTERNATIVE',
155        URL: 'URL',
156    }
157
158    type_constructors = {
159        NIL: lambda x: DataElement(DataElement.NIL, None),
160        UNSIGNED_INTEGER: lambda x, y: DataElement(
161            DataElement.UNSIGNED_INTEGER,
162            DataElement.unsigned_integer_from_bytes(x),
163            value_size=y,
164        ),
165        SIGNED_INTEGER: lambda x, y: DataElement(
166            DataElement.SIGNED_INTEGER,
167            DataElement.signed_integer_from_bytes(x),
168            value_size=y,
169        ),
170        UUID: lambda x: DataElement(
171            DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
172        ),
173        TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
174        BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
175        SEQUENCE: lambda x: DataElement(
176            DataElement.SEQUENCE, DataElement.list_from_bytes(x)
177        ),
178        ALTERNATIVE: lambda x: DataElement(
179            DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
180        ),
181        URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
182    }
183
184    def __init__(self, element_type, value, value_size=None):
185        self.type = element_type
186        self.value = value
187        self.value_size = value_size
188        # Used as a cache when parsing from bytes so we can emit a byte-for-byte replica
189        self.bytes = None
190        if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
191            if value_size is None:
192                raise InvalidArgumentError(
193                    'integer types must have a value size specified'
194                )
195
196    @staticmethod
197    def nil() -> DataElement:
198        return DataElement(DataElement.NIL, None)
199
200    @staticmethod
201    def unsigned_integer(value: int, value_size: int) -> DataElement:
202        return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
203
204    @staticmethod
205    def unsigned_integer_8(value: int) -> DataElement:
206        return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
207
208    @staticmethod
209    def unsigned_integer_16(value: int) -> DataElement:
210        return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
211
212    @staticmethod
213    def unsigned_integer_32(value: int) -> DataElement:
214        return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
215
216    @staticmethod
217    def signed_integer(value: int, value_size: int) -> DataElement:
218        return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
219
220    @staticmethod
221    def signed_integer_8(value: int) -> DataElement:
222        return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
223
224    @staticmethod
225    def signed_integer_16(value: int) -> DataElement:
226        return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
227
228    @staticmethod
229    def signed_integer_32(value: int) -> DataElement:
230        return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
231
232    @staticmethod
233    def uuid(value: core.UUID) -> DataElement:
234        return DataElement(DataElement.UUID, value)
235
236    @staticmethod
237    def text_string(value: bytes) -> DataElement:
238        return DataElement(DataElement.TEXT_STRING, value)
239
240    @staticmethod
241    def boolean(value: bool) -> DataElement:
242        return DataElement(DataElement.BOOLEAN, value)
243
244    @staticmethod
245    def sequence(value: List[DataElement]) -> DataElement:
246        return DataElement(DataElement.SEQUENCE, value)
247
248    @staticmethod
249    def alternative(value: List[DataElement]) -> DataElement:
250        return DataElement(DataElement.ALTERNATIVE, value)
251
252    @staticmethod
253    def url(value: str) -> DataElement:
254        return DataElement(DataElement.URL, value)
255
256    @staticmethod
257    def unsigned_integer_from_bytes(data):
258        if len(data) == 1:
259            return data[0]
260
261        if len(data) == 2:
262            return struct.unpack('>H', data)[0]
263
264        if len(data) == 4:
265            return struct.unpack('>I', data)[0]
266
267        if len(data) == 8:
268            return struct.unpack('>Q', data)[0]
269
270        raise InvalidPacketError(f'invalid integer length {len(data)}')
271
272    @staticmethod
273    def signed_integer_from_bytes(data):
274        if len(data) == 1:
275            return struct.unpack('b', data)[0]
276
277        if len(data) == 2:
278            return struct.unpack('>h', data)[0]
279
280        if len(data) == 4:
281            return struct.unpack('>i', data)[0]
282
283        if len(data) == 8:
284            return struct.unpack('>q', data)[0]
285
286        raise InvalidPacketError(f'invalid integer length {len(data)}')
287
288    @staticmethod
289    def list_from_bytes(data):
290        elements = []
291        while data:
292            element = DataElement.from_bytes(data)
293            elements.append(element)
294            data = data[len(bytes(element)) :]
295        return elements
296
297    @staticmethod
298    def parse_from_bytes(data, offset):
299        element = DataElement.from_bytes(data[offset:])
300        return offset + len(bytes(element)), element
301
302    @staticmethod
303    def from_bytes(data):
304        element_type = data[0] >> 3
305        size_index = data[0] & 7
306        value_offset = 0
307        if size_index == 0:
308            if element_type == DataElement.NIL:
309                value_size = 0
310            else:
311                value_size = 1
312        elif size_index == 1:
313            value_size = 2
314        elif size_index == 2:
315            value_size = 4
316        elif size_index == 3:
317            value_size = 8
318        elif size_index == 4:
319            value_size = 16
320        elif size_index == 5:
321            value_size = data[1]
322            value_offset = 1
323        elif size_index == 6:
324            value_size = struct.unpack('>H', data[1:3])[0]
325            value_offset = 2
326        else:  # size_index == 7
327            value_size = struct.unpack('>I', data[1:5])[0]
328            value_offset = 4
329
330        value_data = data[1 + value_offset : 1 + value_offset + value_size]
331        constructor = DataElement.type_constructors.get(element_type)
332        if constructor:
333            if element_type in (
334                DataElement.UNSIGNED_INTEGER,
335                DataElement.SIGNED_INTEGER,
336            ):
337                result = constructor(value_data, value_size)
338            else:
339                result = constructor(value_data)
340        else:
341            result = DataElement(element_type, value_data)
342        result.bytes = data[
343            : 1 + value_offset + value_size
344        ]  # Keep a copy so we can re-serialize to an exact replica
345        return result
346
347    def to_bytes(self):
348        return bytes(self)
349
350    def __bytes__(self):
351        # Return early if we have a cache
352        if self.bytes:
353            return self.bytes
354
355        if self.type == DataElement.NIL:
356            data = b''
357        elif self.type == DataElement.UNSIGNED_INTEGER:
358            if self.value < 0:
359                raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
360
361            if self.value_size == 1:
362                data = struct.pack('B', self.value)
363            elif self.value_size == 2:
364                data = struct.pack('>H', self.value)
365            elif self.value_size == 4:
366                data = struct.pack('>I', self.value)
367            elif self.value_size == 8:
368                data = struct.pack('>Q', self.value)
369            else:
370                raise InvalidArgumentError('invalid value_size')
371        elif self.type == DataElement.SIGNED_INTEGER:
372            if self.value_size == 1:
373                data = struct.pack('b', self.value)
374            elif self.value_size == 2:
375                data = struct.pack('>h', self.value)
376            elif self.value_size == 4:
377                data = struct.pack('>i', self.value)
378            elif self.value_size == 8:
379                data = struct.pack('>q', self.value)
380            else:
381                raise InvalidArgumentError('invalid value_size')
382        elif self.type == DataElement.UUID:
383            data = bytes(reversed(bytes(self.value)))
384        elif self.type == DataElement.URL:
385            data = self.value.encode('utf8')
386        elif self.type == DataElement.BOOLEAN:
387            data = bytes([1 if self.value else 0])
388        elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
389            data = b''.join([bytes(element) for element in self.value])
390        else:
391            data = self.value
392
393        size = len(data)
394        size_bytes = b''
395        if self.type == DataElement.NIL:
396            if size != 0:
397                raise InvalidArgumentError('NIL must be empty')
398            size_index = 0
399        elif self.type in (
400            DataElement.UNSIGNED_INTEGER,
401            DataElement.SIGNED_INTEGER,
402            DataElement.UUID,
403        ):
404            if size <= 1:
405                size_index = 0
406            elif size == 2:
407                size_index = 1
408            elif size == 4:
409                size_index = 2
410            elif size == 8:
411                size_index = 3
412            elif size == 16:
413                size_index = 4
414            else:
415                raise InvalidArgumentError('invalid data size')
416        elif self.type in (
417            DataElement.TEXT_STRING,
418            DataElement.SEQUENCE,
419            DataElement.ALTERNATIVE,
420            DataElement.URL,
421        ):
422            if size <= 0xFF:
423                size_index = 5
424                size_bytes = bytes([size])
425            elif size <= 0xFFFF:
426                size_index = 6
427                size_bytes = struct.pack('>H', size)
428            elif size <= 0xFFFFFFFF:
429                size_index = 7
430                size_bytes = struct.pack('>I', size)
431            else:
432                raise InvalidArgumentError('invalid data size')
433        elif self.type == DataElement.BOOLEAN:
434            if size != 1:
435                raise InvalidArgumentError('boolean must be 1 byte')
436            size_index = 0
437
438        self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
439        return self.bytes
440
441    def to_string(self, pretty=False, indentation=0):
442        prefix = '  ' * indentation
443        type_name = name_or_number(self.TYPE_NAMES, self.type)
444        if self.type == DataElement.NIL:
445            value_string = ''
446        elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
447            container_separator = '\n' if pretty else ''
448            element_separator = '\n' if pretty else ','
449            elements = [
450                element.to_string(pretty, indentation + 1 if pretty else 0)
451                for element in self.value
452            ]
453            value_string = (
454                f'[{container_separator}'
455                f'{element_separator.join(elements)}'
456                f'{container_separator}{prefix}]'
457            )
458        elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
459            value_string = f'{self.value}#{self.value_size}'
460        elif isinstance(self.value, DataElement):
461            value_string = self.value.to_string(pretty, indentation)
462        else:
463            value_string = str(self.value)
464        return f'{prefix}{type_name}({value_string})'
465
466    def __str__(self):
467        return self.to_string()
468
469
470# -----------------------------------------------------------------------------
471class ServiceAttribute:
472    def __init__(self, attribute_id: int, value: DataElement) -> None:
473        self.id = attribute_id
474        self.value = value
475
476    @staticmethod
477    def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]:
478        attribute_list = []
479        for i in range(0, len(elements) // 2):
480            attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
481            if attribute_id.type != DataElement.UNSIGNED_INTEGER:
482                logger.warning('attribute ID element is not an integer')
483                continue
484            attribute_list.append(ServiceAttribute(attribute_id.value, attribute_value))
485
486        return attribute_list
487
488    @staticmethod
489    def find_attribute_in_list(
490        attribute_list: List[ServiceAttribute], attribute_id: int
491    ) -> Optional[DataElement]:
492        return next(
493            (
494                attribute.value
495                for attribute in attribute_list
496                if attribute.id == attribute_id
497            ),
498            None,
499        )
500
501    @staticmethod
502    def id_name(id_code):
503        return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
504
505    @staticmethod
506    def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool:
507        # Find if a uuid matches a value, either directly or recursing into sequences
508        if value.type == DataElement.UUID:
509            return value.value == uuid
510
511        if value.type == DataElement.SEQUENCE:
512            for element in value.value:
513                if ServiceAttribute.is_uuid_in_value(uuid, element):
514                    return True
515            return False
516
517        return False
518
519    def to_string(self, with_colors=False):
520        if with_colors:
521            return (
522                f'Attribute(id={color(self.id_name(self.id),"magenta")},'
523                f'value={self.value})'
524            )
525
526        return f'Attribute(id={self.id_name(self.id)},value={self.value})'
527
528    def __str__(self):
529        return self.to_string()
530
531
532# -----------------------------------------------------------------------------
533class SDP_PDU:
534    '''
535    See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
536    '''
537
538    sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {}
539    name = None
540    pdu_id = 0
541
542    @staticmethod
543    def from_bytes(pdu):
544        pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
545
546        cls = SDP_PDU.sdp_pdu_classes.get(pdu_id)
547        if cls is None:
548            instance = SDP_PDU(pdu)
549            instance.name = SDP_PDU.pdu_name(pdu_id)
550            instance.pdu_id = pdu_id
551            instance.transaction_id = transaction_id
552            return instance
553        self = cls.__new__(cls)
554        SDP_PDU.__init__(self, pdu, transaction_id)
555        if hasattr(self, 'fields'):
556            self.init_from_bytes(pdu, 5)
557        return self
558
559    @staticmethod
560    def parse_service_record_handle_list_preceded_by_count(
561        data: bytes, offset: int
562    ) -> Tuple[int, List[int]]:
563        count = struct.unpack_from('>H', data, offset - 2)[0]
564        handle_list = [
565            struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
566        ]
567        return offset + count * 4, handle_list
568
569    @staticmethod
570    def parse_bytes_preceded_by_length(data, offset):
571        length = struct.unpack_from('>H', data, offset - 2)[0]
572        return offset + length, data[offset : offset + length]
573
574    @staticmethod
575    def error_name(error_code):
576        return name_or_number(SDP_ERROR_NAMES, error_code)
577
578    @staticmethod
579    def pdu_name(code):
580        return name_or_number(SDP_PDU_NAMES, code)
581
582    @staticmethod
583    def subclass(fields):
584        def inner(cls):
585            name = cls.__name__
586
587            # add a _ character before every uppercase letter, except the SDP_ prefix
588            location = len(name) - 1
589            while location > 4:
590                if not name[location].isupper():
591                    location -= 1
592                    continue
593                name = name[:location] + '_' + name[location:]
594                location -= 1
595
596            cls.name = name.upper()
597            cls.pdu_id = key_with_value(SDP_PDU_NAMES, cls.name)
598            if cls.pdu_id is None:
599                raise KeyError(f'PDU name {cls.name} not found in SDP_PDU_NAMES')
600            cls.fields = fields
601
602            # Register a factory for this class
603            SDP_PDU.sdp_pdu_classes[cls.pdu_id] = cls
604
605            return cls
606
607        return inner
608
609    def __init__(self, pdu=None, transaction_id=0, **kwargs):
610        if hasattr(self, 'fields') and kwargs:
611            HCI_Object.init_from_fields(self, self.fields, kwargs)
612        if pdu is None:
613            parameters = HCI_Object.dict_to_bytes(kwargs, self.fields)
614            pdu = (
615                struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters))
616                + parameters
617            )
618        self.pdu = pdu
619        self.transaction_id = transaction_id
620
621    def init_from_bytes(self, pdu, offset):
622        return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
623
624    def to_bytes(self):
625        return self.pdu
626
627    def __bytes__(self):
628        return self.to_bytes()
629
630    def __str__(self):
631        result = f'{color(self.name, "blue")} [TID={self.transaction_id}]'
632        if fields := getattr(self, 'fields', None):
633            result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, '  ')
634        elif len(self.pdu) > 1:
635            result += f': {self.pdu.hex()}'
636        return result
637
638
639# -----------------------------------------------------------------------------
640@SDP_PDU.subclass([('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})])
641class SDP_ErrorResponse(SDP_PDU):
642    '''
643    See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
644    '''
645
646
647# -----------------------------------------------------------------------------
648@SDP_PDU.subclass(
649    [
650        ('service_search_pattern', DataElement.parse_from_bytes),
651        ('maximum_service_record_count', '>2'),
652        ('continuation_state', '*'),
653    ]
654)
655class SDP_ServiceSearchRequest(SDP_PDU):
656    '''
657    See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
658    '''
659
660    service_search_pattern: DataElement
661    maximum_service_record_count: int
662    continuation_state: bytes
663
664
665# -----------------------------------------------------------------------------
666@SDP_PDU.subclass(
667    [
668        ('total_service_record_count', '>2'),
669        ('current_service_record_count', '>2'),
670        (
671            'service_record_handle_list',
672            SDP_PDU.parse_service_record_handle_list_preceded_by_count,
673        ),
674        ('continuation_state', '*'),
675    ]
676)
677class SDP_ServiceSearchResponse(SDP_PDU):
678    '''
679    See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
680    '''
681
682    service_record_handle_list: List[int]
683    total_service_record_count: int
684    current_service_record_count: int
685    continuation_state: bytes
686
687
688# -----------------------------------------------------------------------------
689@SDP_PDU.subclass(
690    [
691        ('service_record_handle', '>4'),
692        ('maximum_attribute_byte_count', '>2'),
693        ('attribute_id_list', DataElement.parse_from_bytes),
694        ('continuation_state', '*'),
695    ]
696)
697class SDP_ServiceAttributeRequest(SDP_PDU):
698    '''
699    See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
700    '''
701
702    service_record_handle: int
703    maximum_attribute_byte_count: int
704    attribute_id_list: DataElement
705    continuation_state: bytes
706
707
708# -----------------------------------------------------------------------------
709@SDP_PDU.subclass(
710    [
711        ('attribute_list_byte_count', '>2'),
712        ('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
713        ('continuation_state', '*'),
714    ]
715)
716class SDP_ServiceAttributeResponse(SDP_PDU):
717    '''
718    See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
719    '''
720
721    attribute_list_byte_count: int
722    attribute_list: bytes
723    continuation_state: bytes
724
725
726# -----------------------------------------------------------------------------
727@SDP_PDU.subclass(
728    [
729        ('service_search_pattern', DataElement.parse_from_bytes),
730        ('maximum_attribute_byte_count', '>2'),
731        ('attribute_id_list', DataElement.parse_from_bytes),
732        ('continuation_state', '*'),
733    ]
734)
735class SDP_ServiceSearchAttributeRequest(SDP_PDU):
736    '''
737    See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
738    '''
739
740    service_search_pattern: DataElement
741    maximum_attribute_byte_count: int
742    attribute_id_list: DataElement
743    continuation_state: bytes
744
745
746# -----------------------------------------------------------------------------
747@SDP_PDU.subclass(
748    [
749        ('attribute_lists_byte_count', '>2'),
750        ('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
751        ('continuation_state', '*'),
752    ]
753)
754class SDP_ServiceSearchAttributeResponse(SDP_PDU):
755    '''
756    See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
757    '''
758
759    attribute_list_byte_count: int
760    attribute_list: bytes
761    continuation_state: bytes
762
763
764# -----------------------------------------------------------------------------
765class Client:
766    channel: Optional[l2cap.ClassicChannel]
767
768    def __init__(self, connection: Connection) -> None:
769        self.connection = connection
770        self.pending_request = None
771        self.channel = None
772
773    async def connect(self) -> None:
774        self.channel = await self.connection.create_l2cap_channel(
775            spec=l2cap.ClassicChannelSpec(SDP_PSM)
776        )
777
778    async def disconnect(self) -> None:
779        if self.channel:
780            await self.channel.disconnect()
781            self.channel = None
782
783    async def search_services(self, uuids: List[core.UUID]) -> List[int]:
784        if self.pending_request is not None:
785            raise InvalidStateError('request already pending')
786        if self.channel is None:
787            raise InvalidStateError('L2CAP not connected')
788
789        service_search_pattern = DataElement.sequence(
790            [DataElement.uuid(uuid) for uuid in uuids]
791        )
792
793        # Request and accumulate until there's no more continuation
794        service_record_handle_list = []
795        continuation_state = bytes([0])
796        watchdog = SDP_CONTINUATION_WATCHDOG
797        while watchdog > 0:
798            response_pdu = await self.channel.send_request(
799                SDP_ServiceSearchRequest(
800                    transaction_id=0,  # Transaction ID TODO: pick a real value
801                    service_search_pattern=service_search_pattern,
802                    maximum_service_record_count=0xFFFF,
803                    continuation_state=continuation_state,
804                )
805            )
806            response = SDP_PDU.from_bytes(response_pdu)
807            logger.debug(f'<<< Response: {response}')
808            service_record_handle_list += response.service_record_handle_list
809            continuation_state = response.continuation_state
810            if len(continuation_state) == 1 and continuation_state[0] == 0:
811                break
812            logger.debug(f'continuation: {continuation_state.hex()}')
813            watchdog -= 1
814
815        return service_record_handle_list
816
817    async def search_attributes(
818        self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]]
819    ) -> List[List[ServiceAttribute]]:
820        if self.pending_request is not None:
821            raise InvalidStateError('request already pending')
822        if self.channel is None:
823            raise InvalidStateError('L2CAP not connected')
824
825        service_search_pattern = DataElement.sequence(
826            [DataElement.uuid(uuid) for uuid in uuids]
827        )
828        attribute_id_list = DataElement.sequence(
829            [
830                (
831                    DataElement.unsigned_integer(
832                        attribute_id[0], value_size=attribute_id[1]
833                    )
834                    if isinstance(attribute_id, tuple)
835                    else DataElement.unsigned_integer_16(attribute_id)
836                )
837                for attribute_id in attribute_ids
838            ]
839        )
840
841        # Request and accumulate until there's no more continuation
842        accumulator = b''
843        continuation_state = bytes([0])
844        watchdog = SDP_CONTINUATION_WATCHDOG
845        while watchdog > 0:
846            response_pdu = await self.channel.send_request(
847                SDP_ServiceSearchAttributeRequest(
848                    transaction_id=0,  # Transaction ID TODO: pick a real value
849                    service_search_pattern=service_search_pattern,
850                    maximum_attribute_byte_count=0xFFFF,
851                    attribute_id_list=attribute_id_list,
852                    continuation_state=continuation_state,
853                )
854            )
855            response = SDP_PDU.from_bytes(response_pdu)
856            logger.debug(f'<<< Response: {response}')
857            accumulator += response.attribute_lists
858            continuation_state = response.continuation_state
859            if len(continuation_state) == 1 and continuation_state[0] == 0:
860                break
861            logger.debug(f'continuation: {continuation_state.hex()}')
862            watchdog -= 1
863
864        # Parse the result into attribute lists
865        attribute_lists_sequences = DataElement.from_bytes(accumulator)
866        if attribute_lists_sequences.type != DataElement.SEQUENCE:
867            logger.warning('unexpected data type')
868            return []
869
870        return [
871            ServiceAttribute.list_from_data_elements(sequence.value)
872            for sequence in attribute_lists_sequences.value
873            if sequence.type == DataElement.SEQUENCE
874        ]
875
876    async def get_attributes(
877        self,
878        service_record_handle: int,
879        attribute_ids: List[Union[int, Tuple[int, int]]],
880    ) -> List[ServiceAttribute]:
881        if self.pending_request is not None:
882            raise InvalidStateError('request already pending')
883        if self.channel is None:
884            raise InvalidStateError('L2CAP not connected')
885
886        attribute_id_list = DataElement.sequence(
887            [
888                (
889                    DataElement.unsigned_integer(
890                        attribute_id[0], value_size=attribute_id[1]
891                    )
892                    if isinstance(attribute_id, tuple)
893                    else DataElement.unsigned_integer_16(attribute_id)
894                )
895                for attribute_id in attribute_ids
896            ]
897        )
898
899        # Request and accumulate until there's no more continuation
900        accumulator = b''
901        continuation_state = bytes([0])
902        watchdog = SDP_CONTINUATION_WATCHDOG
903        while watchdog > 0:
904            response_pdu = await self.channel.send_request(
905                SDP_ServiceAttributeRequest(
906                    transaction_id=0,  # Transaction ID TODO: pick a real value
907                    service_record_handle=service_record_handle,
908                    maximum_attribute_byte_count=0xFFFF,
909                    attribute_id_list=attribute_id_list,
910                    continuation_state=continuation_state,
911                )
912            )
913            response = SDP_PDU.from_bytes(response_pdu)
914            logger.debug(f'<<< Response: {response}')
915            accumulator += response.attribute_list
916            continuation_state = response.continuation_state
917            if len(continuation_state) == 1 and continuation_state[0] == 0:
918                break
919            logger.debug(f'continuation: {continuation_state.hex()}')
920            watchdog -= 1
921
922        # Parse the result into a list of attributes
923        attribute_list_sequence = DataElement.from_bytes(accumulator)
924        if attribute_list_sequence.type != DataElement.SEQUENCE:
925            logger.warning('unexpected data type')
926            return []
927
928        return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
929
930    async def __aenter__(self) -> Self:
931        await self.connect()
932        return self
933
934    async def __aexit__(self, *args) -> None:
935        await self.disconnect()
936
937
938# -----------------------------------------------------------------------------
939class Server:
940    CONTINUATION_STATE = bytes([0x01, 0x43])
941    channel: Optional[l2cap.ClassicChannel]
942    Service = NewType('Service', List[ServiceAttribute])
943    service_records: Dict[int, Service]
944    current_response: Union[None, bytes, Tuple[int, List[int]]]
945
946    def __init__(self, device: Device) -> None:
947        self.device = device
948        self.service_records = {}  # Service records maps, by record handle
949        self.channel = None
950        self.current_response = None
951
952    def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
953        l2cap_channel_manager.create_classic_server(
954            spec=l2cap.ClassicChannelSpec(psm=SDP_PSM), handler=self.on_connection
955        )
956
957    def send_response(self, response):
958        logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
959        self.channel.send_pdu(response)
960
961    def match_services(self, search_pattern: DataElement) -> Dict[int, Service]:
962        # Find the services for which the attributes in the pattern is a subset of the
963        # service's attribute values (NOTE: the value search recurses into sequences)
964        matching_services = {}
965        for handle, service in self.service_records.items():
966            for uuid in search_pattern.value:
967                found = False
968                for attribute in service:
969                    if ServiceAttribute.is_uuid_in_value(uuid.value, attribute.value):
970                        found = True
971                        break
972                if found:
973                    matching_services[handle] = service
974                    break
975
976        return matching_services
977
978    def on_connection(self, channel):
979        self.channel = channel
980        self.channel.sink = self.on_pdu
981
982    def on_pdu(self, pdu):
983        try:
984            sdp_pdu = SDP_PDU.from_bytes(pdu)
985        except Exception as error:
986            logger.warning(color(f'failed to parse SDP Request PDU: {error}', 'red'))
987            self.send_response(
988                SDP_ErrorResponse(
989                    transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR
990                )
991            )
992
993        logger.debug(f'{color("<<< Received SDP Request", "green")}: {sdp_pdu}')
994
995        # Find the handler method
996        handler_name = f'on_{sdp_pdu.name.lower()}'
997        handler = getattr(self, handler_name, None)
998        if handler:
999            try:
1000                handler(sdp_pdu)
1001            except Exception as error:
1002                logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
1003                self.send_response(
1004                    SDP_ErrorResponse(
1005                        transaction_id=sdp_pdu.transaction_id,
1006                        error_code=SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR,
1007                    )
1008                )
1009        else:
1010            logger.error(color('SDP Request not handled???', 'red'))
1011            self.send_response(
1012                SDP_ErrorResponse(
1013                    transaction_id=sdp_pdu.transaction_id,
1014                    error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR,
1015                )
1016            )
1017
1018    def get_next_response_payload(self, maximum_size):
1019        if len(self.current_response) > maximum_size:
1020            payload = self.current_response[:maximum_size]
1021            continuation_state = Server.CONTINUATION_STATE
1022            self.current_response = self.current_response[maximum_size:]
1023        else:
1024            payload = self.current_response
1025            continuation_state = bytes([0])
1026            self.current_response = None
1027
1028        return (payload, continuation_state)
1029
1030    @staticmethod
1031    def get_service_attributes(
1032        service: Service, attribute_ids: List[DataElement]
1033    ) -> DataElement:
1034        attributes = []
1035        for attribute_id in attribute_ids:
1036            if attribute_id.value_size == 4:
1037                # Attribute ID range
1038                id_range_start = attribute_id.value >> 16
1039                id_range_end = attribute_id.value & 0xFFFF
1040            else:
1041                id_range_start = attribute_id.value
1042                id_range_end = attribute_id.value
1043            attributes += [
1044                attribute
1045                for attribute in service
1046                if attribute.id >= id_range_start and attribute.id <= id_range_end
1047            ]
1048
1049        # Return the matching attributes, sorted by attribute id
1050        attributes.sort(key=lambda x: x.id)
1051        attribute_list = DataElement.sequence([])
1052        for attribute in attributes:
1053            attribute_list.value.append(DataElement.unsigned_integer_16(attribute.id))
1054            attribute_list.value.append(attribute.value)
1055
1056        return attribute_list
1057
1058    def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
1059        # Check if this is a continuation
1060        if len(request.continuation_state) > 1:
1061            if self.current_response is None:
1062                self.send_response(
1063                    SDP_ErrorResponse(
1064                        transaction_id=request.transaction_id,
1065                        error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
1066                    )
1067                )
1068                return
1069        else:
1070            # Cleanup any partial response leftover
1071            self.current_response = None
1072
1073            # Find the matching services
1074            matching_services = self.match_services(request.service_search_pattern)
1075            service_record_handles = list(matching_services.keys())
1076
1077            # Only return up to the maximum requested
1078            service_record_handles_subset = service_record_handles[
1079                : request.maximum_service_record_count
1080            ]
1081
1082            # Serialize to a byte array, and remember the total count
1083            logger.debug(f'Service Record Handles: {service_record_handles}')
1084            self.current_response = (
1085                len(service_record_handles),
1086                service_record_handles_subset,
1087            )
1088
1089        # Respond, keeping any unsent handles for later
1090        assert isinstance(self.current_response, tuple)
1091        service_record_handles = self.current_response[1][
1092            : request.maximum_service_record_count
1093        ]
1094        self.current_response = (
1095            self.current_response[0],
1096            self.current_response[1][request.maximum_service_record_count :],
1097        )
1098        continuation_state = (
1099            Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
1100        )
1101        service_record_handle_list = b''.join(
1102            [struct.pack('>I', handle) for handle in service_record_handles]
1103        )
1104        self.send_response(
1105            SDP_ServiceSearchResponse(
1106                transaction_id=request.transaction_id,
1107                total_service_record_count=self.current_response[0],
1108                current_service_record_count=len(service_record_handles),
1109                service_record_handle_list=service_record_handle_list,
1110                continuation_state=continuation_state,
1111            )
1112        )
1113
1114    def on_sdp_service_attribute_request(
1115        self, request: SDP_ServiceAttributeRequest
1116    ) -> None:
1117        # Check if this is a continuation
1118        if len(request.continuation_state) > 1:
1119            if self.current_response is None:
1120                self.send_response(
1121                    SDP_ErrorResponse(
1122                        transaction_id=request.transaction_id,
1123                        error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
1124                    )
1125                )
1126                return
1127        else:
1128            # Cleanup any partial response leftover
1129            self.current_response = None
1130
1131            # Check that the service exists
1132            service = self.service_records.get(request.service_record_handle)
1133            if service is None:
1134                self.send_response(
1135                    SDP_ErrorResponse(
1136                        transaction_id=request.transaction_id,
1137                        error_code=SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR,
1138                    )
1139                )
1140                return
1141
1142            # Get the attributes for the service
1143            attribute_list = Server.get_service_attributes(
1144                service, request.attribute_id_list.value
1145            )
1146
1147            # Serialize to a byte array
1148            logger.debug(f'Attributes: {attribute_list}')
1149            self.current_response = bytes(attribute_list)
1150
1151        # Respond, keeping any pending chunks for later
1152        attribute_list_response, continuation_state = self.get_next_response_payload(
1153            request.maximum_attribute_byte_count
1154        )
1155        self.send_response(
1156            SDP_ServiceAttributeResponse(
1157                transaction_id=request.transaction_id,
1158                attribute_list_byte_count=len(attribute_list_response),
1159                attribute_list=attribute_list,
1160                continuation_state=continuation_state,
1161            )
1162        )
1163
1164    def on_sdp_service_search_attribute_request(
1165        self, request: SDP_ServiceSearchAttributeRequest
1166    ) -> None:
1167        # Check if this is a continuation
1168        if len(request.continuation_state) > 1:
1169            if self.current_response is None:
1170                self.send_response(
1171                    SDP_ErrorResponse(
1172                        transaction_id=request.transaction_id,
1173                        error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
1174                    )
1175                )
1176        else:
1177            # Cleanup any partial response leftover
1178            self.current_response = None
1179
1180            # Find the matching services
1181            matching_services = self.match_services(
1182                request.service_search_pattern
1183            ).values()
1184
1185            # Filter the required attributes
1186            attribute_lists = DataElement.sequence([])
1187            for service in matching_services:
1188                attribute_list = Server.get_service_attributes(
1189                    service, request.attribute_id_list.value
1190                )
1191                if attribute_list.value:
1192                    attribute_lists.value.append(attribute_list)
1193
1194            # Serialize to a byte array
1195            logger.debug(f'Search response: {attribute_lists}')
1196            self.current_response = bytes(attribute_lists)
1197
1198        # Respond, keeping any pending chunks for later
1199        attribute_lists_response, continuation_state = self.get_next_response_payload(
1200            request.maximum_attribute_byte_count
1201        )
1202        self.send_response(
1203            SDP_ServiceSearchAttributeResponse(
1204                transaction_id=request.transaction_id,
1205                attribute_lists_byte_count=len(attribute_lists_response),
1206                attribute_lists=attribute_lists,
1207                continuation_state=continuation_state,
1208            )
1209        )
1210