1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import pytest
20from typing import List
21
22from . import test_utils
23from bumble import core
24from bumble.rfcomm import (
25    RFCOMM_Frame,
26    Server,
27    Client,
28    DLC,
29    make_service_sdp_records,
30    find_rfcomm_channels,
31    find_rfcomm_channel_with_uuid,
32    RFCOMM_PSM,
33)
34
35_TIMEOUT = 0.1
36
37
38# -----------------------------------------------------------------------------
39def basic_frame_check(x):
40    serialized = bytes(x)
41    if len(serialized) < 500:
42        print('Original:', x)
43        print('Serialized:', serialized.hex())
44    parsed = RFCOMM_Frame.from_bytes(serialized)
45    if len(serialized) < 500:
46        print('Parsed:', parsed)
47    parsed_bytes = bytes(parsed)
48    if len(serialized) < 500:
49        print('Parsed Bytes:', parsed_bytes.hex())
50    assert parsed_bytes == serialized
51    x_str = str(x)
52    parsed_str = str(parsed)
53    assert x_str == parsed_str
54
55
56# -----------------------------------------------------------------------------
57def test_frames():
58    data = bytes.fromhex('033f011c')
59    frame = RFCOMM_Frame.from_bytes(data)
60    basic_frame_check(frame)
61
62
63# -----------------------------------------------------------------------------
64@pytest.mark.asyncio
65async def test_connection_and_disconnection() -> None:
66    devices = test_utils.TwoDevices()
67    await devices.setup_connection()
68
69    accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future()
70    channel = Server(devices[0]).listen(acceptor=accept_future.set_result)
71
72    assert devices.connections[1]
73    multiplexer = await Client(devices.connections[1]).start()
74    dlcs = await asyncio.gather(accept_future, multiplexer.open_dlc(channel))
75
76    queues: List[asyncio.Queue] = [asyncio.Queue(), asyncio.Queue()]
77    for dlc, queue in zip(dlcs, queues):
78        dlc.sink = queue.put_nowait
79
80    dlcs[0].write(b'The quick brown fox jumps over the lazy dog')
81    assert await queues[1].get() == b'The quick brown fox jumps over the lazy dog'
82
83    dlcs[1].write(b'Lorem ipsum dolor sit amet')
84    assert await queues[0].get() == b'Lorem ipsum dolor sit amet'
85
86    closed = asyncio.Event()
87    dlcs[1].on('close', closed.set)
88    await dlcs[1].disconnect()
89    await closed.wait()
90
91
92# -----------------------------------------------------------------------------
93@pytest.mark.asyncio
94async def test_receive_pdu_before_open_dlc_returns() -> None:
95    devices = await test_utils.TwoDevices.create_with_connection()
96    DATA = b'123'
97
98    accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future()
99    channel = Server(devices[0]).listen(acceptor=accept_future.set_result)
100
101    assert devices.connections[1]
102    multiplexer = await Client(devices.connections[1]).start()
103    open_dlc_task = asyncio.create_task(multiplexer.open_dlc(channel))
104
105    dlc_responder = await accept_future
106    dlc_responder.write(DATA)
107
108    dlc_initiator = await open_dlc_task
109    dlc_initiator_queue = asyncio.Queue()  # type: ignore[var-annotated]
110    dlc_initiator.sink = dlc_initiator_queue.put_nowait
111
112    assert await asyncio.wait_for(dlc_initiator_queue.get(), timeout=_TIMEOUT) == DATA
113
114
115# -----------------------------------------------------------------------------
116@pytest.mark.asyncio
117async def test_service_record():
118    HANDLE = 2
119    CHANNEL = 1
120    SERVICE_UUID = core.UUID('00000000-0000-0000-0000-000000000001')
121
122    devices = test_utils.TwoDevices()
123    await devices.setup_connection()
124
125    devices[0].sdp_service_records[HANDLE] = make_service_sdp_records(
126        HANDLE, CHANNEL, SERVICE_UUID
127    )
128
129    assert SERVICE_UUID in (await find_rfcomm_channels(devices.connections[1]))[CHANNEL]
130    assert (
131        await find_rfcomm_channel_with_uuid(devices.connections[1], SERVICE_UUID)
132        == CHANNEL
133    )
134
135
136# -----------------------------------------------------------------------------
137@pytest.mark.asyncio
138async def test_context():
139    devices = test_utils.TwoDevices()
140    await devices.setup_connection()
141
142    server = Server(devices[0])
143    with server:
144        assert server.l2cap_server is not None
145
146        client = Client(devices.connections[1])
147        async with client:
148            assert client.l2cap_channel is not None
149
150        assert client.l2cap_channel is None
151    assert RFCOMM_PSM not in devices[0].l2cap_channel_manager.servers
152
153
154# -----------------------------------------------------------------------------
155if __name__ == '__main__':
156    test_frames()
157