1# Test iterators.
2
3import sys
4import unittest
5from test.support import cpython_only
6from test.support.os_helper import TESTFN, unlink
7from test.support import check_free_after_iterating, ALWAYS_EQ, NEVER_EQ
8import pickle
9import collections.abc
10import functools
11import contextlib
12import builtins
13
14# Test result of triple loop (too big to inline)
15TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
16            (0, 1, 0), (0, 1, 1), (0, 1, 2),
17            (0, 2, 0), (0, 2, 1), (0, 2, 2),
18
19            (1, 0, 0), (1, 0, 1), (1, 0, 2),
20            (1, 1, 0), (1, 1, 1), (1, 1, 2),
21            (1, 2, 0), (1, 2, 1), (1, 2, 2),
22
23            (2, 0, 0), (2, 0, 1), (2, 0, 2),
24            (2, 1, 0), (2, 1, 1), (2, 1, 2),
25            (2, 2, 0), (2, 2, 1), (2, 2, 2)]
26
27# Helper classes
28
29class BasicIterClass:
30    def __init__(self, n):
31        self.n = n
32        self.i = 0
33    def __next__(self):
34        res = self.i
35        if res >= self.n:
36            raise StopIteration
37        self.i = res + 1
38        return res
39    def __iter__(self):
40        return self
41
42class IteratingSequenceClass:
43    def __init__(self, n):
44        self.n = n
45    def __iter__(self):
46        return BasicIterClass(self.n)
47
48class IteratorProxyClass:
49    def __init__(self, i):
50        self.i = i
51    def __next__(self):
52        return next(self.i)
53    def __iter__(self):
54        return self
55
56class SequenceClass:
57    def __init__(self, n):
58        self.n = n
59    def __getitem__(self, i):
60        if 0 <= i < self.n:
61            return i
62        else:
63            raise IndexError
64
65class SequenceProxyClass:
66    def __init__(self, s):
67        self.s = s
68    def __getitem__(self, i):
69        return self.s[i]
70
71class UnlimitedSequenceClass:
72    def __getitem__(self, i):
73        return i
74
75class DefaultIterClass:
76    pass
77
78class NoIterClass:
79    def __getitem__(self, i):
80        return i
81    __iter__ = None
82
83class BadIterableClass:
84    def __iter__(self):
85        raise ZeroDivisionError
86
87class CallableIterClass:
88    def __init__(self):
89        self.i = 0
90    def __call__(self):
91        i = self.i
92        self.i = i + 1
93        if i > 100:
94            raise IndexError # Emergency stop
95        return i
96
97class EmptyIterClass:
98    def __len__(self):
99        return 0
100    def __getitem__(self, i):
101        raise StopIteration
102
103# Main test suite
104
105class TestCase(unittest.TestCase):
106
107    # Helper to check that an iterator returns a given sequence
108    def check_iterator(self, it, seq, pickle=True):
109        if pickle:
110            self.check_pickle(it, seq)
111        res = []
112        while 1:
113            try:
114                val = next(it)
115            except StopIteration:
116                break
117            res.append(val)
118        self.assertEqual(res, seq)
119
120    # Helper to check that a for loop generates a given sequence
121    def check_for_loop(self, expr, seq, pickle=True):
122        if pickle:
123            self.check_pickle(iter(expr), seq)
124        res = []
125        for val in expr:
126            res.append(val)
127        self.assertEqual(res, seq)
128
129    # Helper to check picklability
130    def check_pickle(self, itorg, seq):
131        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
132            d = pickle.dumps(itorg, proto)
133            it = pickle.loads(d)
134            # Cannot assert type equality because dict iterators unpickle as list
135            # iterators.
136            # self.assertEqual(type(itorg), type(it))
137            self.assertTrue(isinstance(it, collections.abc.Iterator))
138            self.assertEqual(list(it), seq)
139
140            it = pickle.loads(d)
141            try:
142                next(it)
143            except StopIteration:
144                continue
145            d = pickle.dumps(it, proto)
146            it = pickle.loads(d)
147            self.assertEqual(list(it), seq[1:])
148
149    # Test basic use of iter() function
150    def test_iter_basic(self):
151        self.check_iterator(iter(range(10)), list(range(10)))
152
153    # Test that iter(iter(x)) is the same as iter(x)
154    def test_iter_idempotency(self):
155        seq = list(range(10))
156        it = iter(seq)
157        it2 = iter(it)
158        self.assertTrue(it is it2)
159
160    # Test that for loops over iterators work
161    def test_iter_for_loop(self):
162        self.check_for_loop(iter(range(10)), list(range(10)))
163
164    # Test several independent iterators over the same list
165    def test_iter_independence(self):
166        seq = range(3)
167        res = []
168        for i in iter(seq):
169            for j in iter(seq):
170                for k in iter(seq):
171                    res.append((i, j, k))
172        self.assertEqual(res, TRIPLETS)
173
174    # Test triple list comprehension using iterators
175    def test_nested_comprehensions_iter(self):
176        seq = range(3)
177        res = [(i, j, k)
178               for i in iter(seq) for j in iter(seq) for k in iter(seq)]
179        self.assertEqual(res, TRIPLETS)
180
181    # Test triple list comprehension without iterators
182    def test_nested_comprehensions_for(self):
183        seq = range(3)
184        res = [(i, j, k) for i in seq for j in seq for k in seq]
185        self.assertEqual(res, TRIPLETS)
186
187    # Test a class with __iter__ in a for loop
188    def test_iter_class_for(self):
189        self.check_for_loop(IteratingSequenceClass(10), list(range(10)))
190
191    # Test a class with __iter__ with explicit iter()
192    def test_iter_class_iter(self):
193        self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10)))
194
195    # Test for loop on a sequence class without __iter__
196    def test_seq_class_for(self):
197        self.check_for_loop(SequenceClass(10), list(range(10)))
198
199    # Test iter() on a sequence class without __iter__
200    def test_seq_class_iter(self):
201        self.check_iterator(iter(SequenceClass(10)), list(range(10)))
202
203    def test_mutating_seq_class_iter_pickle(self):
204        orig = SequenceClass(5)
205        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
206            # initial iterator
207            itorig = iter(orig)
208            d = pickle.dumps((itorig, orig), proto)
209            it, seq = pickle.loads(d)
210            seq.n = 7
211            self.assertIs(type(it), type(itorig))
212            self.assertEqual(list(it), list(range(7)))
213
214            # running iterator
215            next(itorig)
216            d = pickle.dumps((itorig, orig), proto)
217            it, seq = pickle.loads(d)
218            seq.n = 7
219            self.assertIs(type(it), type(itorig))
220            self.assertEqual(list(it), list(range(1, 7)))
221
222            # empty iterator
223            for i in range(1, 5):
224                next(itorig)
225            d = pickle.dumps((itorig, orig), proto)
226            it, seq = pickle.loads(d)
227            seq.n = 7
228            self.assertIs(type(it), type(itorig))
229            self.assertEqual(list(it), list(range(5, 7)))
230
231            # exhausted iterator
232            self.assertRaises(StopIteration, next, itorig)
233            d = pickle.dumps((itorig, orig), proto)
234            it, seq = pickle.loads(d)
235            seq.n = 7
236            self.assertTrue(isinstance(it, collections.abc.Iterator))
237            self.assertEqual(list(it), [])
238
239    def test_mutating_seq_class_exhausted_iter(self):
240        a = SequenceClass(5)
241        exhit = iter(a)
242        empit = iter(a)
243        for x in exhit:  # exhaust the iterator
244            next(empit)  # not exhausted
245        a.n = 7
246        self.assertEqual(list(exhit), [])
247        self.assertEqual(list(empit), [5, 6])
248        self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6])
249
250    def test_reduce_mutating_builtins_iter(self):
251        # This is a reproducer of issue #101765
252        # where iter `__reduce__` calls could lead to a segfault or SystemError
253        # depending on the order of C argument evaluation, which is undefined
254
255        # Backup builtins
256        builtins_dict = builtins.__dict__
257        orig = {"iter": iter, "reversed": reversed}
258
259        def run(builtin_name, item, sentinel=None):
260            it = iter(item) if sentinel is None else iter(item, sentinel)
261
262            class CustomStr:
263                def __init__(self, name, iterator):
264                    self.name = name
265                    self.iterator = iterator
266                def __hash__(self):
267                    return hash(self.name)
268                def __eq__(self, other):
269                    # Here we exhaust our iterator, possibly changing
270                    # its `it_seq` pointer to NULL
271                    # The `__reduce__` call should correctly get
272                    # the pointers after this call
273                    list(self.iterator)
274                    return other == self.name
275
276            # del is required here
277            # to not prematurely call __eq__ from
278            # the hash collision with the old key
279            del builtins_dict[builtin_name]
280            builtins_dict[CustomStr(builtin_name, it)] = orig[builtin_name]
281
282            return it.__reduce__()
283
284        types = [
285            (EmptyIterClass(),),
286            (bytes(8),),
287            (bytearray(8),),
288            ((1, 2, 3),),
289            (lambda: 0, 0),
290            (tuple[int],)  # GenericAlias
291        ]
292
293        try:
294            run_iter = functools.partial(run, "iter")
295            # The returned value of `__reduce__` should not only be valid
296            # but also *empty*, as `it` was exhausted during `__eq__`
297            # i.e "xyz" returns (iter, ("",))
298            self.assertEqual(run_iter("xyz"), (orig["iter"], ("",)))
299            self.assertEqual(run_iter([1, 2, 3]), (orig["iter"], ([],)))
300
301            # _PyEval_GetBuiltin is also called for `reversed` in a branch of
302            # listiter_reduce_general
303            self.assertEqual(
304                run("reversed", orig["reversed"](list(range(8)))),
305                (iter, ([],))
306            )
307
308            for case in types:
309                self.assertEqual(run_iter(*case), (orig["iter"], ((),)))
310        finally:
311            # Restore original builtins
312            for key, func in orig.items():
313                # need to suppress KeyErrors in case
314                # a failed test deletes the key without setting anything
315                with contextlib.suppress(KeyError):
316                    # del is required here
317                    # to not invoke our custom __eq__ from
318                    # the hash collision with the old key
319                    del builtins_dict[key]
320                builtins_dict[key] = func
321
322    # Test a new_style class with __iter__ but no next() method
323    def test_new_style_iter_class(self):
324        class IterClass(object):
325            def __iter__(self):
326                return self
327        self.assertRaises(TypeError, iter, IterClass())
328
329    # Test two-argument iter() with callable instance
330    def test_iter_callable(self):
331        self.check_iterator(iter(CallableIterClass(), 10), list(range(10)), pickle=True)
332
333    # Test two-argument iter() with function
334    def test_iter_function(self):
335        def spam(state=[0]):
336            i = state[0]
337            state[0] = i+1
338            return i
339        self.check_iterator(iter(spam, 10), list(range(10)), pickle=False)
340
341    # Test two-argument iter() with function that raises StopIteration
342    def test_iter_function_stop(self):
343        def spam(state=[0]):
344            i = state[0]
345            if i == 10:
346                raise StopIteration
347            state[0] = i+1
348            return i
349        self.check_iterator(iter(spam, 20), list(range(10)), pickle=False)
350
351    def test_iter_function_concealing_reentrant_exhaustion(self):
352        # gh-101892: Test two-argument iter() with a function that
353        # exhausts its associated iterator but forgets to either return
354        # a sentinel value or raise StopIteration.
355        HAS_MORE = 1
356        NO_MORE = 2
357
358        def exhaust(iterator):
359            """Exhaust an iterator without raising StopIteration."""
360            list(iterator)
361
362        def spam():
363            # Touching the iterator with exhaust() below will call
364            # spam() once again so protect against recursion.
365            if spam.is_recursive_call:
366                return NO_MORE
367            spam.is_recursive_call = True
368            exhaust(spam.iterator)
369            return HAS_MORE
370
371        spam.is_recursive_call = False
372        spam.iterator = iter(spam, NO_MORE)
373        with self.assertRaises(StopIteration):
374            next(spam.iterator)
375
376    # Test exception propagation through function iterator
377    def test_exception_function(self):
378        def spam(state=[0]):
379            i = state[0]
380            state[0] = i+1
381            if i == 10:
382                raise RuntimeError
383            return i
384        res = []
385        try:
386            for x in iter(spam, 20):
387                res.append(x)
388        except RuntimeError:
389            self.assertEqual(res, list(range(10)))
390        else:
391            self.fail("should have raised RuntimeError")
392
393    # Test exception propagation through sequence iterator
394    def test_exception_sequence(self):
395        class MySequenceClass(SequenceClass):
396            def __getitem__(self, i):
397                if i == 10:
398                    raise RuntimeError
399                return SequenceClass.__getitem__(self, i)
400        res = []
401        try:
402            for x in MySequenceClass(20):
403                res.append(x)
404        except RuntimeError:
405            self.assertEqual(res, list(range(10)))
406        else:
407            self.fail("should have raised RuntimeError")
408
409    # Test for StopIteration from __getitem__
410    def test_stop_sequence(self):
411        class MySequenceClass(SequenceClass):
412            def __getitem__(self, i):
413                if i == 10:
414                    raise StopIteration
415                return SequenceClass.__getitem__(self, i)
416        self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
417
418    # Test a big range
419    def test_iter_big_range(self):
420        self.check_for_loop(iter(range(10000)), list(range(10000)))
421
422    # Test an empty list
423    def test_iter_empty(self):
424        self.check_for_loop(iter([]), [])
425
426    # Test a tuple
427    def test_iter_tuple(self):
428        self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), list(range(10)))
429
430    # Test a range
431    def test_iter_range(self):
432        self.check_for_loop(iter(range(10)), list(range(10)))
433
434    # Test a string
435    def test_iter_string(self):
436        self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
437
438    # Test a directory
439    def test_iter_dict(self):
440        dict = {}
441        for i in range(10):
442            dict[i] = None
443        self.check_for_loop(dict, list(dict.keys()))
444
445    # Test a file
446    def test_iter_file(self):
447        f = open(TESTFN, "w", encoding="utf-8")
448        try:
449            for i in range(5):
450                f.write("%d\n" % i)
451        finally:
452            f.close()
453        f = open(TESTFN, "r", encoding="utf-8")
454        try:
455            self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"], pickle=False)
456            self.check_for_loop(f, [], pickle=False)
457        finally:
458            f.close()
459            try:
460                unlink(TESTFN)
461            except OSError:
462                pass
463
464    # Test list()'s use of iterators.
465    def test_builtin_list(self):
466        self.assertEqual(list(SequenceClass(5)), list(range(5)))
467        self.assertEqual(list(SequenceClass(0)), [])
468        self.assertEqual(list(()), [])
469
470        d = {"one": 1, "two": 2, "three": 3}
471        self.assertEqual(list(d), list(d.keys()))
472
473        self.assertRaises(TypeError, list, list)
474        self.assertRaises(TypeError, list, 42)
475
476        f = open(TESTFN, "w", encoding="utf-8")
477        try:
478            for i in range(5):
479                f.write("%d\n" % i)
480        finally:
481            f.close()
482        f = open(TESTFN, "r", encoding="utf-8")
483        try:
484            self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
485            f.seek(0, 0)
486            self.assertEqual(list(f),
487                             ["0\n", "1\n", "2\n", "3\n", "4\n"])
488        finally:
489            f.close()
490            try:
491                unlink(TESTFN)
492            except OSError:
493                pass
494
495    # Test tuples()'s use of iterators.
496    def test_builtin_tuple(self):
497        self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
498        self.assertEqual(tuple(SequenceClass(0)), ())
499        self.assertEqual(tuple([]), ())
500        self.assertEqual(tuple(()), ())
501        self.assertEqual(tuple("abc"), ("a", "b", "c"))
502
503        d = {"one": 1, "two": 2, "three": 3}
504        self.assertEqual(tuple(d), tuple(d.keys()))
505
506        self.assertRaises(TypeError, tuple, list)
507        self.assertRaises(TypeError, tuple, 42)
508
509        f = open(TESTFN, "w", encoding="utf-8")
510        try:
511            for i in range(5):
512                f.write("%d\n" % i)
513        finally:
514            f.close()
515        f = open(TESTFN, "r", encoding="utf-8")
516        try:
517            self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
518            f.seek(0, 0)
519            self.assertEqual(tuple(f),
520                             ("0\n", "1\n", "2\n", "3\n", "4\n"))
521        finally:
522            f.close()
523            try:
524                unlink(TESTFN)
525            except OSError:
526                pass
527
528    # Test filter()'s use of iterators.
529    def test_builtin_filter(self):
530        self.assertEqual(list(filter(None, SequenceClass(5))),
531                         list(range(1, 5)))
532        self.assertEqual(list(filter(None, SequenceClass(0))), [])
533        self.assertEqual(list(filter(None, ())), [])
534        self.assertEqual(list(filter(None, "abc")), ["a", "b", "c"])
535
536        d = {"one": 1, "two": 2, "three": 3}
537        self.assertEqual(list(filter(None, d)), list(d.keys()))
538
539        self.assertRaises(TypeError, filter, None, list)
540        self.assertRaises(TypeError, filter, None, 42)
541
542        class Boolean:
543            def __init__(self, truth):
544                self.truth = truth
545            def __bool__(self):
546                return self.truth
547        bTrue = Boolean(True)
548        bFalse = Boolean(False)
549
550        class Seq:
551            def __init__(self, *args):
552                self.vals = args
553            def __iter__(self):
554                class SeqIter:
555                    def __init__(self, vals):
556                        self.vals = vals
557                        self.i = 0
558                    def __iter__(self):
559                        return self
560                    def __next__(self):
561                        i = self.i
562                        self.i = i + 1
563                        if i < len(self.vals):
564                            return self.vals[i]
565                        else:
566                            raise StopIteration
567                return SeqIter(self.vals)
568
569        seq = Seq(*([bTrue, bFalse] * 25))
570        self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
571        self.assertEqual(list(filter(lambda x: not x, iter(seq))), [bFalse]*25)
572
573    # Test max() and min()'s use of iterators.
574    def test_builtin_max_min(self):
575        self.assertEqual(max(SequenceClass(5)), 4)
576        self.assertEqual(min(SequenceClass(5)), 0)
577        self.assertEqual(max(8, -1), 8)
578        self.assertEqual(min(8, -1), -1)
579
580        d = {"one": 1, "two": 2, "three": 3}
581        self.assertEqual(max(d), "two")
582        self.assertEqual(min(d), "one")
583        self.assertEqual(max(d.values()), 3)
584        self.assertEqual(min(iter(d.values())), 1)
585
586        f = open(TESTFN, "w", encoding="utf-8")
587        try:
588            f.write("medium line\n")
589            f.write("xtra large line\n")
590            f.write("itty-bitty line\n")
591        finally:
592            f.close()
593        f = open(TESTFN, "r", encoding="utf-8")
594        try:
595            self.assertEqual(min(f), "itty-bitty line\n")
596            f.seek(0, 0)
597            self.assertEqual(max(f), "xtra large line\n")
598        finally:
599            f.close()
600            try:
601                unlink(TESTFN)
602            except OSError:
603                pass
604
605    # Test map()'s use of iterators.
606    def test_builtin_map(self):
607        self.assertEqual(list(map(lambda x: x+1, SequenceClass(5))),
608                         list(range(1, 6)))
609
610        d = {"one": 1, "two": 2, "three": 3}
611        self.assertEqual(list(map(lambda k, d=d: (k, d[k]), d)),
612                         list(d.items()))
613        dkeys = list(d.keys())
614        expected = [(i < len(d) and dkeys[i] or None,
615                     i,
616                     i < len(d) and dkeys[i] or None)
617                    for i in range(3)]
618
619        f = open(TESTFN, "w", encoding="utf-8")
620        try:
621            for i in range(10):
622                f.write("xy" * i + "\n") # line i has len 2*i+1
623        finally:
624            f.close()
625        f = open(TESTFN, "r", encoding="utf-8")
626        try:
627            self.assertEqual(list(map(len, f)), list(range(1, 21, 2)))
628        finally:
629            f.close()
630            try:
631                unlink(TESTFN)
632            except OSError:
633                pass
634
635    # Test zip()'s use of iterators.
636    def test_builtin_zip(self):
637        self.assertEqual(list(zip()), [])
638        self.assertEqual(list(zip(*[])), [])
639        self.assertEqual(list(zip(*[(1, 2), 'ab'])), [(1, 'a'), (2, 'b')])
640
641        self.assertRaises(TypeError, zip, None)
642        self.assertRaises(TypeError, zip, range(10), 42)
643        self.assertRaises(TypeError, zip, range(10), zip)
644
645        self.assertEqual(list(zip(IteratingSequenceClass(3))),
646                         [(0,), (1,), (2,)])
647        self.assertEqual(list(zip(SequenceClass(3))),
648                         [(0,), (1,), (2,)])
649
650        d = {"one": 1, "two": 2, "three": 3}
651        self.assertEqual(list(d.items()), list(zip(d, d.values())))
652
653        # Generate all ints starting at constructor arg.
654        class IntsFrom:
655            def __init__(self, start):
656                self.i = start
657
658            def __iter__(self):
659                return self
660
661            def __next__(self):
662                i = self.i
663                self.i = i+1
664                return i
665
666        f = open(TESTFN, "w", encoding="utf-8")
667        try:
668            f.write("a\n" "bbb\n" "cc\n")
669        finally:
670            f.close()
671        f = open(TESTFN, "r", encoding="utf-8")
672        try:
673            self.assertEqual(list(zip(IntsFrom(0), f, IntsFrom(-100))),
674                             [(0, "a\n", -100),
675                              (1, "bbb\n", -99),
676                              (2, "cc\n", -98)])
677        finally:
678            f.close()
679            try:
680                unlink(TESTFN)
681            except OSError:
682                pass
683
684        self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
685
686        # Classes that lie about their lengths.
687        class NoGuessLen5:
688            def __getitem__(self, i):
689                if i >= 5:
690                    raise IndexError
691                return i
692
693        class Guess3Len5(NoGuessLen5):
694            def __len__(self):
695                return 3
696
697        class Guess30Len5(NoGuessLen5):
698            def __len__(self):
699                return 30
700
701        def lzip(*args):
702            return list(zip(*args))
703
704        self.assertEqual(len(Guess3Len5()), 3)
705        self.assertEqual(len(Guess30Len5()), 30)
706        self.assertEqual(lzip(NoGuessLen5()), lzip(range(5)))
707        self.assertEqual(lzip(Guess3Len5()), lzip(range(5)))
708        self.assertEqual(lzip(Guess30Len5()), lzip(range(5)))
709
710        expected = [(i, i) for i in range(5)]
711        for x in NoGuessLen5(), Guess3Len5(), Guess30Len5():
712            for y in NoGuessLen5(), Guess3Len5(), Guess30Len5():
713                self.assertEqual(lzip(x, y), expected)
714
715    def test_unicode_join_endcase(self):
716
717        # This class inserts a Unicode object into its argument's natural
718        # iteration, in the 3rd position.
719        class OhPhooey:
720            def __init__(self, seq):
721                self.it = iter(seq)
722                self.i = 0
723
724            def __iter__(self):
725                return self
726
727            def __next__(self):
728                i = self.i
729                self.i = i+1
730                if i == 2:
731                    return "fooled you!"
732                return next(self.it)
733
734        f = open(TESTFN, "w", encoding="utf-8")
735        try:
736            f.write("a\n" + "b\n" + "c\n")
737        finally:
738            f.close()
739
740        f = open(TESTFN, "r", encoding="utf-8")
741        # Nasty:  string.join(s) can't know whether unicode.join() is needed
742        # until it's seen all of s's elements.  But in this case, f's
743        # iterator cannot be restarted.  So what we're testing here is
744        # whether string.join() can manage to remember everything it's seen
745        # and pass that on to unicode.join().
746        try:
747            got = " - ".join(OhPhooey(f))
748            self.assertEqual(got, "a\n - b\n - fooled you! - c\n")
749        finally:
750            f.close()
751            try:
752                unlink(TESTFN)
753            except OSError:
754                pass
755
756    # Test iterators with 'x in y' and 'x not in y'.
757    def test_in_and_not_in(self):
758        for sc5 in IteratingSequenceClass(5), SequenceClass(5):
759            for i in range(5):
760                self.assertIn(i, sc5)
761            for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
762                self.assertNotIn(i, sc5)
763
764        self.assertIn(ALWAYS_EQ, IteratorProxyClass(iter([1])))
765        self.assertIn(ALWAYS_EQ, SequenceProxyClass([1]))
766        self.assertNotIn(ALWAYS_EQ, IteratorProxyClass(iter([NEVER_EQ])))
767        self.assertNotIn(ALWAYS_EQ, SequenceProxyClass([NEVER_EQ]))
768        self.assertIn(NEVER_EQ, IteratorProxyClass(iter([ALWAYS_EQ])))
769        self.assertIn(NEVER_EQ, SequenceProxyClass([ALWAYS_EQ]))
770
771        self.assertRaises(TypeError, lambda: 3 in 12)
772        self.assertRaises(TypeError, lambda: 3 not in map)
773        self.assertRaises(ZeroDivisionError, lambda: 3 in BadIterableClass())
774
775        d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
776        for k in d:
777            self.assertIn(k, d)
778            self.assertNotIn(k, d.values())
779        for v in d.values():
780            self.assertIn(v, d.values())
781            self.assertNotIn(v, d)
782        for k, v in d.items():
783            self.assertIn((k, v), d.items())
784            self.assertNotIn((v, k), d.items())
785
786        f = open(TESTFN, "w", encoding="utf-8")
787        try:
788            f.write("a\n" "b\n" "c\n")
789        finally:
790            f.close()
791        f = open(TESTFN, "r", encoding="utf-8")
792        try:
793            for chunk in "abc":
794                f.seek(0, 0)
795                self.assertNotIn(chunk, f)
796                f.seek(0, 0)
797                self.assertIn((chunk + "\n"), f)
798        finally:
799            f.close()
800            try:
801                unlink(TESTFN)
802            except OSError:
803                pass
804
805    # Test iterators with operator.countOf (PySequence_Count).
806    def test_countOf(self):
807        from operator import countOf
808        self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
809        self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
810        self.assertEqual(countOf("122325", "2"), 3)
811        self.assertEqual(countOf("122325", "6"), 0)
812
813        self.assertRaises(TypeError, countOf, 42, 1)
814        self.assertRaises(TypeError, countOf, countOf, countOf)
815
816        d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
817        for k in d:
818            self.assertEqual(countOf(d, k), 1)
819        self.assertEqual(countOf(d.values(), 3), 3)
820        self.assertEqual(countOf(d.values(), 2j), 1)
821        self.assertEqual(countOf(d.values(), 1j), 0)
822
823        f = open(TESTFN, "w", encoding="utf-8")
824        try:
825            f.write("a\n" "b\n" "c\n" "b\n")
826        finally:
827            f.close()
828        f = open(TESTFN, "r", encoding="utf-8")
829        try:
830            for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
831                f.seek(0, 0)
832                self.assertEqual(countOf(f, letter + "\n"), count)
833        finally:
834            f.close()
835            try:
836                unlink(TESTFN)
837            except OSError:
838                pass
839
840    # Test iterators with operator.indexOf (PySequence_Index).
841    def test_indexOf(self):
842        from operator import indexOf
843        self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
844        self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
845        self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
846        self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
847        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
848        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
849
850        self.assertEqual(indexOf("122325", "2"), 1)
851        self.assertEqual(indexOf("122325", "5"), 5)
852        self.assertRaises(ValueError, indexOf, "122325", "6")
853
854        self.assertRaises(TypeError, indexOf, 42, 1)
855        self.assertRaises(TypeError, indexOf, indexOf, indexOf)
856        self.assertRaises(ZeroDivisionError, indexOf, BadIterableClass(), 1)
857
858        f = open(TESTFN, "w", encoding="utf-8")
859        try:
860            f.write("a\n" "b\n" "c\n" "d\n" "e\n")
861        finally:
862            f.close()
863        f = open(TESTFN, "r", encoding="utf-8")
864        try:
865            fiter = iter(f)
866            self.assertEqual(indexOf(fiter, "b\n"), 1)
867            self.assertEqual(indexOf(fiter, "d\n"), 1)
868            self.assertEqual(indexOf(fiter, "e\n"), 0)
869            self.assertRaises(ValueError, indexOf, fiter, "a\n")
870        finally:
871            f.close()
872            try:
873                unlink(TESTFN)
874            except OSError:
875                pass
876
877        iclass = IteratingSequenceClass(3)
878        for i in range(3):
879            self.assertEqual(indexOf(iclass, i), i)
880        self.assertRaises(ValueError, indexOf, iclass, -1)
881
882    # Test iterators with file.writelines().
883    def test_writelines(self):
884        f = open(TESTFN, "w", encoding="utf-8")
885
886        try:
887            self.assertRaises(TypeError, f.writelines, None)
888            self.assertRaises(TypeError, f.writelines, 42)
889
890            f.writelines(["1\n", "2\n"])
891            f.writelines(("3\n", "4\n"))
892            f.writelines({'5\n': None})
893            f.writelines({})
894
895            # Try a big chunk too.
896            class Iterator:
897                def __init__(self, start, finish):
898                    self.start = start
899                    self.finish = finish
900                    self.i = self.start
901
902                def __next__(self):
903                    if self.i >= self.finish:
904                        raise StopIteration
905                    result = str(self.i) + '\n'
906                    self.i += 1
907                    return result
908
909                def __iter__(self):
910                    return self
911
912            class Whatever:
913                def __init__(self, start, finish):
914                    self.start = start
915                    self.finish = finish
916
917                def __iter__(self):
918                    return Iterator(self.start, self.finish)
919
920            f.writelines(Whatever(6, 6+2000))
921            f.close()
922
923            f = open(TESTFN, encoding="utf-8")
924            expected = [str(i) + "\n" for i in range(1, 2006)]
925            self.assertEqual(list(f), expected)
926
927        finally:
928            f.close()
929            try:
930                unlink(TESTFN)
931            except OSError:
932                pass
933
934
935    # Test iterators on RHS of unpacking assignments.
936    def test_unpack_iter(self):
937        a, b = 1, 2
938        self.assertEqual((a, b), (1, 2))
939
940        a, b, c = IteratingSequenceClass(3)
941        self.assertEqual((a, b, c), (0, 1, 2))
942
943        try:    # too many values
944            a, b = IteratingSequenceClass(3)
945        except ValueError:
946            pass
947        else:
948            self.fail("should have raised ValueError")
949
950        try:    # not enough values
951            a, b, c = IteratingSequenceClass(2)
952        except ValueError:
953            pass
954        else:
955            self.fail("should have raised ValueError")
956
957        try:    # not iterable
958            a, b, c = len
959        except TypeError:
960            pass
961        else:
962            self.fail("should have raised TypeError")
963
964        a, b, c = {1: 42, 2: 42, 3: 42}.values()
965        self.assertEqual((a, b, c), (42, 42, 42))
966
967        f = open(TESTFN, "w", encoding="utf-8")
968        lines = ("a\n", "bb\n", "ccc\n")
969        try:
970            for line in lines:
971                f.write(line)
972        finally:
973            f.close()
974        f = open(TESTFN, "r", encoding="utf-8")
975        try:
976            a, b, c = f
977            self.assertEqual((a, b, c), lines)
978        finally:
979            f.close()
980            try:
981                unlink(TESTFN)
982            except OSError:
983                pass
984
985        (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
986        self.assertEqual((a, b, c), (0, 1, 42))
987
988
989    @cpython_only
990    def test_ref_counting_behavior(self):
991        class C(object):
992            count = 0
993            def __new__(cls):
994                cls.count += 1
995                return object.__new__(cls)
996            def __del__(self):
997                cls = self.__class__
998                assert cls.count > 0
999                cls.count -= 1
1000        x = C()
1001        self.assertEqual(C.count, 1)
1002        del x
1003        self.assertEqual(C.count, 0)
1004        l = [C(), C(), C()]
1005        self.assertEqual(C.count, 3)
1006        try:
1007            a, b = iter(l)
1008        except ValueError:
1009            pass
1010        del l
1011        self.assertEqual(C.count, 0)
1012
1013
1014    # Make sure StopIteration is a "sink state".
1015    # This tests various things that weren't sink states in Python 2.2.1,
1016    # plus various things that always were fine.
1017
1018    def test_sinkstate_list(self):
1019        # This used to fail
1020        a = list(range(5))
1021        b = iter(a)
1022        self.assertEqual(list(b), list(range(5)))
1023        a.extend(range(5, 10))
1024        self.assertEqual(list(b), [])
1025
1026    def test_sinkstate_tuple(self):
1027        a = (0, 1, 2, 3, 4)
1028        b = iter(a)
1029        self.assertEqual(list(b), list(range(5)))
1030        self.assertEqual(list(b), [])
1031
1032    def test_sinkstate_string(self):
1033        a = "abcde"
1034        b = iter(a)
1035        self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
1036        self.assertEqual(list(b), [])
1037
1038    def test_sinkstate_sequence(self):
1039        # This used to fail
1040        a = SequenceClass(5)
1041        b = iter(a)
1042        self.assertEqual(list(b), list(range(5)))
1043        a.n = 10
1044        self.assertEqual(list(b), [])
1045
1046    def test_sinkstate_callable(self):
1047        # This used to fail
1048        def spam(state=[0]):
1049            i = state[0]
1050            state[0] = i+1
1051            if i == 10:
1052                raise AssertionError("shouldn't have gotten this far")
1053            return i
1054        b = iter(spam, 5)
1055        self.assertEqual(list(b), list(range(5)))
1056        self.assertEqual(list(b), [])
1057
1058    def test_sinkstate_dict(self):
1059        # XXX For a more thorough test, see towards the end of:
1060        # http://mail.python.org/pipermail/python-dev/2002-July/026512.html
1061        a = {1:1, 2:2, 0:0, 4:4, 3:3}
1062        for b in iter(a), a.keys(), a.items(), a.values():
1063            b = iter(a)
1064            self.assertEqual(len(list(b)), 5)
1065            self.assertEqual(list(b), [])
1066
1067    def test_sinkstate_yield(self):
1068        def gen():
1069            for i in range(5):
1070                yield i
1071        b = gen()
1072        self.assertEqual(list(b), list(range(5)))
1073        self.assertEqual(list(b), [])
1074
1075    def test_sinkstate_range(self):
1076        a = range(5)
1077        b = iter(a)
1078        self.assertEqual(list(b), list(range(5)))
1079        self.assertEqual(list(b), [])
1080
1081    def test_sinkstate_enumerate(self):
1082        a = range(5)
1083        e = enumerate(a)
1084        b = iter(e)
1085        self.assertEqual(list(b), list(zip(range(5), range(5))))
1086        self.assertEqual(list(b), [])
1087
1088    def test_3720(self):
1089        # Avoid a crash, when an iterator deletes its next() method.
1090        class BadIterator(object):
1091            def __iter__(self):
1092                return self
1093            def __next__(self):
1094                del BadIterator.__next__
1095                return 1
1096
1097        try:
1098            for i in BadIterator() :
1099                pass
1100        except TypeError:
1101            pass
1102
1103    def test_extending_list_with_iterator_does_not_segfault(self):
1104        # The code to extend a list with an iterator has a fair
1105        # amount of nontrivial logic in terms of guessing how
1106        # much memory to allocate in advance, "stealing" refs,
1107        # and then shrinking at the end.  This is a basic smoke
1108        # test for that scenario.
1109        def gen():
1110            for i in range(500):
1111                yield i
1112        lst = [0] * 500
1113        for i in range(240):
1114            lst.pop(0)
1115        lst.extend(gen())
1116        self.assertEqual(len(lst), 760)
1117
1118    @cpython_only
1119    def test_iter_overflow(self):
1120        # Test for the issue 22939
1121        it = iter(UnlimitedSequenceClass())
1122        # Manually set `it_index` to PY_SSIZE_T_MAX-2 without a loop
1123        it.__setstate__(sys.maxsize - 2)
1124        self.assertEqual(next(it), sys.maxsize - 2)
1125        self.assertEqual(next(it), sys.maxsize - 1)
1126        with self.assertRaises(OverflowError):
1127            next(it)
1128        # Check that Overflow error is always raised
1129        with self.assertRaises(OverflowError):
1130            next(it)
1131
1132    def test_iter_neg_setstate(self):
1133        it = iter(UnlimitedSequenceClass())
1134        it.__setstate__(-42)
1135        self.assertEqual(next(it), 0)
1136        self.assertEqual(next(it), 1)
1137
1138    def test_free_after_iterating(self):
1139        check_free_after_iterating(self, iter, SequenceClass, (0,))
1140
1141    def test_error_iter(self):
1142        for typ in (DefaultIterClass, NoIterClass):
1143            self.assertRaises(TypeError, iter, typ())
1144        self.assertRaises(ZeroDivisionError, iter, BadIterableClass())
1145
1146
1147if __name__ == "__main__":
1148    unittest.main()
1149