1import unittest
2import operator
3import sys
4import pickle
5import gc
6
7from test import support
8
9class G:
10    'Sequence using __getitem__'
11    def __init__(self, seqn):
12        self.seqn = seqn
13    def __getitem__(self, i):
14        return self.seqn[i]
15
16class I:
17    'Sequence using iterator protocol'
18    def __init__(self, seqn):
19        self.seqn = seqn
20        self.i = 0
21    def __iter__(self):
22        return self
23    def __next__(self):
24        if self.i >= len(self.seqn): raise StopIteration
25        v = self.seqn[self.i]
26        self.i += 1
27        return v
28
29class Ig:
30    'Sequence using iterator protocol defined with a generator'
31    def __init__(self, seqn):
32        self.seqn = seqn
33        self.i = 0
34    def __iter__(self):
35        for val in self.seqn:
36            yield val
37
38class X:
39    'Missing __getitem__ and __iter__'
40    def __init__(self, seqn):
41        self.seqn = seqn
42        self.i = 0
43    def __next__(self):
44        if self.i >= len(self.seqn): raise StopIteration
45        v = self.seqn[self.i]
46        self.i += 1
47        return v
48
49class E:
50    'Test propagation of exceptions'
51    def __init__(self, seqn):
52        self.seqn = seqn
53        self.i = 0
54    def __iter__(self):
55        return self
56    def __next__(self):
57        3 // 0
58
59class N:
60    'Iterator missing __next__()'
61    def __init__(self, seqn):
62        self.seqn = seqn
63        self.i = 0
64    def __iter__(self):
65        return self
66
67class PickleTest:
68    # Helper to check picklability
69    def check_pickle(self, itorg, seq):
70        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
71            d = pickle.dumps(itorg, proto)
72            it = pickle.loads(d)
73            self.assertEqual(type(itorg), type(it))
74            self.assertEqual(list(it), seq)
75
76            it = pickle.loads(d)
77            try:
78                next(it)
79            except StopIteration:
80                self.assertFalse(seq[1:])
81                continue
82            d = pickle.dumps(it, proto)
83            it = pickle.loads(d)
84            self.assertEqual(list(it), seq[1:])
85
86class EnumerateTestCase(unittest.TestCase, PickleTest):
87
88    enum = enumerate
89    seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')]
90
91    def test_basicfunction(self):
92        self.assertEqual(type(self.enum(self.seq)), self.enum)
93        e = self.enum(self.seq)
94        self.assertEqual(iter(e), e)
95        self.assertEqual(list(self.enum(self.seq)), self.res)
96        self.enum.__doc__
97
98    def test_pickle(self):
99        self.check_pickle(self.enum(self.seq), self.res)
100
101    def test_getitemseqn(self):
102        self.assertEqual(list(self.enum(G(self.seq))), self.res)
103        e = self.enum(G(''))
104        self.assertRaises(StopIteration, next, e)
105
106    def test_iteratorseqn(self):
107        self.assertEqual(list(self.enum(I(self.seq))), self.res)
108        e = self.enum(I(''))
109        self.assertRaises(StopIteration, next, e)
110
111    def test_iteratorgenerator(self):
112        self.assertEqual(list(self.enum(Ig(self.seq))), self.res)
113        e = self.enum(Ig(''))
114        self.assertRaises(StopIteration, next, e)
115
116    def test_noniterable(self):
117        self.assertRaises(TypeError, self.enum, X(self.seq))
118
119    def test_illformediterable(self):
120        self.assertRaises(TypeError, self.enum, N(self.seq))
121
122    def test_exception_propagation(self):
123        self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq)))
124
125    def test_argumentcheck(self):
126        self.assertRaises(TypeError, self.enum) # no arguments
127        self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable)
128        self.assertRaises(TypeError, self.enum, 'abc', 'a') # wrong type
129        self.assertRaises(TypeError, self.enum, 'abc', 2, 3) # too many arguments
130
131    def test_kwargs(self):
132        self.assertEqual(list(self.enum(iterable=Ig(self.seq))), self.res)
133        expected = list(self.enum(Ig(self.seq), 0))
134        self.assertEqual(list(self.enum(iterable=Ig(self.seq), start=0)),
135                         expected)
136        self.assertEqual(list(self.enum(start=0, iterable=Ig(self.seq))),
137                         expected)
138        self.assertRaises(TypeError, self.enum, iterable=[], x=3)
139        self.assertRaises(TypeError, self.enum, start=0, x=3)
140        self.assertRaises(TypeError, self.enum, x=0, y=3)
141        self.assertRaises(TypeError, self.enum, x=0)
142
143    @support.cpython_only
144    def test_tuple_reuse(self):
145        # Tests an implementation detail where tuple is reused
146        # whenever nothing else holds a reference to it
147        self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq))
148        self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq)))
149
150    @support.cpython_only
151    def test_enumerate_result_gc(self):
152        # bpo-42536: enumerate's tuple-reuse speed trick breaks the GC's
153        # assumptions about what can be untracked. Make sure we re-track result
154        # tuples whenever we reuse them.
155        it = self.enum([[]])
156        gc.collect()
157        # That GC collection probably untracked the recycled internal result
158        # tuple, which is initialized to (None, None). Make sure it's re-tracked
159        # when it's mutated and returned from __next__:
160        self.assertTrue(gc.is_tracked(next(it)))
161
162class MyEnum(enumerate):
163    pass
164
165class SubclassTestCase(EnumerateTestCase):
166
167    enum = MyEnum
168
169class TestEmpty(EnumerateTestCase):
170
171    seq, res = '', []
172
173class TestBig(EnumerateTestCase):
174
175    seq = range(10,20000,2)
176    res = list(zip(range(20000), seq))
177
178class TestReversed(unittest.TestCase, PickleTest):
179
180    def test_simple(self):
181        class A:
182            def __getitem__(self, i):
183                if i < 5:
184                    return str(i)
185                raise StopIteration
186            def __len__(self):
187                return 5
188        for data in ('abc', range(5), tuple(enumerate('abc')), A(),
189                    range(1,17,5), dict.fromkeys('abcde')):
190            self.assertEqual(list(data)[::-1], list(reversed(data)))
191        # don't allow keyword arguments
192        self.assertRaises(TypeError, reversed, [], a=1)
193
194    def test_range_optimization(self):
195        x = range(1)
196        self.assertEqual(type(reversed(x)), type(iter(x)))
197
198    def test_len(self):
199        for s in ('hello', tuple('hello'), list('hello'), range(5)):
200            self.assertEqual(operator.length_hint(reversed(s)), len(s))
201            r = reversed(s)
202            list(r)
203            self.assertEqual(operator.length_hint(r), 0)
204        class SeqWithWeirdLen:
205            called = False
206            def __len__(self):
207                if not self.called:
208                    self.called = True
209                    return 10
210                raise ZeroDivisionError
211            def __getitem__(self, index):
212                return index
213        r = reversed(SeqWithWeirdLen())
214        self.assertRaises(ZeroDivisionError, operator.length_hint, r)
215
216
217    def test_gc(self):
218        class Seq:
219            def __len__(self):
220                return 10
221            def __getitem__(self, index):
222                return index
223        s = Seq()
224        r = reversed(s)
225        s.r = r
226
227    def test_args(self):
228        self.assertRaises(TypeError, reversed)
229        self.assertRaises(TypeError, reversed, [], 'extra')
230
231    @unittest.skipUnless(hasattr(sys, 'getrefcount'), 'test needs sys.getrefcount()')
232    def test_bug1229429(self):
233        # this bug was never in reversed, it was in
234        # PyObject_CallMethod, and reversed_new calls that sometimes.
235        def f():
236            pass
237        r = f.__reversed__ = object()
238        rc = sys.getrefcount(r)
239        for i in range(10):
240            try:
241                reversed(f)
242            except TypeError:
243                pass
244            else:
245                self.fail("non-callable __reversed__ didn't raise!")
246        self.assertEqual(rc, sys.getrefcount(r))
247
248    def test_objmethods(self):
249        # Objects must have __len__() and __getitem__() implemented.
250        class NoLen(object):
251            def __getitem__(self, i): return 1
252        nl = NoLen()
253        self.assertRaises(TypeError, reversed, nl)
254
255        class NoGetItem(object):
256            def __len__(self): return 2
257        ngi = NoGetItem()
258        self.assertRaises(TypeError, reversed, ngi)
259
260        class Blocked(object):
261            def __getitem__(self, i): return 1
262            def __len__(self): return 2
263            __reversed__ = None
264        b = Blocked()
265        self.assertRaises(TypeError, reversed, b)
266
267    def test_pickle(self):
268        for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5):
269            self.check_pickle(reversed(data), list(data)[::-1])
270
271
272class EnumerateStartTestCase(EnumerateTestCase):
273
274    def test_basicfunction(self):
275        e = self.enum(self.seq)
276        self.assertEqual(iter(e), e)
277        self.assertEqual(list(self.enum(self.seq)), self.res)
278
279
280class TestStart(EnumerateStartTestCase):
281    def enum(self, iterable, start=11):
282        return enumerate(iterable, start=start)
283
284    seq, res = 'abc', [(11, 'a'), (12, 'b'), (13, 'c')]
285
286
287class TestLongStart(EnumerateStartTestCase):
288    def enum(self, iterable, start=sys.maxsize + 1):
289        return enumerate(iterable, start=start)
290
291    seq, res = 'abc', [(sys.maxsize+1,'a'), (sys.maxsize+2,'b'),
292                       (sys.maxsize+3,'c')]
293
294
295if __name__ == "__main__":
296    unittest.main()
297