1import abc
2import builtins
3import collections
4import collections.abc
5import copy
6from itertools import permutations
7import pickle
8from random import choice
9import sys
10from test import support
11import threading
12import time
13import typing
14import unittest
15import unittest.mock
16import os
17import weakref
18import gc
19from weakref import proxy
20import contextlib
21
22from test.support import import_helper
23from test.support import threading_helper
24from test.support.script_helper import assert_python_ok
25
26import functools
27
28py_functools = import_helper.import_fresh_module('functools',
29                                                 blocked=['_functools'])
30c_functools = import_helper.import_fresh_module('functools')
31
32decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
33
34@contextlib.contextmanager
35def replaced_module(name, replacement):
36    original_module = sys.modules[name]
37    sys.modules[name] = replacement
38    try:
39        yield
40    finally:
41        sys.modules[name] = original_module
42
43def capture(*args, **kw):
44    """capture all positional and keyword arguments"""
45    return args, kw
46
47
48def signature(part):
49    """ return the signature of a partial object """
50    return (part.func, part.args, part.keywords, part.__dict__)
51
52class MyTuple(tuple):
53    pass
54
55class BadTuple(tuple):
56    def __add__(self, other):
57        return list(self) + list(other)
58
59class MyDict(dict):
60    pass
61
62
63class TestPartial:
64
65    def test_basic_examples(self):
66        p = self.partial(capture, 1, 2, a=10, b=20)
67        self.assertTrue(callable(p))
68        self.assertEqual(p(3, 4, b=30, c=40),
69                         ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
70        p = self.partial(map, lambda x: x*10)
71        self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
72
73    def test_attributes(self):
74        p = self.partial(capture, 1, 2, a=10, b=20)
75        # attributes should be readable
76        self.assertEqual(p.func, capture)
77        self.assertEqual(p.args, (1, 2))
78        self.assertEqual(p.keywords, dict(a=10, b=20))
79
80    def test_argument_checking(self):
81        self.assertRaises(TypeError, self.partial)     # need at least a func arg
82        try:
83            self.partial(2)()
84        except TypeError:
85            pass
86        else:
87            self.fail('First arg not checked for callability')
88
89    def test_protection_of_callers_dict_argument(self):
90        # a caller's dictionary should not be altered by partial
91        def func(a=10, b=20):
92            return a
93        d = {'a':3}
94        p = self.partial(func, a=5)
95        self.assertEqual(p(**d), 3)
96        self.assertEqual(d, {'a':3})
97        p(b=7)
98        self.assertEqual(d, {'a':3})
99
100    def test_kwargs_copy(self):
101        # Issue #29532: Altering a kwarg dictionary passed to a constructor
102        # should not affect a partial object after creation
103        d = {'a': 3}
104        p = self.partial(capture, **d)
105        self.assertEqual(p(), ((), {'a': 3}))
106        d['a'] = 5
107        self.assertEqual(p(), ((), {'a': 3}))
108
109    def test_arg_combinations(self):
110        # exercise special code paths for zero args in either partial
111        # object or the caller
112        p = self.partial(capture)
113        self.assertEqual(p(), ((), {}))
114        self.assertEqual(p(1,2), ((1,2), {}))
115        p = self.partial(capture, 1, 2)
116        self.assertEqual(p(), ((1,2), {}))
117        self.assertEqual(p(3,4), ((1,2,3,4), {}))
118
119    def test_kw_combinations(self):
120        # exercise special code paths for no keyword args in
121        # either the partial object or the caller
122        p = self.partial(capture)
123        self.assertEqual(p.keywords, {})
124        self.assertEqual(p(), ((), {}))
125        self.assertEqual(p(a=1), ((), {'a':1}))
126        p = self.partial(capture, a=1)
127        self.assertEqual(p.keywords, {'a':1})
128        self.assertEqual(p(), ((), {'a':1}))
129        self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
130        # keyword args in the call override those in the partial object
131        self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
132
133    def test_positional(self):
134        # make sure positional arguments are captured correctly
135        for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
136            p = self.partial(capture, *args)
137            expected = args + ('x',)
138            got, empty = p('x')
139            self.assertTrue(expected == got and empty == {})
140
141    def test_keyword(self):
142        # make sure keyword arguments are captured correctly
143        for a in ['a', 0, None, 3.5]:
144            p = self.partial(capture, a=a)
145            expected = {'a':a,'x':None}
146            empty, got = p(x=None)
147            self.assertTrue(expected == got and empty == ())
148
149    def test_no_side_effects(self):
150        # make sure there are no side effects that affect subsequent calls
151        p = self.partial(capture, 0, a=1)
152        args1, kw1 = p(1, b=2)
153        self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
154        args2, kw2 = p()
155        self.assertTrue(args2 == (0,) and kw2 == {'a':1})
156
157    def test_error_propagation(self):
158        def f(x, y):
159            x / y
160        self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
161        self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
162        self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
163        self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
164
165    def test_weakref(self):
166        f = self.partial(int, base=16)
167        p = proxy(f)
168        self.assertEqual(f.func, p.func)
169        f = None
170        support.gc_collect()  # For PyPy or other GCs.
171        self.assertRaises(ReferenceError, getattr, p, 'func')
172
173    def test_with_bound_and_unbound_methods(self):
174        data = list(map(str, range(10)))
175        join = self.partial(str.join, '')
176        self.assertEqual(join(data), '0123456789')
177        join = self.partial(''.join)
178        self.assertEqual(join(data), '0123456789')
179
180    def test_nested_optimization(self):
181        partial = self.partial
182        inner = partial(signature, 'asdf')
183        nested = partial(inner, bar=True)
184        flat = partial(signature, 'asdf', bar=True)
185        self.assertEqual(signature(nested), signature(flat))
186
187    def test_nested_partial_with_attribute(self):
188        # see issue 25137
189        partial = self.partial
190
191        def foo(bar):
192            return bar
193
194        p = partial(foo, 'first')
195        p2 = partial(p, 'second')
196        p2.new_attr = 'spam'
197        self.assertEqual(p2.new_attr, 'spam')
198
199    def test_repr(self):
200        args = (object(), object())
201        args_repr = ', '.join(repr(a) for a in args)
202        kwargs = {'a': object(), 'b': object()}
203        kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
204                        'b={b!r}, a={a!r}'.format_map(kwargs)]
205        if self.partial in (c_functools.partial, py_functools.partial):
206            name = 'functools.partial'
207        else:
208            name = self.partial.__name__
209
210        f = self.partial(capture)
211        self.assertEqual(f'{name}({capture!r})', repr(f))
212
213        f = self.partial(capture, *args)
214        self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
215
216        f = self.partial(capture, **kwargs)
217        self.assertIn(repr(f),
218                      [f'{name}({capture!r}, {kwargs_repr})'
219                       for kwargs_repr in kwargs_reprs])
220
221        f = self.partial(capture, *args, **kwargs)
222        self.assertIn(repr(f),
223                      [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
224                       for kwargs_repr in kwargs_reprs])
225
226    def test_recursive_repr(self):
227        if self.partial in (c_functools.partial, py_functools.partial):
228            name = 'functools.partial'
229        else:
230            name = self.partial.__name__
231
232        f = self.partial(capture)
233        f.__setstate__((f, (), {}, {}))
234        try:
235            self.assertEqual(repr(f), '%s(...)' % (name,))
236        finally:
237            f.__setstate__((capture, (), {}, {}))
238
239        f = self.partial(capture)
240        f.__setstate__((capture, (f,), {}, {}))
241        try:
242            self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
243        finally:
244            f.__setstate__((capture, (), {}, {}))
245
246        f = self.partial(capture)
247        f.__setstate__((capture, (), {'a': f}, {}))
248        try:
249            self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
250        finally:
251            f.__setstate__((capture, (), {}, {}))
252
253    def test_pickle(self):
254        with self.AllowPickle():
255            f = self.partial(signature, ['asdf'], bar=[True])
256            f.attr = []
257            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
258                f_copy = pickle.loads(pickle.dumps(f, proto))
259                self.assertEqual(signature(f_copy), signature(f))
260
261    def test_copy(self):
262        f = self.partial(signature, ['asdf'], bar=[True])
263        f.attr = []
264        f_copy = copy.copy(f)
265        self.assertEqual(signature(f_copy), signature(f))
266        self.assertIs(f_copy.attr, f.attr)
267        self.assertIs(f_copy.args, f.args)
268        self.assertIs(f_copy.keywords, f.keywords)
269
270    def test_deepcopy(self):
271        f = self.partial(signature, ['asdf'], bar=[True])
272        f.attr = []
273        f_copy = copy.deepcopy(f)
274        self.assertEqual(signature(f_copy), signature(f))
275        self.assertIsNot(f_copy.attr, f.attr)
276        self.assertIsNot(f_copy.args, f.args)
277        self.assertIsNot(f_copy.args[0], f.args[0])
278        self.assertIsNot(f_copy.keywords, f.keywords)
279        self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
280
281    def test_setstate(self):
282        f = self.partial(signature)
283        f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
284
285        self.assertEqual(signature(f),
286                         (capture, (1,), dict(a=10), dict(attr=[])))
287        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
288
289        f.__setstate__((capture, (1,), dict(a=10), None))
290
291        self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
292        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
293
294        f.__setstate__((capture, (1,), None, None))
295        #self.assertEqual(signature(f), (capture, (1,), {}, {}))
296        self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
297        self.assertEqual(f(2), ((1, 2), {}))
298        self.assertEqual(f(), ((1,), {}))
299
300        f.__setstate__((capture, (), {}, None))
301        self.assertEqual(signature(f), (capture, (), {}, {}))
302        self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
303        self.assertEqual(f(2), ((2,), {}))
304        self.assertEqual(f(), ((), {}))
305
306    def test_setstate_errors(self):
307        f = self.partial(signature)
308        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
309        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
310        self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
311        self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
312        self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
313        self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
314        self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
315
316    def test_setstate_subclasses(self):
317        f = self.partial(signature)
318        f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
319        s = signature(f)
320        self.assertEqual(s, (capture, (1,), dict(a=10), {}))
321        self.assertIs(type(s[1]), tuple)
322        self.assertIs(type(s[2]), dict)
323        r = f()
324        self.assertEqual(r, ((1,), {'a': 10}))
325        self.assertIs(type(r[0]), tuple)
326        self.assertIs(type(r[1]), dict)
327
328        f.__setstate__((capture, BadTuple((1,)), {}, None))
329        s = signature(f)
330        self.assertEqual(s, (capture, (1,), {}, {}))
331        self.assertIs(type(s[1]), tuple)
332        r = f(2)
333        self.assertEqual(r, ((1, 2), {}))
334        self.assertIs(type(r[0]), tuple)
335
336    def test_recursive_pickle(self):
337        with self.AllowPickle():
338            f = self.partial(capture)
339            f.__setstate__((f, (), {}, {}))
340            try:
341                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342                    with self.assertRaises(RecursionError):
343                        pickle.dumps(f, proto)
344            finally:
345                f.__setstate__((capture, (), {}, {}))
346
347            f = self.partial(capture)
348            f.__setstate__((capture, (f,), {}, {}))
349            try:
350                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
351                    f_copy = pickle.loads(pickle.dumps(f, proto))
352                    try:
353                        self.assertIs(f_copy.args[0], f_copy)
354                    finally:
355                        f_copy.__setstate__((capture, (), {}, {}))
356            finally:
357                f.__setstate__((capture, (), {}, {}))
358
359            f = self.partial(capture)
360            f.__setstate__((capture, (), {'a': f}, {}))
361            try:
362                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
363                    f_copy = pickle.loads(pickle.dumps(f, proto))
364                    try:
365                        self.assertIs(f_copy.keywords['a'], f_copy)
366                    finally:
367                        f_copy.__setstate__((capture, (), {}, {}))
368            finally:
369                f.__setstate__((capture, (), {}, {}))
370
371    # Issue 6083: Reference counting bug
372    def test_setstate_refcount(self):
373        class BadSequence:
374            def __len__(self):
375                return 4
376            def __getitem__(self, key):
377                if key == 0:
378                    return max
379                elif key == 1:
380                    return tuple(range(1000000))
381                elif key in (2, 3):
382                    return {}
383                raise IndexError
384
385        f = self.partial(object)
386        self.assertRaises(TypeError, f.__setstate__, BadSequence())
387
388@unittest.skipUnless(c_functools, 'requires the C _functools module')
389class TestPartialC(TestPartial, unittest.TestCase):
390    if c_functools:
391        partial = c_functools.partial
392
393    class AllowPickle:
394        def __enter__(self):
395            return self
396        def __exit__(self, type, value, tb):
397            return False
398
399    def test_attributes_unwritable(self):
400        # attributes should not be writable
401        p = self.partial(capture, 1, 2, a=10, b=20)
402        self.assertRaises(AttributeError, setattr, p, 'func', map)
403        self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
404        self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
405
406        p = self.partial(hex)
407        try:
408            del p.__dict__
409        except TypeError:
410            pass
411        else:
412            self.fail('partial object allowed __dict__ to be deleted')
413
414    def test_manually_adding_non_string_keyword(self):
415        p = self.partial(capture)
416        # Adding a non-string/unicode keyword to partial kwargs
417        p.keywords[1234] = 'value'
418        r = repr(p)
419        self.assertIn('1234', r)
420        self.assertIn("'value'", r)
421        with self.assertRaises(TypeError):
422            p()
423
424    def test_keystr_replaces_value(self):
425        p = self.partial(capture)
426
427        class MutatesYourDict(object):
428            def __str__(self):
429                p.keywords[self] = ['sth2']
430                return 'astr'
431
432        # Replacing the value during key formatting should keep the original
433        # value alive (at least long enough).
434        p.keywords[MutatesYourDict()] = ['sth']
435        r = repr(p)
436        self.assertIn('astr', r)
437        self.assertIn("['sth']", r)
438
439
440class TestPartialPy(TestPartial, unittest.TestCase):
441    partial = py_functools.partial
442
443    class AllowPickle:
444        def __init__(self):
445            self._cm = replaced_module("functools", py_functools)
446        def __enter__(self):
447            return self._cm.__enter__()
448        def __exit__(self, type, value, tb):
449            return self._cm.__exit__(type, value, tb)
450
451if c_functools:
452    class CPartialSubclass(c_functools.partial):
453        pass
454
455class PyPartialSubclass(py_functools.partial):
456    pass
457
458@unittest.skipUnless(c_functools, 'requires the C _functools module')
459class TestPartialCSubclass(TestPartialC):
460    if c_functools:
461        partial = CPartialSubclass
462
463    # partial subclasses are not optimized for nested calls
464    test_nested_optimization = None
465
466class TestPartialPySubclass(TestPartialPy):
467    partial = PyPartialSubclass
468
469class TestPartialMethod(unittest.TestCase):
470
471    class A(object):
472        nothing = functools.partialmethod(capture)
473        positional = functools.partialmethod(capture, 1)
474        keywords = functools.partialmethod(capture, a=2)
475        both = functools.partialmethod(capture, 3, b=4)
476        spec_keywords = functools.partialmethod(capture, self=1, func=2)
477
478        nested = functools.partialmethod(positional, 5)
479
480        over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
481
482        static = functools.partialmethod(staticmethod(capture), 8)
483        cls = functools.partialmethod(classmethod(capture), d=9)
484
485    a = A()
486
487    def test_arg_combinations(self):
488        self.assertEqual(self.a.nothing(), ((self.a,), {}))
489        self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
490        self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
491        self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
492
493        self.assertEqual(self.a.positional(), ((self.a, 1), {}))
494        self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
495        self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
496        self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
497
498        self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
499        self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
500        self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
501        self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
502
503        self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
504        self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
505        self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
506        self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
507
508        self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
509
510        self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
511
512    def test_nested(self):
513        self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
514        self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
515        self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
516        self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
517
518        self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
519
520    def test_over_partial(self):
521        self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
522        self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
523        self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
524        self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
525
526        self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
527
528    def test_bound_method_introspection(self):
529        obj = self.a
530        self.assertIs(obj.both.__self__, obj)
531        self.assertIs(obj.nested.__self__, obj)
532        self.assertIs(obj.over_partial.__self__, obj)
533        self.assertIs(obj.cls.__self__, self.A)
534        self.assertIs(self.A.cls.__self__, self.A)
535
536    def test_unbound_method_retrieval(self):
537        obj = self.A
538        self.assertFalse(hasattr(obj.both, "__self__"))
539        self.assertFalse(hasattr(obj.nested, "__self__"))
540        self.assertFalse(hasattr(obj.over_partial, "__self__"))
541        self.assertFalse(hasattr(obj.static, "__self__"))
542        self.assertFalse(hasattr(self.a.static, "__self__"))
543
544    def test_descriptors(self):
545        for obj in [self.A, self.a]:
546            with self.subTest(obj=obj):
547                self.assertEqual(obj.static(), ((8,), {}))
548                self.assertEqual(obj.static(5), ((8, 5), {}))
549                self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
550                self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
551
552                self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
553                self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
554                self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
555                self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
556
557    def test_overriding_keywords(self):
558        self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
559        self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
560
561    def test_invalid_args(self):
562        with self.assertRaises(TypeError):
563            class B(object):
564                method = functools.partialmethod(None, 1)
565        with self.assertRaises(TypeError):
566            class B:
567                method = functools.partialmethod()
568        with self.assertRaises(TypeError):
569            class B:
570                method = functools.partialmethod(func=capture, a=1)
571
572    def test_repr(self):
573        self.assertEqual(repr(vars(self.A)['both']),
574                         'functools.partialmethod({}, 3, b=4)'.format(capture))
575
576    def test_abstract(self):
577        class Abstract(abc.ABCMeta):
578
579            @abc.abstractmethod
580            def add(self, x, y):
581                pass
582
583            add5 = functools.partialmethod(add, 5)
584
585        self.assertTrue(Abstract.add.__isabstractmethod__)
586        self.assertTrue(Abstract.add5.__isabstractmethod__)
587
588        for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
589            self.assertFalse(getattr(func, '__isabstractmethod__', False))
590
591    def test_positional_only(self):
592        def f(a, b, /):
593            return a + b
594
595        p = functools.partial(f, 1)
596        self.assertEqual(p(2), f(1, 2))
597
598
599class TestUpdateWrapper(unittest.TestCase):
600
601    def check_wrapper(self, wrapper, wrapped,
602                      assigned=functools.WRAPPER_ASSIGNMENTS,
603                      updated=functools.WRAPPER_UPDATES):
604        # Check attributes were assigned
605        for name in assigned:
606            self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
607        # Check attributes were updated
608        for name in updated:
609            wrapper_attr = getattr(wrapper, name)
610            wrapped_attr = getattr(wrapped, name)
611            for key in wrapped_attr:
612                if name == "__dict__" and key == "__wrapped__":
613                    # __wrapped__ is overwritten by the update code
614                    continue
615                self.assertIs(wrapped_attr[key], wrapper_attr[key])
616        # Check __wrapped__
617        self.assertIs(wrapper.__wrapped__, wrapped)
618
619
620    def _default_update(self):
621        def f(a:'This is a new annotation'):
622            """This is a test"""
623            pass
624        f.attr = 'This is also a test'
625        f.__wrapped__ = "This is a bald faced lie"
626        def wrapper(b:'This is the prior annotation'):
627            pass
628        functools.update_wrapper(wrapper, f)
629        return wrapper, f
630
631    def test_default_update(self):
632        wrapper, f = self._default_update()
633        self.check_wrapper(wrapper, f)
634        self.assertIs(wrapper.__wrapped__, f)
635        self.assertEqual(wrapper.__name__, 'f')
636        self.assertEqual(wrapper.__qualname__, f.__qualname__)
637        self.assertEqual(wrapper.attr, 'This is also a test')
638        self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
639        self.assertNotIn('b', wrapper.__annotations__)
640
641    @unittest.skipIf(sys.flags.optimize >= 2,
642                     "Docstrings are omitted with -O2 and above")
643    def test_default_update_doc(self):
644        wrapper, f = self._default_update()
645        self.assertEqual(wrapper.__doc__, 'This is a test')
646
647    def test_no_update(self):
648        def f():
649            """This is a test"""
650            pass
651        f.attr = 'This is also a test'
652        def wrapper():
653            pass
654        functools.update_wrapper(wrapper, f, (), ())
655        self.check_wrapper(wrapper, f, (), ())
656        self.assertEqual(wrapper.__name__, 'wrapper')
657        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
658        self.assertEqual(wrapper.__doc__, None)
659        self.assertEqual(wrapper.__annotations__, {})
660        self.assertFalse(hasattr(wrapper, 'attr'))
661
662    def test_selective_update(self):
663        def f():
664            pass
665        f.attr = 'This is a different test'
666        f.dict_attr = dict(a=1, b=2, c=3)
667        def wrapper():
668            pass
669        wrapper.dict_attr = {}
670        assign = ('attr',)
671        update = ('dict_attr',)
672        functools.update_wrapper(wrapper, f, assign, update)
673        self.check_wrapper(wrapper, f, assign, update)
674        self.assertEqual(wrapper.__name__, 'wrapper')
675        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
676        self.assertEqual(wrapper.__doc__, None)
677        self.assertEqual(wrapper.attr, 'This is a different test')
678        self.assertEqual(wrapper.dict_attr, f.dict_attr)
679
680    def test_missing_attributes(self):
681        def f():
682            pass
683        def wrapper():
684            pass
685        wrapper.dict_attr = {}
686        assign = ('attr',)
687        update = ('dict_attr',)
688        # Missing attributes on wrapped object are ignored
689        functools.update_wrapper(wrapper, f, assign, update)
690        self.assertNotIn('attr', wrapper.__dict__)
691        self.assertEqual(wrapper.dict_attr, {})
692        # Wrapper must have expected attributes for updating
693        del wrapper.dict_attr
694        with self.assertRaises(AttributeError):
695            functools.update_wrapper(wrapper, f, assign, update)
696        wrapper.dict_attr = 1
697        with self.assertRaises(AttributeError):
698            functools.update_wrapper(wrapper, f, assign, update)
699
700    @support.requires_docstrings
701    @unittest.skipIf(sys.flags.optimize >= 2,
702                     "Docstrings are omitted with -O2 and above")
703    def test_builtin_update(self):
704        # Test for bug #1576241
705        def wrapper():
706            pass
707        functools.update_wrapper(wrapper, max)
708        self.assertEqual(wrapper.__name__, 'max')
709        self.assertTrue(wrapper.__doc__.startswith('max('))
710        self.assertEqual(wrapper.__annotations__, {})
711
712
713class TestWraps(TestUpdateWrapper):
714
715    def _default_update(self):
716        def f():
717            """This is a test"""
718            pass
719        f.attr = 'This is also a test'
720        f.__wrapped__ = "This is still a bald faced lie"
721        @functools.wraps(f)
722        def wrapper():
723            pass
724        return wrapper, f
725
726    def test_default_update(self):
727        wrapper, f = self._default_update()
728        self.check_wrapper(wrapper, f)
729        self.assertEqual(wrapper.__name__, 'f')
730        self.assertEqual(wrapper.__qualname__, f.__qualname__)
731        self.assertEqual(wrapper.attr, 'This is also a test')
732
733    @unittest.skipIf(sys.flags.optimize >= 2,
734                     "Docstrings are omitted with -O2 and above")
735    def test_default_update_doc(self):
736        wrapper, _ = self._default_update()
737        self.assertEqual(wrapper.__doc__, 'This is a test')
738
739    def test_no_update(self):
740        def f():
741            """This is a test"""
742            pass
743        f.attr = 'This is also a test'
744        @functools.wraps(f, (), ())
745        def wrapper():
746            pass
747        self.check_wrapper(wrapper, f, (), ())
748        self.assertEqual(wrapper.__name__, 'wrapper')
749        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
750        self.assertEqual(wrapper.__doc__, None)
751        self.assertFalse(hasattr(wrapper, 'attr'))
752
753    def test_selective_update(self):
754        def f():
755            pass
756        f.attr = 'This is a different test'
757        f.dict_attr = dict(a=1, b=2, c=3)
758        def add_dict_attr(f):
759            f.dict_attr = {}
760            return f
761        assign = ('attr',)
762        update = ('dict_attr',)
763        @functools.wraps(f, assign, update)
764        @add_dict_attr
765        def wrapper():
766            pass
767        self.check_wrapper(wrapper, f, assign, update)
768        self.assertEqual(wrapper.__name__, 'wrapper')
769        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
770        self.assertEqual(wrapper.__doc__, None)
771        self.assertEqual(wrapper.attr, 'This is a different test')
772        self.assertEqual(wrapper.dict_attr, f.dict_attr)
773
774
775class TestReduce:
776    def test_reduce(self):
777        class Squares:
778            def __init__(self, max):
779                self.max = max
780                self.sofar = []
781
782            def __len__(self):
783                return len(self.sofar)
784
785            def __getitem__(self, i):
786                if not 0 <= i < self.max: raise IndexError
787                n = len(self.sofar)
788                while n <= i:
789                    self.sofar.append(n*n)
790                    n += 1
791                return self.sofar[i]
792        def add(x, y):
793            return x + y
794        self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
795        self.assertEqual(
796            self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
797            ['a','c','d','w']
798        )
799        self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
800        self.assertEqual(
801            self.reduce(lambda x, y: x*y, range(2,21), 1),
802            2432902008176640000
803        )
804        self.assertEqual(self.reduce(add, Squares(10)), 285)
805        self.assertEqual(self.reduce(add, Squares(10), 0), 285)
806        self.assertEqual(self.reduce(add, Squares(0), 0), 0)
807        self.assertRaises(TypeError, self.reduce)
808        self.assertRaises(TypeError, self.reduce, 42, 42)
809        self.assertRaises(TypeError, self.reduce, 42, 42, 42)
810        self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
811        self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
812        self.assertRaises(TypeError, self.reduce, 42, (42, 42))
813        self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
814        self.assertRaises(TypeError, self.reduce, add, "")
815        self.assertRaises(TypeError, self.reduce, add, ())
816        self.assertRaises(TypeError, self.reduce, add, object())
817
818        class TestFailingIter:
819            def __iter__(self):
820                raise RuntimeError
821        self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
822
823        self.assertEqual(self.reduce(add, [], None), None)
824        self.assertEqual(self.reduce(add, [], 42), 42)
825
826        class BadSeq:
827            def __getitem__(self, index):
828                raise ValueError
829        self.assertRaises(ValueError, self.reduce, 42, BadSeq())
830
831    # Test reduce()'s use of iterators.
832    def test_iterator_usage(self):
833        class SequenceClass:
834            def __init__(self, n):
835                self.n = n
836            def __getitem__(self, i):
837                if 0 <= i < self.n:
838                    return i
839                else:
840                    raise IndexError
841
842        from operator import add
843        self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
844        self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
845        self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
846        self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
847        self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
848        self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
849
850        d = {"one": 1, "two": 2, "three": 3}
851        self.assertEqual(self.reduce(add, d), "".join(d.keys()))
852
853
854@unittest.skipUnless(c_functools, 'requires the C _functools module')
855class TestReduceC(TestReduce, unittest.TestCase):
856    if c_functools:
857        reduce = c_functools.reduce
858
859
860class TestReducePy(TestReduce, unittest.TestCase):
861    reduce = staticmethod(py_functools.reduce)
862
863
864class TestCmpToKey:
865
866    def test_cmp_to_key(self):
867        def cmp1(x, y):
868            return (x > y) - (x < y)
869        key = self.cmp_to_key(cmp1)
870        self.assertEqual(key(3), key(3))
871        self.assertGreater(key(3), key(1))
872        self.assertGreaterEqual(key(3), key(3))
873
874        def cmp2(x, y):
875            return int(x) - int(y)
876        key = self.cmp_to_key(cmp2)
877        self.assertEqual(key(4.0), key('4'))
878        self.assertLess(key(2), key('35'))
879        self.assertLessEqual(key(2), key('35'))
880        self.assertNotEqual(key(2), key('35'))
881
882    def test_cmp_to_key_arguments(self):
883        def cmp1(x, y):
884            return (x > y) - (x < y)
885        key = self.cmp_to_key(mycmp=cmp1)
886        self.assertEqual(key(obj=3), key(obj=3))
887        self.assertGreater(key(obj=3), key(obj=1))
888        with self.assertRaises((TypeError, AttributeError)):
889            key(3) > 1    # rhs is not a K object
890        with self.assertRaises((TypeError, AttributeError)):
891            1 < key(3)    # lhs is not a K object
892        with self.assertRaises(TypeError):
893            key = self.cmp_to_key()             # too few args
894        with self.assertRaises(TypeError):
895            key = self.cmp_to_key(cmp1, None)   # too many args
896        key = self.cmp_to_key(cmp1)
897        with self.assertRaises(TypeError):
898            key()                                    # too few args
899        with self.assertRaises(TypeError):
900            key(None, None)                          # too many args
901
902    def test_bad_cmp(self):
903        def cmp1(x, y):
904            raise ZeroDivisionError
905        key = self.cmp_to_key(cmp1)
906        with self.assertRaises(ZeroDivisionError):
907            key(3) > key(1)
908
909        class BadCmp:
910            def __lt__(self, other):
911                raise ZeroDivisionError
912        def cmp1(x, y):
913            return BadCmp()
914        with self.assertRaises(ZeroDivisionError):
915            key(3) > key(1)
916
917    def test_obj_field(self):
918        def cmp1(x, y):
919            return (x > y) - (x < y)
920        key = self.cmp_to_key(mycmp=cmp1)
921        self.assertEqual(key(50).obj, 50)
922
923    def test_sort_int(self):
924        def mycmp(x, y):
925            return y - x
926        self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
927                         [4, 3, 2, 1, 0])
928
929    def test_sort_int_str(self):
930        def mycmp(x, y):
931            x, y = int(x), int(y)
932            return (x > y) - (x < y)
933        values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
934        values = sorted(values, key=self.cmp_to_key(mycmp))
935        self.assertEqual([int(value) for value in values],
936                         [0, 1, 1, 2, 3, 4, 5, 7, 10])
937
938    def test_hash(self):
939        def mycmp(x, y):
940            return y - x
941        key = self.cmp_to_key(mycmp)
942        k = key(10)
943        self.assertRaises(TypeError, hash, k)
944        self.assertNotIsInstance(k, collections.abc.Hashable)
945
946
947@unittest.skipUnless(c_functools, 'requires the C _functools module')
948class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
949    if c_functools:
950        cmp_to_key = c_functools.cmp_to_key
951
952    @support.cpython_only
953    def test_disallow_instantiation(self):
954        # Ensure that the type disallows instantiation (bpo-43916)
955        support.check_disallow_instantiation(
956            self, type(c_functools.cmp_to_key(None))
957        )
958
959
960class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
961    cmp_to_key = staticmethod(py_functools.cmp_to_key)
962
963
964class TestTotalOrdering(unittest.TestCase):
965
966    def test_total_ordering_lt(self):
967        @functools.total_ordering
968        class A:
969            def __init__(self, value):
970                self.value = value
971            def __lt__(self, other):
972                return self.value < other.value
973            def __eq__(self, other):
974                return self.value == other.value
975        self.assertTrue(A(1) < A(2))
976        self.assertTrue(A(2) > A(1))
977        self.assertTrue(A(1) <= A(2))
978        self.assertTrue(A(2) >= A(1))
979        self.assertTrue(A(2) <= A(2))
980        self.assertTrue(A(2) >= A(2))
981        self.assertFalse(A(1) > A(2))
982
983    def test_total_ordering_le(self):
984        @functools.total_ordering
985        class A:
986            def __init__(self, value):
987                self.value = value
988            def __le__(self, other):
989                return self.value <= other.value
990            def __eq__(self, other):
991                return self.value == other.value
992        self.assertTrue(A(1) < A(2))
993        self.assertTrue(A(2) > A(1))
994        self.assertTrue(A(1) <= A(2))
995        self.assertTrue(A(2) >= A(1))
996        self.assertTrue(A(2) <= A(2))
997        self.assertTrue(A(2) >= A(2))
998        self.assertFalse(A(1) >= A(2))
999
1000    def test_total_ordering_gt(self):
1001        @functools.total_ordering
1002        class A:
1003            def __init__(self, value):
1004                self.value = value
1005            def __gt__(self, other):
1006                return self.value > other.value
1007            def __eq__(self, other):
1008                return self.value == other.value
1009        self.assertTrue(A(1) < A(2))
1010        self.assertTrue(A(2) > A(1))
1011        self.assertTrue(A(1) <= A(2))
1012        self.assertTrue(A(2) >= A(1))
1013        self.assertTrue(A(2) <= A(2))
1014        self.assertTrue(A(2) >= A(2))
1015        self.assertFalse(A(2) < A(1))
1016
1017    def test_total_ordering_ge(self):
1018        @functools.total_ordering
1019        class A:
1020            def __init__(self, value):
1021                self.value = value
1022            def __ge__(self, other):
1023                return self.value >= other.value
1024            def __eq__(self, other):
1025                return self.value == other.value
1026        self.assertTrue(A(1) < A(2))
1027        self.assertTrue(A(2) > A(1))
1028        self.assertTrue(A(1) <= A(2))
1029        self.assertTrue(A(2) >= A(1))
1030        self.assertTrue(A(2) <= A(2))
1031        self.assertTrue(A(2) >= A(2))
1032        self.assertFalse(A(2) <= A(1))
1033
1034    def test_total_ordering_no_overwrite(self):
1035        # new methods should not overwrite existing
1036        @functools.total_ordering
1037        class A(int):
1038            pass
1039        self.assertTrue(A(1) < A(2))
1040        self.assertTrue(A(2) > A(1))
1041        self.assertTrue(A(1) <= A(2))
1042        self.assertTrue(A(2) >= A(1))
1043        self.assertTrue(A(2) <= A(2))
1044        self.assertTrue(A(2) >= A(2))
1045
1046    def test_no_operations_defined(self):
1047        with self.assertRaises(ValueError):
1048            @functools.total_ordering
1049            class A:
1050                pass
1051
1052    def test_notimplemented(self):
1053        # Verify NotImplemented results are correctly handled
1054        @functools.total_ordering
1055        class ImplementsLessThan:
1056            def __init__(self, value):
1057                self.value = value
1058            def __eq__(self, other):
1059                if isinstance(other, ImplementsLessThan):
1060                    return self.value == other.value
1061                return False
1062            def __lt__(self, other):
1063                if isinstance(other, ImplementsLessThan):
1064                    return self.value < other.value
1065                return NotImplemented
1066
1067        @functools.total_ordering
1068        class ImplementsLessThanEqualTo:
1069            def __init__(self, value):
1070                self.value = value
1071            def __eq__(self, other):
1072                if isinstance(other, ImplementsLessThanEqualTo):
1073                    return self.value == other.value
1074                return False
1075            def __le__(self, other):
1076                if isinstance(other, ImplementsLessThanEqualTo):
1077                    return self.value <= other.value
1078                return NotImplemented
1079
1080        @functools.total_ordering
1081        class ImplementsGreaterThan:
1082            def __init__(self, value):
1083                self.value = value
1084            def __eq__(self, other):
1085                if isinstance(other, ImplementsGreaterThan):
1086                    return self.value == other.value
1087                return False
1088            def __gt__(self, other):
1089                if isinstance(other, ImplementsGreaterThan):
1090                    return self.value > other.value
1091                return NotImplemented
1092
1093        @functools.total_ordering
1094        class ImplementsGreaterThanEqualTo:
1095            def __init__(self, value):
1096                self.value = value
1097            def __eq__(self, other):
1098                if isinstance(other, ImplementsGreaterThanEqualTo):
1099                    return self.value == other.value
1100                return False
1101            def __ge__(self, other):
1102                if isinstance(other, ImplementsGreaterThanEqualTo):
1103                    return self.value >= other.value
1104                return NotImplemented
1105
1106        self.assertIs(ImplementsLessThan(1).__le__(1), NotImplemented)
1107        self.assertIs(ImplementsLessThan(1).__gt__(1), NotImplemented)
1108        self.assertIs(ImplementsLessThan(1).__ge__(1), NotImplemented)
1109        self.assertIs(ImplementsLessThanEqualTo(1).__lt__(1), NotImplemented)
1110        self.assertIs(ImplementsLessThanEqualTo(1).__gt__(1), NotImplemented)
1111        self.assertIs(ImplementsLessThanEqualTo(1).__ge__(1), NotImplemented)
1112        self.assertIs(ImplementsGreaterThan(1).__lt__(1), NotImplemented)
1113        self.assertIs(ImplementsGreaterThan(1).__gt__(1), NotImplemented)
1114        self.assertIs(ImplementsGreaterThan(1).__ge__(1), NotImplemented)
1115        self.assertIs(ImplementsGreaterThanEqualTo(1).__lt__(1), NotImplemented)
1116        self.assertIs(ImplementsGreaterThanEqualTo(1).__le__(1), NotImplemented)
1117        self.assertIs(ImplementsGreaterThanEqualTo(1).__gt__(1), NotImplemented)
1118
1119    def test_type_error_when_not_implemented(self):
1120        # bug 10042; ensure stack overflow does not occur
1121        # when decorated types return NotImplemented
1122        @functools.total_ordering
1123        class ImplementsLessThan:
1124            def __init__(self, value):
1125                self.value = value
1126            def __eq__(self, other):
1127                if isinstance(other, ImplementsLessThan):
1128                    return self.value == other.value
1129                return False
1130            def __lt__(self, other):
1131                if isinstance(other, ImplementsLessThan):
1132                    return self.value < other.value
1133                return NotImplemented
1134
1135        @functools.total_ordering
1136        class ImplementsGreaterThan:
1137            def __init__(self, value):
1138                self.value = value
1139            def __eq__(self, other):
1140                if isinstance(other, ImplementsGreaterThan):
1141                    return self.value == other.value
1142                return False
1143            def __gt__(self, other):
1144                if isinstance(other, ImplementsGreaterThan):
1145                    return self.value > other.value
1146                return NotImplemented
1147
1148        @functools.total_ordering
1149        class ImplementsLessThanEqualTo:
1150            def __init__(self, value):
1151                self.value = value
1152            def __eq__(self, other):
1153                if isinstance(other, ImplementsLessThanEqualTo):
1154                    return self.value == other.value
1155                return False
1156            def __le__(self, other):
1157                if isinstance(other, ImplementsLessThanEqualTo):
1158                    return self.value <= other.value
1159                return NotImplemented
1160
1161        @functools.total_ordering
1162        class ImplementsGreaterThanEqualTo:
1163            def __init__(self, value):
1164                self.value = value
1165            def __eq__(self, other):
1166                if isinstance(other, ImplementsGreaterThanEqualTo):
1167                    return self.value == other.value
1168                return False
1169            def __ge__(self, other):
1170                if isinstance(other, ImplementsGreaterThanEqualTo):
1171                    return self.value >= other.value
1172                return NotImplemented
1173
1174        @functools.total_ordering
1175        class ComparatorNotImplemented:
1176            def __init__(self, value):
1177                self.value = value
1178            def __eq__(self, other):
1179                if isinstance(other, ComparatorNotImplemented):
1180                    return self.value == other.value
1181                return False
1182            def __lt__(self, other):
1183                return NotImplemented
1184
1185        with self.subTest("LT < 1"), self.assertRaises(TypeError):
1186            ImplementsLessThan(-1) < 1
1187
1188        with self.subTest("LT < LE"), self.assertRaises(TypeError):
1189            ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1190
1191        with self.subTest("LT < GT"), self.assertRaises(TypeError):
1192            ImplementsLessThan(1) < ImplementsGreaterThan(1)
1193
1194        with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1195            ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1196
1197        with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1198            ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1199
1200        with self.subTest("GT > GE"), self.assertRaises(TypeError):
1201            ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1202
1203        with self.subTest("GT > LT"), self.assertRaises(TypeError):
1204            ImplementsGreaterThan(5) > ImplementsLessThan(5)
1205
1206        with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1207            ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1208
1209        with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1210            ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1211
1212        with self.subTest("GE when equal"):
1213            a = ComparatorNotImplemented(8)
1214            b = ComparatorNotImplemented(8)
1215            self.assertEqual(a, b)
1216            with self.assertRaises(TypeError):
1217                a >= b
1218
1219        with self.subTest("LE when equal"):
1220            a = ComparatorNotImplemented(9)
1221            b = ComparatorNotImplemented(9)
1222            self.assertEqual(a, b)
1223            with self.assertRaises(TypeError):
1224                a <= b
1225
1226    def test_pickle(self):
1227        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1228            for name in '__lt__', '__gt__', '__le__', '__ge__':
1229                with self.subTest(method=name, proto=proto):
1230                    method = getattr(Orderable_LT, name)
1231                    method_copy = pickle.loads(pickle.dumps(method, proto))
1232                    self.assertIs(method_copy, method)
1233
1234
1235    def test_total_ordering_for_metaclasses_issue_44605(self):
1236
1237        @functools.total_ordering
1238        class SortableMeta(type):
1239            def __new__(cls, name, bases, ns):
1240                return super().__new__(cls, name, bases, ns)
1241
1242            def __lt__(self, other):
1243                if not isinstance(other, SortableMeta):
1244                    pass
1245                return self.__name__ < other.__name__
1246
1247            def __eq__(self, other):
1248                if not isinstance(other, SortableMeta):
1249                    pass
1250                return self.__name__ == other.__name__
1251
1252        class B(metaclass=SortableMeta):
1253            pass
1254
1255        class A(metaclass=SortableMeta):
1256            pass
1257
1258        self.assertTrue(A < B)
1259        self.assertFalse(A > B)
1260
1261
1262@functools.total_ordering
1263class Orderable_LT:
1264    def __init__(self, value):
1265        self.value = value
1266    def __lt__(self, other):
1267        return self.value < other.value
1268    def __eq__(self, other):
1269        return self.value == other.value
1270
1271
1272class TestCache:
1273    # This tests that the pass-through is working as designed.
1274    # The underlying functionality is tested in TestLRU.
1275
1276    def test_cache(self):
1277        @self.module.cache
1278        def fib(n):
1279            if n < 2:
1280                return n
1281            return fib(n-1) + fib(n-2)
1282        self.assertEqual([fib(n) for n in range(16)],
1283            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1284        self.assertEqual(fib.cache_info(),
1285            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1286        fib.cache_clear()
1287        self.assertEqual(fib.cache_info(),
1288            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1289
1290
1291class TestLRU:
1292
1293    def test_lru(self):
1294        def orig(x, y):
1295            return 3 * x + y
1296        f = self.module.lru_cache(maxsize=20)(orig)
1297        hits, misses, maxsize, currsize = f.cache_info()
1298        self.assertEqual(maxsize, 20)
1299        self.assertEqual(currsize, 0)
1300        self.assertEqual(hits, 0)
1301        self.assertEqual(misses, 0)
1302
1303        domain = range(5)
1304        for i in range(1000):
1305            x, y = choice(domain), choice(domain)
1306            actual = f(x, y)
1307            expected = orig(x, y)
1308            self.assertEqual(actual, expected)
1309        hits, misses, maxsize, currsize = f.cache_info()
1310        self.assertTrue(hits > misses)
1311        self.assertEqual(hits + misses, 1000)
1312        self.assertEqual(currsize, 20)
1313
1314        f.cache_clear()   # test clearing
1315        hits, misses, maxsize, currsize = f.cache_info()
1316        self.assertEqual(hits, 0)
1317        self.assertEqual(misses, 0)
1318        self.assertEqual(currsize, 0)
1319        f(x, y)
1320        hits, misses, maxsize, currsize = f.cache_info()
1321        self.assertEqual(hits, 0)
1322        self.assertEqual(misses, 1)
1323        self.assertEqual(currsize, 1)
1324
1325        # Test bypassing the cache
1326        self.assertIs(f.__wrapped__, orig)
1327        f.__wrapped__(x, y)
1328        hits, misses, maxsize, currsize = f.cache_info()
1329        self.assertEqual(hits, 0)
1330        self.assertEqual(misses, 1)
1331        self.assertEqual(currsize, 1)
1332
1333        # test size zero (which means "never-cache")
1334        @self.module.lru_cache(0)
1335        def f():
1336            nonlocal f_cnt
1337            f_cnt += 1
1338            return 20
1339        self.assertEqual(f.cache_info().maxsize, 0)
1340        f_cnt = 0
1341        for i in range(5):
1342            self.assertEqual(f(), 20)
1343        self.assertEqual(f_cnt, 5)
1344        hits, misses, maxsize, currsize = f.cache_info()
1345        self.assertEqual(hits, 0)
1346        self.assertEqual(misses, 5)
1347        self.assertEqual(currsize, 0)
1348
1349        # test size one
1350        @self.module.lru_cache(1)
1351        def f():
1352            nonlocal f_cnt
1353            f_cnt += 1
1354            return 20
1355        self.assertEqual(f.cache_info().maxsize, 1)
1356        f_cnt = 0
1357        for i in range(5):
1358            self.assertEqual(f(), 20)
1359        self.assertEqual(f_cnt, 1)
1360        hits, misses, maxsize, currsize = f.cache_info()
1361        self.assertEqual(hits, 4)
1362        self.assertEqual(misses, 1)
1363        self.assertEqual(currsize, 1)
1364
1365        # test size two
1366        @self.module.lru_cache(2)
1367        def f(x):
1368            nonlocal f_cnt
1369            f_cnt += 1
1370            return x*10
1371        self.assertEqual(f.cache_info().maxsize, 2)
1372        f_cnt = 0
1373        for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1374            #    *  *              *                          *
1375            self.assertEqual(f(x), x*10)
1376        self.assertEqual(f_cnt, 4)
1377        hits, misses, maxsize, currsize = f.cache_info()
1378        self.assertEqual(hits, 12)
1379        self.assertEqual(misses, 4)
1380        self.assertEqual(currsize, 2)
1381
1382    def test_lru_no_args(self):
1383        @self.module.lru_cache
1384        def square(x):
1385            return x ** 2
1386
1387        self.assertEqual(list(map(square, [10, 20, 10])),
1388                         [100, 400, 100])
1389        self.assertEqual(square.cache_info().hits, 1)
1390        self.assertEqual(square.cache_info().misses, 2)
1391        self.assertEqual(square.cache_info().maxsize, 128)
1392        self.assertEqual(square.cache_info().currsize, 2)
1393
1394    def test_lru_bug_35780(self):
1395        # C version of the lru_cache was not checking to see if
1396        # the user function call has already modified the cache
1397        # (this arises in recursive calls and in multi-threading).
1398        # This cause the cache to have orphan links not referenced
1399        # by the cache dictionary.
1400
1401        once = True                 # Modified by f(x) below
1402
1403        @self.module.lru_cache(maxsize=10)
1404        def f(x):
1405            nonlocal once
1406            rv = f'.{x}.'
1407            if x == 20 and once:
1408                once = False
1409                rv = f(x)
1410            return rv
1411
1412        # Fill the cache
1413        for x in range(15):
1414            self.assertEqual(f(x), f'.{x}.')
1415        self.assertEqual(f.cache_info().currsize, 10)
1416
1417        # Make a recursive call and make sure the cache remains full
1418        self.assertEqual(f(20), '.20.')
1419        self.assertEqual(f.cache_info().currsize, 10)
1420
1421    def test_lru_bug_36650(self):
1422        # C version of lru_cache was treating a call with an empty **kwargs
1423        # dictionary as being distinct from a call with no keywords at all.
1424        # This did not result in an incorrect answer, but it did trigger
1425        # an unexpected cache miss.
1426
1427        @self.module.lru_cache()
1428        def f(x):
1429            pass
1430
1431        f(0)
1432        f(0, **{})
1433        self.assertEqual(f.cache_info().hits, 1)
1434
1435    def test_lru_hash_only_once(self):
1436        # To protect against weird reentrancy bugs and to improve
1437        # efficiency when faced with slow __hash__ methods, the
1438        # LRU cache guarantees that it will only call __hash__
1439        # only once per use as an argument to the cached function.
1440
1441        @self.module.lru_cache(maxsize=1)
1442        def f(x, y):
1443            return x * 3 + y
1444
1445        # Simulate the integer 5
1446        mock_int = unittest.mock.Mock()
1447        mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1448        mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1449
1450        # Add to cache:  One use as an argument gives one call
1451        self.assertEqual(f(mock_int, 1), 16)
1452        self.assertEqual(mock_int.__hash__.call_count, 1)
1453        self.assertEqual(f.cache_info(), (0, 1, 1, 1))
1454
1455        # Cache hit: One use as an argument gives one additional call
1456        self.assertEqual(f(mock_int, 1), 16)
1457        self.assertEqual(mock_int.__hash__.call_count, 2)
1458        self.assertEqual(f.cache_info(), (1, 1, 1, 1))
1459
1460        # Cache eviction: No use as an argument gives no additional call
1461        self.assertEqual(f(6, 2), 20)
1462        self.assertEqual(mock_int.__hash__.call_count, 2)
1463        self.assertEqual(f.cache_info(), (1, 2, 1, 1))
1464
1465        # Cache miss: One use as an argument gives one additional call
1466        self.assertEqual(f(mock_int, 1), 16)
1467        self.assertEqual(mock_int.__hash__.call_count, 3)
1468        self.assertEqual(f.cache_info(), (1, 3, 1, 1))
1469
1470    def test_lru_reentrancy_with_len(self):
1471        # Test to make sure the LRU cache code isn't thrown-off by
1472        # caching the built-in len() function.  Since len() can be
1473        # cached, we shouldn't use it inside the lru code itself.
1474        old_len = builtins.len
1475        try:
1476            builtins.len = self.module.lru_cache(4)(len)
1477            for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1478                self.assertEqual(len('abcdefghijklmn'[:i]), i)
1479        finally:
1480            builtins.len = old_len
1481
1482    def test_lru_star_arg_handling(self):
1483        # Test regression that arose in ea064ff3c10f
1484        @self.module.lru_cache()
1485        def f(*args):
1486            return args
1487
1488        self.assertEqual(f(1, 2), (1, 2))
1489        self.assertEqual(f((1, 2)), ((1, 2),))
1490
1491    def test_lru_type_error(self):
1492        # Regression test for issue #28653.
1493        # lru_cache was leaking when one of the arguments
1494        # wasn't cacheable.
1495
1496        @self.module.lru_cache(maxsize=None)
1497        def infinite_cache(o):
1498            pass
1499
1500        @self.module.lru_cache(maxsize=10)
1501        def limited_cache(o):
1502            pass
1503
1504        with self.assertRaises(TypeError):
1505            infinite_cache([])
1506
1507        with self.assertRaises(TypeError):
1508            limited_cache([])
1509
1510    def test_lru_with_maxsize_none(self):
1511        @self.module.lru_cache(maxsize=None)
1512        def fib(n):
1513            if n < 2:
1514                return n
1515            return fib(n-1) + fib(n-2)
1516        self.assertEqual([fib(n) for n in range(16)],
1517            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1518        self.assertEqual(fib.cache_info(),
1519            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1520        fib.cache_clear()
1521        self.assertEqual(fib.cache_info(),
1522            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1523
1524    def test_lru_with_maxsize_negative(self):
1525        @self.module.lru_cache(maxsize=-10)
1526        def eq(n):
1527            return n
1528        for i in (0, 1):
1529            self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1530        self.assertEqual(eq.cache_info(),
1531            self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
1532
1533    def test_lru_with_exceptions(self):
1534        # Verify that user_function exceptions get passed through without
1535        # creating a hard-to-read chained exception.
1536        # http://bugs.python.org/issue13177
1537        for maxsize in (None, 128):
1538            @self.module.lru_cache(maxsize)
1539            def func(i):
1540                return 'abc'[i]
1541            self.assertEqual(func(0), 'a')
1542            with self.assertRaises(IndexError) as cm:
1543                func(15)
1544            self.assertIsNone(cm.exception.__context__)
1545            # Verify that the previous exception did not result in a cached entry
1546            with self.assertRaises(IndexError):
1547                func(15)
1548
1549    def test_lru_with_types(self):
1550        for maxsize in (None, 128):
1551            @self.module.lru_cache(maxsize=maxsize, typed=True)
1552            def square(x):
1553                return x * x
1554            self.assertEqual(square(3), 9)
1555            self.assertEqual(type(square(3)), type(9))
1556            self.assertEqual(square(3.0), 9.0)
1557            self.assertEqual(type(square(3.0)), type(9.0))
1558            self.assertEqual(square(x=3), 9)
1559            self.assertEqual(type(square(x=3)), type(9))
1560            self.assertEqual(square(x=3.0), 9.0)
1561            self.assertEqual(type(square(x=3.0)), type(9.0))
1562            self.assertEqual(square.cache_info().hits, 4)
1563            self.assertEqual(square.cache_info().misses, 4)
1564
1565    def test_lru_cache_typed_is_not_recursive(self):
1566        cached = self.module.lru_cache(typed=True)(repr)
1567
1568        self.assertEqual(cached(1), '1')
1569        self.assertEqual(cached(True), 'True')
1570        self.assertEqual(cached(1.0), '1.0')
1571        self.assertEqual(cached(0), '0')
1572        self.assertEqual(cached(False), 'False')
1573        self.assertEqual(cached(0.0), '0.0')
1574
1575        self.assertEqual(cached((1,)), '(1,)')
1576        self.assertEqual(cached((True,)), '(1,)')
1577        self.assertEqual(cached((1.0,)), '(1,)')
1578        self.assertEqual(cached((0,)), '(0,)')
1579        self.assertEqual(cached((False,)), '(0,)')
1580        self.assertEqual(cached((0.0,)), '(0,)')
1581
1582        class T(tuple):
1583            pass
1584
1585        self.assertEqual(cached(T((1,))), '(1,)')
1586        self.assertEqual(cached(T((True,))), '(1,)')
1587        self.assertEqual(cached(T((1.0,))), '(1,)')
1588        self.assertEqual(cached(T((0,))), '(0,)')
1589        self.assertEqual(cached(T((False,))), '(0,)')
1590        self.assertEqual(cached(T((0.0,))), '(0,)')
1591
1592    def test_lru_with_keyword_args(self):
1593        @self.module.lru_cache()
1594        def fib(n):
1595            if n < 2:
1596                return n
1597            return fib(n=n-1) + fib(n=n-2)
1598        self.assertEqual(
1599            [fib(n=number) for number in range(16)],
1600            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1601        )
1602        self.assertEqual(fib.cache_info(),
1603            self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1604        fib.cache_clear()
1605        self.assertEqual(fib.cache_info(),
1606            self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1607
1608    def test_lru_with_keyword_args_maxsize_none(self):
1609        @self.module.lru_cache(maxsize=None)
1610        def fib(n):
1611            if n < 2:
1612                return n
1613            return fib(n=n-1) + fib(n=n-2)
1614        self.assertEqual([fib(n=number) for number in range(16)],
1615            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1616        self.assertEqual(fib.cache_info(),
1617            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1618        fib.cache_clear()
1619        self.assertEqual(fib.cache_info(),
1620            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1621
1622    def test_kwargs_order(self):
1623        # PEP 468: Preserving Keyword Argument Order
1624        @self.module.lru_cache(maxsize=10)
1625        def f(**kwargs):
1626            return list(kwargs.items())
1627        self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1628        self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1629        self.assertEqual(f.cache_info(),
1630            self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1631
1632    def test_lru_cache_decoration(self):
1633        def f(zomg: 'zomg_annotation'):
1634            """f doc string"""
1635            return 42
1636        g = self.module.lru_cache()(f)
1637        for attr in self.module.WRAPPER_ASSIGNMENTS:
1638            self.assertEqual(getattr(g, attr), getattr(f, attr))
1639
1640    @threading_helper.requires_working_threading()
1641    def test_lru_cache_threaded(self):
1642        n, m = 5, 11
1643        def orig(x, y):
1644            return 3 * x + y
1645        f = self.module.lru_cache(maxsize=n*m)(orig)
1646        hits, misses, maxsize, currsize = f.cache_info()
1647        self.assertEqual(currsize, 0)
1648
1649        start = threading.Event()
1650        def full(k):
1651            start.wait(10)
1652            for _ in range(m):
1653                self.assertEqual(f(k, 0), orig(k, 0))
1654
1655        def clear():
1656            start.wait(10)
1657            for _ in range(2*m):
1658                f.cache_clear()
1659
1660        orig_si = sys.getswitchinterval()
1661        support.setswitchinterval(1e-6)
1662        try:
1663            # create n threads in order to fill cache
1664            threads = [threading.Thread(target=full, args=[k])
1665                       for k in range(n)]
1666            with threading_helper.start_threads(threads):
1667                start.set()
1668
1669            hits, misses, maxsize, currsize = f.cache_info()
1670            if self.module is py_functools:
1671                # XXX: Why can be not equal?
1672                self.assertLessEqual(misses, n)
1673                self.assertLessEqual(hits, m*n - misses)
1674            else:
1675                self.assertEqual(misses, n)
1676                self.assertEqual(hits, m*n - misses)
1677            self.assertEqual(currsize, n)
1678
1679            # create n threads in order to fill cache and 1 to clear it
1680            threads = [threading.Thread(target=clear)]
1681            threads += [threading.Thread(target=full, args=[k])
1682                        for k in range(n)]
1683            start.clear()
1684            with threading_helper.start_threads(threads):
1685                start.set()
1686        finally:
1687            sys.setswitchinterval(orig_si)
1688
1689    @threading_helper.requires_working_threading()
1690    def test_lru_cache_threaded2(self):
1691        # Simultaneous call with the same arguments
1692        n, m = 5, 7
1693        start = threading.Barrier(n+1)
1694        pause = threading.Barrier(n+1)
1695        stop = threading.Barrier(n+1)
1696        @self.module.lru_cache(maxsize=m*n)
1697        def f(x):
1698            pause.wait(10)
1699            return 3 * x
1700        self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1701        def test():
1702            for i in range(m):
1703                start.wait(10)
1704                self.assertEqual(f(i), 3 * i)
1705                stop.wait(10)
1706        threads = [threading.Thread(target=test) for k in range(n)]
1707        with threading_helper.start_threads(threads):
1708            for i in range(m):
1709                start.wait(10)
1710                stop.reset()
1711                pause.wait(10)
1712                start.reset()
1713                stop.wait(10)
1714                pause.reset()
1715                self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1716
1717    @threading_helper.requires_working_threading()
1718    def test_lru_cache_threaded3(self):
1719        @self.module.lru_cache(maxsize=2)
1720        def f(x):
1721            time.sleep(.01)
1722            return 3 * x
1723        def test(i, x):
1724            with self.subTest(thread=i):
1725                self.assertEqual(f(x), 3 * x, i)
1726        threads = [threading.Thread(target=test, args=(i, v))
1727                   for i, v in enumerate([1, 2, 2, 3, 2])]
1728        with threading_helper.start_threads(threads):
1729            pass
1730
1731    def test_need_for_rlock(self):
1732        # This will deadlock on an LRU cache that uses a regular lock
1733
1734        @self.module.lru_cache(maxsize=10)
1735        def test_func(x):
1736            'Used to demonstrate a reentrant lru_cache call within a single thread'
1737            return x
1738
1739        class DoubleEq:
1740            'Demonstrate a reentrant lru_cache call within a single thread'
1741            def __init__(self, x):
1742                self.x = x
1743            def __hash__(self):
1744                return self.x
1745            def __eq__(self, other):
1746                if self.x == 2:
1747                    test_func(DoubleEq(1))
1748                return self.x == other.x
1749
1750        test_func(DoubleEq(1))                      # Load the cache
1751        test_func(DoubleEq(2))                      # Load the cache
1752        self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
1753                         DoubleEq(2))               # Verify the correct return value
1754
1755    def test_lru_method(self):
1756        class X(int):
1757            f_cnt = 0
1758            @self.module.lru_cache(2)
1759            def f(self, x):
1760                self.f_cnt += 1
1761                return x*10+self
1762        a = X(5)
1763        b = X(5)
1764        c = X(7)
1765        self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1766
1767        for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1768            self.assertEqual(a.f(x), x*10 + 5)
1769        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1770        self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1771
1772        for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1773            self.assertEqual(b.f(x), x*10 + 5)
1774        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1775        self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1776
1777        for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1778            self.assertEqual(c.f(x), x*10 + 7)
1779        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1780        self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1781
1782        self.assertEqual(a.f.cache_info(), X.f.cache_info())
1783        self.assertEqual(b.f.cache_info(), X.f.cache_info())
1784        self.assertEqual(c.f.cache_info(), X.f.cache_info())
1785
1786    def test_pickle(self):
1787        cls = self.__class__
1788        for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1789            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1790                with self.subTest(proto=proto, func=f):
1791                    f_copy = pickle.loads(pickle.dumps(f, proto))
1792                    self.assertIs(f_copy, f)
1793
1794    def test_copy(self):
1795        cls = self.__class__
1796        def orig(x, y):
1797            return 3 * x + y
1798        part = self.module.partial(orig, 2)
1799        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1800                 self.module.lru_cache(2)(part))
1801        for f in funcs:
1802            with self.subTest(func=f):
1803                f_copy = copy.copy(f)
1804                self.assertIs(f_copy, f)
1805
1806    def test_deepcopy(self):
1807        cls = self.__class__
1808        def orig(x, y):
1809            return 3 * x + y
1810        part = self.module.partial(orig, 2)
1811        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1812                 self.module.lru_cache(2)(part))
1813        for f in funcs:
1814            with self.subTest(func=f):
1815                f_copy = copy.deepcopy(f)
1816                self.assertIs(f_copy, f)
1817
1818    def test_lru_cache_parameters(self):
1819        @self.module.lru_cache(maxsize=2)
1820        def f():
1821            return 1
1822        self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1823
1824        @self.module.lru_cache(maxsize=1000, typed=True)
1825        def f():
1826            return 1
1827        self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1828
1829    def test_lru_cache_weakrefable(self):
1830        @self.module.lru_cache
1831        def test_function(x):
1832            return x
1833
1834        class A:
1835            @self.module.lru_cache
1836            def test_method(self, x):
1837                return (self, x)
1838
1839            @staticmethod
1840            @self.module.lru_cache
1841            def test_staticmethod(x):
1842                return (self, x)
1843
1844        refs = [weakref.ref(test_function),
1845                weakref.ref(A.test_method),
1846                weakref.ref(A.test_staticmethod)]
1847
1848        for ref in refs:
1849            self.assertIsNotNone(ref())
1850
1851        del A
1852        del test_function
1853        gc.collect()
1854
1855        for ref in refs:
1856            self.assertIsNone(ref())
1857
1858
1859@py_functools.lru_cache()
1860def py_cached_func(x, y):
1861    return 3 * x + y
1862
1863@c_functools.lru_cache()
1864def c_cached_func(x, y):
1865    return 3 * x + y
1866
1867
1868class TestLRUPy(TestLRU, unittest.TestCase):
1869    module = py_functools
1870    cached_func = py_cached_func,
1871
1872    @module.lru_cache()
1873    def cached_meth(self, x, y):
1874        return 3 * x + y
1875
1876    @staticmethod
1877    @module.lru_cache()
1878    def cached_staticmeth(x, y):
1879        return 3 * x + y
1880
1881
1882class TestLRUC(TestLRU, unittest.TestCase):
1883    module = c_functools
1884    cached_func = c_cached_func,
1885
1886    @module.lru_cache()
1887    def cached_meth(self, x, y):
1888        return 3 * x + y
1889
1890    @staticmethod
1891    @module.lru_cache()
1892    def cached_staticmeth(x, y):
1893        return 3 * x + y
1894
1895
1896class TestSingleDispatch(unittest.TestCase):
1897    def test_simple_overloads(self):
1898        @functools.singledispatch
1899        def g(obj):
1900            return "base"
1901        def g_int(i):
1902            return "integer"
1903        g.register(int, g_int)
1904        self.assertEqual(g("str"), "base")
1905        self.assertEqual(g(1), "integer")
1906        self.assertEqual(g([1,2,3]), "base")
1907
1908    def test_mro(self):
1909        @functools.singledispatch
1910        def g(obj):
1911            return "base"
1912        class A:
1913            pass
1914        class C(A):
1915            pass
1916        class B(A):
1917            pass
1918        class D(C, B):
1919            pass
1920        def g_A(a):
1921            return "A"
1922        def g_B(b):
1923            return "B"
1924        g.register(A, g_A)
1925        g.register(B, g_B)
1926        self.assertEqual(g(A()), "A")
1927        self.assertEqual(g(B()), "B")
1928        self.assertEqual(g(C()), "A")
1929        self.assertEqual(g(D()), "B")
1930
1931    def test_register_decorator(self):
1932        @functools.singledispatch
1933        def g(obj):
1934            return "base"
1935        @g.register(int)
1936        def g_int(i):
1937            return "int %s" % (i,)
1938        self.assertEqual(g(""), "base")
1939        self.assertEqual(g(12), "int 12")
1940        self.assertIs(g.dispatch(int), g_int)
1941        self.assertIs(g.dispatch(object), g.dispatch(str))
1942        # Note: in the assert above this is not g.
1943        # @singledispatch returns the wrapper.
1944
1945    def test_wrapping_attributes(self):
1946        @functools.singledispatch
1947        def g(obj):
1948            "Simple test"
1949            return "Test"
1950        self.assertEqual(g.__name__, "g")
1951        if sys.flags.optimize < 2:
1952            self.assertEqual(g.__doc__, "Simple test")
1953
1954    @unittest.skipUnless(decimal, 'requires _decimal')
1955    @support.cpython_only
1956    def test_c_classes(self):
1957        @functools.singledispatch
1958        def g(obj):
1959            return "base"
1960        @g.register(decimal.DecimalException)
1961        def _(obj):
1962            return obj.args
1963        subn = decimal.Subnormal("Exponent < Emin")
1964        rnd = decimal.Rounded("Number got rounded")
1965        self.assertEqual(g(subn), ("Exponent < Emin",))
1966        self.assertEqual(g(rnd), ("Number got rounded",))
1967        @g.register(decimal.Subnormal)
1968        def _(obj):
1969            return "Too small to care."
1970        self.assertEqual(g(subn), "Too small to care.")
1971        self.assertEqual(g(rnd), ("Number got rounded",))
1972
1973    def test_compose_mro(self):
1974        # None of the examples in this test depend on haystack ordering.
1975        c = collections.abc
1976        mro = functools._compose_mro
1977        bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1978        for haystack in permutations(bases):
1979            m = mro(dict, haystack)
1980            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1981                                 c.Collection, c.Sized, c.Iterable,
1982                                 c.Container, object])
1983        bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
1984        for haystack in permutations(bases):
1985            m = mro(collections.ChainMap, haystack)
1986            self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
1987                                 c.Collection, c.Sized, c.Iterable,
1988                                 c.Container, object])
1989
1990        # If there's a generic function with implementations registered for
1991        # both Sized and Container, passing a defaultdict to it results in an
1992        # ambiguous dispatch which will cause a RuntimeError (see
1993        # test_mro_conflicts).
1994        bases = [c.Container, c.Sized, str]
1995        for haystack in permutations(bases):
1996            m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1997            self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1998                                 c.Container, object])
1999
2000        # MutableSequence below is registered directly on D. In other words, it
2001        # precedes MutableMapping which means single dispatch will always
2002        # choose MutableSequence here.
2003        class D(collections.defaultdict):
2004            pass
2005        c.MutableSequence.register(D)
2006        bases = [c.MutableSequence, c.MutableMapping]
2007        for haystack in permutations(bases):
2008            m = mro(D, bases)
2009            self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
2010                                 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
2011                                 c.Collection, c.Sized, c.Iterable, c.Container,
2012                                 object])
2013
2014        # Container and Callable are registered on different base classes and
2015        # a generic function supporting both should always pick the Callable
2016        # implementation if a C instance is passed.
2017        class C(collections.defaultdict):
2018            def __call__(self):
2019                pass
2020        bases = [c.Sized, c.Callable, c.Container, c.Mapping]
2021        for haystack in permutations(bases):
2022            m = mro(C, haystack)
2023            self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
2024                                 c.Collection, c.Sized, c.Iterable,
2025                                 c.Container, object])
2026
2027    def test_register_abc(self):
2028        c = collections.abc
2029        d = {"a": "b"}
2030        l = [1, 2, 3]
2031        s = {object(), None}
2032        f = frozenset(s)
2033        t = (1, 2, 3)
2034        @functools.singledispatch
2035        def g(obj):
2036            return "base"
2037        self.assertEqual(g(d), "base")
2038        self.assertEqual(g(l), "base")
2039        self.assertEqual(g(s), "base")
2040        self.assertEqual(g(f), "base")
2041        self.assertEqual(g(t), "base")
2042        g.register(c.Sized, lambda obj: "sized")
2043        self.assertEqual(g(d), "sized")
2044        self.assertEqual(g(l), "sized")
2045        self.assertEqual(g(s), "sized")
2046        self.assertEqual(g(f), "sized")
2047        self.assertEqual(g(t), "sized")
2048        g.register(c.MutableMapping, lambda obj: "mutablemapping")
2049        self.assertEqual(g(d), "mutablemapping")
2050        self.assertEqual(g(l), "sized")
2051        self.assertEqual(g(s), "sized")
2052        self.assertEqual(g(f), "sized")
2053        self.assertEqual(g(t), "sized")
2054        g.register(collections.ChainMap, lambda obj: "chainmap")
2055        self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
2056        self.assertEqual(g(l), "sized")
2057        self.assertEqual(g(s), "sized")
2058        self.assertEqual(g(f), "sized")
2059        self.assertEqual(g(t), "sized")
2060        g.register(c.MutableSequence, lambda obj: "mutablesequence")
2061        self.assertEqual(g(d), "mutablemapping")
2062        self.assertEqual(g(l), "mutablesequence")
2063        self.assertEqual(g(s), "sized")
2064        self.assertEqual(g(f), "sized")
2065        self.assertEqual(g(t), "sized")
2066        g.register(c.MutableSet, lambda obj: "mutableset")
2067        self.assertEqual(g(d), "mutablemapping")
2068        self.assertEqual(g(l), "mutablesequence")
2069        self.assertEqual(g(s), "mutableset")
2070        self.assertEqual(g(f), "sized")
2071        self.assertEqual(g(t), "sized")
2072        g.register(c.Mapping, lambda obj: "mapping")
2073        self.assertEqual(g(d), "mutablemapping")  # not specific enough
2074        self.assertEqual(g(l), "mutablesequence")
2075        self.assertEqual(g(s), "mutableset")
2076        self.assertEqual(g(f), "sized")
2077        self.assertEqual(g(t), "sized")
2078        g.register(c.Sequence, lambda obj: "sequence")
2079        self.assertEqual(g(d), "mutablemapping")
2080        self.assertEqual(g(l), "mutablesequence")
2081        self.assertEqual(g(s), "mutableset")
2082        self.assertEqual(g(f), "sized")
2083        self.assertEqual(g(t), "sequence")
2084        g.register(c.Set, lambda obj: "set")
2085        self.assertEqual(g(d), "mutablemapping")
2086        self.assertEqual(g(l), "mutablesequence")
2087        self.assertEqual(g(s), "mutableset")
2088        self.assertEqual(g(f), "set")
2089        self.assertEqual(g(t), "sequence")
2090        g.register(dict, lambda obj: "dict")
2091        self.assertEqual(g(d), "dict")
2092        self.assertEqual(g(l), "mutablesequence")
2093        self.assertEqual(g(s), "mutableset")
2094        self.assertEqual(g(f), "set")
2095        self.assertEqual(g(t), "sequence")
2096        g.register(list, lambda obj: "list")
2097        self.assertEqual(g(d), "dict")
2098        self.assertEqual(g(l), "list")
2099        self.assertEqual(g(s), "mutableset")
2100        self.assertEqual(g(f), "set")
2101        self.assertEqual(g(t), "sequence")
2102        g.register(set, lambda obj: "concrete-set")
2103        self.assertEqual(g(d), "dict")
2104        self.assertEqual(g(l), "list")
2105        self.assertEqual(g(s), "concrete-set")
2106        self.assertEqual(g(f), "set")
2107        self.assertEqual(g(t), "sequence")
2108        g.register(frozenset, lambda obj: "frozen-set")
2109        self.assertEqual(g(d), "dict")
2110        self.assertEqual(g(l), "list")
2111        self.assertEqual(g(s), "concrete-set")
2112        self.assertEqual(g(f), "frozen-set")
2113        self.assertEqual(g(t), "sequence")
2114        g.register(tuple, lambda obj: "tuple")
2115        self.assertEqual(g(d), "dict")
2116        self.assertEqual(g(l), "list")
2117        self.assertEqual(g(s), "concrete-set")
2118        self.assertEqual(g(f), "frozen-set")
2119        self.assertEqual(g(t), "tuple")
2120
2121    def test_c3_abc(self):
2122        c = collections.abc
2123        mro = functools._c3_mro
2124        class A(object):
2125            pass
2126        class B(A):
2127            def __len__(self):
2128                return 0   # implies Sized
2129        @c.Container.register
2130        class C(object):
2131            pass
2132        class D(object):
2133            pass   # unrelated
2134        class X(D, C, B):
2135            def __call__(self):
2136                pass   # implies Callable
2137        expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2138        for abcs in permutations([c.Sized, c.Callable, c.Container]):
2139            self.assertEqual(mro(X, abcs=abcs), expected)
2140        # unrelated ABCs don't appear in the resulting MRO
2141        many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2142        self.assertEqual(mro(X, abcs=many_abcs), expected)
2143
2144    def test_false_meta(self):
2145        # see issue23572
2146        class MetaA(type):
2147            def __len__(self):
2148                return 0
2149        class A(metaclass=MetaA):
2150            pass
2151        class AA(A):
2152            pass
2153        @functools.singledispatch
2154        def fun(a):
2155            return 'base A'
2156        @fun.register(A)
2157        def _(a):
2158            return 'fun A'
2159        aa = AA()
2160        self.assertEqual(fun(aa), 'fun A')
2161
2162    def test_mro_conflicts(self):
2163        c = collections.abc
2164        @functools.singledispatch
2165        def g(arg):
2166            return "base"
2167        class O(c.Sized):
2168            def __len__(self):
2169                return 0
2170        o = O()
2171        self.assertEqual(g(o), "base")
2172        g.register(c.Iterable, lambda arg: "iterable")
2173        g.register(c.Container, lambda arg: "container")
2174        g.register(c.Sized, lambda arg: "sized")
2175        g.register(c.Set, lambda arg: "set")
2176        self.assertEqual(g(o), "sized")
2177        c.Iterable.register(O)
2178        self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
2179        c.Container.register(O)
2180        self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
2181        c.Set.register(O)
2182        self.assertEqual(g(o), "set")     # because c.Set is a subclass of
2183                                          # c.Sized and c.Container
2184        class P:
2185            pass
2186        p = P()
2187        self.assertEqual(g(p), "base")
2188        c.Iterable.register(P)
2189        self.assertEqual(g(p), "iterable")
2190        c.Container.register(P)
2191        with self.assertRaises(RuntimeError) as re_one:
2192            g(p)
2193        self.assertIn(
2194            str(re_one.exception),
2195            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2196              "or <class 'collections.abc.Iterable'>"),
2197             ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2198              "or <class 'collections.abc.Container'>")),
2199        )
2200        class Q(c.Sized):
2201            def __len__(self):
2202                return 0
2203        q = Q()
2204        self.assertEqual(g(q), "sized")
2205        c.Iterable.register(Q)
2206        self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
2207        c.Set.register(Q)
2208        self.assertEqual(g(q), "set")     # because c.Set is a subclass of
2209                                          # c.Sized and c.Iterable
2210        @functools.singledispatch
2211        def h(arg):
2212            return "base"
2213        @h.register(c.Sized)
2214        def _(arg):
2215            return "sized"
2216        @h.register(c.Container)
2217        def _(arg):
2218            return "container"
2219        # Even though Sized and Container are explicit bases of MutableMapping,
2220        # this ABC is implicitly registered on defaultdict which makes all of
2221        # MutableMapping's bases implicit as well from defaultdict's
2222        # perspective.
2223        with self.assertRaises(RuntimeError) as re_two:
2224            h(collections.defaultdict(lambda: 0))
2225        self.assertIn(
2226            str(re_two.exception),
2227            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2228              "or <class 'collections.abc.Sized'>"),
2229             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2230              "or <class 'collections.abc.Container'>")),
2231        )
2232        class R(collections.defaultdict):
2233            pass
2234        c.MutableSequence.register(R)
2235        @functools.singledispatch
2236        def i(arg):
2237            return "base"
2238        @i.register(c.MutableMapping)
2239        def _(arg):
2240            return "mapping"
2241        @i.register(c.MutableSequence)
2242        def _(arg):
2243            return "sequence"
2244        r = R()
2245        self.assertEqual(i(r), "sequence")
2246        class S:
2247            pass
2248        class T(S, c.Sized):
2249            def __len__(self):
2250                return 0
2251        t = T()
2252        self.assertEqual(h(t), "sized")
2253        c.Container.register(T)
2254        self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
2255        class U:
2256            def __len__(self):
2257                return 0
2258        u = U()
2259        self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
2260                                          # from the existence of __len__()
2261        c.Container.register(U)
2262        # There is no preference for registered versus inferred ABCs.
2263        with self.assertRaises(RuntimeError) as re_three:
2264            h(u)
2265        self.assertIn(
2266            str(re_three.exception),
2267            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2268              "or <class 'collections.abc.Sized'>"),
2269             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2270              "or <class 'collections.abc.Container'>")),
2271        )
2272        class V(c.Sized, S):
2273            def __len__(self):
2274                return 0
2275        @functools.singledispatch
2276        def j(arg):
2277            return "base"
2278        @j.register(S)
2279        def _(arg):
2280            return "s"
2281        @j.register(c.Container)
2282        def _(arg):
2283            return "container"
2284        v = V()
2285        self.assertEqual(j(v), "s")
2286        c.Container.register(V)
2287        self.assertEqual(j(v), "container")   # because it ends up right after
2288                                              # Sized in the MRO
2289
2290    def test_cache_invalidation(self):
2291        from collections import UserDict
2292        import weakref
2293
2294        class TracingDict(UserDict):
2295            def __init__(self, *args, **kwargs):
2296                super(TracingDict, self).__init__(*args, **kwargs)
2297                self.set_ops = []
2298                self.get_ops = []
2299            def __getitem__(self, key):
2300                result = self.data[key]
2301                self.get_ops.append(key)
2302                return result
2303            def __setitem__(self, key, value):
2304                self.set_ops.append(key)
2305                self.data[key] = value
2306            def clear(self):
2307                self.data.clear()
2308
2309        td = TracingDict()
2310        with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2311            c = collections.abc
2312            @functools.singledispatch
2313            def g(arg):
2314                return "base"
2315            d = {}
2316            l = []
2317            self.assertEqual(len(td), 0)
2318            self.assertEqual(g(d), "base")
2319            self.assertEqual(len(td), 1)
2320            self.assertEqual(td.get_ops, [])
2321            self.assertEqual(td.set_ops, [dict])
2322            self.assertEqual(td.data[dict], g.registry[object])
2323            self.assertEqual(g(l), "base")
2324            self.assertEqual(len(td), 2)
2325            self.assertEqual(td.get_ops, [])
2326            self.assertEqual(td.set_ops, [dict, list])
2327            self.assertEqual(td.data[dict], g.registry[object])
2328            self.assertEqual(td.data[list], g.registry[object])
2329            self.assertEqual(td.data[dict], td.data[list])
2330            self.assertEqual(g(l), "base")
2331            self.assertEqual(g(d), "base")
2332            self.assertEqual(td.get_ops, [list, dict])
2333            self.assertEqual(td.set_ops, [dict, list])
2334            g.register(list, lambda arg: "list")
2335            self.assertEqual(td.get_ops, [list, dict])
2336            self.assertEqual(len(td), 0)
2337            self.assertEqual(g(d), "base")
2338            self.assertEqual(len(td), 1)
2339            self.assertEqual(td.get_ops, [list, dict])
2340            self.assertEqual(td.set_ops, [dict, list, dict])
2341            self.assertEqual(td.data[dict],
2342                             functools._find_impl(dict, g.registry))
2343            self.assertEqual(g(l), "list")
2344            self.assertEqual(len(td), 2)
2345            self.assertEqual(td.get_ops, [list, dict])
2346            self.assertEqual(td.set_ops, [dict, list, dict, list])
2347            self.assertEqual(td.data[list],
2348                             functools._find_impl(list, g.registry))
2349            class X:
2350                pass
2351            c.MutableMapping.register(X)   # Will not invalidate the cache,
2352                                           # not using ABCs yet.
2353            self.assertEqual(g(d), "base")
2354            self.assertEqual(g(l), "list")
2355            self.assertEqual(td.get_ops, [list, dict, dict, list])
2356            self.assertEqual(td.set_ops, [dict, list, dict, list])
2357            g.register(c.Sized, lambda arg: "sized")
2358            self.assertEqual(len(td), 0)
2359            self.assertEqual(g(d), "sized")
2360            self.assertEqual(len(td), 1)
2361            self.assertEqual(td.get_ops, [list, dict, dict, list])
2362            self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2363            self.assertEqual(g(l), "list")
2364            self.assertEqual(len(td), 2)
2365            self.assertEqual(td.get_ops, [list, dict, dict, list])
2366            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2367            self.assertEqual(g(l), "list")
2368            self.assertEqual(g(d), "sized")
2369            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2370            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2371            g.dispatch(list)
2372            g.dispatch(dict)
2373            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2374                                          list, dict])
2375            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2376            c.MutableSet.register(X)       # Will invalidate the cache.
2377            self.assertEqual(len(td), 2)   # Stale cache.
2378            self.assertEqual(g(l), "list")
2379            self.assertEqual(len(td), 1)
2380            g.register(c.MutableMapping, lambda arg: "mutablemapping")
2381            self.assertEqual(len(td), 0)
2382            self.assertEqual(g(d), "mutablemapping")
2383            self.assertEqual(len(td), 1)
2384            self.assertEqual(g(l), "list")
2385            self.assertEqual(len(td), 2)
2386            g.register(dict, lambda arg: "dict")
2387            self.assertEqual(g(d), "dict")
2388            self.assertEqual(g(l), "list")
2389            g._clear_cache()
2390            self.assertEqual(len(td), 0)
2391
2392    def test_annotations(self):
2393        @functools.singledispatch
2394        def i(arg):
2395            return "base"
2396        @i.register
2397        def _(arg: collections.abc.Mapping):
2398            return "mapping"
2399        @i.register
2400        def _(arg: "collections.abc.Sequence"):
2401            return "sequence"
2402        self.assertEqual(i(None), "base")
2403        self.assertEqual(i({"a": 1}), "mapping")
2404        self.assertEqual(i([1, 2, 3]), "sequence")
2405        self.assertEqual(i((1, 2, 3)), "sequence")
2406        self.assertEqual(i("str"), "sequence")
2407
2408        # Registering classes as callables doesn't work with annotations,
2409        # you need to pass the type explicitly.
2410        @i.register(str)
2411        class _:
2412            def __init__(self, arg):
2413                self.arg = arg
2414
2415            def __eq__(self, other):
2416                return self.arg == other
2417        self.assertEqual(i("str"), "str")
2418
2419    def test_method_register(self):
2420        class A:
2421            @functools.singledispatchmethod
2422            def t(self, arg):
2423                self.arg = "base"
2424            @t.register(int)
2425            def _(self, arg):
2426                self.arg = "int"
2427            @t.register(str)
2428            def _(self, arg):
2429                self.arg = "str"
2430        a = A()
2431
2432        a.t(0)
2433        self.assertEqual(a.arg, "int")
2434        aa = A()
2435        self.assertFalse(hasattr(aa, 'arg'))
2436        a.t('')
2437        self.assertEqual(a.arg, "str")
2438        aa = A()
2439        self.assertFalse(hasattr(aa, 'arg'))
2440        a.t(0.0)
2441        self.assertEqual(a.arg, "base")
2442        aa = A()
2443        self.assertFalse(hasattr(aa, 'arg'))
2444
2445    def test_staticmethod_register(self):
2446        class A:
2447            @functools.singledispatchmethod
2448            @staticmethod
2449            def t(arg):
2450                return arg
2451            @t.register(int)
2452            @staticmethod
2453            def _(arg):
2454                return isinstance(arg, int)
2455            @t.register(str)
2456            @staticmethod
2457            def _(arg):
2458                return isinstance(arg, str)
2459        a = A()
2460
2461        self.assertTrue(A.t(0))
2462        self.assertTrue(A.t(''))
2463        self.assertEqual(A.t(0.0), 0.0)
2464
2465    def test_classmethod_register(self):
2466        class A:
2467            def __init__(self, arg):
2468                self.arg = arg
2469
2470            @functools.singledispatchmethod
2471            @classmethod
2472            def t(cls, arg):
2473                return cls("base")
2474            @t.register(int)
2475            @classmethod
2476            def _(cls, arg):
2477                return cls("int")
2478            @t.register(str)
2479            @classmethod
2480            def _(cls, arg):
2481                return cls("str")
2482
2483        self.assertEqual(A.t(0).arg, "int")
2484        self.assertEqual(A.t('').arg, "str")
2485        self.assertEqual(A.t(0.0).arg, "base")
2486
2487    def test_callable_register(self):
2488        class A:
2489            def __init__(self, arg):
2490                self.arg = arg
2491
2492            @functools.singledispatchmethod
2493            @classmethod
2494            def t(cls, arg):
2495                return cls("base")
2496
2497        @A.t.register(int)
2498        @classmethod
2499        def _(cls, arg):
2500            return cls("int")
2501        @A.t.register(str)
2502        @classmethod
2503        def _(cls, arg):
2504            return cls("str")
2505
2506        self.assertEqual(A.t(0).arg, "int")
2507        self.assertEqual(A.t('').arg, "str")
2508        self.assertEqual(A.t(0.0).arg, "base")
2509
2510    def test_abstractmethod_register(self):
2511        class Abstract(metaclass=abc.ABCMeta):
2512
2513            @functools.singledispatchmethod
2514            @abc.abstractmethod
2515            def add(self, x, y):
2516                pass
2517
2518        self.assertTrue(Abstract.add.__isabstractmethod__)
2519        self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__)
2520
2521        with self.assertRaises(TypeError):
2522            Abstract()
2523
2524    def test_type_ann_register(self):
2525        class A:
2526            @functools.singledispatchmethod
2527            def t(self, arg):
2528                return "base"
2529            @t.register
2530            def _(self, arg: int):
2531                return "int"
2532            @t.register
2533            def _(self, arg: str):
2534                return "str"
2535        a = A()
2536
2537        self.assertEqual(a.t(0), "int")
2538        self.assertEqual(a.t(''), "str")
2539        self.assertEqual(a.t(0.0), "base")
2540
2541    def test_staticmethod_type_ann_register(self):
2542        class A:
2543            @functools.singledispatchmethod
2544            @staticmethod
2545            def t(arg):
2546                return arg
2547            @t.register
2548            @staticmethod
2549            def _(arg: int):
2550                return isinstance(arg, int)
2551            @t.register
2552            @staticmethod
2553            def _(arg: str):
2554                return isinstance(arg, str)
2555        a = A()
2556
2557        self.assertTrue(A.t(0))
2558        self.assertTrue(A.t(''))
2559        self.assertEqual(A.t(0.0), 0.0)
2560
2561    def test_classmethod_type_ann_register(self):
2562        class A:
2563            def __init__(self, arg):
2564                self.arg = arg
2565
2566            @functools.singledispatchmethod
2567            @classmethod
2568            def t(cls, arg):
2569                return cls("base")
2570            @t.register
2571            @classmethod
2572            def _(cls, arg: int):
2573                return cls("int")
2574            @t.register
2575            @classmethod
2576            def _(cls, arg: str):
2577                return cls("str")
2578
2579        self.assertEqual(A.t(0).arg, "int")
2580        self.assertEqual(A.t('').arg, "str")
2581        self.assertEqual(A.t(0.0).arg, "base")
2582
2583    def test_method_wrapping_attributes(self):
2584        class A:
2585            @functools.singledispatchmethod
2586            def func(self, arg: int) -> str:
2587                """My function docstring"""
2588                return str(arg)
2589            @functools.singledispatchmethod
2590            @classmethod
2591            def cls_func(cls, arg: int) -> str:
2592                """My function docstring"""
2593                return str(arg)
2594            @functools.singledispatchmethod
2595            @staticmethod
2596            def static_func(arg: int) -> str:
2597                """My function docstring"""
2598                return str(arg)
2599
2600        for meth in (
2601            A.func,
2602            A().func,
2603            A.cls_func,
2604            A().cls_func,
2605            A.static_func,
2606            A().static_func
2607        ):
2608            with self.subTest(meth=meth):
2609                self.assertEqual(meth.__doc__, 'My function docstring')
2610                self.assertEqual(meth.__annotations__['arg'], int)
2611
2612        self.assertEqual(A.func.__name__, 'func')
2613        self.assertEqual(A().func.__name__, 'func')
2614        self.assertEqual(A.cls_func.__name__, 'cls_func')
2615        self.assertEqual(A().cls_func.__name__, 'cls_func')
2616        self.assertEqual(A.static_func.__name__, 'static_func')
2617        self.assertEqual(A().static_func.__name__, 'static_func')
2618
2619    def test_double_wrapped_methods(self):
2620        def classmethod_friendly_decorator(func):
2621            wrapped = func.__func__
2622            @classmethod
2623            @functools.wraps(wrapped)
2624            def wrapper(*args, **kwargs):
2625                return wrapped(*args, **kwargs)
2626            return wrapper
2627
2628        class WithoutSingleDispatch:
2629            @classmethod
2630            @contextlib.contextmanager
2631            def cls_context_manager(cls, arg: int) -> str:
2632                try:
2633                    yield str(arg)
2634                finally:
2635                    return 'Done'
2636
2637            @classmethod_friendly_decorator
2638            @classmethod
2639            def decorated_classmethod(cls, arg: int) -> str:
2640                return str(arg)
2641
2642        class WithSingleDispatch:
2643            @functools.singledispatchmethod
2644            @classmethod
2645            @contextlib.contextmanager
2646            def cls_context_manager(cls, arg: int) -> str:
2647                """My function docstring"""
2648                try:
2649                    yield str(arg)
2650                finally:
2651                    return 'Done'
2652
2653            @functools.singledispatchmethod
2654            @classmethod_friendly_decorator
2655            @classmethod
2656            def decorated_classmethod(cls, arg: int) -> str:
2657                """My function docstring"""
2658                return str(arg)
2659
2660        # These are sanity checks
2661        # to test the test itself is working as expected
2662        with WithoutSingleDispatch.cls_context_manager(5) as foo:
2663            without_single_dispatch_foo = foo
2664
2665        with WithSingleDispatch.cls_context_manager(5) as foo:
2666            single_dispatch_foo = foo
2667
2668        self.assertEqual(without_single_dispatch_foo, single_dispatch_foo)
2669        self.assertEqual(single_dispatch_foo, '5')
2670
2671        self.assertEqual(
2672            WithoutSingleDispatch.decorated_classmethod(5),
2673            WithSingleDispatch.decorated_classmethod(5)
2674        )
2675
2676        self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5')
2677
2678        # Behavioural checks now follow
2679        for method_name in ('cls_context_manager', 'decorated_classmethod'):
2680            with self.subTest(method=method_name):
2681                self.assertEqual(
2682                    getattr(WithSingleDispatch, method_name).__name__,
2683                    getattr(WithoutSingleDispatch, method_name).__name__
2684                )
2685
2686                self.assertEqual(
2687                    getattr(WithSingleDispatch(), method_name).__name__,
2688                    getattr(WithoutSingleDispatch(), method_name).__name__
2689                )
2690
2691        for meth in (
2692            WithSingleDispatch.cls_context_manager,
2693            WithSingleDispatch().cls_context_manager,
2694            WithSingleDispatch.decorated_classmethod,
2695            WithSingleDispatch().decorated_classmethod
2696        ):
2697            with self.subTest(meth=meth):
2698                self.assertEqual(meth.__doc__, 'My function docstring')
2699                self.assertEqual(meth.__annotations__['arg'], int)
2700
2701        self.assertEqual(
2702            WithSingleDispatch.cls_context_manager.__name__,
2703            'cls_context_manager'
2704        )
2705        self.assertEqual(
2706            WithSingleDispatch().cls_context_manager.__name__,
2707            'cls_context_manager'
2708        )
2709        self.assertEqual(
2710            WithSingleDispatch.decorated_classmethod.__name__,
2711            'decorated_classmethod'
2712        )
2713        self.assertEqual(
2714            WithSingleDispatch().decorated_classmethod.__name__,
2715            'decorated_classmethod'
2716        )
2717
2718    def test_invalid_registrations(self):
2719        msg_prefix = "Invalid first argument to `register()`: "
2720        msg_suffix = (
2721            ". Use either `@register(some_class)` or plain `@register` on an "
2722            "annotated function."
2723        )
2724        @functools.singledispatch
2725        def i(arg):
2726            return "base"
2727        with self.assertRaises(TypeError) as exc:
2728            @i.register(42)
2729            def _(arg):
2730                return "I annotated with a non-type"
2731        self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2732        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2733        with self.assertRaises(TypeError) as exc:
2734            @i.register
2735            def _(arg):
2736                return "I forgot to annotate"
2737        self.assertTrue(str(exc.exception).startswith(msg_prefix +
2738            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2739        ))
2740        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2741
2742        with self.assertRaises(TypeError) as exc:
2743            @i.register
2744            def _(arg: typing.Iterable[str]):
2745                # At runtime, dispatching on generics is impossible.
2746                # When registering implementations with singledispatch, avoid
2747                # types from `typing`. Instead, annotate with regular types
2748                # or ABCs.
2749                return "I annotated with a generic collection"
2750        self.assertTrue(str(exc.exception).startswith(
2751            "Invalid annotation for 'arg'."
2752        ))
2753        self.assertTrue(str(exc.exception).endswith(
2754            'typing.Iterable[str] is not a class.'
2755        ))
2756
2757        with self.assertRaises(TypeError) as exc:
2758            @i.register
2759            def _(arg: typing.Union[int, typing.Iterable[str]]):
2760                return "Invalid Union"
2761        self.assertTrue(str(exc.exception).startswith(
2762            "Invalid annotation for 'arg'."
2763        ))
2764        self.assertTrue(str(exc.exception).endswith(
2765            'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
2766        ))
2767
2768    def test_invalid_positional_argument(self):
2769        @functools.singledispatch
2770        def f(*args):
2771            pass
2772        msg = 'f requires at least 1 positional argument'
2773        with self.assertRaisesRegex(TypeError, msg):
2774            f()
2775
2776    def test_union(self):
2777        @functools.singledispatch
2778        def f(arg):
2779            return "default"
2780
2781        @f.register
2782        def _(arg: typing.Union[str, bytes]):
2783            return "typing.Union"
2784
2785        @f.register
2786        def _(arg: int | float):
2787            return "types.UnionType"
2788
2789        self.assertEqual(f([]), "default")
2790        self.assertEqual(f(""), "typing.Union")
2791        self.assertEqual(f(b""), "typing.Union")
2792        self.assertEqual(f(1), "types.UnionType")
2793        self.assertEqual(f(1.0), "types.UnionType")
2794
2795    def test_union_conflict(self):
2796        @functools.singledispatch
2797        def f(arg):
2798            return "default"
2799
2800        @f.register
2801        def _(arg: typing.Union[str, bytes]):
2802            return "typing.Union"
2803
2804        @f.register
2805        def _(arg: int | str):
2806            return "types.UnionType"
2807
2808        self.assertEqual(f([]), "default")
2809        self.assertEqual(f(""), "types.UnionType")  # last one wins
2810        self.assertEqual(f(b""), "typing.Union")
2811        self.assertEqual(f(1), "types.UnionType")
2812
2813    def test_union_None(self):
2814        @functools.singledispatch
2815        def typing_union(arg):
2816            return "default"
2817
2818        @typing_union.register
2819        def _(arg: typing.Union[str, None]):
2820            return "typing.Union"
2821
2822        self.assertEqual(typing_union(1), "default")
2823        self.assertEqual(typing_union(""), "typing.Union")
2824        self.assertEqual(typing_union(None), "typing.Union")
2825
2826        @functools.singledispatch
2827        def types_union(arg):
2828            return "default"
2829
2830        @types_union.register
2831        def _(arg: int | None):
2832            return "types.UnionType"
2833
2834        self.assertEqual(types_union(""), "default")
2835        self.assertEqual(types_union(1), "types.UnionType")
2836        self.assertEqual(types_union(None), "types.UnionType")
2837
2838    def test_register_genericalias(self):
2839        @functools.singledispatch
2840        def f(arg):
2841            return "default"
2842
2843        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2844            f.register(list[int], lambda arg: "types.GenericAlias")
2845        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2846            f.register(typing.List[int], lambda arg: "typing.GenericAlias")
2847        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2848            f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)")
2849        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2850            f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]")
2851
2852        self.assertEqual(f([1]), "default")
2853        self.assertEqual(f([1.0]), "default")
2854        self.assertEqual(f(""), "default")
2855        self.assertEqual(f(b""), "default")
2856
2857    def test_register_genericalias_decorator(self):
2858        @functools.singledispatch
2859        def f(arg):
2860            return "default"
2861
2862        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2863            f.register(list[int])
2864        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2865            f.register(typing.List[int])
2866        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2867            f.register(list[int] | str)
2868        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2869            f.register(typing.List[int] | str)
2870
2871    def test_register_genericalias_annotation(self):
2872        @functools.singledispatch
2873        def f(arg):
2874            return "default"
2875
2876        with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2877            @f.register
2878            def _(arg: list[int]):
2879                return "types.GenericAlias"
2880        with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2881            @f.register
2882            def _(arg: typing.List[float]):
2883                return "typing.GenericAlias"
2884        with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2885            @f.register
2886            def _(arg: list[int] | str):
2887                return "types.UnionType(types.GenericAlias)"
2888        with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2889            @f.register
2890            def _(arg: typing.List[float] | bytes):
2891                return "typing.Union[typing.GenericAlias]"
2892
2893        self.assertEqual(f([1]), "default")
2894        self.assertEqual(f([1.0]), "default")
2895        self.assertEqual(f(""), "default")
2896        self.assertEqual(f(b""), "default")
2897
2898
2899class CachedCostItem:
2900    _cost = 1
2901
2902    def __init__(self):
2903        self.lock = py_functools.RLock()
2904
2905    @py_functools.cached_property
2906    def cost(self):
2907        """The cost of the item."""
2908        with self.lock:
2909            self._cost += 1
2910        return self._cost
2911
2912
2913class OptionallyCachedCostItem:
2914    _cost = 1
2915
2916    def get_cost(self):
2917        """The cost of the item."""
2918        self._cost += 1
2919        return self._cost
2920
2921    cached_cost = py_functools.cached_property(get_cost)
2922
2923
2924class CachedCostItemWait:
2925
2926    def __init__(self, event):
2927        self._cost = 1
2928        self.lock = py_functools.RLock()
2929        self.event = event
2930
2931    @py_functools.cached_property
2932    def cost(self):
2933        self.event.wait(1)
2934        with self.lock:
2935            self._cost += 1
2936        return self._cost
2937
2938
2939class CachedCostItemWithSlots:
2940    __slots__ = ('_cost')
2941
2942    def __init__(self):
2943        self._cost = 1
2944
2945    @py_functools.cached_property
2946    def cost(self):
2947        raise RuntimeError('never called, slots not supported')
2948
2949
2950class TestCachedProperty(unittest.TestCase):
2951    def test_cached(self):
2952        item = CachedCostItem()
2953        self.assertEqual(item.cost, 2)
2954        self.assertEqual(item.cost, 2) # not 3
2955
2956    def test_cached_attribute_name_differs_from_func_name(self):
2957        item = OptionallyCachedCostItem()
2958        self.assertEqual(item.get_cost(), 2)
2959        self.assertEqual(item.cached_cost, 3)
2960        self.assertEqual(item.get_cost(), 4)
2961        self.assertEqual(item.cached_cost, 3)
2962
2963    @threading_helper.requires_working_threading()
2964    def test_threaded(self):
2965        go = threading.Event()
2966        item = CachedCostItemWait(go)
2967
2968        num_threads = 3
2969
2970        orig_si = sys.getswitchinterval()
2971        sys.setswitchinterval(1e-6)
2972        try:
2973            threads = [
2974                threading.Thread(target=lambda: item.cost)
2975                for k in range(num_threads)
2976            ]
2977            with threading_helper.start_threads(threads):
2978                go.set()
2979        finally:
2980            sys.setswitchinterval(orig_si)
2981
2982        self.assertEqual(item.cost, 2)
2983
2984    def test_object_with_slots(self):
2985        item = CachedCostItemWithSlots()
2986        with self.assertRaisesRegex(
2987                TypeError,
2988                "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2989        ):
2990            item.cost
2991
2992    def test_immutable_dict(self):
2993        class MyMeta(type):
2994            @py_functools.cached_property
2995            def prop(self):
2996                return True
2997
2998        class MyClass(metaclass=MyMeta):
2999            pass
3000
3001        with self.assertRaisesRegex(
3002            TypeError,
3003            "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
3004        ):
3005            MyClass.prop
3006
3007    def test_reuse_different_names(self):
3008        """Disallow this case because decorated function a would not be cached."""
3009        with self.assertRaises(RuntimeError) as ctx:
3010            class ReusedCachedProperty:
3011                @py_functools.cached_property
3012                def a(self):
3013                    pass
3014
3015                b = a
3016
3017        self.assertEqual(
3018            str(ctx.exception.__context__),
3019            str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
3020        )
3021
3022    def test_reuse_same_name(self):
3023        """Reusing a cached_property on different classes under the same name is OK."""
3024        counter = 0
3025
3026        @py_functools.cached_property
3027        def _cp(_self):
3028            nonlocal counter
3029            counter += 1
3030            return counter
3031
3032        class A:
3033            cp = _cp
3034
3035        class B:
3036            cp = _cp
3037
3038        a = A()
3039        b = B()
3040
3041        self.assertEqual(a.cp, 1)
3042        self.assertEqual(b.cp, 2)
3043        self.assertEqual(a.cp, 1)
3044
3045    def test_set_name_not_called(self):
3046        cp = py_functools.cached_property(lambda s: None)
3047        class Foo:
3048            pass
3049
3050        Foo.cp = cp
3051
3052        with self.assertRaisesRegex(
3053                TypeError,
3054                "Cannot use cached_property instance without calling __set_name__ on it.",
3055        ):
3056            Foo().cp
3057
3058    def test_access_from_class(self):
3059        self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
3060
3061    def test_doc(self):
3062        self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
3063
3064
3065if __name__ == '__main__':
3066    unittest.main()
3067