1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING, 2 NAME_MAPPING, REVERSE_NAME_MAPPING) 3import builtins 4import pickle 5import io 6import collections 7import struct 8import sys 9import warnings 10import weakref 11 12import doctest 13import unittest 14from test import support 15from test.support import import_helper 16 17from test.pickletester import AbstractHookTests 18from test.pickletester import AbstractUnpickleTests 19from test.pickletester import AbstractPickleTests 20from test.pickletester import AbstractPickleModuleTests 21from test.pickletester import AbstractPersistentPicklerTests 22from test.pickletester import AbstractIdentityPersistentPicklerTests 23from test.pickletester import AbstractPicklerUnpicklerObjectTests 24from test.pickletester import AbstractDispatchTableTests 25from test.pickletester import AbstractCustomPicklerClass 26from test.pickletester import BigmemPickleTests 27 28try: 29 import _pickle 30 has_c_implementation = True 31except ImportError: 32 has_c_implementation = False 33 34 35class PyPickleTests(AbstractPickleModuleTests, unittest.TestCase): 36 dump = staticmethod(pickle._dump) 37 dumps = staticmethod(pickle._dumps) 38 load = staticmethod(pickle._load) 39 loads = staticmethod(pickle._loads) 40 Pickler = pickle._Pickler 41 Unpickler = pickle._Unpickler 42 43 44class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase): 45 46 unpickler = pickle._Unpickler 47 bad_stack_errors = (IndexError,) 48 truncated_errors = (pickle.UnpicklingError, EOFError, 49 AttributeError, ValueError, 50 struct.error, IndexError, ImportError) 51 52 def loads(self, buf, **kwds): 53 f = io.BytesIO(buf) 54 u = self.unpickler(f, **kwds) 55 return u.load() 56 57 58class PyPicklerTests(AbstractPickleTests, unittest.TestCase): 59 60 pickler = pickle._Pickler 61 unpickler = pickle._Unpickler 62 63 def dumps(self, arg, proto=None, **kwargs): 64 f = io.BytesIO() 65 p = self.pickler(f, proto, **kwargs) 66 p.dump(arg) 67 f.seek(0) 68 return bytes(f.read()) 69 70 def loads(self, buf, **kwds): 71 f = io.BytesIO(buf) 72 u = self.unpickler(f, **kwds) 73 return u.load() 74 75 76class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests, 77 BigmemPickleTests, unittest.TestCase): 78 79 bad_stack_errors = (pickle.UnpicklingError, IndexError) 80 truncated_errors = (pickle.UnpicklingError, EOFError, 81 AttributeError, ValueError, 82 struct.error, IndexError, ImportError) 83 84 def dumps(self, arg, protocol=None, **kwargs): 85 return pickle.dumps(arg, protocol, **kwargs) 86 87 def loads(self, buf, **kwds): 88 return pickle.loads(buf, **kwds) 89 90 test_framed_write_sizes_with_delayed_writer = None 91 92 93class PersistentPicklerUnpicklerMixin(object): 94 95 def dumps(self, arg, proto=None): 96 class PersPickler(self.pickler): 97 def persistent_id(subself, obj): 98 return self.persistent_id(obj) 99 f = io.BytesIO() 100 p = PersPickler(f, proto) 101 p.dump(arg) 102 return f.getvalue() 103 104 def loads(self, buf, **kwds): 105 class PersUnpickler(self.unpickler): 106 def persistent_load(subself, obj): 107 return self.persistent_load(obj) 108 f = io.BytesIO(buf) 109 u = PersUnpickler(f, **kwds) 110 return u.load() 111 112 113class PyPersPicklerTests(AbstractPersistentPicklerTests, 114 PersistentPicklerUnpicklerMixin, unittest.TestCase): 115 116 pickler = pickle._Pickler 117 unpickler = pickle._Unpickler 118 119 120class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests, 121 PersistentPicklerUnpicklerMixin, unittest.TestCase): 122 123 pickler = pickle._Pickler 124 unpickler = pickle._Unpickler 125 126 @support.cpython_only 127 def test_pickler_reference_cycle(self): 128 def check(Pickler): 129 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 130 f = io.BytesIO() 131 pickler = Pickler(f, proto) 132 pickler.dump('abc') 133 self.assertEqual(self.loads(f.getvalue()), 'abc') 134 pickler = Pickler(io.BytesIO()) 135 self.assertEqual(pickler.persistent_id('def'), 'def') 136 r = weakref.ref(pickler) 137 del pickler 138 self.assertIsNone(r()) 139 140 class PersPickler(self.pickler): 141 def persistent_id(subself, obj): 142 return obj 143 check(PersPickler) 144 145 class PersPickler(self.pickler): 146 @classmethod 147 def persistent_id(cls, obj): 148 return obj 149 check(PersPickler) 150 151 class PersPickler(self.pickler): 152 @staticmethod 153 def persistent_id(obj): 154 return obj 155 check(PersPickler) 156 157 @support.cpython_only 158 def test_custom_pickler_dispatch_table_memleak(self): 159 # See https://github.com/python/cpython/issues/89988 160 161 class Pickler(self.pickler): 162 def __init__(self, *args, **kwargs): 163 self.dispatch_table = table 164 super().__init__(*args, **kwargs) 165 166 class DispatchTable: 167 pass 168 169 table = DispatchTable() 170 pickler = Pickler(io.BytesIO()) 171 self.assertIs(pickler.dispatch_table, table) 172 table_ref = weakref.ref(table) 173 self.assertIsNotNone(table_ref()) 174 del pickler 175 del table 176 support.gc_collect() 177 self.assertIsNone(table_ref()) 178 179 180 @support.cpython_only 181 def test_unpickler_reference_cycle(self): 182 def check(Unpickler): 183 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 184 unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto))) 185 self.assertEqual(unpickler.load(), 'abc') 186 unpickler = Unpickler(io.BytesIO()) 187 self.assertEqual(unpickler.persistent_load('def'), 'def') 188 r = weakref.ref(unpickler) 189 del unpickler 190 self.assertIsNone(r()) 191 192 class PersUnpickler(self.unpickler): 193 def persistent_load(subself, pid): 194 return pid 195 check(PersUnpickler) 196 197 class PersUnpickler(self.unpickler): 198 @classmethod 199 def persistent_load(cls, pid): 200 return pid 201 check(PersUnpickler) 202 203 class PersUnpickler(self.unpickler): 204 @staticmethod 205 def persistent_load(pid): 206 return pid 207 check(PersUnpickler) 208 209 210class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase): 211 212 pickler_class = pickle._Pickler 213 unpickler_class = pickle._Unpickler 214 215 216class PyDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 217 218 pickler_class = pickle._Pickler 219 220 def get_dispatch_table(self): 221 return pickle.dispatch_table.copy() 222 223 224class PyChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 225 226 pickler_class = pickle._Pickler 227 228 def get_dispatch_table(self): 229 return collections.ChainMap({}, pickle.dispatch_table) 230 231 232class PyPicklerHookTests(AbstractHookTests, unittest.TestCase): 233 class CustomPyPicklerClass(pickle._Pickler, 234 AbstractCustomPicklerClass): 235 pass 236 pickler_class = CustomPyPicklerClass 237 238 239if has_c_implementation: 240 class CPickleTests(AbstractPickleModuleTests, unittest.TestCase): 241 from _pickle import dump, dumps, load, loads, Pickler, Unpickler 242 243 class CUnpicklerTests(PyUnpicklerTests): 244 unpickler = _pickle.Unpickler 245 bad_stack_errors = (pickle.UnpicklingError,) 246 truncated_errors = (pickle.UnpicklingError,) 247 248 class CPicklerTests(PyPicklerTests): 249 pickler = _pickle.Pickler 250 unpickler = _pickle.Unpickler 251 252 class CPersPicklerTests(PyPersPicklerTests): 253 pickler = _pickle.Pickler 254 unpickler = _pickle.Unpickler 255 256 class CIdPersPicklerTests(PyIdPersPicklerTests): 257 pickler = _pickle.Pickler 258 unpickler = _pickle.Unpickler 259 260 class CDumpPickle_LoadPickle(PyPicklerTests): 261 pickler = _pickle.Pickler 262 unpickler = pickle._Unpickler 263 264 class DumpPickle_CLoadPickle(PyPicklerTests): 265 pickler = pickle._Pickler 266 unpickler = _pickle.Unpickler 267 268 class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase): 269 pickler_class = _pickle.Pickler 270 unpickler_class = _pickle.Unpickler 271 272 def test_issue18339(self): 273 unpickler = self.unpickler_class(io.BytesIO()) 274 with self.assertRaises(TypeError): 275 unpickler.memo = object 276 # used to cause a segfault 277 with self.assertRaises(ValueError): 278 unpickler.memo = {-1: None} 279 unpickler.memo = {1: None} 280 281 class CDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 282 pickler_class = pickle.Pickler 283 def get_dispatch_table(self): 284 return pickle.dispatch_table.copy() 285 286 class CChainDispatchTableTests(AbstractDispatchTableTests, unittest.TestCase): 287 pickler_class = pickle.Pickler 288 def get_dispatch_table(self): 289 return collections.ChainMap({}, pickle.dispatch_table) 290 291 class CPicklerHookTests(AbstractHookTests, unittest.TestCase): 292 class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass): 293 pass 294 pickler_class = CustomCPicklerClass 295 296 @support.cpython_only 297 class SizeofTests(unittest.TestCase): 298 check_sizeof = support.check_sizeof 299 300 def test_pickler(self): 301 basesize = support.calcobjsize('7P2n3i2n3i2P') 302 p = _pickle.Pickler(io.BytesIO()) 303 self.assertEqual(object.__sizeof__(p), basesize) 304 MT_size = struct.calcsize('3nP0n') 305 ME_size = struct.calcsize('Pn0P') 306 check = self.check_sizeof 307 check(p, basesize + 308 MT_size + 8 * ME_size + # Minimal memo table size. 309 sys.getsizeof(b'x'*4096)) # Minimal write buffer size. 310 for i in range(6): 311 p.dump(chr(i)) 312 check(p, basesize + 313 MT_size + 32 * ME_size + # Size of memo table required to 314 # save references to 6 objects. 315 0) # Write buffer is cleared after every dump(). 316 317 def test_unpickler(self): 318 basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i') 319 unpickler = _pickle.Unpickler 320 P = struct.calcsize('P') # Size of memo table entry. 321 n = struct.calcsize('n') # Size of mark table entry. 322 check = self.check_sizeof 323 for encoding in 'ASCII', 'UTF-16', 'latin-1': 324 for errors in 'strict', 'replace': 325 u = unpickler(io.BytesIO(), 326 encoding=encoding, errors=errors) 327 self.assertEqual(object.__sizeof__(u), basesize) 328 check(u, basesize + 329 32 * P + # Minimal memo table size. 330 len(encoding) + 1 + len(errors) + 1) 331 332 stdsize = basesize + len('ASCII') + 1 + len('strict') + 1 333 def check_unpickler(data, memo_size, marks_size): 334 dump = pickle.dumps(data) 335 u = unpickler(io.BytesIO(dump), 336 encoding='ASCII', errors='strict') 337 u.load() 338 check(u, stdsize + memo_size * P + marks_size * n) 339 340 check_unpickler(0, 32, 0) 341 # 20 is minimal non-empty mark stack size. 342 check_unpickler([0] * 100, 32, 20) 343 # 128 is memo table size required to save references to 100 objects. 344 check_unpickler([chr(i) for i in range(100)], 128, 20) 345 def recurse(deep): 346 data = 0 347 for i in range(deep): 348 data = [data, data] 349 return data 350 check_unpickler(recurse(0), 32, 0) 351 check_unpickler(recurse(1), 32, 20) 352 check_unpickler(recurse(20), 32, 20) 353 check_unpickler(recurse(50), 64, 60) 354 check_unpickler(recurse(100), 128, 140) 355 356 u = unpickler(io.BytesIO(pickle.dumps('a', 0)), 357 encoding='ASCII', errors='strict') 358 u.load() 359 check(u, stdsize + 32 * P + 2 + 1) 360 361 362ALT_IMPORT_MAPPING = { 363 ('_elementtree', 'xml.etree.ElementTree'), 364 ('cPickle', 'pickle'), 365 ('StringIO', 'io'), 366 ('cStringIO', 'io'), 367} 368 369ALT_NAME_MAPPING = { 370 ('__builtin__', 'basestring', 'builtins', 'str'), 371 ('exceptions', 'StandardError', 'builtins', 'Exception'), 372 ('UserDict', 'UserDict', 'collections', 'UserDict'), 373 ('socket', '_socketobject', 'socket', 'SocketType'), 374} 375 376def mapping(module, name): 377 if (module, name) in NAME_MAPPING: 378 module, name = NAME_MAPPING[(module, name)] 379 elif module in IMPORT_MAPPING: 380 module = IMPORT_MAPPING[module] 381 return module, name 382 383def reverse_mapping(module, name): 384 if (module, name) in REVERSE_NAME_MAPPING: 385 module, name = REVERSE_NAME_MAPPING[(module, name)] 386 elif module in REVERSE_IMPORT_MAPPING: 387 module = REVERSE_IMPORT_MAPPING[module] 388 return module, name 389 390def getmodule(module): 391 try: 392 return sys.modules[module] 393 except KeyError: 394 try: 395 with warnings.catch_warnings(): 396 action = 'always' if support.verbose else 'ignore' 397 warnings.simplefilter(action, DeprecationWarning) 398 __import__(module) 399 except AttributeError as exc: 400 if support.verbose: 401 print("Can't import module %r: %s" % (module, exc)) 402 raise ImportError 403 except ImportError as exc: 404 if support.verbose: 405 print(exc) 406 raise 407 return sys.modules[module] 408 409def getattribute(module, name): 410 obj = getmodule(module) 411 for n in name.split('.'): 412 obj = getattr(obj, n) 413 return obj 414 415def get_exceptions(mod): 416 for name in dir(mod): 417 attr = getattr(mod, name) 418 if isinstance(attr, type) and issubclass(attr, BaseException): 419 yield name, attr 420 421class CompatPickleTests(unittest.TestCase): 422 def test_import(self): 423 modules = set(IMPORT_MAPPING.values()) 424 modules |= set(REVERSE_IMPORT_MAPPING) 425 modules |= {module for module, name in REVERSE_NAME_MAPPING} 426 modules |= {module for module, name in NAME_MAPPING.values()} 427 for module in modules: 428 try: 429 getmodule(module) 430 except ImportError: 431 pass 432 433 def test_import_mapping(self): 434 for module3, module2 in REVERSE_IMPORT_MAPPING.items(): 435 with self.subTest((module3, module2)): 436 try: 437 getmodule(module3) 438 except ImportError: 439 pass 440 if module3[:1] != '_': 441 self.assertIn(module2, IMPORT_MAPPING) 442 self.assertEqual(IMPORT_MAPPING[module2], module3) 443 444 def test_name_mapping(self): 445 for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items(): 446 with self.subTest(((module3, name3), (module2, name2))): 447 if (module2, name2) == ('exceptions', 'OSError'): 448 attr = getattribute(module3, name3) 449 self.assertTrue(issubclass(attr, OSError)) 450 elif (module2, name2) == ('exceptions', 'ImportError'): 451 attr = getattribute(module3, name3) 452 self.assertTrue(issubclass(attr, ImportError)) 453 else: 454 module, name = mapping(module2, name2) 455 if module3[:1] != '_': 456 self.assertEqual((module, name), (module3, name3)) 457 try: 458 attr = getattribute(module3, name3) 459 except ImportError: 460 pass 461 else: 462 self.assertEqual(getattribute(module, name), attr) 463 464 def test_reverse_import_mapping(self): 465 for module2, module3 in IMPORT_MAPPING.items(): 466 with self.subTest((module2, module3)): 467 try: 468 getmodule(module3) 469 except ImportError as exc: 470 if support.verbose: 471 print(exc) 472 if ((module2, module3) not in ALT_IMPORT_MAPPING and 473 REVERSE_IMPORT_MAPPING.get(module3, None) != module2): 474 for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items(): 475 if (module3, module2) == (m3, m2): 476 break 477 else: 478 self.fail('No reverse mapping from %r to %r' % 479 (module3, module2)) 480 module = REVERSE_IMPORT_MAPPING.get(module3, module3) 481 module = IMPORT_MAPPING.get(module, module) 482 self.assertEqual(module, module3) 483 484 def test_reverse_name_mapping(self): 485 for (module2, name2), (module3, name3) in NAME_MAPPING.items(): 486 with self.subTest(((module2, name2), (module3, name3))): 487 try: 488 attr = getattribute(module3, name3) 489 except ImportError: 490 pass 491 module, name = reverse_mapping(module3, name3) 492 if (module2, name2, module3, name3) not in ALT_NAME_MAPPING: 493 self.assertEqual((module, name), (module2, name2)) 494 module, name = mapping(module, name) 495 self.assertEqual((module, name), (module3, name3)) 496 497 def test_exceptions(self): 498 self.assertEqual(mapping('exceptions', 'StandardError'), 499 ('builtins', 'Exception')) 500 self.assertEqual(mapping('exceptions', 'Exception'), 501 ('builtins', 'Exception')) 502 self.assertEqual(reverse_mapping('builtins', 'Exception'), 503 ('exceptions', 'Exception')) 504 self.assertEqual(mapping('exceptions', 'OSError'), 505 ('builtins', 'OSError')) 506 self.assertEqual(reverse_mapping('builtins', 'OSError'), 507 ('exceptions', 'OSError')) 508 509 for name, exc in get_exceptions(builtins): 510 with self.subTest(name): 511 if exc in (BlockingIOError, 512 ResourceWarning, 513 StopAsyncIteration, 514 RecursionError, 515 EncodingWarning, 516 BaseExceptionGroup, 517 ExceptionGroup): 518 continue 519 if exc is not OSError and issubclass(exc, OSError): 520 self.assertEqual(reverse_mapping('builtins', name), 521 ('exceptions', 'OSError')) 522 elif exc is not ImportError and issubclass(exc, ImportError): 523 self.assertEqual(reverse_mapping('builtins', name), 524 ('exceptions', 'ImportError')) 525 self.assertEqual(mapping('exceptions', name), 526 ('exceptions', name)) 527 else: 528 self.assertEqual(reverse_mapping('builtins', name), 529 ('exceptions', name)) 530 self.assertEqual(mapping('exceptions', name), 531 ('builtins', name)) 532 533 def test_multiprocessing_exceptions(self): 534 module = import_helper.import_module('multiprocessing.context') 535 for name, exc in get_exceptions(module): 536 with self.subTest(name): 537 self.assertEqual(reverse_mapping('multiprocessing.context', name), 538 ('multiprocessing', name)) 539 self.assertEqual(mapping('multiprocessing', name), 540 ('multiprocessing.context', name)) 541 542 543def load_tests(loader, tests, pattern): 544 tests.addTest(doctest.DocTestSuite()) 545 return tests 546 547 548if __name__ == "__main__": 549 unittest.main() 550