1import sys
2import unittest
3
4
5class CacheTestMixin:
6
7    Cache = None
8
9    def test_defaults(self):
10        cache = self.Cache(maxsize=1)
11        self.assertEqual(0, len(cache))
12        self.assertEqual(1, cache.maxsize)
13        self.assertEqual(0, cache.currsize)
14        self.assertEqual(1, cache.getsizeof(None))
15        self.assertEqual(1, cache.getsizeof(""))
16        self.assertEqual(1, cache.getsizeof(0))
17        self.assertTrue(repr(cache).startswith(cache.__class__.__name__))
18
19    def test_insert(self):
20        cache = self.Cache(maxsize=2)
21
22        cache.update({1: 1, 2: 2})
23        self.assertEqual(2, len(cache))
24        self.assertEqual(1, cache[1])
25        self.assertEqual(2, cache[2])
26
27        cache[3] = 3
28        self.assertEqual(2, len(cache))
29        self.assertEqual(3, cache[3])
30        self.assertTrue(1 in cache or 2 in cache)
31
32        cache[4] = 4
33        self.assertEqual(2, len(cache))
34        self.assertEqual(4, cache[4])
35        self.assertTrue(1 in cache or 2 in cache or 3 in cache)
36
37    def test_update(self):
38        cache = self.Cache(maxsize=2)
39
40        cache.update({1: 1, 2: 2})
41        self.assertEqual(2, len(cache))
42        self.assertEqual(1, cache[1])
43        self.assertEqual(2, cache[2])
44
45        cache.update({1: 1, 2: 2})
46        self.assertEqual(2, len(cache))
47        self.assertEqual(1, cache[1])
48        self.assertEqual(2, cache[2])
49
50        cache.update({1: "a", 2: "b"})
51        self.assertEqual(2, len(cache))
52        self.assertEqual("a", cache[1])
53        self.assertEqual("b", cache[2])
54
55    def test_delete(self):
56        cache = self.Cache(maxsize=2)
57
58        cache.update({1: 1, 2: 2})
59        self.assertEqual(2, len(cache))
60        self.assertEqual(1, cache[1])
61        self.assertEqual(2, cache[2])
62
63        del cache[2]
64        self.assertEqual(1, len(cache))
65        self.assertEqual(1, cache[1])
66        self.assertNotIn(2, cache)
67
68        del cache[1]
69        self.assertEqual(0, len(cache))
70        self.assertNotIn(1, cache)
71        self.assertNotIn(2, cache)
72
73        with self.assertRaises(KeyError):
74            del cache[1]
75        self.assertEqual(0, len(cache))
76        self.assertNotIn(1, cache)
77        self.assertNotIn(2, cache)
78
79    def test_pop(self):
80        cache = self.Cache(maxsize=2)
81
82        cache.update({1: 1, 2: 2})
83        self.assertEqual(2, cache.pop(2))
84        self.assertEqual(1, len(cache))
85        self.assertEqual(1, cache.pop(1))
86        self.assertEqual(0, len(cache))
87
88        with self.assertRaises(KeyError):
89            cache.pop(2)
90        with self.assertRaises(KeyError):
91            cache.pop(1)
92        with self.assertRaises(KeyError):
93            cache.pop(0)
94
95        self.assertEqual(None, cache.pop(2, None))
96        self.assertEqual(None, cache.pop(1, None))
97        self.assertEqual(None, cache.pop(0, None))
98
99    def test_popitem(self):
100        cache = self.Cache(maxsize=2)
101
102        cache.update({1: 1, 2: 2})
103        self.assertIn(cache.pop(1), {1: 1, 2: 2})
104        self.assertEqual(1, len(cache))
105        self.assertIn(cache.pop(2), {1: 1, 2: 2})
106        self.assertEqual(0, len(cache))
107
108        with self.assertRaises(KeyError):
109            cache.popitem()
110
111    @unittest.skipUnless(sys.version_info >= (3, 7), "requires Python 3.7")
112    def test_popitem_exception_context(self):
113        # since Python 3.7, MutableMapping.popitem() suppresses
114        # exception context as implementation detail
115        exception = None
116        try:
117            self.Cache(maxsize=2).popitem()
118        except Exception as e:
119            exception = e
120        self.assertIsNone(exception.__cause__)
121        self.assertTrue(exception.__suppress_context__)
122
123    def test_missing(self):
124        class DefaultCache(self.Cache):
125            def __missing__(self, key):
126                self[key] = key
127                return key
128
129        cache = DefaultCache(maxsize=2)
130
131        self.assertEqual(0, cache.currsize)
132        self.assertEqual(2, cache.maxsize)
133        self.assertEqual(0, len(cache))
134        self.assertEqual(1, cache[1])
135        self.assertEqual(2, cache[2])
136        self.assertEqual(2, len(cache))
137        self.assertTrue(1 in cache and 2 in cache)
138
139        self.assertEqual(3, cache[3])
140        self.assertEqual(2, len(cache))
141        self.assertTrue(3 in cache)
142        self.assertTrue(1 in cache or 2 in cache)
143        self.assertTrue(1 not in cache or 2 not in cache)
144
145        self.assertEqual(4, cache[4])
146        self.assertEqual(2, len(cache))
147        self.assertTrue(4 in cache)
148        self.assertTrue(1 in cache or 2 in cache or 3 in cache)
149
150        # verify __missing__() is *not* called for any operations
151        # besides __getitem__()
152
153        self.assertEqual(4, cache.get(4))
154        self.assertEqual(None, cache.get(5))
155        self.assertEqual(5 * 5, cache.get(5, 5 * 5))
156        self.assertEqual(2, len(cache))
157
158        self.assertEqual(4, cache.pop(4))
159        with self.assertRaises(KeyError):
160            cache.pop(5)
161        self.assertEqual(None, cache.pop(5, None))
162        self.assertEqual(5 * 5, cache.pop(5, 5 * 5))
163        self.assertEqual(1, len(cache))
164
165        cache.clear()
166        cache[1] = 1 + 1
167        self.assertEqual(1 + 1, cache.setdefault(1))
168        self.assertEqual(1 + 1, cache.setdefault(1, 1))
169        self.assertEqual(1 + 1, cache[1])
170        self.assertEqual(2 + 2, cache.setdefault(2, 2 + 2))
171        self.assertEqual(2 + 2, cache.setdefault(2, None))
172        self.assertEqual(2 + 2, cache.setdefault(2))
173        self.assertEqual(2 + 2, cache[2])
174        self.assertEqual(2, len(cache))
175        self.assertTrue(1 in cache and 2 in cache)
176        self.assertEqual(None, cache.setdefault(3))
177        self.assertEqual(2, len(cache))
178        self.assertTrue(3 in cache)
179        self.assertTrue(1 in cache or 2 in cache)
180        self.assertTrue(1 not in cache or 2 not in cache)
181
182    def test_missing_getsizeof(self):
183        class DefaultCache(self.Cache):
184            def __missing__(self, key):
185                try:
186                    self[key] = key
187                except ValueError:
188                    pass  # not stored
189                return key
190
191        cache = DefaultCache(maxsize=2, getsizeof=lambda x: x)
192
193        self.assertEqual(0, cache.currsize)
194        self.assertEqual(2, cache.maxsize)
195
196        self.assertEqual(1, cache[1])
197        self.assertEqual(1, len(cache))
198        self.assertEqual(1, cache.currsize)
199        self.assertIn(1, cache)
200
201        self.assertEqual(2, cache[2])
202        self.assertEqual(1, len(cache))
203        self.assertEqual(2, cache.currsize)
204        self.assertNotIn(1, cache)
205        self.assertIn(2, cache)
206
207        self.assertEqual(3, cache[3])  # not stored
208        self.assertEqual(1, len(cache))
209        self.assertEqual(2, cache.currsize)
210        self.assertEqual(1, cache[1])
211        self.assertEqual(1, len(cache))
212        self.assertEqual(1, cache.currsize)
213        self.assertEqual((1, 1), cache.popitem())
214
215    def _test_getsizeof(self, cache):
216        self.assertEqual(0, cache.currsize)
217        self.assertEqual(3, cache.maxsize)
218        self.assertEqual(1, cache.getsizeof(1))
219        self.assertEqual(2, cache.getsizeof(2))
220        self.assertEqual(3, cache.getsizeof(3))
221
222        cache.update({1: 1, 2: 2})
223        self.assertEqual(2, len(cache))
224        self.assertEqual(3, cache.currsize)
225        self.assertEqual(1, cache[1])
226        self.assertEqual(2, cache[2])
227
228        cache[1] = 2
229        self.assertEqual(1, len(cache))
230        self.assertEqual(2, cache.currsize)
231        self.assertEqual(2, cache[1])
232        self.assertNotIn(2, cache)
233
234        cache.update({1: 1, 2: 2})
235        self.assertEqual(2, len(cache))
236        self.assertEqual(3, cache.currsize)
237        self.assertEqual(1, cache[1])
238        self.assertEqual(2, cache[2])
239
240        cache[3] = 3
241        self.assertEqual(1, len(cache))
242        self.assertEqual(3, cache.currsize)
243        self.assertEqual(3, cache[3])
244        self.assertNotIn(1, cache)
245        self.assertNotIn(2, cache)
246
247        with self.assertRaises(ValueError):
248            cache[3] = 4
249        self.assertEqual(1, len(cache))
250        self.assertEqual(3, cache.currsize)
251        self.assertEqual(3, cache[3])
252
253        with self.assertRaises(ValueError):
254            cache[4] = 4
255        self.assertEqual(1, len(cache))
256        self.assertEqual(3, cache.currsize)
257        self.assertEqual(3, cache[3])
258
259    def test_getsizeof_param(self):
260        self._test_getsizeof(self.Cache(maxsize=3, getsizeof=lambda x: x))
261
262    def test_getsizeof_subclass(self):
263        class Cache(self.Cache):
264            def getsizeof(self, value):
265                return value
266
267        self._test_getsizeof(Cache(maxsize=3))
268
269    def test_pickle(self):
270        import pickle
271
272        source = self.Cache(maxsize=2)
273        source.update({1: 1, 2: 2})
274
275        cache = pickle.loads(pickle.dumps(source))
276        self.assertEqual(source, cache)
277
278        self.assertEqual(2, len(cache))
279        self.assertEqual(1, cache[1])
280        self.assertEqual(2, cache[2])
281
282        cache[3] = 3
283        self.assertEqual(2, len(cache))
284        self.assertEqual(3, cache[3])
285        self.assertTrue(1 in cache or 2 in cache)
286
287        cache[4] = 4
288        self.assertEqual(2, len(cache))
289        self.assertEqual(4, cache[4])
290        self.assertTrue(1 in cache or 2 in cache or 3 in cache)
291
292        self.assertEqual(cache, pickle.loads(pickle.dumps(cache)))
293
294    def test_pickle_maxsize(self):
295        import pickle
296        import sys
297
298        # test empty cache, single element, large cache (recursion limit)
299        for n in [0, 1, sys.getrecursionlimit() * 2]:
300            source = self.Cache(maxsize=n)
301            source.update((i, i) for i in range(n))
302            cache = pickle.loads(pickle.dumps(source))
303            self.assertEqual(n, len(cache))
304            self.assertEqual(source, cache)
305