xref: /aosp_15_r20/external/tink/python/tink/streaming_aead/_streaming_aead_wrapper_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tink.python.tink.streaming_aead._streaming_aead_wrapper."""
15
16import io
17from typing import BinaryIO, cast
18
19from absl.testing import absltest
20from absl.testing import parameterized
21from tink.proto import aes_gcm_hkdf_streaming_pb2
22from tink.proto import common_pb2
23from tink.proto import tink_pb2
24import tink
25from tink import cleartext_keyset_handle
26from tink import streaming_aead
27from tink.testing import bytes_io
28from tink.testing import keyset_builder
29
30
31TEMPLATE = streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_4KB
32TYPE_URL = 'type.googleapis.com/google.crypto.tink.AesGcmHkdfStreamingKey'
33
34
35def setUpModule():
36  streaming_aead.register()
37
38
39def _encrypt(primitive: streaming_aead.StreamingAead, plaintext: bytes,
40             associated_data: bytes) -> bytes:
41  ciphertext_dest = bytes_io.BytesIOWithValueAfterClose()
42  with primitive.new_encrypting_stream(ciphertext_dest, associated_data) as es:
43    es.write(plaintext)
44  return ciphertext_dest.value_after_close()
45
46
47class StreamingAeadWrapperTest(parameterized.TestCase):
48
49  @parameterized.parameters(
50      [b'plaintext', b'', b'smile \xf0\x9f\x98\x80', b'\xf0\x9f\x98'])
51  def test_encrypt_decrypt_success(self, plaintext):
52    keyset_handle = tink.new_keyset_handle(TEMPLATE)
53    primitive = keyset_handle.primitive(streaming_aead.StreamingAead)
54
55    aad = b'associated_data'
56    ciphertext_dest = bytes_io.BytesIOWithValueAfterClose()
57    with primitive.new_encrypting_stream(ciphertext_dest, aad) as es:
58      self.assertLen(plaintext, es.write(plaintext))
59    self.assertTrue(ciphertext_dest.closed)
60
61    ciphertext_src = io.BytesIO(ciphertext_dest.value_after_close())
62    with primitive.new_decrypting_stream(ciphertext_src, aad) as ds:
63      output = ds.read()
64    self.assertTrue(ciphertext_src.closed)
65    self.assertEqual(output, plaintext)
66
67  def test_long_plaintext_encrypt_decrypt_success(self):
68    keyset_handle = tink.new_keyset_handle(TEMPLATE)
69    primitive = keyset_handle.primitive(streaming_aead.StreamingAead)
70
71    long_plaintext = b' '.join(b'%d' % i for i in range(10 * 1000 * 1000))
72    aad = b'associated_data'
73    ciphertext_dest = bytes_io.BytesIOWithValueAfterClose()
74    with primitive.new_encrypting_stream(ciphertext_dest, aad) as es:
75      self.assertLen(long_plaintext, es.write(long_plaintext))
76    self.assertTrue(ciphertext_dest.closed)
77
78    ciphertext_src = io.BytesIO(ciphertext_dest.value_after_close())
79    with primitive.new_decrypting_stream(ciphertext_src, aad) as ds:
80      output = ds.read()
81    self.assertTrue(ciphertext_src.closed)
82    self.assertEqual(output, long_plaintext)
83
84  @parameterized.parameters(
85      [bytes_io.SlowBytesIO, bytes_io.SlowReadableRawBytes])
86  def test_slow_encrypt_decrypt_success(self, input_stream_factory):
87    keyset_handle = tink.new_keyset_handle(TEMPLATE)
88    primitive = keyset_handle.primitive(streaming_aead.StreamingAead)
89    plaintext = b' '.join(b'%d' % i for i in range(10 * 1000))
90    aad = b'associated_data'
91    ciphertext = _encrypt(primitive, plaintext, aad)
92
93    # Even if the ciphertext source only returns small data chunks and sometimes
94    # None, calling read() should return the whole ciphertext.
95    ciphertext_src = cast(BinaryIO, input_stream_factory(ciphertext))
96    with primitive.new_decrypting_stream(ciphertext_src, aad) as ds:
97      output = ds.read()
98    self.assertTrue(ciphertext_src.closed)
99    self.assertEqual(output, plaintext)
100
101  def test_encrypt_decrypt_bad_aad(self):
102    keyset_handle = tink.new_keyset_handle(TEMPLATE)
103    primitive = keyset_handle.primitive(streaming_aead.StreamingAead)
104
105    plaintext = b'plaintext'
106    aad = b'associated_data'
107
108    ciphertext_dest = bytes_io.BytesIOWithValueAfterClose()
109    with primitive.new_encrypting_stream(ciphertext_dest, aad) as es:
110      self.assertLen(plaintext, es.write(plaintext))
111    self.assertTrue(ciphertext_dest.closed)
112
113    ciphertext_src = io.BytesIO(ciphertext_dest.value_after_close())
114    with primitive.new_decrypting_stream(ciphertext_src, b'bad aad') as ds:
115      with self.assertRaises(tink.TinkError):
116        _ = ds.read()
117
118  def test_decrypt_unknown_key_fails(self):
119    plaintext = b'plaintext'
120    aad = b'associated_data'
121
122    unknown_keyset_handle = tink.new_keyset_handle(TEMPLATE)
123    unknown_primitive = unknown_keyset_handle.primitive(
124        streaming_aead.StreamingAead)
125    unknown_ciphertext_dest = bytes_io.BytesIOWithValueAfterClose()
126    with unknown_primitive.new_encrypting_stream(unknown_ciphertext_dest,
127                                                 aad) as es:
128      es.write(plaintext)
129
130    keyset_handle = tink.new_keyset_handle(TEMPLATE)
131    primitive = keyset_handle.primitive(streaming_aead.StreamingAead)
132    ciphertext_src = io.BytesIO(unknown_ciphertext_dest.value_after_close())
133    with primitive.new_decrypting_stream(ciphertext_src, aad) as ds:
134      with self.assertRaises(tink.TinkError):
135        _ = ds.read()
136
137  @parameterized.parameters(
138      [io.BytesIO, bytes_io.SlowBytesIO, bytes_io.SlowReadableRawBytes])
139  def test_encrypt_decrypt_with_key_rotation(self, input_stream_factory):
140    builder = keyset_builder.new_keyset_builder()
141    older_key_id = builder.add_new_key(TEMPLATE)
142    builder.set_primary_key(older_key_id)
143    p1 = builder.keyset_handle().primitive(streaming_aead.StreamingAead)
144
145    newer_key_id = builder.add_new_key(TEMPLATE)
146    p2 = builder.keyset_handle().primitive(streaming_aead.StreamingAead)
147
148    builder.set_primary_key(newer_key_id)
149    p3 = builder.keyset_handle().primitive(streaming_aead.StreamingAead)
150
151    builder.disable_key(older_key_id)
152    p4 = builder.keyset_handle().primitive(streaming_aead.StreamingAead)
153
154    self.assertNotEqual(older_key_id, newer_key_id)
155
156    # p1 encrypts with the older key. So p1, p2 and p3 can decrypt it,
157    # but not p4.
158    plaintext1 = b' '.join(b'%d' % i for i in range(100 * 101))
159    ciphertext1 = _encrypt(p1, plaintext1, b'aad1')
160    with p1.new_decrypting_stream(
161        cast(BinaryIO, input_stream_factory(ciphertext1)), b'aad1') as ds:
162      self.assertEqual(ds.read(), plaintext1)
163    with p2.new_decrypting_stream(
164        cast(BinaryIO, input_stream_factory(ciphertext1)), b'aad1') as ds:
165      self.assertEqual(ds.read(), plaintext1)
166    with p3.new_decrypting_stream(
167        cast(BinaryIO, input_stream_factory(ciphertext1)), b'aad1') as ds:
168      self.assertEqual(ds.read(), plaintext1)
169    with p4.new_decrypting_stream(
170        cast(BinaryIO, input_stream_factory(ciphertext1)), b'aad1') as ds:
171      with self.assertRaises(tink.TinkError):
172        ds.read()
173
174    # p2 encrypts with the older key. So p1, p2 and p3 can decrypt it,
175    # but not p4.
176    plaintext2 = b' '.join(b'%d' % i for i in range(100 * 102))
177    ciphertext2 = _encrypt(p2, plaintext2, b'aad2')
178    with p1.new_decrypting_stream(
179        cast(BinaryIO, input_stream_factory(ciphertext2)), b'aad2') as ds:
180      self.assertEqual(ds.read(), plaintext2)
181    with p2.new_decrypting_stream(
182        cast(BinaryIO, input_stream_factory(ciphertext2)), b'aad2') as ds:
183      self.assertEqual(ds.read(), plaintext2)
184    with p3.new_decrypting_stream(
185        cast(BinaryIO, input_stream_factory(ciphertext2)), b'aad2') as ds:
186      self.assertEqual(ds.read(), plaintext2)
187    with p4.new_decrypting_stream(
188        cast(BinaryIO, input_stream_factory(ciphertext2)), b'aad2') as ds:
189      with self.assertRaises(tink.TinkError):
190        ds.read()
191
192    # p3 encrypts with the newer key. So p2, p3 and p4 can decrypt it,
193    # but not p1.
194    plaintext3 = b' '.join(b'%d' % i for i in range(100 * 103))
195    ciphertext3 = _encrypt(p3, plaintext3, b'aad3')
196    with p1.new_decrypting_stream(
197        cast(BinaryIO, input_stream_factory(ciphertext3)), b'aad3') as ds:
198      with self.assertRaises(tink.TinkError):
199        ds.read()
200    with p2.new_decrypting_stream(
201        cast(BinaryIO, input_stream_factory(ciphertext3)), b'aad3') as ds:
202      self.assertEqual(ds.read(), plaintext3)
203    with p3.new_decrypting_stream(
204        cast(BinaryIO, input_stream_factory(ciphertext3)), b'aad3') as ds:
205      self.assertEqual(ds.read(), plaintext3)
206    with p4.new_decrypting_stream(
207        cast(BinaryIO, input_stream_factory(ciphertext3)), b'aad3') as ds:
208      self.assertEqual(ds.read(), plaintext3)
209
210    # p4 encrypts with the newer key. So p2, p3 and p4 can decrypt it,
211    # but not p1.
212    plaintext4 = b' '.join(b'%d' % i for i in range(100 * 104))
213    ciphertext4 = _encrypt(p4, plaintext4, b'aad4')
214    with p1.new_decrypting_stream(
215        cast(BinaryIO, input_stream_factory(ciphertext4)), b'aad4') as ds:
216      with self.assertRaises(tink.TinkError):
217        ds.read()
218    with p2.new_decrypting_stream(
219        cast(BinaryIO, input_stream_factory(ciphertext4)), b'aad4') as ds:
220      self.assertEqual(ds.read(), plaintext4)
221    with p3.new_decrypting_stream(
222        cast(BinaryIO, input_stream_factory(ciphertext4)), b'aad4') as ds:
223      self.assertEqual(ds.read(), plaintext4)
224    with p4.new_decrypting_stream(
225        cast(BinaryIO, input_stream_factory(ciphertext4)), b'aad4') as ds:
226      self.assertEqual(ds.read(), plaintext4)
227
228  def test_decrypt_tink_output_prefix(self):
229    key = aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey(
230        version=0,
231        params=aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingParams(
232            ciphertext_segment_size=512,
233            derived_key_size=16,
234            hkdf_hash_type=common_pb2.HashType.SHA256,
235        ),
236        key_value=b'0123456789abcdef',
237    )
238    value1 = key.SerializeToString()
239    key.key_value = b'ABCDEF0123456789'
240    value2 = key.SerializeToString()
241
242    # We use two keys so that we have at least 1 raw key in the keyset: Tink
243    # has a check that creating a new StreamingAead fails when the keyset does
244    # not contain any raw keys, and we only want this to fail on decryption.
245    keyset = tink_pb2.Keyset(
246        primary_key_id=1,
247        key=[
248            tink_pb2.Keyset.Key(
249                key_data=tink_pb2.KeyData(
250                    type_url=TYPE_URL,
251                    value=value1,
252                    key_material_type=tink_pb2.KeyData.SYMMETRIC,
253                ),
254                output_prefix_type=tink_pb2.OutputPrefixType.TINK,
255                status=tink_pb2.KeyStatusType.ENABLED,
256                key_id=1,
257            ),
258            tink_pb2.Keyset.Key(
259                key_data=tink_pb2.KeyData(
260                    type_url=TYPE_URL,
261                    value=value2,
262                    key_material_type=tink_pb2.KeyData.SYMMETRIC,
263                ),
264                output_prefix_type=tink_pb2.OutputPrefixType.RAW,
265                status=tink_pb2.KeyStatusType.ENABLED,
266                key_id=2,
267            ),
268        ],
269    )
270
271    keyset_handle = cleartext_keyset_handle.from_keyset(keyset)
272    primitive = keyset_handle.primitive(streaming_aead.StreamingAead)
273
274    plaintext = b'plaintext'
275    associated_data = b'associated_data'
276
277    ciphertext_dest = bytes_io.BytesIOWithValueAfterClose()
278    with primitive.new_encrypting_stream(
279        ciphertext_dest, associated_data
280    ) as es:
281      self.assertLen(plaintext, es.write(plaintext))
282    self.assertTrue(ciphertext_dest.closed)
283
284    ciphertext_src = io.BytesIO(ciphertext_dest.value_after_close())
285    with primitive.new_decrypting_stream(ciphertext_src, associated_data) as ds:
286      output = ds.read()
287    self.assertTrue(ciphertext_src.closed)
288    self.assertEqual(output, plaintext)
289
290
291if __name__ == '__main__':
292  absltest.main()
293