1# Copyright 2021-2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16# -----------------------------------------------------------------------------
17# Imports
18# -----------------------------------------------------------------------------
19from __future__ import annotations
20import enum
21import struct
22from typing import Optional, Tuple
23
24from bumble import core
25from bumble import crypto
26from bumble import device
27from bumble import gatt
28from bumble import gatt_client
29
30
31# -----------------------------------------------------------------------------
32# Constants
33# -----------------------------------------------------------------------------
34SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
35
36
37class SirkType(enum.IntEnum):
38    '''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
39
40    ENCRYPTED = 0x00
41    PLAINTEXT = 0x01
42
43
44class MemberLock(enum.IntEnum):
45    '''Coordinated Set Identification Service - 5.3 Set Member Lock.'''
46
47    UNLOCKED = 0x01
48    LOCKED = 0x02
49
50
51# -----------------------------------------------------------------------------
52# Crypto Toolbox
53# -----------------------------------------------------------------------------
54def s1(m: bytes) -> bytes:
55    '''
56    Coordinated Set Identification Service - 4.3 s1 SALT generation function.
57    '''
58    return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
59
60
61def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
62    '''
63    Coordinated Set Identification Service - 4.4 k1 derivation function.
64    '''
65    t = crypto.aes_cmac(n[::-1], salt[::-1])
66    return crypto.aes_cmac(p[::-1], t)[::-1]
67
68
69def sef(k: bytes, r: bytes) -> bytes:
70    '''
71    Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
72
73    SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
74      * Plaintext in encryption
75      * Cipher in decryption
76    '''
77    return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
78
79
80def sih(k: bytes, r: bytes) -> bytes:
81    '''
82    Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
83    '''
84    return crypto.e(k, r + bytes(13))[:3]
85
86
87def generate_rsi(sirk: bytes) -> bytes:
88    '''
89    Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
90    '''
91    prand = crypto.generate_prand()
92    return sih(sirk, prand) + prand
93
94
95# -----------------------------------------------------------------------------
96# Server
97# -----------------------------------------------------------------------------
98class CoordinatedSetIdentificationService(gatt.TemplateService):
99    UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
100
101    set_identity_resolving_key: bytes
102    set_identity_resolving_key_characteristic: gatt.Characteristic
103    coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
104    set_member_lock_characteristic: Optional[gatt.Characteristic] = None
105    set_member_rank_characteristic: Optional[gatt.Characteristic] = None
106
107    def __init__(
108        self,
109        set_identity_resolving_key: bytes,
110        set_identity_resolving_key_type: SirkType,
111        coordinated_set_size: Optional[int] = None,
112        set_member_lock: Optional[MemberLock] = None,
113        set_member_rank: Optional[int] = None,
114    ) -> None:
115        if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
116            raise core.InvalidArgumentError(
117                f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
118            )
119
120        characteristics = []
121
122        self.set_identity_resolving_key = set_identity_resolving_key
123        self.set_identity_resolving_key_type = set_identity_resolving_key_type
124        self.set_identity_resolving_key_characteristic = gatt.Characteristic(
125            uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
126            properties=gatt.Characteristic.Properties.READ
127            | gatt.Characteristic.Properties.NOTIFY,
128            permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
129            value=gatt.CharacteristicValue(read=self.on_sirk_read),
130        )
131        characteristics.append(self.set_identity_resolving_key_characteristic)
132
133        if coordinated_set_size is not None:
134            self.coordinated_set_size_characteristic = gatt.Characteristic(
135                uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
136                properties=gatt.Characteristic.Properties.READ
137                | gatt.Characteristic.Properties.NOTIFY,
138                permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
139                value=struct.pack('B', coordinated_set_size),
140            )
141            characteristics.append(self.coordinated_set_size_characteristic)
142
143        if set_member_lock is not None:
144            self.set_member_lock_characteristic = gatt.Characteristic(
145                uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
146                properties=gatt.Characteristic.Properties.READ
147                | gatt.Characteristic.Properties.NOTIFY
148                | gatt.Characteristic.Properties.WRITE,
149                permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
150                | gatt.Characteristic.Permissions.WRITEABLE,
151                value=struct.pack('B', set_member_lock),
152            )
153            characteristics.append(self.set_member_lock_characteristic)
154
155        if set_member_rank is not None:
156            self.set_member_rank_characteristic = gatt.Characteristic(
157                uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
158                properties=gatt.Characteristic.Properties.READ
159                | gatt.Characteristic.Properties.NOTIFY,
160                permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
161                value=struct.pack('B', set_member_rank),
162            )
163            characteristics.append(self.set_member_rank_characteristic)
164
165        super().__init__(characteristics)
166
167    async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
168        if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
169            sirk_bytes = self.set_identity_resolving_key
170        else:
171            assert connection
172
173            if connection.transport == core.BT_LE_TRANSPORT:
174                key = await connection.device.get_long_term_key(
175                    connection_handle=connection.handle, rand=b'', ediv=0
176                )
177            else:
178                key = await connection.device.get_link_key(connection.peer_address)
179
180            if not key:
181                raise core.InvalidOperationError('LTK or LinkKey is not present')
182
183            sirk_bytes = sef(key, self.set_identity_resolving_key)
184
185        return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
186
187    def get_advertising_data(self) -> bytes:
188        return bytes(
189            core.AdvertisingData(
190                [
191                    (
192                        core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
193                        generate_rsi(self.set_identity_resolving_key),
194                    ),
195                ]
196            )
197        )
198
199
200# -----------------------------------------------------------------------------
201# Client
202# -----------------------------------------------------------------------------
203class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
204    SERVICE_CLASS = CoordinatedSetIdentificationService
205
206    set_identity_resolving_key: gatt_client.CharacteristicProxy
207    coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
208    set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
209    set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
210
211    def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
212        self.service_proxy = service_proxy
213
214        self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
215            gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
216        )[0]
217
218        if characteristics := service_proxy.get_characteristics_by_uuid(
219            gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
220        ):
221            self.coordinated_set_size = characteristics[0]
222
223        if characteristics := service_proxy.get_characteristics_by_uuid(
224            gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
225        ):
226            self.set_member_lock = characteristics[0]
227
228        if characteristics := service_proxy.get_characteristics_by_uuid(
229            gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
230        ):
231            self.set_member_rank = characteristics[0]
232
233    async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
234        '''Reads SIRK and decrypts if encrypted.'''
235        response = await self.set_identity_resolving_key.read_value()
236        if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
237            raise core.InvalidPacketError('Invalid SIRK value')
238
239        sirk_type = SirkType(response[0])
240        if sirk_type == SirkType.PLAINTEXT:
241            sirk = response[1:]
242        else:
243            connection = self.service_proxy.client.connection
244            device = connection.device
245            if connection.transport == core.BT_LE_TRANSPORT:
246                key = await device.get_long_term_key(
247                    connection_handle=connection.handle, rand=b'', ediv=0
248                )
249            else:
250                key = await device.get_link_key(connection.peer_address)
251
252            if not key:
253                raise core.InvalidOperationError('LTK or LinkKey is not present')
254
255            sirk = sef(key, response[1:])
256
257        return (sirk_type, sirk)
258