1# Test the support for SSL and sockets
2
3import sys
4import unittest
5import unittest.mock
6from test import support
7from test.support import import_helper
8from test.support import os_helper
9from test.support import socket_helper
10from test.support import threading_helper
11from test.support import warnings_helper
12import socket
13import select
14import time
15import enum
16import gc
17import os
18import errno
19import pprint
20import urllib.request
21import threading
22import traceback
23import weakref
24import platform
25import sysconfig
26import functools
27try:
28    import ctypes
29except ImportError:
30    ctypes = None
31
32
33asyncore = warnings_helper.import_deprecated('asyncore')
34
35
36ssl = import_helper.import_module("ssl")
37import _ssl
38
39from ssl import TLSVersion, _TLSContentType, _TLSMessageType, _TLSAlertType
40
41Py_DEBUG = hasattr(sys, 'gettotalrefcount')
42Py_DEBUG_WIN32 = Py_DEBUG and sys.platform == 'win32'
43
44PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
45HOST = socket_helper.HOST
46IS_OPENSSL_3_0_0 = ssl.OPENSSL_VERSION_INFO >= (3, 0, 0)
47PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')
48
49PROTOCOL_TO_TLS_VERSION = {}
50for proto, ver in (
51    ("PROTOCOL_SSLv23", "SSLv3"),
52    ("PROTOCOL_TLSv1", "TLSv1"),
53    ("PROTOCOL_TLSv1_1", "TLSv1_1"),
54):
55    try:
56        proto = getattr(ssl, proto)
57        ver = getattr(ssl.TLSVersion, ver)
58    except AttributeError:
59        continue
60    PROTOCOL_TO_TLS_VERSION[proto] = ver
61
62def data_file(*name):
63    return os.path.join(os.path.dirname(__file__), *name)
64
65# The custom key and certificate files used in test_ssl are generated
66# using Lib/test/make_ssl_certs.py.
67# Other certificates are simply fetched from the internet servers they
68# are meant to authenticate.
69
70CERTFILE = data_file("keycert.pem")
71BYTES_CERTFILE = os.fsencode(CERTFILE)
72ONLYCERT = data_file("ssl_cert.pem")
73ONLYKEY = data_file("ssl_key.pem")
74BYTES_ONLYCERT = os.fsencode(ONLYCERT)
75BYTES_ONLYKEY = os.fsencode(ONLYKEY)
76CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
77ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
78KEY_PASSWORD = "somepass"
79CAPATH = data_file("capath")
80BYTES_CAPATH = os.fsencode(CAPATH)
81CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
82CAFILE_CACERT = data_file("capath", "5ed36f99.0")
83
84CERTFILE_INFO = {
85    'issuer': ((('countryName', 'XY'),),
86               (('localityName', 'Castle Anthrax'),),
87               (('organizationName', 'Python Software Foundation'),),
88               (('commonName', 'localhost'),)),
89    'notAfter': 'Aug 26 14:23:15 2028 GMT',
90    'notBefore': 'Aug 29 14:23:15 2018 GMT',
91    'serialNumber': '98A7CF88C74A32ED',
92    'subject': ((('countryName', 'XY'),),
93             (('localityName', 'Castle Anthrax'),),
94             (('organizationName', 'Python Software Foundation'),),
95             (('commonName', 'localhost'),)),
96    'subjectAltName': (('DNS', 'localhost'),),
97    'version': 3
98}
99
100# empty CRL
101CRLFILE = data_file("revocation.crl")
102
103# Two keys and certs signed by the same CA (for SNI tests)
104SIGNED_CERTFILE = data_file("keycert3.pem")
105SIGNED_CERTFILE_HOSTNAME = 'localhost'
106
107SIGNED_CERTFILE_INFO = {
108    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
109    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
110    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
111    'issuer': ((('countryName', 'XY'),),
112            (('organizationName', 'Python Software Foundation CA'),),
113            (('commonName', 'our-ca-server'),)),
114    'notAfter': 'Oct 28 14:23:16 2037 GMT',
115    'notBefore': 'Aug 29 14:23:16 2018 GMT',
116    'serialNumber': 'CB2D80995A69525C',
117    'subject': ((('countryName', 'XY'),),
118             (('localityName', 'Castle Anthrax'),),
119             (('organizationName', 'Python Software Foundation'),),
120             (('commonName', 'localhost'),)),
121    'subjectAltName': (('DNS', 'localhost'),),
122    'version': 3
123}
124
125SIGNED_CERTFILE2 = data_file("keycert4.pem")
126SIGNED_CERTFILE2_HOSTNAME = 'fakehostname'
127SIGNED_CERTFILE_ECC = data_file("keycertecc.pem")
128SIGNED_CERTFILE_ECC_HOSTNAME = 'localhost-ecc'
129
130# Same certificate as pycacert.pem, but without extra text in file
131SIGNING_CA = data_file("capath", "ceff1710.0")
132# cert with all kinds of subject alt names
133ALLSANFILE = data_file("allsans.pem")
134IDNSANSFILE = data_file("idnsans.pem")
135NOSANFILE = data_file("nosan.pem")
136NOSAN_HOSTNAME = 'localhost'
137
138REMOTE_HOST = "self-signed.pythontest.net"
139
140EMPTYCERT = data_file("nullcert.pem")
141BADCERT = data_file("badcert.pem")
142NONEXISTINGCERT = data_file("XXXnonexisting.pem")
143BADKEY = data_file("badkey.pem")
144NOKIACERT = data_file("nokia.pem")
145NULLBYTECERT = data_file("nullbytecert.pem")
146TALOS_INVALID_CRLDP = data_file("talos-2019-0758.pem")
147
148DHFILE = data_file("ffdh3072.pem")
149BYTES_DHFILE = os.fsencode(DHFILE)
150
151# Not defined in all versions of OpenSSL
152OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
153OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
154OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
155OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
156OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0)
157
158# Ubuntu has patched OpenSSL and changed behavior of security level 2
159# see https://bugs.python.org/issue41561#msg389003
160def is_ubuntu():
161    try:
162        # Assume that any references of "ubuntu" implies Ubuntu-like distro
163        # The workaround is not required for 18.04, but doesn't hurt either.
164        with open("/etc/os-release", encoding="utf-8") as f:
165            return "ubuntu" in f.read()
166    except FileNotFoundError:
167        return False
168
169if is_ubuntu():
170    def seclevel_workaround(*ctxs):
171        """"Lower security level to '1' and allow all ciphers for TLS 1.0/1"""
172        for ctx in ctxs:
173            if (
174                hasattr(ctx, "minimum_version") and
175                ctx.minimum_version <= ssl.TLSVersion.TLSv1_1
176            ):
177                ctx.set_ciphers("@SECLEVEL=1:ALL")
178else:
179    def seclevel_workaround(*ctxs):
180        pass
181
182
183def has_tls_protocol(protocol):
184    """Check if a TLS protocol is available and enabled
185
186    :param protocol: enum ssl._SSLMethod member or name
187    :return: bool
188    """
189    if isinstance(protocol, str):
190        assert protocol.startswith('PROTOCOL_')
191        protocol = getattr(ssl, protocol, None)
192        if protocol is None:
193            return False
194    if protocol in {
195        ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER,
196        ssl.PROTOCOL_TLS_CLIENT
197    }:
198        # auto-negotiate protocols are always available
199        return True
200    name = protocol.name
201    return has_tls_version(name[len('PROTOCOL_'):])
202
203
204@functools.lru_cache
205def has_tls_version(version):
206    """Check if a TLS/SSL version is enabled
207
208    :param version: TLS version name or ssl.TLSVersion member
209    :return: bool
210    """
211    if version == "SSLv2":
212        # never supported and not even in TLSVersion enum
213        return False
214
215    if isinstance(version, str):
216        version = ssl.TLSVersion.__members__[version]
217
218    # check compile time flags like ssl.HAS_TLSv1_2
219    if not getattr(ssl, f'HAS_{version.name}'):
220        return False
221
222    if IS_OPENSSL_3_0_0 and version < ssl.TLSVersion.TLSv1_2:
223        # bpo43791: 3.0.0-alpha14 fails with TLSV1_ALERT_INTERNAL_ERROR
224        return False
225
226    # check runtime and dynamic crypto policy settings. A TLS version may
227    # be compiled in but disabled by a policy or config option.
228    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
229    if (
230            hasattr(ctx, 'minimum_version') and
231            ctx.minimum_version != ssl.TLSVersion.MINIMUM_SUPPORTED and
232            version < ctx.minimum_version
233    ):
234        return False
235    if (
236        hasattr(ctx, 'maximum_version') and
237        ctx.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED and
238        version > ctx.maximum_version
239    ):
240        return False
241
242    return True
243
244
245def requires_tls_version(version):
246    """Decorator to skip tests when a required TLS version is not available
247
248    :param version: TLS version name or ssl.TLSVersion member
249    :return:
250    """
251    def decorator(func):
252        @functools.wraps(func)
253        def wrapper(*args, **kw):
254            if not has_tls_version(version):
255                raise unittest.SkipTest(f"{version} is not available.")
256            else:
257                return func(*args, **kw)
258        return wrapper
259    return decorator
260
261
262def handle_error(prefix):
263    exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
264    if support.verbose:
265        sys.stdout.write(prefix + exc_format)
266
267
268def utc_offset(): #NOTE: ignore issues like #1647654
269    # local time = utc time + utc offset
270    if time.daylight and time.localtime().tm_isdst > 0:
271        return -time.altzone  # seconds
272    return -time.timezone
273
274
275ignore_deprecation = warnings_helper.ignore_warnings(
276    category=DeprecationWarning
277)
278
279
280def test_wrap_socket(sock, *,
281                     cert_reqs=ssl.CERT_NONE, ca_certs=None,
282                     ciphers=None, certfile=None, keyfile=None,
283                     **kwargs):
284    if not kwargs.get("server_side"):
285        kwargs["server_hostname"] = SIGNED_CERTFILE_HOSTNAME
286        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
287    else:
288        context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
289    if cert_reqs is not None:
290        if cert_reqs == ssl.CERT_NONE:
291            context.check_hostname = False
292        context.verify_mode = cert_reqs
293    if ca_certs is not None:
294        context.load_verify_locations(ca_certs)
295    if certfile is not None or keyfile is not None:
296        context.load_cert_chain(certfile, keyfile)
297    if ciphers is not None:
298        context.set_ciphers(ciphers)
299    return context.wrap_socket(sock, **kwargs)
300
301
302def testing_context(server_cert=SIGNED_CERTFILE, *, server_chain=True):
303    """Create context
304
305    client_context, server_context, hostname = testing_context()
306    """
307    if server_cert == SIGNED_CERTFILE:
308        hostname = SIGNED_CERTFILE_HOSTNAME
309    elif server_cert == SIGNED_CERTFILE2:
310        hostname = SIGNED_CERTFILE2_HOSTNAME
311    elif server_cert == NOSANFILE:
312        hostname = NOSAN_HOSTNAME
313    else:
314        raise ValueError(server_cert)
315
316    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
317    client_context.load_verify_locations(SIGNING_CA)
318
319    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
320    server_context.load_cert_chain(server_cert)
321    if server_chain:
322        server_context.load_verify_locations(SIGNING_CA)
323
324    return client_context, server_context, hostname
325
326
327class BasicSocketTests(unittest.TestCase):
328
329    def test_constants(self):
330        ssl.CERT_NONE
331        ssl.CERT_OPTIONAL
332        ssl.CERT_REQUIRED
333        ssl.OP_CIPHER_SERVER_PREFERENCE
334        ssl.OP_SINGLE_DH_USE
335        ssl.OP_SINGLE_ECDH_USE
336        ssl.OP_NO_COMPRESSION
337        self.assertEqual(ssl.HAS_SNI, True)
338        self.assertEqual(ssl.HAS_ECDH, True)
339        self.assertEqual(ssl.HAS_TLSv1_2, True)
340        self.assertEqual(ssl.HAS_TLSv1_3, True)
341        ssl.OP_NO_SSLv2
342        ssl.OP_NO_SSLv3
343        ssl.OP_NO_TLSv1
344        ssl.OP_NO_TLSv1_3
345        ssl.OP_NO_TLSv1_1
346        ssl.OP_NO_TLSv1_2
347        self.assertEqual(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv23)
348
349    def test_ssl_types(self):
350        ssl_types = [
351            _ssl._SSLContext,
352            _ssl._SSLSocket,
353            _ssl.MemoryBIO,
354            _ssl.Certificate,
355            _ssl.SSLSession,
356            _ssl.SSLError,
357        ]
358        for ssl_type in ssl_types:
359            with self.subTest(ssl_type=ssl_type):
360                with self.assertRaisesRegex(TypeError, "immutable type"):
361                    ssl_type.value = None
362        support.check_disallow_instantiation(self, _ssl.Certificate)
363
364    def test_private_init(self):
365        with self.assertRaisesRegex(TypeError, "public constructor"):
366            with socket.socket() as s:
367                ssl.SSLSocket(s)
368
369    def test_str_for_enums(self):
370        # Make sure that the PROTOCOL_* constants have enum-like string
371        # reprs.
372        proto = ssl.PROTOCOL_TLS_CLIENT
373        self.assertEqual(repr(proto), '<_SSLMethod.PROTOCOL_TLS_CLIENT: %r>' % proto.value)
374        self.assertEqual(str(proto), str(proto.value))
375        ctx = ssl.SSLContext(proto)
376        self.assertIs(ctx.protocol, proto)
377
378    def test_random(self):
379        v = ssl.RAND_status()
380        if support.verbose:
381            sys.stdout.write("\n RAND_status is %d (%s)\n"
382                             % (v, (v and "sufficient randomness") or
383                                "insufficient randomness"))
384
385        with warnings_helper.check_warnings():
386            data, is_cryptographic = ssl.RAND_pseudo_bytes(16)
387        self.assertEqual(len(data), 16)
388        self.assertEqual(is_cryptographic, v == 1)
389        if v:
390            data = ssl.RAND_bytes(16)
391            self.assertEqual(len(data), 16)
392        else:
393            self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16)
394
395        # negative num is invalid
396        self.assertRaises(ValueError, ssl.RAND_bytes, -5)
397        with warnings_helper.check_warnings():
398            self.assertRaises(ValueError, ssl.RAND_pseudo_bytes, -5)
399
400        ssl.RAND_add("this is a random string", 75.0)
401        ssl.RAND_add(b"this is a random bytes object", 75.0)
402        ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0)
403
404    def test_parse_cert(self):
405        # note that this uses an 'unofficial' function in _ssl.c,
406        # provided solely for this test, to exercise the certificate
407        # parsing code
408        self.assertEqual(
409            ssl._ssl._test_decode_cert(CERTFILE),
410            CERTFILE_INFO
411        )
412        self.assertEqual(
413            ssl._ssl._test_decode_cert(SIGNED_CERTFILE),
414            SIGNED_CERTFILE_INFO
415        )
416
417        # Issue #13034: the subjectAltName in some certificates
418        # (notably projects.developer.nokia.com:443) wasn't parsed
419        p = ssl._ssl._test_decode_cert(NOKIACERT)
420        if support.verbose:
421            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
422        self.assertEqual(p['subjectAltName'],
423                         (('DNS', 'projects.developer.nokia.com'),
424                          ('DNS', 'projects.forum.nokia.com'))
425                        )
426        # extra OCSP and AIA fields
427        self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',))
428        self.assertEqual(p['caIssuers'],
429                         ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',))
430        self.assertEqual(p['crlDistributionPoints'],
431                         ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',))
432
433    def test_parse_cert_CVE_2019_5010(self):
434        p = ssl._ssl._test_decode_cert(TALOS_INVALID_CRLDP)
435        if support.verbose:
436            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
437        self.assertEqual(
438            p,
439            {
440                'issuer': (
441                    (('countryName', 'UK'),), (('commonName', 'cody-ca'),)),
442                'notAfter': 'Jun 14 18:00:58 2028 GMT',
443                'notBefore': 'Jun 18 18:00:58 2018 GMT',
444                'serialNumber': '02',
445                'subject': ((('countryName', 'UK'),),
446                            (('commonName',
447                              'codenomicon-vm-2.test.lal.cisco.com'),)),
448                'subjectAltName': (
449                    ('DNS', 'codenomicon-vm-2.test.lal.cisco.com'),),
450                'version': 3
451            }
452        )
453
454    def test_parse_cert_CVE_2013_4238(self):
455        p = ssl._ssl._test_decode_cert(NULLBYTECERT)
456        if support.verbose:
457            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
458        subject = ((('countryName', 'US'),),
459                   (('stateOrProvinceName', 'Oregon'),),
460                   (('localityName', 'Beaverton'),),
461                   (('organizationName', 'Python Software Foundation'),),
462                   (('organizationalUnitName', 'Python Core Development'),),
463                   (('commonName', 'null.python.org\x00example.org'),),
464                   (('emailAddress', '[email protected]'),))
465        self.assertEqual(p['subject'], subject)
466        self.assertEqual(p['issuer'], subject)
467        if ssl._OPENSSL_API_VERSION >= (0, 9, 8):
468            san = (('DNS', 'altnull.python.org\x00example.com'),
469                   ('email', '[email protected]\[email protected]'),
470                   ('URI', 'http://null.python.org\x00http://example.org'),
471                   ('IP Address', '192.0.2.1'),
472                   ('IP Address', '2001:DB8:0:0:0:0:0:1'))
473        else:
474            # OpenSSL 0.9.7 doesn't support IPv6 addresses in subjectAltName
475            san = (('DNS', 'altnull.python.org\x00example.com'),
476                   ('email', '[email protected]\[email protected]'),
477                   ('URI', 'http://null.python.org\x00http://example.org'),
478                   ('IP Address', '192.0.2.1'),
479                   ('IP Address', '<invalid>'))
480
481        self.assertEqual(p['subjectAltName'], san)
482
483    def test_parse_all_sans(self):
484        p = ssl._ssl._test_decode_cert(ALLSANFILE)
485        self.assertEqual(p['subjectAltName'],
486            (
487                ('DNS', 'allsans'),
488                ('othername', '<unsupported>'),
489                ('othername', '<unsupported>'),
490                ('email', '[email protected]'),
491                ('DNS', 'www.example.org'),
492                ('DirName',
493                    ((('countryName', 'XY'),),
494                    (('localityName', 'Castle Anthrax'),),
495                    (('organizationName', 'Python Software Foundation'),),
496                    (('commonName', 'dirname example'),))),
497                ('URI', 'https://www.python.org/'),
498                ('IP Address', '127.0.0.1'),
499                ('IP Address', '0:0:0:0:0:0:0:1'),
500                ('Registered ID', '1.2.3.4.5')
501            )
502        )
503
504    def test_DER_to_PEM(self):
505        with open(CAFILE_CACERT, 'r') as f:
506            pem = f.read()
507        d1 = ssl.PEM_cert_to_DER_cert(pem)
508        p2 = ssl.DER_cert_to_PEM_cert(d1)
509        d2 = ssl.PEM_cert_to_DER_cert(p2)
510        self.assertEqual(d1, d2)
511        if not p2.startswith(ssl.PEM_HEADER + '\n'):
512            self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
513        if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
514            self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
515
516    def test_openssl_version(self):
517        n = ssl.OPENSSL_VERSION_NUMBER
518        t = ssl.OPENSSL_VERSION_INFO
519        s = ssl.OPENSSL_VERSION
520        self.assertIsInstance(n, int)
521        self.assertIsInstance(t, tuple)
522        self.assertIsInstance(s, str)
523        # Some sanity checks follow
524        # >= 1.1.1
525        self.assertGreaterEqual(n, 0x10101000)
526        # < 4.0
527        self.assertLess(n, 0x40000000)
528        major, minor, fix, patch, status = t
529        self.assertGreaterEqual(major, 1)
530        self.assertLess(major, 4)
531        self.assertGreaterEqual(minor, 0)
532        self.assertLess(minor, 256)
533        self.assertGreaterEqual(fix, 0)
534        self.assertLess(fix, 256)
535        self.assertGreaterEqual(patch, 0)
536        self.assertLessEqual(patch, 63)
537        self.assertGreaterEqual(status, 0)
538        self.assertLessEqual(status, 15)
539
540        libressl_ver = f"LibreSSL {major:d}"
541        if major >= 3:
542            # 3.x uses 0xMNN00PP0L
543            openssl_ver = f"OpenSSL {major:d}.{minor:d}.{patch:d}"
544        else:
545            openssl_ver = f"OpenSSL {major:d}.{minor:d}.{fix:d}"
546        self.assertTrue(
547            s.startswith((openssl_ver, libressl_ver)),
548            (s, t, hex(n))
549        )
550
551    @support.cpython_only
552    def test_refcycle(self):
553        # Issue #7943: an SSL object doesn't create reference cycles with
554        # itself.
555        s = socket.socket(socket.AF_INET)
556        ss = test_wrap_socket(s)
557        wr = weakref.ref(ss)
558        with warnings_helper.check_warnings(("", ResourceWarning)):
559            del ss
560        self.assertEqual(wr(), None)
561
562    def test_wrapped_unconnected(self):
563        # Methods on an unconnected SSLSocket propagate the original
564        # OSError raise by the underlying socket object.
565        s = socket.socket(socket.AF_INET)
566        with test_wrap_socket(s) as ss:
567            self.assertRaises(OSError, ss.recv, 1)
568            self.assertRaises(OSError, ss.recv_into, bytearray(b'x'))
569            self.assertRaises(OSError, ss.recvfrom, 1)
570            self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
571            self.assertRaises(OSError, ss.send, b'x')
572            self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
573            self.assertRaises(NotImplementedError, ss.dup)
574            self.assertRaises(NotImplementedError, ss.sendmsg,
575                              [b'x'], (), 0, ('0.0.0.0', 0))
576            self.assertRaises(NotImplementedError, ss.recvmsg, 100)
577            self.assertRaises(NotImplementedError, ss.recvmsg_into,
578                              [bytearray(100)])
579
580    def test_timeout(self):
581        # Issue #8524: when creating an SSL socket, the timeout of the
582        # original socket should be retained.
583        for timeout in (None, 0.0, 5.0):
584            s = socket.socket(socket.AF_INET)
585            s.settimeout(timeout)
586            with test_wrap_socket(s) as ss:
587                self.assertEqual(timeout, ss.gettimeout())
588
589    def test_openssl111_deprecations(self):
590        options = [
591            ssl.OP_NO_TLSv1,
592            ssl.OP_NO_TLSv1_1,
593            ssl.OP_NO_TLSv1_2,
594            ssl.OP_NO_TLSv1_3
595        ]
596        protocols = [
597            ssl.PROTOCOL_TLSv1,
598            ssl.PROTOCOL_TLSv1_1,
599            ssl.PROTOCOL_TLSv1_2,
600            ssl.PROTOCOL_TLS
601        ]
602        versions = [
603            ssl.TLSVersion.SSLv3,
604            ssl.TLSVersion.TLSv1,
605            ssl.TLSVersion.TLSv1_1,
606        ]
607
608        for option in options:
609            with self.subTest(option=option):
610                ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
611                with self.assertWarns(DeprecationWarning) as cm:
612                    ctx.options |= option
613                self.assertEqual(
614                    'ssl.OP_NO_SSL*/ssl.OP_NO_TLS* options are deprecated',
615                    str(cm.warning)
616                )
617
618        for protocol in protocols:
619            if not has_tls_protocol(protocol):
620                continue
621            with self.subTest(protocol=protocol):
622                with self.assertWarns(DeprecationWarning) as cm:
623                    ssl.SSLContext(protocol)
624                self.assertEqual(
625                    f'ssl.{protocol.name} is deprecated',
626                    str(cm.warning)
627                )
628
629        for version in versions:
630            if not has_tls_version(version):
631                continue
632            with self.subTest(version=version):
633                ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
634                with self.assertWarns(DeprecationWarning) as cm:
635                    ctx.minimum_version = version
636                version_text = '%s.%s' % (version.__class__.__name__, version.name)
637                self.assertEqual(
638                    f'ssl.{version_text} is deprecated',
639                    str(cm.warning)
640                )
641
642    @ignore_deprecation
643    def test_errors_sslwrap(self):
644        sock = socket.socket()
645        self.assertRaisesRegex(ValueError,
646                        "certfile must be specified",
647                        ssl.wrap_socket, sock, keyfile=CERTFILE)
648        self.assertRaisesRegex(ValueError,
649                        "certfile must be specified for server-side operations",
650                        ssl.wrap_socket, sock, server_side=True)
651        self.assertRaisesRegex(ValueError,
652                        "certfile must be specified for server-side operations",
653                         ssl.wrap_socket, sock, server_side=True, certfile="")
654        with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
655            self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
656                                     s.connect, (HOST, 8080))
657        with self.assertRaises(OSError) as cm:
658            with socket.socket() as sock:
659                ssl.wrap_socket(sock, certfile=NONEXISTINGCERT)
660        self.assertEqual(cm.exception.errno, errno.ENOENT)
661        with self.assertRaises(OSError) as cm:
662            with socket.socket() as sock:
663                ssl.wrap_socket(sock,
664                    certfile=CERTFILE, keyfile=NONEXISTINGCERT)
665        self.assertEqual(cm.exception.errno, errno.ENOENT)
666        with self.assertRaises(OSError) as cm:
667            with socket.socket() as sock:
668                ssl.wrap_socket(sock,
669                    certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT)
670        self.assertEqual(cm.exception.errno, errno.ENOENT)
671
672    def bad_cert_test(self, certfile):
673        """Check that trying to use the given client certificate fails"""
674        certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
675                                   certfile)
676        sock = socket.socket()
677        self.addCleanup(sock.close)
678        with self.assertRaises(ssl.SSLError):
679            test_wrap_socket(sock,
680                             certfile=certfile)
681
682    def test_empty_cert(self):
683        """Wrapping with an empty cert file"""
684        self.bad_cert_test("nullcert.pem")
685
686    def test_malformed_cert(self):
687        """Wrapping with a badly formatted certificate (syntax error)"""
688        self.bad_cert_test("badcert.pem")
689
690    def test_malformed_key(self):
691        """Wrapping with a badly formatted key (syntax error)"""
692        self.bad_cert_test("badkey.pem")
693
694    @ignore_deprecation
695    def test_match_hostname(self):
696        def ok(cert, hostname):
697            ssl.match_hostname(cert, hostname)
698        def fail(cert, hostname):
699            self.assertRaises(ssl.CertificateError,
700                              ssl.match_hostname, cert, hostname)
701
702        # -- Hostname matching --
703
704        cert = {'subject': ((('commonName', 'example.com'),),)}
705        ok(cert, 'example.com')
706        ok(cert, 'ExAmple.cOm')
707        fail(cert, 'www.example.com')
708        fail(cert, '.example.com')
709        fail(cert, 'example.org')
710        fail(cert, 'exampleXcom')
711
712        cert = {'subject': ((('commonName', '*.a.com'),),)}
713        ok(cert, 'foo.a.com')
714        fail(cert, 'bar.foo.a.com')
715        fail(cert, 'a.com')
716        fail(cert, 'Xa.com')
717        fail(cert, '.a.com')
718
719        # only match wildcards when they are the only thing
720        # in left-most segment
721        cert = {'subject': ((('commonName', 'f*.com'),),)}
722        fail(cert, 'foo.com')
723        fail(cert, 'f.com')
724        fail(cert, 'bar.com')
725        fail(cert, 'foo.a.com')
726        fail(cert, 'bar.foo.com')
727
728        # NULL bytes are bad, CVE-2013-4073
729        cert = {'subject': ((('commonName',
730                              'null.python.org\x00example.org'),),)}
731        ok(cert, 'null.python.org\x00example.org') # or raise an error?
732        fail(cert, 'example.org')
733        fail(cert, 'null.python.org')
734
735        # error cases with wildcards
736        cert = {'subject': ((('commonName', '*.*.a.com'),),)}
737        fail(cert, 'bar.foo.a.com')
738        fail(cert, 'a.com')
739        fail(cert, 'Xa.com')
740        fail(cert, '.a.com')
741
742        cert = {'subject': ((('commonName', 'a.*.com'),),)}
743        fail(cert, 'a.foo.com')
744        fail(cert, 'a..com')
745        fail(cert, 'a.com')
746
747        # wildcard doesn't match IDNA prefix 'xn--'
748        idna = 'püthon.python.org'.encode("idna").decode("ascii")
749        cert = {'subject': ((('commonName', idna),),)}
750        ok(cert, idna)
751        cert = {'subject': ((('commonName', 'x*.python.org'),),)}
752        fail(cert, idna)
753        cert = {'subject': ((('commonName', 'xn--p*.python.org'),),)}
754        fail(cert, idna)
755
756        # wildcard in first fragment and  IDNA A-labels in sequent fragments
757        # are supported.
758        idna = 'www*.pythön.org'.encode("idna").decode("ascii")
759        cert = {'subject': ((('commonName', idna),),)}
760        fail(cert, 'www.pythön.org'.encode("idna").decode("ascii"))
761        fail(cert, 'www1.pythön.org'.encode("idna").decode("ascii"))
762        fail(cert, 'ftp.pythön.org'.encode("idna").decode("ascii"))
763        fail(cert, 'pythön.org'.encode("idna").decode("ascii"))
764
765        # Slightly fake real-world example
766        cert = {'notAfter': 'Jun 26 21:41:46 2011 GMT',
767                'subject': ((('commonName', 'linuxfrz.org'),),),
768                'subjectAltName': (('DNS', 'linuxfr.org'),
769                                   ('DNS', 'linuxfr.com'),
770                                   ('othername', '<unsupported>'))}
771        ok(cert, 'linuxfr.org')
772        ok(cert, 'linuxfr.com')
773        # Not a "DNS" entry
774        fail(cert, '<unsupported>')
775        # When there is a subjectAltName, commonName isn't used
776        fail(cert, 'linuxfrz.org')
777
778        # A pristine real-world example
779        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
780                'subject': ((('countryName', 'US'),),
781                            (('stateOrProvinceName', 'California'),),
782                            (('localityName', 'Mountain View'),),
783                            (('organizationName', 'Google Inc'),),
784                            (('commonName', 'mail.google.com'),))}
785        ok(cert, 'mail.google.com')
786        fail(cert, 'gmail.com')
787        # Only commonName is considered
788        fail(cert, 'California')
789
790        # -- IPv4 matching --
791        cert = {'subject': ((('commonName', 'example.com'),),),
792                'subjectAltName': (('DNS', 'example.com'),
793                                   ('IP Address', '10.11.12.13'),
794                                   ('IP Address', '14.15.16.17'),
795                                   ('IP Address', '127.0.0.1'))}
796        ok(cert, '10.11.12.13')
797        ok(cert, '14.15.16.17')
798        # socket.inet_ntoa(socket.inet_aton('127.1')) == '127.0.0.1'
799        fail(cert, '127.1')
800        fail(cert, '14.15.16.17 ')
801        fail(cert, '14.15.16.17 extra data')
802        fail(cert, '14.15.16.18')
803        fail(cert, 'example.net')
804
805        # -- IPv6 matching --
806        if socket_helper.IPV6_ENABLED:
807            cert = {'subject': ((('commonName', 'example.com'),),),
808                    'subjectAltName': (
809                        ('DNS', 'example.com'),
810                        ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'),
811                        ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))}
812            ok(cert, '2001::cafe')
813            ok(cert, '2003::baba')
814            fail(cert, '2003::baba ')
815            fail(cert, '2003::baba extra data')
816            fail(cert, '2003::bebe')
817            fail(cert, 'example.net')
818
819        # -- Miscellaneous --
820
821        # Neither commonName nor subjectAltName
822        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
823                'subject': ((('countryName', 'US'),),
824                            (('stateOrProvinceName', 'California'),),
825                            (('localityName', 'Mountain View'),),
826                            (('organizationName', 'Google Inc'),))}
827        fail(cert, 'mail.google.com')
828
829        # No DNS entry in subjectAltName but a commonName
830        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
831                'subject': ((('countryName', 'US'),),
832                            (('stateOrProvinceName', 'California'),),
833                            (('localityName', 'Mountain View'),),
834                            (('commonName', 'mail.google.com'),)),
835                'subjectAltName': (('othername', 'blabla'), )}
836        ok(cert, 'mail.google.com')
837
838        # No DNS entry subjectAltName and no commonName
839        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
840                'subject': ((('countryName', 'US'),),
841                            (('stateOrProvinceName', 'California'),),
842                            (('localityName', 'Mountain View'),),
843                            (('organizationName', 'Google Inc'),)),
844                'subjectAltName': (('othername', 'blabla'),)}
845        fail(cert, 'google.com')
846
847        # Empty cert / no cert
848        self.assertRaises(ValueError, ssl.match_hostname, None, 'example.com')
849        self.assertRaises(ValueError, ssl.match_hostname, {}, 'example.com')
850
851        # Issue #17980: avoid denials of service by refusing more than one
852        # wildcard per fragment.
853        cert = {'subject': ((('commonName', 'a*b.example.com'),),)}
854        with self.assertRaisesRegex(
855                ssl.CertificateError,
856                "partial wildcards in leftmost label are not supported"):
857            ssl.match_hostname(cert, 'axxb.example.com')
858
859        cert = {'subject': ((('commonName', 'www.*.example.com'),),)}
860        with self.assertRaisesRegex(
861                ssl.CertificateError,
862                "wildcard can only be present in the leftmost label"):
863            ssl.match_hostname(cert, 'www.sub.example.com')
864
865        cert = {'subject': ((('commonName', 'a*b*.example.com'),),)}
866        with self.assertRaisesRegex(
867                ssl.CertificateError,
868                "too many wildcards"):
869            ssl.match_hostname(cert, 'axxbxxc.example.com')
870
871        cert = {'subject': ((('commonName', '*'),),)}
872        with self.assertRaisesRegex(
873                ssl.CertificateError,
874                "sole wildcard without additional labels are not support"):
875            ssl.match_hostname(cert, 'host')
876
877        cert = {'subject': ((('commonName', '*.com'),),)}
878        with self.assertRaisesRegex(
879                ssl.CertificateError,
880                r"hostname 'com' doesn't match '\*.com'"):
881            ssl.match_hostname(cert, 'com')
882
883        # extra checks for _inet_paton()
884        for invalid in ['1', '', '1.2.3', '256.0.0.1', '127.0.0.1/24']:
885            with self.assertRaises(ValueError):
886                ssl._inet_paton(invalid)
887        for ipaddr in ['127.0.0.1', '192.168.0.1']:
888            self.assertTrue(ssl._inet_paton(ipaddr))
889        if socket_helper.IPV6_ENABLED:
890            for ipaddr in ['::1', '2001:db8:85a3::8a2e:370:7334']:
891                self.assertTrue(ssl._inet_paton(ipaddr))
892
893    def test_server_side(self):
894        # server_hostname doesn't work for server sockets
895        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
896        with socket.socket() as sock:
897            self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
898                              server_hostname="some.hostname")
899
900    def test_unknown_channel_binding(self):
901        # should raise ValueError for unknown type
902        s = socket.create_server(('127.0.0.1', 0))
903        c = socket.socket(socket.AF_INET)
904        c.connect(s.getsockname())
905        with test_wrap_socket(c, do_handshake_on_connect=False) as ss:
906            with self.assertRaises(ValueError):
907                ss.get_channel_binding("unknown-type")
908        s.close()
909
910    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
911                         "'tls-unique' channel binding not available")
912    def test_tls_unique_channel_binding(self):
913        # unconnected should return None for known type
914        s = socket.socket(socket.AF_INET)
915        with test_wrap_socket(s) as ss:
916            self.assertIsNone(ss.get_channel_binding("tls-unique"))
917        # the same for server-side
918        s = socket.socket(socket.AF_INET)
919        with test_wrap_socket(s, server_side=True, certfile=CERTFILE) as ss:
920            self.assertIsNone(ss.get_channel_binding("tls-unique"))
921
922    def test_dealloc_warn(self):
923        ss = test_wrap_socket(socket.socket(socket.AF_INET))
924        r = repr(ss)
925        with self.assertWarns(ResourceWarning) as cm:
926            ss = None
927            support.gc_collect()
928        self.assertIn(r, str(cm.warning.args[0]))
929
930    def test_get_default_verify_paths(self):
931        paths = ssl.get_default_verify_paths()
932        self.assertEqual(len(paths), 6)
933        self.assertIsInstance(paths, ssl.DefaultVerifyPaths)
934
935        with os_helper.EnvironmentVarGuard() as env:
936            env["SSL_CERT_DIR"] = CAPATH
937            env["SSL_CERT_FILE"] = CERTFILE
938            paths = ssl.get_default_verify_paths()
939            self.assertEqual(paths.cafile, CERTFILE)
940            self.assertEqual(paths.capath, CAPATH)
941
942    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
943    def test_enum_certificates(self):
944        self.assertTrue(ssl.enum_certificates("CA"))
945        self.assertTrue(ssl.enum_certificates("ROOT"))
946
947        self.assertRaises(TypeError, ssl.enum_certificates)
948        self.assertRaises(WindowsError, ssl.enum_certificates, "")
949
950        trust_oids = set()
951        for storename in ("CA", "ROOT"):
952            store = ssl.enum_certificates(storename)
953            self.assertIsInstance(store, list)
954            for element in store:
955                self.assertIsInstance(element, tuple)
956                self.assertEqual(len(element), 3)
957                cert, enc, trust = element
958                self.assertIsInstance(cert, bytes)
959                self.assertIn(enc, {"x509_asn", "pkcs_7_asn"})
960                self.assertIsInstance(trust, (frozenset, set, bool))
961                if isinstance(trust, (frozenset, set)):
962                    trust_oids.update(trust)
963
964        serverAuth = "1.3.6.1.5.5.7.3.1"
965        self.assertIn(serverAuth, trust_oids)
966
967    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
968    def test_enum_crls(self):
969        self.assertTrue(ssl.enum_crls("CA"))
970        self.assertRaises(TypeError, ssl.enum_crls)
971        self.assertRaises(WindowsError, ssl.enum_crls, "")
972
973        crls = ssl.enum_crls("CA")
974        self.assertIsInstance(crls, list)
975        for element in crls:
976            self.assertIsInstance(element, tuple)
977            self.assertEqual(len(element), 2)
978            self.assertIsInstance(element[0], bytes)
979            self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"})
980
981
982    def test_asn1object(self):
983        expected = (129, 'serverAuth', 'TLS Web Server Authentication',
984                    '1.3.6.1.5.5.7.3.1')
985
986        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
987        self.assertEqual(val, expected)
988        self.assertEqual(val.nid, 129)
989        self.assertEqual(val.shortname, 'serverAuth')
990        self.assertEqual(val.longname, 'TLS Web Server Authentication')
991        self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1')
992        self.assertIsInstance(val, ssl._ASN1Object)
993        self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth')
994
995        val = ssl._ASN1Object.fromnid(129)
996        self.assertEqual(val, expected)
997        self.assertIsInstance(val, ssl._ASN1Object)
998        self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1)
999        with self.assertRaisesRegex(ValueError, "unknown NID 100000"):
1000            ssl._ASN1Object.fromnid(100000)
1001        for i in range(1000):
1002            try:
1003                obj = ssl._ASN1Object.fromnid(i)
1004            except ValueError:
1005                pass
1006            else:
1007                self.assertIsInstance(obj.nid, int)
1008                self.assertIsInstance(obj.shortname, str)
1009                self.assertIsInstance(obj.longname, str)
1010                self.assertIsInstance(obj.oid, (str, type(None)))
1011
1012        val = ssl._ASN1Object.fromname('TLS Web Server Authentication')
1013        self.assertEqual(val, expected)
1014        self.assertIsInstance(val, ssl._ASN1Object)
1015        self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected)
1016        self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'),
1017                         expected)
1018        with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"):
1019            ssl._ASN1Object.fromname('serverauth')
1020
1021    def test_purpose_enum(self):
1022        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
1023        self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object)
1024        self.assertEqual(ssl.Purpose.SERVER_AUTH, val)
1025        self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129)
1026        self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth')
1027        self.assertEqual(ssl.Purpose.SERVER_AUTH.oid,
1028                              '1.3.6.1.5.5.7.3.1')
1029
1030        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2')
1031        self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object)
1032        self.assertEqual(ssl.Purpose.CLIENT_AUTH, val)
1033        self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130)
1034        self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth')
1035        self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid,
1036                              '1.3.6.1.5.5.7.3.2')
1037
1038    def test_unsupported_dtls(self):
1039        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1040        self.addCleanup(s.close)
1041        with self.assertRaises(NotImplementedError) as cx:
1042            test_wrap_socket(s, cert_reqs=ssl.CERT_NONE)
1043        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1044        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1045        with self.assertRaises(NotImplementedError) as cx:
1046            ctx.wrap_socket(s)
1047        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1048
1049    def cert_time_ok(self, timestring, timestamp):
1050        self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
1051
1052    def cert_time_fail(self, timestring):
1053        with self.assertRaises(ValueError):
1054            ssl.cert_time_to_seconds(timestring)
1055
1056    @unittest.skipUnless(utc_offset(),
1057                         'local time needs to be different from UTC')
1058    def test_cert_time_to_seconds_timezone(self):
1059        # Issue #19940: ssl.cert_time_to_seconds() returns wrong
1060        #               results if local timezone is not UTC
1061        self.cert_time_ok("May  9 00:00:00 2007 GMT", 1178668800.0)
1062        self.cert_time_ok("Jan  5 09:34:43 2018 GMT", 1515144883.0)
1063
1064    def test_cert_time_to_seconds(self):
1065        timestring = "Jan  5 09:34:43 2018 GMT"
1066        ts = 1515144883.0
1067        self.cert_time_ok(timestring, ts)
1068        # accept keyword parameter, assert its name
1069        self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
1070        # accept both %e and %d (space or zero generated by strftime)
1071        self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
1072        # case-insensitive
1073        self.cert_time_ok("JaN  5 09:34:43 2018 GmT", ts)
1074        self.cert_time_fail("Jan  5 09:34 2018 GMT")     # no seconds
1075        self.cert_time_fail("Jan  5 09:34:43 2018")      # no GMT
1076        self.cert_time_fail("Jan  5 09:34:43 2018 UTC")  # not GMT timezone
1077        self.cert_time_fail("Jan 35 09:34:43 2018 GMT")  # invalid day
1078        self.cert_time_fail("Jon  5 09:34:43 2018 GMT")  # invalid month
1079        self.cert_time_fail("Jan  5 24:00:00 2018 GMT")  # invalid hour
1080        self.cert_time_fail("Jan  5 09:60:43 2018 GMT")  # invalid minute
1081
1082        newyear_ts = 1230768000.0
1083        # leap seconds
1084        self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
1085        # same timestamp
1086        self.cert_time_ok("Jan  1 00:00:00 2009 GMT", newyear_ts)
1087
1088        self.cert_time_ok("Jan  5 09:34:59 2018 GMT", 1515144899)
1089        #  allow 60th second (even if it is not a leap second)
1090        self.cert_time_ok("Jan  5 09:34:60 2018 GMT", 1515144900)
1091        #  allow 2nd leap second for compatibility with time.strptime()
1092        self.cert_time_ok("Jan  5 09:34:61 2018 GMT", 1515144901)
1093        self.cert_time_fail("Jan  5 09:34:62 2018 GMT")  # invalid seconds
1094
1095        # no special treatment for the special value:
1096        #   99991231235959Z (rfc 5280)
1097        self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
1098
1099    @support.run_with_locale('LC_ALL', '')
1100    def test_cert_time_to_seconds_locale(self):
1101        # `cert_time_to_seconds()` should be locale independent
1102
1103        def local_february_name():
1104            return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
1105
1106        if local_february_name().lower() == 'feb':
1107            self.skipTest("locale-specific month name needs to be "
1108                          "different from C locale")
1109
1110        # locale-independent
1111        self.cert_time_ok("Feb  9 00:00:00 2007 GMT", 1170979200.0)
1112        self.cert_time_fail(local_february_name() + "  9 00:00:00 2007 GMT")
1113
1114    def test_connect_ex_error(self):
1115        server = socket.socket(socket.AF_INET)
1116        self.addCleanup(server.close)
1117        port = socket_helper.bind_port(server)  # Reserve port but don't listen
1118        s = test_wrap_socket(socket.socket(socket.AF_INET),
1119                            cert_reqs=ssl.CERT_REQUIRED)
1120        self.addCleanup(s.close)
1121        rc = s.connect_ex((HOST, port))
1122        # Issue #19919: Windows machines or VMs hosted on Windows
1123        # machines sometimes return EWOULDBLOCK.
1124        errors = (
1125            errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.ETIMEDOUT,
1126            errno.EWOULDBLOCK,
1127        )
1128        self.assertIn(rc, errors)
1129
1130    def test_read_write_zero(self):
1131        # empty reads and writes now work, bpo-42854, bpo-31711
1132        client_context, server_context, hostname = testing_context()
1133        server = ThreadedEchoServer(context=server_context)
1134        with server:
1135            with client_context.wrap_socket(socket.socket(),
1136                                            server_hostname=hostname) as s:
1137                s.connect((HOST, server.port))
1138                self.assertEqual(s.recv(0), b"")
1139                self.assertEqual(s.send(b""), 0)
1140
1141
1142class ContextTests(unittest.TestCase):
1143
1144    def test_constructor(self):
1145        for protocol in PROTOCOLS:
1146            if has_tls_protocol(protocol):
1147                with warnings_helper.check_warnings():
1148                    ctx = ssl.SSLContext(protocol)
1149                self.assertEqual(ctx.protocol, protocol)
1150        with warnings_helper.check_warnings():
1151            ctx = ssl.SSLContext()
1152        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1153        self.assertRaises(ValueError, ssl.SSLContext, -1)
1154        self.assertRaises(ValueError, ssl.SSLContext, 42)
1155
1156    def test_ciphers(self):
1157        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1158        ctx.set_ciphers("ALL")
1159        ctx.set_ciphers("DEFAULT")
1160        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
1161            ctx.set_ciphers("^$:,;?*'dorothyx")
1162
1163    @unittest.skipUnless(PY_SSL_DEFAULT_CIPHERS == 1,
1164                         "Test applies only to Python default ciphers")
1165    def test_python_ciphers(self):
1166        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1167        ciphers = ctx.get_ciphers()
1168        for suite in ciphers:
1169            name = suite['name']
1170            self.assertNotIn("PSK", name)
1171            self.assertNotIn("SRP", name)
1172            self.assertNotIn("MD5", name)
1173            self.assertNotIn("RC4", name)
1174            self.assertNotIn("3DES", name)
1175
1176    def test_get_ciphers(self):
1177        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1178        ctx.set_ciphers('AESGCM')
1179        names = set(d['name'] for d in ctx.get_ciphers())
1180        expected = {
1181            'AES128-GCM-SHA256',
1182            'ECDHE-ECDSA-AES128-GCM-SHA256',
1183            'ECDHE-RSA-AES128-GCM-SHA256',
1184            'DHE-RSA-AES128-GCM-SHA256',
1185            'AES256-GCM-SHA384',
1186            'ECDHE-ECDSA-AES256-GCM-SHA384',
1187            'ECDHE-RSA-AES256-GCM-SHA384',
1188            'DHE-RSA-AES256-GCM-SHA384',
1189        }
1190        intersection = names.intersection(expected)
1191        self.assertGreaterEqual(
1192            len(intersection), 2, f"\ngot: {sorted(names)}\nexpected: {sorted(expected)}"
1193        )
1194
1195    def test_options(self):
1196        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1197        # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
1198        default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
1199        # SSLContext also enables these by default
1200        default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
1201                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
1202                    OP_ENABLE_MIDDLEBOX_COMPAT)
1203        self.assertEqual(default, ctx.options)
1204        with warnings_helper.check_warnings():
1205            ctx.options |= ssl.OP_NO_TLSv1
1206        self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
1207        with warnings_helper.check_warnings():
1208            ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
1209        self.assertEqual(default, ctx.options)
1210        ctx.options = 0
1211        # Ubuntu has OP_NO_SSLv3 forced on by default
1212        self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
1213
1214    def test_verify_mode_protocol(self):
1215        with warnings_helper.check_warnings():
1216            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1217        # Default value
1218        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1219        ctx.verify_mode = ssl.CERT_OPTIONAL
1220        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1221        ctx.verify_mode = ssl.CERT_REQUIRED
1222        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1223        ctx.verify_mode = ssl.CERT_NONE
1224        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1225        with self.assertRaises(TypeError):
1226            ctx.verify_mode = None
1227        with self.assertRaises(ValueError):
1228            ctx.verify_mode = 42
1229
1230        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1231        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1232        self.assertFalse(ctx.check_hostname)
1233
1234        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1235        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1236        self.assertTrue(ctx.check_hostname)
1237
1238    def test_hostname_checks_common_name(self):
1239        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1240        self.assertTrue(ctx.hostname_checks_common_name)
1241        if ssl.HAS_NEVER_CHECK_COMMON_NAME:
1242            ctx.hostname_checks_common_name = True
1243            self.assertTrue(ctx.hostname_checks_common_name)
1244            ctx.hostname_checks_common_name = False
1245            self.assertFalse(ctx.hostname_checks_common_name)
1246            ctx.hostname_checks_common_name = True
1247            self.assertTrue(ctx.hostname_checks_common_name)
1248        else:
1249            with self.assertRaises(AttributeError):
1250                ctx.hostname_checks_common_name = True
1251
1252    @ignore_deprecation
1253    def test_min_max_version(self):
1254        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1255        # OpenSSL default is MINIMUM_SUPPORTED, however some vendors like
1256        # Fedora override the setting to TLS 1.0.
1257        minimum_range = {
1258            # stock OpenSSL
1259            ssl.TLSVersion.MINIMUM_SUPPORTED,
1260            # Fedora 29 uses TLS 1.0 by default
1261            ssl.TLSVersion.TLSv1,
1262            # RHEL 8 uses TLS 1.2 by default
1263            ssl.TLSVersion.TLSv1_2
1264        }
1265        maximum_range = {
1266            # stock OpenSSL
1267            ssl.TLSVersion.MAXIMUM_SUPPORTED,
1268            # Fedora 32 uses TLS 1.3 by default
1269            ssl.TLSVersion.TLSv1_3
1270        }
1271
1272        self.assertIn(
1273            ctx.minimum_version, minimum_range
1274        )
1275        self.assertIn(
1276            ctx.maximum_version, maximum_range
1277        )
1278
1279        ctx.minimum_version = ssl.TLSVersion.TLSv1_1
1280        ctx.maximum_version = ssl.TLSVersion.TLSv1_2
1281        self.assertEqual(
1282            ctx.minimum_version, ssl.TLSVersion.TLSv1_1
1283        )
1284        self.assertEqual(
1285            ctx.maximum_version, ssl.TLSVersion.TLSv1_2
1286        )
1287
1288        ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1289        ctx.maximum_version = ssl.TLSVersion.TLSv1
1290        self.assertEqual(
1291            ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
1292        )
1293        self.assertEqual(
1294            ctx.maximum_version, ssl.TLSVersion.TLSv1
1295        )
1296
1297        ctx.maximum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1298        self.assertEqual(
1299            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1300        )
1301
1302        ctx.maximum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1303        self.assertIn(
1304            ctx.maximum_version,
1305            {ssl.TLSVersion.TLSv1, ssl.TLSVersion.TLSv1_1, ssl.TLSVersion.SSLv3}
1306        )
1307
1308        ctx.minimum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1309        self.assertIn(
1310            ctx.minimum_version,
1311            {ssl.TLSVersion.TLSv1_2, ssl.TLSVersion.TLSv1_3}
1312        )
1313
1314        with self.assertRaises(ValueError):
1315            ctx.minimum_version = 42
1316
1317        if has_tls_protocol(ssl.PROTOCOL_TLSv1_1):
1318            ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_1)
1319
1320            self.assertIn(
1321                ctx.minimum_version, minimum_range
1322            )
1323            self.assertEqual(
1324                ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1325            )
1326            with self.assertRaises(ValueError):
1327                ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1328            with self.assertRaises(ValueError):
1329                ctx.maximum_version = ssl.TLSVersion.TLSv1
1330
1331    @unittest.skipUnless(
1332        hasattr(ssl.SSLContext, 'security_level'),
1333        "requires OpenSSL >= 1.1.0"
1334    )
1335    def test_security_level(self):
1336        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1337        # The default security callback allows for levels between 0-5
1338        # with OpenSSL defaulting to 1, however some vendors override the
1339        # default value (e.g. Debian defaults to 2)
1340        security_level_range = {
1341            0,
1342            1, # OpenSSL default
1343            2, # Debian
1344            3,
1345            4,
1346            5,
1347        }
1348        self.assertIn(ctx.security_level, security_level_range)
1349
1350    def test_verify_flags(self):
1351        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1352        # default value
1353        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
1354        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf)
1355        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
1356        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF)
1357        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN
1358        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN)
1359        ctx.verify_flags = ssl.VERIFY_DEFAULT
1360        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT)
1361        ctx.verify_flags = ssl.VERIFY_ALLOW_PROXY_CERTS
1362        self.assertEqual(ctx.verify_flags, ssl.VERIFY_ALLOW_PROXY_CERTS)
1363        # supports any value
1364        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT
1365        self.assertEqual(ctx.verify_flags,
1366                         ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT)
1367        with self.assertRaises(TypeError):
1368            ctx.verify_flags = None
1369
1370    def test_load_cert_chain(self):
1371        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1372        # Combined key and cert in a single file
1373        ctx.load_cert_chain(CERTFILE, keyfile=None)
1374        ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
1375        self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
1376        with self.assertRaises(OSError) as cm:
1377            ctx.load_cert_chain(NONEXISTINGCERT)
1378        self.assertEqual(cm.exception.errno, errno.ENOENT)
1379        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1380            ctx.load_cert_chain(BADCERT)
1381        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1382            ctx.load_cert_chain(EMPTYCERT)
1383        # Separate key and cert
1384        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1385        ctx.load_cert_chain(ONLYCERT, ONLYKEY)
1386        ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
1387        ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY)
1388        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1389            ctx.load_cert_chain(ONLYCERT)
1390        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1391            ctx.load_cert_chain(ONLYKEY)
1392        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1393            ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT)
1394        # Mismatching key and cert
1395        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1396        with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
1397            ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY)
1398        # Password protected key and cert
1399        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
1400        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
1401        ctx.load_cert_chain(CERTFILE_PROTECTED,
1402                            password=bytearray(KEY_PASSWORD.encode()))
1403        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
1404        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
1405        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
1406                            bytearray(KEY_PASSWORD.encode()))
1407        with self.assertRaisesRegex(TypeError, "should be a string"):
1408            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
1409        with self.assertRaises(ssl.SSLError):
1410            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
1411        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1412            # openssl has a fixed limit on the password buffer.
1413            # PEM_BUFSIZE is generally set to 1kb.
1414            # Return a string larger than this.
1415            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
1416        # Password callback
1417        def getpass_unicode():
1418            return KEY_PASSWORD
1419        def getpass_bytes():
1420            return KEY_PASSWORD.encode()
1421        def getpass_bytearray():
1422            return bytearray(KEY_PASSWORD.encode())
1423        def getpass_badpass():
1424            return "badpass"
1425        def getpass_huge():
1426            return b'a' * (1024 * 1024)
1427        def getpass_bad_type():
1428            return 9
1429        def getpass_exception():
1430            raise Exception('getpass error')
1431        class GetPassCallable:
1432            def __call__(self):
1433                return KEY_PASSWORD
1434            def getpass(self):
1435                return KEY_PASSWORD
1436        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
1437        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
1438        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
1439        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
1440        ctx.load_cert_chain(CERTFILE_PROTECTED,
1441                            password=GetPassCallable().getpass)
1442        with self.assertRaises(ssl.SSLError):
1443            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
1444        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1445            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
1446        with self.assertRaisesRegex(TypeError, "must return a string"):
1447            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
1448        with self.assertRaisesRegex(Exception, "getpass error"):
1449            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
1450        # Make sure the password function isn't called if it isn't needed
1451        ctx.load_cert_chain(CERTFILE, password=getpass_exception)
1452
1453    def test_load_verify_locations(self):
1454        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1455        ctx.load_verify_locations(CERTFILE)
1456        ctx.load_verify_locations(cafile=CERTFILE, capath=None)
1457        ctx.load_verify_locations(BYTES_CERTFILE)
1458        ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None)
1459        self.assertRaises(TypeError, ctx.load_verify_locations)
1460        self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None)
1461        with self.assertRaises(OSError) as cm:
1462            ctx.load_verify_locations(NONEXISTINGCERT)
1463        self.assertEqual(cm.exception.errno, errno.ENOENT)
1464        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1465            ctx.load_verify_locations(BADCERT)
1466        ctx.load_verify_locations(CERTFILE, CAPATH)
1467        ctx.load_verify_locations(CERTFILE, capath=BYTES_CAPATH)
1468
1469        # Issue #10989: crash if the second argument type is invalid
1470        self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
1471
1472    def test_load_verify_cadata(self):
1473        # test cadata
1474        with open(CAFILE_CACERT) as f:
1475            cacert_pem = f.read()
1476        cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem)
1477        with open(CAFILE_NEURONIO) as f:
1478            neuronio_pem = f.read()
1479        neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem)
1480
1481        # test PEM
1482        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1483        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0)
1484        ctx.load_verify_locations(cadata=cacert_pem)
1485        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1)
1486        ctx.load_verify_locations(cadata=neuronio_pem)
1487        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1488        # cert already in hash table
1489        ctx.load_verify_locations(cadata=neuronio_pem)
1490        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1491
1492        # combined
1493        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1494        combined = "\n".join((cacert_pem, neuronio_pem))
1495        ctx.load_verify_locations(cadata=combined)
1496        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1497
1498        # with junk around the certs
1499        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1500        combined = ["head", cacert_pem, "other", neuronio_pem, "again",
1501                    neuronio_pem, "tail"]
1502        ctx.load_verify_locations(cadata="\n".join(combined))
1503        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1504
1505        # test DER
1506        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1507        ctx.load_verify_locations(cadata=cacert_der)
1508        ctx.load_verify_locations(cadata=neuronio_der)
1509        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1510        # cert already in hash table
1511        ctx.load_verify_locations(cadata=cacert_der)
1512        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1513
1514        # combined
1515        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1516        combined = b"".join((cacert_der, neuronio_der))
1517        ctx.load_verify_locations(cadata=combined)
1518        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1519
1520        # error cases
1521        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1522        self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object)
1523
1524        with self.assertRaisesRegex(
1525            ssl.SSLError,
1526            "no start line: cadata does not contain a certificate"
1527        ):
1528            ctx.load_verify_locations(cadata="broken")
1529        with self.assertRaisesRegex(
1530            ssl.SSLError,
1531            "not enough data: cadata does not contain a certificate"
1532        ):
1533            ctx.load_verify_locations(cadata=b"broken")
1534
1535    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
1536    def test_load_dh_params(self):
1537        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1538        ctx.load_dh_params(DHFILE)
1539        if os.name != 'nt':
1540            ctx.load_dh_params(BYTES_DHFILE)
1541        self.assertRaises(TypeError, ctx.load_dh_params)
1542        self.assertRaises(TypeError, ctx.load_dh_params, None)
1543        with self.assertRaises(FileNotFoundError) as cm:
1544            ctx.load_dh_params(NONEXISTINGCERT)
1545        self.assertEqual(cm.exception.errno, errno.ENOENT)
1546        with self.assertRaises(ssl.SSLError) as cm:
1547            ctx.load_dh_params(CERTFILE)
1548
1549    def test_session_stats(self):
1550        for proto in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
1551            ctx = ssl.SSLContext(proto)
1552            self.assertEqual(ctx.session_stats(), {
1553                'number': 0,
1554                'connect': 0,
1555                'connect_good': 0,
1556                'connect_renegotiate': 0,
1557                'accept': 0,
1558                'accept_good': 0,
1559                'accept_renegotiate': 0,
1560                'hits': 0,
1561                'misses': 0,
1562                'timeouts': 0,
1563                'cache_full': 0,
1564            })
1565
1566    def test_set_default_verify_paths(self):
1567        # There's not much we can do to test that it acts as expected,
1568        # so just check it doesn't crash or raise an exception.
1569        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1570        ctx.set_default_verify_paths()
1571
1572    @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
1573    def test_set_ecdh_curve(self):
1574        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1575        ctx.set_ecdh_curve("prime256v1")
1576        ctx.set_ecdh_curve(b"prime256v1")
1577        self.assertRaises(TypeError, ctx.set_ecdh_curve)
1578        self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
1579        self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
1580        self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
1581
1582    def test_sni_callback(self):
1583        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1584
1585        # set_servername_callback expects a callable, or None
1586        self.assertRaises(TypeError, ctx.set_servername_callback)
1587        self.assertRaises(TypeError, ctx.set_servername_callback, 4)
1588        self.assertRaises(TypeError, ctx.set_servername_callback, "")
1589        self.assertRaises(TypeError, ctx.set_servername_callback, ctx)
1590
1591        def dummycallback(sock, servername, ctx):
1592            pass
1593        ctx.set_servername_callback(None)
1594        ctx.set_servername_callback(dummycallback)
1595
1596    def test_sni_callback_refcycle(self):
1597        # Reference cycles through the servername callback are detected
1598        # and cleared.
1599        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1600        def dummycallback(sock, servername, ctx, cycle=ctx):
1601            pass
1602        ctx.set_servername_callback(dummycallback)
1603        wr = weakref.ref(ctx)
1604        del ctx, dummycallback
1605        gc.collect()
1606        self.assertIs(wr(), None)
1607
1608    def test_cert_store_stats(self):
1609        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1610        self.assertEqual(ctx.cert_store_stats(),
1611            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1612        ctx.load_cert_chain(CERTFILE)
1613        self.assertEqual(ctx.cert_store_stats(),
1614            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1615        ctx.load_verify_locations(CERTFILE)
1616        self.assertEqual(ctx.cert_store_stats(),
1617            {'x509_ca': 0, 'crl': 0, 'x509': 1})
1618        ctx.load_verify_locations(CAFILE_CACERT)
1619        self.assertEqual(ctx.cert_store_stats(),
1620            {'x509_ca': 1, 'crl': 0, 'x509': 2})
1621
1622    def test_get_ca_certs(self):
1623        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1624        self.assertEqual(ctx.get_ca_certs(), [])
1625        # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
1626        ctx.load_verify_locations(CERTFILE)
1627        self.assertEqual(ctx.get_ca_certs(), [])
1628        # but CAFILE_CACERT is a CA cert
1629        ctx.load_verify_locations(CAFILE_CACERT)
1630        self.assertEqual(ctx.get_ca_certs(),
1631            [{'issuer': ((('organizationName', 'Root CA'),),
1632                         (('organizationalUnitName', 'http://www.cacert.org'),),
1633                         (('commonName', 'CA Cert Signing Authority'),),
1634                         (('emailAddress', '[email protected]'),)),
1635              'notAfter': 'Mar 29 12:29:49 2033 GMT',
1636              'notBefore': 'Mar 30 12:29:49 2003 GMT',
1637              'serialNumber': '00',
1638              'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
1639              'subject': ((('organizationName', 'Root CA'),),
1640                          (('organizationalUnitName', 'http://www.cacert.org'),),
1641                          (('commonName', 'CA Cert Signing Authority'),),
1642                          (('emailAddress', '[email protected]'),)),
1643              'version': 3}])
1644
1645        with open(CAFILE_CACERT) as f:
1646            pem = f.read()
1647        der = ssl.PEM_cert_to_DER_cert(pem)
1648        self.assertEqual(ctx.get_ca_certs(True), [der])
1649
1650    def test_load_default_certs(self):
1651        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1652        ctx.load_default_certs()
1653
1654        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1655        ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
1656        ctx.load_default_certs()
1657
1658        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1659        ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH)
1660
1661        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1662        self.assertRaises(TypeError, ctx.load_default_certs, None)
1663        self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
1664
1665    @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
1666    def test_load_default_certs_env(self):
1667        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1668        with os_helper.EnvironmentVarGuard() as env:
1669            env["SSL_CERT_DIR"] = CAPATH
1670            env["SSL_CERT_FILE"] = CERTFILE
1671            ctx.load_default_certs()
1672            self.assertEqual(ctx.cert_store_stats(), {"crl": 0, "x509": 1, "x509_ca": 0})
1673
1674    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
1675    @unittest.skipIf(hasattr(sys, "gettotalrefcount"), "Debug build does not share environment between CRTs")
1676    def test_load_default_certs_env_windows(self):
1677        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1678        ctx.load_default_certs()
1679        stats = ctx.cert_store_stats()
1680
1681        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1682        with os_helper.EnvironmentVarGuard() as env:
1683            env["SSL_CERT_DIR"] = CAPATH
1684            env["SSL_CERT_FILE"] = CERTFILE
1685            ctx.load_default_certs()
1686            stats["x509"] += 1
1687            self.assertEqual(ctx.cert_store_stats(), stats)
1688
1689    def _assert_context_options(self, ctx):
1690        self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
1691        if OP_NO_COMPRESSION != 0:
1692            self.assertEqual(ctx.options & OP_NO_COMPRESSION,
1693                             OP_NO_COMPRESSION)
1694        if OP_SINGLE_DH_USE != 0:
1695            self.assertEqual(ctx.options & OP_SINGLE_DH_USE,
1696                             OP_SINGLE_DH_USE)
1697        if OP_SINGLE_ECDH_USE != 0:
1698            self.assertEqual(ctx.options & OP_SINGLE_ECDH_USE,
1699                             OP_SINGLE_ECDH_USE)
1700        if OP_CIPHER_SERVER_PREFERENCE != 0:
1701            self.assertEqual(ctx.options & OP_CIPHER_SERVER_PREFERENCE,
1702                             OP_CIPHER_SERVER_PREFERENCE)
1703
1704    def test_create_default_context(self):
1705        ctx = ssl.create_default_context()
1706
1707        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
1708        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1709        self.assertTrue(ctx.check_hostname)
1710        self._assert_context_options(ctx)
1711
1712        with open(SIGNING_CA) as f:
1713            cadata = f.read()
1714        ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH,
1715                                         cadata=cadata)
1716        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
1717        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1718        self._assert_context_options(ctx)
1719
1720        ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
1721        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_SERVER)
1722        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1723        self._assert_context_options(ctx)
1724
1725    def test__create_stdlib_context(self):
1726        ctx = ssl._create_stdlib_context()
1727        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
1728        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1729        self.assertFalse(ctx.check_hostname)
1730        self._assert_context_options(ctx)
1731
1732        if has_tls_protocol(ssl.PROTOCOL_TLSv1):
1733            with warnings_helper.check_warnings():
1734                ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
1735            self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1736            self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1737            self._assert_context_options(ctx)
1738
1739        with warnings_helper.check_warnings():
1740            ctx = ssl._create_stdlib_context(
1741                ssl.PROTOCOL_TLSv1_2,
1742                cert_reqs=ssl.CERT_REQUIRED,
1743                check_hostname=True
1744            )
1745        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1_2)
1746        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1747        self.assertTrue(ctx.check_hostname)
1748        self._assert_context_options(ctx)
1749
1750        ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH)
1751        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_SERVER)
1752        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1753        self._assert_context_options(ctx)
1754
1755    def test_check_hostname(self):
1756        with warnings_helper.check_warnings():
1757            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1758        self.assertFalse(ctx.check_hostname)
1759        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1760
1761        # Auto set CERT_REQUIRED
1762        ctx.check_hostname = True
1763        self.assertTrue(ctx.check_hostname)
1764        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1765        ctx.check_hostname = False
1766        ctx.verify_mode = ssl.CERT_REQUIRED
1767        self.assertFalse(ctx.check_hostname)
1768        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1769
1770        # Changing verify_mode does not affect check_hostname
1771        ctx.check_hostname = False
1772        ctx.verify_mode = ssl.CERT_NONE
1773        ctx.check_hostname = False
1774        self.assertFalse(ctx.check_hostname)
1775        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1776        # Auto set
1777        ctx.check_hostname = True
1778        self.assertTrue(ctx.check_hostname)
1779        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1780
1781        ctx.check_hostname = False
1782        ctx.verify_mode = ssl.CERT_OPTIONAL
1783        ctx.check_hostname = False
1784        self.assertFalse(ctx.check_hostname)
1785        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1786        # keep CERT_OPTIONAL
1787        ctx.check_hostname = True
1788        self.assertTrue(ctx.check_hostname)
1789        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1790
1791        # Cannot set CERT_NONE with check_hostname enabled
1792        with self.assertRaises(ValueError):
1793            ctx.verify_mode = ssl.CERT_NONE
1794        ctx.check_hostname = False
1795        self.assertFalse(ctx.check_hostname)
1796        ctx.verify_mode = ssl.CERT_NONE
1797        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1798
1799    def test_context_client_server(self):
1800        # PROTOCOL_TLS_CLIENT has sane defaults
1801        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1802        self.assertTrue(ctx.check_hostname)
1803        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1804
1805        # PROTOCOL_TLS_SERVER has different but also sane defaults
1806        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1807        self.assertFalse(ctx.check_hostname)
1808        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1809
1810    def test_context_custom_class(self):
1811        class MySSLSocket(ssl.SSLSocket):
1812            pass
1813
1814        class MySSLObject(ssl.SSLObject):
1815            pass
1816
1817        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1818        ctx.sslsocket_class = MySSLSocket
1819        ctx.sslobject_class = MySSLObject
1820
1821        with ctx.wrap_socket(socket.socket(), server_side=True) as sock:
1822            self.assertIsInstance(sock, MySSLSocket)
1823        obj = ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(), server_side=True)
1824        self.assertIsInstance(obj, MySSLObject)
1825
1826    def test_num_tickest(self):
1827        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1828        self.assertEqual(ctx.num_tickets, 2)
1829        ctx.num_tickets = 1
1830        self.assertEqual(ctx.num_tickets, 1)
1831        ctx.num_tickets = 0
1832        self.assertEqual(ctx.num_tickets, 0)
1833        with self.assertRaises(ValueError):
1834            ctx.num_tickets = -1
1835        with self.assertRaises(TypeError):
1836            ctx.num_tickets = None
1837
1838        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1839        self.assertEqual(ctx.num_tickets, 2)
1840        with self.assertRaises(ValueError):
1841            ctx.num_tickets = 1
1842
1843
1844class SSLErrorTests(unittest.TestCase):
1845
1846    def test_str(self):
1847        # The str() of a SSLError doesn't include the errno
1848        e = ssl.SSLError(1, "foo")
1849        self.assertEqual(str(e), "foo")
1850        self.assertEqual(e.errno, 1)
1851        # Same for a subclass
1852        e = ssl.SSLZeroReturnError(1, "foo")
1853        self.assertEqual(str(e), "foo")
1854        self.assertEqual(e.errno, 1)
1855
1856    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
1857    def test_lib_reason(self):
1858        # Test the library and reason attributes
1859        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1860        with self.assertRaises(ssl.SSLError) as cm:
1861            ctx.load_dh_params(CERTFILE)
1862        self.assertEqual(cm.exception.library, 'PEM')
1863        self.assertEqual(cm.exception.reason, 'NO_START_LINE')
1864        s = str(cm.exception)
1865        self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
1866
1867    def test_subclass(self):
1868        # Check that the appropriate SSLError subclass is raised
1869        # (this only tests one of them)
1870        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1871        ctx.check_hostname = False
1872        ctx.verify_mode = ssl.CERT_NONE
1873        with socket.create_server(("127.0.0.1", 0)) as s:
1874            c = socket.create_connection(s.getsockname())
1875            c.setblocking(False)
1876            with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c:
1877                with self.assertRaises(ssl.SSLWantReadError) as cm:
1878                    c.do_handshake()
1879                s = str(cm.exception)
1880                self.assertTrue(s.startswith("The operation did not complete (read)"), s)
1881                # For compatibility
1882                self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
1883
1884
1885    def test_bad_server_hostname(self):
1886        ctx = ssl.create_default_context()
1887        with self.assertRaises(ValueError):
1888            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1889                         server_hostname="")
1890        with self.assertRaises(ValueError):
1891            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1892                         server_hostname=".example.org")
1893        with self.assertRaises(TypeError):
1894            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1895                         server_hostname="example.org\x00evil.com")
1896
1897
1898class MemoryBIOTests(unittest.TestCase):
1899
1900    def test_read_write(self):
1901        bio = ssl.MemoryBIO()
1902        bio.write(b'foo')
1903        self.assertEqual(bio.read(), b'foo')
1904        self.assertEqual(bio.read(), b'')
1905        bio.write(b'foo')
1906        bio.write(b'bar')
1907        self.assertEqual(bio.read(), b'foobar')
1908        self.assertEqual(bio.read(), b'')
1909        bio.write(b'baz')
1910        self.assertEqual(bio.read(2), b'ba')
1911        self.assertEqual(bio.read(1), b'z')
1912        self.assertEqual(bio.read(1), b'')
1913
1914    def test_eof(self):
1915        bio = ssl.MemoryBIO()
1916        self.assertFalse(bio.eof)
1917        self.assertEqual(bio.read(), b'')
1918        self.assertFalse(bio.eof)
1919        bio.write(b'foo')
1920        self.assertFalse(bio.eof)
1921        bio.write_eof()
1922        self.assertFalse(bio.eof)
1923        self.assertEqual(bio.read(2), b'fo')
1924        self.assertFalse(bio.eof)
1925        self.assertEqual(bio.read(1), b'o')
1926        self.assertTrue(bio.eof)
1927        self.assertEqual(bio.read(), b'')
1928        self.assertTrue(bio.eof)
1929
1930    def test_pending(self):
1931        bio = ssl.MemoryBIO()
1932        self.assertEqual(bio.pending, 0)
1933        bio.write(b'foo')
1934        self.assertEqual(bio.pending, 3)
1935        for i in range(3):
1936            bio.read(1)
1937            self.assertEqual(bio.pending, 3-i-1)
1938        for i in range(3):
1939            bio.write(b'x')
1940            self.assertEqual(bio.pending, i+1)
1941        bio.read()
1942        self.assertEqual(bio.pending, 0)
1943
1944    def test_buffer_types(self):
1945        bio = ssl.MemoryBIO()
1946        bio.write(b'foo')
1947        self.assertEqual(bio.read(), b'foo')
1948        bio.write(bytearray(b'bar'))
1949        self.assertEqual(bio.read(), b'bar')
1950        bio.write(memoryview(b'baz'))
1951        self.assertEqual(bio.read(), b'baz')
1952
1953    def test_error_types(self):
1954        bio = ssl.MemoryBIO()
1955        self.assertRaises(TypeError, bio.write, 'foo')
1956        self.assertRaises(TypeError, bio.write, None)
1957        self.assertRaises(TypeError, bio.write, True)
1958        self.assertRaises(TypeError, bio.write, 1)
1959
1960
1961class SSLObjectTests(unittest.TestCase):
1962    def test_private_init(self):
1963        bio = ssl.MemoryBIO()
1964        with self.assertRaisesRegex(TypeError, "public constructor"):
1965            ssl.SSLObject(bio, bio)
1966
1967    def test_unwrap(self):
1968        client_ctx, server_ctx, hostname = testing_context()
1969        c_in = ssl.MemoryBIO()
1970        c_out = ssl.MemoryBIO()
1971        s_in = ssl.MemoryBIO()
1972        s_out = ssl.MemoryBIO()
1973        client = client_ctx.wrap_bio(c_in, c_out, server_hostname=hostname)
1974        server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
1975
1976        # Loop on the handshake for a bit to get it settled
1977        for _ in range(5):
1978            try:
1979                client.do_handshake()
1980            except ssl.SSLWantReadError:
1981                pass
1982            if c_out.pending:
1983                s_in.write(c_out.read())
1984            try:
1985                server.do_handshake()
1986            except ssl.SSLWantReadError:
1987                pass
1988            if s_out.pending:
1989                c_in.write(s_out.read())
1990        # Now the handshakes should be complete (don't raise WantReadError)
1991        client.do_handshake()
1992        server.do_handshake()
1993
1994        # Now if we unwrap one side unilaterally, it should send close-notify
1995        # and raise WantReadError:
1996        with self.assertRaises(ssl.SSLWantReadError):
1997            client.unwrap()
1998
1999        # But server.unwrap() does not raise, because it reads the client's
2000        # close-notify:
2001        s_in.write(c_out.read())
2002        server.unwrap()
2003
2004        # And now that the client gets the server's close-notify, it doesn't
2005        # raise either.
2006        c_in.write(s_out.read())
2007        client.unwrap()
2008
2009class SimpleBackgroundTests(unittest.TestCase):
2010    """Tests that connect to a simple server running in the background"""
2011
2012    def setUp(self):
2013        self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
2014        self.server_context.load_cert_chain(SIGNED_CERTFILE)
2015        server = ThreadedEchoServer(context=self.server_context)
2016        self.enterContext(server)
2017        self.server_addr = (HOST, server.port)
2018
2019    def test_connect(self):
2020        with test_wrap_socket(socket.socket(socket.AF_INET),
2021                            cert_reqs=ssl.CERT_NONE) as s:
2022            s.connect(self.server_addr)
2023            self.assertEqual({}, s.getpeercert())
2024            self.assertFalse(s.server_side)
2025
2026        # this should succeed because we specify the root cert
2027        with test_wrap_socket(socket.socket(socket.AF_INET),
2028                            cert_reqs=ssl.CERT_REQUIRED,
2029                            ca_certs=SIGNING_CA) as s:
2030            s.connect(self.server_addr)
2031            self.assertTrue(s.getpeercert())
2032            self.assertFalse(s.server_side)
2033
2034    def test_connect_fail(self):
2035        # This should fail because we have no verification certs. Connection
2036        # failure crashes ThreadedEchoServer, so run this in an independent
2037        # test method.
2038        s = test_wrap_socket(socket.socket(socket.AF_INET),
2039                            cert_reqs=ssl.CERT_REQUIRED)
2040        self.addCleanup(s.close)
2041        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
2042                               s.connect, self.server_addr)
2043
2044    def test_connect_ex(self):
2045        # Issue #11326: check connect_ex() implementation
2046        s = test_wrap_socket(socket.socket(socket.AF_INET),
2047                            cert_reqs=ssl.CERT_REQUIRED,
2048                            ca_certs=SIGNING_CA)
2049        self.addCleanup(s.close)
2050        self.assertEqual(0, s.connect_ex(self.server_addr))
2051        self.assertTrue(s.getpeercert())
2052
2053    def test_non_blocking_connect_ex(self):
2054        # Issue #11326: non-blocking connect_ex() should allow handshake
2055        # to proceed after the socket gets ready.
2056        s = test_wrap_socket(socket.socket(socket.AF_INET),
2057                            cert_reqs=ssl.CERT_REQUIRED,
2058                            ca_certs=SIGNING_CA,
2059                            do_handshake_on_connect=False)
2060        self.addCleanup(s.close)
2061        s.setblocking(False)
2062        rc = s.connect_ex(self.server_addr)
2063        # EWOULDBLOCK under Windows, EINPROGRESS elsewhere
2064        self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
2065        # Wait for connect to finish
2066        select.select([], [s], [], 5.0)
2067        # Non-blocking handshake
2068        while True:
2069            try:
2070                s.do_handshake()
2071                break
2072            except ssl.SSLWantReadError:
2073                select.select([s], [], [], 5.0)
2074            except ssl.SSLWantWriteError:
2075                select.select([], [s], [], 5.0)
2076        # SSL established
2077        self.assertTrue(s.getpeercert())
2078
2079    def test_connect_with_context(self):
2080        # Same as test_connect, but with a separately created context
2081        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2082        ctx.check_hostname = False
2083        ctx.verify_mode = ssl.CERT_NONE
2084        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2085            s.connect(self.server_addr)
2086            self.assertEqual({}, s.getpeercert())
2087        # Same with a server hostname
2088        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2089                            server_hostname="dummy") as s:
2090            s.connect(self.server_addr)
2091        ctx.verify_mode = ssl.CERT_REQUIRED
2092        # This should succeed because we specify the root cert
2093        ctx.load_verify_locations(SIGNING_CA)
2094        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2095            s.connect(self.server_addr)
2096            cert = s.getpeercert()
2097            self.assertTrue(cert)
2098
2099    def test_connect_with_context_fail(self):
2100        # This should fail because we have no verification certs. Connection
2101        # failure crashes ThreadedEchoServer, so run this in an independent
2102        # test method.
2103        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2104        s = ctx.wrap_socket(
2105            socket.socket(socket.AF_INET),
2106            server_hostname=SIGNED_CERTFILE_HOSTNAME
2107        )
2108        self.addCleanup(s.close)
2109        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
2110                                s.connect, self.server_addr)
2111
2112    def test_connect_capath(self):
2113        # Verify server certificates using the `capath` argument
2114        # NOTE: the subject hashing algorithm has been changed between
2115        # OpenSSL 0.9.8n and 1.0.0, as a result the capath directory must
2116        # contain both versions of each certificate (same content, different
2117        # filename) for this test to be portable across OpenSSL releases.
2118        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2119        ctx.load_verify_locations(capath=CAPATH)
2120        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2121                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2122            s.connect(self.server_addr)
2123            cert = s.getpeercert()
2124            self.assertTrue(cert)
2125
2126        # Same with a bytes `capath` argument
2127        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2128        ctx.load_verify_locations(capath=BYTES_CAPATH)
2129        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2130                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2131            s.connect(self.server_addr)
2132            cert = s.getpeercert()
2133            self.assertTrue(cert)
2134
2135    def test_connect_cadata(self):
2136        with open(SIGNING_CA) as f:
2137            pem = f.read()
2138        der = ssl.PEM_cert_to_DER_cert(pem)
2139        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2140        ctx.load_verify_locations(cadata=pem)
2141        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2142                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2143            s.connect(self.server_addr)
2144            cert = s.getpeercert()
2145            self.assertTrue(cert)
2146
2147        # same with DER
2148        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2149        ctx.load_verify_locations(cadata=der)
2150        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2151                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2152            s.connect(self.server_addr)
2153            cert = s.getpeercert()
2154            self.assertTrue(cert)
2155
2156    @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
2157    def test_makefile_close(self):
2158        # Issue #5238: creating a file-like object with makefile() shouldn't
2159        # delay closing the underlying "real socket" (here tested with its
2160        # file descriptor, hence skipping the test under Windows).
2161        ss = test_wrap_socket(socket.socket(socket.AF_INET))
2162        ss.connect(self.server_addr)
2163        fd = ss.fileno()
2164        f = ss.makefile()
2165        f.close()
2166        # The fd is still open
2167        os.read(fd, 0)
2168        # Closing the SSL socket should close the fd too
2169        ss.close()
2170        gc.collect()
2171        with self.assertRaises(OSError) as e:
2172            os.read(fd, 0)
2173        self.assertEqual(e.exception.errno, errno.EBADF)
2174
2175    def test_non_blocking_handshake(self):
2176        s = socket.socket(socket.AF_INET)
2177        s.connect(self.server_addr)
2178        s.setblocking(False)
2179        s = test_wrap_socket(s,
2180                            cert_reqs=ssl.CERT_NONE,
2181                            do_handshake_on_connect=False)
2182        self.addCleanup(s.close)
2183        count = 0
2184        while True:
2185            try:
2186                count += 1
2187                s.do_handshake()
2188                break
2189            except ssl.SSLWantReadError:
2190                select.select([s], [], [])
2191            except ssl.SSLWantWriteError:
2192                select.select([], [s], [])
2193        if support.verbose:
2194            sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
2195
2196    def test_get_server_certificate(self):
2197        _test_get_server_certificate(self, *self.server_addr, cert=SIGNING_CA)
2198
2199    def test_get_server_certificate_sni(self):
2200        host, port = self.server_addr
2201        server_names = []
2202
2203        # We store servername_cb arguments to make sure they match the host
2204        def servername_cb(ssl_sock, server_name, initial_context):
2205            server_names.append(server_name)
2206        self.server_context.set_servername_callback(servername_cb)
2207
2208        pem = ssl.get_server_certificate((host, port))
2209        if not pem:
2210            self.fail("No server certificate on %s:%s!" % (host, port))
2211
2212        pem = ssl.get_server_certificate((host, port), ca_certs=SIGNING_CA)
2213        if not pem:
2214            self.fail("No server certificate on %s:%s!" % (host, port))
2215        if support.verbose:
2216            sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port, pem))
2217
2218        self.assertEqual(server_names, [host, host])
2219
2220    def test_get_server_certificate_fail(self):
2221        # Connection failure crashes ThreadedEchoServer, so run this in an
2222        # independent test method
2223        _test_get_server_certificate_fail(self, *self.server_addr)
2224
2225    def test_get_server_certificate_timeout(self):
2226        def servername_cb(ssl_sock, server_name, initial_context):
2227            time.sleep(0.2)
2228        self.server_context.set_servername_callback(servername_cb)
2229
2230        with self.assertRaises(socket.timeout):
2231            ssl.get_server_certificate(self.server_addr, ca_certs=SIGNING_CA,
2232                                       timeout=0.1)
2233
2234    def test_ciphers(self):
2235        with test_wrap_socket(socket.socket(socket.AF_INET),
2236                             cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s:
2237            s.connect(self.server_addr)
2238        with test_wrap_socket(socket.socket(socket.AF_INET),
2239                             cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s:
2240            s.connect(self.server_addr)
2241        # Error checking can happen at instantiation or when connecting
2242        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
2243            with socket.socket(socket.AF_INET) as sock:
2244                s = test_wrap_socket(sock,
2245                                    cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
2246                s.connect(self.server_addr)
2247
2248    def test_get_ca_certs_capath(self):
2249        # capath certs are loaded on request
2250        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2251        ctx.load_verify_locations(capath=CAPATH)
2252        self.assertEqual(ctx.get_ca_certs(), [])
2253        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2254                             server_hostname='localhost') as s:
2255            s.connect(self.server_addr)
2256            cert = s.getpeercert()
2257            self.assertTrue(cert)
2258        self.assertEqual(len(ctx.get_ca_certs()), 1)
2259
2260    def test_context_setget(self):
2261        # Check that the context of a connected socket can be replaced.
2262        ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2263        ctx1.load_verify_locations(capath=CAPATH)
2264        ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2265        ctx2.load_verify_locations(capath=CAPATH)
2266        s = socket.socket(socket.AF_INET)
2267        with ctx1.wrap_socket(s, server_hostname='localhost') as ss:
2268            ss.connect(self.server_addr)
2269            self.assertIs(ss.context, ctx1)
2270            self.assertIs(ss._sslobj.context, ctx1)
2271            ss.context = ctx2
2272            self.assertIs(ss.context, ctx2)
2273            self.assertIs(ss._sslobj.context, ctx2)
2274
2275    def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
2276        # A simple IO loop. Call func(*args) depending on the error we get
2277        # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
2278        timeout = kwargs.get('timeout', support.SHORT_TIMEOUT)
2279        deadline = time.monotonic() + timeout
2280        count = 0
2281        while True:
2282            if time.monotonic() > deadline:
2283                self.fail("timeout")
2284            errno = None
2285            count += 1
2286            try:
2287                ret = func(*args)
2288            except ssl.SSLError as e:
2289                if e.errno not in (ssl.SSL_ERROR_WANT_READ,
2290                                   ssl.SSL_ERROR_WANT_WRITE):
2291                    raise
2292                errno = e.errno
2293            # Get any data from the outgoing BIO irrespective of any error, and
2294            # send it to the socket.
2295            buf = outgoing.read()
2296            sock.sendall(buf)
2297            # If there's no error, we're done. For WANT_READ, we need to get
2298            # data from the socket and put it in the incoming BIO.
2299            if errno is None:
2300                break
2301            elif errno == ssl.SSL_ERROR_WANT_READ:
2302                buf = sock.recv(32768)
2303                if buf:
2304                    incoming.write(buf)
2305                else:
2306                    incoming.write_eof()
2307        if support.verbose:
2308            sys.stdout.write("Needed %d calls to complete %s().\n"
2309                             % (count, func.__name__))
2310        return ret
2311
2312    def test_bio_handshake(self):
2313        sock = socket.socket(socket.AF_INET)
2314        self.addCleanup(sock.close)
2315        sock.connect(self.server_addr)
2316        incoming = ssl.MemoryBIO()
2317        outgoing = ssl.MemoryBIO()
2318        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2319        self.assertTrue(ctx.check_hostname)
2320        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
2321        ctx.load_verify_locations(SIGNING_CA)
2322        sslobj = ctx.wrap_bio(incoming, outgoing, False,
2323                              SIGNED_CERTFILE_HOSTNAME)
2324        self.assertIs(sslobj._sslobj.owner, sslobj)
2325        self.assertIsNone(sslobj.cipher())
2326        self.assertIsNone(sslobj.version())
2327        self.assertIsNone(sslobj.shared_ciphers())
2328        self.assertRaises(ValueError, sslobj.getpeercert)
2329        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2330            self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
2331        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2332        self.assertTrue(sslobj.cipher())
2333        self.assertIsNone(sslobj.shared_ciphers())
2334        self.assertIsNotNone(sslobj.version())
2335        self.assertTrue(sslobj.getpeercert())
2336        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2337            self.assertTrue(sslobj.get_channel_binding('tls-unique'))
2338        try:
2339            self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2340        except ssl.SSLSyscallError:
2341            # If the server shuts down the TCP connection without sending a
2342            # secure shutdown message, this is reported as SSL_ERROR_SYSCALL
2343            pass
2344        self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
2345
2346    def test_bio_read_write_data(self):
2347        sock = socket.socket(socket.AF_INET)
2348        self.addCleanup(sock.close)
2349        sock.connect(self.server_addr)
2350        incoming = ssl.MemoryBIO()
2351        outgoing = ssl.MemoryBIO()
2352        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2353        ctx.check_hostname = False
2354        ctx.verify_mode = ssl.CERT_NONE
2355        sslobj = ctx.wrap_bio(incoming, outgoing, False)
2356        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2357        req = b'FOO\n'
2358        self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2359        buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2360        self.assertEqual(buf, b'foo\n')
2361        self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2362
2363    def test_transport_eof(self):
2364        client_context, server_context, hostname = testing_context()
2365        with socket.socket(socket.AF_INET) as sock:
2366            sock.connect(self.server_addr)
2367            incoming = ssl.MemoryBIO()
2368            outgoing = ssl.MemoryBIO()
2369            sslobj = client_context.wrap_bio(incoming, outgoing,
2370                                             server_hostname=hostname)
2371            self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2372
2373            # Simulate EOF from the transport.
2374            incoming.write_eof()
2375            self.assertRaises(ssl.SSLEOFError, sslobj.read)
2376
2377
2378@support.requires_resource('network')
2379class NetworkedTests(unittest.TestCase):
2380
2381    def test_timeout_connect_ex(self):
2382        # Issue #12065: on a timeout, connect_ex() should return the original
2383        # errno (mimicking the behaviour of non-SSL sockets).
2384        with socket_helper.transient_internet(REMOTE_HOST):
2385            s = test_wrap_socket(socket.socket(socket.AF_INET),
2386                                cert_reqs=ssl.CERT_REQUIRED,
2387                                do_handshake_on_connect=False)
2388            self.addCleanup(s.close)
2389            s.settimeout(0.0000001)
2390            rc = s.connect_ex((REMOTE_HOST, 443))
2391            if rc == 0:
2392                self.skipTest("REMOTE_HOST responded too quickly")
2393            elif rc == errno.ENETUNREACH:
2394                self.skipTest("Network unreachable.")
2395            self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK))
2396
2397    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'Needs IPv6')
2398    def test_get_server_certificate_ipv6(self):
2399        with socket_helper.transient_internet('ipv6.google.com'):
2400            _test_get_server_certificate(self, 'ipv6.google.com', 443)
2401            _test_get_server_certificate_fail(self, 'ipv6.google.com', 443)
2402
2403
2404def _test_get_server_certificate(test, host, port, cert=None):
2405    pem = ssl.get_server_certificate((host, port))
2406    if not pem:
2407        test.fail("No server certificate on %s:%s!" % (host, port))
2408
2409    pem = ssl.get_server_certificate((host, port), ca_certs=cert)
2410    if not pem:
2411        test.fail("No server certificate on %s:%s!" % (host, port))
2412    if support.verbose:
2413        sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
2414
2415def _test_get_server_certificate_fail(test, host, port):
2416    try:
2417        pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
2418    except ssl.SSLError as x:
2419        #should fail
2420        if support.verbose:
2421            sys.stdout.write("%s\n" % x)
2422    else:
2423        test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
2424
2425
2426from test.ssl_servers import make_https_server
2427
2428class ThreadedEchoServer(threading.Thread):
2429
2430    class ConnectionHandler(threading.Thread):
2431
2432        """A mildly complicated class, because we want it to work both
2433        with and without the SSL wrapper around the socket connection, so
2434        that we can test the STARTTLS functionality."""
2435
2436        def __init__(self, server, connsock, addr):
2437            self.server = server
2438            self.running = False
2439            self.sock = connsock
2440            self.addr = addr
2441            self.sock.setblocking(True)
2442            self.sslconn = None
2443            threading.Thread.__init__(self)
2444            self.daemon = True
2445
2446        def wrap_conn(self):
2447            try:
2448                self.sslconn = self.server.context.wrap_socket(
2449                    self.sock, server_side=True)
2450                self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
2451            except (ConnectionResetError, BrokenPipeError, ConnectionAbortedError) as e:
2452                # We treat ConnectionResetError as though it were an
2453                # SSLError - OpenSSL on Ubuntu abruptly closes the
2454                # connection when asked to use an unsupported protocol.
2455                #
2456                # BrokenPipeError is raised in TLS 1.3 mode, when OpenSSL
2457                # tries to send session tickets after handshake.
2458                # https://github.com/openssl/openssl/issues/6342
2459                #
2460                # ConnectionAbortedError is raised in TLS 1.3 mode, when OpenSSL
2461                # tries to send session tickets after handshake when using WinSock.
2462                self.server.conn_errors.append(str(e))
2463                if self.server.chatty:
2464                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2465                self.running = False
2466                self.close()
2467                return False
2468            except (ssl.SSLError, OSError) as e:
2469                # OSError may occur with wrong protocols, e.g. both
2470                # sides use PROTOCOL_TLS_SERVER.
2471                #
2472                # XXX Various errors can have happened here, for example
2473                # a mismatching protocol version, an invalid certificate,
2474                # or a low-level bug. This should be made more discriminating.
2475                #
2476                # bpo-31323: Store the exception as string to prevent
2477                # a reference leak: server -> conn_errors -> exception
2478                # -> traceback -> self (ConnectionHandler) -> server
2479                self.server.conn_errors.append(str(e))
2480                if self.server.chatty:
2481                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2482
2483                # bpo-44229, bpo-43855, bpo-44237, and bpo-33450:
2484                # Ignore spurious EPROTOTYPE returned by write() on macOS.
2485                # See also http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
2486                if e.errno != errno.EPROTOTYPE and sys.platform != "darwin":
2487                    self.running = False
2488                    self.server.stop()
2489                    self.close()
2490                return False
2491            else:
2492                self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
2493                if self.server.context.verify_mode == ssl.CERT_REQUIRED:
2494                    cert = self.sslconn.getpeercert()
2495                    if support.verbose and self.server.chatty:
2496                        sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
2497                    cert_binary = self.sslconn.getpeercert(True)
2498                    if support.verbose and self.server.chatty:
2499                        if cert_binary is None:
2500                            sys.stdout.write(" client did not provide a cert\n")
2501                        else:
2502                            sys.stdout.write(f" cert binary is {len(cert_binary)}b\n")
2503                cipher = self.sslconn.cipher()
2504                if support.verbose and self.server.chatty:
2505                    sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
2506                return True
2507
2508        def read(self):
2509            if self.sslconn:
2510                return self.sslconn.read()
2511            else:
2512                return self.sock.recv(1024)
2513
2514        def write(self, bytes):
2515            if self.sslconn:
2516                return self.sslconn.write(bytes)
2517            else:
2518                return self.sock.send(bytes)
2519
2520        def close(self):
2521            if self.sslconn:
2522                self.sslconn.close()
2523            else:
2524                self.sock.close()
2525
2526        def run(self):
2527            self.running = True
2528            if not self.server.starttls_server:
2529                if not self.wrap_conn():
2530                    return
2531            while self.running:
2532                try:
2533                    msg = self.read()
2534                    stripped = msg.strip()
2535                    if not stripped:
2536                        # eof, so quit this handler
2537                        self.running = False
2538                        try:
2539                            self.sock = self.sslconn.unwrap()
2540                        except OSError:
2541                            # Many tests shut the TCP connection down
2542                            # without an SSL shutdown. This causes
2543                            # unwrap() to raise OSError with errno=0!
2544                            pass
2545                        else:
2546                            self.sslconn = None
2547                        self.close()
2548                    elif stripped == b'over':
2549                        if support.verbose and self.server.connectionchatty:
2550                            sys.stdout.write(" server: client closed connection\n")
2551                        self.close()
2552                        return
2553                    elif (self.server.starttls_server and
2554                          stripped == b'STARTTLS'):
2555                        if support.verbose and self.server.connectionchatty:
2556                            sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
2557                        self.write(b"OK\n")
2558                        if not self.wrap_conn():
2559                            return
2560                    elif (self.server.starttls_server and self.sslconn
2561                          and stripped == b'ENDTLS'):
2562                        if support.verbose and self.server.connectionchatty:
2563                            sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
2564                        self.write(b"OK\n")
2565                        self.sock = self.sslconn.unwrap()
2566                        self.sslconn = None
2567                        if support.verbose and self.server.connectionchatty:
2568                            sys.stdout.write(" server: connection is now unencrypted...\n")
2569                    elif stripped == b'CB tls-unique':
2570                        if support.verbose and self.server.connectionchatty:
2571                            sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
2572                        data = self.sslconn.get_channel_binding("tls-unique")
2573                        self.write(repr(data).encode("us-ascii") + b"\n")
2574                    elif stripped == b'PHA':
2575                        if support.verbose and self.server.connectionchatty:
2576                            sys.stdout.write(" server: initiating post handshake auth\n")
2577                        try:
2578                            self.sslconn.verify_client_post_handshake()
2579                        except ssl.SSLError as e:
2580                            self.write(repr(e).encode("us-ascii") + b"\n")
2581                        else:
2582                            self.write(b"OK\n")
2583                    elif stripped == b'HASCERT':
2584                        if self.sslconn.getpeercert() is not None:
2585                            self.write(b'TRUE\n')
2586                        else:
2587                            self.write(b'FALSE\n')
2588                    elif stripped == b'GETCERT':
2589                        cert = self.sslconn.getpeercert()
2590                        self.write(repr(cert).encode("us-ascii") + b"\n")
2591                    elif stripped == b'VERIFIEDCHAIN':
2592                        certs = self.sslconn._sslobj.get_verified_chain()
2593                        self.write(len(certs).to_bytes(1, "big") + b"\n")
2594                    elif stripped == b'UNVERIFIEDCHAIN':
2595                        certs = self.sslconn._sslobj.get_unverified_chain()
2596                        self.write(len(certs).to_bytes(1, "big") + b"\n")
2597                    else:
2598                        if (support.verbose and
2599                            self.server.connectionchatty):
2600                            ctype = (self.sslconn and "encrypted") or "unencrypted"
2601                            sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
2602                                             % (msg, ctype, msg.lower(), ctype))
2603                        self.write(msg.lower())
2604                except OSError as e:
2605                    # handles SSLError and socket errors
2606                    if self.server.chatty and support.verbose:
2607                        if isinstance(e, ConnectionError):
2608                            # OpenSSL 1.1.1 sometimes raises
2609                            # ConnectionResetError when connection is not
2610                            # shut down gracefully.
2611                            print(
2612                                f" Connection reset by peer: {self.addr}"
2613                            )
2614                        else:
2615                            handle_error("Test server failure:\n")
2616                    try:
2617                        self.write(b"ERROR\n")
2618                    except OSError:
2619                        pass
2620                    self.close()
2621                    self.running = False
2622
2623                    # normally, we'd just stop here, but for the test
2624                    # harness, we want to stop the server
2625                    self.server.stop()
2626
2627    def __init__(self, certificate=None, ssl_version=None,
2628                 certreqs=None, cacerts=None,
2629                 chatty=True, connectionchatty=False, starttls_server=False,
2630                 alpn_protocols=None,
2631                 ciphers=None, context=None):
2632        if context:
2633            self.context = context
2634        else:
2635            self.context = ssl.SSLContext(ssl_version
2636                                          if ssl_version is not None
2637                                          else ssl.PROTOCOL_TLS_SERVER)
2638            self.context.verify_mode = (certreqs if certreqs is not None
2639                                        else ssl.CERT_NONE)
2640            if cacerts:
2641                self.context.load_verify_locations(cacerts)
2642            if certificate:
2643                self.context.load_cert_chain(certificate)
2644            if alpn_protocols:
2645                self.context.set_alpn_protocols(alpn_protocols)
2646            if ciphers:
2647                self.context.set_ciphers(ciphers)
2648        self.chatty = chatty
2649        self.connectionchatty = connectionchatty
2650        self.starttls_server = starttls_server
2651        self.sock = socket.socket()
2652        self.port = socket_helper.bind_port(self.sock)
2653        self.flag = None
2654        self.active = False
2655        self.selected_alpn_protocols = []
2656        self.shared_ciphers = []
2657        self.conn_errors = []
2658        threading.Thread.__init__(self)
2659        self.daemon = True
2660
2661    def __enter__(self):
2662        self.start(threading.Event())
2663        self.flag.wait()
2664        return self
2665
2666    def __exit__(self, *args):
2667        self.stop()
2668        self.join()
2669
2670    def start(self, flag=None):
2671        self.flag = flag
2672        threading.Thread.start(self)
2673
2674    def run(self):
2675        self.sock.settimeout(1.0)
2676        self.sock.listen(5)
2677        self.active = True
2678        if self.flag:
2679            # signal an event
2680            self.flag.set()
2681        while self.active:
2682            try:
2683                newconn, connaddr = self.sock.accept()
2684                if support.verbose and self.chatty:
2685                    sys.stdout.write(' server:  new connection from '
2686                                     + repr(connaddr) + '\n')
2687                handler = self.ConnectionHandler(self, newconn, connaddr)
2688                handler.start()
2689                handler.join()
2690            except TimeoutError as e:
2691                if support.verbose:
2692                    sys.stdout.write(f' connection timeout {e!r}\n')
2693            except KeyboardInterrupt:
2694                self.stop()
2695            except BaseException as e:
2696                if support.verbose and self.chatty:
2697                    sys.stdout.write(
2698                        ' connection handling failed: ' + repr(e) + '\n')
2699
2700        self.close()
2701
2702    def close(self):
2703        if self.sock is not None:
2704            self.sock.close()
2705            self.sock = None
2706
2707    def stop(self):
2708        self.active = False
2709
2710class AsyncoreEchoServer(threading.Thread):
2711
2712    # this one's based on asyncore.dispatcher
2713
2714    class EchoServer (asyncore.dispatcher):
2715
2716        class ConnectionHandler(asyncore.dispatcher_with_send):
2717
2718            def __init__(self, conn, certfile):
2719                self.socket = test_wrap_socket(conn, server_side=True,
2720                                              certfile=certfile,
2721                                              do_handshake_on_connect=False)
2722                asyncore.dispatcher_with_send.__init__(self, self.socket)
2723                self._ssl_accepting = True
2724                self._do_ssl_handshake()
2725
2726            def readable(self):
2727                if isinstance(self.socket, ssl.SSLSocket):
2728                    while self.socket.pending() > 0:
2729                        self.handle_read_event()
2730                return True
2731
2732            def _do_ssl_handshake(self):
2733                try:
2734                    self.socket.do_handshake()
2735                except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
2736                    return
2737                except ssl.SSLEOFError:
2738                    return self.handle_close()
2739                except ssl.SSLError:
2740                    raise
2741                except OSError as err:
2742                    if err.args[0] == errno.ECONNABORTED:
2743                        return self.handle_close()
2744                else:
2745                    self._ssl_accepting = False
2746
2747            def handle_read(self):
2748                if self._ssl_accepting:
2749                    self._do_ssl_handshake()
2750                else:
2751                    data = self.recv(1024)
2752                    if support.verbose:
2753                        sys.stdout.write(" server:  read %s from client\n" % repr(data))
2754                    if not data:
2755                        self.close()
2756                    else:
2757                        self.send(data.lower())
2758
2759            def handle_close(self):
2760                self.close()
2761                if support.verbose:
2762                    sys.stdout.write(" server:  closed connection %s\n" % self.socket)
2763
2764            def handle_error(self):
2765                raise
2766
2767        def __init__(self, certfile):
2768            self.certfile = certfile
2769            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2770            self.port = socket_helper.bind_port(sock, '')
2771            asyncore.dispatcher.__init__(self, sock)
2772            self.listen(5)
2773
2774        def handle_accepted(self, sock_obj, addr):
2775            if support.verbose:
2776                sys.stdout.write(" server:  new connection from %s:%s\n" %addr)
2777            self.ConnectionHandler(sock_obj, self.certfile)
2778
2779        def handle_error(self):
2780            raise
2781
2782    def __init__(self, certfile):
2783        self.flag = None
2784        self.active = False
2785        self.server = self.EchoServer(certfile)
2786        self.port = self.server.port
2787        threading.Thread.__init__(self)
2788        self.daemon = True
2789
2790    def __str__(self):
2791        return "<%s %s>" % (self.__class__.__name__, self.server)
2792
2793    def __enter__(self):
2794        self.start(threading.Event())
2795        self.flag.wait()
2796        return self
2797
2798    def __exit__(self, *args):
2799        if support.verbose:
2800            sys.stdout.write(" cleanup: stopping server.\n")
2801        self.stop()
2802        if support.verbose:
2803            sys.stdout.write(" cleanup: joining server thread.\n")
2804        self.join()
2805        if support.verbose:
2806            sys.stdout.write(" cleanup: successfully joined.\n")
2807        # make sure that ConnectionHandler is removed from socket_map
2808        asyncore.close_all(ignore_all=True)
2809
2810    def start (self, flag=None):
2811        self.flag = flag
2812        threading.Thread.start(self)
2813
2814    def run(self):
2815        self.active = True
2816        if self.flag:
2817            self.flag.set()
2818        while self.active:
2819            try:
2820                asyncore.loop(1)
2821            except:
2822                pass
2823
2824    def stop(self):
2825        self.active = False
2826        self.server.close()
2827
2828def server_params_test(client_context, server_context, indata=b"FOO\n",
2829                       chatty=True, connectionchatty=False, sni_name=None,
2830                       session=None):
2831    """
2832    Launch a server, connect a client to it and try various reads
2833    and writes.
2834    """
2835    stats = {}
2836    server = ThreadedEchoServer(context=server_context,
2837                                chatty=chatty,
2838                                connectionchatty=False)
2839    with server:
2840        with client_context.wrap_socket(socket.socket(),
2841                server_hostname=sni_name, session=session) as s:
2842            s.connect((HOST, server.port))
2843            for arg in [indata, bytearray(indata), memoryview(indata)]:
2844                if connectionchatty:
2845                    if support.verbose:
2846                        sys.stdout.write(
2847                            " client:  sending %r...\n" % indata)
2848                s.write(arg)
2849                outdata = s.read()
2850                if connectionchatty:
2851                    if support.verbose:
2852                        sys.stdout.write(" client:  read %r\n" % outdata)
2853                if outdata != indata.lower():
2854                    raise AssertionError(
2855                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2856                        % (outdata[:20], len(outdata),
2857                           indata[:20].lower(), len(indata)))
2858            s.write(b"over\n")
2859            if connectionchatty:
2860                if support.verbose:
2861                    sys.stdout.write(" client:  closing connection.\n")
2862            stats.update({
2863                'compression': s.compression(),
2864                'cipher': s.cipher(),
2865                'peercert': s.getpeercert(),
2866                'client_alpn_protocol': s.selected_alpn_protocol(),
2867                'version': s.version(),
2868                'session_reused': s.session_reused,
2869                'session': s.session,
2870            })
2871            s.close()
2872        stats['server_alpn_protocols'] = server.selected_alpn_protocols
2873        stats['server_shared_ciphers'] = server.shared_ciphers
2874    return stats
2875
2876def try_protocol_combo(server_protocol, client_protocol, expect_success,
2877                       certsreqs=None, server_options=0, client_options=0):
2878    """
2879    Try to SSL-connect using *client_protocol* to *server_protocol*.
2880    If *expect_success* is true, assert that the connection succeeds,
2881    if it's false, assert that the connection fails.
2882    Also, if *expect_success* is a string, assert that it is the protocol
2883    version actually used by the connection.
2884    """
2885    if certsreqs is None:
2886        certsreqs = ssl.CERT_NONE
2887    certtype = {
2888        ssl.CERT_NONE: "CERT_NONE",
2889        ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
2890        ssl.CERT_REQUIRED: "CERT_REQUIRED",
2891    }[certsreqs]
2892    if support.verbose:
2893        formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
2894        sys.stdout.write(formatstr %
2895                         (ssl.get_protocol_name(client_protocol),
2896                          ssl.get_protocol_name(server_protocol),
2897                          certtype))
2898
2899    with warnings_helper.check_warnings():
2900        # ignore Deprecation warnings
2901        client_context = ssl.SSLContext(client_protocol)
2902        client_context.options |= client_options
2903        server_context = ssl.SSLContext(server_protocol)
2904        server_context.options |= server_options
2905
2906    min_version = PROTOCOL_TO_TLS_VERSION.get(client_protocol, None)
2907    if (min_version is not None
2908        # SSLContext.minimum_version is only available on recent OpenSSL
2909        # (setter added in OpenSSL 1.1.0, getter added in OpenSSL 1.1.1)
2910        and hasattr(server_context, 'minimum_version')
2911        and server_protocol == ssl.PROTOCOL_TLS
2912        and server_context.minimum_version > min_version
2913    ):
2914        # If OpenSSL configuration is strict and requires more recent TLS
2915        # version, we have to change the minimum to test old TLS versions.
2916        with warnings_helper.check_warnings():
2917            server_context.minimum_version = min_version
2918
2919    # NOTE: we must enable "ALL" ciphers on the client, otherwise an
2920    # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
2921    # starting from OpenSSL 1.0.0 (see issue #8322).
2922    if client_context.protocol == ssl.PROTOCOL_TLS:
2923        client_context.set_ciphers("ALL")
2924
2925    seclevel_workaround(server_context, client_context)
2926
2927    for ctx in (client_context, server_context):
2928        ctx.verify_mode = certsreqs
2929        ctx.load_cert_chain(SIGNED_CERTFILE)
2930        ctx.load_verify_locations(SIGNING_CA)
2931    try:
2932        stats = server_params_test(client_context, server_context,
2933                                   chatty=False, connectionchatty=False)
2934    # Protocol mismatch can result in either an SSLError, or a
2935    # "Connection reset by peer" error.
2936    except ssl.SSLError:
2937        if expect_success:
2938            raise
2939    except OSError as e:
2940        if expect_success or e.errno != errno.ECONNRESET:
2941            raise
2942    else:
2943        if not expect_success:
2944            raise AssertionError(
2945                "Client protocol %s succeeded with server protocol %s!"
2946                % (ssl.get_protocol_name(client_protocol),
2947                   ssl.get_protocol_name(server_protocol)))
2948        elif (expect_success is not True
2949              and expect_success != stats['version']):
2950            raise AssertionError("version mismatch: expected %r, got %r"
2951                                 % (expect_success, stats['version']))
2952
2953
2954class ThreadedTests(unittest.TestCase):
2955
2956    def test_echo(self):
2957        """Basic test of an SSL client connecting to a server"""
2958        if support.verbose:
2959            sys.stdout.write("\n")
2960
2961        client_context, server_context, hostname = testing_context()
2962
2963        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
2964            server_params_test(client_context=client_context,
2965                               server_context=server_context,
2966                               chatty=True, connectionchatty=True,
2967                               sni_name=hostname)
2968
2969        client_context.check_hostname = False
2970        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
2971            with self.assertRaises(ssl.SSLError) as e:
2972                server_params_test(client_context=server_context,
2973                                   server_context=client_context,
2974                                   chatty=True, connectionchatty=True,
2975                                   sni_name=hostname)
2976            self.assertIn(
2977                'Cannot create a client socket with a PROTOCOL_TLS_SERVER context',
2978                str(e.exception)
2979            )
2980
2981        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
2982            with self.assertRaises(ssl.SSLError) as e:
2983                server_params_test(client_context=server_context,
2984                                   server_context=server_context,
2985                                   chatty=True, connectionchatty=True)
2986            self.assertIn(
2987                'Cannot create a client socket with a PROTOCOL_TLS_SERVER context',
2988                str(e.exception)
2989            )
2990
2991        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
2992            with self.assertRaises(ssl.SSLError) as e:
2993                server_params_test(client_context=server_context,
2994                                   server_context=client_context,
2995                                   chatty=True, connectionchatty=True)
2996            self.assertIn(
2997                'Cannot create a client socket with a PROTOCOL_TLS_SERVER context',
2998                str(e.exception))
2999
3000    def test_getpeercert(self):
3001        if support.verbose:
3002            sys.stdout.write("\n")
3003
3004        client_context, server_context, hostname = testing_context()
3005        server = ThreadedEchoServer(context=server_context, chatty=False)
3006        with server:
3007            with client_context.wrap_socket(socket.socket(),
3008                                            do_handshake_on_connect=False,
3009                                            server_hostname=hostname) as s:
3010                s.connect((HOST, server.port))
3011                # getpeercert() raise ValueError while the handshake isn't
3012                # done.
3013                with self.assertRaises(ValueError):
3014                    s.getpeercert()
3015                s.do_handshake()
3016                cert = s.getpeercert()
3017                self.assertTrue(cert, "Can't get peer certificate.")
3018                cipher = s.cipher()
3019                if support.verbose:
3020                    sys.stdout.write(pprint.pformat(cert) + '\n')
3021                    sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
3022                if 'subject' not in cert:
3023                    self.fail("No subject field in certificate: %s." %
3024                              pprint.pformat(cert))
3025                if ((('organizationName', 'Python Software Foundation'),)
3026                    not in cert['subject']):
3027                    self.fail(
3028                        "Missing or invalid 'organizationName' field in certificate subject; "
3029                        "should be 'Python Software Foundation'.")
3030                self.assertIn('notBefore', cert)
3031                self.assertIn('notAfter', cert)
3032                before = ssl.cert_time_to_seconds(cert['notBefore'])
3033                after = ssl.cert_time_to_seconds(cert['notAfter'])
3034                self.assertLess(before, after)
3035
3036    def test_crl_check(self):
3037        if support.verbose:
3038            sys.stdout.write("\n")
3039
3040        client_context, server_context, hostname = testing_context()
3041
3042        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
3043        self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf)
3044
3045        # VERIFY_DEFAULT should pass
3046        server = ThreadedEchoServer(context=server_context, chatty=True)
3047        with server:
3048            with client_context.wrap_socket(socket.socket(),
3049                                            server_hostname=hostname) as s:
3050                s.connect((HOST, server.port))
3051                cert = s.getpeercert()
3052                self.assertTrue(cert, "Can't get peer certificate.")
3053
3054        # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
3055        client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
3056
3057        server = ThreadedEchoServer(context=server_context, chatty=True)
3058        with server:
3059            with client_context.wrap_socket(socket.socket(),
3060                                            server_hostname=hostname) as s:
3061                with self.assertRaisesRegex(ssl.SSLError,
3062                                            "certificate verify failed"):
3063                    s.connect((HOST, server.port))
3064
3065        # now load a CRL file. The CRL file is signed by the CA.
3066        client_context.load_verify_locations(CRLFILE)
3067
3068        server = ThreadedEchoServer(context=server_context, chatty=True)
3069        with server:
3070            with client_context.wrap_socket(socket.socket(),
3071                                            server_hostname=hostname) as s:
3072                s.connect((HOST, server.port))
3073                cert = s.getpeercert()
3074                self.assertTrue(cert, "Can't get peer certificate.")
3075
3076    def test_check_hostname(self):
3077        if support.verbose:
3078            sys.stdout.write("\n")
3079
3080        client_context, server_context, hostname = testing_context()
3081
3082        # correct hostname should verify
3083        server = ThreadedEchoServer(context=server_context, chatty=True)
3084        with server:
3085            with client_context.wrap_socket(socket.socket(),
3086                                            server_hostname=hostname) as s:
3087                s.connect((HOST, server.port))
3088                cert = s.getpeercert()
3089                self.assertTrue(cert, "Can't get peer certificate.")
3090
3091        # incorrect hostname should raise an exception
3092        server = ThreadedEchoServer(context=server_context, chatty=True)
3093        with server:
3094            with client_context.wrap_socket(socket.socket(),
3095                                            server_hostname="invalid") as s:
3096                with self.assertRaisesRegex(
3097                        ssl.CertificateError,
3098                        "Hostname mismatch, certificate is not valid for 'invalid'."):
3099                    s.connect((HOST, server.port))
3100
3101        # missing server_hostname arg should cause an exception, too
3102        server = ThreadedEchoServer(context=server_context, chatty=True)
3103        with server:
3104            with socket.socket() as s:
3105                with self.assertRaisesRegex(ValueError,
3106                                            "check_hostname requires server_hostname"):
3107                    client_context.wrap_socket(s)
3108
3109    @unittest.skipUnless(
3110        ssl.HAS_NEVER_CHECK_COMMON_NAME, "test requires hostname_checks_common_name"
3111    )
3112    def test_hostname_checks_common_name(self):
3113        client_context, server_context, hostname = testing_context()
3114        assert client_context.hostname_checks_common_name
3115        client_context.hostname_checks_common_name = False
3116
3117        # default cert has a SAN
3118        server = ThreadedEchoServer(context=server_context, chatty=True)
3119        with server:
3120            with client_context.wrap_socket(socket.socket(),
3121                                            server_hostname=hostname) as s:
3122                s.connect((HOST, server.port))
3123
3124        client_context, server_context, hostname = testing_context(NOSANFILE)
3125        client_context.hostname_checks_common_name = False
3126        server = ThreadedEchoServer(context=server_context, chatty=True)
3127        with server:
3128            with client_context.wrap_socket(socket.socket(),
3129                                            server_hostname=hostname) as s:
3130                with self.assertRaises(ssl.SSLCertVerificationError):
3131                    s.connect((HOST, server.port))
3132
3133    def test_ecc_cert(self):
3134        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3135        client_context.load_verify_locations(SIGNING_CA)
3136        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
3137        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
3138
3139        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3140        # load ECC cert
3141        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
3142
3143        # correct hostname should verify
3144        server = ThreadedEchoServer(context=server_context, chatty=True)
3145        with server:
3146            with client_context.wrap_socket(socket.socket(),
3147                                            server_hostname=hostname) as s:
3148                s.connect((HOST, server.port))
3149                cert = s.getpeercert()
3150                self.assertTrue(cert, "Can't get peer certificate.")
3151                cipher = s.cipher()[0].split('-')
3152                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
3153
3154    def test_dual_rsa_ecc(self):
3155        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3156        client_context.load_verify_locations(SIGNING_CA)
3157        # TODO: fix TLSv1.3 once SSLContext can restrict signature
3158        #       algorithms.
3159        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3160        # only ECDSA certs
3161        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
3162        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
3163
3164        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3165        # load ECC and RSA key/cert pairs
3166        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
3167        server_context.load_cert_chain(SIGNED_CERTFILE)
3168
3169        # correct hostname should verify
3170        server = ThreadedEchoServer(context=server_context, chatty=True)
3171        with server:
3172            with client_context.wrap_socket(socket.socket(),
3173                                            server_hostname=hostname) as s:
3174                s.connect((HOST, server.port))
3175                cert = s.getpeercert()
3176                self.assertTrue(cert, "Can't get peer certificate.")
3177                cipher = s.cipher()[0].split('-')
3178                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
3179
3180    def test_check_hostname_idn(self):
3181        if support.verbose:
3182            sys.stdout.write("\n")
3183
3184        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3185        server_context.load_cert_chain(IDNSANSFILE)
3186
3187        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3188        context.verify_mode = ssl.CERT_REQUIRED
3189        context.check_hostname = True
3190        context.load_verify_locations(SIGNING_CA)
3191
3192        # correct hostname should verify, when specified in several
3193        # different ways
3194        idn_hostnames = [
3195            ('könig.idn.pythontest.net',
3196             'xn--knig-5qa.idn.pythontest.net'),
3197            ('xn--knig-5qa.idn.pythontest.net',
3198             'xn--knig-5qa.idn.pythontest.net'),
3199            (b'xn--knig-5qa.idn.pythontest.net',
3200             'xn--knig-5qa.idn.pythontest.net'),
3201
3202            ('königsgäßchen.idna2003.pythontest.net',
3203             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3204            ('xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3205             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3206            (b'xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3207             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3208
3209            # ('königsgäßchen.idna2008.pythontest.net',
3210            #  'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3211            ('xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3212             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3213            (b'xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3214             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3215
3216        ]
3217        for server_hostname, expected_hostname in idn_hostnames:
3218            server = ThreadedEchoServer(context=server_context, chatty=True)
3219            with server:
3220                with context.wrap_socket(socket.socket(),
3221                                         server_hostname=server_hostname) as s:
3222                    self.assertEqual(s.server_hostname, expected_hostname)
3223                    s.connect((HOST, server.port))
3224                    cert = s.getpeercert()
3225                    self.assertEqual(s.server_hostname, expected_hostname)
3226                    self.assertTrue(cert, "Can't get peer certificate.")
3227
3228        # incorrect hostname should raise an exception
3229        server = ThreadedEchoServer(context=server_context, chatty=True)
3230        with server:
3231            with context.wrap_socket(socket.socket(),
3232                                     server_hostname="python.example.org") as s:
3233                with self.assertRaises(ssl.CertificateError):
3234                    s.connect((HOST, server.port))
3235
3236    def test_wrong_cert_tls12(self):
3237        """Connecting when the server rejects the client's certificate
3238
3239        Launch a server with CERT_REQUIRED, and check that trying to
3240        connect to it with a wrong client certificate fails.
3241        """
3242        client_context, server_context, hostname = testing_context()
3243        # load client cert that is not signed by trusted CA
3244        client_context.load_cert_chain(CERTFILE)
3245        # require TLS client authentication
3246        server_context.verify_mode = ssl.CERT_REQUIRED
3247        # TLS 1.3 has different handshake
3248        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3249
3250        server = ThreadedEchoServer(
3251            context=server_context, chatty=True, connectionchatty=True,
3252        )
3253
3254        with server, \
3255                client_context.wrap_socket(socket.socket(),
3256                                           server_hostname=hostname) as s:
3257            try:
3258                # Expect either an SSL error about the server rejecting
3259                # the connection, or a low-level connection reset (which
3260                # sometimes happens on Windows)
3261                s.connect((HOST, server.port))
3262            except ssl.SSLError as e:
3263                if support.verbose:
3264                    sys.stdout.write("\nSSLError is %r\n" % e)
3265            except OSError as e:
3266                if e.errno != errno.ECONNRESET:
3267                    raise
3268                if support.verbose:
3269                    sys.stdout.write("\nsocket.error is %r\n" % e)
3270            else:
3271                self.fail("Use of invalid cert should have failed!")
3272
3273    @requires_tls_version('TLSv1_3')
3274    def test_wrong_cert_tls13(self):
3275        client_context, server_context, hostname = testing_context()
3276        # load client cert that is not signed by trusted CA
3277        client_context.load_cert_chain(CERTFILE)
3278        server_context.verify_mode = ssl.CERT_REQUIRED
3279        server_context.minimum_version = ssl.TLSVersion.TLSv1_3
3280        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
3281
3282        server = ThreadedEchoServer(
3283            context=server_context, chatty=True, connectionchatty=True,
3284        )
3285        with server, \
3286             client_context.wrap_socket(socket.socket(),
3287                                        server_hostname=hostname,
3288                                        suppress_ragged_eofs=False) as s:
3289            s.connect((HOST, server.port))
3290            with self.assertRaisesRegex(
3291                ssl.SSLError,
3292                'alert unknown ca|EOF occurred'
3293            ):
3294                # TLS 1.3 perform client cert exchange after handshake
3295                s.write(b'data')
3296                s.read(1000)
3297                s.write(b'should have failed already')
3298                s.read(1000)
3299
3300    def test_rude_shutdown(self):
3301        """A brutal shutdown of an SSL server should raise an OSError
3302        in the client when attempting handshake.
3303        """
3304        listener_ready = threading.Event()
3305        listener_gone = threading.Event()
3306
3307        s = socket.socket()
3308        port = socket_helper.bind_port(s, HOST)
3309
3310        # `listener` runs in a thread.  It sits in an accept() until
3311        # the main thread connects.  Then it rudely closes the socket,
3312        # and sets Event `listener_gone` to let the main thread know
3313        # the socket is gone.
3314        def listener():
3315            s.listen()
3316            listener_ready.set()
3317            newsock, addr = s.accept()
3318            newsock.close()
3319            s.close()
3320            listener_gone.set()
3321
3322        def connector():
3323            listener_ready.wait()
3324            with socket.socket() as c:
3325                c.connect((HOST, port))
3326                listener_gone.wait()
3327                try:
3328                    ssl_sock = test_wrap_socket(c)
3329                except OSError:
3330                    pass
3331                else:
3332                    self.fail('connecting to closed SSL socket should have failed')
3333
3334        t = threading.Thread(target=listener)
3335        t.start()
3336        try:
3337            connector()
3338        finally:
3339            t.join()
3340
3341    def test_ssl_cert_verify_error(self):
3342        if support.verbose:
3343            sys.stdout.write("\n")
3344
3345        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3346        server_context.load_cert_chain(SIGNED_CERTFILE)
3347
3348        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3349
3350        server = ThreadedEchoServer(context=server_context, chatty=True)
3351        with server:
3352            with context.wrap_socket(socket.socket(),
3353                                     server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
3354                try:
3355                    s.connect((HOST, server.port))
3356                except ssl.SSLError as e:
3357                    msg = 'unable to get local issuer certificate'
3358                    self.assertIsInstance(e, ssl.SSLCertVerificationError)
3359                    self.assertEqual(e.verify_code, 20)
3360                    self.assertEqual(e.verify_message, msg)
3361                    self.assertIn(msg, repr(e))
3362                    self.assertIn('certificate verify failed', repr(e))
3363
3364    @requires_tls_version('SSLv2')
3365    def test_protocol_sslv2(self):
3366        """Connecting to an SSLv2 server with various client options"""
3367        if support.verbose:
3368            sys.stdout.write("\n")
3369        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
3370        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
3371        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
3372        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False)
3373        if has_tls_version('SSLv3'):
3374            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
3375        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
3376        # SSLv23 client with specific SSL options
3377        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3378                           client_options=ssl.OP_NO_SSLv3)
3379        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3380                           client_options=ssl.OP_NO_TLSv1)
3381
3382    def test_PROTOCOL_TLS(self):
3383        """Connecting to an SSLv23 server with various client options"""
3384        if support.verbose:
3385            sys.stdout.write("\n")
3386        if has_tls_version('SSLv2'):
3387            try:
3388                try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv2, True)
3389            except OSError as x:
3390                # this fails on some older versions of OpenSSL (0.9.7l, for instance)
3391                if support.verbose:
3392                    sys.stdout.write(
3393                        " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
3394                        % str(x))
3395        if has_tls_version('SSLv3'):
3396            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False)
3397        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True)
3398        if has_tls_version('TLSv1'):
3399            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1')
3400
3401        if has_tls_version('SSLv3'):
3402            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
3403        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_OPTIONAL)
3404        if has_tls_version('TLSv1'):
3405            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3406
3407        if has_tls_version('SSLv3'):
3408            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
3409        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_REQUIRED)
3410        if has_tls_version('TLSv1'):
3411            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3412
3413        # Server with specific SSL options
3414        if has_tls_version('SSLv3'):
3415            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False,
3416                           server_options=ssl.OP_NO_SSLv3)
3417        # Will choose TLSv1
3418        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True,
3419                           server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
3420        if has_tls_version('TLSv1'):
3421            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, False,
3422                               server_options=ssl.OP_NO_TLSv1)
3423
3424    @requires_tls_version('SSLv3')
3425    def test_protocol_sslv3(self):
3426        """Connecting to an SSLv3 server with various client options"""
3427        if support.verbose:
3428            sys.stdout.write("\n")
3429        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
3430        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
3431        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
3432        if has_tls_version('SSLv2'):
3433            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
3434        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS, False,
3435                           client_options=ssl.OP_NO_SSLv3)
3436        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
3437
3438    @requires_tls_version('TLSv1')
3439    def test_protocol_tlsv1(self):
3440        """Connecting to a TLSv1 server with various client options"""
3441        if support.verbose:
3442            sys.stdout.write("\n")
3443        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
3444        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3445        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3446        if has_tls_version('SSLv2'):
3447            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
3448        if has_tls_version('SSLv3'):
3449            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
3450        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False,
3451                           client_options=ssl.OP_NO_TLSv1)
3452
3453    @requires_tls_version('TLSv1_1')
3454    def test_protocol_tlsv1_1(self):
3455        """Connecting to a TLSv1.1 server with various client options.
3456           Testing against older TLS versions."""
3457        if support.verbose:
3458            sys.stdout.write("\n")
3459        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3460        if has_tls_version('SSLv2'):
3461            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
3462        if has_tls_version('SSLv3'):
3463            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
3464        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False,
3465                           client_options=ssl.OP_NO_TLSv1_1)
3466
3467        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3468        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3469        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3470
3471    @requires_tls_version('TLSv1_2')
3472    def test_protocol_tlsv1_2(self):
3473        """Connecting to a TLSv1.2 server with various client options.
3474           Testing against older TLS versions."""
3475        if support.verbose:
3476            sys.stdout.write("\n")
3477        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
3478                           server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
3479                           client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
3480        if has_tls_version('SSLv2'):
3481            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
3482        if has_tls_version('SSLv3'):
3483            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
3484        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False,
3485                           client_options=ssl.OP_NO_TLSv1_2)
3486
3487        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
3488        if has_tls_protocol(ssl.PROTOCOL_TLSv1):
3489            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
3490            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
3491        if has_tls_protocol(ssl.PROTOCOL_TLSv1_1):
3492            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3493            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3494
3495    def test_starttls(self):
3496        """Switching from clear text to encrypted and back again."""
3497        msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
3498
3499        server = ThreadedEchoServer(CERTFILE,
3500                                    starttls_server=True,
3501                                    chatty=True,
3502                                    connectionchatty=True)
3503        wrapped = False
3504        with server:
3505            s = socket.socket()
3506            s.setblocking(True)
3507            s.connect((HOST, server.port))
3508            if support.verbose:
3509                sys.stdout.write("\n")
3510            for indata in msgs:
3511                if support.verbose:
3512                    sys.stdout.write(
3513                        " client:  sending %r...\n" % indata)
3514                if wrapped:
3515                    conn.write(indata)
3516                    outdata = conn.read()
3517                else:
3518                    s.send(indata)
3519                    outdata = s.recv(1024)
3520                msg = outdata.strip().lower()
3521                if indata == b"STARTTLS" and msg.startswith(b"ok"):
3522                    # STARTTLS ok, switch to secure mode
3523                    if support.verbose:
3524                        sys.stdout.write(
3525                            " client:  read %r from server, starting TLS...\n"
3526                            % msg)
3527                    conn = test_wrap_socket(s)
3528                    wrapped = True
3529                elif indata == b"ENDTLS" and msg.startswith(b"ok"):
3530                    # ENDTLS ok, switch back to clear text
3531                    if support.verbose:
3532                        sys.stdout.write(
3533                            " client:  read %r from server, ending TLS...\n"
3534                            % msg)
3535                    s = conn.unwrap()
3536                    wrapped = False
3537                else:
3538                    if support.verbose:
3539                        sys.stdout.write(
3540                            " client:  read %r from server\n" % msg)
3541            if support.verbose:
3542                sys.stdout.write(" client:  closing connection.\n")
3543            if wrapped:
3544                conn.write(b"over\n")
3545            else:
3546                s.send(b"over\n")
3547            if wrapped:
3548                conn.close()
3549            else:
3550                s.close()
3551
3552    def test_socketserver(self):
3553        """Using socketserver to create and manage SSL connections."""
3554        server = make_https_server(self, certfile=SIGNED_CERTFILE)
3555        # try to connect
3556        if support.verbose:
3557            sys.stdout.write('\n')
3558        with open(CERTFILE, 'rb') as f:
3559            d1 = f.read()
3560        d2 = ''
3561        # now fetch the same data from the HTTPS server
3562        url = 'https://localhost:%d/%s' % (
3563            server.port, os.path.split(CERTFILE)[1])
3564        context = ssl.create_default_context(cafile=SIGNING_CA)
3565        f = urllib.request.urlopen(url, context=context)
3566        try:
3567            dlen = f.info().get("content-length")
3568            if dlen and (int(dlen) > 0):
3569                d2 = f.read(int(dlen))
3570                if support.verbose:
3571                    sys.stdout.write(
3572                        " client: read %d bytes from remote server '%s'\n"
3573                        % (len(d2), server))
3574        finally:
3575            f.close()
3576        self.assertEqual(d1, d2)
3577
3578    def test_asyncore_server(self):
3579        """Check the example asyncore integration."""
3580        if support.verbose:
3581            sys.stdout.write("\n")
3582
3583        indata = b"FOO\n"
3584        server = AsyncoreEchoServer(CERTFILE)
3585        with server:
3586            s = test_wrap_socket(socket.socket())
3587            s.connect(('127.0.0.1', server.port))
3588            if support.verbose:
3589                sys.stdout.write(
3590                    " client:  sending %r...\n" % indata)
3591            s.write(indata)
3592            outdata = s.read()
3593            if support.verbose:
3594                sys.stdout.write(" client:  read %r\n" % outdata)
3595            if outdata != indata.lower():
3596                self.fail(
3597                    "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
3598                    % (outdata[:20], len(outdata),
3599                       indata[:20].lower(), len(indata)))
3600            s.write(b"over\n")
3601            if support.verbose:
3602                sys.stdout.write(" client:  closing connection.\n")
3603            s.close()
3604            if support.verbose:
3605                sys.stdout.write(" client:  connection closed.\n")
3606
3607    def test_recv_send(self):
3608        """Test recv(), send() and friends."""
3609        if support.verbose:
3610            sys.stdout.write("\n")
3611
3612        server = ThreadedEchoServer(CERTFILE,
3613                                    certreqs=ssl.CERT_NONE,
3614                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3615                                    cacerts=CERTFILE,
3616                                    chatty=True,
3617                                    connectionchatty=False)
3618        with server:
3619            s = test_wrap_socket(socket.socket(),
3620                                server_side=False,
3621                                certfile=CERTFILE,
3622                                ca_certs=CERTFILE,
3623                                cert_reqs=ssl.CERT_NONE)
3624            s.connect((HOST, server.port))
3625            # helper methods for standardising recv* method signatures
3626            def _recv_into():
3627                b = bytearray(b"\0"*100)
3628                count = s.recv_into(b)
3629                return b[:count]
3630
3631            def _recvfrom_into():
3632                b = bytearray(b"\0"*100)
3633                count, addr = s.recvfrom_into(b)
3634                return b[:count]
3635
3636            # (name, method, expect success?, *args, return value func)
3637            send_methods = [
3638                ('send', s.send, True, [], len),
3639                ('sendto', s.sendto, False, ["some.address"], len),
3640                ('sendall', s.sendall, True, [], lambda x: None),
3641            ]
3642            # (name, method, whether to expect success, *args)
3643            recv_methods = [
3644                ('recv', s.recv, True, []),
3645                ('recvfrom', s.recvfrom, False, ["some.address"]),
3646                ('recv_into', _recv_into, True, []),
3647                ('recvfrom_into', _recvfrom_into, False, []),
3648            ]
3649            data_prefix = "PREFIX_"
3650
3651            for (meth_name, send_meth, expect_success, args,
3652                    ret_val_meth) in send_methods:
3653                indata = (data_prefix + meth_name).encode('ascii')
3654                try:
3655                    ret = send_meth(indata, *args)
3656                    msg = "sending with {}".format(meth_name)
3657                    self.assertEqual(ret, ret_val_meth(indata), msg=msg)
3658                    outdata = s.read()
3659                    if outdata != indata.lower():
3660                        self.fail(
3661                            "While sending with <<{name:s}>> bad data "
3662                            "<<{outdata:r}>> ({nout:d}) received; "
3663                            "expected <<{indata:r}>> ({nin:d})\n".format(
3664                                name=meth_name, outdata=outdata[:20],
3665                                nout=len(outdata),
3666                                indata=indata[:20], nin=len(indata)
3667                            )
3668                        )
3669                except ValueError as e:
3670                    if expect_success:
3671                        self.fail(
3672                            "Failed to send with method <<{name:s}>>; "
3673                            "expected to succeed.\n".format(name=meth_name)
3674                        )
3675                    if not str(e).startswith(meth_name):
3676                        self.fail(
3677                            "Method <<{name:s}>> failed with unexpected "
3678                            "exception message: {exp:s}\n".format(
3679                                name=meth_name, exp=e
3680                            )
3681                        )
3682
3683            for meth_name, recv_meth, expect_success, args in recv_methods:
3684                indata = (data_prefix + meth_name).encode('ascii')
3685                try:
3686                    s.send(indata)
3687                    outdata = recv_meth(*args)
3688                    if outdata != indata.lower():
3689                        self.fail(
3690                            "While receiving with <<{name:s}>> bad data "
3691                            "<<{outdata:r}>> ({nout:d}) received; "
3692                            "expected <<{indata:r}>> ({nin:d})\n".format(
3693                                name=meth_name, outdata=outdata[:20],
3694                                nout=len(outdata),
3695                                indata=indata[:20], nin=len(indata)
3696                            )
3697                        )
3698                except ValueError as e:
3699                    if expect_success:
3700                        self.fail(
3701                            "Failed to receive with method <<{name:s}>>; "
3702                            "expected to succeed.\n".format(name=meth_name)
3703                        )
3704                    if not str(e).startswith(meth_name):
3705                        self.fail(
3706                            "Method <<{name:s}>> failed with unexpected "
3707                            "exception message: {exp:s}\n".format(
3708                                name=meth_name, exp=e
3709                            )
3710                        )
3711                    # consume data
3712                    s.read()
3713
3714            # read(-1, buffer) is supported, even though read(-1) is not
3715            data = b"data"
3716            s.send(data)
3717            buffer = bytearray(len(data))
3718            self.assertEqual(s.read(-1, buffer), len(data))
3719            self.assertEqual(buffer, data)
3720
3721            # sendall accepts bytes-like objects
3722            if ctypes is not None:
3723                ubyte = ctypes.c_ubyte * len(data)
3724                byteslike = ubyte.from_buffer_copy(data)
3725                s.sendall(byteslike)
3726                self.assertEqual(s.read(), data)
3727
3728            # Make sure sendmsg et al are disallowed to avoid
3729            # inadvertent disclosure of data and/or corruption
3730            # of the encrypted data stream
3731            self.assertRaises(NotImplementedError, s.dup)
3732            self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
3733            self.assertRaises(NotImplementedError, s.recvmsg, 100)
3734            self.assertRaises(NotImplementedError,
3735                              s.recvmsg_into, [bytearray(100)])
3736            s.write(b"over\n")
3737
3738            self.assertRaises(ValueError, s.recv, -1)
3739            self.assertRaises(ValueError, s.read, -1)
3740
3741            s.close()
3742
3743    def test_recv_zero(self):
3744        server = ThreadedEchoServer(CERTFILE)
3745        self.enterContext(server)
3746        s = socket.create_connection((HOST, server.port))
3747        self.addCleanup(s.close)
3748        s = test_wrap_socket(s, suppress_ragged_eofs=False)
3749        self.addCleanup(s.close)
3750
3751        # recv/read(0) should return no data
3752        s.send(b"data")
3753        self.assertEqual(s.recv(0), b"")
3754        self.assertEqual(s.read(0), b"")
3755        self.assertEqual(s.read(), b"data")
3756
3757        # Should not block if the other end sends no data
3758        s.setblocking(False)
3759        self.assertEqual(s.recv(0), b"")
3760        self.assertEqual(s.recv_into(bytearray()), 0)
3761
3762    def test_nonblocking_send(self):
3763        server = ThreadedEchoServer(CERTFILE,
3764                                    certreqs=ssl.CERT_NONE,
3765                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3766                                    cacerts=CERTFILE,
3767                                    chatty=True,
3768                                    connectionchatty=False)
3769        with server:
3770            s = test_wrap_socket(socket.socket(),
3771                                server_side=False,
3772                                certfile=CERTFILE,
3773                                ca_certs=CERTFILE,
3774                                cert_reqs=ssl.CERT_NONE)
3775            s.connect((HOST, server.port))
3776            s.setblocking(False)
3777
3778            # If we keep sending data, at some point the buffers
3779            # will be full and the call will block
3780            buf = bytearray(8192)
3781            def fill_buffer():
3782                while True:
3783                    s.send(buf)
3784            self.assertRaises((ssl.SSLWantWriteError,
3785                               ssl.SSLWantReadError), fill_buffer)
3786
3787            # Now read all the output and discard it
3788            s.setblocking(True)
3789            s.close()
3790
3791    def test_handshake_timeout(self):
3792        # Issue #5103: SSL handshake must respect the socket timeout
3793        server = socket.socket(socket.AF_INET)
3794        host = "127.0.0.1"
3795        port = socket_helper.bind_port(server)
3796        started = threading.Event()
3797        finish = False
3798
3799        def serve():
3800            server.listen()
3801            started.set()
3802            conns = []
3803            while not finish:
3804                r, w, e = select.select([server], [], [], 0.1)
3805                if server in r:
3806                    # Let the socket hang around rather than having
3807                    # it closed by garbage collection.
3808                    conns.append(server.accept()[0])
3809            for sock in conns:
3810                sock.close()
3811
3812        t = threading.Thread(target=serve)
3813        t.start()
3814        started.wait()
3815
3816        try:
3817            try:
3818                c = socket.socket(socket.AF_INET)
3819                c.settimeout(0.2)
3820                c.connect((host, port))
3821                # Will attempt handshake and time out
3822                self.assertRaisesRegex(TimeoutError, "timed out",
3823                                       test_wrap_socket, c)
3824            finally:
3825                c.close()
3826            try:
3827                c = socket.socket(socket.AF_INET)
3828                c = test_wrap_socket(c)
3829                c.settimeout(0.2)
3830                # Will attempt handshake and time out
3831                self.assertRaisesRegex(TimeoutError, "timed out",
3832                                       c.connect, (host, port))
3833            finally:
3834                c.close()
3835        finally:
3836            finish = True
3837            t.join()
3838            server.close()
3839
3840    def test_server_accept(self):
3841        # Issue #16357: accept() on a SSLSocket created through
3842        # SSLContext.wrap_socket().
3843        client_ctx, server_ctx, hostname = testing_context()
3844        server = socket.socket(socket.AF_INET)
3845        host = "127.0.0.1"
3846        port = socket_helper.bind_port(server)
3847        server = server_ctx.wrap_socket(server, server_side=True)
3848        self.assertTrue(server.server_side)
3849
3850        evt = threading.Event()
3851        remote = None
3852        peer = None
3853        def serve():
3854            nonlocal remote, peer
3855            server.listen()
3856            # Block on the accept and wait on the connection to close.
3857            evt.set()
3858            remote, peer = server.accept()
3859            remote.send(remote.recv(4))
3860
3861        t = threading.Thread(target=serve)
3862        t.start()
3863        # Client wait until server setup and perform a connect.
3864        evt.wait()
3865        client = client_ctx.wrap_socket(
3866            socket.socket(), server_hostname=hostname
3867        )
3868        client.connect((hostname, port))
3869        client.send(b'data')
3870        client.recv()
3871        client_addr = client.getsockname()
3872        client.close()
3873        t.join()
3874        remote.close()
3875        server.close()
3876        # Sanity checks.
3877        self.assertIsInstance(remote, ssl.SSLSocket)
3878        self.assertEqual(peer, client_addr)
3879
3880    def test_getpeercert_enotconn(self):
3881        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3882        context.check_hostname = False
3883        with context.wrap_socket(socket.socket()) as sock:
3884            with self.assertRaises(OSError) as cm:
3885                sock.getpeercert()
3886            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3887
3888    def test_do_handshake_enotconn(self):
3889        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3890        context.check_hostname = False
3891        with context.wrap_socket(socket.socket()) as sock:
3892            with self.assertRaises(OSError) as cm:
3893                sock.do_handshake()
3894            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3895
3896    def test_no_shared_ciphers(self):
3897        client_context, server_context, hostname = testing_context()
3898        # OpenSSL enables all TLS 1.3 ciphers, enforce TLS 1.2 for test
3899        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3900        # Force different suites on client and server
3901        client_context.set_ciphers("AES128")
3902        server_context.set_ciphers("AES256")
3903        with ThreadedEchoServer(context=server_context) as server:
3904            with client_context.wrap_socket(socket.socket(),
3905                                            server_hostname=hostname) as s:
3906                with self.assertRaises(OSError):
3907                    s.connect((HOST, server.port))
3908        self.assertIn("no shared cipher", server.conn_errors[0])
3909
3910    def test_version_basic(self):
3911        """
3912        Basic tests for SSLSocket.version().
3913        More tests are done in the test_protocol_*() methods.
3914        """
3915        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3916        context.check_hostname = False
3917        context.verify_mode = ssl.CERT_NONE
3918        with ThreadedEchoServer(CERTFILE,
3919                                ssl_version=ssl.PROTOCOL_TLS_SERVER,
3920                                chatty=False) as server:
3921            with context.wrap_socket(socket.socket()) as s:
3922                self.assertIs(s.version(), None)
3923                self.assertIs(s._sslobj, None)
3924                s.connect((HOST, server.port))
3925                self.assertEqual(s.version(), 'TLSv1.3')
3926            self.assertIs(s._sslobj, None)
3927            self.assertIs(s.version(), None)
3928
3929    @requires_tls_version('TLSv1_3')
3930    def test_tls1_3(self):
3931        client_context, server_context, hostname = testing_context()
3932        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
3933        with ThreadedEchoServer(context=server_context) as server:
3934            with client_context.wrap_socket(socket.socket(),
3935                                            server_hostname=hostname) as s:
3936                s.connect((HOST, server.port))
3937                self.assertIn(s.cipher()[0], {
3938                    'TLS_AES_256_GCM_SHA384',
3939                    'TLS_CHACHA20_POLY1305_SHA256',
3940                    'TLS_AES_128_GCM_SHA256',
3941                })
3942                self.assertEqual(s.version(), 'TLSv1.3')
3943
3944    @requires_tls_version('TLSv1_2')
3945    @requires_tls_version('TLSv1')
3946    @ignore_deprecation
3947    def test_min_max_version_tlsv1_2(self):
3948        client_context, server_context, hostname = testing_context()
3949        # client TLSv1.0 to 1.2
3950        client_context.minimum_version = ssl.TLSVersion.TLSv1
3951        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3952        # server only TLSv1.2
3953        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3954        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3955
3956        with ThreadedEchoServer(context=server_context) as server:
3957            with client_context.wrap_socket(socket.socket(),
3958                                            server_hostname=hostname) as s:
3959                s.connect((HOST, server.port))
3960                self.assertEqual(s.version(), 'TLSv1.2')
3961
3962    @requires_tls_version('TLSv1_1')
3963    @ignore_deprecation
3964    def test_min_max_version_tlsv1_1(self):
3965        client_context, server_context, hostname = testing_context()
3966        # client 1.0 to 1.2, server 1.0 to 1.1
3967        client_context.minimum_version = ssl.TLSVersion.TLSv1
3968        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3969        server_context.minimum_version = ssl.TLSVersion.TLSv1
3970        server_context.maximum_version = ssl.TLSVersion.TLSv1_1
3971        seclevel_workaround(client_context, server_context)
3972
3973        with ThreadedEchoServer(context=server_context) as server:
3974            with client_context.wrap_socket(socket.socket(),
3975                                            server_hostname=hostname) as s:
3976                s.connect((HOST, server.port))
3977                self.assertEqual(s.version(), 'TLSv1.1')
3978
3979    @requires_tls_version('TLSv1_2')
3980    @requires_tls_version('TLSv1')
3981    @ignore_deprecation
3982    def test_min_max_version_mismatch(self):
3983        client_context, server_context, hostname = testing_context()
3984        # client 1.0, server 1.2 (mismatch)
3985        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3986        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3987        client_context.maximum_version = ssl.TLSVersion.TLSv1
3988        client_context.minimum_version = ssl.TLSVersion.TLSv1
3989        seclevel_workaround(client_context, server_context)
3990
3991        with ThreadedEchoServer(context=server_context) as server:
3992            with client_context.wrap_socket(socket.socket(),
3993                                            server_hostname=hostname) as s:
3994                with self.assertRaises(ssl.SSLError) as e:
3995                    s.connect((HOST, server.port))
3996                self.assertIn("alert", str(e.exception))
3997
3998    @requires_tls_version('SSLv3')
3999    def test_min_max_version_sslv3(self):
4000        client_context, server_context, hostname = testing_context()
4001        server_context.minimum_version = ssl.TLSVersion.SSLv3
4002        client_context.minimum_version = ssl.TLSVersion.SSLv3
4003        client_context.maximum_version = ssl.TLSVersion.SSLv3
4004        seclevel_workaround(client_context, server_context)
4005
4006        with ThreadedEchoServer(context=server_context) as server:
4007            with client_context.wrap_socket(socket.socket(),
4008                                            server_hostname=hostname) as s:
4009                s.connect((HOST, server.port))
4010                self.assertEqual(s.version(), 'SSLv3')
4011
4012    def test_default_ecdh_curve(self):
4013        # Issue #21015: elliptic curve-based Diffie Hellman key exchange
4014        # should be enabled by default on SSL contexts.
4015        client_context, server_context, hostname = testing_context()
4016        # TLSv1.3 defaults to PFS key agreement and no longer has KEA in
4017        # cipher name.
4018        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4019        # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
4020        # explicitly using the 'ECCdraft' cipher alias.  Otherwise,
4021        # our default cipher list should prefer ECDH-based ciphers
4022        # automatically.
4023        with ThreadedEchoServer(context=server_context) as server:
4024            with client_context.wrap_socket(socket.socket(),
4025                                            server_hostname=hostname) as s:
4026                s.connect((HOST, server.port))
4027                self.assertIn("ECDH", s.cipher()[0])
4028
4029    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
4030                         "'tls-unique' channel binding not available")
4031    def test_tls_unique_channel_binding(self):
4032        """Test tls-unique channel binding."""
4033        if support.verbose:
4034            sys.stdout.write("\n")
4035
4036        client_context, server_context, hostname = testing_context()
4037
4038        server = ThreadedEchoServer(context=server_context,
4039                                    chatty=True,
4040                                    connectionchatty=False)
4041
4042        with server:
4043            with client_context.wrap_socket(
4044                    socket.socket(),
4045                    server_hostname=hostname) as s:
4046                s.connect((HOST, server.port))
4047                # get the data
4048                cb_data = s.get_channel_binding("tls-unique")
4049                if support.verbose:
4050                    sys.stdout.write(
4051                        " got channel binding data: {0!r}\n".format(cb_data))
4052
4053                # check if it is sane
4054                self.assertIsNotNone(cb_data)
4055                if s.version() == 'TLSv1.3':
4056                    self.assertEqual(len(cb_data), 48)
4057                else:
4058                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
4059
4060                # and compare with the peers version
4061                s.write(b"CB tls-unique\n")
4062                peer_data_repr = s.read().strip()
4063                self.assertEqual(peer_data_repr,
4064                                 repr(cb_data).encode("us-ascii"))
4065
4066            # now, again
4067            with client_context.wrap_socket(
4068                    socket.socket(),
4069                    server_hostname=hostname) as s:
4070                s.connect((HOST, server.port))
4071                new_cb_data = s.get_channel_binding("tls-unique")
4072                if support.verbose:
4073                    sys.stdout.write(
4074                        "got another channel binding data: {0!r}\n".format(
4075                            new_cb_data)
4076                    )
4077                # is it really unique
4078                self.assertNotEqual(cb_data, new_cb_data)
4079                self.assertIsNotNone(cb_data)
4080                if s.version() == 'TLSv1.3':
4081                    self.assertEqual(len(cb_data), 48)
4082                else:
4083                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
4084                s.write(b"CB tls-unique\n")
4085                peer_data_repr = s.read().strip()
4086                self.assertEqual(peer_data_repr,
4087                                 repr(new_cb_data).encode("us-ascii"))
4088
4089    def test_compression(self):
4090        client_context, server_context, hostname = testing_context()
4091        stats = server_params_test(client_context, server_context,
4092                                   chatty=True, connectionchatty=True,
4093                                   sni_name=hostname)
4094        if support.verbose:
4095            sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
4096        self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
4097
4098    @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
4099                         "ssl.OP_NO_COMPRESSION needed for this test")
4100    def test_compression_disabled(self):
4101        client_context, server_context, hostname = testing_context()
4102        client_context.options |= ssl.OP_NO_COMPRESSION
4103        server_context.options |= ssl.OP_NO_COMPRESSION
4104        stats = server_params_test(client_context, server_context,
4105                                   chatty=True, connectionchatty=True,
4106                                   sni_name=hostname)
4107        self.assertIs(stats['compression'], None)
4108
4109    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4110    def test_dh_params(self):
4111        # Check we can get a connection with ephemeral Diffie-Hellman
4112        client_context, server_context, hostname = testing_context()
4113        # test scenario needs TLS <= 1.2
4114        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4115        server_context.load_dh_params(DHFILE)
4116        server_context.set_ciphers("kEDH")
4117        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
4118        stats = server_params_test(client_context, server_context,
4119                                   chatty=True, connectionchatty=True,
4120                                   sni_name=hostname)
4121        cipher = stats["cipher"][0]
4122        parts = cipher.split("-")
4123        if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
4124            self.fail("Non-DH cipher: " + cipher[0])
4125
4126    def test_ecdh_curve(self):
4127        # server secp384r1, client auto
4128        client_context, server_context, hostname = testing_context()
4129
4130        server_context.set_ecdh_curve("secp384r1")
4131        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4132        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
4133        stats = server_params_test(client_context, server_context,
4134                                   chatty=True, connectionchatty=True,
4135                                   sni_name=hostname)
4136
4137        # server auto, client secp384r1
4138        client_context, server_context, hostname = testing_context()
4139        client_context.set_ecdh_curve("secp384r1")
4140        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4141        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
4142        stats = server_params_test(client_context, server_context,
4143                                   chatty=True, connectionchatty=True,
4144                                   sni_name=hostname)
4145
4146        # server / client curve mismatch
4147        client_context, server_context, hostname = testing_context()
4148        client_context.set_ecdh_curve("prime256v1")
4149        server_context.set_ecdh_curve("secp384r1")
4150        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4151        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
4152        with self.assertRaises(ssl.SSLError):
4153            server_params_test(client_context, server_context,
4154                               chatty=True, connectionchatty=True,
4155                               sni_name=hostname)
4156
4157    def test_selected_alpn_protocol(self):
4158        # selected_alpn_protocol() is None unless ALPN is used.
4159        client_context, server_context, hostname = testing_context()
4160        stats = server_params_test(client_context, server_context,
4161                                   chatty=True, connectionchatty=True,
4162                                   sni_name=hostname)
4163        self.assertIs(stats['client_alpn_protocol'], None)
4164
4165    def test_selected_alpn_protocol_if_server_uses_alpn(self):
4166        # selected_alpn_protocol() is None unless ALPN is used by the client.
4167        client_context, server_context, hostname = testing_context()
4168        server_context.set_alpn_protocols(['foo', 'bar'])
4169        stats = server_params_test(client_context, server_context,
4170                                   chatty=True, connectionchatty=True,
4171                                   sni_name=hostname)
4172        self.assertIs(stats['client_alpn_protocol'], None)
4173
4174    def test_alpn_protocols(self):
4175        server_protocols = ['foo', 'bar', 'milkshake']
4176        protocol_tests = [
4177            (['foo', 'bar'], 'foo'),
4178            (['bar', 'foo'], 'foo'),
4179            (['milkshake'], 'milkshake'),
4180            (['http/3.0', 'http/4.0'], None)
4181        ]
4182        for client_protocols, expected in protocol_tests:
4183            client_context, server_context, hostname = testing_context()
4184            server_context.set_alpn_protocols(server_protocols)
4185            client_context.set_alpn_protocols(client_protocols)
4186
4187            try:
4188                stats = server_params_test(client_context,
4189                                           server_context,
4190                                           chatty=True,
4191                                           connectionchatty=True,
4192                                           sni_name=hostname)
4193            except ssl.SSLError as e:
4194                stats = e
4195
4196            msg = "failed trying %s (s) and %s (c).\n" \
4197                "was expecting %s, but got %%s from the %%s" \
4198                    % (str(server_protocols), str(client_protocols),
4199                        str(expected))
4200            client_result = stats['client_alpn_protocol']
4201            self.assertEqual(client_result, expected,
4202                             msg % (client_result, "client"))
4203            server_result = stats['server_alpn_protocols'][-1] \
4204                if len(stats['server_alpn_protocols']) else 'nothing'
4205            self.assertEqual(server_result, expected,
4206                             msg % (server_result, "server"))
4207
4208    def test_npn_protocols(self):
4209        assert not ssl.HAS_NPN
4210
4211    def sni_contexts(self):
4212        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4213        server_context.load_cert_chain(SIGNED_CERTFILE)
4214        other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4215        other_context.load_cert_chain(SIGNED_CERTFILE2)
4216        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4217        client_context.load_verify_locations(SIGNING_CA)
4218        return server_context, other_context, client_context
4219
4220    def check_common_name(self, stats, name):
4221        cert = stats['peercert']
4222        self.assertIn((('commonName', name),), cert['subject'])
4223
4224    def test_sni_callback(self):
4225        calls = []
4226        server_context, other_context, client_context = self.sni_contexts()
4227
4228        client_context.check_hostname = False
4229
4230        def servername_cb(ssl_sock, server_name, initial_context):
4231            calls.append((server_name, initial_context))
4232            if server_name is not None:
4233                ssl_sock.context = other_context
4234        server_context.set_servername_callback(servername_cb)
4235
4236        stats = server_params_test(client_context, server_context,
4237                                   chatty=True,
4238                                   sni_name='supermessage')
4239        # The hostname was fetched properly, and the certificate was
4240        # changed for the connection.
4241        self.assertEqual(calls, [("supermessage", server_context)])
4242        # CERTFILE4 was selected
4243        self.check_common_name(stats, 'fakehostname')
4244
4245        calls = []
4246        # The callback is called with server_name=None
4247        stats = server_params_test(client_context, server_context,
4248                                   chatty=True,
4249                                   sni_name=None)
4250        self.assertEqual(calls, [(None, server_context)])
4251        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4252
4253        # Check disabling the callback
4254        calls = []
4255        server_context.set_servername_callback(None)
4256
4257        stats = server_params_test(client_context, server_context,
4258                                   chatty=True,
4259                                   sni_name='notfunny')
4260        # Certificate didn't change
4261        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4262        self.assertEqual(calls, [])
4263
4264    def test_sni_callback_alert(self):
4265        # Returning a TLS alert is reflected to the connecting client
4266        server_context, other_context, client_context = self.sni_contexts()
4267
4268        def cb_returning_alert(ssl_sock, server_name, initial_context):
4269            return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
4270        server_context.set_servername_callback(cb_returning_alert)
4271        with self.assertRaises(ssl.SSLError) as cm:
4272            stats = server_params_test(client_context, server_context,
4273                                       chatty=False,
4274                                       sni_name='supermessage')
4275        self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
4276
4277    def test_sni_callback_raising(self):
4278        # Raising fails the connection with a TLS handshake failure alert.
4279        server_context, other_context, client_context = self.sni_contexts()
4280
4281        def cb_raising(ssl_sock, server_name, initial_context):
4282            1/0
4283        server_context.set_servername_callback(cb_raising)
4284
4285        with support.catch_unraisable_exception() as catch:
4286            with self.assertRaises(ssl.SSLError) as cm:
4287                stats = server_params_test(client_context, server_context,
4288                                           chatty=False,
4289                                           sni_name='supermessage')
4290
4291            self.assertEqual(cm.exception.reason,
4292                             'SSLV3_ALERT_HANDSHAKE_FAILURE')
4293            self.assertEqual(catch.unraisable.exc_type, ZeroDivisionError)
4294
4295    def test_sni_callback_wrong_return_type(self):
4296        # Returning the wrong return type terminates the TLS connection
4297        # with an internal error alert.
4298        server_context, other_context, client_context = self.sni_contexts()
4299
4300        def cb_wrong_return_type(ssl_sock, server_name, initial_context):
4301            return "foo"
4302        server_context.set_servername_callback(cb_wrong_return_type)
4303
4304        with support.catch_unraisable_exception() as catch:
4305            with self.assertRaises(ssl.SSLError) as cm:
4306                stats = server_params_test(client_context, server_context,
4307                                           chatty=False,
4308                                           sni_name='supermessage')
4309
4310
4311            self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
4312            self.assertEqual(catch.unraisable.exc_type, TypeError)
4313
4314    def test_shared_ciphers(self):
4315        client_context, server_context, hostname = testing_context()
4316        client_context.set_ciphers("AES128:AES256")
4317        server_context.set_ciphers("AES256:eNULL")
4318        expected_algs = [
4319            "AES256", "AES-256",
4320            # TLS 1.3 ciphers are always enabled
4321            "TLS_CHACHA20", "TLS_AES",
4322        ]
4323
4324        stats = server_params_test(client_context, server_context,
4325                                   sni_name=hostname)
4326        ciphers = stats['server_shared_ciphers'][0]
4327        self.assertGreater(len(ciphers), 0)
4328        for name, tls_version, bits in ciphers:
4329            if not any(alg in name for alg in expected_algs):
4330                self.fail(name)
4331
4332    def test_read_write_after_close_raises_valuerror(self):
4333        client_context, server_context, hostname = testing_context()
4334        server = ThreadedEchoServer(context=server_context, chatty=False)
4335
4336        with server:
4337            s = client_context.wrap_socket(socket.socket(),
4338                                           server_hostname=hostname)
4339            s.connect((HOST, server.port))
4340            s.close()
4341
4342            self.assertRaises(ValueError, s.read, 1024)
4343            self.assertRaises(ValueError, s.write, b'hello')
4344
4345    def test_sendfile(self):
4346        TEST_DATA = b"x" * 512
4347        with open(os_helper.TESTFN, 'wb') as f:
4348            f.write(TEST_DATA)
4349        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4350        client_context, server_context, hostname = testing_context()
4351        server = ThreadedEchoServer(context=server_context, chatty=False)
4352        with server:
4353            with client_context.wrap_socket(socket.socket(),
4354                                            server_hostname=hostname) as s:
4355                s.connect((HOST, server.port))
4356                with open(os_helper.TESTFN, 'rb') as file:
4357                    s.sendfile(file)
4358                    self.assertEqual(s.recv(1024), TEST_DATA)
4359
4360    def test_session(self):
4361        client_context, server_context, hostname = testing_context()
4362        # TODO: sessions aren't compatible with TLSv1.3 yet
4363        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4364
4365        # first connection without session
4366        stats = server_params_test(client_context, server_context,
4367                                   sni_name=hostname)
4368        session = stats['session']
4369        self.assertTrue(session.id)
4370        self.assertGreater(session.time, 0)
4371        self.assertGreater(session.timeout, 0)
4372        self.assertTrue(session.has_ticket)
4373        self.assertGreater(session.ticket_lifetime_hint, 0)
4374        self.assertFalse(stats['session_reused'])
4375        sess_stat = server_context.session_stats()
4376        self.assertEqual(sess_stat['accept'], 1)
4377        self.assertEqual(sess_stat['hits'], 0)
4378
4379        # reuse session
4380        stats = server_params_test(client_context, server_context,
4381                                   session=session, sni_name=hostname)
4382        sess_stat = server_context.session_stats()
4383        self.assertEqual(sess_stat['accept'], 2)
4384        self.assertEqual(sess_stat['hits'], 1)
4385        self.assertTrue(stats['session_reused'])
4386        session2 = stats['session']
4387        self.assertEqual(session2.id, session.id)
4388        self.assertEqual(session2, session)
4389        self.assertIsNot(session2, session)
4390        self.assertGreaterEqual(session2.time, session.time)
4391        self.assertGreaterEqual(session2.timeout, session.timeout)
4392
4393        # another one without session
4394        stats = server_params_test(client_context, server_context,
4395                                   sni_name=hostname)
4396        self.assertFalse(stats['session_reused'])
4397        session3 = stats['session']
4398        self.assertNotEqual(session3.id, session.id)
4399        self.assertNotEqual(session3, session)
4400        sess_stat = server_context.session_stats()
4401        self.assertEqual(sess_stat['accept'], 3)
4402        self.assertEqual(sess_stat['hits'], 1)
4403
4404        # reuse session again
4405        stats = server_params_test(client_context, server_context,
4406                                   session=session, sni_name=hostname)
4407        self.assertTrue(stats['session_reused'])
4408        session4 = stats['session']
4409        self.assertEqual(session4.id, session.id)
4410        self.assertEqual(session4, session)
4411        self.assertGreaterEqual(session4.time, session.time)
4412        self.assertGreaterEqual(session4.timeout, session.timeout)
4413        sess_stat = server_context.session_stats()
4414        self.assertEqual(sess_stat['accept'], 4)
4415        self.assertEqual(sess_stat['hits'], 2)
4416
4417    def test_session_handling(self):
4418        client_context, server_context, hostname = testing_context()
4419        client_context2, _, _ = testing_context()
4420
4421        # TODO: session reuse does not work with TLSv1.3
4422        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4423        client_context2.maximum_version = ssl.TLSVersion.TLSv1_2
4424
4425        server = ThreadedEchoServer(context=server_context, chatty=False)
4426        with server:
4427            with client_context.wrap_socket(socket.socket(),
4428                                            server_hostname=hostname) as s:
4429                # session is None before handshake
4430                self.assertEqual(s.session, None)
4431                self.assertEqual(s.session_reused, None)
4432                s.connect((HOST, server.port))
4433                session = s.session
4434                self.assertTrue(session)
4435                with self.assertRaises(TypeError) as e:
4436                    s.session = object
4437                self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
4438
4439            with client_context.wrap_socket(socket.socket(),
4440                                            server_hostname=hostname) as s:
4441                s.connect((HOST, server.port))
4442                # cannot set session after handshake
4443                with self.assertRaises(ValueError) as e:
4444                    s.session = session
4445                self.assertEqual(str(e.exception),
4446                                 'Cannot set session after handshake.')
4447
4448            with client_context.wrap_socket(socket.socket(),
4449                                            server_hostname=hostname) as s:
4450                # can set session before handshake and before the
4451                # connection was established
4452                s.session = session
4453                s.connect((HOST, server.port))
4454                self.assertEqual(s.session.id, session.id)
4455                self.assertEqual(s.session, session)
4456                self.assertEqual(s.session_reused, True)
4457
4458            with client_context2.wrap_socket(socket.socket(),
4459                                             server_hostname=hostname) as s:
4460                # cannot re-use session with a different SSLContext
4461                with self.assertRaises(ValueError) as e:
4462                    s.session = session
4463                    s.connect((HOST, server.port))
4464                self.assertEqual(str(e.exception),
4465                                 'Session refers to a different SSLContext.')
4466
4467
4468@unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3")
4469class TestPostHandshakeAuth(unittest.TestCase):
4470    def test_pha_setter(self):
4471        protocols = [
4472            ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
4473        ]
4474        for protocol in protocols:
4475            ctx = ssl.SSLContext(protocol)
4476            self.assertEqual(ctx.post_handshake_auth, False)
4477
4478            ctx.post_handshake_auth = True
4479            self.assertEqual(ctx.post_handshake_auth, True)
4480
4481            ctx.verify_mode = ssl.CERT_REQUIRED
4482            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4483            self.assertEqual(ctx.post_handshake_auth, True)
4484
4485            ctx.post_handshake_auth = False
4486            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4487            self.assertEqual(ctx.post_handshake_auth, False)
4488
4489            ctx.verify_mode = ssl.CERT_OPTIONAL
4490            ctx.post_handshake_auth = True
4491            self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
4492            self.assertEqual(ctx.post_handshake_auth, True)
4493
4494    def test_pha_required(self):
4495        client_context, server_context, hostname = testing_context()
4496        server_context.post_handshake_auth = True
4497        server_context.verify_mode = ssl.CERT_REQUIRED
4498        client_context.post_handshake_auth = True
4499        client_context.load_cert_chain(SIGNED_CERTFILE)
4500
4501        server = ThreadedEchoServer(context=server_context, chatty=False)
4502        with server:
4503            with client_context.wrap_socket(socket.socket(),
4504                                            server_hostname=hostname) as s:
4505                s.connect((HOST, server.port))
4506                s.write(b'HASCERT')
4507                self.assertEqual(s.recv(1024), b'FALSE\n')
4508                s.write(b'PHA')
4509                self.assertEqual(s.recv(1024), b'OK\n')
4510                s.write(b'HASCERT')
4511                self.assertEqual(s.recv(1024), b'TRUE\n')
4512                # PHA method just returns true when cert is already available
4513                s.write(b'PHA')
4514                self.assertEqual(s.recv(1024), b'OK\n')
4515                s.write(b'GETCERT')
4516                cert_text = s.recv(4096).decode('us-ascii')
4517                self.assertIn('Python Software Foundation CA', cert_text)
4518
4519    def test_pha_required_nocert(self):
4520        client_context, server_context, hostname = testing_context()
4521        server_context.post_handshake_auth = True
4522        server_context.verify_mode = ssl.CERT_REQUIRED
4523        client_context.post_handshake_auth = True
4524
4525        def msg_cb(conn, direction, version, content_type, msg_type, data):
4526            if support.verbose and content_type == _TLSContentType.ALERT:
4527                info = (conn, direction, version, content_type, msg_type, data)
4528                sys.stdout.write(f"TLS: {info!r}\n")
4529
4530        server_context._msg_callback = msg_cb
4531        client_context._msg_callback = msg_cb
4532
4533        server = ThreadedEchoServer(context=server_context, chatty=True)
4534        with server:
4535            with client_context.wrap_socket(socket.socket(),
4536                                            server_hostname=hostname,
4537                                            suppress_ragged_eofs=False) as s:
4538                s.connect((HOST, server.port))
4539                s.write(b'PHA')
4540                # test sometimes fails with EOF error. Test passes as long as
4541                # server aborts connection with an error.
4542                with self.assertRaisesRegex(
4543                    ssl.SSLError,
4544                    '(certificate required|EOF occurred)'
4545                ):
4546                    # receive CertificateRequest
4547                    data = s.recv(1024)
4548                    self.assertEqual(data, b'OK\n')
4549
4550                    # send empty Certificate + Finish
4551                    s.write(b'HASCERT')
4552
4553                    # receive alert
4554                    s.recv(1024)
4555
4556    def test_pha_optional(self):
4557        if support.verbose:
4558            sys.stdout.write("\n")
4559
4560        client_context, server_context, hostname = testing_context()
4561        server_context.post_handshake_auth = True
4562        server_context.verify_mode = ssl.CERT_REQUIRED
4563        client_context.post_handshake_auth = True
4564        client_context.load_cert_chain(SIGNED_CERTFILE)
4565
4566        # check CERT_OPTIONAL
4567        server_context.verify_mode = ssl.CERT_OPTIONAL
4568        server = ThreadedEchoServer(context=server_context, chatty=False)
4569        with server:
4570            with client_context.wrap_socket(socket.socket(),
4571                                            server_hostname=hostname) as s:
4572                s.connect((HOST, server.port))
4573                s.write(b'HASCERT')
4574                self.assertEqual(s.recv(1024), b'FALSE\n')
4575                s.write(b'PHA')
4576                self.assertEqual(s.recv(1024), b'OK\n')
4577                s.write(b'HASCERT')
4578                self.assertEqual(s.recv(1024), b'TRUE\n')
4579
4580    def test_pha_optional_nocert(self):
4581        if support.verbose:
4582            sys.stdout.write("\n")
4583
4584        client_context, server_context, hostname = testing_context()
4585        server_context.post_handshake_auth = True
4586        server_context.verify_mode = ssl.CERT_OPTIONAL
4587        client_context.post_handshake_auth = True
4588
4589        server = ThreadedEchoServer(context=server_context, chatty=False)
4590        with server:
4591            with client_context.wrap_socket(socket.socket(),
4592                                            server_hostname=hostname) as s:
4593                s.connect((HOST, server.port))
4594                s.write(b'HASCERT')
4595                self.assertEqual(s.recv(1024), b'FALSE\n')
4596                s.write(b'PHA')
4597                self.assertEqual(s.recv(1024), b'OK\n')
4598                # optional doesn't fail when client does not have a cert
4599                s.write(b'HASCERT')
4600                self.assertEqual(s.recv(1024), b'FALSE\n')
4601
4602    def test_pha_no_pha_client(self):
4603        client_context, server_context, hostname = testing_context()
4604        server_context.post_handshake_auth = True
4605        server_context.verify_mode = ssl.CERT_REQUIRED
4606        client_context.load_cert_chain(SIGNED_CERTFILE)
4607
4608        server = ThreadedEchoServer(context=server_context, chatty=False)
4609        with server:
4610            with client_context.wrap_socket(socket.socket(),
4611                                            server_hostname=hostname) as s:
4612                s.connect((HOST, server.port))
4613                with self.assertRaisesRegex(ssl.SSLError, 'not server'):
4614                    s.verify_client_post_handshake()
4615                s.write(b'PHA')
4616                self.assertIn(b'extension not received', s.recv(1024))
4617
4618    def test_pha_no_pha_server(self):
4619        # server doesn't have PHA enabled, cert is requested in handshake
4620        client_context, server_context, hostname = testing_context()
4621        server_context.verify_mode = ssl.CERT_REQUIRED
4622        client_context.post_handshake_auth = True
4623        client_context.load_cert_chain(SIGNED_CERTFILE)
4624
4625        server = ThreadedEchoServer(context=server_context, chatty=False)
4626        with server:
4627            with client_context.wrap_socket(socket.socket(),
4628                                            server_hostname=hostname) as s:
4629                s.connect((HOST, server.port))
4630                s.write(b'HASCERT')
4631                self.assertEqual(s.recv(1024), b'TRUE\n')
4632                # PHA doesn't fail if there is already a cert
4633                s.write(b'PHA')
4634                self.assertEqual(s.recv(1024), b'OK\n')
4635                s.write(b'HASCERT')
4636                self.assertEqual(s.recv(1024), b'TRUE\n')
4637
4638    def test_pha_not_tls13(self):
4639        # TLS 1.2
4640        client_context, server_context, hostname = testing_context()
4641        server_context.verify_mode = ssl.CERT_REQUIRED
4642        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4643        client_context.post_handshake_auth = True
4644        client_context.load_cert_chain(SIGNED_CERTFILE)
4645
4646        server = ThreadedEchoServer(context=server_context, chatty=False)
4647        with server:
4648            with client_context.wrap_socket(socket.socket(),
4649                                            server_hostname=hostname) as s:
4650                s.connect((HOST, server.port))
4651                # PHA fails for TLS != 1.3
4652                s.write(b'PHA')
4653                self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
4654
4655    def test_bpo37428_pha_cert_none(self):
4656        # verify that post_handshake_auth does not implicitly enable cert
4657        # validation.
4658        hostname = SIGNED_CERTFILE_HOSTNAME
4659        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4660        client_context.post_handshake_auth = True
4661        client_context.load_cert_chain(SIGNED_CERTFILE)
4662        # no cert validation and CA on client side
4663        client_context.check_hostname = False
4664        client_context.verify_mode = ssl.CERT_NONE
4665
4666        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4667        server_context.load_cert_chain(SIGNED_CERTFILE)
4668        server_context.load_verify_locations(SIGNING_CA)
4669        server_context.post_handshake_auth = True
4670        server_context.verify_mode = ssl.CERT_REQUIRED
4671
4672        server = ThreadedEchoServer(context=server_context, chatty=False)
4673        with server:
4674            with client_context.wrap_socket(socket.socket(),
4675                                            server_hostname=hostname) as s:
4676                s.connect((HOST, server.port))
4677                s.write(b'HASCERT')
4678                self.assertEqual(s.recv(1024), b'FALSE\n')
4679                s.write(b'PHA')
4680                self.assertEqual(s.recv(1024), b'OK\n')
4681                s.write(b'HASCERT')
4682                self.assertEqual(s.recv(1024), b'TRUE\n')
4683                # server cert has not been validated
4684                self.assertEqual(s.getpeercert(), {})
4685
4686    def test_internal_chain_client(self):
4687        client_context, server_context, hostname = testing_context(
4688            server_chain=False
4689        )
4690        server = ThreadedEchoServer(context=server_context, chatty=False)
4691        with server:
4692            with client_context.wrap_socket(
4693                socket.socket(),
4694                server_hostname=hostname
4695            ) as s:
4696                s.connect((HOST, server.port))
4697                vc = s._sslobj.get_verified_chain()
4698                self.assertEqual(len(vc), 2)
4699                ee, ca = vc
4700                uvc = s._sslobj.get_unverified_chain()
4701                self.assertEqual(len(uvc), 1)
4702
4703                self.assertEqual(ee, uvc[0])
4704                self.assertEqual(hash(ee), hash(uvc[0]))
4705                self.assertEqual(repr(ee), repr(uvc[0]))
4706
4707                self.assertNotEqual(ee, ca)
4708                self.assertNotEqual(hash(ee), hash(ca))
4709                self.assertNotEqual(repr(ee), repr(ca))
4710                self.assertNotEqual(ee.get_info(), ca.get_info())
4711                self.assertIn("CN=localhost", repr(ee))
4712                self.assertIn("CN=our-ca-server", repr(ca))
4713
4714                pem = ee.public_bytes(_ssl.ENCODING_PEM)
4715                der = ee.public_bytes(_ssl.ENCODING_DER)
4716                self.assertIsInstance(pem, str)
4717                self.assertIn("-----BEGIN CERTIFICATE-----", pem)
4718                self.assertIsInstance(der, bytes)
4719                self.assertEqual(
4720                    ssl.PEM_cert_to_DER_cert(pem), der
4721                )
4722
4723    def test_internal_chain_server(self):
4724        client_context, server_context, hostname = testing_context()
4725        client_context.load_cert_chain(SIGNED_CERTFILE)
4726        server_context.verify_mode = ssl.CERT_REQUIRED
4727        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
4728
4729        server = ThreadedEchoServer(context=server_context, chatty=False)
4730        with server:
4731            with client_context.wrap_socket(
4732                socket.socket(),
4733                server_hostname=hostname
4734            ) as s:
4735                s.connect((HOST, server.port))
4736                s.write(b'VERIFIEDCHAIN\n')
4737                res = s.recv(1024)
4738                self.assertEqual(res, b'\x02\n')
4739                s.write(b'UNVERIFIEDCHAIN\n')
4740                res = s.recv(1024)
4741                self.assertEqual(res, b'\x02\n')
4742
4743
4744HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
4745requires_keylog = unittest.skipUnless(
4746    HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
4747
4748class TestSSLDebug(unittest.TestCase):
4749
4750    def keylog_lines(self, fname=os_helper.TESTFN):
4751        with open(fname) as f:
4752            return len(list(f))
4753
4754    @requires_keylog
4755    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4756    def test_keylog_defaults(self):
4757        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4758        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4759        self.assertEqual(ctx.keylog_filename, None)
4760
4761        self.assertFalse(os.path.isfile(os_helper.TESTFN))
4762        ctx.keylog_filename = os_helper.TESTFN
4763        self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)
4764        self.assertTrue(os.path.isfile(os_helper.TESTFN))
4765        self.assertEqual(self.keylog_lines(), 1)
4766
4767        ctx.keylog_filename = None
4768        self.assertEqual(ctx.keylog_filename, None)
4769
4770        with self.assertRaises((IsADirectoryError, PermissionError)):
4771            # Windows raises PermissionError
4772            ctx.keylog_filename = os.path.dirname(
4773                os.path.abspath(os_helper.TESTFN))
4774
4775        with self.assertRaises(TypeError):
4776            ctx.keylog_filename = 1
4777
4778    @requires_keylog
4779    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4780    def test_keylog_filename(self):
4781        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4782        client_context, server_context, hostname = testing_context()
4783
4784        client_context.keylog_filename = os_helper.TESTFN
4785        server = ThreadedEchoServer(context=server_context, chatty=False)
4786        with server:
4787            with client_context.wrap_socket(socket.socket(),
4788                                            server_hostname=hostname) as s:
4789                s.connect((HOST, server.port))
4790        # header, 5 lines for TLS 1.3
4791        self.assertEqual(self.keylog_lines(), 6)
4792
4793        client_context.keylog_filename = None
4794        server_context.keylog_filename = os_helper.TESTFN
4795        server = ThreadedEchoServer(context=server_context, chatty=False)
4796        with server:
4797            with client_context.wrap_socket(socket.socket(),
4798                                            server_hostname=hostname) as s:
4799                s.connect((HOST, server.port))
4800        self.assertGreaterEqual(self.keylog_lines(), 11)
4801
4802        client_context.keylog_filename = os_helper.TESTFN
4803        server_context.keylog_filename = os_helper.TESTFN
4804        server = ThreadedEchoServer(context=server_context, chatty=False)
4805        with server:
4806            with client_context.wrap_socket(socket.socket(),
4807                                            server_hostname=hostname) as s:
4808                s.connect((HOST, server.port))
4809        self.assertGreaterEqual(self.keylog_lines(), 21)
4810
4811        client_context.keylog_filename = None
4812        server_context.keylog_filename = None
4813
4814    @requires_keylog
4815    @unittest.skipIf(sys.flags.ignore_environment,
4816                     "test is not compatible with ignore_environment")
4817    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4818    def test_keylog_env(self):
4819        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4820        with unittest.mock.patch.dict(os.environ):
4821            os.environ['SSLKEYLOGFILE'] = os_helper.TESTFN
4822            self.assertEqual(os.environ['SSLKEYLOGFILE'], os_helper.TESTFN)
4823
4824            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4825            self.assertEqual(ctx.keylog_filename, None)
4826
4827            ctx = ssl.create_default_context()
4828            self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)
4829
4830            ctx = ssl._create_stdlib_context()
4831            self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)
4832
4833    def test_msg_callback(self):
4834        client_context, server_context, hostname = testing_context()
4835
4836        def msg_cb(conn, direction, version, content_type, msg_type, data):
4837            pass
4838
4839        self.assertIs(client_context._msg_callback, None)
4840        client_context._msg_callback = msg_cb
4841        self.assertIs(client_context._msg_callback, msg_cb)
4842        with self.assertRaises(TypeError):
4843            client_context._msg_callback = object()
4844
4845    def test_msg_callback_tls12(self):
4846        client_context, server_context, hostname = testing_context()
4847        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4848
4849        msg = []
4850
4851        def msg_cb(conn, direction, version, content_type, msg_type, data):
4852            self.assertIsInstance(conn, ssl.SSLSocket)
4853            self.assertIsInstance(data, bytes)
4854            self.assertIn(direction, {'read', 'write'})
4855            msg.append((direction, version, content_type, msg_type))
4856
4857        client_context._msg_callback = msg_cb
4858
4859        server = ThreadedEchoServer(context=server_context, chatty=False)
4860        with server:
4861            with client_context.wrap_socket(socket.socket(),
4862                                            server_hostname=hostname) as s:
4863                s.connect((HOST, server.port))
4864
4865        self.assertIn(
4866            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
4867             _TLSMessageType.SERVER_KEY_EXCHANGE),
4868            msg
4869        )
4870        self.assertIn(
4871            ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC,
4872             _TLSMessageType.CHANGE_CIPHER_SPEC),
4873            msg
4874        )
4875
4876    def test_msg_callback_deadlock_bpo43577(self):
4877        client_context, server_context, hostname = testing_context()
4878        server_context2 = testing_context()[1]
4879
4880        def msg_cb(conn, direction, version, content_type, msg_type, data):
4881            pass
4882
4883        def sni_cb(sock, servername, ctx):
4884            sock.context = server_context2
4885
4886        server_context._msg_callback = msg_cb
4887        server_context.sni_callback = sni_cb
4888
4889        server = ThreadedEchoServer(context=server_context, chatty=False)
4890        with server:
4891            with client_context.wrap_socket(socket.socket(),
4892                                            server_hostname=hostname) as s:
4893                s.connect((HOST, server.port))
4894            with client_context.wrap_socket(socket.socket(),
4895                                            server_hostname=hostname) as s:
4896                s.connect((HOST, server.port))
4897
4898
4899class TestEnumerations(unittest.TestCase):
4900
4901    def test_tlsversion(self):
4902        class CheckedTLSVersion(enum.IntEnum):
4903            MINIMUM_SUPPORTED = _ssl.PROTO_MINIMUM_SUPPORTED
4904            SSLv3 = _ssl.PROTO_SSLv3
4905            TLSv1 = _ssl.PROTO_TLSv1
4906            TLSv1_1 = _ssl.PROTO_TLSv1_1
4907            TLSv1_2 = _ssl.PROTO_TLSv1_2
4908            TLSv1_3 = _ssl.PROTO_TLSv1_3
4909            MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED
4910        enum._test_simple_enum(CheckedTLSVersion, TLSVersion)
4911
4912    def test_tlscontenttype(self):
4913        class Checked_TLSContentType(enum.IntEnum):
4914            """Content types (record layer)
4915
4916            See RFC 8446, section B.1
4917            """
4918            CHANGE_CIPHER_SPEC = 20
4919            ALERT = 21
4920            HANDSHAKE = 22
4921            APPLICATION_DATA = 23
4922            # pseudo content types
4923            HEADER = 0x100
4924            INNER_CONTENT_TYPE = 0x101
4925        enum._test_simple_enum(Checked_TLSContentType, _TLSContentType)
4926
4927    def test_tlsalerttype(self):
4928        class Checked_TLSAlertType(enum.IntEnum):
4929            """Alert types for TLSContentType.ALERT messages
4930
4931            See RFC 8466, section B.2
4932            """
4933            CLOSE_NOTIFY = 0
4934            UNEXPECTED_MESSAGE = 10
4935            BAD_RECORD_MAC = 20
4936            DECRYPTION_FAILED = 21
4937            RECORD_OVERFLOW = 22
4938            DECOMPRESSION_FAILURE = 30
4939            HANDSHAKE_FAILURE = 40
4940            NO_CERTIFICATE = 41
4941            BAD_CERTIFICATE = 42
4942            UNSUPPORTED_CERTIFICATE = 43
4943            CERTIFICATE_REVOKED = 44
4944            CERTIFICATE_EXPIRED = 45
4945            CERTIFICATE_UNKNOWN = 46
4946            ILLEGAL_PARAMETER = 47
4947            UNKNOWN_CA = 48
4948            ACCESS_DENIED = 49
4949            DECODE_ERROR = 50
4950            DECRYPT_ERROR = 51
4951            EXPORT_RESTRICTION = 60
4952            PROTOCOL_VERSION = 70
4953            INSUFFICIENT_SECURITY = 71
4954            INTERNAL_ERROR = 80
4955            INAPPROPRIATE_FALLBACK = 86
4956            USER_CANCELED = 90
4957            NO_RENEGOTIATION = 100
4958            MISSING_EXTENSION = 109
4959            UNSUPPORTED_EXTENSION = 110
4960            CERTIFICATE_UNOBTAINABLE = 111
4961            UNRECOGNIZED_NAME = 112
4962            BAD_CERTIFICATE_STATUS_RESPONSE = 113
4963            BAD_CERTIFICATE_HASH_VALUE = 114
4964            UNKNOWN_PSK_IDENTITY = 115
4965            CERTIFICATE_REQUIRED = 116
4966            NO_APPLICATION_PROTOCOL = 120
4967        enum._test_simple_enum(Checked_TLSAlertType, _TLSAlertType)
4968
4969    def test_tlsmessagetype(self):
4970        class Checked_TLSMessageType(enum.IntEnum):
4971            """Message types (handshake protocol)
4972
4973            See RFC 8446, section B.3
4974            """
4975            HELLO_REQUEST = 0
4976            CLIENT_HELLO = 1
4977            SERVER_HELLO = 2
4978            HELLO_VERIFY_REQUEST = 3
4979            NEWSESSION_TICKET = 4
4980            END_OF_EARLY_DATA = 5
4981            HELLO_RETRY_REQUEST = 6
4982            ENCRYPTED_EXTENSIONS = 8
4983            CERTIFICATE = 11
4984            SERVER_KEY_EXCHANGE = 12
4985            CERTIFICATE_REQUEST = 13
4986            SERVER_DONE = 14
4987            CERTIFICATE_VERIFY = 15
4988            CLIENT_KEY_EXCHANGE = 16
4989            FINISHED = 20
4990            CERTIFICATE_URL = 21
4991            CERTIFICATE_STATUS = 22
4992            SUPPLEMENTAL_DATA = 23
4993            KEY_UPDATE = 24
4994            NEXT_PROTO = 67
4995            MESSAGE_HASH = 254
4996            CHANGE_CIPHER_SPEC = 0x0101
4997        enum._test_simple_enum(Checked_TLSMessageType, _TLSMessageType)
4998
4999    def test_sslmethod(self):
5000        Checked_SSLMethod = enum._old_convert_(
5001                enum.IntEnum, '_SSLMethod', 'ssl',
5002                lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23',
5003                source=ssl._ssl,
5004                )
5005        # This member is assigned dynamically in `ssl.py`:
5006        Checked_SSLMethod.PROTOCOL_SSLv23 = Checked_SSLMethod.PROTOCOL_TLS
5007        enum._test_simple_enum(Checked_SSLMethod, ssl._SSLMethod)
5008
5009    def test_options(self):
5010        CheckedOptions = enum._old_convert_(
5011                enum.IntFlag, 'Options', 'ssl',
5012                lambda name: name.startswith('OP_'),
5013                source=ssl._ssl,
5014                )
5015        enum._test_simple_enum(CheckedOptions, ssl.Options)
5016
5017    def test_alertdescription(self):
5018        CheckedAlertDescription = enum._old_convert_(
5019                enum.IntEnum, 'AlertDescription', 'ssl',
5020                lambda name: name.startswith('ALERT_DESCRIPTION_'),
5021                source=ssl._ssl,
5022                )
5023        enum._test_simple_enum(CheckedAlertDescription, ssl.AlertDescription)
5024
5025    def test_sslerrornumber(self):
5026        Checked_SSLErrorNumber = enum._old_convert_(
5027                enum.IntEnum, 'SSLErrorNumber', 'ssl',
5028                lambda name: name.startswith('SSL_ERROR_'),
5029                source=ssl._ssl,
5030                )
5031        enum._test_simple_enum(Checked_SSLErrorNumber, ssl.SSLErrorNumber)
5032
5033    def test_verifyflags(self):
5034        CheckedVerifyFlags = enum._old_convert_(
5035                enum.IntFlag, 'VerifyFlags', 'ssl',
5036                lambda name: name.startswith('VERIFY_'),
5037                source=ssl._ssl,
5038                )
5039        enum._test_simple_enum(CheckedVerifyFlags, ssl.VerifyFlags)
5040
5041    def test_verifymode(self):
5042        CheckedVerifyMode = enum._old_convert_(
5043                enum.IntEnum, 'VerifyMode', 'ssl',
5044                lambda name: name.startswith('CERT_'),
5045                source=ssl._ssl,
5046                )
5047        enum._test_simple_enum(CheckedVerifyMode, ssl.VerifyMode)
5048
5049
5050def setUpModule():
5051    if support.verbose:
5052        plats = {
5053            'Mac': platform.mac_ver,
5054            'Windows': platform.win32_ver,
5055        }
5056        for name, func in plats.items():
5057            plat = func()
5058            if plat and plat[0]:
5059                plat = '%s %r' % (name, plat)
5060                break
5061        else:
5062            plat = repr(platform.platform())
5063        print("test_ssl: testing with %r %r" %
5064            (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
5065        print("          under %s" % plat)
5066        print("          HAS_SNI = %r" % ssl.HAS_SNI)
5067        print("          OP_ALL = 0x%8x" % ssl.OP_ALL)
5068        try:
5069            print("          OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1)
5070        except AttributeError:
5071            pass
5072
5073    for filename in [
5074        CERTFILE, BYTES_CERTFILE,
5075        ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY,
5076        SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA,
5077        BADCERT, BADKEY, EMPTYCERT]:
5078        if not os.path.exists(filename):
5079            raise support.TestFailed("Can't read certificate file %r" % filename)
5080
5081    thread_info = threading_helper.threading_setup()
5082    unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info)
5083
5084
5085if __name__ == "__main__":
5086    unittest.main()
5087