1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tink.python.tink.streaming_aead.streaming_aead.""" 15 16import io 17import os 18import tempfile 19 20from absl.testing import absltest 21 22import tink 23from tink import streaming_aead 24 25 26def setUpModule(): 27 streaming_aead.register() 28 29 30def get_primitive() -> streaming_aead.StreamingAead: 31 key_template = streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_4KB 32 keyset_handle = tink.new_keyset_handle(key_template) 33 primitive = keyset_handle.primitive(streaming_aead.StreamingAead) 34 return primitive 35 36 37class StreamingAeadTest(absltest.TestCase): 38 """End-to-end test of Streaming AEAD Encrypting/Decrypting Streams.""" 39 40 def test_encrypt_decrypt(self): 41 primitive = get_primitive() 42 long_plaintext = b' '.join(b'%d' % i for i in range(100 * 1000)) 43 aad = b'associated_data' 44 with tempfile.TemporaryDirectory() as tmpdirname: 45 filename = os.path.join(tmpdirname, 'encrypted_file') 46 dest = open(filename, 'wb') 47 with primitive.new_encrypting_stream(dest, aad) as es: 48 n = es.write(long_plaintext) 49 self.assertTrue(dest.closed) 50 self.assertLen(long_plaintext, n) 51 52 src = open(filename, 'rb') 53 with primitive.new_decrypting_stream(src, aad) as ds: 54 output = ds.read() 55 self.assertTrue(src.closed) 56 self.assertEqual(output, long_plaintext) 57 58 def test_encrypt_decrypt_raw(self): 59 primitive = get_primitive() 60 long_plaintext = b' '.join(b'%d' % i for i in range(100 * 1000)) 61 aad = b'associated_data' 62 with tempfile.TemporaryDirectory() as tmpdirname: 63 filename = os.path.join(tmpdirname, 'encrypted_file_raw') 64 dest = open(filename, 'wb', buffering=0) # returns a raw file. 65 with primitive.new_encrypting_stream(dest, aad) as es: 66 n = es.write(long_plaintext) 67 self.assertTrue(dest.closed) 68 self.assertLen(long_plaintext, n) 69 70 src = open(filename, 'rb', buffering=0) # returns a raw file. 71 with primitive.new_decrypting_stream(src, aad) as ds: 72 output = ds.read() 73 self.assertTrue(src.closed) 74 self.assertEqual(output, long_plaintext) 75 76 def test_encrypt_decrypt_textiowrapper(self): 77 primitive = get_primitive() 78 text_lines = [ 79 'ᚻᛖ ᚳᚹᚫᚦ ᚦᚫᛏ ᚻᛖ ᛒᚢᛞᛖ ᚩᚾ ᚦᚫᛗ ᛚᚪᚾᛞᛖ ᚾᚩᚱᚦᚹᛖᚪᚱᛞᚢᛗ ᚹᛁᚦ ᚦᚪ ᚹᛖᛥᚫ\n', 80 '⡌⠁⠧⠑ ⠼⠁⠒ ⡍⠜⠇⠑⠹⠰⠎ ⡣⠕⠌\n', 81 '2H₂ + O₂ ⇌ 2H₂O\n', 82 'smile \n'] 83 aad = b'associated_data' 84 with tempfile.TemporaryDirectory() as tmpdirname: 85 filename = os.path.join(tmpdirname, 'encrypted_textfile') 86 dest = open(filename, 'wb') 87 with io.TextIOWrapper( 88 primitive.new_encrypting_stream(dest, aad), encoding='utf8') as es: 89 es.writelines(text_lines) 90 self.assertTrue(dest.closed) 91 92 src = open(filename, 'rb') 93 with io.TextIOWrapper( 94 primitive.new_decrypting_stream(src, aad), encoding='utf8') as es: 95 for i, text_line in enumerate(es): 96 self.assertEqual(text_line, text_lines[i]) 97 self.assertTrue(src.closed) 98 99 def test_encrypt_fails_on_nonwritable_stream(self): 100 primitive = get_primitive() 101 with tempfile.TemporaryDirectory() as tmpdirname: 102 filename = os.path.join(tmpdirname, 'file') 103 with open(filename, 'wb') as f: 104 f.write(b'data') 105 with open(filename, 'rb') as dest: # dest is not writable 106 with self.assertRaises(ValueError): 107 primitive.new_encrypting_stream(dest, b'aad') 108 109 def test_decrypt_fails_on_nonreadable_stream(self): 110 primitive = get_primitive() 111 with tempfile.TemporaryDirectory() as tmpdirname: 112 # src not readable 113 with open(os.path.join(tmpdirname, 'file2'), 'wb') as src: 114 with self.assertRaises(ValueError): 115 primitive.new_decrypting_stream(src, b'aad') 116 117if __name__ == '__main__': 118 absltest.main() 119