1# Copyright 2021-2024 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18from __future__ import annotations 19import asyncio 20import datetime 21import enum 22import functools 23from importlib import resources 24import json 25import os 26import logging 27import pathlib 28from typing import Optional, List, cast 29import weakref 30import struct 31 32import ctypes 33import wasmtime 34import wasmtime.loader 35import liblc3 # type: ignore 36 37import click 38import aiohttp.web 39 40import bumble 41from bumble.core import AdvertisingData 42from bumble.colors import color 43from bumble.device import Device, DeviceConfiguration, AdvertisingParameters 44from bumble.transport import open_transport 45from bumble.profiles import ascs, bap, pacs 46from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket 47 48# ----------------------------------------------------------------------------- 49# Logging 50# ----------------------------------------------------------------------------- 51logger = logging.getLogger(__name__) 52 53# ----------------------------------------------------------------------------- 54# Constants 55# ----------------------------------------------------------------------------- 56DEFAULT_UI_PORT = 7654 57 58 59def _sink_pac_record() -> pacs.PacRecord: 60 return pacs.PacRecord( 61 coding_format=CodingFormat(CodecID.LC3), 62 codec_specific_capabilities=bap.CodecSpecificCapabilities( 63 supported_sampling_frequencies=( 64 bap.SupportedSamplingFrequency.FREQ_8000 65 | bap.SupportedSamplingFrequency.FREQ_16000 66 | bap.SupportedSamplingFrequency.FREQ_24000 67 | bap.SupportedSamplingFrequency.FREQ_32000 68 | bap.SupportedSamplingFrequency.FREQ_48000 69 ), 70 supported_frame_durations=( 71 bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED 72 ), 73 supported_audio_channel_count=[1, 2], 74 min_octets_per_codec_frame=26, 75 max_octets_per_codec_frame=240, 76 supported_max_codec_frames_per_sdu=2, 77 ), 78 ) 79 80 81def _source_pac_record() -> pacs.PacRecord: 82 return pacs.PacRecord( 83 coding_format=CodingFormat(CodecID.LC3), 84 codec_specific_capabilities=bap.CodecSpecificCapabilities( 85 supported_sampling_frequencies=( 86 bap.SupportedSamplingFrequency.FREQ_8000 87 | bap.SupportedSamplingFrequency.FREQ_16000 88 | bap.SupportedSamplingFrequency.FREQ_24000 89 | bap.SupportedSamplingFrequency.FREQ_32000 90 | bap.SupportedSamplingFrequency.FREQ_48000 91 ), 92 supported_frame_durations=( 93 bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED 94 ), 95 supported_audio_channel_count=[1], 96 min_octets_per_codec_frame=30, 97 max_octets_per_codec_frame=100, 98 supported_max_codec_frames_per_sdu=1, 99 ), 100 ) 101 102 103# ----------------------------------------------------------------------------- 104# WASM - liblc3 105# ----------------------------------------------------------------------------- 106store = wasmtime.loader.store 107_memory = cast(wasmtime.Memory, liblc3.memory) 108STACK_POINTER = _memory.data_len(store) 109_memory.grow(store, 1) 110# Mapping wasmtime memory to linear address 111memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address( 112 ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore 113) 114 115 116class Liblc3PcmFormat(enum.IntEnum): 117 S16 = 0 118 S24 = 1 119 S24_3LE = 2 120 FLOAT = 3 121 122 123MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000) 124MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000) 125 126DECODER_STACK_POINTER = STACK_POINTER 127ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2 128DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2 129ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192 130DEFAULT_PCM_SAMPLE_RATE = 48000 131DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16 132DEFAULT_PCM_BYTES_PER_SAMPLE = 2 133 134 135encoders: List[int] = [] 136decoders: List[int] = [] 137 138 139def setup_encoders( 140 sample_rate_hz: int, frame_duration_us: int, num_channels: int 141) -> None: 142 logger.info( 143 f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels" 144 ) 145 encoders[:num_channels] = [ 146 liblc3.lc3_setup_encoder( 147 frame_duration_us, 148 sample_rate_hz, 149 DEFAULT_PCM_SAMPLE_RATE, # Input sample rate 150 ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i, 151 ) 152 for i in range(num_channels) 153 ] 154 155 156def setup_decoders( 157 sample_rate_hz: int, frame_duration_us: int, num_channels: int 158) -> None: 159 logger.info( 160 f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels" 161 ) 162 decoders[:num_channels] = [ 163 liblc3.lc3_setup_decoder( 164 frame_duration_us, 165 sample_rate_hz, 166 DEFAULT_PCM_SAMPLE_RATE, # Output sample rate 167 DECODER_STACK_POINTER + MAX_DECODER_SIZE * i, 168 ) 169 for i in range(num_channels) 170 ] 171 172 173def decode( 174 frame_duration_us: int, 175 num_channels: int, 176 input_bytes: bytes, 177) -> bytes: 178 if not input_bytes: 179 return b'' 180 181 input_buffer_offset = DECODE_BUFFER_STACK_POINTER 182 input_buffer_size = len(input_bytes) 183 input_bytes_per_frame = input_buffer_size // num_channels 184 185 # Copy into wasm 186 memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore 187 188 output_buffer_offset = input_buffer_offset + input_buffer_size 189 output_buffer_size = ( 190 liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE) 191 * DEFAULT_PCM_BYTES_PER_SAMPLE 192 * num_channels 193 ) 194 195 for i in range(num_channels): 196 res = liblc3.lc3_decode( 197 decoders[i], 198 input_buffer_offset + input_bytes_per_frame * i, 199 input_bytes_per_frame, 200 DEFAULT_PCM_FORMAT, 201 output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE, 202 num_channels, # Stride 203 ) 204 205 if res != 0: 206 logging.error(f"Parsing failed, res={res}") 207 208 # Extract decoded data from the output buffer 209 return bytes( 210 memory[output_buffer_offset : output_buffer_offset + output_buffer_size] 211 ) 212 213 214def encode( 215 sdu_length: int, 216 num_channels: int, 217 stride: int, 218 input_bytes: bytes, 219) -> bytes: 220 if not input_bytes: 221 return b'' 222 223 input_buffer_offset = ENCODE_BUFFER_STACK_POINTER 224 input_buffer_size = len(input_bytes) 225 226 # Copy into wasm 227 memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore 228 229 output_buffer_offset = input_buffer_offset + input_buffer_size 230 output_buffer_size = sdu_length 231 output_frame_size = output_buffer_size // num_channels 232 233 for i in range(num_channels): 234 res = liblc3.lc3_encode( 235 encoders[i], 236 DEFAULT_PCM_FORMAT, 237 input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i, 238 stride, 239 output_frame_size, 240 output_buffer_offset + output_frame_size * i, 241 ) 242 243 if res != 0: 244 logging.error(f"Parsing failed, res={res}") 245 246 # Extract decoded data from the output buffer 247 return bytes( 248 memory[output_buffer_offset : output_buffer_offset + output_buffer_size] 249 ) 250 251 252async def lc3_source_task( 253 filename: str, 254 sdu_length: int, 255 frame_duration_us: int, 256 device: Device, 257 cis_handle: int, 258) -> None: 259 with open(filename, 'rb') as f: 260 header = f.read(44) 261 assert header[8:12] == b'WAVE' 262 263 pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = ( 264 struct.unpack("<HIIHH", header[22:36]) 265 ) 266 assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE 267 assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8 268 269 frame_bytes = ( 270 liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE) 271 * DEFAULT_PCM_BYTES_PER_SAMPLE 272 ) 273 packet_sequence_number = 0 274 275 while True: 276 next_round = datetime.datetime.now() + datetime.timedelta( 277 microseconds=frame_duration_us 278 ) 279 pcm_data = f.read(frame_bytes) 280 sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data) 281 282 iso_packet = HCI_IsoDataPacket( 283 connection_handle=cis_handle, 284 data_total_length=sdu_length + 4, 285 packet_sequence_number=packet_sequence_number, 286 pb_flag=0b10, 287 packet_status_flag=0, 288 iso_sdu_length=sdu_length, 289 iso_sdu_fragment=sdu, 290 ) 291 device.host.send_hci_packet(iso_packet) 292 packet_sequence_number += 1 293 sleep_time = next_round - datetime.datetime.now() 294 await asyncio.sleep(sleep_time.total_seconds()) 295 296 297# ----------------------------------------------------------------------------- 298class UiServer: 299 speaker: weakref.ReferenceType[Speaker] 300 port: int 301 302 def __init__(self, speaker: Speaker, port: int) -> None: 303 self.speaker = weakref.ref(speaker) 304 self.port = port 305 self.channel_socket = None 306 307 async def start_http(self) -> None: 308 """Start the UI HTTP server.""" 309 310 app = aiohttp.web.Application() 311 app.add_routes( 312 [ 313 aiohttp.web.get('/', self.get_static), 314 aiohttp.web.get('/index.html', self.get_static), 315 aiohttp.web.get('/channel', self.get_channel), 316 ] 317 ) 318 319 runner = aiohttp.web.AppRunner(app) 320 await runner.setup() 321 site = aiohttp.web.TCPSite(runner, 'localhost', self.port) 322 print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green')) 323 await site.start() 324 325 async def get_static(self, request): 326 path = request.path 327 if path == '/': 328 path = '/index.html' 329 if path.endswith('.html'): 330 content_type = 'text/html' 331 elif path.endswith('.js'): 332 content_type = 'text/javascript' 333 elif path.endswith('.css'): 334 content_type = 'text/css' 335 elif path.endswith('.svg'): 336 content_type = 'image/svg+xml' 337 else: 338 content_type = 'text/plain' 339 text = ( 340 resources.files("bumble.apps.lea_unicast") 341 .joinpath(pathlib.Path(path).relative_to('/')) 342 .read_text(encoding="utf-8") 343 ) 344 return aiohttp.web.Response(text=text, content_type=content_type) 345 346 async def get_channel(self, request): 347 ws = aiohttp.web.WebSocketResponse() 348 await ws.prepare(request) 349 350 # Process messages until the socket is closed. 351 self.channel_socket = ws 352 async for message in ws: 353 if message.type == aiohttp.WSMsgType.TEXT: 354 logger.debug(f'<<< received message: {message.data}') 355 await self.on_message(message.data) 356 elif message.type == aiohttp.WSMsgType.ERROR: 357 logger.debug( 358 f'channel connection closed with exception {ws.exception()}' 359 ) 360 361 self.channel_socket = None 362 logger.debug('--- channel connection closed') 363 364 return ws 365 366 async def on_message(self, message_str: str): 367 # Parse the message as JSON 368 message = json.loads(message_str) 369 370 # Dispatch the message 371 message_type = message['type'] 372 message_params = message.get('params', {}) 373 handler = getattr(self, f'on_{message_type}_message') 374 if handler: 375 await handler(**message_params) 376 377 async def on_hello_message(self): 378 await self.send_message( 379 'hello', 380 bumble_version=bumble.__version__, 381 codec=self.speaker().codec, 382 streamState=self.speaker().stream_state.name, 383 ) 384 if connection := self.speaker().connection: 385 await self.send_message( 386 'connection', 387 peer_address=connection.peer_address.to_string(False), 388 peer_name=connection.peer_name, 389 ) 390 391 async def send_message(self, message_type: str, **kwargs) -> None: 392 if self.channel_socket is None: 393 return 394 395 message = {'type': message_type, 'params': kwargs} 396 await self.channel_socket.send_json(message) 397 398 async def send_audio(self, data: bytes) -> None: 399 if self.channel_socket is None: 400 return 401 402 try: 403 await self.channel_socket.send_bytes(data) 404 except Exception as error: 405 logger.warning(f'exception while sending audio packet: {error}') 406 407 408# ----------------------------------------------------------------------------- 409class Speaker: 410 411 def __init__( 412 self, 413 device_config_path: Optional[str], 414 ui_port: int, 415 transport: str, 416 lc3_input_file_path: str, 417 ): 418 self.device_config_path = device_config_path 419 self.transport = transport 420 self.lc3_input_file_path = lc3_input_file_path 421 422 # Create an HTTP server for the UI 423 self.ui_server = UiServer(speaker=self, port=ui_port) 424 425 async def run(self) -> None: 426 await self.ui_server.start_http() 427 428 async with await open_transport(self.transport) as hci_transport: 429 # Create a device 430 if self.device_config_path: 431 device_config = DeviceConfiguration.from_file(self.device_config_path) 432 else: 433 device_config = DeviceConfiguration( 434 name="Bumble LE Headphone", 435 class_of_device=0x244418, 436 keystore="JsonKeyStore", 437 advertising_interval_min=25, 438 advertising_interval_max=25, 439 address=Address('F1:F2:F3:F4:F5:F6'), 440 ) 441 442 device_config.le_enabled = True 443 device_config.cis_enabled = True 444 self.device = Device.from_config_with_hci( 445 device_config, hci_transport.source, hci_transport.sink 446 ) 447 448 self.device.add_service( 449 pacs.PublishedAudioCapabilitiesService( 450 supported_source_context=bap.ContextType(0xFFFF), 451 available_source_context=bap.ContextType(0xFFFF), 452 supported_sink_context=bap.ContextType(0xFFFF), # All context types 453 available_sink_context=bap.ContextType(0xFFFF), # All context types 454 sink_audio_locations=( 455 bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT 456 ), 457 sink_pac=[_sink_pac_record()], 458 source_audio_locations=bap.AudioLocation.FRONT_LEFT, 459 source_pac=[_source_pac_record()], 460 ) 461 ) 462 463 ascs_service = ascs.AudioStreamControlService( 464 self.device, sink_ase_id=[1], source_ase_id=[2] 465 ) 466 self.device.add_service(ascs_service) 467 468 advertising_data = bytes( 469 AdvertisingData( 470 [ 471 ( 472 AdvertisingData.COMPLETE_LOCAL_NAME, 473 bytes(device_config.name, 'utf-8'), 474 ), 475 ( 476 AdvertisingData.FLAGS, 477 bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]), 478 ), 479 ( 480 AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, 481 bytes(pacs.PublishedAudioCapabilitiesService.UUID), 482 ), 483 ] 484 ) 485 ) + bytes(bap.UnicastServerAdvertisingData()) 486 487 def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine): 488 codec_config = ase.codec_specific_configuration 489 assert isinstance(codec_config, bap.CodecSpecificConfiguration) 490 pcm = decode( 491 codec_config.frame_duration.us, 492 codec_config.audio_channel_allocation.channel_count, 493 pdu.iso_sdu_fragment, 494 ) 495 self.device.abort_on('disconnection', self.ui_server.send_audio(pcm)) 496 497 def on_ase_state_change(ase: ascs.AseStateMachine) -> None: 498 if ase.state == ascs.AseStateMachine.State.STREAMING: 499 codec_config = ase.codec_specific_configuration 500 assert isinstance(codec_config, bap.CodecSpecificConfiguration) 501 assert ase.cis_link 502 if ase.role == ascs.AudioRole.SOURCE: 503 ase.cis_link.abort_on( 504 'disconnection', 505 lc3_source_task( 506 filename=self.lc3_input_file_path, 507 sdu_length=( 508 codec_config.codec_frames_per_sdu 509 * codec_config.octets_per_codec_frame 510 ), 511 frame_duration_us=codec_config.frame_duration.us, 512 device=self.device, 513 cis_handle=ase.cis_link.handle, 514 ), 515 ) 516 else: 517 ase.cis_link.sink = functools.partial(on_pdu, ase=ase) 518 elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED: 519 codec_config = ase.codec_specific_configuration 520 assert isinstance(codec_config, bap.CodecSpecificConfiguration) 521 if ase.role == ascs.AudioRole.SOURCE: 522 setup_encoders( 523 codec_config.sampling_frequency.hz, 524 codec_config.frame_duration.us, 525 codec_config.audio_channel_allocation.channel_count, 526 ) 527 else: 528 setup_decoders( 529 codec_config.sampling_frequency.hz, 530 codec_config.frame_duration.us, 531 codec_config.audio_channel_allocation.channel_count, 532 ) 533 534 for ase in ascs_service.ase_state_machines.values(): 535 ase.on('state_change', functools.partial(on_ase_state_change, ase=ase)) 536 537 await self.device.power_on() 538 await self.device.create_advertising_set( 539 advertising_data=advertising_data, 540 auto_restart=True, 541 advertising_parameters=AdvertisingParameters( 542 primary_advertising_interval_min=100, 543 primary_advertising_interval_max=100, 544 ), 545 ) 546 547 await hci_transport.source.terminated 548 549 550@click.command() 551@click.option( 552 '--ui-port', 553 'ui_port', 554 metavar='HTTP_PORT', 555 default=DEFAULT_UI_PORT, 556 show_default=True, 557 help='HTTP port for the UI server', 558) 559@click.option('--device-config', metavar='FILENAME', help='Device configuration file') 560@click.argument('transport') 561@click.argument('lc3_file') 562def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None: 563 """Run the speaker.""" 564 565 asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run()) 566 567 568# ----------------------------------------------------------------------------- 569def main(): 570 logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) 571 speaker() 572 573 574# ----------------------------------------------------------------------------- 575if __name__ == "__main__": 576 main() # pylint: disable=no-value-for-parameter 577