1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Implements a variant of BytesIO that lets you read the value after close(). 15 16This class can be used when an interface that writes to a stream and closes it 17in the end need to be transformed into a function that returns a value. 18 19An example is the implementation of normal AEAD encryption interface using 20the streaming AEAD encryption interface. 21""" 22 23import errno 24import io 25from typing import Optional 26 27 28class BytesIOWithValueAfterClose(io.BytesIO): 29 """A BytesIO that lets you read the written value after close().""" 30 31 def __init__(self, initial_bytes=None): 32 self._finalvalue = None 33 if initial_bytes: 34 super().__init__(initial_bytes) 35 else: 36 super().__init__() 37 38 def close(self) -> None: 39 if not self.closed: 40 self._value_after_close = self.getvalue() 41 super().close() 42 43 def value_after_close(self) -> bytes: 44 if not self.closed: 45 raise ValueError('call to value_after_close before close()') 46 return self._value_after_close 47 48 49class SlowBytesIO(io.BytesIO): 50 """A readable BytesIO that raised BlockingIOError on some calls to read.""" 51 52 def __init__(self, data: bytes, seekable: bool = False): 53 super().__init__(data) 54 self._seekable = seekable 55 self._state = -1 56 57 def read(self, size: int = -1) -> bytes: 58 if size > 0: 59 self._state += 1 60 if self._state > 10000000: 61 raise AssertionError('too many read. Is there an infinite loop?') 62 if self._state % 3 == 0: # block on every third call. 63 raise io.BlockingIOError( 64 errno.EAGAIN, 65 'write could not complete without blocking', 0) 66 # read at most 5 bytes. 67 return super().read(min(size, 5)) 68 return super().read(size) 69 70 def seek(self, pos: int, whence: int = 0) -> int: 71 if self._seekable: 72 return super().seek(pos, whence) 73 raise io.UnsupportedOperation('seek') 74 75 def seekable(self)-> bool: 76 return self._seekable 77 78 79class SlowReadableRawBytes(io.RawIOBase): 80 """A readable io.RawIOBase stream that only sometimes returns data.""" 81 82 def __init__(self, data: bytes, seekable: bool = False): 83 super().__init__() 84 self._bytes_io = io.BytesIO(data) 85 self._seekable = seekable 86 self._state = -1 87 88 def readinto(self, b: bytearray) -> Optional[int]: 89 try: 90 self._state += 1 91 if self._state > 10000000: 92 raise AssertionError('too many read. Is there an infinite loop?') 93 if self._state % 3 == 0: # return None on every third call. 94 return None 95 # read at most 5 bytes 96 q = self._bytes_io.read(5) 97 b[:len(q)] = q 98 return len(q) 99 except io.BlockingIOError: 100 raise ValueError('io.BytesIO should not raise BlockingIOError') 101 102 def readable(self): 103 return True 104 105 def seek(self, pos: int, whence: int = 0) -> int: 106 if self._seekable: 107 return self._bytes_io.seek(pos, whence) 108 raise io.UnsupportedOperation('seek') 109 110 def seekable(self)-> bool: 111 return self._seekable 112