1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import logging 20import os 21import random 22import pytest 23 24from bumble.core import ProtocolError 25from bumble.l2cap import ( 26 L2CAP_Connection_Request, 27 ClassicChannelSpec, 28 LeCreditBasedChannelSpec, 29) 30from .test_utils import TwoDevices 31 32 33# ----------------------------------------------------------------------------- 34# Logging 35# ----------------------------------------------------------------------------- 36logger = logging.getLogger(__name__) 37 38 39# ----------------------------------------------------------------------------- 40 41 42# ----------------------------------------------------------------------------- 43def test_helpers(): 44 psm = L2CAP_Connection_Request.serialize_psm(0x01) 45 assert psm == bytes([0x01, 0x00]) 46 47 psm = L2CAP_Connection_Request.serialize_psm(0x1023) 48 assert psm == bytes([0x23, 0x10]) 49 50 psm = L2CAP_Connection_Request.serialize_psm(0x242311) 51 assert psm == bytes([0x11, 0x23, 0x24]) 52 53 (offset, psm) = L2CAP_Connection_Request.parse_psm( 54 bytes([0x00, 0x01, 0x00, 0x44]), 1 55 ) 56 assert offset == 3 57 assert psm == 0x01 58 59 (offset, psm) = L2CAP_Connection_Request.parse_psm( 60 bytes([0x00, 0x23, 0x10, 0x44]), 1 61 ) 62 assert offset == 3 63 assert psm == 0x1023 64 65 (offset, psm) = L2CAP_Connection_Request.parse_psm( 66 bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1 67 ) 68 assert offset == 4 69 assert psm == 0x242311 70 71 rq = L2CAP_Connection_Request(psm=0x01, source_cid=0x44) 72 brq = bytes(rq) 73 srq = L2CAP_Connection_Request.from_bytes(brq) 74 assert srq.psm == rq.psm 75 assert srq.source_cid == rq.source_cid 76 77 78# ----------------------------------------------------------------------------- 79@pytest.mark.asyncio 80async def test_basic_connection(): 81 devices = TwoDevices() 82 await devices.setup_connection() 83 psm = 1234 84 85 # Check that if there's no one listening, we can't connect 86 with pytest.raises(ProtocolError): 87 l2cap_channel = await devices.connections[0].create_l2cap_channel( 88 spec=LeCreditBasedChannelSpec(psm) 89 ) 90 91 # Now add a listener 92 incoming_channel = None 93 received = [] 94 95 def on_coc(channel): 96 nonlocal incoming_channel 97 incoming_channel = channel 98 99 def on_data(data): 100 received.append(data) 101 102 channel.sink = on_data 103 104 devices.devices[1].create_l2cap_server( 105 spec=LeCreditBasedChannelSpec(psm=1234), handler=on_coc 106 ) 107 l2cap_channel = await devices.connections[0].create_l2cap_channel( 108 spec=LeCreditBasedChannelSpec(psm) 109 ) 110 111 messages = (bytes([1, 2, 3]), bytes([4, 5, 6]), bytes(10000)) 112 for message in messages: 113 l2cap_channel.write(message) 114 await asyncio.sleep(0) 115 116 await l2cap_channel.drain() 117 118 # Test closing 119 closed = [False, False] 120 closed_event = asyncio.Event() 121 122 def on_close(which, event): 123 closed[which] = True 124 if event: 125 event.set() 126 127 l2cap_channel.on('close', lambda: on_close(0, None)) 128 incoming_channel.on('close', lambda: on_close(1, closed_event)) 129 await l2cap_channel.disconnect() 130 assert closed == [True, True] 131 await closed_event.wait() 132 133 sent_bytes = b''.join(messages) 134 received_bytes = b''.join(received) 135 assert sent_bytes == received_bytes 136 137 138# ----------------------------------------------------------------------------- 139async def transfer_payload(max_credits, mtu, mps): 140 devices = TwoDevices() 141 await devices.setup_connection() 142 143 received = [] 144 145 def on_coc(channel): 146 def on_data(data): 147 received.append(data) 148 149 channel.sink = on_data 150 151 server = devices.devices[1].create_l2cap_server( 152 spec=LeCreditBasedChannelSpec(max_credits=max_credits, mtu=mtu, mps=mps), 153 handler=on_coc, 154 ) 155 l2cap_channel = await devices.connections[0].create_l2cap_channel( 156 spec=LeCreditBasedChannelSpec(server.psm) 157 ) 158 159 messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)] 160 for message in messages: 161 l2cap_channel.write(message) 162 await asyncio.sleep(0) 163 if random.randint(0, 5) == 1: 164 await l2cap_channel.drain() 165 166 await l2cap_channel.drain() 167 await l2cap_channel.disconnect() 168 169 sent_bytes = b''.join(messages) 170 received_bytes = b''.join(received) 171 assert sent_bytes == received_bytes 172 173 174@pytest.mark.asyncio 175async def test_transfer(): 176 for max_credits in (1, 10, 100, 10000): 177 for mtu in (50, 255, 256, 1000): 178 for mps in (50, 255, 256, 1000): 179 # print(max_credits, mtu, mps) 180 await transfer_payload(max_credits, mtu, mps) 181 182 183# ----------------------------------------------------------------------------- 184@pytest.mark.asyncio 185async def test_bidirectional_transfer(): 186 devices = TwoDevices() 187 await devices.setup_connection() 188 189 client_received = [] 190 server_received = [] 191 server_channel = None 192 193 def on_server_coc(channel): 194 nonlocal server_channel 195 server_channel = channel 196 197 def on_server_data(data): 198 server_received.append(data) 199 200 channel.sink = on_server_data 201 202 def on_client_data(data): 203 client_received.append(data) 204 205 server = devices.devices[1].create_l2cap_server( 206 spec=LeCreditBasedChannelSpec(), handler=on_server_coc 207 ) 208 client_channel = await devices.connections[0].create_l2cap_channel( 209 spec=LeCreditBasedChannelSpec(server.psm) 210 ) 211 client_channel.sink = on_client_data 212 213 messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100)] 214 for message in messages: 215 client_channel.write(message) 216 await client_channel.drain() 217 await asyncio.sleep(0) 218 server_channel.write(message) 219 await server_channel.drain() 220 221 await client_channel.disconnect() 222 223 message_bytes = b''.join(messages) 224 client_received_bytes = b''.join(client_received) 225 server_received_bytes = b''.join(server_received) 226 assert client_received_bytes == message_bytes 227 assert server_received_bytes == message_bytes 228 229 230# ----------------------------------------------------------------------------- 231@pytest.mark.asyncio 232async def test_mtu(): 233 devices = TwoDevices() 234 await devices.setup_connection() 235 236 def on_channel_open(channel): 237 assert channel.peer_mtu == 456 238 239 def on_channel(channel): 240 channel.on('open', lambda: on_channel_open(channel)) 241 242 server = devices.devices[1].create_l2cap_server( 243 spec=ClassicChannelSpec(mtu=345), handler=on_channel 244 ) 245 client_channel = await devices.connections[0].create_l2cap_channel( 246 spec=ClassicChannelSpec(server.psm, mtu=456) 247 ) 248 assert client_channel.peer_mtu == 345 249 250 251# ----------------------------------------------------------------------------- 252async def run(): 253 test_helpers() 254 await test_basic_connection() 255 await test_transfer() 256 await test_bidirectional_transfer() 257 await test_mtu() 258 259 260# ----------------------------------------------------------------------------- 261if __name__ == '__main__': 262 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) 263 asyncio.run(run()) 264