1import unittest 2import operator 3import sys 4import pickle 5import gc 6 7from test import support 8 9class G: 10 'Sequence using __getitem__' 11 def __init__(self, seqn): 12 self.seqn = seqn 13 def __getitem__(self, i): 14 return self.seqn[i] 15 16class I: 17 'Sequence using iterator protocol' 18 def __init__(self, seqn): 19 self.seqn = seqn 20 self.i = 0 21 def __iter__(self): 22 return self 23 def __next__(self): 24 if self.i >= len(self.seqn): raise StopIteration 25 v = self.seqn[self.i] 26 self.i += 1 27 return v 28 29class Ig: 30 'Sequence using iterator protocol defined with a generator' 31 def __init__(self, seqn): 32 self.seqn = seqn 33 self.i = 0 34 def __iter__(self): 35 for val in self.seqn: 36 yield val 37 38class X: 39 'Missing __getitem__ and __iter__' 40 def __init__(self, seqn): 41 self.seqn = seqn 42 self.i = 0 43 def __next__(self): 44 if self.i >= len(self.seqn): raise StopIteration 45 v = self.seqn[self.i] 46 self.i += 1 47 return v 48 49class E: 50 'Test propagation of exceptions' 51 def __init__(self, seqn): 52 self.seqn = seqn 53 self.i = 0 54 def __iter__(self): 55 return self 56 def __next__(self): 57 3 // 0 58 59class N: 60 'Iterator missing __next__()' 61 def __init__(self, seqn): 62 self.seqn = seqn 63 self.i = 0 64 def __iter__(self): 65 return self 66 67class PickleTest: 68 # Helper to check picklability 69 def check_pickle(self, itorg, seq): 70 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 71 d = pickle.dumps(itorg, proto) 72 it = pickle.loads(d) 73 self.assertEqual(type(itorg), type(it)) 74 self.assertEqual(list(it), seq) 75 76 it = pickle.loads(d) 77 try: 78 next(it) 79 except StopIteration: 80 self.assertFalse(seq[1:]) 81 continue 82 d = pickle.dumps(it, proto) 83 it = pickle.loads(d) 84 self.assertEqual(list(it), seq[1:]) 85 86class EnumerateTestCase(unittest.TestCase, PickleTest): 87 88 enum = enumerate 89 seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')] 90 91 def test_basicfunction(self): 92 self.assertEqual(type(self.enum(self.seq)), self.enum) 93 e = self.enum(self.seq) 94 self.assertEqual(iter(e), e) 95 self.assertEqual(list(self.enum(self.seq)), self.res) 96 self.enum.__doc__ 97 98 def test_pickle(self): 99 self.check_pickle(self.enum(self.seq), self.res) 100 101 def test_getitemseqn(self): 102 self.assertEqual(list(self.enum(G(self.seq))), self.res) 103 e = self.enum(G('')) 104 self.assertRaises(StopIteration, next, e) 105 106 def test_iteratorseqn(self): 107 self.assertEqual(list(self.enum(I(self.seq))), self.res) 108 e = self.enum(I('')) 109 self.assertRaises(StopIteration, next, e) 110 111 def test_iteratorgenerator(self): 112 self.assertEqual(list(self.enum(Ig(self.seq))), self.res) 113 e = self.enum(Ig('')) 114 self.assertRaises(StopIteration, next, e) 115 116 def test_noniterable(self): 117 self.assertRaises(TypeError, self.enum, X(self.seq)) 118 119 def test_illformediterable(self): 120 self.assertRaises(TypeError, self.enum, N(self.seq)) 121 122 def test_exception_propagation(self): 123 self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq))) 124 125 def test_argumentcheck(self): 126 self.assertRaises(TypeError, self.enum) # no arguments 127 self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable) 128 self.assertRaises(TypeError, self.enum, 'abc', 'a') # wrong type 129 self.assertRaises(TypeError, self.enum, 'abc', 2, 3) # too many arguments 130 131 def test_kwargs(self): 132 self.assertEqual(list(self.enum(iterable=Ig(self.seq))), self.res) 133 expected = list(self.enum(Ig(self.seq), 0)) 134 self.assertEqual(list(self.enum(iterable=Ig(self.seq), start=0)), 135 expected) 136 self.assertEqual(list(self.enum(start=0, iterable=Ig(self.seq))), 137 expected) 138 self.assertRaises(TypeError, self.enum, iterable=[], x=3) 139 self.assertRaises(TypeError, self.enum, start=0, x=3) 140 self.assertRaises(TypeError, self.enum, x=0, y=3) 141 self.assertRaises(TypeError, self.enum, x=0) 142 143 @support.cpython_only 144 def test_tuple_reuse(self): 145 # Tests an implementation detail where tuple is reused 146 # whenever nothing else holds a reference to it 147 self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq)) 148 self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq))) 149 150 @support.cpython_only 151 def test_enumerate_result_gc(self): 152 # bpo-42536: enumerate's tuple-reuse speed trick breaks the GC's 153 # assumptions about what can be untracked. Make sure we re-track result 154 # tuples whenever we reuse them. 155 it = self.enum([[]]) 156 gc.collect() 157 # That GC collection probably untracked the recycled internal result 158 # tuple, which is initialized to (None, None). Make sure it's re-tracked 159 # when it's mutated and returned from __next__: 160 self.assertTrue(gc.is_tracked(next(it))) 161 162class MyEnum(enumerate): 163 pass 164 165class SubclassTestCase(EnumerateTestCase): 166 167 enum = MyEnum 168 169class TestEmpty(EnumerateTestCase): 170 171 seq, res = '', [] 172 173class TestBig(EnumerateTestCase): 174 175 seq = range(10,20000,2) 176 res = list(zip(range(20000), seq)) 177 178class TestReversed(unittest.TestCase, PickleTest): 179 180 def test_simple(self): 181 class A: 182 def __getitem__(self, i): 183 if i < 5: 184 return str(i) 185 raise StopIteration 186 def __len__(self): 187 return 5 188 for data in ('abc', range(5), tuple(enumerate('abc')), A(), 189 range(1,17,5), dict.fromkeys('abcde')): 190 self.assertEqual(list(data)[::-1], list(reversed(data))) 191 # don't allow keyword arguments 192 self.assertRaises(TypeError, reversed, [], a=1) 193 194 def test_range_optimization(self): 195 x = range(1) 196 self.assertEqual(type(reversed(x)), type(iter(x))) 197 198 def test_len(self): 199 for s in ('hello', tuple('hello'), list('hello'), range(5)): 200 self.assertEqual(operator.length_hint(reversed(s)), len(s)) 201 r = reversed(s) 202 list(r) 203 self.assertEqual(operator.length_hint(r), 0) 204 class SeqWithWeirdLen: 205 called = False 206 def __len__(self): 207 if not self.called: 208 self.called = True 209 return 10 210 raise ZeroDivisionError 211 def __getitem__(self, index): 212 return index 213 r = reversed(SeqWithWeirdLen()) 214 self.assertRaises(ZeroDivisionError, operator.length_hint, r) 215 216 217 def test_gc(self): 218 class Seq: 219 def __len__(self): 220 return 10 221 def __getitem__(self, index): 222 return index 223 s = Seq() 224 r = reversed(s) 225 s.r = r 226 227 def test_args(self): 228 self.assertRaises(TypeError, reversed) 229 self.assertRaises(TypeError, reversed, [], 'extra') 230 231 @unittest.skipUnless(hasattr(sys, 'getrefcount'), 'test needs sys.getrefcount()') 232 def test_bug1229429(self): 233 # this bug was never in reversed, it was in 234 # PyObject_CallMethod, and reversed_new calls that sometimes. 235 def f(): 236 pass 237 r = f.__reversed__ = object() 238 rc = sys.getrefcount(r) 239 for i in range(10): 240 try: 241 reversed(f) 242 except TypeError: 243 pass 244 else: 245 self.fail("non-callable __reversed__ didn't raise!") 246 self.assertEqual(rc, sys.getrefcount(r)) 247 248 def test_objmethods(self): 249 # Objects must have __len__() and __getitem__() implemented. 250 class NoLen(object): 251 def __getitem__(self, i): return 1 252 nl = NoLen() 253 self.assertRaises(TypeError, reversed, nl) 254 255 class NoGetItem(object): 256 def __len__(self): return 2 257 ngi = NoGetItem() 258 self.assertRaises(TypeError, reversed, ngi) 259 260 class Blocked(object): 261 def __getitem__(self, i): return 1 262 def __len__(self): return 2 263 __reversed__ = None 264 b = Blocked() 265 self.assertRaises(TypeError, reversed, b) 266 267 def test_pickle(self): 268 for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5): 269 self.check_pickle(reversed(data), list(data)[::-1]) 270 271 272class EnumerateStartTestCase(EnumerateTestCase): 273 274 def test_basicfunction(self): 275 e = self.enum(self.seq) 276 self.assertEqual(iter(e), e) 277 self.assertEqual(list(self.enum(self.seq)), self.res) 278 279 280class TestStart(EnumerateStartTestCase): 281 def enum(self, iterable, start=11): 282 return enumerate(iterable, start=start) 283 284 seq, res = 'abc', [(11, 'a'), (12, 'b'), (13, 'c')] 285 286 287class TestLongStart(EnumerateStartTestCase): 288 def enum(self, iterable, start=sys.maxsize + 1): 289 return enumerate(iterable, start=start) 290 291 seq, res = 'abc', [(sys.maxsize+1,'a'), (sys.maxsize+2,'b'), 292 (sys.maxsize+3,'c')] 293 294 295if __name__ == "__main__": 296 unittest.main() 297