1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import json
20import logging
21import pathlib
22import pytest
23import tempfile
24import os
25
26from bumble.keys import JsonKeyStore, PairingKeys
27
28
29# -----------------------------------------------------------------------------
30# Logging
31# -----------------------------------------------------------------------------
32logger = logging.getLogger(__name__)
33
34
35# -----------------------------------------------------------------------------
36# Tests
37# -----------------------------------------------------------------------------
38
39JSON1 = """
40        {
41            "my_namespace": {
42                "14:7D:DA:4E:53:A8/P": {
43                    "address_type": 0,
44                    "irk": {
45                        "authenticated": false,
46                        "value": "e7b2543b206e4e46b44f9e51dad22bd1"
47                    },
48                    "link_key": {
49                        "authenticated": false,
50                        "value": "0745dd9691e693d9dca740f7d8dfea75"
51                    },
52                    "ltk": {
53                        "authenticated": false,
54                        "value": "d1897ee10016eb1a08e4e037fd54c683"
55                    }
56                }
57            }
58        }
59        """
60
61JSON2 = """
62        {
63            "my_namespace1": {
64            },
65            "my_namespace2": {
66            }
67        }
68        """
69
70JSON3 = """
71        {
72            "my_namespace1": {
73            },
74            "__DEFAULT__": {
75                "14:7D:DA:4E:53:A8/P": {
76                    "address_type": 0,
77                    "irk": {
78                        "authenticated": false,
79                        "value": "e7b2543b206e4e46b44f9e51dad22bd1"
80                    }
81                }
82            }
83        }
84        """
85
86
87# -----------------------------------------------------------------------------
88@pytest.fixture
89def temporary_file():
90    file = tempfile.NamedTemporaryFile(delete=False)
91    file.close()
92    yield file.name
93    pathlib.Path(file.name).unlink()
94
95
96# -----------------------------------------------------------------------------
97@pytest.mark.asyncio
98async def test_basic(temporary_file):
99    with open(temporary_file, mode='w', encoding='utf-8') as file:
100        file.write("{}")
101        file.flush()
102
103    keystore = JsonKeyStore('my_namespace', temporary_file)
104
105    keys = await keystore.get_all()
106    assert len(keys) == 0
107
108    keys = PairingKeys()
109    await keystore.update('foo', keys)
110    foo = await keystore.get('foo')
111    assert foo is not None
112    assert foo.ltk is None
113    ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
114    keys.ltk = PairingKeys.Key(ltk)
115    await keystore.update('foo', keys)
116    foo = await keystore.get('foo')
117    assert foo is not None
118    assert foo.ltk is not None
119    assert foo.ltk.value == ltk
120
121    with open(file.name, "r", encoding="utf-8") as json_file:
122        json_data = json.load(json_file)
123        assert 'my_namespace' in json_data
124        assert 'foo' in json_data['my_namespace']
125        assert 'ltk' in json_data['my_namespace']['foo']
126
127
128# -----------------------------------------------------------------------------
129@pytest.mark.asyncio
130async def test_parsing(temporary_file):
131    with open(temporary_file, mode='w', encoding='utf-8') as file:
132        file.write(JSON1)
133        file.flush()
134
135    keystore = JsonKeyStore('my_namespace', file.name)
136    foo = await keystore.get('14:7D:DA:4E:53:A8/P')
137    assert foo is not None
138    assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683')
139
140
141# -----------------------------------------------------------------------------
142@pytest.mark.asyncio
143async def test_default_namespace(temporary_file):
144    with open(temporary_file, mode='w', encoding='utf-8') as file:
145        file.write(JSON1)
146        file.flush()
147
148    keystore = JsonKeyStore(None, file.name)
149    all_keys = await keystore.get_all()
150    assert len(all_keys) == 1
151    name, keys = all_keys[0]
152    assert name == '14:7D:DA:4E:53:A8/P'
153    assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
154
155    with open(temporary_file, mode='w', encoding='utf-8') as file:
156        file.write(JSON2)
157        file.flush()
158
159    keystore = JsonKeyStore(None, file.name)
160    keys = PairingKeys()
161    ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
162    keys.ltk = PairingKeys.Key(ltk)
163    await keystore.update('foo', keys)
164    with open(file.name, "r", encoding="utf-8") as json_file:
165        json_data = json.load(json_file)
166        assert '__DEFAULT__' in json_data
167        assert 'foo' in json_data['__DEFAULT__']
168        assert 'ltk' in json_data['__DEFAULT__']['foo']
169
170    with open(temporary_file, mode='w', encoding='utf-8') as file:
171        file.write(JSON3)
172        file.flush()
173
174    keystore = JsonKeyStore(None, file.name)
175    all_keys = await keystore.get_all()
176    assert len(all_keys) == 1
177    name, keys = all_keys[0]
178    assert name == '14:7D:DA:4E:53:A8/P'
179    assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
180
181
182# -----------------------------------------------------------------------------
183async def run_tests():
184    await test_basic()
185    await test_parsing()
186    await test_default_namespace()
187
188
189# -----------------------------------------------------------------------------
190if __name__ == '__main__':
191    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
192    asyncio.run(run_tests())
193