1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Client for the pw_transfer service, which transmits data over pw_rpc.""" 15 16import asyncio 17import ctypes 18import logging 19import threading 20from typing import Any, Callable 21 22from pw_rpc.callback_client import BidirectionalStreamingCall 23from pw_status import Status 24 25from pw_transfer.transfer import ( 26 ProgressCallback, 27 ProtocolVersion, 28 ReadTransfer, 29 Transfer, 30 WriteTransfer, 31) 32from pw_transfer.chunk import Chunk 33from pw_transfer import transfer_pb2 34 35_LOG = logging.getLogger(__package__) 36 37_TransferDict = dict[int, Transfer] 38 39 40class _TransferStream: 41 def __init__( 42 self, 43 method, 44 chunk_handler: Callable[[Chunk], Any], 45 error_handler: Callable[[Status], Any], 46 max_reopen_attempts=3, 47 ): 48 self._method = method 49 self._chunk_handler = chunk_handler 50 self._error_handler = error_handler 51 self._call: BidirectionalStreamingCall | None = None 52 self._reopen_attempts = 0 53 self._max_reopen_attempts = max_reopen_attempts 54 55 def is_open(self) -> bool: 56 return self._call is not None 57 58 def open(self, force: bool = False) -> None: 59 if force or self._call is None: 60 self._call = self._method.invoke( 61 lambda _, chunk: self._on_chunk_received(chunk), 62 on_error=lambda _, status: self._on_stream_error(status), 63 ) 64 65 def close(self) -> None: 66 if self._call is not None: 67 self._call.cancel() 68 self._call = None 69 70 def send(self, chunk: Chunk) -> None: 71 assert self._call is not None 72 self._call.send(chunk.to_message()) 73 74 def _on_chunk_received(self, chunk: Chunk) -> None: 75 self._reopen_attempts = 0 76 self._chunk_handler(chunk) 77 78 def _on_stream_error(self, rpc_status: Status) -> None: 79 if rpc_status is Status.FAILED_PRECONDITION: 80 # FAILED_PRECONDITION indicates that the stream packet was not 81 # recognized as the stream is not open. Attempt to re-open the 82 # stream automatically. 83 self.open(force=True) 84 else: 85 # Other errors are unrecoverable; clear the stream. 86 _LOG.error('Transfer stream shut down with status %s', rpc_status) 87 self._call = None 88 89 self._error_handler(rpc_status) 90 91 92class Manager: # pylint: disable=too-many-instance-attributes 93 """A manager for transmitting data through an RPC TransferService. 94 95 This should be initialized with an active Manager over an RPC channel. Only 96 one instance of this class should exist for a configured RPC TransferService 97 -- the Manager supports multiple simultaneous transfers. 98 99 When created, a Manager starts a separate thread in which transfer 100 communications and events are handled. 101 """ 102 103 def __init__( 104 self, 105 rpc_transfer_service, 106 *, 107 default_response_timeout_s: float = 2.0, 108 initial_response_timeout_s: float = 4.0, 109 max_retries: int = 3, 110 max_lifetime_retries: int = 1500, 111 max_chunk_size_bytes: int = 1024, 112 default_protocol_version=ProtocolVersion.VERSION_TWO, 113 ): 114 """Initializes a Manager on top of a TransferService. 115 116 Args: 117 rpc_transfer_service: the pw_rpc transfer service client 118 default_response_timeout_s: max time to wait between receiving packets 119 initial_response_timeout_s: timeout for the first packet; may be 120 longer to account for transfer handler initialization 121 max_retries: number of times to retry a single package after a timeout 122 max_lifetime_retires: Cumulative maximum number of times to retry over 123 the course of the transfer before giving up. 124 max_chunk_size_bytes: In a read transfer, the maximum size of data the 125 server should send within a single packet. 126 default_protocol_version: Version of the pw_transfer protocol to use. 127 Defaults to the latest, but can be set to legacy for projects 128 which use legacy devices. 129 """ 130 self._service: Any = rpc_transfer_service 131 self._default_response_timeout_s = default_response_timeout_s 132 self._initial_response_timeout_s = initial_response_timeout_s 133 self.max_retries = max_retries 134 self.max_lifetime_retries = max_lifetime_retries 135 self._max_chunk_size_bytes = max_chunk_size_bytes 136 self._default_protocol_version = default_protocol_version 137 138 # Ongoing transfers in the service by resource ID. 139 self._read_transfers: _TransferDict = {} 140 self._write_transfers: _TransferDict = {} 141 self._next_session_id = ctypes.c_uint32(1) 142 143 self._loop = asyncio.new_event_loop() 144 # Set the event loop for the current thread. 145 asyncio.set_event_loop(self._loop) 146 147 # Queues are used for communication between the Manager context and the 148 # dedicated asyncio transfer thread. 149 self._new_transfer_queue: asyncio.Queue = asyncio.Queue() 150 self._read_chunk_queue: asyncio.Queue = asyncio.Queue() 151 self._write_chunk_queue: asyncio.Queue = asyncio.Queue() 152 self._quit_event = asyncio.Event() 153 154 self._thread = threading.Thread( 155 target=self._start_event_loop_thread, daemon=True 156 ) 157 158 # RPC streams for read and write transfers. These are shareable by 159 # multiple transfers of the same type. 160 self._read_stream = _TransferStream( 161 self._service.Read, 162 lambda chunk: self._loop.call_soon_threadsafe( 163 self._read_chunk_queue.put_nowait, chunk 164 ), 165 self._on_read_error, 166 ) 167 self._write_stream = _TransferStream( 168 self._service.Write, 169 lambda chunk: self._loop.call_soon_threadsafe( 170 self._write_chunk_queue.put_nowait, chunk 171 ), 172 self._on_write_error, 173 ) 174 175 self._thread.start() 176 177 def __del__(self): 178 # Notify the thread that the transfer manager is being destroyed and 179 # wait for it to exit. 180 if self._thread.is_alive(): 181 self._loop.call_soon_threadsafe(self._quit_event.set) 182 self._thread.join() 183 184 def read( 185 self, 186 resource_id: int, 187 progress_callback: ProgressCallback | None = None, 188 protocol_version: ProtocolVersion | None = None, 189 chunk_timeout_s: float | None = None, 190 initial_timeout_s: float | None = None, 191 initial_offset: int = 0, 192 ) -> bytes: 193 """Receives ("downloads") data from the server. 194 195 Args: 196 resource_id: ID of the resource from which to read. 197 progress_callback: Optional callback periodically invoked throughout 198 the transfer with the transfer state. Can be used to provide user- 199 facing status updates such as progress bars. 200 protocol_version: The desired protocol version to use for this 201 transfer. Defaults to the version the manager was initialized 202 (typically VERSION_TWO). 203 chunk_timeout_s: Timeout for any individual chunk. 204 initial_timeout_s: Timeout for the first chunk, overrides 205 chunk_timeout_s. 206 initial_offset: Initial offset to start reading from. Must be 207 supported by the transfer handler. All transfers support starting 208 from 0, the default. Returned bytes will not have any padding 209 related to this initial offset. No seeking is done in the transfer 210 operation on the client side. 211 212 Raises: 213 Error: the transfer failed to complete 214 """ 215 216 if resource_id in self._read_transfers: 217 raise ValueError( 218 f'Read transfer for resource {resource_id} already exists' 219 ) 220 221 if protocol_version is None: 222 protocol_version = self._default_protocol_version 223 224 if protocol_version == ProtocolVersion.LEGACY and initial_offset != 0: 225 raise ValueError( 226 f'Unsupported transfer with offset {initial_offset} started ' 227 + 'with legacy protocol' 228 ) 229 230 session_id = ( 231 resource_id 232 if protocol_version is ProtocolVersion.LEGACY 233 else self.assign_session_id() 234 ) 235 236 if chunk_timeout_s is None: 237 chunk_timeout_s = self._default_response_timeout_s 238 239 if initial_timeout_s is None: 240 initial_timeout_s = self._initial_response_timeout_s 241 242 transfer = ReadTransfer( 243 session_id, 244 resource_id, 245 self._read_stream.send, 246 self._end_read_transfer, 247 chunk_timeout_s, 248 initial_timeout_s, 249 self.max_retries, 250 self.max_lifetime_retries, 251 protocol_version, 252 max_chunk_size=self._max_chunk_size_bytes, 253 progress_callback=progress_callback, 254 initial_offset=initial_offset, 255 ) 256 self._start_read_transfer(transfer) 257 258 transfer.done.wait() 259 260 if not transfer.status.ok(): 261 raise Error(transfer.resource_id, transfer.status) 262 263 return transfer.data 264 265 def write( 266 self, 267 resource_id: int, 268 data: bytes | str, 269 progress_callback: ProgressCallback | None = None, 270 protocol_version: ProtocolVersion | None = None, 271 chunk_timeout_s: Any | None = None, 272 initial_timeout_s: Any | None = None, 273 initial_offset: int = 0, 274 ) -> None: 275 """Transmits ("uploads") data to the server. 276 277 Args: 278 resource_id: ID of the resource to which to write. 279 data: Data to send to the server. 280 progress_callback: Optional callback periodically invoked throughout 281 the transfer with the transfer state. Can be used to provide user- 282 facing status updates such as progress bars. 283 protocol_version: The desired protocol version to use for this 284 transfer. Defaults to the version the manager was initialized 285 (defaults to LATEST). 286 chunk_timeout_s: Timeout for any individual chunk. 287 initial_timeout_s: Timeout for the first chunk, overrides 288 chunk_timeout_s. 289 initial_offset: Initial offset to start writing to. Must be supported 290 by the transfer handler. All transfers support starting from 0, 291 the default. data arg should start with the data you want to see 292 starting at this initial offset on the server. No seeking is done 293 in the transfer operation on the client side. 294 295 Raises: 296 Error: the transfer failed to complete 297 """ 298 299 if isinstance(data, str): 300 data = data.encode() 301 302 if resource_id in self._write_transfers: 303 raise ValueError( 304 f'Write transfer for resource {resource_id} already exists' 305 ) 306 307 if protocol_version is None: 308 protocol_version = self._default_protocol_version 309 310 if ( 311 protocol_version != ProtocolVersion.VERSION_TWO 312 and initial_offset != 0 313 ): 314 raise ValueError( 315 f'Unsupported transfer with offset {initial_offset} started ' 316 + 'with legacy protocol' 317 ) 318 319 session_id = ( 320 resource_id 321 if protocol_version is ProtocolVersion.LEGACY 322 else self.assign_session_id() 323 ) 324 325 if chunk_timeout_s is None: 326 chunk_timeout_s = self._default_response_timeout_s 327 328 if initial_timeout_s is None: 329 initial_timeout_s = self._initial_response_timeout_s 330 331 transfer = WriteTransfer( 332 session_id, 333 resource_id, 334 data, 335 self._write_stream.send, 336 self._end_write_transfer, 337 chunk_timeout_s, 338 initial_timeout_s, 339 self.max_retries, 340 self.max_lifetime_retries, 341 protocol_version, 342 progress_callback=progress_callback, 343 initial_offset=initial_offset, 344 ) 345 self._start_write_transfer(transfer) 346 347 transfer.done.wait() 348 349 if not transfer.status.ok(): 350 raise Error(transfer.resource_id, transfer.status) 351 352 def assign_session_id(self) -> int: 353 new_id = self._next_session_id.value 354 355 self._next_session_id = ctypes.c_uint32(self._next_session_id.value + 1) 356 if self._next_session_id.value == 0: 357 self._next_session_id = ctypes.c_uint32(1) 358 359 return new_id 360 361 def _start_event_loop_thread(self): 362 """Entry point for event loop thread that starts an asyncio context.""" 363 asyncio.set_event_loop(self._loop) 364 365 # Recreate the async communication channels in the context of the 366 # running event loop. 367 self._new_transfer_queue = asyncio.Queue() 368 self._read_chunk_queue = asyncio.Queue() 369 self._write_chunk_queue = asyncio.Queue() 370 self._quit_event = asyncio.Event() 371 372 self._loop.create_task(self._transfer_event_loop()) 373 self._loop.run_forever() 374 375 async def _transfer_event_loop(self): 376 """Main async event loop.""" 377 exit_thread = self._loop.create_task(self._quit_event.wait()) 378 new_transfer = self._loop.create_task(self._new_transfer_queue.get()) 379 read_chunk = self._loop.create_task(self._read_chunk_queue.get()) 380 write_chunk = self._loop.create_task(self._write_chunk_queue.get()) 381 382 while not self._quit_event.is_set(): 383 # Perform a select(2)-like wait for one of several events to occur. 384 done, _ = await asyncio.wait( 385 (exit_thread, new_transfer, read_chunk, write_chunk), 386 return_when=asyncio.FIRST_COMPLETED, 387 ) 388 389 if exit_thread in done: 390 break 391 392 if new_transfer in done: 393 await new_transfer.result().begin() 394 new_transfer = self._loop.create_task( 395 self._new_transfer_queue.get() 396 ) 397 398 if read_chunk in done: 399 self._loop.create_task( 400 self._handle_chunk( 401 self._read_transfers, read_chunk.result() 402 ) 403 ) 404 read_chunk = self._loop.create_task( 405 self._read_chunk_queue.get() 406 ) 407 408 if write_chunk in done: 409 self._loop.create_task( 410 self._handle_chunk( 411 self._write_transfers, write_chunk.result() 412 ) 413 ) 414 write_chunk = self._loop.create_task( 415 self._write_chunk_queue.get() 416 ) 417 418 self._loop.stop() 419 420 @staticmethod 421 async def _handle_chunk( 422 transfers: _TransferDict, message: transfer_pb2.Chunk 423 ) -> None: 424 """Processes an incoming chunk from a stream. 425 426 The chunk is dispatched to an active transfer based on its ID. If the 427 transfer indicates that it is complete, the provided completion callback 428 is invoked. 429 """ 430 431 chunk = Chunk.from_message(message) 432 433 # Find a transfer for the chunk in the list of active transfers. 434 try: 435 if chunk.protocol_version is ProtocolVersion.LEGACY: 436 transfer = next( 437 t 438 for t in transfers.values() 439 if t.resource_id == chunk.session_id 440 ) 441 else: 442 transfer = next( 443 t for t in transfers.values() if t.id == chunk.id() 444 ) 445 except StopIteration: 446 _LOG.error( 447 'TransferManager received chunk for unknown transfer %d', 448 chunk.id(), 449 ) 450 # TODO(frolv): What should be done here, if anything? 451 return 452 453 await transfer.handle_chunk(chunk) 454 455 def _on_read_error(self, status: Status) -> None: 456 """Callback for an RPC error in the read stream.""" 457 458 for transfer in self._read_transfers.values(): 459 transfer.finish(Status.INTERNAL, skip_callback=True) 460 self._read_transfers.clear() 461 462 _LOG.error('Read stream shut down: %s', status) 463 464 def _on_write_error(self, status: Status) -> None: 465 """Callback for an RPC error in the write stream.""" 466 467 for transfer in self._write_transfers.values(): 468 transfer.finish(Status.INTERNAL, skip_callback=True) 469 self._write_transfers.clear() 470 471 _LOG.error('Write stream shut down: %s', status) 472 473 def _start_read_transfer(self, transfer: Transfer) -> None: 474 """Begins a new read transfer, opening the stream if it isn't.""" 475 476 self._read_transfers[transfer.resource_id] = transfer 477 self._read_stream.open() 478 479 _LOG.debug('Starting new read transfer %d', transfer.id) 480 delay = 1 481 self._loop.call_soon_threadsafe( 482 self._loop.call_later, 483 delay, 484 self._new_transfer_queue.put_nowait, 485 transfer, 486 ) 487 488 def _end_read_transfer(self, transfer: Transfer) -> None: 489 """Completes a read transfer.""" 490 del self._read_transfers[transfer.resource_id] 491 492 if not transfer.status.ok(): 493 _LOG.error( 494 'Read transfer %d terminated with status %s', 495 transfer.id, 496 transfer.status, 497 ) 498 499 def _start_write_transfer(self, transfer: Transfer) -> None: 500 """Begins a new write transfer, opening the stream if it isn't.""" 501 502 self._write_transfers[transfer.resource_id] = transfer 503 self._write_stream.open() 504 505 _LOG.debug('Starting new write transfer %d', transfer.id) 506 delay = 1 507 self._loop.call_soon_threadsafe( 508 self._loop.call_later, 509 delay, 510 self._new_transfer_queue.put_nowait, 511 transfer, 512 ) 513 514 def _end_write_transfer(self, transfer: Transfer) -> None: 515 """Completes a write transfer.""" 516 del self._write_transfers[transfer.resource_id] 517 518 if not transfer.status.ok(): 519 _LOG.error( 520 'Write transfer %d terminated with status %s', 521 transfer.id, 522 transfer.status, 523 ) 524 525 526class Error(Exception): 527 """Exception raised when a transfer fails. 528 529 Stores the ID of the failed transfer resource and the error that occurred. 530 """ 531 532 def __init__(self, resource_id: int, status: Status): 533 super().__init__(f'Transfer {resource_id} failed with status {status}') 534 self.resource_id = resource_id 535 self.status = status 536