1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18
19import pytest
20from unittest import mock
21
22from bumble import smp
23from bumble import pairing
24from bumble.crypto import EccKey, aes_cmac, ah, c1, f4, f5, f6, g2, h6, h7, s1
25from bumble.pairing import OobData, OobSharedData, LeRole
26from bumble.hci import Address
27from bumble.core import AdvertisingData
28from bumble.device import Device
29
30from typing import Optional
31
32# -----------------------------------------------------------------------------
33# pylint: disable=invalid-name
34# -----------------------------------------------------------------------------
35
36
37# -----------------------------------------------------------------------------
38def reversed_hex(hex_str: str) -> bytes:
39    return bytes.fromhex(hex_str)[::-1]
40
41
42# -----------------------------------------------------------------------------
43def test_ecc():
44    key = EccKey.generate()
45    x = key.x
46    y = key.y
47
48    assert len(x) == 32
49    assert len(y) == 32
50
51    # Test DH with test vectors from the spec
52    private_A = (
53        '3f49f6d4 a3c55f38 74c9b3e3 d2103f50 4aff607b eb40b799 5899b8a6 cd3c1abd'
54    )
55    private_B = (
56        '55188b3d 32f6bb9a 900afcfb eed4e72a 59cb9ac2 f19d7cfb 6b4fdd49 f47fc5fd'
57    )
58    public_A_x = (
59        '20b003d2 f297be2c 5e2c83a7 e9f9a5b9 eff49111 acf4fddb cc030148 0e359de6'
60    )
61    public_A_y = (
62        'dc809c49 652aeb6d 63329abf 5a52155c 766345c2 8fed3024 741c8ed0 1589d28b'
63    )
64    public_B_x = (
65        '1ea1f0f0 1faf1d96 09592284 f19e4c00 47b58afd 8615a69f 559077b2 2faaa190'
66    )
67    public_B_y = (
68        '4c55f33e 429dad37 7356703a 9ab85160 472d1130 e28e3676 5f89aff9 15b1214a'
69    )
70    dhkey = 'ec0234a3 57c8ad05 341010a6 0a397d9b 99796b13 b4f866f1 868d34f3 73bfa698'
71
72    key_a = EccKey.from_private_key_bytes(
73        bytes.fromhex(private_A), bytes.fromhex(public_A_x), bytes.fromhex(public_A_y)
74    )
75    shared_key = key_a.dh(bytes.fromhex(public_B_x), bytes.fromhex(public_B_y))
76    assert shared_key == bytes.fromhex(dhkey)
77
78    key_b = EccKey.from_private_key_bytes(
79        bytes.fromhex(private_B), bytes.fromhex(public_B_x), bytes.fromhex(public_B_y)
80    )
81    shared_key = key_b.dh(bytes.fromhex(public_A_x), bytes.fromhex(public_A_y))
82    assert shared_key == bytes.fromhex(dhkey)
83
84
85# -----------------------------------------------------------------------------
86def test_c1():
87    k = bytes(16)
88    r = reversed_hex('5783D52156AD6F0E6388274EC6702EE0')
89    pres = reversed_hex('05000800000302')
90    preq = reversed_hex('07071000000101')
91    iat = 1
92    ia = reversed_hex('A1A2A3A4A5A6')
93    rat = 0
94    ra = reversed_hex('B1B2B3B4B5B6')
95    result = c1(k, r, preq, pres, iat, rat, ia, ra)
96    assert result == reversed_hex('1e1e3fef878988ead2a74dc5bef13b86')
97
98
99# -----------------------------------------------------------------------------
100def test_s1():
101    k = bytes(16)
102    r1 = reversed_hex('000F0E0D0C0B0A091122334455667788')
103    r2 = reversed_hex('010203040506070899AABBCCDDEEFF00')
104    result = s1(k, r1, r2)
105    assert result == reversed_hex('9a1fe1f0e8b0f49b5b4216ae796da062')
106
107
108# -----------------------------------------------------------------------------
109def test_aes_cmac():
110    m = b''
111    k = bytes.fromhex('2b7e1516 28aed2a6 abf71588 09cf4f3c')
112    cmac = aes_cmac(m, k)
113    assert cmac == bytes.fromhex('bb1d6929 e9593728 7fa37d12 9b756746')
114
115    m = bytes.fromhex('6bc1bee2 2e409f96 e93d7e11 7393172a')
116    cmac = aes_cmac(m, k)
117    assert cmac == bytes.fromhex('070a16b4 6b4d4144 f79bdd9d d04a287c')
118
119    m = bytes.fromhex(
120        '6bc1bee2 2e409f96 e93d7e11 7393172a'
121        + 'ae2d8a57 1e03ac9c 9eb76fac 45af8e51'
122        + '30c81c46 a35ce411'
123    )
124    cmac = aes_cmac(m, k)
125    assert cmac == bytes.fromhex('dfa66747 de9ae630 30ca3261 1497c827')
126
127    m = bytes.fromhex(
128        '6bc1bee2 2e409f96 e93d7e11 7393172a'
129        + 'ae2d8a57 1e03ac9c 9eb76fac 45af8e51'
130        + '30c81c46 a35ce411 e5fbc119 1a0a52ef'
131        + 'f69f2445 df4f9b17 ad2b417b e66c3710'
132    )
133    cmac = aes_cmac(m, k)
134    assert cmac == bytes.fromhex('51f0bebf 7e3b9d92 fc497417 79363cfe')
135
136
137# -----------------------------------------------------------------------------
138def test_f4():
139    u = reversed_hex(
140        '20b003d2 f297be2c 5e2c83a7 e9f9a5b9 eff49111 acf4fddb cc030148 0e359de6'
141    )
142    v = reversed_hex(
143        '55188b3d 32f6bb9a 900afcfb eed4e72a 59cb9ac2 f19d7cfb 6b4fdd49 f47fc5fd'
144    )
145    x = reversed_hex('d5cb8454 d177733e ffffb2ec 712baeab')
146    z = b'\0'
147    value = f4(u, v, x, z)
148    assert value == reversed_hex('f2c916f1 07a9bd1c f1eda1be a974872d')
149
150
151# -----------------------------------------------------------------------------
152def test_f5():
153    w = reversed_hex(
154        'ec0234a3 57c8ad05 341010a6 0a397d9b 99796b13 b4f866f1 868d34f3 73bfa698'
155    )
156    n1 = reversed_hex('d5cb8454 d177733e ffffb2ec 712baeab')
157    n2 = reversed_hex('a6e8e7cc 25a75f6e 216583f7 ff3dc4cf')
158    a1 = reversed_hex('00561237 37bfce')
159    a2 = reversed_hex('00a71370 2dcfc1')
160    value = f5(w, n1, n2, a1, a2)
161    assert value[0] == reversed_hex('2965f176 a1084a02 fd3f6a20 ce636e20')
162    assert value[1] == reversed_hex('69867911 69d7cd23 980522b5 94750a38')
163
164
165# -----------------------------------------------------------------------------
166def test_f6():
167    n1 = reversed_hex('d5cb8454 d177733e ffffb2ec 712baeab')
168    n2 = reversed_hex('a6e8e7cc 25a75f6e 216583f7 ff3dc4cf')
169    mac_key = reversed_hex('2965f176 a1084a02 fd3f6a20 ce636e20')
170    r = reversed_hex('12a3343b b453bb54 08da42d2 0c2d0fc8')
171    io_cap = reversed_hex('010102')
172    a1 = reversed_hex('00561237 37bfce')
173    a2 = reversed_hex('00a71370 2dcfc1')
174    value = f6(mac_key, n1, n2, r, io_cap, a1, a2)
175    assert value == reversed_hex('e3c47398 9cd0e8c5 d26c0b09 da958f61')
176
177
178# -----------------------------------------------------------------------------
179def test_g2():
180    u = reversed_hex(
181        '20b003d2 f297be2c 5e2c83a7 e9f9a5b9 eff49111 acf4fddb cc030148 0e359de6'
182    )
183    v = reversed_hex(
184        '55188b3d 32f6bb9a 900afcfb eed4e72a 59cb9ac2 f19d7cfb 6b4fdd49 f47fc5fd'
185    )
186    x = reversed_hex('d5cb8454 d177733e ffffb2ec 712baeab')
187    y = reversed_hex('a6e8e7cc 25a75f6e 216583f7 ff3dc4cf')
188    value = g2(u, v, x, y)
189    assert value == 0x2F9ED5BA
190
191
192# -----------------------------------------------------------------------------
193def test_h6():
194    KEY = reversed_hex('ec0234a3 57c8ad05 341010a6 0a397d9b')
195    KEY_ID = bytes.fromhex('6c656272')
196    assert h6(KEY, KEY_ID) == reversed_hex('2d9ae102 e76dc91c e8d3a9e2 80b16399')
197
198
199# -----------------------------------------------------------------------------
200def test_h7():
201    KEY = reversed_hex('ec0234a3 57c8ad05 341010a6 0a397d9b')
202    SALT = bytes.fromhex('00000000 00000000 00000000 746D7031')
203    assert h7(SALT, KEY) == reversed_hex('fb173597 c6a3c0ec d2998c2a 75a57011')
204
205
206# -----------------------------------------------------------------------------
207def test_ah():
208    irk = reversed_hex('ec0234a3 57c8ad05 341010a6 0a397d9b')
209    prand = reversed_hex('708194')
210    value = ah(irk, prand)
211    expected = reversed_hex('0dfbaa')
212    assert value == expected
213
214
215# -----------------------------------------------------------------------------
216def test_oob_data():
217    oob_data = OobData(
218        address=Address("F0:F1:F2:F3:F4:F5"),
219        role=LeRole.BOTH_PERIPHERAL_PREFERRED,
220        shared_data=OobSharedData(c=b'12', r=b'34'),
221    )
222    oob_data_ad = oob_data.to_ad()
223    oob_data_bytes = bytes(oob_data_ad)
224    oob_data_ad_parsed = AdvertisingData.from_bytes(oob_data_bytes)
225    oob_data_parsed = OobData.from_ad(oob_data_ad_parsed)
226    assert oob_data_parsed.address == oob_data.address
227    assert oob_data_parsed.role == oob_data.role
228    assert oob_data_parsed.shared_data.c == oob_data.shared_data.c
229    assert oob_data_parsed.shared_data.r == oob_data.shared_data.r
230
231
232# -----------------------------------------------------------------------------
233@pytest.mark.parametrize(
234    'ct2, expected',
235    [
236        (False, 'bc1ca4ef 633fc1bd 0d8230af ee388fb0'),
237        (True, '287ad379 dca40253 0a39f1f4 3047b835'),
238    ],
239)
240def test_ltk_to_link_key(ct2: bool, expected: str):
241    LTK = reversed_hex('368df9bc e3264b58 bd066c33 334fbf64')
242    assert smp.Session.derive_link_key(LTK, ct2) == reversed_hex(expected)
243
244
245# -----------------------------------------------------------------------------
246@pytest.mark.parametrize(
247    'ct2, expected',
248    [
249        (False, 'a813fb72 f1a3dfa1 8a2c9a43 f10d0a30'),
250        (True, 'e85e09eb 5eccb3e2 69418a13 3211bc79'),
251    ],
252)
253def test_link_key_to_ltk(ct2: bool, expected: str):
254    LINK_KEY = reversed_hex('05040302 01000908 07060504 03020100')
255    assert smp.Session.derive_ltk(LINK_KEY, ct2) == reversed_hex(expected)
256
257
258# -----------------------------------------------------------------------------
259@pytest.mark.parametrize(
260    'identity_address_type, public_address, random_address, expected_identity_address',
261    [
262        (
263            None,
264            Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
265            Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
266            Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
267        ),
268        (
269            None,
270            Address.ANY,
271            Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
272            Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
273        ),
274        (
275            pairing.PairingConfig.AddressType.PUBLIC,
276            Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
277            Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
278            Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
279        ),
280        (
281            pairing.PairingConfig.AddressType.RANDOM,
282            Address("00:11:22:33:44:55", Address.PUBLIC_DEVICE_ADDRESS),
283            Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
284            Address("EE:EE:EE:EE:EE:EE", Address.RANDOM_DEVICE_ADDRESS),
285        ),
286    ],
287)
288@pytest.mark.asyncio
289async def test_send_identity_address_command(
290    identity_address_type: Optional[pairing.PairingConfig.AddressType],
291    public_address: Address,
292    random_address: Address,
293    expected_identity_address: Address,
294):
295    device = Device()
296    device.public_address = public_address
297    device.static_address = random_address
298    pairing_config = pairing.PairingConfig(identity_address_type=identity_address_type)
299    session = smp.Session(device.smp_manager, mock.MagicMock(), pairing_config, True)
300
301    with mock.patch.object(session, 'send_command') as mock_method:
302        session.send_identity_address_command()
303
304    actual_command = mock_method.call_args.args[0]
305    assert actual_command.addr_type == expected_identity_address.address_type
306    assert actual_command.bd_addr == expected_identity_address
307
308
309# -----------------------------------------------------------------------------
310if __name__ == '__main__':
311    test_ecc()
312    test_c1()
313    test_s1()
314    test_aes_cmac()
315    test_f4()
316    test_f5()
317    test_f6()
318    test_g2()
319    test_h6()
320    test_h7()
321    test_ah()
322    test_oob_data()
323