1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import logging
20import os
21import random
22import pytest
23
24from bumble.core import ProtocolError
25from bumble.l2cap import (
26    L2CAP_Connection_Request,
27    ClassicChannelSpec,
28    LeCreditBasedChannelSpec,
29)
30from .test_utils import TwoDevices
31
32
33# -----------------------------------------------------------------------------
34# Logging
35# -----------------------------------------------------------------------------
36logger = logging.getLogger(__name__)
37
38
39# -----------------------------------------------------------------------------
40
41
42# -----------------------------------------------------------------------------
43def test_helpers():
44    psm = L2CAP_Connection_Request.serialize_psm(0x01)
45    assert psm == bytes([0x01, 0x00])
46
47    psm = L2CAP_Connection_Request.serialize_psm(0x1023)
48    assert psm == bytes([0x23, 0x10])
49
50    psm = L2CAP_Connection_Request.serialize_psm(0x242311)
51    assert psm == bytes([0x11, 0x23, 0x24])
52
53    (offset, psm) = L2CAP_Connection_Request.parse_psm(
54        bytes([0x00, 0x01, 0x00, 0x44]), 1
55    )
56    assert offset == 3
57    assert psm == 0x01
58
59    (offset, psm) = L2CAP_Connection_Request.parse_psm(
60        bytes([0x00, 0x23, 0x10, 0x44]), 1
61    )
62    assert offset == 3
63    assert psm == 0x1023
64
65    (offset, psm) = L2CAP_Connection_Request.parse_psm(
66        bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1
67    )
68    assert offset == 4
69    assert psm == 0x242311
70
71    rq = L2CAP_Connection_Request(psm=0x01, source_cid=0x44)
72    brq = bytes(rq)
73    srq = L2CAP_Connection_Request.from_bytes(brq)
74    assert srq.psm == rq.psm
75    assert srq.source_cid == rq.source_cid
76
77
78# -----------------------------------------------------------------------------
79@pytest.mark.asyncio
80async def test_basic_connection():
81    devices = TwoDevices()
82    await devices.setup_connection()
83    psm = 1234
84
85    # Check that if there's no one listening, we can't connect
86    with pytest.raises(ProtocolError):
87        l2cap_channel = await devices.connections[0].create_l2cap_channel(
88            spec=LeCreditBasedChannelSpec(psm)
89        )
90
91    # Now add a listener
92    incoming_channel = None
93    received = []
94
95    def on_coc(channel):
96        nonlocal incoming_channel
97        incoming_channel = channel
98
99        def on_data(data):
100            received.append(data)
101
102        channel.sink = on_data
103
104    devices.devices[1].create_l2cap_server(
105        spec=LeCreditBasedChannelSpec(psm=1234), handler=on_coc
106    )
107    l2cap_channel = await devices.connections[0].create_l2cap_channel(
108        spec=LeCreditBasedChannelSpec(psm)
109    )
110
111    messages = (bytes([1, 2, 3]), bytes([4, 5, 6]), bytes(10000))
112    for message in messages:
113        l2cap_channel.write(message)
114        await asyncio.sleep(0)
115
116    await l2cap_channel.drain()
117
118    # Test closing
119    closed = [False, False]
120    closed_event = asyncio.Event()
121
122    def on_close(which, event):
123        closed[which] = True
124        if event:
125            event.set()
126
127    l2cap_channel.on('close', lambda: on_close(0, None))
128    incoming_channel.on('close', lambda: on_close(1, closed_event))
129    await l2cap_channel.disconnect()
130    assert closed == [True, True]
131    await closed_event.wait()
132
133    sent_bytes = b''.join(messages)
134    received_bytes = b''.join(received)
135    assert sent_bytes == received_bytes
136
137
138# -----------------------------------------------------------------------------
139async def transfer_payload(max_credits, mtu, mps):
140    devices = TwoDevices()
141    await devices.setup_connection()
142
143    received = []
144
145    def on_coc(channel):
146        def on_data(data):
147            received.append(data)
148
149        channel.sink = on_data
150
151    server = devices.devices[1].create_l2cap_server(
152        spec=LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps),
153        handler=on_coc,
154    )
155    l2cap_channel = await devices.connections[0].create_l2cap_channel(
156        spec=LeCreditBasedChannelSpec(server.psm)
157    )
158
159    messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)]
160    for message in messages:
161        l2cap_channel.write(message)
162        await asyncio.sleep(0)
163        if random.randint(0, 5) == 1:
164            await l2cap_channel.drain()
165
166    await l2cap_channel.drain()
167    await l2cap_channel.disconnect()
168
169    sent_bytes = b''.join(messages)
170    received_bytes = b''.join(received)
171    assert sent_bytes == received_bytes
172
173
174@pytest.mark.asyncio
175async def test_transfer():
176    for max_credits in (1, 10, 100, 10000):
177        for mtu in (50, 255, 256, 1000):
178            for mps in (50, 255, 256, 1000):
179                # print(max_credits, mtu, mps)
180                await transfer_payload(max_credits, mtu, mps)
181
182
183# -----------------------------------------------------------------------------
184@pytest.mark.asyncio
185async def test_bidirectional_transfer():
186    devices = TwoDevices()
187    await devices.setup_connection()
188
189    client_received = []
190    server_received = []
191    server_channel = None
192
193    def on_server_coc(channel):
194        nonlocal server_channel
195        server_channel = channel
196
197        def on_server_data(data):
198            server_received.append(data)
199
200        channel.sink = on_server_data
201
202    def on_client_data(data):
203        client_received.append(data)
204
205    server = devices.devices[1].create_l2cap_server(
206        spec=LeCreditBasedChannelSpec(), handler=on_server_coc
207    )
208    client_channel = await devices.connections[0].create_l2cap_channel(
209        spec=LeCreditBasedChannelSpec(server.psm)
210    )
211    client_channel.sink = on_client_data
212
213    messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100)]
214    for message in messages:
215        client_channel.write(message)
216        await client_channel.drain()
217        await asyncio.sleep(0)
218        server_channel.write(message)
219        await server_channel.drain()
220
221    await client_channel.disconnect()
222
223    message_bytes = b''.join(messages)
224    client_received_bytes = b''.join(client_received)
225    server_received_bytes = b''.join(server_received)
226    assert client_received_bytes == message_bytes
227    assert server_received_bytes == message_bytes
228
229
230# -----------------------------------------------------------------------------
231@pytest.mark.asyncio
232async def test_mtu():
233    devices = TwoDevices()
234    await devices.setup_connection()
235
236    def on_channel_open(channel):
237        assert channel.peer_mtu == 456
238
239    def on_channel(channel):
240        channel.on('open', lambda: on_channel_open(channel))
241
242    server = devices.devices[1].create_l2cap_server(
243        spec=ClassicChannelSpec(mtu=345), handler=on_channel
244    )
245    client_channel = await devices.connections[0].create_l2cap_channel(
246        spec=ClassicChannelSpec(server.psm, mtu=456)
247    )
248    assert client_channel.peer_mtu == 345
249
250
251# -----------------------------------------------------------------------------
252async def run():
253    test_helpers()
254    await test_basic_connection()
255    await test_transfer()
256    await test_bidirectional_transfer()
257    await test_mtu()
258
259
260# -----------------------------------------------------------------------------
261if __name__ == '__main__':
262    logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
263    asyncio.run(run())
264