1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Streaming AEAD wrapper.""" 15 16import io 17from typing import cast, BinaryIO, Optional, Type 18 19from tink import core 20from tink.streaming_aead import _raw_streaming_aead 21from tink.streaming_aead import _rewindable_input_stream 22from tink.streaming_aead import _streaming_aead 23 24 25class _DecryptingStreamWrapper(io.RawIOBase): 26 """A file-like object which decrypts reads from an underlying object. 27 28 It uses a primitive set of streaming AEADs, and decrypts the stream with the 29 matching key in the keyset. Closing this wrapper also closes 30 ciphertext_source. 31 """ 32 33 def __init__(self, primitive_set: core.PrimitiveSet, 34 ciphertext_source: BinaryIO, associated_data: bytes): 35 """Create a new _DecryptingStreamWrapper. 36 37 Args: 38 primitive_set: The primitive set of StreamingAead primitives. 39 ciphertext_source: A readable file-like object from which ciphertext bytes 40 will be read. 41 associated_data: The associated data to use for decryption. 42 """ 43 super().__init__() 44 if not ciphertext_source.readable(): 45 raise ValueError('ciphertext_source must be readable') 46 self._ciphertext_source = _rewindable_input_stream.RewindableInputStream( 47 ciphertext_source) 48 self._associated_data = associated_data 49 self._matching_stream = None 50 self._remaining_primitives = [] 51 # For legacy reasons (Tink always encrypted with non-RAW keys) we use all 52 # primitives, even those which have output_prefix_type != RAW. 53 for entry_list in primitive_set.all(): 54 for e in entry_list: 55 self._remaining_primitives.append(e.primitive) 56 self._attempting_stream = self._next_decrypting_stream() 57 58 def _next_decrypting_stream(self) -> io.RawIOBase: 59 """Takes the next remaining primitive and returns a decrypting stream.""" 60 if not self._remaining_primitives: 61 raise ValueError('No primitive remaining.') 62 # ciphertext_source should never be closed by any of the raw decrypting 63 # streams, to be able to use it for another decrypting stream. 64 # ciphertext_source will be closed in close(). 65 # self._ciphertext_source needs to be at the starting position. 66 return self._remaining_primitives.pop(0).new_raw_decrypting_stream( 67 self._ciphertext_source, 68 self._associated_data, 69 close_ciphertext_source=False) 70 71 def read(self, size=-1) -> Optional[bytes]: 72 """Read and return up to size bytes, where size is an int. 73 74 Args: 75 size: Maximum number of bytes to read. As a convenience, if size is 76 unspecified or -1, all bytes until EOF are returned. 77 78 Returns: 79 Bytes read. An empty bytes object is returned if the stream is already at 80 EOF. None is returned if no data is available at the moment. 81 82 Raises: 83 TinkError if there was a permanent error. 84 ValueError if the file is closed. 85 """ 86 if self.closed: # pylint:disable=using-constant-test 87 raise ValueError('read on closed file.') 88 if size == 0: 89 return bytes() 90 if self._matching_stream: 91 return self._matching_stream.read(size) 92 # if self._matching_stream is not set, we are currently reading from 93 # self._attempting_stream but no data has been read successfully yet. 94 while True: 95 try: 96 data = self._attempting_stream.read(size) 97 if data is None: 98 # No data at the moment. Not clear if decryption was successful. 99 # Try again with the same stream next time. 100 return None 101 # Any value other than None means that decryption was successful. 102 # (b'' indicates that the plaintext is an empty string.) 103 self._matching_stream = self._attempting_stream 104 self._attempting_stream = None 105 self._ciphertext_source.disable_rewind() 106 return data 107 except core.TinkError: 108 if not self._remaining_primitives: 109 raise core.TinkError( 110 'No matching key found for the ciphertext in the stream') 111 # Try another key. 112 self._ciphertext_source.rewind() 113 self._attempting_stream = self._next_decrypting_stream() 114 115 def readinto(self, b: bytearray) -> Optional[int]: 116 """Read bytes into a pre-allocated bytes-like object b.""" 117 data = self.read(len(b)) 118 if data is None: 119 return None 120 n = len(data) 121 b[:n] = data 122 return n 123 124 def close(self) -> None: 125 if self.closed: # pylint:disable=using-constant-test 126 return 127 if self._matching_stream: 128 self._matching_stream.close() 129 if self._attempting_stream: 130 self._attempting_stream.close() 131 self._ciphertext_source.close() 132 super().close() 133 134 def readable(self) -> bool: 135 return True 136 137 138class _WrappedStreamingAead(_streaming_aead.StreamingAead): 139 """Implements StreamingAead by wrapping a set of RawStreamingAead.""" 140 141 def __init__(self, primitives_set: core.PrimitiveSet): 142 self._primitive_set = primitives_set 143 144 def new_encrypting_stream(self, ciphertext_destination: BinaryIO, 145 associated_data: bytes) -> BinaryIO: 146 raw = self._primitive_set.primary().primitive.new_raw_encrypting_stream( 147 ciphertext_destination, associated_data) 148 return cast(BinaryIO, io.BufferedWriter(raw)) 149 150 def new_decrypting_stream(self, ciphertext_source: BinaryIO, 151 associated_data: bytes) -> BinaryIO: 152 raw = _DecryptingStreamWrapper(self._primitive_set, ciphertext_source, 153 associated_data) 154 return cast(BinaryIO, io.BufferedReader(raw)) 155 156 157class StreamingAeadWrapper( 158 core.PrimitiveWrapper[_raw_streaming_aead.RawStreamingAead, 159 _streaming_aead.StreamingAead]): 160 """StreamingAeadWrapper is the PrimitiveWrapper for StreamingAead.""" 161 162 def wrap(self, 163 primitives_set: core.PrimitiveSet) -> _streaming_aead.StreamingAead: 164 return _WrappedStreamingAead(primitives_set) 165 166 def primitive_class(self) -> Type[_streaming_aead.StreamingAead]: 167 return _streaming_aead.StreamingAead 168 169 def input_primitive_class( 170 self) -> Type[_raw_streaming_aead.RawStreamingAead]: 171 return _raw_streaming_aead.RawStreamingAead 172