1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for exception (non-)propagation in Pybind11PythonFileObjectAdapter."""
15
16import io
17
18from absl.testing import absltest
19
20import tink
21from tink import streaming_aead
22
23
24def setUpModule():
25  streaming_aead.register()
26
27
28def get_primitive() -> streaming_aead.StreamingAead:
29  key_template = streaming_aead.streaming_aead_key_templates.AES128_GCM_HKDF_4KB
30  keyset_handle = tink.new_keyset_handle(key_template)
31  primitive = keyset_handle.primitive(streaming_aead.StreamingAead)
32  return primitive
33
34
35class BytesIOThatThrowsExceptionsOnReadWrite(io.BytesIO):
36
37  def write(self, data):
38    raise tink.TinkError('Called write!')
39
40  def read(self, num):
41    raise tink.TinkError('Called read!')
42
43  def close(self):
44    pass
45
46
47class BytesIOThatThrowsExceptionsOnClose(io.BytesIO):
48
49  def write(self, data):
50    return len(data)
51
52  def read(self, _):
53    return b''
54
55  def close(self):
56    raise tink.TinkError('Called close!')
57
58
59class Pybind11PythonFileObjectAdaterTest(absltest.TestCase):
60
61  # This and the following tests do not use the `with` statement. This is done
62  # for two reasons:
63  # 1. consistency with the `test_close_throws()`: there, exit from the
64  #    context created by the `with` statement causes the `close()` function
65  #    to be called after `assertRaises()` verified that it throws -- thus
66  #    one more exception is raised, and the test fails.
67  # 2. avoiding similar unexpected sideffects in the other tests
68  def test_write_throws(self):
69    streaming_aead_primitive = get_primitive()
70
71    ciphertext_destination = BytesIOThatThrowsExceptionsOnReadWrite()
72    enc_stream = streaming_aead_primitive.new_encrypting_stream(
73        ciphertext_destination, b'associated_data')
74    # The exception is thrown but swallowed on the way.
75    _ = enc_stream.write(b'plaintext')
76    # The exception is thrown and is not swallowed.
77    self.assertRaises(tink.TinkError, enc_stream.close)
78
79  def test_read_throws(self):
80    streaming_aead_primitive = get_primitive()
81
82    ciphertext_source = BytesIOThatThrowsExceptionsOnReadWrite()
83    dec_stream = streaming_aead_primitive.new_decrypting_stream(
84        ciphertext_source, b'associated_data')
85    self.assertRaises(tink.TinkError, dec_stream.read)
86    dec_stream.close()
87
88  def test_close_throws(self):
89    streaming_aead_primitive = get_primitive()
90
91    ciphertext_destination = BytesIOThatThrowsExceptionsOnClose()
92    enc_stream = streaming_aead_primitive.new_encrypting_stream(
93        ciphertext_destination, b'associated_data')
94    self.assertRaises(tink.TinkError, enc_stream.close)
95
96
97if __name__ == '__main__':
98  absltest.main()
99