xref: /aosp_15_r20/external/tink/python/tink/streaming_aead/_streaming_aead_key_manager_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tink.python.tink.streaming_aead_key_manager."""
15
16import io
17
18from absl.testing import absltest
19from absl.testing import parameterized
20from tink.proto import aes_ctr_hmac_streaming_pb2
21from tink.proto import aes_gcm_hkdf_streaming_pb2
22from tink.proto import common_pb2
23from tink.proto import tink_pb2
24import tink
25from tink import core
26from tink import streaming_aead
27from tink.streaming_aead import _raw_streaming_aead
28from tink.testing import bytes_io
29
30# Using malformed UTF-8 sequences to ensure there is no accidental decoding.
31B_X80 = b'\x80'
32
33
34def setUpModule():
35  streaming_aead.register()
36
37
38def new_raw_primitive():
39  key_data = core.Registry.new_key_data(
40      streaming_aead.streaming_aead_key_templates
41      .AES128_CTR_HMAC_SHA256_4KB)
42  return core.Registry.primitive(key_data,
43                                 _raw_streaming_aead.RawStreamingAead)
44
45
46class StreamingAeadKeyManagerTest(parameterized.TestCase):
47
48  def test_new_aes_gcm_hkdf_key_data(self):
49    key_template = (
50        streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_4KB)
51    key_data = core.Registry.new_key_data(key_template)
52    self.assertEqual(key_data.type_url, key_template.type_url)
53    self.assertEqual(key_data.key_material_type, tink_pb2.KeyData.SYMMETRIC)
54    key = aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey.FromString(
55        key_data.value)
56    self.assertEqual(key.version, 0)
57    self.assertLen(key.key_value, 16)
58    self.assertEqual(key.params.hkdf_hash_type, common_pb2.HashType.SHA256)
59    self.assertEqual(key.params.derived_key_size, 16)
60    self.assertEqual(key.params.ciphertext_segment_size, 4096)
61
62  def test_new_aes_ctr_hmac_key_data(self):
63    key_template = (
64        streaming_aead.streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_4KB)
65    key_data = core.Registry.new_key_data(key_template)
66    self.assertEqual(key_data.type_url, key_template.type_url)
67    self.assertEqual(key_data.key_material_type, tink_pb2.KeyData.SYMMETRIC)
68    key = aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingKey.FromString(
69        key_data.value)
70    self.assertEqual(key.version, 0)
71    self.assertLen(key.key_value, 16)
72    self.assertEqual(key.params.hkdf_hash_type, common_pb2.HashType.SHA256)
73    self.assertEqual(key.params.derived_key_size, 16)
74    self.assertEqual(key.params.hmac_params.hash, common_pb2.HashType.SHA256)
75    self.assertEqual(key.params.hmac_params.tag_size, 32)
76    self.assertEqual(key.params.ciphertext_segment_size, 4096)
77
78  def test_invalid_aes_gcm_hkdf_params_throw_exception(self):
79    tmpls = streaming_aead.streaming_aead_key_templates
80    key_template = tmpls.create_aes_gcm_hkdf_streaming_key_template(
81        63, common_pb2.HashType.SHA1, 65, 55)
82    with self.assertRaisesRegex(core.TinkError,
83                                'key_size must not be smaller than'):
84      core.Registry.new_key_data(key_template)
85
86  def test_invalid_aes_ctr_hmac_params_throw_exception(self):
87    tmpls = streaming_aead.streaming_aead_key_templates
88    key_template = tmpls.create_aes_ctr_hmac_streaming_key_template(
89        63, common_pb2.HashType.SHA1, 65, common_pb2.HashType.SHA256, 55, 2)
90    with self.assertRaisesRegex(core.TinkError,
91                                'key_size must not be smaller than'):
92      core.Registry.new_key_data(key_template)
93
94  def test_raw_encrypt_decrypt_readall(self):
95    raw_primitive = new_raw_primitive()
96    plaintext = b'plaintext' + B_X80
97    aad = b'associated_data' + B_X80
98
99    # Encrypt
100    ct_destination = bytes_io.BytesIOWithValueAfterClose()
101    with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es:
102      self.assertLen(plaintext, es.write(plaintext))
103    # context manager closes es, which also closes ciphertext_dest
104    self.assertTrue(ct_destination.closed)
105
106    # Decrypt, with and without close_ciphertext_source
107    for close_ciphertext_source in [True, False]:
108      ct_source = io.BytesIO(ct_destination.value_after_close())
109      with raw_primitive.new_raw_decrypting_stream(
110          ct_source, aad,
111          close_ciphertext_source=close_ciphertext_source) as ds:
112        output = ds.readall()
113      self.assertEqual(ct_source.closed, close_ciphertext_source)
114      self.assertEqual(output, plaintext)
115
116  def test_raw_encrypt_decrypt_read(self):
117    raw_primitive = new_raw_primitive()
118    plaintext = b'plaintext'
119    aad = b'aad'
120
121    ct_destination = bytes_io.BytesIOWithValueAfterClose()
122    with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es:
123      es.write(plaintext)
124
125    ct_source = io.BytesIO(ct_destination.value_after_close())
126    with raw_primitive.new_raw_decrypting_stream(
127        ct_source, aad, close_ciphertext_source=True) as ds:
128      self.assertEqual(ds.read(5), b'plain')
129      self.assertEqual(ds.read(5), b'text')
130
131  def test_raw_encrypt_decrypt_readinto(self):
132    raw_primitive = new_raw_primitive()
133    plaintext = b'plaintext'
134    aad = b'aad'
135
136    ct_destination = bytes_io.BytesIOWithValueAfterClose()
137    with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es:
138      es.write(plaintext)
139
140    ct_source = io.BytesIO(ct_destination.value_after_close())
141    with raw_primitive.new_raw_decrypting_stream(
142        ct_source, aad, close_ciphertext_source=True) as ds:
143      data = bytearray(b'xxxxx')
144      n = ds.readinto(data)  # writes 5 bytes into data.
145      self.assertEqual(n, 5)
146      self.assertEqual(data, b'plain')
147      n = ds.readinto(data)  # writes remaining 4 bytes, leave the rest
148      self.assertEqual(n, 4)
149      self.assertEqual(data, b'textn')
150
151  def test_raw_encrypt_decrypt_empty(self):
152    raw_primitive = new_raw_primitive()
153    plaintext = b''
154    aad = b''
155    ct_destination = bytes_io.BytesIOWithValueAfterClose()
156    with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es:
157      es.write(plaintext)
158
159    ct_source = io.BytesIO(ct_destination.value_after_close())
160    with raw_primitive.new_raw_decrypting_stream(
161        ct_source, aad, close_ciphertext_source=True) as ds:
162      self.assertEqual(ds.read(5), b'')
163
164  def test_raw_read_after_eof_returns_empty_bytes(self):
165    raw_primitive = new_raw_primitive()
166    plaintext = b'plaintext' + B_X80
167    aad = b'associated_data' + B_X80
168
169    ct_destination = bytes_io.BytesIOWithValueAfterClose()
170    with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es:
171      self.assertLen(plaintext, es.write(plaintext))
172
173    ct_source = io.BytesIO(ct_destination.value_after_close())
174    with raw_primitive.new_raw_decrypting_stream(
175        ct_source, aad, close_ciphertext_source=True) as ds:
176      _ = ds.readall()
177      self.assertEqual(ds.read(100), b'')
178
179  def test_raw_encrypt_decrypt_close(self):
180    raw_primitive = new_raw_primitive()
181    plaintext = b'plaintext' + B_X80
182    aad = b'associated_data' + B_X80
183
184    # Encrypt
185    ct_destination = bytes_io.BytesIOWithValueAfterClose()
186    es = raw_primitive.new_raw_encrypting_stream(ct_destination, aad)
187    es.write(plaintext)
188    self.assertFalse(ct_destination.closed)
189    self.assertFalse(es.closed)
190    es.close()
191    self.assertTrue(ct_destination.closed)
192    self.assertTrue(es.closed)
193
194    # Decrypt, with and without close_ciphertext_source
195    for close_ciphertext_source in [True, False]:
196      ct_source = io.BytesIO(ct_destination.value_after_close())
197      ds = raw_primitive.new_raw_decrypting_stream(
198          ct_source, aad,
199          close_ciphertext_source=close_ciphertext_source)
200      self.assertFalse(ct_source.closed)
201      self.assertFalse(ds.closed)
202      ds.close()
203      self.assertEqual(ct_source.closed, close_ciphertext_source)
204      self.assertTrue(ds.closed)
205
206  def test_raw_encrypt_decrypt_wrong_aad(self):
207    raw_primitive = new_raw_primitive()
208    plaintext = b'plaintext' + B_X80
209    aad = b'associated_data' + B_X80
210
211    # Encrypt
212    ct_destination = bytes_io.BytesIOWithValueAfterClose()
213    with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es:
214      self.assertLen(plaintext, es.write(plaintext))
215    self.assertNotEqual(ct_destination.value_after_close(), plaintext)
216
217    # Decrypt
218    ct_source = io.BytesIO(ct_destination.value_after_close())
219    with raw_primitive.new_raw_decrypting_stream(
220        ct_source, b'bad' + aad, close_ciphertext_source=True) as ds:
221      with self.assertRaises(core.TinkError):
222        ds.read()
223
224  @parameterized.parameters([
225      streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_4KB,
226      streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_1MB,
227      streaming_aead.streaming_aead_key_templates.AES256_GCM_HKDF_4KB,
228      streaming_aead.streaming_aead_key_templates.AES256_GCM_HKDF_1MB,
229      streaming_aead.streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_4KB,
230      streaming_aead.streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_1MB,
231      streaming_aead.streaming_aead_key_templates.AES256_CTR_HMAC_SHA256_4KB,
232      streaming_aead.streaming_aead_key_templates.AES256_CTR_HMAC_SHA256_1MB
233  ])
234  def test_encrypt_decrypt_success(self, template):
235    keyset_handle = tink.new_keyset_handle(template)
236    primitive = keyset_handle.primitive(streaming_aead.StreamingAead)
237
238    plaintext = b'plaintext'
239    associated_data = b'associated_data'
240
241    # Encrypt
242    ciphertext_destination = bytes_io.BytesIOWithValueAfterClose()
243    with primitive.new_encrypting_stream(ciphertext_destination,
244                                         associated_data) as encryption_stream:
245      encryption_stream.write(plaintext)
246
247    ciphertext = ciphertext_destination.value_after_close()
248
249    # Decrypt
250    ciphertext_source = io.BytesIO(ciphertext)
251    decrypted = None
252    with primitive.new_decrypting_stream(ciphertext_source,
253                                         associated_data) as decryption_stream:
254      decrypted = decryption_stream.read()
255
256    self.assertEqual(decrypted, plaintext)
257
258if __name__ == '__main__':
259  absltest.main()
260