1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING,
2                            NAME_MAPPING, REVERSE_NAME_MAPPING)
3import builtins
4import pickle
5import io
6import collections
7import struct
8import sys
9import warnings
10import weakref
11
12import doctest
13import unittest
14from test import support
15from test.support import import_helper
16
17from test.pickletester import AbstractHookTests
18from test.pickletester import AbstractUnpickleTests
19from test.pickletester import AbstractPickleTests
20from test.pickletester import AbstractPickleModuleTests
21from test.pickletester import AbstractPersistentPicklerTests
22from test.pickletester import AbstractIdentityPersistentPicklerTests
23from test.pickletester import AbstractPicklerUnpicklerObjectTests
24from test.pickletester import AbstractDispatchTableTests
25from test.pickletester import AbstractCustomPicklerClass
26from test.pickletester import BigmemPickleTests
27
28try:
29    import _pickle
30    has_c_implementation = True
31except ImportError:
32    has_c_implementation = False
33
34
35class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase):
36    dump = staticmethod(pickle._dump)
37    dumps = staticmethod(pickle._dumps)
38    load = staticmethod(pickle._load)
39    loads = staticmethod(pickle._loads)
40    Pickler = pickle._Pickler
41    Unpickler = pickle._Unpickler
42
43
44class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase):
45
46    unpickler = pickle._Unpickler
47    bad_stack_errors = (IndexError,)
48    truncated_errors = (pickle.UnpicklingError, EOFError,
49                        AttributeError, ValueError,
50                        struct.error, IndexError, ImportError)
51
52    def loads(self, buf, **kwds):
53        f = io.BytesIO(buf)
54        u = self.unpickler(f, **kwds)
55        return u.load()
56
57
58class PyPicklerTests(AbstractPickleTests, unittest.TestCase):
59
60    pickler = pickle._Pickler
61    unpickler = pickle._Unpickler
62
63    def dumps(self, arg, proto=None, **kwargs):
64        f = io.BytesIO()
65        p = self.pickler(f, proto, **kwargs)
66        p.dump(arg)
67        f.seek(0)
68        return bytes(f.read())
69
70    def loads(self, buf, **kwds):
71        f = io.BytesIO(buf)
72        u = self.unpickler(f, **kwds)
73        return u.load()
74
75
76class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
77                          BigmemPickleTests, unittest.TestCase):
78
79    bad_stack_errors = (pickle.UnpicklingError, IndexError)
80    truncated_errors = (pickle.UnpicklingError, EOFError,
81                        AttributeError, ValueError,
82                        struct.error, IndexError, ImportError)
83
84    def dumps(self, arg, protocol=None, **kwargs):
85        return pickle.dumps(arg, protocol, **kwargs)
86
87    def loads(self, buf, **kwds):
88        return pickle.loads(buf, **kwds)
89
90    test_framed_write_sizes_with_delayed_writer = None
91
92
93class PersistentPicklerUnpicklerMixin(object):
94
95    def dumps(self, arg, proto=None):
96        class PersPickler(self.pickler):
97            def persistent_id(subself, obj):
98                return self.persistent_id(obj)
99        f = io.BytesIO()
100        p = PersPickler(f, proto)
101        p.dump(arg)
102        return f.getvalue()
103
104    def loads(self, buf, **kwds):
105        class PersUnpickler(self.unpickler):
106            def persistent_load(subself, obj):
107                return self.persistent_load(obj)
108        f = io.BytesIO(buf)
109        u = PersUnpickler(f, **kwds)
110        return u.load()
111
112
113class PyPersPicklerTests(AbstractPersistentPicklerTests,
114                         PersistentPicklerUnpicklerMixin, unittest.TestCase):
115
116    pickler = pickle._Pickler
117    unpickler = pickle._Unpickler
118
119
120class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
121                           PersistentPicklerUnpicklerMixin, unittest.TestCase):
122
123    pickler = pickle._Pickler
124    unpickler = pickle._Unpickler
125
126    @support.cpython_only
127    def test_pickler_reference_cycle(self):
128        def check(Pickler):
129            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
130                f = io.BytesIO()
131                pickler = Pickler(f, proto)
132                pickler.dump('abc')
133                self.assertEqual(self.loads(f.getvalue()), 'abc')
134            pickler = Pickler(io.BytesIO())
135            self.assertEqual(pickler.persistent_id('def'), 'def')
136            r = weakref.ref(pickler)
137            del pickler
138            self.assertIsNone(r())
139
140        class PersPickler(self.pickler):
141            def persistent_id(subself, obj):
142                return obj
143        check(PersPickler)
144
145        class PersPickler(self.pickler):
146            @classmethod
147            def persistent_id(cls, obj):
148                return obj
149        check(PersPickler)
150
151        class PersPickler(self.pickler):
152            @staticmethod
153            def persistent_id(obj):
154                return obj
155        check(PersPickler)
156
157    @support.cpython_only
158    def test_custom_pickler_dispatch_table_memleak(self):
159        # See https://github.com/python/cpython/issues/89988
160
161        class Pickler(self.pickler):
162            def __init__(self, *args, **kwargs):
163                self.dispatch_table = table
164                super().__init__(*args, **kwargs)
165
166        class DispatchTable:
167            pass
168
169        table = DispatchTable()
170        pickler = Pickler(io.BytesIO())
171        self.assertIs(pickler.dispatch_table, table)
172        table_ref = weakref.ref(table)
173        self.assertIsNotNone(table_ref())
174        del pickler
175        del table
176        support.gc_collect()
177        self.assertIsNone(table_ref())
178
179
180    @support.cpython_only
181    def test_unpickler_reference_cycle(self):
182        def check(Unpickler):
183            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
184                unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
185                self.assertEqual(unpickler.load(), 'abc')
186            unpickler = Unpickler(io.BytesIO())
187            self.assertEqual(unpickler.persistent_load('def'), 'def')
188            r = weakref.ref(unpickler)
189            del unpickler
190            self.assertIsNone(r())
191
192        class PersUnpickler(self.unpickler):
193            def persistent_load(subself, pid):
194                return pid
195        check(PersUnpickler)
196
197        class PersUnpickler(self.unpickler):
198            @classmethod
199            def persistent_load(cls, pid):
200                return pid
201        check(PersUnpickler)
202
203        class PersUnpickler(self.unpickler):
204            @staticmethod
205            def persistent_load(pid):
206                return pid
207        check(PersUnpickler)
208
209
210class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):
211
212    pickler_class = pickle._Pickler
213    unpickler_class = pickle._Unpickler
214
215
216class PyDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
217
218    pickler_class = pickle._Pickler
219
220    def get_dispatch_table(self):
221        return pickle.dispatch_table.copy()
222
223
224class PyChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
225
226    pickler_class = pickle._Pickler
227
228    def get_dispatch_table(self):
229        return collections.ChainMap({}, pickle.dispatch_table)
230
231
232class PyPicklerHookTests(AbstractHookTests, unittest.TestCase):
233    class CustomPyPicklerClass(pickle._Pickler,
234                               AbstractCustomPicklerClass):
235        pass
236    pickler_class = CustomPyPicklerClass
237
238
239if has_c_implementation:
240    class CPickleTests(AbstractPickleModuleTests, unittest.TestCase):
241        from _pickle import dump, dumps, load, loads, Pickler, Unpickler
242
243    class CUnpicklerTests(PyUnpicklerTests):
244        unpickler = _pickle.Unpickler
245        bad_stack_errors = (pickle.UnpicklingError,)
246        truncated_errors = (pickle.UnpicklingError,)
247
248    class CPicklerTests(PyPicklerTests):
249        pickler = _pickle.Pickler
250        unpickler = _pickle.Unpickler
251
252    class CPersPicklerTests(PyPersPicklerTests):
253        pickler = _pickle.Pickler
254        unpickler = _pickle.Unpickler
255
256    class CIdPersPicklerTests(PyIdPersPicklerTests):
257        pickler = _pickle.Pickler
258        unpickler = _pickle.Unpickler
259
260    class CDumpPickle_LoadPickle(PyPicklerTests):
261        pickler = _pickle.Pickler
262        unpickler = pickle._Unpickler
263
264    class DumpPickle_CLoadPickle(PyPicklerTests):
265        pickler = pickle._Pickler
266        unpickler = _pickle.Unpickler
267
268    class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):
269        pickler_class = _pickle.Pickler
270        unpickler_class = _pickle.Unpickler
271
272        def test_issue18339(self):
273            unpickler = self.unpickler_class(io.BytesIO())
274            with self.assertRaises(TypeError):
275                unpickler.memo = object
276            # used to cause a segfault
277            with self.assertRaises(ValueError):
278                unpickler.memo = {-1: None}
279            unpickler.memo = {1: None}
280
281    class CDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
282        pickler_class = pickle.Pickler
283        def get_dispatch_table(self):
284            return pickle.dispatch_table.copy()
285
286    class CChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase):
287        pickler_class = pickle.Pickler
288        def get_dispatch_table(self):
289            return collections.ChainMap({}, pickle.dispatch_table)
290
291    class CPicklerHookTests(AbstractHookTests, unittest.TestCase):
292        class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass):
293            pass
294        pickler_class = CustomCPicklerClass
295
296    @support.cpython_only
297    class SizeofTests(unittest.TestCase):
298        check_sizeof = support.check_sizeof
299
300        def test_pickler(self):
301            basesize = support.calcobjsize('7P2n3i2n3i2P')
302            p = _pickle.Pickler(io.BytesIO())
303            self.assertEqual(object.__sizeof__(p), basesize)
304            MT_size = struct.calcsize('3nP0n')
305            ME_size = struct.calcsize('Pn0P')
306            check = self.check_sizeof
307            check(p, basesize +
308                MT_size + 8 * ME_size +  # Minimal memo table size.
309                sys.getsizeof(b'x'*4096))  # Minimal write buffer size.
310            for i in range(6):
311                p.dump(chr(i))
312            check(p, basesize +
313                MT_size + 32 * ME_size +  # Size of memo table required to
314                                          # save references to 6 objects.
315                0)  # Write buffer is cleared after every dump().
316
317        def test_unpickler(self):
318            basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i')
319            unpickler = _pickle.Unpickler
320            P = struct.calcsize('P')  # Size of memo table entry.
321            n = struct.calcsize('n')  # Size of mark table entry.
322            check = self.check_sizeof
323            for encoding in 'ASCII', 'UTF-16', 'latin-1':
324                for errors in 'strict', 'replace':
325                    u = unpickler(io.BytesIO(),
326                                  encoding=encoding, errors=errors)
327                    self.assertEqual(object.__sizeof__(u), basesize)
328                    check(u, basesize +
329                             32 * P +  # Minimal memo table size.
330                             len(encoding) + 1 + len(errors) + 1)
331
332            stdsize = basesize + len('ASCII') + 1 + len('strict') + 1
333            def check_unpickler(data, memo_size, marks_size):
334                dump = pickle.dumps(data)
335                u = unpickler(io.BytesIO(dump),
336                              encoding='ASCII', errors='strict')
337                u.load()
338                check(u, stdsize + memo_size * P + marks_size * n)
339
340            check_unpickler(0, 32, 0)
341            # 20 is minimal non-empty mark stack size.
342            check_unpickler([0] * 100, 32, 20)
343            # 128 is memo table size required to save references to 100 objects.
344            check_unpickler([chr(i) for i in range(100)], 128, 20)
345            def recurse(deep):
346                data = 0
347                for i in range(deep):
348                    data = [data, data]
349                return data
350            check_unpickler(recurse(0), 32, 0)
351            check_unpickler(recurse(1), 32, 20)
352            check_unpickler(recurse(20), 32, 20)
353            check_unpickler(recurse(50), 64, 60)
354            check_unpickler(recurse(100), 128, 140)
355
356            u = unpickler(io.BytesIO(pickle.dumps('a', 0)),
357                          encoding='ASCII', errors='strict')
358            u.load()
359            check(u, stdsize + 32 * P + 2 + 1)
360
361
362ALT_IMPORT_MAPPING = {
363    ('_elementtree', 'xml.etree.ElementTree'),
364    ('cPickle', 'pickle'),
365    ('StringIO', 'io'),
366    ('cStringIO', 'io'),
367}
368
369ALT_NAME_MAPPING = {
370    ('__builtin__', 'basestring', 'builtins', 'str'),
371    ('exceptions', 'StandardError', 'builtins', 'Exception'),
372    ('UserDict', 'UserDict', 'collections', 'UserDict'),
373    ('socket', '_socketobject', 'socket', 'SocketType'),
374}
375
376def mapping(module, name):
377    if (module, name) in NAME_MAPPING:
378        module, name = NAME_MAPPING[(module, name)]
379    elif module in IMPORT_MAPPING:
380        module = IMPORT_MAPPING[module]
381    return module, name
382
383def reverse_mapping(module, name):
384    if (module, name) in REVERSE_NAME_MAPPING:
385        module, name = REVERSE_NAME_MAPPING[(module, name)]
386    elif module in REVERSE_IMPORT_MAPPING:
387        module = REVERSE_IMPORT_MAPPING[module]
388    return module, name
389
390def getmodule(module):
391    try:
392        return sys.modules[module]
393    except KeyError:
394        try:
395            with warnings.catch_warnings():
396                action = 'always' if support.verbose else 'ignore'
397                warnings.simplefilter(action, DeprecationWarning)
398                __import__(module)
399        except AttributeError as exc:
400            if support.verbose:
401                print("Can't import module %r: %s" % (module, exc))
402            raise ImportError
403        except ImportError as exc:
404            if support.verbose:
405                print(exc)
406            raise
407        return sys.modules[module]
408
409def getattribute(module, name):
410    obj = getmodule(module)
411    for n in name.split('.'):
412        obj = getattr(obj, n)
413    return obj
414
415def get_exceptions(mod):
416    for name in dir(mod):
417        attr = getattr(mod, name)
418        if isinstance(attr, type) and issubclass(attr, BaseException):
419            yield name, attr
420
421class CompatPickleTests(unittest.TestCase):
422    def test_import(self):
423        modules = set(IMPORT_MAPPING.values())
424        modules |= set(REVERSE_IMPORT_MAPPING)
425        modules |= {module for module, name in REVERSE_NAME_MAPPING}
426        modules |= {module for module, name in NAME_MAPPING.values()}
427        for module in modules:
428            try:
429                getmodule(module)
430            except ImportError:
431                pass
432
433    def test_import_mapping(self):
434        for module3, module2 in REVERSE_IMPORT_MAPPING.items():
435            with self.subTest((module3, module2)):
436                try:
437                    getmodule(module3)
438                except ImportError:
439                    pass
440                if module3[:1] != '_':
441                    self.assertIn(module2, IMPORT_MAPPING)
442                    self.assertEqual(IMPORT_MAPPING[module2], module3)
443
444    def test_name_mapping(self):
445        for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items():
446            with self.subTest(((module3, name3), (module2, name2))):
447                if (module2, name2) == ('exceptions', 'OSError'):
448                    attr = getattribute(module3, name3)
449                    self.assertTrue(issubclass(attr, OSError))
450                elif (module2, name2) == ('exceptions', 'ImportError'):
451                    attr = getattribute(module3, name3)
452                    self.assertTrue(issubclass(attr, ImportError))
453                else:
454                    module, name = mapping(module2, name2)
455                    if module3[:1] != '_':
456                        self.assertEqual((module, name), (module3, name3))
457                    try:
458                        attr = getattribute(module3, name3)
459                    except ImportError:
460                        pass
461                    else:
462                        self.assertEqual(getattribute(module, name), attr)
463
464    def test_reverse_import_mapping(self):
465        for module2, module3 in IMPORT_MAPPING.items():
466            with self.subTest((module2, module3)):
467                try:
468                    getmodule(module3)
469                except ImportError as exc:
470                    if support.verbose:
471                        print(exc)
472                if ((module2, module3) not in ALT_IMPORT_MAPPING and
473                    REVERSE_IMPORT_MAPPING.get(module3, None) != module2):
474                    for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items():
475                        if (module3, module2) == (m3, m2):
476                            break
477                    else:
478                        self.fail('No reverse mapping from %r to %r' %
479                                  (module3, module2))
480                module = REVERSE_IMPORT_MAPPING.get(module3, module3)
481                module = IMPORT_MAPPING.get(module, module)
482                self.assertEqual(module, module3)
483
484    def test_reverse_name_mapping(self):
485        for (module2, name2), (module3, name3) in NAME_MAPPING.items():
486            with self.subTest(((module2, name2), (module3, name3))):
487                try:
488                    attr = getattribute(module3, name3)
489                except ImportError:
490                    pass
491                module, name = reverse_mapping(module3, name3)
492                if (module2, name2, module3, name3) not in ALT_NAME_MAPPING:
493                    self.assertEqual((module, name), (module2, name2))
494                module, name = mapping(module, name)
495                self.assertEqual((module, name), (module3, name3))
496
497    def test_exceptions(self):
498        self.assertEqual(mapping('exceptions', 'StandardError'),
499                         ('builtins', 'Exception'))
500        self.assertEqual(mapping('exceptions', 'Exception'),
501                         ('builtins', 'Exception'))
502        self.assertEqual(reverse_mapping('builtins', 'Exception'),
503                         ('exceptions', 'Exception'))
504        self.assertEqual(mapping('exceptions', 'OSError'),
505                         ('builtins', 'OSError'))
506        self.assertEqual(reverse_mapping('builtins', 'OSError'),
507                         ('exceptions', 'OSError'))
508
509        for name, exc in get_exceptions(builtins):
510            with self.subTest(name):
511                if exc in (BlockingIOError,
512                           ResourceWarning,
513                           StopAsyncIteration,
514                           RecursionError,
515                           EncodingWarning,
516                           BaseExceptionGroup,
517                           ExceptionGroup):
518                    continue
519                if exc is not OSError and issubclass(exc, OSError):
520                    self.assertEqual(reverse_mapping('builtins', name),
521                                     ('exceptions', 'OSError'))
522                elif exc is not ImportError and issubclass(exc, ImportError):
523                    self.assertEqual(reverse_mapping('builtins', name),
524                                     ('exceptions', 'ImportError'))
525                    self.assertEqual(mapping('exceptions', name),
526                                     ('exceptions', name))
527                else:
528                    self.assertEqual(reverse_mapping('builtins', name),
529                                     ('exceptions', name))
530                    self.assertEqual(mapping('exceptions', name),
531                                     ('builtins', name))
532
533    def test_multiprocessing_exceptions(self):
534        module = import_helper.import_module('multiprocessing.context')
535        for name, exc in get_exceptions(module):
536            with self.subTest(name):
537                self.assertEqual(reverse_mapping('multiprocessing.context', name),
538                                 ('multiprocessing', name))
539                self.assertEqual(mapping('multiprocessing', name),
540                                 ('multiprocessing.context', name))
541
542
543def load_tests(loader, tests, pattern):
544    tests.addTest(doctest.DocTestSuite())
545    return tests
546
547
548if __name__ == "__main__":
549    unittest.main()
550