xref: /aosp_15_r20/external/tink/python/tink/streaming_aead/_streaming_aead_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tink.python.tink.streaming_aead.streaming_aead."""
15
16import io
17import os
18import tempfile
19
20from absl.testing import absltest
21
22import tink
23from tink import streaming_aead
24
25
26def setUpModule():
27  streaming_aead.register()
28
29
30def get_primitive() -> streaming_aead.StreamingAead:
31  key_template = streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_4KB
32  keyset_handle = tink.new_keyset_handle(key_template)
33  primitive = keyset_handle.primitive(streaming_aead.StreamingAead)
34  return primitive
35
36
37class StreamingAeadTest(absltest.TestCase):
38  """End-to-end test of Streaming AEAD Encrypting/Decrypting Streams."""
39
40  def test_encrypt_decrypt(self):
41    primitive = get_primitive()
42    long_plaintext = b' '.join(b'%d' % i for i in range(100 * 1000))
43    aad = b'associated_data'
44    with tempfile.TemporaryDirectory() as tmpdirname:
45      filename = os.path.join(tmpdirname, 'encrypted_file')
46      dest = open(filename, 'wb')
47      with primitive.new_encrypting_stream(dest, aad) as es:
48        n = es.write(long_plaintext)
49      self.assertTrue(dest.closed)
50      self.assertLen(long_plaintext, n)
51
52      src = open(filename, 'rb')
53      with primitive.new_decrypting_stream(src, aad) as ds:
54        output = ds.read()
55      self.assertTrue(src.closed)
56      self.assertEqual(output, long_plaintext)
57
58  def test_encrypt_decrypt_raw(self):
59    primitive = get_primitive()
60    long_plaintext = b' '.join(b'%d' % i for i in range(100 * 1000))
61    aad = b'associated_data'
62    with tempfile.TemporaryDirectory() as tmpdirname:
63      filename = os.path.join(tmpdirname, 'encrypted_file_raw')
64      dest = open(filename, 'wb', buffering=0)  # returns a raw file.
65      with primitive.new_encrypting_stream(dest, aad) as es:
66        n = es.write(long_plaintext)
67      self.assertTrue(dest.closed)
68      self.assertLen(long_plaintext, n)
69
70      src = open(filename, 'rb', buffering=0)  # returns a raw file.
71      with primitive.new_decrypting_stream(src, aad) as ds:
72        output = ds.read()
73      self.assertTrue(src.closed)
74      self.assertEqual(output, long_plaintext)
75
76  def test_encrypt_decrypt_textiowrapper(self):
77    primitive = get_primitive()
78    text_lines = [
79        'ᚻᛖ ᚳᚹᚫᚦ ᚦᚫᛏ ᚻᛖ ᛒᚢᛞᛖ ᚩᚾ ᚦᚫᛗ ᛚᚪᚾᛞᛖ ᚾᚩᚱᚦᚹᛖᚪᚱᛞᚢᛗ ᚹᛁᚦ ᚦᚪ ᚹᛖᛥᚫ\n',
80        '⡌⠁⠧⠑ ⠼⠁⠒  ⡍⠜⠇⠑⠹⠰⠎ ⡣⠕⠌\n',
81        '2H₂ + O₂ ⇌ 2H₂O\n',
82        'smile ��\n']
83    aad = b'associated_data'
84    with tempfile.TemporaryDirectory() as tmpdirname:
85      filename = os.path.join(tmpdirname, 'encrypted_textfile')
86      dest = open(filename, 'wb')
87      with io.TextIOWrapper(
88          primitive.new_encrypting_stream(dest, aad), encoding='utf8') as es:
89        es.writelines(text_lines)
90      self.assertTrue(dest.closed)
91
92      src = open(filename, 'rb')
93      with io.TextIOWrapper(
94          primitive.new_decrypting_stream(src, aad), encoding='utf8') as es:
95        for i, text_line in enumerate(es):
96          self.assertEqual(text_line, text_lines[i])
97      self.assertTrue(src.closed)
98
99  def test_encrypt_fails_on_nonwritable_stream(self):
100    primitive = get_primitive()
101    with tempfile.TemporaryDirectory() as tmpdirname:
102      filename = os.path.join(tmpdirname, 'file')
103      with open(filename, 'wb') as f:
104        f.write(b'data')
105      with open(filename, 'rb') as dest:  # dest is not writable
106        with self.assertRaises(ValueError):
107          primitive.new_encrypting_stream(dest, b'aad')
108
109  def test_decrypt_fails_on_nonreadable_stream(self):
110    primitive = get_primitive()
111    with tempfile.TemporaryDirectory() as tmpdirname:
112      # src not readable
113      with open(os.path.join(tmpdirname, 'file2'), 'wb') as src:
114        with self.assertRaises(ValueError):
115          primitive.new_decrypting_stream(src, b'aad')
116
117if __name__ == '__main__':
118  absltest.main()
119