xref: /aosp_15_r20/external/tink/python/tink/streaming_aead/_encrypting_stream_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tink.python.tink.streaming_aead.encrypting_stream."""
15
16import io
17from typing import cast
18
19from absl.testing import absltest
20
21from tink import core
22from tink import streaming_aead
23from tink.streaming_aead import _raw_streaming_aead
24
25# Using malformed UTF-8 sequences to ensure there is no accidental decoding.
26B_X80 = b'\x80'
27B_AAD_ = b'aa' + B_X80
28B_ASSOC_ = b'asso' + B_X80
29
30
31def setUpModule():
32  streaming_aead.register()
33
34
35def get_raw_primitive():
36  key_data = core.Registry.new_key_data(
37      streaming_aead.streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_4KB)
38  return core.Registry.primitive(key_data, _raw_streaming_aead.RawStreamingAead)
39
40
41class EncryptingStreamTest(absltest.TestCase):
42
43  def test_write_non_bytes(self):
44    f = io.BytesIO()
45    with get_raw_primitive().new_raw_encrypting_stream(f, B_AAD_) as es:
46      with self.assertRaisesRegex(TypeError, 'bytes-like object is required'):
47        es.write(cast(bytes, 'This is a string, not a bytes object'))
48
49  def test_flush(self):
50    f = io.BytesIO()
51    with get_raw_primitive().new_raw_encrypting_stream(f, B_ASSOC_) as es:
52      es.write(b'Hello world!' + B_X80)
53      es.flush()
54
55  def test_closed(self):
56    f = io.BytesIO()
57    es = get_raw_primitive().new_raw_encrypting_stream(f, B_ASSOC_)
58    es.write(b'Hello world!' + B_X80)
59    es.close()
60
61    self.assertTrue(es.closed)
62    self.assertTrue(f.closed)
63
64  def test_closed_methods_raise(self):
65    f = io.BytesIO()
66    es = get_raw_primitive().new_raw_encrypting_stream(f, B_ASSOC_)
67    es.write(b'Hello world!' + B_X80)
68    es.close()
69
70    with self.assertRaisesRegex(ValueError, 'closed'):
71      es.write(b'Goodbye world.' + B_X80)
72    with self.assertRaisesRegex(ValueError, 'closed'):
73      with es:
74        pass
75    with self.assertRaisesRegex(ValueError, 'closed'):
76      es.flush()
77
78  def test_unsupported_operation(self):
79    f = io.BytesIO()
80    with get_raw_primitive().new_raw_encrypting_stream(f, B_ASSOC_) as es:
81      with self.assertRaises(io.UnsupportedOperation):
82        es.seek(0, 2)
83      with self.assertRaises(io.UnsupportedOperation):
84        es.truncate(0)
85      with self.assertRaises(io.UnsupportedOperation):
86        es.read(-1)
87
88  def test_inquiries(self):
89    f = io.BytesIO()
90    with get_raw_primitive().new_raw_encrypting_stream(f, B_ASSOC_) as es:
91      self.assertTrue(es.writable())
92      self.assertFalse(es.readable())
93      self.assertFalse(es.seekable())
94
95  def test_context_manager_exception_closes_dest_file(self):
96    """Tests that exceptional exits trigger normal file closure.
97
98    Any other behaviour seems to be difficult to implement, since standard
99    file wrappers (such as io.BufferedWriter, or io.TextIOWrapper) will always
100    close the wrapped file, even if an error was raised.
101    """
102    ciphertext_destination = io.BytesIO()
103    with self.assertRaisesRegex(ValueError, 'raised inside'):
104      with get_raw_primitive().new_raw_encrypting_stream(
105          ciphertext_destination, B_ASSOC_) as es:
106        es.write(b'some message' + B_X80)
107        raise ValueError('Error raised inside context manager')
108    self.assertTrue(ciphertext_destination.closed)
109
110
111if __name__ == '__main__':
112  absltest.main()
113