xref: /aosp_15_r20/external/tink/python/tink/streaming_aead/_rewindable_input_stream_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tink.python.tink.util.bytes_io."""
15
16import io
17import tempfile
18from typing import BinaryIO, cast
19from absl.testing import absltest
20from absl.testing import parameterized
21
22from tink.streaming_aead import _rewindable_input_stream
23from tink.testing import bytes_io
24
25
26class NonSeekableBytesIO(io.BytesIO):
27
28  def seekable(self) -> bool:
29    return False
30
31
32def _rewindable(data,
33                seekable) -> _rewindable_input_stream.RewindableInputStream:
34  if seekable:
35    b = cast(BinaryIO, io.BytesIO(data))
36  else:
37    b = cast(BinaryIO, NonSeekableBytesIO(data))
38  return _rewindable_input_stream.RewindableInputStream(b)
39
40
41class RewindableInputStreamTest(parameterized.TestCase):
42
43  @parameterized.parameters([False, True])
44  def test_read(self, seekable):
45    with _rewindable(b'The quick brown fox', seekable) as f:
46      self.assertEqual(b'The q', f.read(5))
47      self.assertEqual(b'uick ', f.read(5))
48      self.assertEqual(b'brown', f.read(5))
49      self.assertEqual(b' fox', f.read(5))
50      self.assertEqual(b'', f.read(5))
51      self.assertEqual(b'', f.read(5))
52
53  @parameterized.parameters([False, True])
54  def test_read_no_argument(self, seekable):
55    with _rewindable(b'The quick brown fox', seekable) as f:
56      self.assertEqual(b'The quick brown fox', f.read())
57
58  @parameterized.parameters([False, True])
59  def test_read_minus_one(self, seekable):
60    with _rewindable(b'The quick brown fox', seekable) as f:
61      self.assertEqual(b'The quick brown fox', f.read(-1))
62
63  @parameterized.parameters([False, True])
64  def test_readall(self, seekable):
65    with _rewindable(b'The quick brown fox', seekable) as f:
66      self.assertEqual(b'The quick brown fox', f.readall())
67
68  @parameterized.parameters([False, True])
69  def test_rewind_read(self, seekable):
70    with _rewindable(b'The quick brown fox', seekable) as f:
71      self.assertEqual(b'The quick', f.read(9))
72      f.rewind()
73      self.assertEqual(b'The ', f.read(4))
74      # this only reads the rest of current buffer content.
75      self.assertEqual(b'quick', f.read(100))
76      self.assertEqual(b' brown fox', f.read())
77
78  @parameterized.parameters([False, True])
79  def test_rewind_readall(self, seekable):
80    with _rewindable(b'The quick brown fox', seekable) as f:
81      self.assertEqual(b'The q', f.read(5))
82      f.rewind()
83      # this must read the whole file.
84      self.assertEqual(b'The quick brown fox', f.read())
85
86  @parameterized.parameters([False, True])
87  def test_rewind_twice(self, seekable):
88    with _rewindable(b'The quick brown fox', seekable) as f:
89      self.assertEqual(b'The q', f.read(5))
90      f.rewind()
91      self.assertEqual(b'The q', f.read(5))
92      self.assertEqual(b'uick ', f.read(5))
93      f.rewind()
94      self.assertEqual(b'The quick brown fox', f.read())
95
96  @parameterized.parameters([False, True])
97  def test_disable_rewind(self, seekable):
98    with _rewindable(b'The quick brown fox', seekable) as f:
99      self.assertEqual(b'The q', f.read(5))
100      f.rewind()
101      f.disable_rewind()
102      # this only reads the current buffer content.
103      self.assertEqual(b'The q', f.read(100))
104      self.assertEqual(b'u', f.read(1))
105      self.assertEmpty(f._buffer)
106      self.assertEqual(b'ick brown fox', f.read())
107
108  @parameterized.parameters([False, True])
109  def test_disable_rewind_readall(self, seekable):
110    with _rewindable(b'The quick brown fox', seekable) as f:
111      self.assertEqual(b'The q', f.read(5))
112      f.rewind()
113      f.disable_rewind()
114      self.assertEqual(b'The quick brown fox', f.read())
115
116  def test_nonreadable_input_fail(self):
117    with tempfile.TemporaryFile('wb') as f:
118      with self.assertRaises(ValueError):
119        _ = _rewindable_input_stream.RewindableInputStream(cast(BinaryIO, f))
120
121
122class RewindableInputStreamSlowTest(parameterized.TestCase):
123  """Tests "slow" input streams where read returns None or BlockingIOError.
124
125  Normally, this should not happen in blocking streams.
126  """
127
128  @parameterized.parameters([False, True])
129  def test_read_slow(self, seekable):
130    input_stream = bytes_io.SlowBytesIO(b'The quick brown fox', seekable)
131    with _rewindable_input_stream.RewindableInputStream(
132        cast(BinaryIO, input_stream)) as f:
133      self.assertIsNone(f.read(10))
134      self.assertEqual(b'The q', f.read(10))
135      self.assertEqual(b'uick ', f.read(10))
136      self.assertIsNone(f.read(10))
137      self.assertEqual(b'brown', f.read(10))
138      self.assertEqual(b' fox', f.read(10))
139      self.assertIsNone(f.read(10))
140      self.assertEqual(b'', f.read(10))
141
142  @parameterized.parameters([False, True])
143  def test_read_slow_raw(self, seekable):
144    input_stream = bytes_io.SlowReadableRawBytes(b'The quick brown fox',
145                                                 seekable)
146    with _rewindable_input_stream.RewindableInputStream(
147        cast(BinaryIO, input_stream)) as f:
148      self.assertIsNone(f.read(10))
149      self.assertEqual(b'The q', f.read(10))
150      self.assertEqual(b'uick ', f.read(10))
151      self.assertIsNone(f.read(10))
152      self.assertEqual(b'brown', f.read(10))
153      self.assertEqual(b' fox', f.read(10))
154      self.assertIsNone(f.read(10))
155      self.assertEqual(b'', f.read(10))
156
157  @parameterized.parameters([False, True])
158  def test_read_slow_raw_readall(self, seekable):
159    input_stream = bytes_io.SlowReadableRawBytes(b'The quick brown fox',
160                                                 seekable)
161    with _rewindable_input_stream.RewindableInputStream(
162        cast(BinaryIO, input_stream)) as f:
163      self.assertIsNone(f.readall())
164      self.assertEqual(b'The quick ', f.readall())
165      self.assertEqual(b'brown fox', f.readall())
166      self.assertEqual(b'', f.readall())
167
168
169if __name__ == '__main__':
170  absltest.main()
171