xref: /aosp_15_r20/external/tink/python/tink/streaming_aead/_rewindable_input_stream.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""A Raw Input stream wrapper that supports rewinding."""
15
16import io
17from typing import Optional, BinaryIO
18
19
20class RewindableInputStream(io.RawIOBase):
21  """Implements a readable io.RawIOBase wrapper that supports rewinding.
22
23  The wrapped input_stream can either be a io.RawIOBase or io.BufferedIOBase.
24  """
25
26  def __init__(self, input_stream: BinaryIO):
27    super().__init__()
28    if not input_stream.readable():
29      raise ValueError('input_stream must be readable')
30    self._input_stream = input_stream
31    self._buffer = bytearray()
32    self._pos = 0
33    self._rewindable = True
34
35  def read(self, size: int = -1) -> Optional[bytes]:
36    """Read and return up to size bytes when size >= 0.
37
38    If input_stream.read returns None to indicate "No data at the moment", this
39    function may return None as well. But it will eventually return
40    some data, or return b'' if EOF is reached.
41
42    Args:
43      size: Maximum number of bytes to be returned, if >= 0. If size is smaller
44        than 0 or None, return the whole content of the file.
45    Returns:
46      bytes read. b'' is returned on EOF, and None if there is currently
47      no data available, but EOF is not reached yet.
48    """
49    if size is None or size < 0:
50      return self.readall()  # implemented in io.RawIOBase
51    if self._pos < len(self._buffer):
52      # buffer has some data left. Return up to 'size' bytes from the buffer
53      new_pos = min(len(self._buffer), self._pos + size)
54      b = self._buffer[self._pos:new_pos]
55      self._pos = new_pos
56      return bytes(b)
57    # no data left in buffer
58    if not self._rewindable and self._buffer:
59      # buffer is not needed anymore
60      self._buffer = bytearray()
61      self._pos = 0
62    try:
63      data = self._input_stream.read(size)
64    except BlockingIOError:
65      # self._input_stream is a BufferedIOBase and has currently no data
66      return None
67    if data is None:
68      # self._input_stream is a RawIOBase and has currently no data
69      return None
70    if self._rewindable:
71      self._buffer.extend(data)
72      self._pos += len(data)
73    return data
74
75  def rewind(self) -> None:
76    if not self._rewindable:
77      raise ValueError('rewind is disabled')
78    self._pos = 0
79
80  def disable_rewind(self) -> None:
81    self._rewindable = False
82
83  def readable(self) -> bool:
84    return True
85
86  def close(self) -> None:
87    """Close the stream and the wrapped input_stream."""
88    if self.closed:  # pylint:disable=using-constant-test
89      return
90    self._input_stream.close()
91    super().close()
92