1import abc 2import builtins 3import collections 4import collections.abc 5import copy 6from itertools import permutations 7import pickle 8from random import choice 9import sys 10from test import support 11import threading 12import time 13import typing 14import unittest 15import unittest.mock 16import os 17import weakref 18import gc 19from weakref import proxy 20import contextlib 21 22from test.support import import_helper 23from test.support import threading_helper 24from test.support.script_helper import assert_python_ok 25 26import functools 27 28py_functools = import_helper.import_fresh_module('functools', 29 blocked=['_functools']) 30c_functools = import_helper.import_fresh_module('functools') 31 32decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal']) 33 34@contextlib.contextmanager 35def replaced_module(name, replacement): 36 original_module = sys.modules[name] 37 sys.modules[name] = replacement 38 try: 39 yield 40 finally: 41 sys.modules[name] = original_module 42 43def capture(*args, **kw): 44 """capture all positional and keyword arguments""" 45 return args, kw 46 47 48def signature(part): 49 """ return the signature of a partial object """ 50 return (part.func, part.args, part.keywords, part.__dict__) 51 52class MyTuple(tuple): 53 pass 54 55class BadTuple(tuple): 56 def __add__(self, other): 57 return list(self) + list(other) 58 59class MyDict(dict): 60 pass 61 62 63class TestPartial: 64 65 def test_basic_examples(self): 66 p = self.partial(capture, 1, 2, a=10, b=20) 67 self.assertTrue(callable(p)) 68 self.assertEqual(p(3, 4, b=30, c=40), 69 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 70 p = self.partial(map, lambda x: x*10) 71 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) 72 73 def test_attributes(self): 74 p = self.partial(capture, 1, 2, a=10, b=20) 75 # attributes should be readable 76 self.assertEqual(p.func, capture) 77 self.assertEqual(p.args, (1, 2)) 78 self.assertEqual(p.keywords, dict(a=10, b=20)) 79 80 def test_argument_checking(self): 81 self.assertRaises(TypeError, self.partial) # need at least a func arg 82 try: 83 self.partial(2)() 84 except TypeError: 85 pass 86 else: 87 self.fail('First arg not checked for callability') 88 89 def test_protection_of_callers_dict_argument(self): 90 # a caller's dictionary should not be altered by partial 91 def func(a=10, b=20): 92 return a 93 d = {'a':3} 94 p = self.partial(func, a=5) 95 self.assertEqual(p(**d), 3) 96 self.assertEqual(d, {'a':3}) 97 p(b=7) 98 self.assertEqual(d, {'a':3}) 99 100 def test_kwargs_copy(self): 101 # Issue #29532: Altering a kwarg dictionary passed to a constructor 102 # should not affect a partial object after creation 103 d = {'a': 3} 104 p = self.partial(capture, **d) 105 self.assertEqual(p(), ((), {'a': 3})) 106 d['a'] = 5 107 self.assertEqual(p(), ((), {'a': 3})) 108 109 def test_arg_combinations(self): 110 # exercise special code paths for zero args in either partial 111 # object or the caller 112 p = self.partial(capture) 113 self.assertEqual(p(), ((), {})) 114 self.assertEqual(p(1,2), ((1,2), {})) 115 p = self.partial(capture, 1, 2) 116 self.assertEqual(p(), ((1,2), {})) 117 self.assertEqual(p(3,4), ((1,2,3,4), {})) 118 119 def test_kw_combinations(self): 120 # exercise special code paths for no keyword args in 121 # either the partial object or the caller 122 p = self.partial(capture) 123 self.assertEqual(p.keywords, {}) 124 self.assertEqual(p(), ((), {})) 125 self.assertEqual(p(a=1), ((), {'a':1})) 126 p = self.partial(capture, a=1) 127 self.assertEqual(p.keywords, {'a':1}) 128 self.assertEqual(p(), ((), {'a':1})) 129 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 130 # keyword args in the call override those in the partial object 131 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 132 133 def test_positional(self): 134 # make sure positional arguments are captured correctly 135 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 136 p = self.partial(capture, *args) 137 expected = args + ('x',) 138 got, empty = p('x') 139 self.assertTrue(expected == got and empty == {}) 140 141 def test_keyword(self): 142 # make sure keyword arguments are captured correctly 143 for a in ['a', 0, None, 3.5]: 144 p = self.partial(capture, a=a) 145 expected = {'a':a,'x':None} 146 empty, got = p(x=None) 147 self.assertTrue(expected == got and empty == ()) 148 149 def test_no_side_effects(self): 150 # make sure there are no side effects that affect subsequent calls 151 p = self.partial(capture, 0, a=1) 152 args1, kw1 = p(1, b=2) 153 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 154 args2, kw2 = p() 155 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 156 157 def test_error_propagation(self): 158 def f(x, y): 159 x / y 160 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 161 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 162 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 163 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 164 165 def test_weakref(self): 166 f = self.partial(int, base=16) 167 p = proxy(f) 168 self.assertEqual(f.func, p.func) 169 f = None 170 support.gc_collect() # For PyPy or other GCs. 171 self.assertRaises(ReferenceError, getattr, p, 'func') 172 173 def test_with_bound_and_unbound_methods(self): 174 data = list(map(str, range(10))) 175 join = self.partial(str.join, '') 176 self.assertEqual(join(data), '0123456789') 177 join = self.partial(''.join) 178 self.assertEqual(join(data), '0123456789') 179 180 def test_nested_optimization(self): 181 partial = self.partial 182 inner = partial(signature, 'asdf') 183 nested = partial(inner, bar=True) 184 flat = partial(signature, 'asdf', bar=True) 185 self.assertEqual(signature(nested), signature(flat)) 186 187 def test_nested_partial_with_attribute(self): 188 # see issue 25137 189 partial = self.partial 190 191 def foo(bar): 192 return bar 193 194 p = partial(foo, 'first') 195 p2 = partial(p, 'second') 196 p2.new_attr = 'spam' 197 self.assertEqual(p2.new_attr, 'spam') 198 199 def test_repr(self): 200 args = (object(), object()) 201 args_repr = ', '.join(repr(a) for a in args) 202 kwargs = {'a': object(), 'b': object()} 203 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 204 'b={b!r}, a={a!r}'.format_map(kwargs)] 205 if self.partial in (c_functools.partial, py_functools.partial): 206 name = 'functools.partial' 207 else: 208 name = self.partial.__name__ 209 210 f = self.partial(capture) 211 self.assertEqual(f'{name}({capture!r})', repr(f)) 212 213 f = self.partial(capture, *args) 214 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) 215 216 f = self.partial(capture, **kwargs) 217 self.assertIn(repr(f), 218 [f'{name}({capture!r}, {kwargs_repr})' 219 for kwargs_repr in kwargs_reprs]) 220 221 f = self.partial(capture, *args, **kwargs) 222 self.assertIn(repr(f), 223 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' 224 for kwargs_repr in kwargs_reprs]) 225 226 def test_recursive_repr(self): 227 if self.partial in (c_functools.partial, py_functools.partial): 228 name = 'functools.partial' 229 else: 230 name = self.partial.__name__ 231 232 f = self.partial(capture) 233 f.__setstate__((f, (), {}, {})) 234 try: 235 self.assertEqual(repr(f), '%s(...)' % (name,)) 236 finally: 237 f.__setstate__((capture, (), {}, {})) 238 239 f = self.partial(capture) 240 f.__setstate__((capture, (f,), {}, {})) 241 try: 242 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) 243 finally: 244 f.__setstate__((capture, (), {}, {})) 245 246 f = self.partial(capture) 247 f.__setstate__((capture, (), {'a': f}, {})) 248 try: 249 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) 250 finally: 251 f.__setstate__((capture, (), {}, {})) 252 253 def test_pickle(self): 254 with self.AllowPickle(): 255 f = self.partial(signature, ['asdf'], bar=[True]) 256 f.attr = [] 257 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 258 f_copy = pickle.loads(pickle.dumps(f, proto)) 259 self.assertEqual(signature(f_copy), signature(f)) 260 261 def test_copy(self): 262 f = self.partial(signature, ['asdf'], bar=[True]) 263 f.attr = [] 264 f_copy = copy.copy(f) 265 self.assertEqual(signature(f_copy), signature(f)) 266 self.assertIs(f_copy.attr, f.attr) 267 self.assertIs(f_copy.args, f.args) 268 self.assertIs(f_copy.keywords, f.keywords) 269 270 def test_deepcopy(self): 271 f = self.partial(signature, ['asdf'], bar=[True]) 272 f.attr = [] 273 f_copy = copy.deepcopy(f) 274 self.assertEqual(signature(f_copy), signature(f)) 275 self.assertIsNot(f_copy.attr, f.attr) 276 self.assertIsNot(f_copy.args, f.args) 277 self.assertIsNot(f_copy.args[0], f.args[0]) 278 self.assertIsNot(f_copy.keywords, f.keywords) 279 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) 280 281 def test_setstate(self): 282 f = self.partial(signature) 283 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) 284 285 self.assertEqual(signature(f), 286 (capture, (1,), dict(a=10), dict(attr=[]))) 287 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 288 289 f.__setstate__((capture, (1,), dict(a=10), None)) 290 291 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) 292 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 293 294 f.__setstate__((capture, (1,), None, None)) 295 #self.assertEqual(signature(f), (capture, (1,), {}, {})) 296 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) 297 self.assertEqual(f(2), ((1, 2), {})) 298 self.assertEqual(f(), ((1,), {})) 299 300 f.__setstate__((capture, (), {}, None)) 301 self.assertEqual(signature(f), (capture, (), {}, {})) 302 self.assertEqual(f(2, b=20), ((2,), {'b': 20})) 303 self.assertEqual(f(2), ((2,), {})) 304 self.assertEqual(f(), ((), {})) 305 306 def test_setstate_errors(self): 307 f = self.partial(signature) 308 self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) 309 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) 310 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) 311 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) 312 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) 313 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) 314 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) 315 316 def test_setstate_subclasses(self): 317 f = self.partial(signature) 318 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) 319 s = signature(f) 320 self.assertEqual(s, (capture, (1,), dict(a=10), {})) 321 self.assertIs(type(s[1]), tuple) 322 self.assertIs(type(s[2]), dict) 323 r = f() 324 self.assertEqual(r, ((1,), {'a': 10})) 325 self.assertIs(type(r[0]), tuple) 326 self.assertIs(type(r[1]), dict) 327 328 f.__setstate__((capture, BadTuple((1,)), {}, None)) 329 s = signature(f) 330 self.assertEqual(s, (capture, (1,), {}, {})) 331 self.assertIs(type(s[1]), tuple) 332 r = f(2) 333 self.assertEqual(r, ((1, 2), {})) 334 self.assertIs(type(r[0]), tuple) 335 336 def test_recursive_pickle(self): 337 with self.AllowPickle(): 338 f = self.partial(capture) 339 f.__setstate__((f, (), {}, {})) 340 try: 341 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 342 with self.assertRaises(RecursionError): 343 pickle.dumps(f, proto) 344 finally: 345 f.__setstate__((capture, (), {}, {})) 346 347 f = self.partial(capture) 348 f.__setstate__((capture, (f,), {}, {})) 349 try: 350 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 351 f_copy = pickle.loads(pickle.dumps(f, proto)) 352 try: 353 self.assertIs(f_copy.args[0], f_copy) 354 finally: 355 f_copy.__setstate__((capture, (), {}, {})) 356 finally: 357 f.__setstate__((capture, (), {}, {})) 358 359 f = self.partial(capture) 360 f.__setstate__((capture, (), {'a': f}, {})) 361 try: 362 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 363 f_copy = pickle.loads(pickle.dumps(f, proto)) 364 try: 365 self.assertIs(f_copy.keywords['a'], f_copy) 366 finally: 367 f_copy.__setstate__((capture, (), {}, {})) 368 finally: 369 f.__setstate__((capture, (), {}, {})) 370 371 # Issue 6083: Reference counting bug 372 def test_setstate_refcount(self): 373 class BadSequence: 374 def __len__(self): 375 return 4 376 def __getitem__(self, key): 377 if key == 0: 378 return max 379 elif key == 1: 380 return tuple(range(1000000)) 381 elif key in (2, 3): 382 return {} 383 raise IndexError 384 385 f = self.partial(object) 386 self.assertRaises(TypeError, f.__setstate__, BadSequence()) 387 388@unittest.skipUnless(c_functools, 'requires the C _functools module') 389class TestPartialC(TestPartial, unittest.TestCase): 390 if c_functools: 391 partial = c_functools.partial 392 393 class AllowPickle: 394 def __enter__(self): 395 return self 396 def __exit__(self, type, value, tb): 397 return False 398 399 def test_attributes_unwritable(self): 400 # attributes should not be writable 401 p = self.partial(capture, 1, 2, a=10, b=20) 402 self.assertRaises(AttributeError, setattr, p, 'func', map) 403 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) 404 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) 405 406 p = self.partial(hex) 407 try: 408 del p.__dict__ 409 except TypeError: 410 pass 411 else: 412 self.fail('partial object allowed __dict__ to be deleted') 413 414 def test_manually_adding_non_string_keyword(self): 415 p = self.partial(capture) 416 # Adding a non-string/unicode keyword to partial kwargs 417 p.keywords[1234] = 'value' 418 r = repr(p) 419 self.assertIn('1234', r) 420 self.assertIn("'value'", r) 421 with self.assertRaises(TypeError): 422 p() 423 424 def test_keystr_replaces_value(self): 425 p = self.partial(capture) 426 427 class MutatesYourDict(object): 428 def __str__(self): 429 p.keywords[self] = ['sth2'] 430 return 'astr' 431 432 # Replacing the value during key formatting should keep the original 433 # value alive (at least long enough). 434 p.keywords[MutatesYourDict()] = ['sth'] 435 r = repr(p) 436 self.assertIn('astr', r) 437 self.assertIn("['sth']", r) 438 439 440class TestPartialPy(TestPartial, unittest.TestCase): 441 partial = py_functools.partial 442 443 class AllowPickle: 444 def __init__(self): 445 self._cm = replaced_module("functools", py_functools) 446 def __enter__(self): 447 return self._cm.__enter__() 448 def __exit__(self, type, value, tb): 449 return self._cm.__exit__(type, value, tb) 450 451if c_functools: 452 class CPartialSubclass(c_functools.partial): 453 pass 454 455class PyPartialSubclass(py_functools.partial): 456 pass 457 458@unittest.skipUnless(c_functools, 'requires the C _functools module') 459class TestPartialCSubclass(TestPartialC): 460 if c_functools: 461 partial = CPartialSubclass 462 463 # partial subclasses are not optimized for nested calls 464 test_nested_optimization = None 465 466class TestPartialPySubclass(TestPartialPy): 467 partial = PyPartialSubclass 468 469class TestPartialMethod(unittest.TestCase): 470 471 class A(object): 472 nothing = functools.partialmethod(capture) 473 positional = functools.partialmethod(capture, 1) 474 keywords = functools.partialmethod(capture, a=2) 475 both = functools.partialmethod(capture, 3, b=4) 476 spec_keywords = functools.partialmethod(capture, self=1, func=2) 477 478 nested = functools.partialmethod(positional, 5) 479 480 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) 481 482 static = functools.partialmethod(staticmethod(capture), 8) 483 cls = functools.partialmethod(classmethod(capture), d=9) 484 485 a = A() 486 487 def test_arg_combinations(self): 488 self.assertEqual(self.a.nothing(), ((self.a,), {})) 489 self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) 490 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) 491 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) 492 493 self.assertEqual(self.a.positional(), ((self.a, 1), {})) 494 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) 495 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) 496 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) 497 498 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) 499 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) 500 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) 501 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) 502 503 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) 504 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) 505 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) 506 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 507 508 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 509 510 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2})) 511 512 def test_nested(self): 513 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) 514 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) 515 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) 516 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 517 518 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 519 520 def test_over_partial(self): 521 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) 522 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) 523 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) 524 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 525 526 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 527 528 def test_bound_method_introspection(self): 529 obj = self.a 530 self.assertIs(obj.both.__self__, obj) 531 self.assertIs(obj.nested.__self__, obj) 532 self.assertIs(obj.over_partial.__self__, obj) 533 self.assertIs(obj.cls.__self__, self.A) 534 self.assertIs(self.A.cls.__self__, self.A) 535 536 def test_unbound_method_retrieval(self): 537 obj = self.A 538 self.assertFalse(hasattr(obj.both, "__self__")) 539 self.assertFalse(hasattr(obj.nested, "__self__")) 540 self.assertFalse(hasattr(obj.over_partial, "__self__")) 541 self.assertFalse(hasattr(obj.static, "__self__")) 542 self.assertFalse(hasattr(self.a.static, "__self__")) 543 544 def test_descriptors(self): 545 for obj in [self.A, self.a]: 546 with self.subTest(obj=obj): 547 self.assertEqual(obj.static(), ((8,), {})) 548 self.assertEqual(obj.static(5), ((8, 5), {})) 549 self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) 550 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) 551 552 self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) 553 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) 554 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) 555 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) 556 557 def test_overriding_keywords(self): 558 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) 559 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) 560 561 def test_invalid_args(self): 562 with self.assertRaises(TypeError): 563 class B(object): 564 method = functools.partialmethod(None, 1) 565 with self.assertRaises(TypeError): 566 class B: 567 method = functools.partialmethod() 568 with self.assertRaises(TypeError): 569 class B: 570 method = functools.partialmethod(func=capture, a=1) 571 572 def test_repr(self): 573 self.assertEqual(repr(vars(self.A)['both']), 574 'functools.partialmethod({}, 3, b=4)'.format(capture)) 575 576 def test_abstract(self): 577 class Abstract(abc.ABCMeta): 578 579 @abc.abstractmethod 580 def add(self, x, y): 581 pass 582 583 add5 = functools.partialmethod(add, 5) 584 585 self.assertTrue(Abstract.add.__isabstractmethod__) 586 self.assertTrue(Abstract.add5.__isabstractmethod__) 587 588 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: 589 self.assertFalse(getattr(func, '__isabstractmethod__', False)) 590 591 def test_positional_only(self): 592 def f(a, b, /): 593 return a + b 594 595 p = functools.partial(f, 1) 596 self.assertEqual(p(2), f(1, 2)) 597 598 599class TestUpdateWrapper(unittest.TestCase): 600 601 def check_wrapper(self, wrapper, wrapped, 602 assigned=functools.WRAPPER_ASSIGNMENTS, 603 updated=functools.WRAPPER_UPDATES): 604 # Check attributes were assigned 605 for name in assigned: 606 self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) 607 # Check attributes were updated 608 for name in updated: 609 wrapper_attr = getattr(wrapper, name) 610 wrapped_attr = getattr(wrapped, name) 611 for key in wrapped_attr: 612 if name == "__dict__" and key == "__wrapped__": 613 # __wrapped__ is overwritten by the update code 614 continue 615 self.assertIs(wrapped_attr[key], wrapper_attr[key]) 616 # Check __wrapped__ 617 self.assertIs(wrapper.__wrapped__, wrapped) 618 619 620 def _default_update(self): 621 def f(a:'This is a new annotation'): 622 """This is a test""" 623 pass 624 f.attr = 'This is also a test' 625 f.__wrapped__ = "This is a bald faced lie" 626 def wrapper(b:'This is the prior annotation'): 627 pass 628 functools.update_wrapper(wrapper, f) 629 return wrapper, f 630 631 def test_default_update(self): 632 wrapper, f = self._default_update() 633 self.check_wrapper(wrapper, f) 634 self.assertIs(wrapper.__wrapped__, f) 635 self.assertEqual(wrapper.__name__, 'f') 636 self.assertEqual(wrapper.__qualname__, f.__qualname__) 637 self.assertEqual(wrapper.attr, 'This is also a test') 638 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') 639 self.assertNotIn('b', wrapper.__annotations__) 640 641 @unittest.skipIf(sys.flags.optimize >= 2, 642 "Docstrings are omitted with -O2 and above") 643 def test_default_update_doc(self): 644 wrapper, f = self._default_update() 645 self.assertEqual(wrapper.__doc__, 'This is a test') 646 647 def test_no_update(self): 648 def f(): 649 """This is a test""" 650 pass 651 f.attr = 'This is also a test' 652 def wrapper(): 653 pass 654 functools.update_wrapper(wrapper, f, (), ()) 655 self.check_wrapper(wrapper, f, (), ()) 656 self.assertEqual(wrapper.__name__, 'wrapper') 657 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 658 self.assertEqual(wrapper.__doc__, None) 659 self.assertEqual(wrapper.__annotations__, {}) 660 self.assertFalse(hasattr(wrapper, 'attr')) 661 662 def test_selective_update(self): 663 def f(): 664 pass 665 f.attr = 'This is a different test' 666 f.dict_attr = dict(a=1, b=2, c=3) 667 def wrapper(): 668 pass 669 wrapper.dict_attr = {} 670 assign = ('attr',) 671 update = ('dict_attr',) 672 functools.update_wrapper(wrapper, f, assign, update) 673 self.check_wrapper(wrapper, f, assign, update) 674 self.assertEqual(wrapper.__name__, 'wrapper') 675 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 676 self.assertEqual(wrapper.__doc__, None) 677 self.assertEqual(wrapper.attr, 'This is a different test') 678 self.assertEqual(wrapper.dict_attr, f.dict_attr) 679 680 def test_missing_attributes(self): 681 def f(): 682 pass 683 def wrapper(): 684 pass 685 wrapper.dict_attr = {} 686 assign = ('attr',) 687 update = ('dict_attr',) 688 # Missing attributes on wrapped object are ignored 689 functools.update_wrapper(wrapper, f, assign, update) 690 self.assertNotIn('attr', wrapper.__dict__) 691 self.assertEqual(wrapper.dict_attr, {}) 692 # Wrapper must have expected attributes for updating 693 del wrapper.dict_attr 694 with self.assertRaises(AttributeError): 695 functools.update_wrapper(wrapper, f, assign, update) 696 wrapper.dict_attr = 1 697 with self.assertRaises(AttributeError): 698 functools.update_wrapper(wrapper, f, assign, update) 699 700 @support.requires_docstrings 701 @unittest.skipIf(sys.flags.optimize >= 2, 702 "Docstrings are omitted with -O2 and above") 703 def test_builtin_update(self): 704 # Test for bug #1576241 705 def wrapper(): 706 pass 707 functools.update_wrapper(wrapper, max) 708 self.assertEqual(wrapper.__name__, 'max') 709 self.assertTrue(wrapper.__doc__.startswith('max(')) 710 self.assertEqual(wrapper.__annotations__, {}) 711 712 713class TestWraps(TestUpdateWrapper): 714 715 def _default_update(self): 716 def f(): 717 """This is a test""" 718 pass 719 f.attr = 'This is also a test' 720 f.__wrapped__ = "This is still a bald faced lie" 721 @functools.wraps(f) 722 def wrapper(): 723 pass 724 return wrapper, f 725 726 def test_default_update(self): 727 wrapper, f = self._default_update() 728 self.check_wrapper(wrapper, f) 729 self.assertEqual(wrapper.__name__, 'f') 730 self.assertEqual(wrapper.__qualname__, f.__qualname__) 731 self.assertEqual(wrapper.attr, 'This is also a test') 732 733 @unittest.skipIf(sys.flags.optimize >= 2, 734 "Docstrings are omitted with -O2 and above") 735 def test_default_update_doc(self): 736 wrapper, _ = self._default_update() 737 self.assertEqual(wrapper.__doc__, 'This is a test') 738 739 def test_no_update(self): 740 def f(): 741 """This is a test""" 742 pass 743 f.attr = 'This is also a test' 744 @functools.wraps(f, (), ()) 745 def wrapper(): 746 pass 747 self.check_wrapper(wrapper, f, (), ()) 748 self.assertEqual(wrapper.__name__, 'wrapper') 749 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 750 self.assertEqual(wrapper.__doc__, None) 751 self.assertFalse(hasattr(wrapper, 'attr')) 752 753 def test_selective_update(self): 754 def f(): 755 pass 756 f.attr = 'This is a different test' 757 f.dict_attr = dict(a=1, b=2, c=3) 758 def add_dict_attr(f): 759 f.dict_attr = {} 760 return f 761 assign = ('attr',) 762 update = ('dict_attr',) 763 @functools.wraps(f, assign, update) 764 @add_dict_attr 765 def wrapper(): 766 pass 767 self.check_wrapper(wrapper, f, assign, update) 768 self.assertEqual(wrapper.__name__, 'wrapper') 769 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 770 self.assertEqual(wrapper.__doc__, None) 771 self.assertEqual(wrapper.attr, 'This is a different test') 772 self.assertEqual(wrapper.dict_attr, f.dict_attr) 773 774 775class TestReduce: 776 def test_reduce(self): 777 class Squares: 778 def __init__(self, max): 779 self.max = max 780 self.sofar = [] 781 782 def __len__(self): 783 return len(self.sofar) 784 785 def __getitem__(self, i): 786 if not 0 <= i < self.max: raise IndexError 787 n = len(self.sofar) 788 while n <= i: 789 self.sofar.append(n*n) 790 n += 1 791 return self.sofar[i] 792 def add(x, y): 793 return x + y 794 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc') 795 self.assertEqual( 796 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []), 797 ['a','c','d','w'] 798 ) 799 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040) 800 self.assertEqual( 801 self.reduce(lambda x, y: x*y, range(2,21), 1), 802 2432902008176640000 803 ) 804 self.assertEqual(self.reduce(add, Squares(10)), 285) 805 self.assertEqual(self.reduce(add, Squares(10), 0), 285) 806 self.assertEqual(self.reduce(add, Squares(0), 0), 0) 807 self.assertRaises(TypeError, self.reduce) 808 self.assertRaises(TypeError, self.reduce, 42, 42) 809 self.assertRaises(TypeError, self.reduce, 42, 42, 42) 810 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item 811 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item 812 self.assertRaises(TypeError, self.reduce, 42, (42, 42)) 813 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value 814 self.assertRaises(TypeError, self.reduce, add, "") 815 self.assertRaises(TypeError, self.reduce, add, ()) 816 self.assertRaises(TypeError, self.reduce, add, object()) 817 818 class TestFailingIter: 819 def __iter__(self): 820 raise RuntimeError 821 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter()) 822 823 self.assertEqual(self.reduce(add, [], None), None) 824 self.assertEqual(self.reduce(add, [], 42), 42) 825 826 class BadSeq: 827 def __getitem__(self, index): 828 raise ValueError 829 self.assertRaises(ValueError, self.reduce, 42, BadSeq()) 830 831 # Test reduce()'s use of iterators. 832 def test_iterator_usage(self): 833 class SequenceClass: 834 def __init__(self, n): 835 self.n = n 836 def __getitem__(self, i): 837 if 0 <= i < self.n: 838 return i 839 else: 840 raise IndexError 841 842 from operator import add 843 self.assertEqual(self.reduce(add, SequenceClass(5)), 10) 844 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52) 845 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0)) 846 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42) 847 self.assertEqual(self.reduce(add, SequenceClass(1)), 0) 848 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42) 849 850 d = {"one": 1, "two": 2, "three": 3} 851 self.assertEqual(self.reduce(add, d), "".join(d.keys())) 852 853 854@unittest.skipUnless(c_functools, 'requires the C _functools module') 855class TestReduceC(TestReduce, unittest.TestCase): 856 if c_functools: 857 reduce = c_functools.reduce 858 859 860class TestReducePy(TestReduce, unittest.TestCase): 861 reduce = staticmethod(py_functools.reduce) 862 863 864class TestCmpToKey: 865 866 def test_cmp_to_key(self): 867 def cmp1(x, y): 868 return (x > y) - (x < y) 869 key = self.cmp_to_key(cmp1) 870 self.assertEqual(key(3), key(3)) 871 self.assertGreater(key(3), key(1)) 872 self.assertGreaterEqual(key(3), key(3)) 873 874 def cmp2(x, y): 875 return int(x) - int(y) 876 key = self.cmp_to_key(cmp2) 877 self.assertEqual(key(4.0), key('4')) 878 self.assertLess(key(2), key('35')) 879 self.assertLessEqual(key(2), key('35')) 880 self.assertNotEqual(key(2), key('35')) 881 882 def test_cmp_to_key_arguments(self): 883 def cmp1(x, y): 884 return (x > y) - (x < y) 885 key = self.cmp_to_key(mycmp=cmp1) 886 self.assertEqual(key(obj=3), key(obj=3)) 887 self.assertGreater(key(obj=3), key(obj=1)) 888 with self.assertRaises((TypeError, AttributeError)): 889 key(3) > 1 # rhs is not a K object 890 with self.assertRaises((TypeError, AttributeError)): 891 1 < key(3) # lhs is not a K object 892 with self.assertRaises(TypeError): 893 key = self.cmp_to_key() # too few args 894 with self.assertRaises(TypeError): 895 key = self.cmp_to_key(cmp1, None) # too many args 896 key = self.cmp_to_key(cmp1) 897 with self.assertRaises(TypeError): 898 key() # too few args 899 with self.assertRaises(TypeError): 900 key(None, None) # too many args 901 902 def test_bad_cmp(self): 903 def cmp1(x, y): 904 raise ZeroDivisionError 905 key = self.cmp_to_key(cmp1) 906 with self.assertRaises(ZeroDivisionError): 907 key(3) > key(1) 908 909 class BadCmp: 910 def __lt__(self, other): 911 raise ZeroDivisionError 912 def cmp1(x, y): 913 return BadCmp() 914 with self.assertRaises(ZeroDivisionError): 915 key(3) > key(1) 916 917 def test_obj_field(self): 918 def cmp1(x, y): 919 return (x > y) - (x < y) 920 key = self.cmp_to_key(mycmp=cmp1) 921 self.assertEqual(key(50).obj, 50) 922 923 def test_sort_int(self): 924 def mycmp(x, y): 925 return y - x 926 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), 927 [4, 3, 2, 1, 0]) 928 929 def test_sort_int_str(self): 930 def mycmp(x, y): 931 x, y = int(x), int(y) 932 return (x > y) - (x < y) 933 values = [5, '3', 7, 2, '0', '1', 4, '10', 1] 934 values = sorted(values, key=self.cmp_to_key(mycmp)) 935 self.assertEqual([int(value) for value in values], 936 [0, 1, 1, 2, 3, 4, 5, 7, 10]) 937 938 def test_hash(self): 939 def mycmp(x, y): 940 return y - x 941 key = self.cmp_to_key(mycmp) 942 k = key(10) 943 self.assertRaises(TypeError, hash, k) 944 self.assertNotIsInstance(k, collections.abc.Hashable) 945 946 947@unittest.skipUnless(c_functools, 'requires the C _functools module') 948class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): 949 if c_functools: 950 cmp_to_key = c_functools.cmp_to_key 951 952 @support.cpython_only 953 def test_disallow_instantiation(self): 954 # Ensure that the type disallows instantiation (bpo-43916) 955 support.check_disallow_instantiation( 956 self, type(c_functools.cmp_to_key(None)) 957 ) 958 959 960class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): 961 cmp_to_key = staticmethod(py_functools.cmp_to_key) 962 963 964class TestTotalOrdering(unittest.TestCase): 965 966 def test_total_ordering_lt(self): 967 @functools.total_ordering 968 class A: 969 def __init__(self, value): 970 self.value = value 971 def __lt__(self, other): 972 return self.value < other.value 973 def __eq__(self, other): 974 return self.value == other.value 975 self.assertTrue(A(1) < A(2)) 976 self.assertTrue(A(2) > A(1)) 977 self.assertTrue(A(1) <= A(2)) 978 self.assertTrue(A(2) >= A(1)) 979 self.assertTrue(A(2) <= A(2)) 980 self.assertTrue(A(2) >= A(2)) 981 self.assertFalse(A(1) > A(2)) 982 983 def test_total_ordering_le(self): 984 @functools.total_ordering 985 class A: 986 def __init__(self, value): 987 self.value = value 988 def __le__(self, other): 989 return self.value <= other.value 990 def __eq__(self, other): 991 return self.value == other.value 992 self.assertTrue(A(1) < A(2)) 993 self.assertTrue(A(2) > A(1)) 994 self.assertTrue(A(1) <= A(2)) 995 self.assertTrue(A(2) >= A(1)) 996 self.assertTrue(A(2) <= A(2)) 997 self.assertTrue(A(2) >= A(2)) 998 self.assertFalse(A(1) >= A(2)) 999 1000 def test_total_ordering_gt(self): 1001 @functools.total_ordering 1002 class A: 1003 def __init__(self, value): 1004 self.value = value 1005 def __gt__(self, other): 1006 return self.value > other.value 1007 def __eq__(self, other): 1008 return self.value == other.value 1009 self.assertTrue(A(1) < A(2)) 1010 self.assertTrue(A(2) > A(1)) 1011 self.assertTrue(A(1) <= A(2)) 1012 self.assertTrue(A(2) >= A(1)) 1013 self.assertTrue(A(2) <= A(2)) 1014 self.assertTrue(A(2) >= A(2)) 1015 self.assertFalse(A(2) < A(1)) 1016 1017 def test_total_ordering_ge(self): 1018 @functools.total_ordering 1019 class A: 1020 def __init__(self, value): 1021 self.value = value 1022 def __ge__(self, other): 1023 return self.value >= other.value 1024 def __eq__(self, other): 1025 return self.value == other.value 1026 self.assertTrue(A(1) < A(2)) 1027 self.assertTrue(A(2) > A(1)) 1028 self.assertTrue(A(1) <= A(2)) 1029 self.assertTrue(A(2) >= A(1)) 1030 self.assertTrue(A(2) <= A(2)) 1031 self.assertTrue(A(2) >= A(2)) 1032 self.assertFalse(A(2) <= A(1)) 1033 1034 def test_total_ordering_no_overwrite(self): 1035 # new methods should not overwrite existing 1036 @functools.total_ordering 1037 class A(int): 1038 pass 1039 self.assertTrue(A(1) < A(2)) 1040 self.assertTrue(A(2) > A(1)) 1041 self.assertTrue(A(1) <= A(2)) 1042 self.assertTrue(A(2) >= A(1)) 1043 self.assertTrue(A(2) <= A(2)) 1044 self.assertTrue(A(2) >= A(2)) 1045 1046 def test_no_operations_defined(self): 1047 with self.assertRaises(ValueError): 1048 @functools.total_ordering 1049 class A: 1050 pass 1051 1052 def test_notimplemented(self): 1053 # Verify NotImplemented results are correctly handled 1054 @functools.total_ordering 1055 class ImplementsLessThan: 1056 def __init__(self, value): 1057 self.value = value 1058 def __eq__(self, other): 1059 if isinstance(other, ImplementsLessThan): 1060 return self.value == other.value 1061 return False 1062 def __lt__(self, other): 1063 if isinstance(other, ImplementsLessThan): 1064 return self.value < other.value 1065 return NotImplemented 1066 1067 @functools.total_ordering 1068 class ImplementsLessThanEqualTo: 1069 def __init__(self, value): 1070 self.value = value 1071 def __eq__(self, other): 1072 if isinstance(other, ImplementsLessThanEqualTo): 1073 return self.value == other.value 1074 return False 1075 def __le__(self, other): 1076 if isinstance(other, ImplementsLessThanEqualTo): 1077 return self.value <= other.value 1078 return NotImplemented 1079 1080 @functools.total_ordering 1081 class ImplementsGreaterThan: 1082 def __init__(self, value): 1083 self.value = value 1084 def __eq__(self, other): 1085 if isinstance(other, ImplementsGreaterThan): 1086 return self.value == other.value 1087 return False 1088 def __gt__(self, other): 1089 if isinstance(other, ImplementsGreaterThan): 1090 return self.value > other.value 1091 return NotImplemented 1092 1093 @functools.total_ordering 1094 class ImplementsGreaterThanEqualTo: 1095 def __init__(self, value): 1096 self.value = value 1097 def __eq__(self, other): 1098 if isinstance(other, ImplementsGreaterThanEqualTo): 1099 return self.value == other.value 1100 return False 1101 def __ge__(self, other): 1102 if isinstance(other, ImplementsGreaterThanEqualTo): 1103 return self.value >= other.value 1104 return NotImplemented 1105 1106 self.assertIs(ImplementsLessThan(1).__le__(1), NotImplemented) 1107 self.assertIs(ImplementsLessThan(1).__gt__(1), NotImplemented) 1108 self.assertIs(ImplementsLessThan(1).__ge__(1), NotImplemented) 1109 self.assertIs(ImplementsLessThanEqualTo(1).__lt__(1), NotImplemented) 1110 self.assertIs(ImplementsLessThanEqualTo(1).__gt__(1), NotImplemented) 1111 self.assertIs(ImplementsLessThanEqualTo(1).__ge__(1), NotImplemented) 1112 self.assertIs(ImplementsGreaterThan(1).__lt__(1), NotImplemented) 1113 self.assertIs(ImplementsGreaterThan(1).__gt__(1), NotImplemented) 1114 self.assertIs(ImplementsGreaterThan(1).__ge__(1), NotImplemented) 1115 self.assertIs(ImplementsGreaterThanEqualTo(1).__lt__(1), NotImplemented) 1116 self.assertIs(ImplementsGreaterThanEqualTo(1).__le__(1), NotImplemented) 1117 self.assertIs(ImplementsGreaterThanEqualTo(1).__gt__(1), NotImplemented) 1118 1119 def test_type_error_when_not_implemented(self): 1120 # bug 10042; ensure stack overflow does not occur 1121 # when decorated types return NotImplemented 1122 @functools.total_ordering 1123 class ImplementsLessThan: 1124 def __init__(self, value): 1125 self.value = value 1126 def __eq__(self, other): 1127 if isinstance(other, ImplementsLessThan): 1128 return self.value == other.value 1129 return False 1130 def __lt__(self, other): 1131 if isinstance(other, ImplementsLessThan): 1132 return self.value < other.value 1133 return NotImplemented 1134 1135 @functools.total_ordering 1136 class ImplementsGreaterThan: 1137 def __init__(self, value): 1138 self.value = value 1139 def __eq__(self, other): 1140 if isinstance(other, ImplementsGreaterThan): 1141 return self.value == other.value 1142 return False 1143 def __gt__(self, other): 1144 if isinstance(other, ImplementsGreaterThan): 1145 return self.value > other.value 1146 return NotImplemented 1147 1148 @functools.total_ordering 1149 class ImplementsLessThanEqualTo: 1150 def __init__(self, value): 1151 self.value = value 1152 def __eq__(self, other): 1153 if isinstance(other, ImplementsLessThanEqualTo): 1154 return self.value == other.value 1155 return False 1156 def __le__(self, other): 1157 if isinstance(other, ImplementsLessThanEqualTo): 1158 return self.value <= other.value 1159 return NotImplemented 1160 1161 @functools.total_ordering 1162 class ImplementsGreaterThanEqualTo: 1163 def __init__(self, value): 1164 self.value = value 1165 def __eq__(self, other): 1166 if isinstance(other, ImplementsGreaterThanEqualTo): 1167 return self.value == other.value 1168 return False 1169 def __ge__(self, other): 1170 if isinstance(other, ImplementsGreaterThanEqualTo): 1171 return self.value >= other.value 1172 return NotImplemented 1173 1174 @functools.total_ordering 1175 class ComparatorNotImplemented: 1176 def __init__(self, value): 1177 self.value = value 1178 def __eq__(self, other): 1179 if isinstance(other, ComparatorNotImplemented): 1180 return self.value == other.value 1181 return False 1182 def __lt__(self, other): 1183 return NotImplemented 1184 1185 with self.subTest("LT < 1"), self.assertRaises(TypeError): 1186 ImplementsLessThan(-1) < 1 1187 1188 with self.subTest("LT < LE"), self.assertRaises(TypeError): 1189 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) 1190 1191 with self.subTest("LT < GT"), self.assertRaises(TypeError): 1192 ImplementsLessThan(1) < ImplementsGreaterThan(1) 1193 1194 with self.subTest("LE <= LT"), self.assertRaises(TypeError): 1195 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) 1196 1197 with self.subTest("LE <= GE"), self.assertRaises(TypeError): 1198 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) 1199 1200 with self.subTest("GT > GE"), self.assertRaises(TypeError): 1201 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) 1202 1203 with self.subTest("GT > LT"), self.assertRaises(TypeError): 1204 ImplementsGreaterThan(5) > ImplementsLessThan(5) 1205 1206 with self.subTest("GE >= GT"), self.assertRaises(TypeError): 1207 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) 1208 1209 with self.subTest("GE >= LE"), self.assertRaises(TypeError): 1210 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) 1211 1212 with self.subTest("GE when equal"): 1213 a = ComparatorNotImplemented(8) 1214 b = ComparatorNotImplemented(8) 1215 self.assertEqual(a, b) 1216 with self.assertRaises(TypeError): 1217 a >= b 1218 1219 with self.subTest("LE when equal"): 1220 a = ComparatorNotImplemented(9) 1221 b = ComparatorNotImplemented(9) 1222 self.assertEqual(a, b) 1223 with self.assertRaises(TypeError): 1224 a <= b 1225 1226 def test_pickle(self): 1227 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1228 for name in '__lt__', '__gt__', '__le__', '__ge__': 1229 with self.subTest(method=name, proto=proto): 1230 method = getattr(Orderable_LT, name) 1231 method_copy = pickle.loads(pickle.dumps(method, proto)) 1232 self.assertIs(method_copy, method) 1233 1234 1235 def test_total_ordering_for_metaclasses_issue_44605(self): 1236 1237 @functools.total_ordering 1238 class SortableMeta(type): 1239 def __new__(cls, name, bases, ns): 1240 return super().__new__(cls, name, bases, ns) 1241 1242 def __lt__(self, other): 1243 if not isinstance(other, SortableMeta): 1244 pass 1245 return self.__name__ < other.__name__ 1246 1247 def __eq__(self, other): 1248 if not isinstance(other, SortableMeta): 1249 pass 1250 return self.__name__ == other.__name__ 1251 1252 class B(metaclass=SortableMeta): 1253 pass 1254 1255 class A(metaclass=SortableMeta): 1256 pass 1257 1258 self.assertTrue(A < B) 1259 self.assertFalse(A > B) 1260 1261 1262@functools.total_ordering 1263class Orderable_LT: 1264 def __init__(self, value): 1265 self.value = value 1266 def __lt__(self, other): 1267 return self.value < other.value 1268 def __eq__(self, other): 1269 return self.value == other.value 1270 1271 1272class TestCache: 1273 # This tests that the pass-through is working as designed. 1274 # The underlying functionality is tested in TestLRU. 1275 1276 def test_cache(self): 1277 @self.module.cache 1278 def fib(n): 1279 if n < 2: 1280 return n 1281 return fib(n-1) + fib(n-2) 1282 self.assertEqual([fib(n) for n in range(16)], 1283 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1284 self.assertEqual(fib.cache_info(), 1285 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1286 fib.cache_clear() 1287 self.assertEqual(fib.cache_info(), 1288 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1289 1290 1291class TestLRU: 1292 1293 def test_lru(self): 1294 def orig(x, y): 1295 return 3 * x + y 1296 f = self.module.lru_cache(maxsize=20)(orig) 1297 hits, misses, maxsize, currsize = f.cache_info() 1298 self.assertEqual(maxsize, 20) 1299 self.assertEqual(currsize, 0) 1300 self.assertEqual(hits, 0) 1301 self.assertEqual(misses, 0) 1302 1303 domain = range(5) 1304 for i in range(1000): 1305 x, y = choice(domain), choice(domain) 1306 actual = f(x, y) 1307 expected = orig(x, y) 1308 self.assertEqual(actual, expected) 1309 hits, misses, maxsize, currsize = f.cache_info() 1310 self.assertTrue(hits > misses) 1311 self.assertEqual(hits + misses, 1000) 1312 self.assertEqual(currsize, 20) 1313 1314 f.cache_clear() # test clearing 1315 hits, misses, maxsize, currsize = f.cache_info() 1316 self.assertEqual(hits, 0) 1317 self.assertEqual(misses, 0) 1318 self.assertEqual(currsize, 0) 1319 f(x, y) 1320 hits, misses, maxsize, currsize = f.cache_info() 1321 self.assertEqual(hits, 0) 1322 self.assertEqual(misses, 1) 1323 self.assertEqual(currsize, 1) 1324 1325 # Test bypassing the cache 1326 self.assertIs(f.__wrapped__, orig) 1327 f.__wrapped__(x, y) 1328 hits, misses, maxsize, currsize = f.cache_info() 1329 self.assertEqual(hits, 0) 1330 self.assertEqual(misses, 1) 1331 self.assertEqual(currsize, 1) 1332 1333 # test size zero (which means "never-cache") 1334 @self.module.lru_cache(0) 1335 def f(): 1336 nonlocal f_cnt 1337 f_cnt += 1 1338 return 20 1339 self.assertEqual(f.cache_info().maxsize, 0) 1340 f_cnt = 0 1341 for i in range(5): 1342 self.assertEqual(f(), 20) 1343 self.assertEqual(f_cnt, 5) 1344 hits, misses, maxsize, currsize = f.cache_info() 1345 self.assertEqual(hits, 0) 1346 self.assertEqual(misses, 5) 1347 self.assertEqual(currsize, 0) 1348 1349 # test size one 1350 @self.module.lru_cache(1) 1351 def f(): 1352 nonlocal f_cnt 1353 f_cnt += 1 1354 return 20 1355 self.assertEqual(f.cache_info().maxsize, 1) 1356 f_cnt = 0 1357 for i in range(5): 1358 self.assertEqual(f(), 20) 1359 self.assertEqual(f_cnt, 1) 1360 hits, misses, maxsize, currsize = f.cache_info() 1361 self.assertEqual(hits, 4) 1362 self.assertEqual(misses, 1) 1363 self.assertEqual(currsize, 1) 1364 1365 # test size two 1366 @self.module.lru_cache(2) 1367 def f(x): 1368 nonlocal f_cnt 1369 f_cnt += 1 1370 return x*10 1371 self.assertEqual(f.cache_info().maxsize, 2) 1372 f_cnt = 0 1373 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: 1374 # * * * * 1375 self.assertEqual(f(x), x*10) 1376 self.assertEqual(f_cnt, 4) 1377 hits, misses, maxsize, currsize = f.cache_info() 1378 self.assertEqual(hits, 12) 1379 self.assertEqual(misses, 4) 1380 self.assertEqual(currsize, 2) 1381 1382 def test_lru_no_args(self): 1383 @self.module.lru_cache 1384 def square(x): 1385 return x ** 2 1386 1387 self.assertEqual(list(map(square, [10, 20, 10])), 1388 [100, 400, 100]) 1389 self.assertEqual(square.cache_info().hits, 1) 1390 self.assertEqual(square.cache_info().misses, 2) 1391 self.assertEqual(square.cache_info().maxsize, 128) 1392 self.assertEqual(square.cache_info().currsize, 2) 1393 1394 def test_lru_bug_35780(self): 1395 # C version of the lru_cache was not checking to see if 1396 # the user function call has already modified the cache 1397 # (this arises in recursive calls and in multi-threading). 1398 # This cause the cache to have orphan links not referenced 1399 # by the cache dictionary. 1400 1401 once = True # Modified by f(x) below 1402 1403 @self.module.lru_cache(maxsize=10) 1404 def f(x): 1405 nonlocal once 1406 rv = f'.{x}.' 1407 if x == 20 and once: 1408 once = False 1409 rv = f(x) 1410 return rv 1411 1412 # Fill the cache 1413 for x in range(15): 1414 self.assertEqual(f(x), f'.{x}.') 1415 self.assertEqual(f.cache_info().currsize, 10) 1416 1417 # Make a recursive call and make sure the cache remains full 1418 self.assertEqual(f(20), '.20.') 1419 self.assertEqual(f.cache_info().currsize, 10) 1420 1421 def test_lru_bug_36650(self): 1422 # C version of lru_cache was treating a call with an empty **kwargs 1423 # dictionary as being distinct from a call with no keywords at all. 1424 # This did not result in an incorrect answer, but it did trigger 1425 # an unexpected cache miss. 1426 1427 @self.module.lru_cache() 1428 def f(x): 1429 pass 1430 1431 f(0) 1432 f(0, **{}) 1433 self.assertEqual(f.cache_info().hits, 1) 1434 1435 def test_lru_hash_only_once(self): 1436 # To protect against weird reentrancy bugs and to improve 1437 # efficiency when faced with slow __hash__ methods, the 1438 # LRU cache guarantees that it will only call __hash__ 1439 # only once per use as an argument to the cached function. 1440 1441 @self.module.lru_cache(maxsize=1) 1442 def f(x, y): 1443 return x * 3 + y 1444 1445 # Simulate the integer 5 1446 mock_int = unittest.mock.Mock() 1447 mock_int.__mul__ = unittest.mock.Mock(return_value=15) 1448 mock_int.__hash__ = unittest.mock.Mock(return_value=999) 1449 1450 # Add to cache: One use as an argument gives one call 1451 self.assertEqual(f(mock_int, 1), 16) 1452 self.assertEqual(mock_int.__hash__.call_count, 1) 1453 self.assertEqual(f.cache_info(), (0, 1, 1, 1)) 1454 1455 # Cache hit: One use as an argument gives one additional call 1456 self.assertEqual(f(mock_int, 1), 16) 1457 self.assertEqual(mock_int.__hash__.call_count, 2) 1458 self.assertEqual(f.cache_info(), (1, 1, 1, 1)) 1459 1460 # Cache eviction: No use as an argument gives no additional call 1461 self.assertEqual(f(6, 2), 20) 1462 self.assertEqual(mock_int.__hash__.call_count, 2) 1463 self.assertEqual(f.cache_info(), (1, 2, 1, 1)) 1464 1465 # Cache miss: One use as an argument gives one additional call 1466 self.assertEqual(f(mock_int, 1), 16) 1467 self.assertEqual(mock_int.__hash__.call_count, 3) 1468 self.assertEqual(f.cache_info(), (1, 3, 1, 1)) 1469 1470 def test_lru_reentrancy_with_len(self): 1471 # Test to make sure the LRU cache code isn't thrown-off by 1472 # caching the built-in len() function. Since len() can be 1473 # cached, we shouldn't use it inside the lru code itself. 1474 old_len = builtins.len 1475 try: 1476 builtins.len = self.module.lru_cache(4)(len) 1477 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]: 1478 self.assertEqual(len('abcdefghijklmn'[:i]), i) 1479 finally: 1480 builtins.len = old_len 1481 1482 def test_lru_star_arg_handling(self): 1483 # Test regression that arose in ea064ff3c10f 1484 @self.module.lru_cache() 1485 def f(*args): 1486 return args 1487 1488 self.assertEqual(f(1, 2), (1, 2)) 1489 self.assertEqual(f((1, 2)), ((1, 2),)) 1490 1491 def test_lru_type_error(self): 1492 # Regression test for issue #28653. 1493 # lru_cache was leaking when one of the arguments 1494 # wasn't cacheable. 1495 1496 @self.module.lru_cache(maxsize=None) 1497 def infinite_cache(o): 1498 pass 1499 1500 @self.module.lru_cache(maxsize=10) 1501 def limited_cache(o): 1502 pass 1503 1504 with self.assertRaises(TypeError): 1505 infinite_cache([]) 1506 1507 with self.assertRaises(TypeError): 1508 limited_cache([]) 1509 1510 def test_lru_with_maxsize_none(self): 1511 @self.module.lru_cache(maxsize=None) 1512 def fib(n): 1513 if n < 2: 1514 return n 1515 return fib(n-1) + fib(n-2) 1516 self.assertEqual([fib(n) for n in range(16)], 1517 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1518 self.assertEqual(fib.cache_info(), 1519 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1520 fib.cache_clear() 1521 self.assertEqual(fib.cache_info(), 1522 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1523 1524 def test_lru_with_maxsize_negative(self): 1525 @self.module.lru_cache(maxsize=-10) 1526 def eq(n): 1527 return n 1528 for i in (0, 1): 1529 self.assertEqual([eq(n) for n in range(150)], list(range(150))) 1530 self.assertEqual(eq.cache_info(), 1531 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0)) 1532 1533 def test_lru_with_exceptions(self): 1534 # Verify that user_function exceptions get passed through without 1535 # creating a hard-to-read chained exception. 1536 # http://bugs.python.org/issue13177 1537 for maxsize in (None, 128): 1538 @self.module.lru_cache(maxsize) 1539 def func(i): 1540 return 'abc'[i] 1541 self.assertEqual(func(0), 'a') 1542 with self.assertRaises(IndexError) as cm: 1543 func(15) 1544 self.assertIsNone(cm.exception.__context__) 1545 # Verify that the previous exception did not result in a cached entry 1546 with self.assertRaises(IndexError): 1547 func(15) 1548 1549 def test_lru_with_types(self): 1550 for maxsize in (None, 128): 1551 @self.module.lru_cache(maxsize=maxsize, typed=True) 1552 def square(x): 1553 return x * x 1554 self.assertEqual(square(3), 9) 1555 self.assertEqual(type(square(3)), type(9)) 1556 self.assertEqual(square(3.0), 9.0) 1557 self.assertEqual(type(square(3.0)), type(9.0)) 1558 self.assertEqual(square(x=3), 9) 1559 self.assertEqual(type(square(x=3)), type(9)) 1560 self.assertEqual(square(x=3.0), 9.0) 1561 self.assertEqual(type(square(x=3.0)), type(9.0)) 1562 self.assertEqual(square.cache_info().hits, 4) 1563 self.assertEqual(square.cache_info().misses, 4) 1564 1565 def test_lru_cache_typed_is_not_recursive(self): 1566 cached = self.module.lru_cache(typed=True)(repr) 1567 1568 self.assertEqual(cached(1), '1') 1569 self.assertEqual(cached(True), 'True') 1570 self.assertEqual(cached(1.0), '1.0') 1571 self.assertEqual(cached(0), '0') 1572 self.assertEqual(cached(False), 'False') 1573 self.assertEqual(cached(0.0), '0.0') 1574 1575 self.assertEqual(cached((1,)), '(1,)') 1576 self.assertEqual(cached((True,)), '(1,)') 1577 self.assertEqual(cached((1.0,)), '(1,)') 1578 self.assertEqual(cached((0,)), '(0,)') 1579 self.assertEqual(cached((False,)), '(0,)') 1580 self.assertEqual(cached((0.0,)), '(0,)') 1581 1582 class T(tuple): 1583 pass 1584 1585 self.assertEqual(cached(T((1,))), '(1,)') 1586 self.assertEqual(cached(T((True,))), '(1,)') 1587 self.assertEqual(cached(T((1.0,))), '(1,)') 1588 self.assertEqual(cached(T((0,))), '(0,)') 1589 self.assertEqual(cached(T((False,))), '(0,)') 1590 self.assertEqual(cached(T((0.0,))), '(0,)') 1591 1592 def test_lru_with_keyword_args(self): 1593 @self.module.lru_cache() 1594 def fib(n): 1595 if n < 2: 1596 return n 1597 return fib(n=n-1) + fib(n=n-2) 1598 self.assertEqual( 1599 [fib(n=number) for number in range(16)], 1600 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] 1601 ) 1602 self.assertEqual(fib.cache_info(), 1603 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) 1604 fib.cache_clear() 1605 self.assertEqual(fib.cache_info(), 1606 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) 1607 1608 def test_lru_with_keyword_args_maxsize_none(self): 1609 @self.module.lru_cache(maxsize=None) 1610 def fib(n): 1611 if n < 2: 1612 return n 1613 return fib(n=n-1) + fib(n=n-2) 1614 self.assertEqual([fib(n=number) for number in range(16)], 1615 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1616 self.assertEqual(fib.cache_info(), 1617 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1618 fib.cache_clear() 1619 self.assertEqual(fib.cache_info(), 1620 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1621 1622 def test_kwargs_order(self): 1623 # PEP 468: Preserving Keyword Argument Order 1624 @self.module.lru_cache(maxsize=10) 1625 def f(**kwargs): 1626 return list(kwargs.items()) 1627 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)]) 1628 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)]) 1629 self.assertEqual(f.cache_info(), 1630 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2)) 1631 1632 def test_lru_cache_decoration(self): 1633 def f(zomg: 'zomg_annotation'): 1634 """f doc string""" 1635 return 42 1636 g = self.module.lru_cache()(f) 1637 for attr in self.module.WRAPPER_ASSIGNMENTS: 1638 self.assertEqual(getattr(g, attr), getattr(f, attr)) 1639 1640 @threading_helper.requires_working_threading() 1641 def test_lru_cache_threaded(self): 1642 n, m = 5, 11 1643 def orig(x, y): 1644 return 3 * x + y 1645 f = self.module.lru_cache(maxsize=n*m)(orig) 1646 hits, misses, maxsize, currsize = f.cache_info() 1647 self.assertEqual(currsize, 0) 1648 1649 start = threading.Event() 1650 def full(k): 1651 start.wait(10) 1652 for _ in range(m): 1653 self.assertEqual(f(k, 0), orig(k, 0)) 1654 1655 def clear(): 1656 start.wait(10) 1657 for _ in range(2*m): 1658 f.cache_clear() 1659 1660 orig_si = sys.getswitchinterval() 1661 support.setswitchinterval(1e-6) 1662 try: 1663 # create n threads in order to fill cache 1664 threads = [threading.Thread(target=full, args=[k]) 1665 for k in range(n)] 1666 with threading_helper.start_threads(threads): 1667 start.set() 1668 1669 hits, misses, maxsize, currsize = f.cache_info() 1670 if self.module is py_functools: 1671 # XXX: Why can be not equal? 1672 self.assertLessEqual(misses, n) 1673 self.assertLessEqual(hits, m*n - misses) 1674 else: 1675 self.assertEqual(misses, n) 1676 self.assertEqual(hits, m*n - misses) 1677 self.assertEqual(currsize, n) 1678 1679 # create n threads in order to fill cache and 1 to clear it 1680 threads = [threading.Thread(target=clear)] 1681 threads += [threading.Thread(target=full, args=[k]) 1682 for k in range(n)] 1683 start.clear() 1684 with threading_helper.start_threads(threads): 1685 start.set() 1686 finally: 1687 sys.setswitchinterval(orig_si) 1688 1689 @threading_helper.requires_working_threading() 1690 def test_lru_cache_threaded2(self): 1691 # Simultaneous call with the same arguments 1692 n, m = 5, 7 1693 start = threading.Barrier(n+1) 1694 pause = threading.Barrier(n+1) 1695 stop = threading.Barrier(n+1) 1696 @self.module.lru_cache(maxsize=m*n) 1697 def f(x): 1698 pause.wait(10) 1699 return 3 * x 1700 self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) 1701 def test(): 1702 for i in range(m): 1703 start.wait(10) 1704 self.assertEqual(f(i), 3 * i) 1705 stop.wait(10) 1706 threads = [threading.Thread(target=test) for k in range(n)] 1707 with threading_helper.start_threads(threads): 1708 for i in range(m): 1709 start.wait(10) 1710 stop.reset() 1711 pause.wait(10) 1712 start.reset() 1713 stop.wait(10) 1714 pause.reset() 1715 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) 1716 1717 @threading_helper.requires_working_threading() 1718 def test_lru_cache_threaded3(self): 1719 @self.module.lru_cache(maxsize=2) 1720 def f(x): 1721 time.sleep(.01) 1722 return 3 * x 1723 def test(i, x): 1724 with self.subTest(thread=i): 1725 self.assertEqual(f(x), 3 * x, i) 1726 threads = [threading.Thread(target=test, args=(i, v)) 1727 for i, v in enumerate([1, 2, 2, 3, 2])] 1728 with threading_helper.start_threads(threads): 1729 pass 1730 1731 def test_need_for_rlock(self): 1732 # This will deadlock on an LRU cache that uses a regular lock 1733 1734 @self.module.lru_cache(maxsize=10) 1735 def test_func(x): 1736 'Used to demonstrate a reentrant lru_cache call within a single thread' 1737 return x 1738 1739 class DoubleEq: 1740 'Demonstrate a reentrant lru_cache call within a single thread' 1741 def __init__(self, x): 1742 self.x = x 1743 def __hash__(self): 1744 return self.x 1745 def __eq__(self, other): 1746 if self.x == 2: 1747 test_func(DoubleEq(1)) 1748 return self.x == other.x 1749 1750 test_func(DoubleEq(1)) # Load the cache 1751 test_func(DoubleEq(2)) # Load the cache 1752 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call 1753 DoubleEq(2)) # Verify the correct return value 1754 1755 def test_lru_method(self): 1756 class X(int): 1757 f_cnt = 0 1758 @self.module.lru_cache(2) 1759 def f(self, x): 1760 self.f_cnt += 1 1761 return x*10+self 1762 a = X(5) 1763 b = X(5) 1764 c = X(7) 1765 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) 1766 1767 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: 1768 self.assertEqual(a.f(x), x*10 + 5) 1769 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) 1770 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) 1771 1772 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: 1773 self.assertEqual(b.f(x), x*10 + 5) 1774 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) 1775 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) 1776 1777 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: 1778 self.assertEqual(c.f(x), x*10 + 7) 1779 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) 1780 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) 1781 1782 self.assertEqual(a.f.cache_info(), X.f.cache_info()) 1783 self.assertEqual(b.f.cache_info(), X.f.cache_info()) 1784 self.assertEqual(c.f.cache_info(), X.f.cache_info()) 1785 1786 def test_pickle(self): 1787 cls = self.__class__ 1788 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: 1789 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1790 with self.subTest(proto=proto, func=f): 1791 f_copy = pickle.loads(pickle.dumps(f, proto)) 1792 self.assertIs(f_copy, f) 1793 1794 def test_copy(self): 1795 cls = self.__class__ 1796 def orig(x, y): 1797 return 3 * x + y 1798 part = self.module.partial(orig, 2) 1799 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1800 self.module.lru_cache(2)(part)) 1801 for f in funcs: 1802 with self.subTest(func=f): 1803 f_copy = copy.copy(f) 1804 self.assertIs(f_copy, f) 1805 1806 def test_deepcopy(self): 1807 cls = self.__class__ 1808 def orig(x, y): 1809 return 3 * x + y 1810 part = self.module.partial(orig, 2) 1811 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1812 self.module.lru_cache(2)(part)) 1813 for f in funcs: 1814 with self.subTest(func=f): 1815 f_copy = copy.deepcopy(f) 1816 self.assertIs(f_copy, f) 1817 1818 def test_lru_cache_parameters(self): 1819 @self.module.lru_cache(maxsize=2) 1820 def f(): 1821 return 1 1822 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False}) 1823 1824 @self.module.lru_cache(maxsize=1000, typed=True) 1825 def f(): 1826 return 1 1827 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True}) 1828 1829 def test_lru_cache_weakrefable(self): 1830 @self.module.lru_cache 1831 def test_function(x): 1832 return x 1833 1834 class A: 1835 @self.module.lru_cache 1836 def test_method(self, x): 1837 return (self, x) 1838 1839 @staticmethod 1840 @self.module.lru_cache 1841 def test_staticmethod(x): 1842 return (self, x) 1843 1844 refs = [weakref.ref(test_function), 1845 weakref.ref(A.test_method), 1846 weakref.ref(A.test_staticmethod)] 1847 1848 for ref in refs: 1849 self.assertIsNotNone(ref()) 1850 1851 del A 1852 del test_function 1853 gc.collect() 1854 1855 for ref in refs: 1856 self.assertIsNone(ref()) 1857 1858 1859@py_functools.lru_cache() 1860def py_cached_func(x, y): 1861 return 3 * x + y 1862 1863@c_functools.lru_cache() 1864def c_cached_func(x, y): 1865 return 3 * x + y 1866 1867 1868class TestLRUPy(TestLRU, unittest.TestCase): 1869 module = py_functools 1870 cached_func = py_cached_func, 1871 1872 @module.lru_cache() 1873 def cached_meth(self, x, y): 1874 return 3 * x + y 1875 1876 @staticmethod 1877 @module.lru_cache() 1878 def cached_staticmeth(x, y): 1879 return 3 * x + y 1880 1881 1882class TestLRUC(TestLRU, unittest.TestCase): 1883 module = c_functools 1884 cached_func = c_cached_func, 1885 1886 @module.lru_cache() 1887 def cached_meth(self, x, y): 1888 return 3 * x + y 1889 1890 @staticmethod 1891 @module.lru_cache() 1892 def cached_staticmeth(x, y): 1893 return 3 * x + y 1894 1895 1896class TestSingleDispatch(unittest.TestCase): 1897 def test_simple_overloads(self): 1898 @functools.singledispatch 1899 def g(obj): 1900 return "base" 1901 def g_int(i): 1902 return "integer" 1903 g.register(int, g_int) 1904 self.assertEqual(g("str"), "base") 1905 self.assertEqual(g(1), "integer") 1906 self.assertEqual(g([1,2,3]), "base") 1907 1908 def test_mro(self): 1909 @functools.singledispatch 1910 def g(obj): 1911 return "base" 1912 class A: 1913 pass 1914 class C(A): 1915 pass 1916 class B(A): 1917 pass 1918 class D(C, B): 1919 pass 1920 def g_A(a): 1921 return "A" 1922 def g_B(b): 1923 return "B" 1924 g.register(A, g_A) 1925 g.register(B, g_B) 1926 self.assertEqual(g(A()), "A") 1927 self.assertEqual(g(B()), "B") 1928 self.assertEqual(g(C()), "A") 1929 self.assertEqual(g(D()), "B") 1930 1931 def test_register_decorator(self): 1932 @functools.singledispatch 1933 def g(obj): 1934 return "base" 1935 @g.register(int) 1936 def g_int(i): 1937 return "int %s" % (i,) 1938 self.assertEqual(g(""), "base") 1939 self.assertEqual(g(12), "int 12") 1940 self.assertIs(g.dispatch(int), g_int) 1941 self.assertIs(g.dispatch(object), g.dispatch(str)) 1942 # Note: in the assert above this is not g. 1943 # @singledispatch returns the wrapper. 1944 1945 def test_wrapping_attributes(self): 1946 @functools.singledispatch 1947 def g(obj): 1948 "Simple test" 1949 return "Test" 1950 self.assertEqual(g.__name__, "g") 1951 if sys.flags.optimize < 2: 1952 self.assertEqual(g.__doc__, "Simple test") 1953 1954 @unittest.skipUnless(decimal, 'requires _decimal') 1955 @support.cpython_only 1956 def test_c_classes(self): 1957 @functools.singledispatch 1958 def g(obj): 1959 return "base" 1960 @g.register(decimal.DecimalException) 1961 def _(obj): 1962 return obj.args 1963 subn = decimal.Subnormal("Exponent < Emin") 1964 rnd = decimal.Rounded("Number got rounded") 1965 self.assertEqual(g(subn), ("Exponent < Emin",)) 1966 self.assertEqual(g(rnd), ("Number got rounded",)) 1967 @g.register(decimal.Subnormal) 1968 def _(obj): 1969 return "Too small to care." 1970 self.assertEqual(g(subn), "Too small to care.") 1971 self.assertEqual(g(rnd), ("Number got rounded",)) 1972 1973 def test_compose_mro(self): 1974 # None of the examples in this test depend on haystack ordering. 1975 c = collections.abc 1976 mro = functools._compose_mro 1977 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] 1978 for haystack in permutations(bases): 1979 m = mro(dict, haystack) 1980 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, 1981 c.Collection, c.Sized, c.Iterable, 1982 c.Container, object]) 1983 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict] 1984 for haystack in permutations(bases): 1985 m = mro(collections.ChainMap, haystack) 1986 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping, 1987 c.Collection, c.Sized, c.Iterable, 1988 c.Container, object]) 1989 1990 # If there's a generic function with implementations registered for 1991 # both Sized and Container, passing a defaultdict to it results in an 1992 # ambiguous dispatch which will cause a RuntimeError (see 1993 # test_mro_conflicts). 1994 bases = [c.Container, c.Sized, str] 1995 for haystack in permutations(bases): 1996 m = mro(collections.defaultdict, [c.Sized, c.Container, str]) 1997 self.assertEqual(m, [collections.defaultdict, dict, c.Sized, 1998 c.Container, object]) 1999 2000 # MutableSequence below is registered directly on D. In other words, it 2001 # precedes MutableMapping which means single dispatch will always 2002 # choose MutableSequence here. 2003 class D(collections.defaultdict): 2004 pass 2005 c.MutableSequence.register(D) 2006 bases = [c.MutableSequence, c.MutableMapping] 2007 for haystack in permutations(bases): 2008 m = mro(D, bases) 2009 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, 2010 collections.defaultdict, dict, c.MutableMapping, c.Mapping, 2011 c.Collection, c.Sized, c.Iterable, c.Container, 2012 object]) 2013 2014 # Container and Callable are registered on different base classes and 2015 # a generic function supporting both should always pick the Callable 2016 # implementation if a C instance is passed. 2017 class C(collections.defaultdict): 2018 def __call__(self): 2019 pass 2020 bases = [c.Sized, c.Callable, c.Container, c.Mapping] 2021 for haystack in permutations(bases): 2022 m = mro(C, haystack) 2023 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping, 2024 c.Collection, c.Sized, c.Iterable, 2025 c.Container, object]) 2026 2027 def test_register_abc(self): 2028 c = collections.abc 2029 d = {"a": "b"} 2030 l = [1, 2, 3] 2031 s = {object(), None} 2032 f = frozenset(s) 2033 t = (1, 2, 3) 2034 @functools.singledispatch 2035 def g(obj): 2036 return "base" 2037 self.assertEqual(g(d), "base") 2038 self.assertEqual(g(l), "base") 2039 self.assertEqual(g(s), "base") 2040 self.assertEqual(g(f), "base") 2041 self.assertEqual(g(t), "base") 2042 g.register(c.Sized, lambda obj: "sized") 2043 self.assertEqual(g(d), "sized") 2044 self.assertEqual(g(l), "sized") 2045 self.assertEqual(g(s), "sized") 2046 self.assertEqual(g(f), "sized") 2047 self.assertEqual(g(t), "sized") 2048 g.register(c.MutableMapping, lambda obj: "mutablemapping") 2049 self.assertEqual(g(d), "mutablemapping") 2050 self.assertEqual(g(l), "sized") 2051 self.assertEqual(g(s), "sized") 2052 self.assertEqual(g(f), "sized") 2053 self.assertEqual(g(t), "sized") 2054 g.register(collections.ChainMap, lambda obj: "chainmap") 2055 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered 2056 self.assertEqual(g(l), "sized") 2057 self.assertEqual(g(s), "sized") 2058 self.assertEqual(g(f), "sized") 2059 self.assertEqual(g(t), "sized") 2060 g.register(c.MutableSequence, lambda obj: "mutablesequence") 2061 self.assertEqual(g(d), "mutablemapping") 2062 self.assertEqual(g(l), "mutablesequence") 2063 self.assertEqual(g(s), "sized") 2064 self.assertEqual(g(f), "sized") 2065 self.assertEqual(g(t), "sized") 2066 g.register(c.MutableSet, lambda obj: "mutableset") 2067 self.assertEqual(g(d), "mutablemapping") 2068 self.assertEqual(g(l), "mutablesequence") 2069 self.assertEqual(g(s), "mutableset") 2070 self.assertEqual(g(f), "sized") 2071 self.assertEqual(g(t), "sized") 2072 g.register(c.Mapping, lambda obj: "mapping") 2073 self.assertEqual(g(d), "mutablemapping") # not specific enough 2074 self.assertEqual(g(l), "mutablesequence") 2075 self.assertEqual(g(s), "mutableset") 2076 self.assertEqual(g(f), "sized") 2077 self.assertEqual(g(t), "sized") 2078 g.register(c.Sequence, lambda obj: "sequence") 2079 self.assertEqual(g(d), "mutablemapping") 2080 self.assertEqual(g(l), "mutablesequence") 2081 self.assertEqual(g(s), "mutableset") 2082 self.assertEqual(g(f), "sized") 2083 self.assertEqual(g(t), "sequence") 2084 g.register(c.Set, lambda obj: "set") 2085 self.assertEqual(g(d), "mutablemapping") 2086 self.assertEqual(g(l), "mutablesequence") 2087 self.assertEqual(g(s), "mutableset") 2088 self.assertEqual(g(f), "set") 2089 self.assertEqual(g(t), "sequence") 2090 g.register(dict, lambda obj: "dict") 2091 self.assertEqual(g(d), "dict") 2092 self.assertEqual(g(l), "mutablesequence") 2093 self.assertEqual(g(s), "mutableset") 2094 self.assertEqual(g(f), "set") 2095 self.assertEqual(g(t), "sequence") 2096 g.register(list, lambda obj: "list") 2097 self.assertEqual(g(d), "dict") 2098 self.assertEqual(g(l), "list") 2099 self.assertEqual(g(s), "mutableset") 2100 self.assertEqual(g(f), "set") 2101 self.assertEqual(g(t), "sequence") 2102 g.register(set, lambda obj: "concrete-set") 2103 self.assertEqual(g(d), "dict") 2104 self.assertEqual(g(l), "list") 2105 self.assertEqual(g(s), "concrete-set") 2106 self.assertEqual(g(f), "set") 2107 self.assertEqual(g(t), "sequence") 2108 g.register(frozenset, lambda obj: "frozen-set") 2109 self.assertEqual(g(d), "dict") 2110 self.assertEqual(g(l), "list") 2111 self.assertEqual(g(s), "concrete-set") 2112 self.assertEqual(g(f), "frozen-set") 2113 self.assertEqual(g(t), "sequence") 2114 g.register(tuple, lambda obj: "tuple") 2115 self.assertEqual(g(d), "dict") 2116 self.assertEqual(g(l), "list") 2117 self.assertEqual(g(s), "concrete-set") 2118 self.assertEqual(g(f), "frozen-set") 2119 self.assertEqual(g(t), "tuple") 2120 2121 def test_c3_abc(self): 2122 c = collections.abc 2123 mro = functools._c3_mro 2124 class A(object): 2125 pass 2126 class B(A): 2127 def __len__(self): 2128 return 0 # implies Sized 2129 @c.Container.register 2130 class C(object): 2131 pass 2132 class D(object): 2133 pass # unrelated 2134 class X(D, C, B): 2135 def __call__(self): 2136 pass # implies Callable 2137 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] 2138 for abcs in permutations([c.Sized, c.Callable, c.Container]): 2139 self.assertEqual(mro(X, abcs=abcs), expected) 2140 # unrelated ABCs don't appear in the resulting MRO 2141 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] 2142 self.assertEqual(mro(X, abcs=many_abcs), expected) 2143 2144 def test_false_meta(self): 2145 # see issue23572 2146 class MetaA(type): 2147 def __len__(self): 2148 return 0 2149 class A(metaclass=MetaA): 2150 pass 2151 class AA(A): 2152 pass 2153 @functools.singledispatch 2154 def fun(a): 2155 return 'base A' 2156 @fun.register(A) 2157 def _(a): 2158 return 'fun A' 2159 aa = AA() 2160 self.assertEqual(fun(aa), 'fun A') 2161 2162 def test_mro_conflicts(self): 2163 c = collections.abc 2164 @functools.singledispatch 2165 def g(arg): 2166 return "base" 2167 class O(c.Sized): 2168 def __len__(self): 2169 return 0 2170 o = O() 2171 self.assertEqual(g(o), "base") 2172 g.register(c.Iterable, lambda arg: "iterable") 2173 g.register(c.Container, lambda arg: "container") 2174 g.register(c.Sized, lambda arg: "sized") 2175 g.register(c.Set, lambda arg: "set") 2176 self.assertEqual(g(o), "sized") 2177 c.Iterable.register(O) 2178 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ 2179 c.Container.register(O) 2180 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ 2181 c.Set.register(O) 2182 self.assertEqual(g(o), "set") # because c.Set is a subclass of 2183 # c.Sized and c.Container 2184 class P: 2185 pass 2186 p = P() 2187 self.assertEqual(g(p), "base") 2188 c.Iterable.register(P) 2189 self.assertEqual(g(p), "iterable") 2190 c.Container.register(P) 2191 with self.assertRaises(RuntimeError) as re_one: 2192 g(p) 2193 self.assertIn( 2194 str(re_one.exception), 2195 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2196 "or <class 'collections.abc.Iterable'>"), 2197 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " 2198 "or <class 'collections.abc.Container'>")), 2199 ) 2200 class Q(c.Sized): 2201 def __len__(self): 2202 return 0 2203 q = Q() 2204 self.assertEqual(g(q), "sized") 2205 c.Iterable.register(Q) 2206 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ 2207 c.Set.register(Q) 2208 self.assertEqual(g(q), "set") # because c.Set is a subclass of 2209 # c.Sized and c.Iterable 2210 @functools.singledispatch 2211 def h(arg): 2212 return "base" 2213 @h.register(c.Sized) 2214 def _(arg): 2215 return "sized" 2216 @h.register(c.Container) 2217 def _(arg): 2218 return "container" 2219 # Even though Sized and Container are explicit bases of MutableMapping, 2220 # this ABC is implicitly registered on defaultdict which makes all of 2221 # MutableMapping's bases implicit as well from defaultdict's 2222 # perspective. 2223 with self.assertRaises(RuntimeError) as re_two: 2224 h(collections.defaultdict(lambda: 0)) 2225 self.assertIn( 2226 str(re_two.exception), 2227 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2228 "or <class 'collections.abc.Sized'>"), 2229 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2230 "or <class 'collections.abc.Container'>")), 2231 ) 2232 class R(collections.defaultdict): 2233 pass 2234 c.MutableSequence.register(R) 2235 @functools.singledispatch 2236 def i(arg): 2237 return "base" 2238 @i.register(c.MutableMapping) 2239 def _(arg): 2240 return "mapping" 2241 @i.register(c.MutableSequence) 2242 def _(arg): 2243 return "sequence" 2244 r = R() 2245 self.assertEqual(i(r), "sequence") 2246 class S: 2247 pass 2248 class T(S, c.Sized): 2249 def __len__(self): 2250 return 0 2251 t = T() 2252 self.assertEqual(h(t), "sized") 2253 c.Container.register(T) 2254 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO 2255 class U: 2256 def __len__(self): 2257 return 0 2258 u = U() 2259 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred 2260 # from the existence of __len__() 2261 c.Container.register(U) 2262 # There is no preference for registered versus inferred ABCs. 2263 with self.assertRaises(RuntimeError) as re_three: 2264 h(u) 2265 self.assertIn( 2266 str(re_three.exception), 2267 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2268 "or <class 'collections.abc.Sized'>"), 2269 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2270 "or <class 'collections.abc.Container'>")), 2271 ) 2272 class V(c.Sized, S): 2273 def __len__(self): 2274 return 0 2275 @functools.singledispatch 2276 def j(arg): 2277 return "base" 2278 @j.register(S) 2279 def _(arg): 2280 return "s" 2281 @j.register(c.Container) 2282 def _(arg): 2283 return "container" 2284 v = V() 2285 self.assertEqual(j(v), "s") 2286 c.Container.register(V) 2287 self.assertEqual(j(v), "container") # because it ends up right after 2288 # Sized in the MRO 2289 2290 def test_cache_invalidation(self): 2291 from collections import UserDict 2292 import weakref 2293 2294 class TracingDict(UserDict): 2295 def __init__(self, *args, **kwargs): 2296 super(TracingDict, self).__init__(*args, **kwargs) 2297 self.set_ops = [] 2298 self.get_ops = [] 2299 def __getitem__(self, key): 2300 result = self.data[key] 2301 self.get_ops.append(key) 2302 return result 2303 def __setitem__(self, key, value): 2304 self.set_ops.append(key) 2305 self.data[key] = value 2306 def clear(self): 2307 self.data.clear() 2308 2309 td = TracingDict() 2310 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td): 2311 c = collections.abc 2312 @functools.singledispatch 2313 def g(arg): 2314 return "base" 2315 d = {} 2316 l = [] 2317 self.assertEqual(len(td), 0) 2318 self.assertEqual(g(d), "base") 2319 self.assertEqual(len(td), 1) 2320 self.assertEqual(td.get_ops, []) 2321 self.assertEqual(td.set_ops, [dict]) 2322 self.assertEqual(td.data[dict], g.registry[object]) 2323 self.assertEqual(g(l), "base") 2324 self.assertEqual(len(td), 2) 2325 self.assertEqual(td.get_ops, []) 2326 self.assertEqual(td.set_ops, [dict, list]) 2327 self.assertEqual(td.data[dict], g.registry[object]) 2328 self.assertEqual(td.data[list], g.registry[object]) 2329 self.assertEqual(td.data[dict], td.data[list]) 2330 self.assertEqual(g(l), "base") 2331 self.assertEqual(g(d), "base") 2332 self.assertEqual(td.get_ops, [list, dict]) 2333 self.assertEqual(td.set_ops, [dict, list]) 2334 g.register(list, lambda arg: "list") 2335 self.assertEqual(td.get_ops, [list, dict]) 2336 self.assertEqual(len(td), 0) 2337 self.assertEqual(g(d), "base") 2338 self.assertEqual(len(td), 1) 2339 self.assertEqual(td.get_ops, [list, dict]) 2340 self.assertEqual(td.set_ops, [dict, list, dict]) 2341 self.assertEqual(td.data[dict], 2342 functools._find_impl(dict, g.registry)) 2343 self.assertEqual(g(l), "list") 2344 self.assertEqual(len(td), 2) 2345 self.assertEqual(td.get_ops, [list, dict]) 2346 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2347 self.assertEqual(td.data[list], 2348 functools._find_impl(list, g.registry)) 2349 class X: 2350 pass 2351 c.MutableMapping.register(X) # Will not invalidate the cache, 2352 # not using ABCs yet. 2353 self.assertEqual(g(d), "base") 2354 self.assertEqual(g(l), "list") 2355 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2356 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2357 g.register(c.Sized, lambda arg: "sized") 2358 self.assertEqual(len(td), 0) 2359 self.assertEqual(g(d), "sized") 2360 self.assertEqual(len(td), 1) 2361 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2362 self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) 2363 self.assertEqual(g(l), "list") 2364 self.assertEqual(len(td), 2) 2365 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2366 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2367 self.assertEqual(g(l), "list") 2368 self.assertEqual(g(d), "sized") 2369 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) 2370 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2371 g.dispatch(list) 2372 g.dispatch(dict) 2373 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, 2374 list, dict]) 2375 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2376 c.MutableSet.register(X) # Will invalidate the cache. 2377 self.assertEqual(len(td), 2) # Stale cache. 2378 self.assertEqual(g(l), "list") 2379 self.assertEqual(len(td), 1) 2380 g.register(c.MutableMapping, lambda arg: "mutablemapping") 2381 self.assertEqual(len(td), 0) 2382 self.assertEqual(g(d), "mutablemapping") 2383 self.assertEqual(len(td), 1) 2384 self.assertEqual(g(l), "list") 2385 self.assertEqual(len(td), 2) 2386 g.register(dict, lambda arg: "dict") 2387 self.assertEqual(g(d), "dict") 2388 self.assertEqual(g(l), "list") 2389 g._clear_cache() 2390 self.assertEqual(len(td), 0) 2391 2392 def test_annotations(self): 2393 @functools.singledispatch 2394 def i(arg): 2395 return "base" 2396 @i.register 2397 def _(arg: collections.abc.Mapping): 2398 return "mapping" 2399 @i.register 2400 def _(arg: "collections.abc.Sequence"): 2401 return "sequence" 2402 self.assertEqual(i(None), "base") 2403 self.assertEqual(i({"a": 1}), "mapping") 2404 self.assertEqual(i([1, 2, 3]), "sequence") 2405 self.assertEqual(i((1, 2, 3)), "sequence") 2406 self.assertEqual(i("str"), "sequence") 2407 2408 # Registering classes as callables doesn't work with annotations, 2409 # you need to pass the type explicitly. 2410 @i.register(str) 2411 class _: 2412 def __init__(self, arg): 2413 self.arg = arg 2414 2415 def __eq__(self, other): 2416 return self.arg == other 2417 self.assertEqual(i("str"), "str") 2418 2419 def test_method_register(self): 2420 class A: 2421 @functools.singledispatchmethod 2422 def t(self, arg): 2423 self.arg = "base" 2424 @t.register(int) 2425 def _(self, arg): 2426 self.arg = "int" 2427 @t.register(str) 2428 def _(self, arg): 2429 self.arg = "str" 2430 a = A() 2431 2432 a.t(0) 2433 self.assertEqual(a.arg, "int") 2434 aa = A() 2435 self.assertFalse(hasattr(aa, 'arg')) 2436 a.t('') 2437 self.assertEqual(a.arg, "str") 2438 aa = A() 2439 self.assertFalse(hasattr(aa, 'arg')) 2440 a.t(0.0) 2441 self.assertEqual(a.arg, "base") 2442 aa = A() 2443 self.assertFalse(hasattr(aa, 'arg')) 2444 2445 def test_staticmethod_register(self): 2446 class A: 2447 @functools.singledispatchmethod 2448 @staticmethod 2449 def t(arg): 2450 return arg 2451 @t.register(int) 2452 @staticmethod 2453 def _(arg): 2454 return isinstance(arg, int) 2455 @t.register(str) 2456 @staticmethod 2457 def _(arg): 2458 return isinstance(arg, str) 2459 a = A() 2460 2461 self.assertTrue(A.t(0)) 2462 self.assertTrue(A.t('')) 2463 self.assertEqual(A.t(0.0), 0.0) 2464 2465 def test_classmethod_register(self): 2466 class A: 2467 def __init__(self, arg): 2468 self.arg = arg 2469 2470 @functools.singledispatchmethod 2471 @classmethod 2472 def t(cls, arg): 2473 return cls("base") 2474 @t.register(int) 2475 @classmethod 2476 def _(cls, arg): 2477 return cls("int") 2478 @t.register(str) 2479 @classmethod 2480 def _(cls, arg): 2481 return cls("str") 2482 2483 self.assertEqual(A.t(0).arg, "int") 2484 self.assertEqual(A.t('').arg, "str") 2485 self.assertEqual(A.t(0.0).arg, "base") 2486 2487 def test_callable_register(self): 2488 class A: 2489 def __init__(self, arg): 2490 self.arg = arg 2491 2492 @functools.singledispatchmethod 2493 @classmethod 2494 def t(cls, arg): 2495 return cls("base") 2496 2497 @A.t.register(int) 2498 @classmethod 2499 def _(cls, arg): 2500 return cls("int") 2501 @A.t.register(str) 2502 @classmethod 2503 def _(cls, arg): 2504 return cls("str") 2505 2506 self.assertEqual(A.t(0).arg, "int") 2507 self.assertEqual(A.t('').arg, "str") 2508 self.assertEqual(A.t(0.0).arg, "base") 2509 2510 def test_abstractmethod_register(self): 2511 class Abstract(metaclass=abc.ABCMeta): 2512 2513 @functools.singledispatchmethod 2514 @abc.abstractmethod 2515 def add(self, x, y): 2516 pass 2517 2518 self.assertTrue(Abstract.add.__isabstractmethod__) 2519 self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__) 2520 2521 with self.assertRaises(TypeError): 2522 Abstract() 2523 2524 def test_type_ann_register(self): 2525 class A: 2526 @functools.singledispatchmethod 2527 def t(self, arg): 2528 return "base" 2529 @t.register 2530 def _(self, arg: int): 2531 return "int" 2532 @t.register 2533 def _(self, arg: str): 2534 return "str" 2535 a = A() 2536 2537 self.assertEqual(a.t(0), "int") 2538 self.assertEqual(a.t(''), "str") 2539 self.assertEqual(a.t(0.0), "base") 2540 2541 def test_staticmethod_type_ann_register(self): 2542 class A: 2543 @functools.singledispatchmethod 2544 @staticmethod 2545 def t(arg): 2546 return arg 2547 @t.register 2548 @staticmethod 2549 def _(arg: int): 2550 return isinstance(arg, int) 2551 @t.register 2552 @staticmethod 2553 def _(arg: str): 2554 return isinstance(arg, str) 2555 a = A() 2556 2557 self.assertTrue(A.t(0)) 2558 self.assertTrue(A.t('')) 2559 self.assertEqual(A.t(0.0), 0.0) 2560 2561 def test_classmethod_type_ann_register(self): 2562 class A: 2563 def __init__(self, arg): 2564 self.arg = arg 2565 2566 @functools.singledispatchmethod 2567 @classmethod 2568 def t(cls, arg): 2569 return cls("base") 2570 @t.register 2571 @classmethod 2572 def _(cls, arg: int): 2573 return cls("int") 2574 @t.register 2575 @classmethod 2576 def _(cls, arg: str): 2577 return cls("str") 2578 2579 self.assertEqual(A.t(0).arg, "int") 2580 self.assertEqual(A.t('').arg, "str") 2581 self.assertEqual(A.t(0.0).arg, "base") 2582 2583 def test_method_wrapping_attributes(self): 2584 class A: 2585 @functools.singledispatchmethod 2586 def func(self, arg: int) -> str: 2587 """My function docstring""" 2588 return str(arg) 2589 @functools.singledispatchmethod 2590 @classmethod 2591 def cls_func(cls, arg: int) -> str: 2592 """My function docstring""" 2593 return str(arg) 2594 @functools.singledispatchmethod 2595 @staticmethod 2596 def static_func(arg: int) -> str: 2597 """My function docstring""" 2598 return str(arg) 2599 2600 for meth in ( 2601 A.func, 2602 A().func, 2603 A.cls_func, 2604 A().cls_func, 2605 A.static_func, 2606 A().static_func 2607 ): 2608 with self.subTest(meth=meth): 2609 self.assertEqual(meth.__doc__, 'My function docstring') 2610 self.assertEqual(meth.__annotations__['arg'], int) 2611 2612 self.assertEqual(A.func.__name__, 'func') 2613 self.assertEqual(A().func.__name__, 'func') 2614 self.assertEqual(A.cls_func.__name__, 'cls_func') 2615 self.assertEqual(A().cls_func.__name__, 'cls_func') 2616 self.assertEqual(A.static_func.__name__, 'static_func') 2617 self.assertEqual(A().static_func.__name__, 'static_func') 2618 2619 def test_double_wrapped_methods(self): 2620 def classmethod_friendly_decorator(func): 2621 wrapped = func.__func__ 2622 @classmethod 2623 @functools.wraps(wrapped) 2624 def wrapper(*args, **kwargs): 2625 return wrapped(*args, **kwargs) 2626 return wrapper 2627 2628 class WithoutSingleDispatch: 2629 @classmethod 2630 @contextlib.contextmanager 2631 def cls_context_manager(cls, arg: int) -> str: 2632 try: 2633 yield str(arg) 2634 finally: 2635 return 'Done' 2636 2637 @classmethod_friendly_decorator 2638 @classmethod 2639 def decorated_classmethod(cls, arg: int) -> str: 2640 return str(arg) 2641 2642 class WithSingleDispatch: 2643 @functools.singledispatchmethod 2644 @classmethod 2645 @contextlib.contextmanager 2646 def cls_context_manager(cls, arg: int) -> str: 2647 """My function docstring""" 2648 try: 2649 yield str(arg) 2650 finally: 2651 return 'Done' 2652 2653 @functools.singledispatchmethod 2654 @classmethod_friendly_decorator 2655 @classmethod 2656 def decorated_classmethod(cls, arg: int) -> str: 2657 """My function docstring""" 2658 return str(arg) 2659 2660 # These are sanity checks 2661 # to test the test itself is working as expected 2662 with WithoutSingleDispatch.cls_context_manager(5) as foo: 2663 without_single_dispatch_foo = foo 2664 2665 with WithSingleDispatch.cls_context_manager(5) as foo: 2666 single_dispatch_foo = foo 2667 2668 self.assertEqual(without_single_dispatch_foo, single_dispatch_foo) 2669 self.assertEqual(single_dispatch_foo, '5') 2670 2671 self.assertEqual( 2672 WithoutSingleDispatch.decorated_classmethod(5), 2673 WithSingleDispatch.decorated_classmethod(5) 2674 ) 2675 2676 self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5') 2677 2678 # Behavioural checks now follow 2679 for method_name in ('cls_context_manager', 'decorated_classmethod'): 2680 with self.subTest(method=method_name): 2681 self.assertEqual( 2682 getattr(WithSingleDispatch, method_name).__name__, 2683 getattr(WithoutSingleDispatch, method_name).__name__ 2684 ) 2685 2686 self.assertEqual( 2687 getattr(WithSingleDispatch(), method_name).__name__, 2688 getattr(WithoutSingleDispatch(), method_name).__name__ 2689 ) 2690 2691 for meth in ( 2692 WithSingleDispatch.cls_context_manager, 2693 WithSingleDispatch().cls_context_manager, 2694 WithSingleDispatch.decorated_classmethod, 2695 WithSingleDispatch().decorated_classmethod 2696 ): 2697 with self.subTest(meth=meth): 2698 self.assertEqual(meth.__doc__, 'My function docstring') 2699 self.assertEqual(meth.__annotations__['arg'], int) 2700 2701 self.assertEqual( 2702 WithSingleDispatch.cls_context_manager.__name__, 2703 'cls_context_manager' 2704 ) 2705 self.assertEqual( 2706 WithSingleDispatch().cls_context_manager.__name__, 2707 'cls_context_manager' 2708 ) 2709 self.assertEqual( 2710 WithSingleDispatch.decorated_classmethod.__name__, 2711 'decorated_classmethod' 2712 ) 2713 self.assertEqual( 2714 WithSingleDispatch().decorated_classmethod.__name__, 2715 'decorated_classmethod' 2716 ) 2717 2718 def test_invalid_registrations(self): 2719 msg_prefix = "Invalid first argument to `register()`: " 2720 msg_suffix = ( 2721 ". Use either `@register(some_class)` or plain `@register` on an " 2722 "annotated function." 2723 ) 2724 @functools.singledispatch 2725 def i(arg): 2726 return "base" 2727 with self.assertRaises(TypeError) as exc: 2728 @i.register(42) 2729 def _(arg): 2730 return "I annotated with a non-type" 2731 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42")) 2732 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2733 with self.assertRaises(TypeError) as exc: 2734 @i.register 2735 def _(arg): 2736 return "I forgot to annotate" 2737 self.assertTrue(str(exc.exception).startswith(msg_prefix + 2738 "<function TestSingleDispatch.test_invalid_registrations.<locals>._" 2739 )) 2740 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2741 2742 with self.assertRaises(TypeError) as exc: 2743 @i.register 2744 def _(arg: typing.Iterable[str]): 2745 # At runtime, dispatching on generics is impossible. 2746 # When registering implementations with singledispatch, avoid 2747 # types from `typing`. Instead, annotate with regular types 2748 # or ABCs. 2749 return "I annotated with a generic collection" 2750 self.assertTrue(str(exc.exception).startswith( 2751 "Invalid annotation for 'arg'." 2752 )) 2753 self.assertTrue(str(exc.exception).endswith( 2754 'typing.Iterable[str] is not a class.' 2755 )) 2756 2757 with self.assertRaises(TypeError) as exc: 2758 @i.register 2759 def _(arg: typing.Union[int, typing.Iterable[str]]): 2760 return "Invalid Union" 2761 self.assertTrue(str(exc.exception).startswith( 2762 "Invalid annotation for 'arg'." 2763 )) 2764 self.assertTrue(str(exc.exception).endswith( 2765 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.' 2766 )) 2767 2768 def test_invalid_positional_argument(self): 2769 @functools.singledispatch 2770 def f(*args): 2771 pass 2772 msg = 'f requires at least 1 positional argument' 2773 with self.assertRaisesRegex(TypeError, msg): 2774 f() 2775 2776 def test_union(self): 2777 @functools.singledispatch 2778 def f(arg): 2779 return "default" 2780 2781 @f.register 2782 def _(arg: typing.Union[str, bytes]): 2783 return "typing.Union" 2784 2785 @f.register 2786 def _(arg: int | float): 2787 return "types.UnionType" 2788 2789 self.assertEqual(f([]), "default") 2790 self.assertEqual(f(""), "typing.Union") 2791 self.assertEqual(f(b""), "typing.Union") 2792 self.assertEqual(f(1), "types.UnionType") 2793 self.assertEqual(f(1.0), "types.UnionType") 2794 2795 def test_union_conflict(self): 2796 @functools.singledispatch 2797 def f(arg): 2798 return "default" 2799 2800 @f.register 2801 def _(arg: typing.Union[str, bytes]): 2802 return "typing.Union" 2803 2804 @f.register 2805 def _(arg: int | str): 2806 return "types.UnionType" 2807 2808 self.assertEqual(f([]), "default") 2809 self.assertEqual(f(""), "types.UnionType") # last one wins 2810 self.assertEqual(f(b""), "typing.Union") 2811 self.assertEqual(f(1), "types.UnionType") 2812 2813 def test_union_None(self): 2814 @functools.singledispatch 2815 def typing_union(arg): 2816 return "default" 2817 2818 @typing_union.register 2819 def _(arg: typing.Union[str, None]): 2820 return "typing.Union" 2821 2822 self.assertEqual(typing_union(1), "default") 2823 self.assertEqual(typing_union(""), "typing.Union") 2824 self.assertEqual(typing_union(None), "typing.Union") 2825 2826 @functools.singledispatch 2827 def types_union(arg): 2828 return "default" 2829 2830 @types_union.register 2831 def _(arg: int | None): 2832 return "types.UnionType" 2833 2834 self.assertEqual(types_union(""), "default") 2835 self.assertEqual(types_union(1), "types.UnionType") 2836 self.assertEqual(types_union(None), "types.UnionType") 2837 2838 def test_register_genericalias(self): 2839 @functools.singledispatch 2840 def f(arg): 2841 return "default" 2842 2843 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2844 f.register(list[int], lambda arg: "types.GenericAlias") 2845 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2846 f.register(typing.List[int], lambda arg: "typing.GenericAlias") 2847 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2848 f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)") 2849 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2850 f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]") 2851 2852 self.assertEqual(f([1]), "default") 2853 self.assertEqual(f([1.0]), "default") 2854 self.assertEqual(f(""), "default") 2855 self.assertEqual(f(b""), "default") 2856 2857 def test_register_genericalias_decorator(self): 2858 @functools.singledispatch 2859 def f(arg): 2860 return "default" 2861 2862 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2863 f.register(list[int]) 2864 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2865 f.register(typing.List[int]) 2866 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2867 f.register(list[int] | str) 2868 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2869 f.register(typing.List[int] | str) 2870 2871 def test_register_genericalias_annotation(self): 2872 @functools.singledispatch 2873 def f(arg): 2874 return "default" 2875 2876 with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): 2877 @f.register 2878 def _(arg: list[int]): 2879 return "types.GenericAlias" 2880 with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): 2881 @f.register 2882 def _(arg: typing.List[float]): 2883 return "typing.GenericAlias" 2884 with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): 2885 @f.register 2886 def _(arg: list[int] | str): 2887 return "types.UnionType(types.GenericAlias)" 2888 with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): 2889 @f.register 2890 def _(arg: typing.List[float] | bytes): 2891 return "typing.Union[typing.GenericAlias]" 2892 2893 self.assertEqual(f([1]), "default") 2894 self.assertEqual(f([1.0]), "default") 2895 self.assertEqual(f(""), "default") 2896 self.assertEqual(f(b""), "default") 2897 2898 2899class CachedCostItem: 2900 _cost = 1 2901 2902 def __init__(self): 2903 self.lock = py_functools.RLock() 2904 2905 @py_functools.cached_property 2906 def cost(self): 2907 """The cost of the item.""" 2908 with self.lock: 2909 self._cost += 1 2910 return self._cost 2911 2912 2913class OptionallyCachedCostItem: 2914 _cost = 1 2915 2916 def get_cost(self): 2917 """The cost of the item.""" 2918 self._cost += 1 2919 return self._cost 2920 2921 cached_cost = py_functools.cached_property(get_cost) 2922 2923 2924class CachedCostItemWait: 2925 2926 def __init__(self, event): 2927 self._cost = 1 2928 self.lock = py_functools.RLock() 2929 self.event = event 2930 2931 @py_functools.cached_property 2932 def cost(self): 2933 self.event.wait(1) 2934 with self.lock: 2935 self._cost += 1 2936 return self._cost 2937 2938 2939class CachedCostItemWithSlots: 2940 __slots__ = ('_cost') 2941 2942 def __init__(self): 2943 self._cost = 1 2944 2945 @py_functools.cached_property 2946 def cost(self): 2947 raise RuntimeError('never called, slots not supported') 2948 2949 2950class TestCachedProperty(unittest.TestCase): 2951 def test_cached(self): 2952 item = CachedCostItem() 2953 self.assertEqual(item.cost, 2) 2954 self.assertEqual(item.cost, 2) # not 3 2955 2956 def test_cached_attribute_name_differs_from_func_name(self): 2957 item = OptionallyCachedCostItem() 2958 self.assertEqual(item.get_cost(), 2) 2959 self.assertEqual(item.cached_cost, 3) 2960 self.assertEqual(item.get_cost(), 4) 2961 self.assertEqual(item.cached_cost, 3) 2962 2963 @threading_helper.requires_working_threading() 2964 def test_threaded(self): 2965 go = threading.Event() 2966 item = CachedCostItemWait(go) 2967 2968 num_threads = 3 2969 2970 orig_si = sys.getswitchinterval() 2971 sys.setswitchinterval(1e-6) 2972 try: 2973 threads = [ 2974 threading.Thread(target=lambda: item.cost) 2975 for k in range(num_threads) 2976 ] 2977 with threading_helper.start_threads(threads): 2978 go.set() 2979 finally: 2980 sys.setswitchinterval(orig_si) 2981 2982 self.assertEqual(item.cost, 2) 2983 2984 def test_object_with_slots(self): 2985 item = CachedCostItemWithSlots() 2986 with self.assertRaisesRegex( 2987 TypeError, 2988 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.", 2989 ): 2990 item.cost 2991 2992 def test_immutable_dict(self): 2993 class MyMeta(type): 2994 @py_functools.cached_property 2995 def prop(self): 2996 return True 2997 2998 class MyClass(metaclass=MyMeta): 2999 pass 3000 3001 with self.assertRaisesRegex( 3002 TypeError, 3003 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.", 3004 ): 3005 MyClass.prop 3006 3007 def test_reuse_different_names(self): 3008 """Disallow this case because decorated function a would not be cached.""" 3009 with self.assertRaises(RuntimeError) as ctx: 3010 class ReusedCachedProperty: 3011 @py_functools.cached_property 3012 def a(self): 3013 pass 3014 3015 b = a 3016 3017 self.assertEqual( 3018 str(ctx.exception.__context__), 3019 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b').")) 3020 ) 3021 3022 def test_reuse_same_name(self): 3023 """Reusing a cached_property on different classes under the same name is OK.""" 3024 counter = 0 3025 3026 @py_functools.cached_property 3027 def _cp(_self): 3028 nonlocal counter 3029 counter += 1 3030 return counter 3031 3032 class A: 3033 cp = _cp 3034 3035 class B: 3036 cp = _cp 3037 3038 a = A() 3039 b = B() 3040 3041 self.assertEqual(a.cp, 1) 3042 self.assertEqual(b.cp, 2) 3043 self.assertEqual(a.cp, 1) 3044 3045 def test_set_name_not_called(self): 3046 cp = py_functools.cached_property(lambda s: None) 3047 class Foo: 3048 pass 3049 3050 Foo.cp = cp 3051 3052 with self.assertRaisesRegex( 3053 TypeError, 3054 "Cannot use cached_property instance without calling __set_name__ on it.", 3055 ): 3056 Foo().cp 3057 3058 def test_access_from_class(self): 3059 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property) 3060 3061 def test_doc(self): 3062 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.") 3063 3064 3065if __name__ == '__main__': 3066 unittest.main() 3067