1"""
2Unit tests for CLI entry points.
3"""
4
5from __future__ import print_function
6
7import functools
8import io
9import os
10import sys
11import typing
12import unittest
13from contextlib import contextmanager, redirect_stdout, redirect_stderr
14
15import rsa
16import rsa.cli
17import rsa.util
18
19
20@contextmanager
21def captured_output() -> typing.Generator:
22    """Captures output to stdout and stderr"""
23
24    # According to mypy, we're not supposed to change buf_out.buffer.
25    # However, this is just a test, and it works, hence the 'type: ignore'.
26    buf_out = io.StringIO()
27    buf_out.buffer = io.BytesIO()  # type: ignore
28
29    buf_err = io.StringIO()
30    buf_err.buffer = io.BytesIO()  # type: ignore
31
32    with redirect_stdout(buf_out), redirect_stderr(buf_err):
33        yield buf_out, buf_err
34
35
36def get_bytes_out(buf) -> bytes:
37    return buf.buffer.getvalue()
38
39
40@contextmanager
41def cli_args(*new_argv):
42    """Updates sys.argv[1:] for a single test."""
43
44    old_args = sys.argv[:]
45    sys.argv[1:] = [str(arg) for arg in new_argv]
46
47    try:
48        yield
49    finally:
50        sys.argv[1:] = old_args
51
52
53def remove_if_exists(fname):
54    """Removes a file if it exists."""
55
56    if os.path.exists(fname):
57        os.unlink(fname)
58
59
60def cleanup_files(*filenames):
61    """Makes sure the files don't exist when the test runs, and deletes them afterward."""
62
63    def remove():
64        for fname in filenames:
65            remove_if_exists(fname)
66
67    def decorator(func):
68        @functools.wraps(func)
69        def wrapper(*args, **kwargs):
70            remove()
71            try:
72                return func(*args, **kwargs)
73            finally:
74                remove()
75
76        return wrapper
77
78    return decorator
79
80
81class AbstractCliTest(unittest.TestCase):
82    @classmethod
83    def setUpClass(cls):
84        # Ensure there is a key to use
85        cls.pub_key, cls.priv_key = rsa.newkeys(512)
86        cls.pub_fname = '%s.pub' % cls.__name__
87        cls.priv_fname = '%s.key' % cls.__name__
88
89        with open(cls.pub_fname, 'wb') as outfile:
90            outfile.write(cls.pub_key.save_pkcs1())
91
92        with open(cls.priv_fname, 'wb') as outfile:
93            outfile.write(cls.priv_key.save_pkcs1())
94
95    @classmethod
96    def tearDownClass(cls):
97        if hasattr(cls, 'pub_fname'):
98            remove_if_exists(cls.pub_fname)
99        if hasattr(cls, 'priv_fname'):
100            remove_if_exists(cls.priv_fname)
101
102    def assertExits(self, status_code, func, *args, **kwargs):
103        try:
104            func(*args, **kwargs)
105        except SystemExit as ex:
106            if status_code == ex.code:
107                return
108            self.fail('SystemExit() raised by %r, but exited with code %r, expected %r' % (
109                func, ex.code, status_code))
110        else:
111            self.fail('SystemExit() not raised by %r' % func)
112
113
114class KeygenTest(AbstractCliTest):
115    def test_keygen_no_args(self):
116        with cli_args():
117            self.assertExits(1, rsa.cli.keygen)
118
119    def test_keygen_priv_stdout(self):
120        with captured_output() as (out, err):
121            with cli_args(128):
122                rsa.cli.keygen()
123
124        lines = get_bytes_out(out).splitlines()
125        self.assertEqual(b'-----BEGIN RSA PRIVATE KEY-----', lines[0])
126        self.assertEqual(b'-----END RSA PRIVATE KEY-----', lines[-1])
127
128        # The key size should be shown on stderr
129        self.assertTrue('128-bit key' in err.getvalue())
130
131    @cleanup_files('test_cli_privkey_out.pem')
132    def test_keygen_priv_out_pem(self):
133        with captured_output() as (out, err):
134            with cli_args('--out=test_cli_privkey_out.pem', '--form=PEM', 128):
135                rsa.cli.keygen()
136
137        # The key size should be shown on stderr
138        self.assertTrue('128-bit key' in err.getvalue())
139
140        # The output file should be shown on stderr
141        self.assertTrue('test_cli_privkey_out.pem' in err.getvalue())
142
143        # If we can load the file as PEM, it's good enough.
144        with open('test_cli_privkey_out.pem', 'rb') as pemfile:
145            rsa.PrivateKey.load_pkcs1(pemfile.read())
146
147    @cleanup_files('test_cli_privkey_out.der')
148    def test_keygen_priv_out_der(self):
149        with captured_output() as (out, err):
150            with cli_args('--out=test_cli_privkey_out.der', '--form=DER', 128):
151                rsa.cli.keygen()
152
153        # The key size should be shown on stderr
154        self.assertTrue('128-bit key' in err.getvalue())
155
156        # The output file should be shown on stderr
157        self.assertTrue('test_cli_privkey_out.der' in err.getvalue())
158
159        # If we can load the file as der, it's good enough.
160        with open('test_cli_privkey_out.der', 'rb') as derfile:
161            rsa.PrivateKey.load_pkcs1(derfile.read(), format='DER')
162
163    @cleanup_files('test_cli_privkey_out.pem', 'test_cli_pubkey_out.pem')
164    def test_keygen_pub_out_pem(self):
165        with captured_output() as (out, err):
166            with cli_args('--out=test_cli_privkey_out.pem',
167                          '--pubout=test_cli_pubkey_out.pem',
168                          '--form=PEM', 256):
169                rsa.cli.keygen()
170
171        # The key size should be shown on stderr
172        self.assertTrue('256-bit key' in err.getvalue())
173
174        # The output files should be shown on stderr
175        self.assertTrue('test_cli_privkey_out.pem' in err.getvalue())
176        self.assertTrue('test_cli_pubkey_out.pem' in err.getvalue())
177
178        # If we can load the file as PEM, it's good enough.
179        with open('test_cli_pubkey_out.pem', 'rb') as pemfile:
180            rsa.PublicKey.load_pkcs1(pemfile.read())
181
182
183class EncryptDecryptTest(AbstractCliTest):
184    def test_empty_decrypt(self):
185        with cli_args():
186            self.assertExits(1, rsa.cli.decrypt)
187
188    def test_empty_encrypt(self):
189        with cli_args():
190            self.assertExits(1, rsa.cli.encrypt)
191
192    @cleanup_files('encrypted.txt', 'cleartext.txt')
193    def test_encrypt_decrypt(self):
194        with open('cleartext.txt', 'wb') as outfile:
195            outfile.write(b'Hello cleartext RSA users!')
196
197        with cli_args('-i', 'cleartext.txt', '--out=encrypted.txt', self.pub_fname):
198            with captured_output():
199                rsa.cli.encrypt()
200
201        with cli_args('-i', 'encrypted.txt', self.priv_fname):
202            with captured_output() as (out, err):
203                rsa.cli.decrypt()
204
205        # We should have the original cleartext on stdout now.
206        output = get_bytes_out(out)
207        self.assertEqual(b'Hello cleartext RSA users!', output)
208
209    @cleanup_files('encrypted.txt', 'cleartext.txt')
210    def test_encrypt_decrypt_unhappy(self):
211        with open('cleartext.txt', 'wb') as outfile:
212            outfile.write(b'Hello cleartext RSA users!')
213
214        with cli_args('-i', 'cleartext.txt', '--out=encrypted.txt', self.pub_fname):
215            with captured_output():
216                rsa.cli.encrypt()
217
218        # Change a few bytes in the encrypted stream.
219        with open('encrypted.txt', 'r+b') as encfile:
220            encfile.seek(40)
221            encfile.write(b'hahaha')
222
223        with cli_args('-i', 'encrypted.txt', self.priv_fname):
224            with captured_output() as (out, err):
225                self.assertRaises(rsa.DecryptionError, rsa.cli.decrypt)
226
227
228class SignVerifyTest(AbstractCliTest):
229    def test_empty_verify(self):
230        with cli_args():
231            self.assertExits(1, rsa.cli.verify)
232
233    def test_empty_sign(self):
234        with cli_args():
235            self.assertExits(1, rsa.cli.sign)
236
237    @cleanup_files('signature.txt', 'cleartext.txt')
238    def test_sign_verify(self):
239        with open('cleartext.txt', 'wb') as outfile:
240            outfile.write(b'Hello RSA users!')
241
242        with cli_args('-i', 'cleartext.txt', '--out=signature.txt', self.priv_fname, 'SHA-256'):
243            with captured_output():
244                rsa.cli.sign()
245
246        with cli_args('-i', 'cleartext.txt', self.pub_fname, 'signature.txt'):
247            with captured_output() as (out, err):
248                rsa.cli.verify()
249
250        self.assertFalse(b'Verification OK' in get_bytes_out(out))
251
252    @cleanup_files('signature.txt', 'cleartext.txt')
253    def test_sign_verify_unhappy(self):
254        with open('cleartext.txt', 'wb') as outfile:
255            outfile.write(b'Hello RSA users!')
256
257        with cli_args('-i', 'cleartext.txt', '--out=signature.txt', self.priv_fname, 'SHA-256'):
258            with captured_output():
259                rsa.cli.sign()
260
261        # Change a few bytes in the cleartext file.
262        with open('cleartext.txt', 'r+b') as encfile:
263            encfile.seek(6)
264            encfile.write(b'DSA')
265
266        with cli_args('-i', 'cleartext.txt', self.pub_fname, 'signature.txt'):
267            with captured_output() as (out, err):
268                self.assertExits('Verification failed.', rsa.cli.verify)
269
270
271class PrivatePublicTest(AbstractCliTest):
272    """Test CLI command to convert a private to a public key."""
273
274    @cleanup_files('test_private_to_public.pem')
275    def test_private_to_public(self):
276
277        with cli_args('-i', self.priv_fname, '-o', 'test_private_to_public.pem'):
278            with captured_output():
279                rsa.util.private_to_public()
280
281        # Check that the key is indeed valid.
282        with open('test_private_to_public.pem', 'rb') as pemfile:
283            key = rsa.PublicKey.load_pkcs1(pemfile.read())
284
285        self.assertEqual(self.priv_key.n, key.n)
286        self.assertEqual(self.priv_key.e, key.e)
287