1import concurrent.futures
2import contextvars
3import functools
4import gc
5import random
6import time
7import unittest
8import weakref
9from test.support import threading_helper
10
11try:
12    from _testcapi import hamt
13except ImportError:
14    hamt = None
15
16
17def isolated_context(func):
18    """Needed to make reftracking test mode work."""
19    @functools.wraps(func)
20    def wrapper(*args, **kwargs):
21        ctx = contextvars.Context()
22        return ctx.run(func, *args, **kwargs)
23    return wrapper
24
25
26class ContextTest(unittest.TestCase):
27    def test_context_var_new_1(self):
28        with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
29            contextvars.ContextVar()
30
31        with self.assertRaisesRegex(TypeError, 'must be a str'):
32            contextvars.ContextVar(1)
33
34        c = contextvars.ContextVar('aaa')
35        self.assertEqual(c.name, 'aaa')
36
37        with self.assertRaises(AttributeError):
38            c.name = 'bbb'
39
40        self.assertNotEqual(hash(c), hash('aaa'))
41
42    @isolated_context
43    def test_context_var_repr_1(self):
44        c = contextvars.ContextVar('a')
45        self.assertIn('a', repr(c))
46
47        c = contextvars.ContextVar('a', default=123)
48        self.assertIn('123', repr(c))
49
50        lst = []
51        c = contextvars.ContextVar('a', default=lst)
52        lst.append(c)
53        self.assertIn('...', repr(c))
54        self.assertIn('...', repr(lst))
55
56        t = c.set(1)
57        self.assertIn(repr(c), repr(t))
58        self.assertNotIn(' used ', repr(t))
59        c.reset(t)
60        self.assertIn(' used ', repr(t))
61
62    def test_context_subclassing_1(self):
63        with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
64            class MyContextVar(contextvars.ContextVar):
65                # Potentially we might want ContextVars to be subclassable.
66                pass
67
68        with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
69            class MyContext(contextvars.Context):
70                pass
71
72        with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
73            class MyToken(contextvars.Token):
74                pass
75
76    def test_context_new_1(self):
77        with self.assertRaisesRegex(TypeError, 'any arguments'):
78            contextvars.Context(1)
79        with self.assertRaisesRegex(TypeError, 'any arguments'):
80            contextvars.Context(1, a=1)
81        with self.assertRaisesRegex(TypeError, 'any arguments'):
82            contextvars.Context(a=1)
83        contextvars.Context(**{})
84
85    def test_context_typerrors_1(self):
86        ctx = contextvars.Context()
87
88        with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
89            ctx[1]
90        with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
91            1 in ctx
92        with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
93            ctx.get(1)
94
95    def test_context_get_context_1(self):
96        ctx = contextvars.copy_context()
97        self.assertIsInstance(ctx, contextvars.Context)
98
99    def test_context_run_1(self):
100        ctx = contextvars.Context()
101
102        with self.assertRaisesRegex(TypeError, 'missing 1 required'):
103            ctx.run()
104
105    def test_context_run_2(self):
106        ctx = contextvars.Context()
107
108        def func(*args, **kwargs):
109            kwargs['spam'] = 'foo'
110            args += ('bar',)
111            return args, kwargs
112
113        for f in (func, functools.partial(func)):
114            # partial doesn't support FASTCALL
115
116            self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
117            self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
118
119            self.assertEqual(
120                ctx.run(f, a=2),
121                (('bar',), {'a': 2, 'spam': 'foo'}))
122
123            self.assertEqual(
124                ctx.run(f, 11, a=2),
125                ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
126
127            a = {}
128            self.assertEqual(
129                ctx.run(f, 11, **a),
130                ((11, 'bar'), {'spam': 'foo'}))
131            self.assertEqual(a, {})
132
133    def test_context_run_3(self):
134        ctx = contextvars.Context()
135
136        def func(*args, **kwargs):
137            1 / 0
138
139        with self.assertRaises(ZeroDivisionError):
140            ctx.run(func)
141        with self.assertRaises(ZeroDivisionError):
142            ctx.run(func, 1, 2)
143        with self.assertRaises(ZeroDivisionError):
144            ctx.run(func, 1, 2, a=123)
145
146    @isolated_context
147    def test_context_run_4(self):
148        ctx1 = contextvars.Context()
149        ctx2 = contextvars.Context()
150        var = contextvars.ContextVar('var')
151
152        def func2():
153            self.assertIsNone(var.get(None))
154
155        def func1():
156            self.assertIsNone(var.get(None))
157            var.set('spam')
158            ctx2.run(func2)
159            self.assertEqual(var.get(None), 'spam')
160
161            cur = contextvars.copy_context()
162            self.assertEqual(len(cur), 1)
163            self.assertEqual(cur[var], 'spam')
164            return cur
165
166        returned_ctx = ctx1.run(func1)
167        self.assertEqual(ctx1, returned_ctx)
168        self.assertEqual(returned_ctx[var], 'spam')
169        self.assertIn(var, returned_ctx)
170
171    def test_context_run_5(self):
172        ctx = contextvars.Context()
173        var = contextvars.ContextVar('var')
174
175        def func():
176            self.assertIsNone(var.get(None))
177            var.set('spam')
178            1 / 0
179
180        with self.assertRaises(ZeroDivisionError):
181            ctx.run(func)
182
183        self.assertIsNone(var.get(None))
184
185    def test_context_run_6(self):
186        ctx = contextvars.Context()
187        c = contextvars.ContextVar('a', default=0)
188
189        def fun():
190            self.assertEqual(c.get(), 0)
191            self.assertIsNone(ctx.get(c))
192
193            c.set(42)
194            self.assertEqual(c.get(), 42)
195            self.assertEqual(ctx.get(c), 42)
196
197        ctx.run(fun)
198
199    def test_context_run_7(self):
200        ctx = contextvars.Context()
201
202        def fun():
203            with self.assertRaisesRegex(RuntimeError, 'is already entered'):
204                ctx.run(fun)
205
206        ctx.run(fun)
207
208    @isolated_context
209    def test_context_getset_1(self):
210        c = contextvars.ContextVar('c')
211        with self.assertRaises(LookupError):
212            c.get()
213
214        self.assertIsNone(c.get(None))
215
216        t0 = c.set(42)
217        self.assertEqual(c.get(), 42)
218        self.assertEqual(c.get(None), 42)
219        self.assertIs(t0.old_value, t0.MISSING)
220        self.assertIs(t0.old_value, contextvars.Token.MISSING)
221        self.assertIs(t0.var, c)
222
223        t = c.set('spam')
224        self.assertEqual(c.get(), 'spam')
225        self.assertEqual(c.get(None), 'spam')
226        self.assertEqual(t.old_value, 42)
227        c.reset(t)
228
229        self.assertEqual(c.get(), 42)
230        self.assertEqual(c.get(None), 42)
231
232        c.set('spam2')
233        with self.assertRaisesRegex(RuntimeError, 'has already been used'):
234            c.reset(t)
235        self.assertEqual(c.get(), 'spam2')
236
237        ctx1 = contextvars.copy_context()
238        self.assertIn(c, ctx1)
239
240        c.reset(t0)
241        with self.assertRaisesRegex(RuntimeError, 'has already been used'):
242            c.reset(t0)
243        self.assertIsNone(c.get(None))
244
245        self.assertIn(c, ctx1)
246        self.assertEqual(ctx1[c], 'spam2')
247        self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
248        self.assertEqual(len(ctx1), 1)
249        self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
250        self.assertEqual(list(ctx1.values()), ['spam2'])
251        self.assertEqual(list(ctx1.keys()), [c])
252        self.assertEqual(list(ctx1), [c])
253
254        ctx2 = contextvars.copy_context()
255        self.assertNotIn(c, ctx2)
256        with self.assertRaises(KeyError):
257            ctx2[c]
258        self.assertEqual(ctx2.get(c, 'aa'), 'aa')
259        self.assertEqual(len(ctx2), 0)
260        self.assertEqual(list(ctx2), [])
261
262    @isolated_context
263    def test_context_getset_2(self):
264        v1 = contextvars.ContextVar('v1')
265        v2 = contextvars.ContextVar('v2')
266
267        t1 = v1.set(42)
268        with self.assertRaisesRegex(ValueError, 'by a different'):
269            v2.reset(t1)
270
271    @isolated_context
272    def test_context_getset_3(self):
273        c = contextvars.ContextVar('c', default=42)
274        ctx = contextvars.Context()
275
276        def fun():
277            self.assertEqual(c.get(), 42)
278            with self.assertRaises(KeyError):
279                ctx[c]
280            self.assertIsNone(ctx.get(c))
281            self.assertEqual(ctx.get(c, 'spam'), 'spam')
282            self.assertNotIn(c, ctx)
283            self.assertEqual(list(ctx.keys()), [])
284
285            t = c.set(1)
286            self.assertEqual(list(ctx.keys()), [c])
287            self.assertEqual(ctx[c], 1)
288
289            c.reset(t)
290            self.assertEqual(list(ctx.keys()), [])
291            with self.assertRaises(KeyError):
292                ctx[c]
293
294        ctx.run(fun)
295
296    @isolated_context
297    def test_context_getset_4(self):
298        c = contextvars.ContextVar('c', default=42)
299        ctx = contextvars.Context()
300
301        tok = ctx.run(c.set, 1)
302
303        with self.assertRaisesRegex(ValueError, 'different Context'):
304            c.reset(tok)
305
306    @isolated_context
307    def test_context_getset_5(self):
308        c = contextvars.ContextVar('c', default=42)
309        c.set([])
310
311        def fun():
312            c.set([])
313            c.get().append(42)
314            self.assertEqual(c.get(), [42])
315
316        contextvars.copy_context().run(fun)
317        self.assertEqual(c.get(), [])
318
319    def test_context_copy_1(self):
320        ctx1 = contextvars.Context()
321        c = contextvars.ContextVar('c', default=42)
322
323        def ctx1_fun():
324            c.set(10)
325
326            ctx2 = ctx1.copy()
327            self.assertEqual(ctx2[c], 10)
328
329            c.set(20)
330            self.assertEqual(ctx1[c], 20)
331            self.assertEqual(ctx2[c], 10)
332
333            ctx2.run(ctx2_fun)
334            self.assertEqual(ctx1[c], 20)
335            self.assertEqual(ctx2[c], 30)
336
337        def ctx2_fun():
338            self.assertEqual(c.get(), 10)
339            c.set(30)
340            self.assertEqual(c.get(), 30)
341
342        ctx1.run(ctx1_fun)
343
344    @isolated_context
345    @threading_helper.requires_working_threading()
346    def test_context_threads_1(self):
347        cvar = contextvars.ContextVar('cvar')
348
349        def sub(num):
350            for i in range(10):
351                cvar.set(num + i)
352                time.sleep(random.uniform(0.001, 0.05))
353                self.assertEqual(cvar.get(), num + i)
354            return num
355
356        tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
357        try:
358            results = list(tp.map(sub, range(10)))
359        finally:
360            tp.shutdown()
361        self.assertEqual(results, list(range(10)))
362
363
364# HAMT Tests
365
366
367class HashKey:
368    _crasher = None
369
370    def __init__(self, hash, name, *, error_on_eq_to=None):
371        assert hash != -1
372        self.name = name
373        self.hash = hash
374        self.error_on_eq_to = error_on_eq_to
375
376    def __repr__(self):
377        return f'<Key name:{self.name} hash:{self.hash}>'
378
379    def __hash__(self):
380        if self._crasher is not None and self._crasher.error_on_hash:
381            raise HashingError
382
383        return self.hash
384
385    def __eq__(self, other):
386        if not isinstance(other, HashKey):
387            return NotImplemented
388
389        if self._crasher is not None and self._crasher.error_on_eq:
390            raise EqError
391
392        if self.error_on_eq_to is not None and self.error_on_eq_to is other:
393            raise ValueError(f'cannot compare {self!r} to {other!r}')
394        if other.error_on_eq_to is not None and other.error_on_eq_to is self:
395            raise ValueError(f'cannot compare {other!r} to {self!r}')
396
397        return (self.name, self.hash) == (other.name, other.hash)
398
399
400class KeyStr(str):
401    def __hash__(self):
402        if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
403            raise HashingError
404        return super().__hash__()
405
406    def __eq__(self, other):
407        if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
408            raise EqError
409        return super().__eq__(other)
410
411
412class HaskKeyCrasher:
413    def __init__(self, *, error_on_hash=False, error_on_eq=False):
414        self.error_on_hash = error_on_hash
415        self.error_on_eq = error_on_eq
416
417    def __enter__(self):
418        if HashKey._crasher is not None:
419            raise RuntimeError('cannot nest crashers')
420        HashKey._crasher = self
421
422    def __exit__(self, *exc):
423        HashKey._crasher = None
424
425
426class HashingError(Exception):
427    pass
428
429
430class EqError(Exception):
431    pass
432
433
434@unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
435class HamtTest(unittest.TestCase):
436
437    def test_hashkey_helper_1(self):
438        k1 = HashKey(10, 'aaa')
439        k2 = HashKey(10, 'bbb')
440
441        self.assertNotEqual(k1, k2)
442        self.assertEqual(hash(k1), hash(k2))
443
444        d = dict()
445        d[k1] = 'a'
446        d[k2] = 'b'
447
448        self.assertEqual(d[k1], 'a')
449        self.assertEqual(d[k2], 'b')
450
451    def test_hamt_basics_1(self):
452        h = hamt()
453        h = None  # NoQA
454
455    def test_hamt_basics_2(self):
456        h = hamt()
457        self.assertEqual(len(h), 0)
458
459        h2 = h.set('a', 'b')
460        self.assertIsNot(h, h2)
461        self.assertEqual(len(h), 0)
462        self.assertEqual(len(h2), 1)
463
464        self.assertIsNone(h.get('a'))
465        self.assertEqual(h.get('a', 42), 42)
466
467        self.assertEqual(h2.get('a'), 'b')
468
469        h3 = h2.set('b', 10)
470        self.assertIsNot(h2, h3)
471        self.assertEqual(len(h), 0)
472        self.assertEqual(len(h2), 1)
473        self.assertEqual(len(h3), 2)
474        self.assertEqual(h3.get('a'), 'b')
475        self.assertEqual(h3.get('b'), 10)
476
477        self.assertIsNone(h.get('b'))
478        self.assertIsNone(h2.get('b'))
479
480        self.assertIsNone(h.get('a'))
481        self.assertEqual(h2.get('a'), 'b')
482
483        h = h2 = h3 = None
484
485    def test_hamt_basics_3(self):
486        h = hamt()
487        o = object()
488        h1 = h.set('1', o)
489        h2 = h1.set('1', o)
490        self.assertIs(h1, h2)
491
492    def test_hamt_basics_4(self):
493        h = hamt()
494        h1 = h.set('key', [])
495        h2 = h1.set('key', [])
496        self.assertIsNot(h1, h2)
497        self.assertEqual(len(h1), 1)
498        self.assertEqual(len(h2), 1)
499        self.assertIsNot(h1.get('key'), h2.get('key'))
500
501    def test_hamt_collision_1(self):
502        k1 = HashKey(10, 'aaa')
503        k2 = HashKey(10, 'bbb')
504        k3 = HashKey(10, 'ccc')
505
506        h = hamt()
507        h2 = h.set(k1, 'a')
508        h3 = h2.set(k2, 'b')
509
510        self.assertEqual(h.get(k1), None)
511        self.assertEqual(h.get(k2), None)
512
513        self.assertEqual(h2.get(k1), 'a')
514        self.assertEqual(h2.get(k2), None)
515
516        self.assertEqual(h3.get(k1), 'a')
517        self.assertEqual(h3.get(k2), 'b')
518
519        h4 = h3.set(k2, 'cc')
520        h5 = h4.set(k3, 'aa')
521
522        self.assertEqual(h3.get(k1), 'a')
523        self.assertEqual(h3.get(k2), 'b')
524        self.assertEqual(h4.get(k1), 'a')
525        self.assertEqual(h4.get(k2), 'cc')
526        self.assertEqual(h4.get(k3), None)
527        self.assertEqual(h5.get(k1), 'a')
528        self.assertEqual(h5.get(k2), 'cc')
529        self.assertEqual(h5.get(k2), 'cc')
530        self.assertEqual(h5.get(k3), 'aa')
531
532        self.assertEqual(len(h), 0)
533        self.assertEqual(len(h2), 1)
534        self.assertEqual(len(h3), 2)
535        self.assertEqual(len(h4), 2)
536        self.assertEqual(len(h5), 3)
537
538    def test_hamt_collision_3(self):
539        # Test that iteration works with the deepest tree possible.
540        # https://github.com/python/cpython/issues/93065
541
542        C = HashKey(0b10000000_00000000_00000000_00000000, 'C')
543        D = HashKey(0b10000000_00000000_00000000_00000000, 'D')
544
545        E = HashKey(0b00000000_00000000_00000000_00000000, 'E')
546
547        h = hamt()
548        h = h.set(C, 'C')
549        h = h.set(D, 'D')
550        h = h.set(E, 'E')
551
552        # BitmapNode(size=2 count=1 bitmap=0b1):
553        #   NULL:
554        #     BitmapNode(size=2 count=1 bitmap=0b1):
555        #       NULL:
556        #         BitmapNode(size=2 count=1 bitmap=0b1):
557        #           NULL:
558        #             BitmapNode(size=2 count=1 bitmap=0b1):
559        #               NULL:
560        #                 BitmapNode(size=2 count=1 bitmap=0b1):
561        #                   NULL:
562        #                     BitmapNode(size=2 count=1 bitmap=0b1):
563        #                       NULL:
564        #                         BitmapNode(size=4 count=2 bitmap=0b101):
565        #                           <Key name:E hash:0>: 'E'
566        #                           NULL:
567        #                             CollisionNode(size=4 id=0x107a24520):
568        #                               <Key name:C hash:2147483648>: 'C'
569        #                               <Key name:D hash:2147483648>: 'D'
570
571        self.assertEqual({k.name for k in h.keys()}, {'C', 'D', 'E'})
572
573    def test_hamt_stress(self):
574        COLLECTION_SIZE = 7000
575        TEST_ITERS_EVERY = 647
576        CRASH_HASH_EVERY = 97
577        CRASH_EQ_EVERY = 11
578        RUN_XTIMES = 3
579
580        for _ in range(RUN_XTIMES):
581            h = hamt()
582            d = dict()
583
584            for i in range(COLLECTION_SIZE):
585                key = KeyStr(i)
586
587                if not (i % CRASH_HASH_EVERY):
588                    with HaskKeyCrasher(error_on_hash=True):
589                        with self.assertRaises(HashingError):
590                            h.set(key, i)
591
592                h = h.set(key, i)
593
594                if not (i % CRASH_EQ_EVERY):
595                    with HaskKeyCrasher(error_on_eq=True):
596                        with self.assertRaises(EqError):
597                            h.get(KeyStr(i))  # really trigger __eq__
598
599                d[key] = i
600                self.assertEqual(len(d), len(h))
601
602                if not (i % TEST_ITERS_EVERY):
603                    self.assertEqual(set(h.items()), set(d.items()))
604                    self.assertEqual(len(h.items()), len(d.items()))
605
606            self.assertEqual(len(h), COLLECTION_SIZE)
607
608            for key in range(COLLECTION_SIZE):
609                self.assertEqual(h.get(KeyStr(key), 'not found'), key)
610
611            keys_to_delete = list(range(COLLECTION_SIZE))
612            random.shuffle(keys_to_delete)
613            for iter_i, i in enumerate(keys_to_delete):
614                key = KeyStr(i)
615
616                if not (iter_i % CRASH_HASH_EVERY):
617                    with HaskKeyCrasher(error_on_hash=True):
618                        with self.assertRaises(HashingError):
619                            h.delete(key)
620
621                if not (iter_i % CRASH_EQ_EVERY):
622                    with HaskKeyCrasher(error_on_eq=True):
623                        with self.assertRaises(EqError):
624                            h.delete(KeyStr(i))
625
626                h = h.delete(key)
627                self.assertEqual(h.get(key, 'not found'), 'not found')
628                del d[key]
629                self.assertEqual(len(d), len(h))
630
631                if iter_i == COLLECTION_SIZE // 2:
632                    hm = h
633                    dm = d.copy()
634
635                if not (iter_i % TEST_ITERS_EVERY):
636                    self.assertEqual(set(h.keys()), set(d.keys()))
637                    self.assertEqual(len(h.keys()), len(d.keys()))
638
639            self.assertEqual(len(d), 0)
640            self.assertEqual(len(h), 0)
641
642            # ============
643
644            for key in dm:
645                self.assertEqual(hm.get(str(key)), dm[key])
646            self.assertEqual(len(dm), len(hm))
647
648            for i, key in enumerate(keys_to_delete):
649                hm = hm.delete(str(key))
650                self.assertEqual(hm.get(str(key), 'not found'), 'not found')
651                dm.pop(str(key), None)
652                self.assertEqual(len(d), len(h))
653
654                if not (i % TEST_ITERS_EVERY):
655                    self.assertEqual(set(h.values()), set(d.values()))
656                    self.assertEqual(len(h.values()), len(d.values()))
657
658            self.assertEqual(len(d), 0)
659            self.assertEqual(len(h), 0)
660            self.assertEqual(list(h.items()), [])
661
662    def test_hamt_delete_1(self):
663        A = HashKey(100, 'A')
664        B = HashKey(101, 'B')
665        C = HashKey(102, 'C')
666        D = HashKey(103, 'D')
667        E = HashKey(104, 'E')
668        Z = HashKey(-100, 'Z')
669
670        Er = HashKey(103, 'Er', error_on_eq_to=D)
671
672        h = hamt()
673        h = h.set(A, 'a')
674        h = h.set(B, 'b')
675        h = h.set(C, 'c')
676        h = h.set(D, 'd')
677        h = h.set(E, 'e')
678
679        orig_len = len(h)
680
681        # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
682        #     <Key name:A hash:100>: 'a'
683        #     <Key name:B hash:101>: 'b'
684        #     <Key name:C hash:102>: 'c'
685        #     <Key name:D hash:103>: 'd'
686        #     <Key name:E hash:104>: 'e'
687
688        h = h.delete(C)
689        self.assertEqual(len(h), orig_len - 1)
690
691        with self.assertRaisesRegex(ValueError, 'cannot compare'):
692            h.delete(Er)
693
694        h = h.delete(D)
695        self.assertEqual(len(h), orig_len - 2)
696
697        h2 = h.delete(Z)
698        self.assertIs(h2, h)
699
700        h = h.delete(A)
701        self.assertEqual(len(h), orig_len - 3)
702
703        self.assertEqual(h.get(A, 42), 42)
704        self.assertEqual(h.get(B), 'b')
705        self.assertEqual(h.get(E), 'e')
706
707    def test_hamt_delete_2(self):
708        A = HashKey(100, 'A')
709        B = HashKey(201001, 'B')
710        C = HashKey(101001, 'C')
711        D = HashKey(103, 'D')
712        E = HashKey(104, 'E')
713        Z = HashKey(-100, 'Z')
714
715        Er = HashKey(201001, 'Er', error_on_eq_to=B)
716
717        h = hamt()
718        h = h.set(A, 'a')
719        h = h.set(B, 'b')
720        h = h.set(C, 'c')
721        h = h.set(D, 'd')
722        h = h.set(E, 'e')
723
724        orig_len = len(h)
725
726        # BitmapNode(size=8 bitmap=0b1110010000):
727        #     <Key name:A hash:100>: 'a'
728        #     <Key name:D hash:103>: 'd'
729        #     <Key name:E hash:104>: 'e'
730        #     NULL:
731        #         BitmapNode(size=4 bitmap=0b100000000001000000000):
732        #             <Key name:B hash:201001>: 'b'
733        #             <Key name:C hash:101001>: 'c'
734
735        with self.assertRaisesRegex(ValueError, 'cannot compare'):
736            h.delete(Er)
737
738        h = h.delete(Z)
739        self.assertEqual(len(h), orig_len)
740
741        h = h.delete(C)
742        self.assertEqual(len(h), orig_len - 1)
743
744        h = h.delete(B)
745        self.assertEqual(len(h), orig_len - 2)
746
747        h = h.delete(A)
748        self.assertEqual(len(h), orig_len - 3)
749
750        self.assertEqual(h.get(D), 'd')
751        self.assertEqual(h.get(E), 'e')
752
753        h = h.delete(A)
754        h = h.delete(B)
755        h = h.delete(D)
756        h = h.delete(E)
757        self.assertEqual(len(h), 0)
758
759    def test_hamt_delete_3(self):
760        A = HashKey(100, 'A')
761        B = HashKey(101, 'B')
762        C = HashKey(100100, 'C')
763        D = HashKey(100100, 'D')
764        E = HashKey(104, 'E')
765
766        h = hamt()
767        h = h.set(A, 'a')
768        h = h.set(B, 'b')
769        h = h.set(C, 'c')
770        h = h.set(D, 'd')
771        h = h.set(E, 'e')
772
773        orig_len = len(h)
774
775        # BitmapNode(size=6 bitmap=0b100110000):
776        #     NULL:
777        #         BitmapNode(size=4 bitmap=0b1000000000000000000001000):
778        #             <Key name:A hash:100>: 'a'
779        #             NULL:
780        #                 CollisionNode(size=4 id=0x108572410):
781        #                     <Key name:C hash:100100>: 'c'
782        #                     <Key name:D hash:100100>: 'd'
783        #     <Key name:B hash:101>: 'b'
784        #     <Key name:E hash:104>: 'e'
785
786        h = h.delete(A)
787        self.assertEqual(len(h), orig_len - 1)
788
789        h = h.delete(E)
790        self.assertEqual(len(h), orig_len - 2)
791
792        self.assertEqual(h.get(C), 'c')
793        self.assertEqual(h.get(B), 'b')
794
795    def test_hamt_delete_4(self):
796        A = HashKey(100, 'A')
797        B = HashKey(101, 'B')
798        C = HashKey(100100, 'C')
799        D = HashKey(100100, 'D')
800        E = HashKey(100100, 'E')
801
802        h = hamt()
803        h = h.set(A, 'a')
804        h = h.set(B, 'b')
805        h = h.set(C, 'c')
806        h = h.set(D, 'd')
807        h = h.set(E, 'e')
808
809        orig_len = len(h)
810
811        # BitmapNode(size=4 bitmap=0b110000):
812        #     NULL:
813        #         BitmapNode(size=4 bitmap=0b1000000000000000000001000):
814        #             <Key name:A hash:100>: 'a'
815        #             NULL:
816        #                 CollisionNode(size=6 id=0x10515ef30):
817        #                     <Key name:C hash:100100>: 'c'
818        #                     <Key name:D hash:100100>: 'd'
819        #                     <Key name:E hash:100100>: 'e'
820        #     <Key name:B hash:101>: 'b'
821
822        h = h.delete(D)
823        self.assertEqual(len(h), orig_len - 1)
824
825        h = h.delete(E)
826        self.assertEqual(len(h), orig_len - 2)
827
828        h = h.delete(C)
829        self.assertEqual(len(h), orig_len - 3)
830
831        h = h.delete(A)
832        self.assertEqual(len(h), orig_len - 4)
833
834        h = h.delete(B)
835        self.assertEqual(len(h), 0)
836
837    def test_hamt_delete_5(self):
838        h = hamt()
839
840        keys = []
841        for i in range(17):
842            key = HashKey(i, str(i))
843            keys.append(key)
844            h = h.set(key, f'val-{i}')
845
846        collision_key16 = HashKey(16, '18')
847        h = h.set(collision_key16, 'collision')
848
849        # ArrayNode(id=0x10f8b9318):
850        #     0::
851        #     BitmapNode(size=2 count=1 bitmap=0b1):
852        #         <Key name:0 hash:0>: 'val-0'
853        #
854        # ... 14 more BitmapNodes ...
855        #
856        #     15::
857        #     BitmapNode(size=2 count=1 bitmap=0b1):
858        #         <Key name:15 hash:15>: 'val-15'
859        #
860        #     16::
861        #     BitmapNode(size=2 count=1 bitmap=0b1):
862        #         NULL:
863        #             CollisionNode(size=4 id=0x10f2f5af8):
864        #                 <Key name:16 hash:16>: 'val-16'
865        #                 <Key name:18 hash:16>: 'collision'
866
867        self.assertEqual(len(h), 18)
868
869        h = h.delete(keys[2])
870        self.assertEqual(len(h), 17)
871
872        h = h.delete(collision_key16)
873        self.assertEqual(len(h), 16)
874        h = h.delete(keys[16])
875        self.assertEqual(len(h), 15)
876
877        h = h.delete(keys[1])
878        self.assertEqual(len(h), 14)
879        h = h.delete(keys[1])
880        self.assertEqual(len(h), 14)
881
882        for key in keys:
883            h = h.delete(key)
884        self.assertEqual(len(h), 0)
885
886    def test_hamt_items_1(self):
887        A = HashKey(100, 'A')
888        B = HashKey(201001, 'B')
889        C = HashKey(101001, 'C')
890        D = HashKey(103, 'D')
891        E = HashKey(104, 'E')
892        F = HashKey(110, 'F')
893
894        h = hamt()
895        h = h.set(A, 'a')
896        h = h.set(B, 'b')
897        h = h.set(C, 'c')
898        h = h.set(D, 'd')
899        h = h.set(E, 'e')
900        h = h.set(F, 'f')
901
902        it = h.items()
903        self.assertEqual(
904            set(list(it)),
905            {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
906
907    def test_hamt_items_2(self):
908        A = HashKey(100, 'A')
909        B = HashKey(101, 'B')
910        C = HashKey(100100, 'C')
911        D = HashKey(100100, 'D')
912        E = HashKey(100100, 'E')
913        F = HashKey(110, 'F')
914
915        h = hamt()
916        h = h.set(A, 'a')
917        h = h.set(B, 'b')
918        h = h.set(C, 'c')
919        h = h.set(D, 'd')
920        h = h.set(E, 'e')
921        h = h.set(F, 'f')
922
923        it = h.items()
924        self.assertEqual(
925            set(list(it)),
926            {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
927
928    def test_hamt_keys_1(self):
929        A = HashKey(100, 'A')
930        B = HashKey(101, 'B')
931        C = HashKey(100100, 'C')
932        D = HashKey(100100, 'D')
933        E = HashKey(100100, 'E')
934        F = HashKey(110, 'F')
935
936        h = hamt()
937        h = h.set(A, 'a')
938        h = h.set(B, 'b')
939        h = h.set(C, 'c')
940        h = h.set(D, 'd')
941        h = h.set(E, 'e')
942        h = h.set(F, 'f')
943
944        self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
945        self.assertEqual(set(list(h)), {A, B, C, D, E, F})
946
947    def test_hamt_items_3(self):
948        h = hamt()
949        self.assertEqual(len(h.items()), 0)
950        self.assertEqual(list(h.items()), [])
951
952    def test_hamt_eq_1(self):
953        A = HashKey(100, 'A')
954        B = HashKey(101, 'B')
955        C = HashKey(100100, 'C')
956        D = HashKey(100100, 'D')
957        E = HashKey(120, 'E')
958
959        h1 = hamt()
960        h1 = h1.set(A, 'a')
961        h1 = h1.set(B, 'b')
962        h1 = h1.set(C, 'c')
963        h1 = h1.set(D, 'd')
964
965        h2 = hamt()
966        h2 = h2.set(A, 'a')
967
968        self.assertFalse(h1 == h2)
969        self.assertTrue(h1 != h2)
970
971        h2 = h2.set(B, 'b')
972        self.assertFalse(h1 == h2)
973        self.assertTrue(h1 != h2)
974
975        h2 = h2.set(C, 'c')
976        self.assertFalse(h1 == h2)
977        self.assertTrue(h1 != h2)
978
979        h2 = h2.set(D, 'd2')
980        self.assertFalse(h1 == h2)
981        self.assertTrue(h1 != h2)
982
983        h2 = h2.set(D, 'd')
984        self.assertTrue(h1 == h2)
985        self.assertFalse(h1 != h2)
986
987        h2 = h2.set(E, 'e')
988        self.assertFalse(h1 == h2)
989        self.assertTrue(h1 != h2)
990
991        h2 = h2.delete(D)
992        self.assertFalse(h1 == h2)
993        self.assertTrue(h1 != h2)
994
995        h2 = h2.set(E, 'd')
996        self.assertFalse(h1 == h2)
997        self.assertTrue(h1 != h2)
998
999    def test_hamt_eq_2(self):
1000        A = HashKey(100, 'A')
1001        Er = HashKey(100, 'Er', error_on_eq_to=A)
1002
1003        h1 = hamt()
1004        h1 = h1.set(A, 'a')
1005
1006        h2 = hamt()
1007        h2 = h2.set(Er, 'a')
1008
1009        with self.assertRaisesRegex(ValueError, 'cannot compare'):
1010            h1 == h2
1011
1012        with self.assertRaisesRegex(ValueError, 'cannot compare'):
1013            h1 != h2
1014
1015    def test_hamt_gc_1(self):
1016        A = HashKey(100, 'A')
1017
1018        h = hamt()
1019        h = h.set(0, 0)  # empty HAMT node is memoized in hamt.c
1020        ref = weakref.ref(h)
1021
1022        a = []
1023        a.append(a)
1024        a.append(h)
1025        b = []
1026        a.append(b)
1027        b.append(a)
1028        h = h.set(A, b)
1029
1030        del h, a, b
1031
1032        gc.collect()
1033        gc.collect()
1034        gc.collect()
1035
1036        self.assertIsNone(ref())
1037
1038    def test_hamt_gc_2(self):
1039        A = HashKey(100, 'A')
1040        B = HashKey(101, 'B')
1041
1042        h = hamt()
1043        h = h.set(A, 'a')
1044        h = h.set(A, h)
1045
1046        ref = weakref.ref(h)
1047        hi = h.items()
1048        next(hi)
1049
1050        del h, hi
1051
1052        gc.collect()
1053        gc.collect()
1054        gc.collect()
1055
1056        self.assertIsNone(ref())
1057
1058    def test_hamt_in_1(self):
1059        A = HashKey(100, 'A')
1060        AA = HashKey(100, 'A')
1061
1062        B = HashKey(101, 'B')
1063
1064        h = hamt()
1065        h = h.set(A, 1)
1066
1067        self.assertTrue(A in h)
1068        self.assertFalse(B in h)
1069
1070        with self.assertRaises(EqError):
1071            with HaskKeyCrasher(error_on_eq=True):
1072                AA in h
1073
1074        with self.assertRaises(HashingError):
1075            with HaskKeyCrasher(error_on_hash=True):
1076                AA in h
1077
1078    def test_hamt_getitem_1(self):
1079        A = HashKey(100, 'A')
1080        AA = HashKey(100, 'A')
1081
1082        B = HashKey(101, 'B')
1083
1084        h = hamt()
1085        h = h.set(A, 1)
1086
1087        self.assertEqual(h[A], 1)
1088        self.assertEqual(h[AA], 1)
1089
1090        with self.assertRaises(KeyError):
1091            h[B]
1092
1093        with self.assertRaises(EqError):
1094            with HaskKeyCrasher(error_on_eq=True):
1095                h[AA]
1096
1097        with self.assertRaises(HashingError):
1098            with HaskKeyCrasher(error_on_hash=True):
1099                h[AA]
1100
1101
1102if __name__ == "__main__":
1103    unittest.main()
1104