1import unittest
2
3import cachetools
4import cachetools.keys
5
6
7class DecoratorTestMixin:
8    def cache(self, minsize):
9        raise NotImplementedError
10
11    def func(self, *args, **kwargs):
12        if hasattr(self, "count"):
13            self.count += 1
14        else:
15            self.count = 0
16        return self.count
17
18    def test_decorator(self):
19        cache = self.cache(2)
20        wrapper = cachetools.cached(cache)(self.func)
21
22        self.assertEqual(len(cache), 0)
23        self.assertEqual(wrapper.__wrapped__, self.func)
24
25        self.assertEqual(wrapper(0), 0)
26        self.assertEqual(len(cache), 1)
27        self.assertIn(cachetools.keys.hashkey(0), cache)
28        self.assertNotIn(cachetools.keys.hashkey(1), cache)
29        self.assertNotIn(cachetools.keys.hashkey(1.0), cache)
30
31        self.assertEqual(wrapper(1), 1)
32        self.assertEqual(len(cache), 2)
33        self.assertIn(cachetools.keys.hashkey(0), cache)
34        self.assertIn(cachetools.keys.hashkey(1), cache)
35        self.assertIn(cachetools.keys.hashkey(1.0), cache)
36
37        self.assertEqual(wrapper(1), 1)
38        self.assertEqual(len(cache), 2)
39
40        self.assertEqual(wrapper(1.0), 1)
41        self.assertEqual(len(cache), 2)
42
43        self.assertEqual(wrapper(1.0), 1)
44        self.assertEqual(len(cache), 2)
45
46    def test_decorator_typed(self):
47        cache = self.cache(3)
48        key = cachetools.keys.typedkey
49        wrapper = cachetools.cached(cache, key=key)(self.func)
50
51        self.assertEqual(len(cache), 0)
52        self.assertEqual(wrapper.__wrapped__, self.func)
53
54        self.assertEqual(wrapper(0), 0)
55        self.assertEqual(len(cache), 1)
56        self.assertIn(cachetools.keys.typedkey(0), cache)
57        self.assertNotIn(cachetools.keys.typedkey(1), cache)
58        self.assertNotIn(cachetools.keys.typedkey(1.0), cache)
59
60        self.assertEqual(wrapper(1), 1)
61        self.assertEqual(len(cache), 2)
62        self.assertIn(cachetools.keys.typedkey(0), cache)
63        self.assertIn(cachetools.keys.typedkey(1), cache)
64        self.assertNotIn(cachetools.keys.typedkey(1.0), cache)
65
66        self.assertEqual(wrapper(1), 1)
67        self.assertEqual(len(cache), 2)
68
69        self.assertEqual(wrapper(1.0), 2)
70        self.assertEqual(len(cache), 3)
71        self.assertIn(cachetools.keys.typedkey(0), cache)
72        self.assertIn(cachetools.keys.typedkey(1), cache)
73        self.assertIn(cachetools.keys.typedkey(1.0), cache)
74
75        self.assertEqual(wrapper(1.0), 2)
76        self.assertEqual(len(cache), 3)
77
78    def test_decorator_lock(self):
79        class Lock:
80
81            count = 0
82
83            def __enter__(self):
84                Lock.count += 1
85
86            def __exit__(self, *exc):
87                pass
88
89        cache = self.cache(2)
90        wrapper = cachetools.cached(cache, lock=Lock())(self.func)
91
92        self.assertEqual(len(cache), 0)
93        self.assertEqual(wrapper.__wrapped__, self.func)
94        self.assertEqual(wrapper(0), 0)
95        self.assertEqual(Lock.count, 2)
96        self.assertEqual(wrapper(1), 1)
97        self.assertEqual(Lock.count, 4)
98        self.assertEqual(wrapper(1), 1)
99        self.assertEqual(Lock.count, 5)
100
101
102class CacheWrapperTest(unittest.TestCase, DecoratorTestMixin):
103    def cache(self, minsize):
104        return cachetools.Cache(maxsize=minsize)
105
106    def test_zero_size_cache_decorator(self):
107        cache = self.cache(0)
108        wrapper = cachetools.cached(cache)(self.func)
109
110        self.assertEqual(len(cache), 0)
111        self.assertEqual(wrapper.__wrapped__, self.func)
112
113        self.assertEqual(wrapper(0), 0)
114        self.assertEqual(len(cache), 0)
115
116    def test_zero_size_cache_decorator_lock(self):
117        class Lock:
118
119            count = 0
120
121            def __enter__(self):
122                Lock.count += 1
123
124            def __exit__(self, *exc):
125                pass
126
127        cache = self.cache(0)
128        wrapper = cachetools.cached(cache, lock=Lock())(self.func)
129
130        self.assertEqual(len(cache), 0)
131        self.assertEqual(wrapper.__wrapped__, self.func)
132
133        self.assertEqual(wrapper(0), 0)
134        self.assertEqual(len(cache), 0)
135        self.assertEqual(Lock.count, 2)
136
137
138class DictWrapperTest(unittest.TestCase, DecoratorTestMixin):
139    def cache(self, minsize):
140        return dict()
141
142
143class NoneWrapperTest(unittest.TestCase):
144    def func(self, *args, **kwargs):
145        return args + tuple(kwargs.items())
146
147    def test_decorator(self):
148        wrapper = cachetools.cached(None)(self.func)
149        self.assertEqual(wrapper.__wrapped__, self.func)
150
151        self.assertEqual(wrapper(0), (0,))
152        self.assertEqual(wrapper(1), (1,))
153        self.assertEqual(wrapper(1, foo="bar"), (1, ("foo", "bar")))
154