xref: /aosp_15_r20/external/tink/python/tink/streaming_aead/_streaming_aead_wrapper.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Streaming AEAD wrapper."""
15
16import io
17from typing import cast, BinaryIO, Optional, Type
18
19from tink import core
20from tink.streaming_aead import _raw_streaming_aead
21from tink.streaming_aead import _rewindable_input_stream
22from tink.streaming_aead import _streaming_aead
23
24
25class _DecryptingStreamWrapper(io.RawIOBase):
26  """A file-like object which decrypts reads from an underlying object.
27
28  It uses a primitive set of streaming AEADs, and decrypts the stream with the
29  matching key in the keyset. Closing this wrapper also closes
30  ciphertext_source.
31  """
32
33  def __init__(self, primitive_set: core.PrimitiveSet,
34               ciphertext_source: BinaryIO, associated_data: bytes):
35    """Create a new _DecryptingStreamWrapper.
36
37    Args:
38      primitive_set: The primitive set of StreamingAead primitives.
39      ciphertext_source: A readable file-like object from which ciphertext bytes
40        will be read.
41      associated_data: The associated data to use for decryption.
42    """
43    super().__init__()
44    if not ciphertext_source.readable():
45      raise ValueError('ciphertext_source must be readable')
46    self._ciphertext_source = _rewindable_input_stream.RewindableInputStream(
47        ciphertext_source)
48    self._associated_data = associated_data
49    self._matching_stream = None
50    self._remaining_primitives = []
51    # For legacy reasons (Tink always encrypted with non-RAW keys) we use all
52    # primitives, even those which have output_prefix_type != RAW.
53    for entry_list in primitive_set.all():
54      for e in entry_list:
55        self._remaining_primitives.append(e.primitive)
56    self._attempting_stream = self._next_decrypting_stream()
57
58  def _next_decrypting_stream(self) -> io.RawIOBase:
59    """Takes the next remaining primitive and returns a decrypting stream."""
60    if not self._remaining_primitives:
61      raise ValueError('No primitive remaining.')
62    # ciphertext_source should never be closed by any of the raw decrypting
63    # streams, to be able to use it for another decrypting stream.
64    # ciphertext_source will be closed in close().
65    # self._ciphertext_source needs to be at the starting position.
66    return self._remaining_primitives.pop(0).new_raw_decrypting_stream(
67        self._ciphertext_source,
68        self._associated_data,
69        close_ciphertext_source=False)
70
71  def read(self, size=-1) -> Optional[bytes]:
72    """Read and return up to size bytes, where size is an int.
73
74    Args:
75      size: Maximum number of bytes to read. As a convenience, if size is
76        unspecified or -1, all bytes until EOF are returned.
77
78    Returns:
79      Bytes read. An empty bytes object is returned if the stream is already at
80      EOF. None is returned if no data is available at the moment.
81
82    Raises:
83      TinkError if there was a permanent error.
84      ValueError if the file is closed.
85    """
86    if self.closed:  # pylint:disable=using-constant-test
87      raise ValueError('read on closed file.')
88    if size == 0:
89      return bytes()
90    if self._matching_stream:
91      return self._matching_stream.read(size)
92    # if self._matching_stream is not set, we are currently reading from
93    # self._attempting_stream but no data has been read successfully yet.
94    while True:
95      try:
96        data = self._attempting_stream.read(size)
97        if data is None:
98          # No data at the moment. Not clear if decryption was successful.
99          # Try again with the same stream next time.
100          return None
101        # Any value other than None means that decryption was successful.
102        # (b'' indicates that the plaintext is an empty string.)
103        self._matching_stream = self._attempting_stream
104        self._attempting_stream = None
105        self._ciphertext_source.disable_rewind()
106        return data
107      except core.TinkError:
108        if not self._remaining_primitives:
109          raise core.TinkError(
110              'No matching key found for the ciphertext in the stream')
111        # Try another key.
112        self._ciphertext_source.rewind()
113        self._attempting_stream = self._next_decrypting_stream()
114
115  def readinto(self, b: bytearray) -> Optional[int]:
116    """Read bytes into a pre-allocated bytes-like object b."""
117    data = self.read(len(b))
118    if data is None:
119      return None
120    n = len(data)
121    b[:n] = data
122    return n
123
124  def close(self) -> None:
125    if self.closed:  # pylint:disable=using-constant-test
126      return
127    if self._matching_stream:
128      self._matching_stream.close()
129    if self._attempting_stream:
130      self._attempting_stream.close()
131    self._ciphertext_source.close()
132    super().close()
133
134  def readable(self) -> bool:
135    return True
136
137
138class _WrappedStreamingAead(_streaming_aead.StreamingAead):
139  """Implements StreamingAead by wrapping a set of RawStreamingAead."""
140
141  def __init__(self, primitives_set: core.PrimitiveSet):
142    self._primitive_set = primitives_set
143
144  def new_encrypting_stream(self, ciphertext_destination: BinaryIO,
145                            associated_data: bytes) -> BinaryIO:
146    raw = self._primitive_set.primary().primitive.new_raw_encrypting_stream(
147        ciphertext_destination, associated_data)
148    return cast(BinaryIO, io.BufferedWriter(raw))
149
150  def new_decrypting_stream(self, ciphertext_source: BinaryIO,
151                            associated_data: bytes) -> BinaryIO:
152    raw = _DecryptingStreamWrapper(self._primitive_set, ciphertext_source,
153                                   associated_data)
154    return cast(BinaryIO, io.BufferedReader(raw))
155
156
157class StreamingAeadWrapper(
158    core.PrimitiveWrapper[_raw_streaming_aead.RawStreamingAead,
159                          _streaming_aead.StreamingAead]):
160  """StreamingAeadWrapper is the PrimitiveWrapper for StreamingAead."""
161
162  def wrap(self,
163           primitives_set: core.PrimitiveSet) -> _streaming_aead.StreamingAead:
164    return _WrappedStreamingAead(primitives_set)
165
166  def primitive_class(self) -> Type[_streaming_aead.StreamingAead]:
167    return _streaming_aead.StreamingAead
168
169  def input_primitive_class(
170      self) -> Type[_raw_streaming_aead.RawStreamingAead]:
171    return _raw_streaming_aead.RawStreamingAead
172