1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tink.python.tink.streaming_aead_key_manager.""" 15 16import io 17 18from absl.testing import absltest 19from absl.testing import parameterized 20from tink.proto import aes_ctr_hmac_streaming_pb2 21from tink.proto import aes_gcm_hkdf_streaming_pb2 22from tink.proto import common_pb2 23from tink.proto import tink_pb2 24import tink 25from tink import core 26from tink import streaming_aead 27from tink.streaming_aead import _raw_streaming_aead 28from tink.testing import bytes_io 29 30# Using malformed UTF-8 sequences to ensure there is no accidental decoding. 31B_X80 = b'\x80' 32 33 34def setUpModule(): 35 streaming_aead.register() 36 37 38def new_raw_primitive(): 39 key_data = core.Registry.new_key_data( 40 streaming_aead.streaming_aead_key_templates 41 .AES128_CTR_HMAC_SHA256_4KB) 42 return core.Registry.primitive(key_data, 43 _raw_streaming_aead.RawStreamingAead) 44 45 46class StreamingAeadKeyManagerTest(parameterized.TestCase): 47 48 def test_new_aes_gcm_hkdf_key_data(self): 49 key_template = ( 50 streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_4KB) 51 key_data = core.Registry.new_key_data(key_template) 52 self.assertEqual(key_data.type_url, key_template.type_url) 53 self.assertEqual(key_data.key_material_type, tink_pb2.KeyData.SYMMETRIC) 54 key = aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey.FromString( 55 key_data.value) 56 self.assertEqual(key.version, 0) 57 self.assertLen(key.key_value, 16) 58 self.assertEqual(key.params.hkdf_hash_type, common_pb2.HashType.SHA256) 59 self.assertEqual(key.params.derived_key_size, 16) 60 self.assertEqual(key.params.ciphertext_segment_size, 4096) 61 62 def test_new_aes_ctr_hmac_key_data(self): 63 key_template = ( 64 streaming_aead.streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_4KB) 65 key_data = core.Registry.new_key_data(key_template) 66 self.assertEqual(key_data.type_url, key_template.type_url) 67 self.assertEqual(key_data.key_material_type, tink_pb2.KeyData.SYMMETRIC) 68 key = aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingKey.FromString( 69 key_data.value) 70 self.assertEqual(key.version, 0) 71 self.assertLen(key.key_value, 16) 72 self.assertEqual(key.params.hkdf_hash_type, common_pb2.HashType.SHA256) 73 self.assertEqual(key.params.derived_key_size, 16) 74 self.assertEqual(key.params.hmac_params.hash, common_pb2.HashType.SHA256) 75 self.assertEqual(key.params.hmac_params.tag_size, 32) 76 self.assertEqual(key.params.ciphertext_segment_size, 4096) 77 78 def test_invalid_aes_gcm_hkdf_params_throw_exception(self): 79 tmpls = streaming_aead.streaming_aead_key_templates 80 key_template = tmpls.create_aes_gcm_hkdf_streaming_key_template( 81 63, common_pb2.HashType.SHA1, 65, 55) 82 with self.assertRaisesRegex(core.TinkError, 83 'key_size must not be smaller than'): 84 core.Registry.new_key_data(key_template) 85 86 def test_invalid_aes_ctr_hmac_params_throw_exception(self): 87 tmpls = streaming_aead.streaming_aead_key_templates 88 key_template = tmpls.create_aes_ctr_hmac_streaming_key_template( 89 63, common_pb2.HashType.SHA1, 65, common_pb2.HashType.SHA256, 55, 2) 90 with self.assertRaisesRegex(core.TinkError, 91 'key_size must not be smaller than'): 92 core.Registry.new_key_data(key_template) 93 94 def test_raw_encrypt_decrypt_readall(self): 95 raw_primitive = new_raw_primitive() 96 plaintext = b'plaintext' + B_X80 97 aad = b'associated_data' + B_X80 98 99 # Encrypt 100 ct_destination = bytes_io.BytesIOWithValueAfterClose() 101 with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es: 102 self.assertLen(plaintext, es.write(plaintext)) 103 # context manager closes es, which also closes ciphertext_dest 104 self.assertTrue(ct_destination.closed) 105 106 # Decrypt, with and without close_ciphertext_source 107 for close_ciphertext_source in [True, False]: 108 ct_source = io.BytesIO(ct_destination.value_after_close()) 109 with raw_primitive.new_raw_decrypting_stream( 110 ct_source, aad, 111 close_ciphertext_source=close_ciphertext_source) as ds: 112 output = ds.readall() 113 self.assertEqual(ct_source.closed, close_ciphertext_source) 114 self.assertEqual(output, plaintext) 115 116 def test_raw_encrypt_decrypt_read(self): 117 raw_primitive = new_raw_primitive() 118 plaintext = b'plaintext' 119 aad = b'aad' 120 121 ct_destination = bytes_io.BytesIOWithValueAfterClose() 122 with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es: 123 es.write(plaintext) 124 125 ct_source = io.BytesIO(ct_destination.value_after_close()) 126 with raw_primitive.new_raw_decrypting_stream( 127 ct_source, aad, close_ciphertext_source=True) as ds: 128 self.assertEqual(ds.read(5), b'plain') 129 self.assertEqual(ds.read(5), b'text') 130 131 def test_raw_encrypt_decrypt_readinto(self): 132 raw_primitive = new_raw_primitive() 133 plaintext = b'plaintext' 134 aad = b'aad' 135 136 ct_destination = bytes_io.BytesIOWithValueAfterClose() 137 with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es: 138 es.write(plaintext) 139 140 ct_source = io.BytesIO(ct_destination.value_after_close()) 141 with raw_primitive.new_raw_decrypting_stream( 142 ct_source, aad, close_ciphertext_source=True) as ds: 143 data = bytearray(b'xxxxx') 144 n = ds.readinto(data) # writes 5 bytes into data. 145 self.assertEqual(n, 5) 146 self.assertEqual(data, b'plain') 147 n = ds.readinto(data) # writes remaining 4 bytes, leave the rest 148 self.assertEqual(n, 4) 149 self.assertEqual(data, b'textn') 150 151 def test_raw_encrypt_decrypt_empty(self): 152 raw_primitive = new_raw_primitive() 153 plaintext = b'' 154 aad = b'' 155 ct_destination = bytes_io.BytesIOWithValueAfterClose() 156 with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es: 157 es.write(plaintext) 158 159 ct_source = io.BytesIO(ct_destination.value_after_close()) 160 with raw_primitive.new_raw_decrypting_stream( 161 ct_source, aad, close_ciphertext_source=True) as ds: 162 self.assertEqual(ds.read(5), b'') 163 164 def test_raw_read_after_eof_returns_empty_bytes(self): 165 raw_primitive = new_raw_primitive() 166 plaintext = b'plaintext' + B_X80 167 aad = b'associated_data' + B_X80 168 169 ct_destination = bytes_io.BytesIOWithValueAfterClose() 170 with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es: 171 self.assertLen(plaintext, es.write(plaintext)) 172 173 ct_source = io.BytesIO(ct_destination.value_after_close()) 174 with raw_primitive.new_raw_decrypting_stream( 175 ct_source, aad, close_ciphertext_source=True) as ds: 176 _ = ds.readall() 177 self.assertEqual(ds.read(100), b'') 178 179 def test_raw_encrypt_decrypt_close(self): 180 raw_primitive = new_raw_primitive() 181 plaintext = b'plaintext' + B_X80 182 aad = b'associated_data' + B_X80 183 184 # Encrypt 185 ct_destination = bytes_io.BytesIOWithValueAfterClose() 186 es = raw_primitive.new_raw_encrypting_stream(ct_destination, aad) 187 es.write(plaintext) 188 self.assertFalse(ct_destination.closed) 189 self.assertFalse(es.closed) 190 es.close() 191 self.assertTrue(ct_destination.closed) 192 self.assertTrue(es.closed) 193 194 # Decrypt, with and without close_ciphertext_source 195 for close_ciphertext_source in [True, False]: 196 ct_source = io.BytesIO(ct_destination.value_after_close()) 197 ds = raw_primitive.new_raw_decrypting_stream( 198 ct_source, aad, 199 close_ciphertext_source=close_ciphertext_source) 200 self.assertFalse(ct_source.closed) 201 self.assertFalse(ds.closed) 202 ds.close() 203 self.assertEqual(ct_source.closed, close_ciphertext_source) 204 self.assertTrue(ds.closed) 205 206 def test_raw_encrypt_decrypt_wrong_aad(self): 207 raw_primitive = new_raw_primitive() 208 plaintext = b'plaintext' + B_X80 209 aad = b'associated_data' + B_X80 210 211 # Encrypt 212 ct_destination = bytes_io.BytesIOWithValueAfterClose() 213 with raw_primitive.new_raw_encrypting_stream(ct_destination, aad) as es: 214 self.assertLen(plaintext, es.write(plaintext)) 215 self.assertNotEqual(ct_destination.value_after_close(), plaintext) 216 217 # Decrypt 218 ct_source = io.BytesIO(ct_destination.value_after_close()) 219 with raw_primitive.new_raw_decrypting_stream( 220 ct_source, b'bad' + aad, close_ciphertext_source=True) as ds: 221 with self.assertRaises(core.TinkError): 222 ds.read() 223 224 @parameterized.parameters([ 225 streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_4KB, 226 streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_1MB, 227 streaming_aead.streaming_aead_key_templates.AES256_GCM_HKDF_4KB, 228 streaming_aead.streaming_aead_key_templates.AES256_GCM_HKDF_1MB, 229 streaming_aead.streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_4KB, 230 streaming_aead.streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_1MB, 231 streaming_aead.streaming_aead_key_templates.AES256_CTR_HMAC_SHA256_4KB, 232 streaming_aead.streaming_aead_key_templates.AES256_CTR_HMAC_SHA256_1MB 233 ]) 234 def test_encrypt_decrypt_success(self, template): 235 keyset_handle = tink.new_keyset_handle(template) 236 primitive = keyset_handle.primitive(streaming_aead.StreamingAead) 237 238 plaintext = b'plaintext' 239 associated_data = b'associated_data' 240 241 # Encrypt 242 ciphertext_destination = bytes_io.BytesIOWithValueAfterClose() 243 with primitive.new_encrypting_stream(ciphertext_destination, 244 associated_data) as encryption_stream: 245 encryption_stream.write(plaintext) 246 247 ciphertext = ciphertext_destination.value_after_close() 248 249 # Decrypt 250 ciphertext_source = io.BytesIO(ciphertext) 251 decrypted = None 252 with primitive.new_decrypting_stream(ciphertext_source, 253 associated_data) as decryption_stream: 254 decrypted = decryption_stream.read() 255 256 self.assertEqual(decrypted, plaintext) 257 258if __name__ == '__main__': 259 absltest.main() 260