1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tink.python.tink.util.bytes_io.""" 15 16import io 17import tempfile 18from typing import BinaryIO, cast 19from absl.testing import absltest 20from absl.testing import parameterized 21 22from tink.streaming_aead import _rewindable_input_stream 23from tink.testing import bytes_io 24 25 26class NonSeekableBytesIO(io.BytesIO): 27 28 def seekable(self) -> bool: 29 return False 30 31 32def _rewindable(data, 33 seekable) -> _rewindable_input_stream.RewindableInputStream: 34 if seekable: 35 b = cast(BinaryIO, io.BytesIO(data)) 36 else: 37 b = cast(BinaryIO, NonSeekableBytesIO(data)) 38 return _rewindable_input_stream.RewindableInputStream(b) 39 40 41class RewindableInputStreamTest(parameterized.TestCase): 42 43 @parameterized.parameters([False, True]) 44 def test_read(self, seekable): 45 with _rewindable(b'The quick brown fox', seekable) as f: 46 self.assertEqual(b'The q', f.read(5)) 47 self.assertEqual(b'uick ', f.read(5)) 48 self.assertEqual(b'brown', f.read(5)) 49 self.assertEqual(b' fox', f.read(5)) 50 self.assertEqual(b'', f.read(5)) 51 self.assertEqual(b'', f.read(5)) 52 53 @parameterized.parameters([False, True]) 54 def test_read_no_argument(self, seekable): 55 with _rewindable(b'The quick brown fox', seekable) as f: 56 self.assertEqual(b'The quick brown fox', f.read()) 57 58 @parameterized.parameters([False, True]) 59 def test_read_minus_one(self, seekable): 60 with _rewindable(b'The quick brown fox', seekable) as f: 61 self.assertEqual(b'The quick brown fox', f.read(-1)) 62 63 @parameterized.parameters([False, True]) 64 def test_readall(self, seekable): 65 with _rewindable(b'The quick brown fox', seekable) as f: 66 self.assertEqual(b'The quick brown fox', f.readall()) 67 68 @parameterized.parameters([False, True]) 69 def test_rewind_read(self, seekable): 70 with _rewindable(b'The quick brown fox', seekable) as f: 71 self.assertEqual(b'The quick', f.read(9)) 72 f.rewind() 73 self.assertEqual(b'The ', f.read(4)) 74 # this only reads the rest of current buffer content. 75 self.assertEqual(b'quick', f.read(100)) 76 self.assertEqual(b' brown fox', f.read()) 77 78 @parameterized.parameters([False, True]) 79 def test_rewind_readall(self, seekable): 80 with _rewindable(b'The quick brown fox', seekable) as f: 81 self.assertEqual(b'The q', f.read(5)) 82 f.rewind() 83 # this must read the whole file. 84 self.assertEqual(b'The quick brown fox', f.read()) 85 86 @parameterized.parameters([False, True]) 87 def test_rewind_twice(self, seekable): 88 with _rewindable(b'The quick brown fox', seekable) as f: 89 self.assertEqual(b'The q', f.read(5)) 90 f.rewind() 91 self.assertEqual(b'The q', f.read(5)) 92 self.assertEqual(b'uick ', f.read(5)) 93 f.rewind() 94 self.assertEqual(b'The quick brown fox', f.read()) 95 96 @parameterized.parameters([False, True]) 97 def test_disable_rewind(self, seekable): 98 with _rewindable(b'The quick brown fox', seekable) as f: 99 self.assertEqual(b'The q', f.read(5)) 100 f.rewind() 101 f.disable_rewind() 102 # this only reads the current buffer content. 103 self.assertEqual(b'The q', f.read(100)) 104 self.assertEqual(b'u', f.read(1)) 105 self.assertEmpty(f._buffer) 106 self.assertEqual(b'ick brown fox', f.read()) 107 108 @parameterized.parameters([False, True]) 109 def test_disable_rewind_readall(self, seekable): 110 with _rewindable(b'The quick brown fox', seekable) as f: 111 self.assertEqual(b'The q', f.read(5)) 112 f.rewind() 113 f.disable_rewind() 114 self.assertEqual(b'The quick brown fox', f.read()) 115 116 def test_nonreadable_input_fail(self): 117 with tempfile.TemporaryFile('wb') as f: 118 with self.assertRaises(ValueError): 119 _ = _rewindable_input_stream.RewindableInputStream(cast(BinaryIO, f)) 120 121 122class RewindableInputStreamSlowTest(parameterized.TestCase): 123 """Tests "slow" input streams where read returns None or BlockingIOError. 124 125 Normally, this should not happen in blocking streams. 126 """ 127 128 @parameterized.parameters([False, True]) 129 def test_read_slow(self, seekable): 130 input_stream = bytes_io.SlowBytesIO(b'The quick brown fox', seekable) 131 with _rewindable_input_stream.RewindableInputStream( 132 cast(BinaryIO, input_stream)) as f: 133 self.assertIsNone(f.read(10)) 134 self.assertEqual(b'The q', f.read(10)) 135 self.assertEqual(b'uick ', f.read(10)) 136 self.assertIsNone(f.read(10)) 137 self.assertEqual(b'brown', f.read(10)) 138 self.assertEqual(b' fox', f.read(10)) 139 self.assertIsNone(f.read(10)) 140 self.assertEqual(b'', f.read(10)) 141 142 @parameterized.parameters([False, True]) 143 def test_read_slow_raw(self, seekable): 144 input_stream = bytes_io.SlowReadableRawBytes(b'The quick brown fox', 145 seekable) 146 with _rewindable_input_stream.RewindableInputStream( 147 cast(BinaryIO, input_stream)) as f: 148 self.assertIsNone(f.read(10)) 149 self.assertEqual(b'The q', f.read(10)) 150 self.assertEqual(b'uick ', f.read(10)) 151 self.assertIsNone(f.read(10)) 152 self.assertEqual(b'brown', f.read(10)) 153 self.assertEqual(b' fox', f.read(10)) 154 self.assertIsNone(f.read(10)) 155 self.assertEqual(b'', f.read(10)) 156 157 @parameterized.parameters([False, True]) 158 def test_read_slow_raw_readall(self, seekable): 159 input_stream = bytes_io.SlowReadableRawBytes(b'The quick brown fox', 160 seekable) 161 with _rewindable_input_stream.RewindableInputStream( 162 cast(BinaryIO, input_stream)) as f: 163 self.assertIsNone(f.readall()) 164 self.assertEqual(b'The quick ', f.readall()) 165 self.assertEqual(b'brown fox', f.readall()) 166 self.assertEqual(b'', f.readall()) 167 168 169if __name__ == '__main__': 170 absltest.main() 171