xref: /aosp_15_r20/external/pigweed/pw_transfer/py/pw_transfer/client.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Client for the pw_transfer service, which transmits data over pw_rpc."""
15
16import asyncio
17import ctypes
18import logging
19import threading
20from typing import Any, Callable
21
22from pw_rpc.callback_client import BidirectionalStreamingCall
23from pw_status import Status
24
25from pw_transfer.transfer import (
26    ProgressCallback,
27    ProtocolVersion,
28    ReadTransfer,
29    Transfer,
30    WriteTransfer,
31)
32from pw_transfer.chunk import Chunk
33from pw_transfer import transfer_pb2
34
35_LOG = logging.getLogger(__package__)
36
37_TransferDict = dict[int, Transfer]
38
39
40class _TransferStream:
41    def __init__(
42        self,
43        method,
44        chunk_handler: Callable[[Chunk], Any],
45        error_handler: Callable[[Status], Any],
46        max_reopen_attempts=3,
47    ):
48        self._method = method
49        self._chunk_handler = chunk_handler
50        self._error_handler = error_handler
51        self._call: BidirectionalStreamingCall | None = None
52        self._reopen_attempts = 0
53        self._max_reopen_attempts = max_reopen_attempts
54
55    def is_open(self) -> bool:
56        return self._call is not None
57
58    def open(self, force: bool = False) -> None:
59        if force or self._call is None:
60            self._call = self._method.invoke(
61                lambda _, chunk: self._on_chunk_received(chunk),
62                on_error=lambda _, status: self._on_stream_error(status),
63            )
64
65    def close(self) -> None:
66        if self._call is not None:
67            self._call.cancel()
68            self._call = None
69
70    def send(self, chunk: Chunk) -> None:
71        assert self._call is not None
72        self._call.send(chunk.to_message())
73
74    def _on_chunk_received(self, chunk: Chunk) -> None:
75        self._reopen_attempts = 0
76        self._chunk_handler(chunk)
77
78    def _on_stream_error(self, rpc_status: Status) -> None:
79        if rpc_status is Status.FAILED_PRECONDITION:
80            # FAILED_PRECONDITION indicates that the stream packet was not
81            # recognized as the stream is not open. Attempt to re-open the
82            # stream automatically.
83            self.open(force=True)
84        else:
85            # Other errors are unrecoverable; clear the stream.
86            _LOG.error('Transfer stream shut down with status %s', rpc_status)
87            self._call = None
88
89        self._error_handler(rpc_status)
90
91
92class Manager:  # pylint: disable=too-many-instance-attributes
93    """A manager for transmitting data through an RPC TransferService.
94
95    This should be initialized with an active Manager over an RPC channel. Only
96    one instance of this class should exist for a configured RPC TransferService
97    -- the Manager supports multiple simultaneous transfers.
98
99    When created, a Manager starts a separate thread in which transfer
100    communications and events are handled.
101    """
102
103    def __init__(
104        self,
105        rpc_transfer_service,
106        *,
107        default_response_timeout_s: float = 2.0,
108        initial_response_timeout_s: float = 4.0,
109        max_retries: int = 3,
110        max_lifetime_retries: int = 1500,
111        max_chunk_size_bytes: int = 1024,
112        default_protocol_version=ProtocolVersion.VERSION_TWO,
113    ):
114        """Initializes a Manager on top of a TransferService.
115
116        Args:
117          rpc_transfer_service: the pw_rpc transfer service client
118          default_response_timeout_s: max time to wait between receiving packets
119          initial_response_timeout_s: timeout for the first packet; may be
120              longer to account for transfer handler initialization
121          max_retries: number of times to retry a single package after a timeout
122          max_lifetime_retires: Cumulative maximum number of times to retry over
123              the course of the transfer before giving up.
124          max_chunk_size_bytes: In a read transfer, the maximum size of data the
125              server should send within a single packet.
126          default_protocol_version: Version of the pw_transfer protocol to use.
127              Defaults to the latest, but can be set to legacy for projects
128              which use legacy devices.
129        """
130        self._service: Any = rpc_transfer_service
131        self._default_response_timeout_s = default_response_timeout_s
132        self._initial_response_timeout_s = initial_response_timeout_s
133        self.max_retries = max_retries
134        self.max_lifetime_retries = max_lifetime_retries
135        self._max_chunk_size_bytes = max_chunk_size_bytes
136        self._default_protocol_version = default_protocol_version
137
138        # Ongoing transfers in the service by resource ID.
139        self._read_transfers: _TransferDict = {}
140        self._write_transfers: _TransferDict = {}
141        self._next_session_id = ctypes.c_uint32(1)
142
143        self._loop = asyncio.new_event_loop()
144        # Set the event loop for the current thread.
145        asyncio.set_event_loop(self._loop)
146
147        # Queues are used for communication between the Manager context and the
148        # dedicated asyncio transfer thread.
149        self._new_transfer_queue: asyncio.Queue = asyncio.Queue()
150        self._read_chunk_queue: asyncio.Queue = asyncio.Queue()
151        self._write_chunk_queue: asyncio.Queue = asyncio.Queue()
152        self._quit_event = asyncio.Event()
153
154        self._thread = threading.Thread(
155            target=self._start_event_loop_thread, daemon=True
156        )
157
158        # RPC streams for read and write transfers. These are shareable by
159        # multiple transfers of the same type.
160        self._read_stream = _TransferStream(
161            self._service.Read,
162            lambda chunk: self._loop.call_soon_threadsafe(
163                self._read_chunk_queue.put_nowait, chunk
164            ),
165            self._on_read_error,
166        )
167        self._write_stream = _TransferStream(
168            self._service.Write,
169            lambda chunk: self._loop.call_soon_threadsafe(
170                self._write_chunk_queue.put_nowait, chunk
171            ),
172            self._on_write_error,
173        )
174
175        self._thread.start()
176
177    def __del__(self):
178        # Notify the thread that the transfer manager is being destroyed and
179        # wait for it to exit.
180        if self._thread.is_alive():
181            self._loop.call_soon_threadsafe(self._quit_event.set)
182            self._thread.join()
183
184    def read(
185        self,
186        resource_id: int,
187        progress_callback: ProgressCallback | None = None,
188        protocol_version: ProtocolVersion | None = None,
189        chunk_timeout_s: float | None = None,
190        initial_timeout_s: float | None = None,
191        initial_offset: int = 0,
192    ) -> bytes:
193        """Receives ("downloads") data from the server.
194
195        Args:
196          resource_id: ID of the resource from which to read.
197          progress_callback: Optional callback periodically invoked throughout
198              the transfer with the transfer state. Can be used to provide user-
199              facing status updates such as progress bars.
200          protocol_version: The desired protocol version to use for this
201              transfer. Defaults to the version the manager was initialized
202              (typically VERSION_TWO).
203          chunk_timeout_s: Timeout for any individual chunk.
204          initial_timeout_s: Timeout for the first chunk, overrides
205              chunk_timeout_s.
206          initial_offset: Initial offset to start reading from. Must be
207              supported by the transfer handler. All transfers support starting
208              from 0, the default. Returned bytes will not have any padding
209              related to this initial offset. No seeking is done in the transfer
210              operation on the client side.
211
212        Raises:
213          Error: the transfer failed to complete
214        """
215
216        if resource_id in self._read_transfers:
217            raise ValueError(
218                f'Read transfer for resource {resource_id} already exists'
219            )
220
221        if protocol_version is None:
222            protocol_version = self._default_protocol_version
223
224        if protocol_version == ProtocolVersion.LEGACY and initial_offset != 0:
225            raise ValueError(
226                f'Unsupported transfer with offset {initial_offset} started '
227                + 'with legacy protocol'
228            )
229
230        session_id = (
231            resource_id
232            if protocol_version is ProtocolVersion.LEGACY
233            else self.assign_session_id()
234        )
235
236        if chunk_timeout_s is None:
237            chunk_timeout_s = self._default_response_timeout_s
238
239        if initial_timeout_s is None:
240            initial_timeout_s = self._initial_response_timeout_s
241
242        transfer = ReadTransfer(
243            session_id,
244            resource_id,
245            self._read_stream.send,
246            self._end_read_transfer,
247            chunk_timeout_s,
248            initial_timeout_s,
249            self.max_retries,
250            self.max_lifetime_retries,
251            protocol_version,
252            max_chunk_size=self._max_chunk_size_bytes,
253            progress_callback=progress_callback,
254            initial_offset=initial_offset,
255        )
256        self._start_read_transfer(transfer)
257
258        transfer.done.wait()
259
260        if not transfer.status.ok():
261            raise Error(transfer.resource_id, transfer.status)
262
263        return transfer.data
264
265    def write(
266        self,
267        resource_id: int,
268        data: bytes | str,
269        progress_callback: ProgressCallback | None = None,
270        protocol_version: ProtocolVersion | None = None,
271        chunk_timeout_s: Any | None = None,
272        initial_timeout_s: Any | None = None,
273        initial_offset: int = 0,
274    ) -> None:
275        """Transmits ("uploads") data to the server.
276
277        Args:
278          resource_id: ID of the resource to which to write.
279          data: Data to send to the server.
280          progress_callback: Optional callback periodically invoked throughout
281              the transfer with the transfer state. Can be used to provide user-
282              facing status updates such as progress bars.
283          protocol_version: The desired protocol version to use for this
284              transfer. Defaults to the version the manager was initialized
285              (defaults to LATEST).
286          chunk_timeout_s: Timeout for any individual chunk.
287          initial_timeout_s: Timeout for the first chunk, overrides
288              chunk_timeout_s.
289          initial_offset: Initial offset to start writing to. Must be supported
290              by the transfer handler. All transfers support starting from 0,
291              the default. data arg should start with the data you want to see
292              starting at this initial offset on the server. No seeking is done
293              in the transfer operation on the client side.
294
295        Raises:
296          Error: the transfer failed to complete
297        """
298
299        if isinstance(data, str):
300            data = data.encode()
301
302        if resource_id in self._write_transfers:
303            raise ValueError(
304                f'Write transfer for resource {resource_id} already exists'
305            )
306
307        if protocol_version is None:
308            protocol_version = self._default_protocol_version
309
310        if (
311            protocol_version != ProtocolVersion.VERSION_TWO
312            and initial_offset != 0
313        ):
314            raise ValueError(
315                f'Unsupported transfer with offset {initial_offset} started '
316                + 'with legacy protocol'
317            )
318
319        session_id = (
320            resource_id
321            if protocol_version is ProtocolVersion.LEGACY
322            else self.assign_session_id()
323        )
324
325        if chunk_timeout_s is None:
326            chunk_timeout_s = self._default_response_timeout_s
327
328        if initial_timeout_s is None:
329            initial_timeout_s = self._initial_response_timeout_s
330
331        transfer = WriteTransfer(
332            session_id,
333            resource_id,
334            data,
335            self._write_stream.send,
336            self._end_write_transfer,
337            chunk_timeout_s,
338            initial_timeout_s,
339            self.max_retries,
340            self.max_lifetime_retries,
341            protocol_version,
342            progress_callback=progress_callback,
343            initial_offset=initial_offset,
344        )
345        self._start_write_transfer(transfer)
346
347        transfer.done.wait()
348
349        if not transfer.status.ok():
350            raise Error(transfer.resource_id, transfer.status)
351
352    def assign_session_id(self) -> int:
353        new_id = self._next_session_id.value
354
355        self._next_session_id = ctypes.c_uint32(self._next_session_id.value + 1)
356        if self._next_session_id.value == 0:
357            self._next_session_id = ctypes.c_uint32(1)
358
359        return new_id
360
361    def _start_event_loop_thread(self):
362        """Entry point for event loop thread that starts an asyncio context."""
363        asyncio.set_event_loop(self._loop)
364
365        # Recreate the async communication channels in the context of the
366        # running event loop.
367        self._new_transfer_queue = asyncio.Queue()
368        self._read_chunk_queue = asyncio.Queue()
369        self._write_chunk_queue = asyncio.Queue()
370        self._quit_event = asyncio.Event()
371
372        self._loop.create_task(self._transfer_event_loop())
373        self._loop.run_forever()
374
375    async def _transfer_event_loop(self):
376        """Main async event loop."""
377        exit_thread = self._loop.create_task(self._quit_event.wait())
378        new_transfer = self._loop.create_task(self._new_transfer_queue.get())
379        read_chunk = self._loop.create_task(self._read_chunk_queue.get())
380        write_chunk = self._loop.create_task(self._write_chunk_queue.get())
381
382        while not self._quit_event.is_set():
383            # Perform a select(2)-like wait for one of several events to occur.
384            done, _ = await asyncio.wait(
385                (exit_thread, new_transfer, read_chunk, write_chunk),
386                return_when=asyncio.FIRST_COMPLETED,
387            )
388
389            if exit_thread in done:
390                break
391
392            if new_transfer in done:
393                await new_transfer.result().begin()
394                new_transfer = self._loop.create_task(
395                    self._new_transfer_queue.get()
396                )
397
398            if read_chunk in done:
399                self._loop.create_task(
400                    self._handle_chunk(
401                        self._read_transfers, read_chunk.result()
402                    )
403                )
404                read_chunk = self._loop.create_task(
405                    self._read_chunk_queue.get()
406                )
407
408            if write_chunk in done:
409                self._loop.create_task(
410                    self._handle_chunk(
411                        self._write_transfers, write_chunk.result()
412                    )
413                )
414                write_chunk = self._loop.create_task(
415                    self._write_chunk_queue.get()
416                )
417
418        self._loop.stop()
419
420    @staticmethod
421    async def _handle_chunk(
422        transfers: _TransferDict, message: transfer_pb2.Chunk
423    ) -> None:
424        """Processes an incoming chunk from a stream.
425
426        The chunk is dispatched to an active transfer based on its ID. If the
427        transfer indicates that it is complete, the provided completion callback
428        is invoked.
429        """
430
431        chunk = Chunk.from_message(message)
432
433        # Find a transfer for the chunk in the list of active transfers.
434        try:
435            if chunk.protocol_version is ProtocolVersion.LEGACY:
436                transfer = next(
437                    t
438                    for t in transfers.values()
439                    if t.resource_id == chunk.session_id
440                )
441            else:
442                transfer = next(
443                    t for t in transfers.values() if t.id == chunk.id()
444                )
445        except StopIteration:
446            _LOG.error(
447                'TransferManager received chunk for unknown transfer %d',
448                chunk.id(),
449            )
450            # TODO(frolv): What should be done here, if anything?
451            return
452
453        await transfer.handle_chunk(chunk)
454
455    def _on_read_error(self, status: Status) -> None:
456        """Callback for an RPC error in the read stream."""
457
458        for transfer in self._read_transfers.values():
459            transfer.finish(Status.INTERNAL, skip_callback=True)
460        self._read_transfers.clear()
461
462        _LOG.error('Read stream shut down: %s', status)
463
464    def _on_write_error(self, status: Status) -> None:
465        """Callback for an RPC error in the write stream."""
466
467        for transfer in self._write_transfers.values():
468            transfer.finish(Status.INTERNAL, skip_callback=True)
469        self._write_transfers.clear()
470
471        _LOG.error('Write stream shut down: %s', status)
472
473    def _start_read_transfer(self, transfer: Transfer) -> None:
474        """Begins a new read transfer, opening the stream if it isn't."""
475
476        self._read_transfers[transfer.resource_id] = transfer
477        self._read_stream.open()
478
479        _LOG.debug('Starting new read transfer %d', transfer.id)
480        delay = 1
481        self._loop.call_soon_threadsafe(
482            self._loop.call_later,
483            delay,
484            self._new_transfer_queue.put_nowait,
485            transfer,
486        )
487
488    def _end_read_transfer(self, transfer: Transfer) -> None:
489        """Completes a read transfer."""
490        del self._read_transfers[transfer.resource_id]
491
492        if not transfer.status.ok():
493            _LOG.error(
494                'Read transfer %d terminated with status %s',
495                transfer.id,
496                transfer.status,
497            )
498
499    def _start_write_transfer(self, transfer: Transfer) -> None:
500        """Begins a new write transfer, opening the stream if it isn't."""
501
502        self._write_transfers[transfer.resource_id] = transfer
503        self._write_stream.open()
504
505        _LOG.debug('Starting new write transfer %d', transfer.id)
506        delay = 1
507        self._loop.call_soon_threadsafe(
508            self._loop.call_later,
509            delay,
510            self._new_transfer_queue.put_nowait,
511            transfer,
512        )
513
514    def _end_write_transfer(self, transfer: Transfer) -> None:
515        """Completes a write transfer."""
516        del self._write_transfers[transfer.resource_id]
517
518        if not transfer.status.ok():
519            _LOG.error(
520                'Write transfer %d terminated with status %s',
521                transfer.id,
522                transfer.status,
523            )
524
525
526class Error(Exception):
527    """Exception raised when a transfer fails.
528
529    Stores the ID of the failed transfer resource and the error that occurred.
530    """
531
532    def __init__(self, resource_id: int, status: Status):
533        super().__init__(f'Transfer {resource_id} failed with status {status}')
534        self.resource_id = resource_id
535        self.status = status
536