1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tink.python.tink.util._file_object_adapter.""" 15 16import io 17 18from absl.testing import absltest 19from absl.testing.absltest import mock 20 21from tink.streaming_aead import _file_object_adapter 22 23 24class FileObjectAdapterTest(absltest.TestCase): 25 26 def test_basic_write(self): 27 file_object = io.BytesIO() 28 adapter = _file_object_adapter.FileObjectAdapter(file_object) 29 30 self.assertEqual(9, adapter.write(b'something')) 31 self.assertEqual(b'something', file_object.getvalue()) 32 adapter.close() 33 34 def test_multiple_write(self): 35 file_object = io.BytesIO() 36 adapter = _file_object_adapter.FileObjectAdapter(file_object) 37 38 self.assertEqual(9, adapter.write(b'something')) 39 self.assertEqual(3, adapter.write(b'123')) 40 self.assertEqual(3, adapter.write(b'456')) 41 self.assertEqual(b'something123456', file_object.getvalue()) 42 43 def test_write_after_close(self): 44 file_object = io.BytesIO() 45 adapter = _file_object_adapter.FileObjectAdapter(file_object) 46 47 adapter.close() 48 with self.assertRaises(ValueError): 49 adapter.write(b'something') 50 51 def test_write_returns_none(self): 52 file_object = mock.Mock() 53 file_object.write = mock.Mock(return_value=None) 54 adapter = _file_object_adapter.FileObjectAdapter(file_object) 55 56 self.assertEqual(0, adapter.write(b'something')) 57 58 def test_write_raises_blocking_error(self): 59 file_object = mock.Mock() 60 file_object.write = mock.Mock(side_effect=io.BlockingIOError(None, None, 5)) 61 adapter = _file_object_adapter.FileObjectAdapter(file_object) 62 63 self.assertEqual(5, adapter.write(b'something')) 64 65 def test_partial_write(self): 66 file_object = mock.Mock() 67 file_object.write = mock.Mock(wraps=lambda data: len(data) - 1) 68 adapter = _file_object_adapter.FileObjectAdapter(file_object) 69 70 self.assertEqual(8, adapter.write(b'something')) 71 72 def test_basic_read(self): 73 file_object = io.BytesIO(b'something') 74 adapter = _file_object_adapter.FileObjectAdapter(file_object) 75 76 self.assertEqual(adapter.read(9), b'something') 77 78 def test_multiple_read(self): 79 file_object = io.BytesIO(b'something') 80 adapter = _file_object_adapter.FileObjectAdapter(file_object) 81 82 self.assertEqual(adapter.read(3), b'som') 83 self.assertEqual(adapter.read(3), b'eth') 84 self.assertEqual(adapter.read(3), b'ing') 85 86 def test_read_returns_none(self): 87 file_object = mock.Mock() 88 file_object.read = mock.Mock(return_value=None) 89 adapter = _file_object_adapter.FileObjectAdapter(file_object) 90 91 self.assertEqual(adapter.read(10), b'') 92 93 def test_read_eof(self): 94 file_object = mock.Mock() 95 file_object.read = mock.Mock(return_value=b'') 96 adapter = _file_object_adapter.FileObjectAdapter(file_object) 97 98 with self.assertRaises(EOFError): 99 adapter.read(10) 100 101 def test_read_size_0(self): 102 file_object = io.BytesIO(b'something') 103 adapter = _file_object_adapter.FileObjectAdapter(file_object) 104 105 self.assertEqual(adapter.read(0), b'') 106 107 def test_read_negative_size_fails(self): 108 file_object = io.BytesIO(b'something') 109 adapter = _file_object_adapter.FileObjectAdapter(file_object) 110 with self.assertRaises(ValueError): 111 adapter.read(-1) 112 113 def test_read_raises_blocking_error(self): 114 file_object = mock.Mock() 115 file_object.read = mock.Mock(side_effect=io.BlockingIOError(None, None)) 116 adapter = _file_object_adapter.FileObjectAdapter(file_object) 117 118 self.assertEqual(adapter.read(10), b'') 119 120 121if __name__ == '__main__': 122 absltest.main() 123