1import operator
2import unittest
3
4from cachetools import LRUCache, cachedmethod, keys
5
6
7class Cached:
8    def __init__(self, cache, count=0):
9        self.cache = cache
10        self.count = count
11
12    @cachedmethod(operator.attrgetter("cache"))
13    def get(self, value):
14        count = self.count
15        self.count += 1
16        return count
17
18    @cachedmethod(operator.attrgetter("cache"), key=keys.typedkey)
19    def get_typed(self, value):
20        count = self.count
21        self.count += 1
22        return count
23
24    # https://github.com/tkem/cachetools/issues/107
25    def __hash__(self):
26        raise TypeError("unhashable type")
27
28
29class Locked:
30    def __init__(self, cache):
31        self.cache = cache
32        self.count = 0
33
34    @cachedmethod(operator.attrgetter("cache"), lock=lambda self: self)
35    def get(self, value):
36        return self.count
37
38    def __enter__(self):
39        self.count += 1
40
41    def __exit__(self, *exc):
42        pass
43
44
45class CachedMethodTest(unittest.TestCase):
46    def test_dict(self):
47        cached = Cached({})
48
49        self.assertEqual(cached.get(0), 0)
50        self.assertEqual(cached.get(1), 1)
51        self.assertEqual(cached.get(1), 1)
52        self.assertEqual(cached.get(1.0), 1)
53        self.assertEqual(cached.get(1.0), 1)
54
55        cached.cache.clear()
56        self.assertEqual(cached.get(1), 2)
57
58    def test_typed_dict(self):
59        cached = Cached(LRUCache(maxsize=2))
60
61        self.assertEqual(cached.get_typed(0), 0)
62        self.assertEqual(cached.get_typed(1), 1)
63        self.assertEqual(cached.get_typed(1), 1)
64        self.assertEqual(cached.get_typed(1.0), 2)
65        self.assertEqual(cached.get_typed(1.0), 2)
66        self.assertEqual(cached.get_typed(0.0), 3)
67        self.assertEqual(cached.get_typed(0), 4)
68
69    def test_lru(self):
70        cached = Cached(LRUCache(maxsize=2))
71
72        self.assertEqual(cached.get(0), 0)
73        self.assertEqual(cached.get(1), 1)
74        self.assertEqual(cached.get(1), 1)
75        self.assertEqual(cached.get(1.0), 1)
76        self.assertEqual(cached.get(1.0), 1)
77
78        cached.cache.clear()
79        self.assertEqual(cached.get(1), 2)
80
81    def test_typed_lru(self):
82        cached = Cached(LRUCache(maxsize=2))
83
84        self.assertEqual(cached.get_typed(0), 0)
85        self.assertEqual(cached.get_typed(1), 1)
86        self.assertEqual(cached.get_typed(1), 1)
87        self.assertEqual(cached.get_typed(1.0), 2)
88        self.assertEqual(cached.get_typed(1.0), 2)
89        self.assertEqual(cached.get_typed(0.0), 3)
90        self.assertEqual(cached.get_typed(0), 4)
91
92    def test_nospace(self):
93        cached = Cached(LRUCache(maxsize=0))
94
95        self.assertEqual(cached.get(0), 0)
96        self.assertEqual(cached.get(1), 1)
97        self.assertEqual(cached.get(1), 2)
98        self.assertEqual(cached.get(1.0), 3)
99        self.assertEqual(cached.get(1.0), 4)
100
101    def test_nocache(self):
102        cached = Cached(None)
103
104        self.assertEqual(cached.get(0), 0)
105        self.assertEqual(cached.get(1), 1)
106        self.assertEqual(cached.get(1), 2)
107        self.assertEqual(cached.get(1.0), 3)
108        self.assertEqual(cached.get(1.0), 4)
109
110    def test_weakref(self):
111        import weakref
112        import fractions
113        import gc
114
115        # in Python 3.4, `int` does not support weak references even
116        # when subclassed, but Fraction apparently does...
117        class Int(fractions.Fraction):
118            def __add__(self, other):
119                return Int(fractions.Fraction.__add__(self, other))
120
121        cached = Cached(weakref.WeakValueDictionary(), count=Int(0))
122
123        self.assertEqual(cached.get(0), 0)
124        gc.collect()
125        self.assertEqual(cached.get(0), 1)
126
127        ref = cached.get(1)
128        self.assertEqual(ref, 2)
129        self.assertEqual(cached.get(1), 2)
130        self.assertEqual(cached.get(1.0), 2)
131
132        ref = cached.get_typed(1)
133        self.assertEqual(ref, 3)
134        self.assertEqual(cached.get_typed(1), 3)
135        self.assertEqual(cached.get_typed(1.0), 4)
136
137        cached.cache.clear()
138        self.assertEqual(cached.get(1), 5)
139
140    def test_locked_dict(self):
141        cached = Locked({})
142
143        self.assertEqual(cached.get(0), 1)
144        self.assertEqual(cached.get(1), 3)
145        self.assertEqual(cached.get(1), 3)
146        self.assertEqual(cached.get(1.0), 3)
147        self.assertEqual(cached.get(2.0), 7)
148
149    def test_locked_nocache(self):
150        cached = Locked(None)
151
152        self.assertEqual(cached.get(0), 0)
153        self.assertEqual(cached.get(1), 0)
154        self.assertEqual(cached.get(1), 0)
155        self.assertEqual(cached.get(1.0), 0)
156        self.assertEqual(cached.get(1.0), 0)
157
158    def test_locked_nospace(self):
159        cached = Locked(LRUCache(maxsize=0))
160
161        self.assertEqual(cached.get(0), 1)
162        self.assertEqual(cached.get(1), 3)
163        self.assertEqual(cached.get(1), 5)
164        self.assertEqual(cached.get(1.0), 7)
165        self.assertEqual(cached.get(1.0), 9)
166