1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tink.python.tink.streaming_aead._streaming_aead_wrapper.""" 15 16import io 17from typing import BinaryIO, cast 18 19from absl.testing import absltest 20from absl.testing import parameterized 21from tink.proto import aes_gcm_hkdf_streaming_pb2 22from tink.proto import common_pb2 23from tink.proto import tink_pb2 24import tink 25from tink import cleartext_keyset_handle 26from tink import streaming_aead 27from tink.testing import bytes_io 28from tink.testing import keyset_builder 29 30 31TEMPLATE = streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_4KB 32TYPE_URL = 'type.googleapis.com/google.crypto.tink.AesGcmHkdfStreamingKey' 33 34 35def setUpModule(): 36 streaming_aead.register() 37 38 39def _encrypt(primitive: streaming_aead.StreamingAead, plaintext: bytes, 40 associated_data: bytes) -> bytes: 41 ciphertext_dest = bytes_io.BytesIOWithValueAfterClose() 42 with primitive.new_encrypting_stream(ciphertext_dest, associated_data) as es: 43 es.write(plaintext) 44 return ciphertext_dest.value_after_close() 45 46 47class StreamingAeadWrapperTest(parameterized.TestCase): 48 49 @parameterized.parameters( 50 [b'plaintext', b'', b'smile \xf0\x9f\x98\x80', b'\xf0\x9f\x98']) 51 def test_encrypt_decrypt_success(self, plaintext): 52 keyset_handle = tink.new_keyset_handle(TEMPLATE) 53 primitive = keyset_handle.primitive(streaming_aead.StreamingAead) 54 55 aad = b'associated_data' 56 ciphertext_dest = bytes_io.BytesIOWithValueAfterClose() 57 with primitive.new_encrypting_stream(ciphertext_dest, aad) as es: 58 self.assertLen(plaintext, es.write(plaintext)) 59 self.assertTrue(ciphertext_dest.closed) 60 61 ciphertext_src = io.BytesIO(ciphertext_dest.value_after_close()) 62 with primitive.new_decrypting_stream(ciphertext_src, aad) as ds: 63 output = ds.read() 64 self.assertTrue(ciphertext_src.closed) 65 self.assertEqual(output, plaintext) 66 67 def test_long_plaintext_encrypt_decrypt_success(self): 68 keyset_handle = tink.new_keyset_handle(TEMPLATE) 69 primitive = keyset_handle.primitive(streaming_aead.StreamingAead) 70 71 long_plaintext = b' '.join(b'%d' % i for i in range(10 * 1000 * 1000)) 72 aad = b'associated_data' 73 ciphertext_dest = bytes_io.BytesIOWithValueAfterClose() 74 with primitive.new_encrypting_stream(ciphertext_dest, aad) as es: 75 self.assertLen(long_plaintext, es.write(long_plaintext)) 76 self.assertTrue(ciphertext_dest.closed) 77 78 ciphertext_src = io.BytesIO(ciphertext_dest.value_after_close()) 79 with primitive.new_decrypting_stream(ciphertext_src, aad) as ds: 80 output = ds.read() 81 self.assertTrue(ciphertext_src.closed) 82 self.assertEqual(output, long_plaintext) 83 84 @parameterized.parameters( 85 [bytes_io.SlowBytesIO, bytes_io.SlowReadableRawBytes]) 86 def test_slow_encrypt_decrypt_success(self, input_stream_factory): 87 keyset_handle = tink.new_keyset_handle(TEMPLATE) 88 primitive = keyset_handle.primitive(streaming_aead.StreamingAead) 89 plaintext = b' '.join(b'%d' % i for i in range(10 * 1000)) 90 aad = b'associated_data' 91 ciphertext = _encrypt(primitive, plaintext, aad) 92 93 # Even if the ciphertext source only returns small data chunks and sometimes 94 # None, calling read() should return the whole ciphertext. 95 ciphertext_src = cast(BinaryIO, input_stream_factory(ciphertext)) 96 with primitive.new_decrypting_stream(ciphertext_src, aad) as ds: 97 output = ds.read() 98 self.assertTrue(ciphertext_src.closed) 99 self.assertEqual(output, plaintext) 100 101 def test_encrypt_decrypt_bad_aad(self): 102 keyset_handle = tink.new_keyset_handle(TEMPLATE) 103 primitive = keyset_handle.primitive(streaming_aead.StreamingAead) 104 105 plaintext = b'plaintext' 106 aad = b'associated_data' 107 108 ciphertext_dest = bytes_io.BytesIOWithValueAfterClose() 109 with primitive.new_encrypting_stream(ciphertext_dest, aad) as es: 110 self.assertLen(plaintext, es.write(plaintext)) 111 self.assertTrue(ciphertext_dest.closed) 112 113 ciphertext_src = io.BytesIO(ciphertext_dest.value_after_close()) 114 with primitive.new_decrypting_stream(ciphertext_src, b'bad aad') as ds: 115 with self.assertRaises(tink.TinkError): 116 _ = ds.read() 117 118 def test_decrypt_unknown_key_fails(self): 119 plaintext = b'plaintext' 120 aad = b'associated_data' 121 122 unknown_keyset_handle = tink.new_keyset_handle(TEMPLATE) 123 unknown_primitive = unknown_keyset_handle.primitive( 124 streaming_aead.StreamingAead) 125 unknown_ciphertext_dest = bytes_io.BytesIOWithValueAfterClose() 126 with unknown_primitive.new_encrypting_stream(unknown_ciphertext_dest, 127 aad) as es: 128 es.write(plaintext) 129 130 keyset_handle = tink.new_keyset_handle(TEMPLATE) 131 primitive = keyset_handle.primitive(streaming_aead.StreamingAead) 132 ciphertext_src = io.BytesIO(unknown_ciphertext_dest.value_after_close()) 133 with primitive.new_decrypting_stream(ciphertext_src, aad) as ds: 134 with self.assertRaises(tink.TinkError): 135 _ = ds.read() 136 137 @parameterized.parameters( 138 [io.BytesIO, bytes_io.SlowBytesIO, bytes_io.SlowReadableRawBytes]) 139 def test_encrypt_decrypt_with_key_rotation(self, input_stream_factory): 140 builder = keyset_builder.new_keyset_builder() 141 older_key_id = builder.add_new_key(TEMPLATE) 142 builder.set_primary_key(older_key_id) 143 p1 = builder.keyset_handle().primitive(streaming_aead.StreamingAead) 144 145 newer_key_id = builder.add_new_key(TEMPLATE) 146 p2 = builder.keyset_handle().primitive(streaming_aead.StreamingAead) 147 148 builder.set_primary_key(newer_key_id) 149 p3 = builder.keyset_handle().primitive(streaming_aead.StreamingAead) 150 151 builder.disable_key(older_key_id) 152 p4 = builder.keyset_handle().primitive(streaming_aead.StreamingAead) 153 154 self.assertNotEqual(older_key_id, newer_key_id) 155 156 # p1 encrypts with the older key. So p1, p2 and p3 can decrypt it, 157 # but not p4. 158 plaintext1 = b' '.join(b'%d' % i for i in range(100 * 101)) 159 ciphertext1 = _encrypt(p1, plaintext1, b'aad1') 160 with p1.new_decrypting_stream( 161 cast(BinaryIO, input_stream_factory(ciphertext1)), b'aad1') as ds: 162 self.assertEqual(ds.read(), plaintext1) 163 with p2.new_decrypting_stream( 164 cast(BinaryIO, input_stream_factory(ciphertext1)), b'aad1') as ds: 165 self.assertEqual(ds.read(), plaintext1) 166 with p3.new_decrypting_stream( 167 cast(BinaryIO, input_stream_factory(ciphertext1)), b'aad1') as ds: 168 self.assertEqual(ds.read(), plaintext1) 169 with p4.new_decrypting_stream( 170 cast(BinaryIO, input_stream_factory(ciphertext1)), b'aad1') as ds: 171 with self.assertRaises(tink.TinkError): 172 ds.read() 173 174 # p2 encrypts with the older key. So p1, p2 and p3 can decrypt it, 175 # but not p4. 176 plaintext2 = b' '.join(b'%d' % i for i in range(100 * 102)) 177 ciphertext2 = _encrypt(p2, plaintext2, b'aad2') 178 with p1.new_decrypting_stream( 179 cast(BinaryIO, input_stream_factory(ciphertext2)), b'aad2') as ds: 180 self.assertEqual(ds.read(), plaintext2) 181 with p2.new_decrypting_stream( 182 cast(BinaryIO, input_stream_factory(ciphertext2)), b'aad2') as ds: 183 self.assertEqual(ds.read(), plaintext2) 184 with p3.new_decrypting_stream( 185 cast(BinaryIO, input_stream_factory(ciphertext2)), b'aad2') as ds: 186 self.assertEqual(ds.read(), plaintext2) 187 with p4.new_decrypting_stream( 188 cast(BinaryIO, input_stream_factory(ciphertext2)), b'aad2') as ds: 189 with self.assertRaises(tink.TinkError): 190 ds.read() 191 192 # p3 encrypts with the newer key. So p2, p3 and p4 can decrypt it, 193 # but not p1. 194 plaintext3 = b' '.join(b'%d' % i for i in range(100 * 103)) 195 ciphertext3 = _encrypt(p3, plaintext3, b'aad3') 196 with p1.new_decrypting_stream( 197 cast(BinaryIO, input_stream_factory(ciphertext3)), b'aad3') as ds: 198 with self.assertRaises(tink.TinkError): 199 ds.read() 200 with p2.new_decrypting_stream( 201 cast(BinaryIO, input_stream_factory(ciphertext3)), b'aad3') as ds: 202 self.assertEqual(ds.read(), plaintext3) 203 with p3.new_decrypting_stream( 204 cast(BinaryIO, input_stream_factory(ciphertext3)), b'aad3') as ds: 205 self.assertEqual(ds.read(), plaintext3) 206 with p4.new_decrypting_stream( 207 cast(BinaryIO, input_stream_factory(ciphertext3)), b'aad3') as ds: 208 self.assertEqual(ds.read(), plaintext3) 209 210 # p4 encrypts with the newer key. So p2, p3 and p4 can decrypt it, 211 # but not p1. 212 plaintext4 = b' '.join(b'%d' % i for i in range(100 * 104)) 213 ciphertext4 = _encrypt(p4, plaintext4, b'aad4') 214 with p1.new_decrypting_stream( 215 cast(BinaryIO, input_stream_factory(ciphertext4)), b'aad4') as ds: 216 with self.assertRaises(tink.TinkError): 217 ds.read() 218 with p2.new_decrypting_stream( 219 cast(BinaryIO, input_stream_factory(ciphertext4)), b'aad4') as ds: 220 self.assertEqual(ds.read(), plaintext4) 221 with p3.new_decrypting_stream( 222 cast(BinaryIO, input_stream_factory(ciphertext4)), b'aad4') as ds: 223 self.assertEqual(ds.read(), plaintext4) 224 with p4.new_decrypting_stream( 225 cast(BinaryIO, input_stream_factory(ciphertext4)), b'aad4') as ds: 226 self.assertEqual(ds.read(), plaintext4) 227 228 def test_decrypt_tink_output_prefix(self): 229 key = aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey( 230 version=0, 231 params=aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingParams( 232 ciphertext_segment_size=512, 233 derived_key_size=16, 234 hkdf_hash_type=common_pb2.HashType.SHA256, 235 ), 236 key_value=b'0123456789abcdef', 237 ) 238 value1 = key.SerializeToString() 239 key.key_value = b'ABCDEF0123456789' 240 value2 = key.SerializeToString() 241 242 # We use two keys so that we have at least 1 raw key in the keyset: Tink 243 # has a check that creating a new StreamingAead fails when the keyset does 244 # not contain any raw keys, and we only want this to fail on decryption. 245 keyset = tink_pb2.Keyset( 246 primary_key_id=1, 247 key=[ 248 tink_pb2.Keyset.Key( 249 key_data=tink_pb2.KeyData( 250 type_url=TYPE_URL, 251 value=value1, 252 key_material_type=tink_pb2.KeyData.SYMMETRIC, 253 ), 254 output_prefix_type=tink_pb2.OutputPrefixType.TINK, 255 status=tink_pb2.KeyStatusType.ENABLED, 256 key_id=1, 257 ), 258 tink_pb2.Keyset.Key( 259 key_data=tink_pb2.KeyData( 260 type_url=TYPE_URL, 261 value=value2, 262 key_material_type=tink_pb2.KeyData.SYMMETRIC, 263 ), 264 output_prefix_type=tink_pb2.OutputPrefixType.RAW, 265 status=tink_pb2.KeyStatusType.ENABLED, 266 key_id=2, 267 ), 268 ], 269 ) 270 271 keyset_handle = cleartext_keyset_handle.from_keyset(keyset) 272 primitive = keyset_handle.primitive(streaming_aead.StreamingAead) 273 274 plaintext = b'plaintext' 275 associated_data = b'associated_data' 276 277 ciphertext_dest = bytes_io.BytesIOWithValueAfterClose() 278 with primitive.new_encrypting_stream( 279 ciphertext_dest, associated_data 280 ) as es: 281 self.assertLen(plaintext, es.write(plaintext)) 282 self.assertTrue(ciphertext_dest.closed) 283 284 ciphertext_src = io.BytesIO(ciphertext_dest.value_after_close()) 285 with primitive.new_decrypting_stream(ciphertext_src, associated_data) as ds: 286 output = ds.read() 287 self.assertTrue(ciphertext_src.closed) 288 self.assertEqual(output, plaintext) 289 290 291if __name__ == '__main__': 292 absltest.main() 293