1# Copyright 2021-2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import os
20import pytest
21import struct
22import logging
23from unittest import mock
24
25from bumble import device
26from bumble.profiles import csip
27from .test_utils import TwoDevices
28
29# -----------------------------------------------------------------------------
30# Logging
31# -----------------------------------------------------------------------------
32logger = logging.getLogger(__name__)
33
34
35# -----------------------------------------------------------------------------
36def test_s1():
37    assert (
38        csip.s1(b'SIRKenc'[::-1])
39        == bytes.fromhex('6901983f 18149e82 3c7d133a 7d774572')[::-1]
40    )
41
42
43# -----------------------------------------------------------------------------
44def test_k1():
45    K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1]
46    SALT = csip.s1(b'SIRKenc'[::-1])
47    P = b'csis'[::-1]
48    assert (
49        csip.k1(K, SALT, P)
50        == bytes.fromhex('5277453c c094d982 b0e8ee53 2f2d1f8b')[::-1]
51    )
52
53
54# -----------------------------------------------------------------------------
55def test_sih():
56    SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1]
57    PRAND = bytes.fromhex('69f563')[::-1]
58    assert csip.sih(SIRK, PRAND) == bytes.fromhex('1948da')[::-1]
59
60
61# -----------------------------------------------------------------------------
62def test_sef():
63    SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1]
64    K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1]
65    assert (
66        csip.sef(K, SIRK) == bytes.fromhex('170a3835 e13524a0 7e2562d5 f25fd346')[::-1]
67    )
68
69
70# -----------------------------------------------------------------------------
71@pytest.mark.asyncio
72@pytest.mark.parametrize(
73    'sirk_type,', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
74)
75async def test_csis(sirk_type):
76    SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
77    LTK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
78
79    devices = TwoDevices()
80    devices[0].add_service(
81        csip.CoordinatedSetIdentificationService(
82            set_identity_resolving_key=SIRK,
83            set_identity_resolving_key_type=sirk_type,
84            coordinated_set_size=2,
85            set_member_lock=csip.MemberLock.UNLOCKED,
86            set_member_rank=0,
87        )
88    )
89
90    await devices.setup_connection()
91
92    # Mock encryption.
93    devices.connections[0].encryption = 1
94    devices.connections[1].encryption = 1
95    devices[0].get_long_term_key = mock.AsyncMock(return_value=LTK)
96    devices[1].get_long_term_key = mock.AsyncMock(return_value=LTK)
97
98    peer = device.Peer(devices.connections[1])
99    csis_client = await peer.discover_service_and_create_proxy(
100        csip.CoordinatedSetIdentificationProxy
101    )
102
103    assert await csis_client.read_set_identity_resolving_key() == (sirk_type, SIRK)
104    assert await csis_client.coordinated_set_size.read_value() == struct.pack('B', 2)
105    assert await csis_client.set_member_lock.read_value() == struct.pack(
106        'B', csip.MemberLock.UNLOCKED
107    )
108    assert await csis_client.set_member_rank.read_value() == struct.pack('B', 0)
109
110
111# -----------------------------------------------------------------------------
112async def run():
113    test_sih()
114    await test_csis()
115
116
117# -----------------------------------------------------------------------------
118if __name__ == '__main__':
119    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
120    asyncio.run(run())
121