1# Test iterators. 2 3import sys 4import unittest 5from test.support import cpython_only 6from test.support.os_helper import TESTFN, unlink 7from test.support import check_free_after_iterating, ALWAYS_EQ, NEVER_EQ 8import pickle 9import collections.abc 10import functools 11import contextlib 12import builtins 13 14# Test result of triple loop (too big to inline) 15TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2), 16 (0, 1, 0), (0, 1, 1), (0, 1, 2), 17 (0, 2, 0), (0, 2, 1), (0, 2, 2), 18 19 (1, 0, 0), (1, 0, 1), (1, 0, 2), 20 (1, 1, 0), (1, 1, 1), (1, 1, 2), 21 (1, 2, 0), (1, 2, 1), (1, 2, 2), 22 23 (2, 0, 0), (2, 0, 1), (2, 0, 2), 24 (2, 1, 0), (2, 1, 1), (2, 1, 2), 25 (2, 2, 0), (2, 2, 1), (2, 2, 2)] 26 27# Helper classes 28 29class BasicIterClass: 30 def __init__(self, n): 31 self.n = n 32 self.i = 0 33 def __next__(self): 34 res = self.i 35 if res >= self.n: 36 raise StopIteration 37 self.i = res + 1 38 return res 39 def __iter__(self): 40 return self 41 42class IteratingSequenceClass: 43 def __init__(self, n): 44 self.n = n 45 def __iter__(self): 46 return BasicIterClass(self.n) 47 48class IteratorProxyClass: 49 def __init__(self, i): 50 self.i = i 51 def __next__(self): 52 return next(self.i) 53 def __iter__(self): 54 return self 55 56class SequenceClass: 57 def __init__(self, n): 58 self.n = n 59 def __getitem__(self, i): 60 if 0 <= i < self.n: 61 return i 62 else: 63 raise IndexError 64 65class SequenceProxyClass: 66 def __init__(self, s): 67 self.s = s 68 def __getitem__(self, i): 69 return self.s[i] 70 71class UnlimitedSequenceClass: 72 def __getitem__(self, i): 73 return i 74 75class DefaultIterClass: 76 pass 77 78class NoIterClass: 79 def __getitem__(self, i): 80 return i 81 __iter__ = None 82 83class BadIterableClass: 84 def __iter__(self): 85 raise ZeroDivisionError 86 87class CallableIterClass: 88 def __init__(self): 89 self.i = 0 90 def __call__(self): 91 i = self.i 92 self.i = i + 1 93 if i > 100: 94 raise IndexError # Emergency stop 95 return i 96 97class EmptyIterClass: 98 def __len__(self): 99 return 0 100 def __getitem__(self, i): 101 raise StopIteration 102 103# Main test suite 104 105class TestCase(unittest.TestCase): 106 107 # Helper to check that an iterator returns a given sequence 108 def check_iterator(self, it, seq, pickle=True): 109 if pickle: 110 self.check_pickle(it, seq) 111 res = [] 112 while 1: 113 try: 114 val = next(it) 115 except StopIteration: 116 break 117 res.append(val) 118 self.assertEqual(res, seq) 119 120 # Helper to check that a for loop generates a given sequence 121 def check_for_loop(self, expr, seq, pickle=True): 122 if pickle: 123 self.check_pickle(iter(expr), seq) 124 res = [] 125 for val in expr: 126 res.append(val) 127 self.assertEqual(res, seq) 128 129 # Helper to check picklability 130 def check_pickle(self, itorg, seq): 131 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 132 d = pickle.dumps(itorg, proto) 133 it = pickle.loads(d) 134 # Cannot assert type equality because dict iterators unpickle as list 135 # iterators. 136 # self.assertEqual(type(itorg), type(it)) 137 self.assertTrue(isinstance(it, collections.abc.Iterator)) 138 self.assertEqual(list(it), seq) 139 140 it = pickle.loads(d) 141 try: 142 next(it) 143 except StopIteration: 144 continue 145 d = pickle.dumps(it, proto) 146 it = pickle.loads(d) 147 self.assertEqual(list(it), seq[1:]) 148 149 # Test basic use of iter() function 150 def test_iter_basic(self): 151 self.check_iterator(iter(range(10)), list(range(10))) 152 153 # Test that iter(iter(x)) is the same as iter(x) 154 def test_iter_idempotency(self): 155 seq = list(range(10)) 156 it = iter(seq) 157 it2 = iter(it) 158 self.assertTrue(it is it2) 159 160 # Test that for loops over iterators work 161 def test_iter_for_loop(self): 162 self.check_for_loop(iter(range(10)), list(range(10))) 163 164 # Test several independent iterators over the same list 165 def test_iter_independence(self): 166 seq = range(3) 167 res = [] 168 for i in iter(seq): 169 for j in iter(seq): 170 for k in iter(seq): 171 res.append((i, j, k)) 172 self.assertEqual(res, TRIPLETS) 173 174 # Test triple list comprehension using iterators 175 def test_nested_comprehensions_iter(self): 176 seq = range(3) 177 res = [(i, j, k) 178 for i in iter(seq) for j in iter(seq) for k in iter(seq)] 179 self.assertEqual(res, TRIPLETS) 180 181 # Test triple list comprehension without iterators 182 def test_nested_comprehensions_for(self): 183 seq = range(3) 184 res = [(i, j, k) for i in seq for j in seq for k in seq] 185 self.assertEqual(res, TRIPLETS) 186 187 # Test a class with __iter__ in a for loop 188 def test_iter_class_for(self): 189 self.check_for_loop(IteratingSequenceClass(10), list(range(10))) 190 191 # Test a class with __iter__ with explicit iter() 192 def test_iter_class_iter(self): 193 self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10))) 194 195 # Test for loop on a sequence class without __iter__ 196 def test_seq_class_for(self): 197 self.check_for_loop(SequenceClass(10), list(range(10))) 198 199 # Test iter() on a sequence class without __iter__ 200 def test_seq_class_iter(self): 201 self.check_iterator(iter(SequenceClass(10)), list(range(10))) 202 203 def test_mutating_seq_class_iter_pickle(self): 204 orig = SequenceClass(5) 205 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 206 # initial iterator 207 itorig = iter(orig) 208 d = pickle.dumps((itorig, orig), proto) 209 it, seq = pickle.loads(d) 210 seq.n = 7 211 self.assertIs(type(it), type(itorig)) 212 self.assertEqual(list(it), list(range(7))) 213 214 # running iterator 215 next(itorig) 216 d = pickle.dumps((itorig, orig), proto) 217 it, seq = pickle.loads(d) 218 seq.n = 7 219 self.assertIs(type(it), type(itorig)) 220 self.assertEqual(list(it), list(range(1, 7))) 221 222 # empty iterator 223 for i in range(1, 5): 224 next(itorig) 225 d = pickle.dumps((itorig, orig), proto) 226 it, seq = pickle.loads(d) 227 seq.n = 7 228 self.assertIs(type(it), type(itorig)) 229 self.assertEqual(list(it), list(range(5, 7))) 230 231 # exhausted iterator 232 self.assertRaises(StopIteration, next, itorig) 233 d = pickle.dumps((itorig, orig), proto) 234 it, seq = pickle.loads(d) 235 seq.n = 7 236 self.assertTrue(isinstance(it, collections.abc.Iterator)) 237 self.assertEqual(list(it), []) 238 239 def test_mutating_seq_class_exhausted_iter(self): 240 a = SequenceClass(5) 241 exhit = iter(a) 242 empit = iter(a) 243 for x in exhit: # exhaust the iterator 244 next(empit) # not exhausted 245 a.n = 7 246 self.assertEqual(list(exhit), []) 247 self.assertEqual(list(empit), [5, 6]) 248 self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6]) 249 250 def test_reduce_mutating_builtins_iter(self): 251 # This is a reproducer of issue #101765 252 # where iter `__reduce__` calls could lead to a segfault or SystemError 253 # depending on the order of C argument evaluation, which is undefined 254 255 # Backup builtins 256 builtins_dict = builtins.__dict__ 257 orig = {"iter": iter, "reversed": reversed} 258 259 def run(builtin_name, item, sentinel=None): 260 it = iter(item) if sentinel is None else iter(item, sentinel) 261 262 class CustomStr: 263 def __init__(self, name, iterator): 264 self.name = name 265 self.iterator = iterator 266 def __hash__(self): 267 return hash(self.name) 268 def __eq__(self, other): 269 # Here we exhaust our iterator, possibly changing 270 # its `it_seq` pointer to NULL 271 # The `__reduce__` call should correctly get 272 # the pointers after this call 273 list(self.iterator) 274 return other == self.name 275 276 # del is required here 277 # to not prematurely call __eq__ from 278 # the hash collision with the old key 279 del builtins_dict[builtin_name] 280 builtins_dict[CustomStr(builtin_name, it)] = orig[builtin_name] 281 282 return it.__reduce__() 283 284 types = [ 285 (EmptyIterClass(),), 286 (bytes(8),), 287 (bytearray(8),), 288 ((1, 2, 3),), 289 (lambda: 0, 0), 290 (tuple[int],) # GenericAlias 291 ] 292 293 try: 294 run_iter = functools.partial(run, "iter") 295 # The returned value of `__reduce__` should not only be valid 296 # but also *empty*, as `it` was exhausted during `__eq__` 297 # i.e "xyz" returns (iter, ("",)) 298 self.assertEqual(run_iter("xyz"), (orig["iter"], ("",))) 299 self.assertEqual(run_iter([1, 2, 3]), (orig["iter"], ([],))) 300 301 # _PyEval_GetBuiltin is also called for `reversed` in a branch of 302 # listiter_reduce_general 303 self.assertEqual( 304 run("reversed", orig["reversed"](list(range(8)))), 305 (iter, ([],)) 306 ) 307 308 for case in types: 309 self.assertEqual(run_iter(*case), (orig["iter"], ((),))) 310 finally: 311 # Restore original builtins 312 for key, func in orig.items(): 313 # need to suppress KeyErrors in case 314 # a failed test deletes the key without setting anything 315 with contextlib.suppress(KeyError): 316 # del is required here 317 # to not invoke our custom __eq__ from 318 # the hash collision with the old key 319 del builtins_dict[key] 320 builtins_dict[key] = func 321 322 # Test a new_style class with __iter__ but no next() method 323 def test_new_style_iter_class(self): 324 class IterClass(object): 325 def __iter__(self): 326 return self 327 self.assertRaises(TypeError, iter, IterClass()) 328 329 # Test two-argument iter() with callable instance 330 def test_iter_callable(self): 331 self.check_iterator(iter(CallableIterClass(), 10), list(range(10)), pickle=True) 332 333 # Test two-argument iter() with function 334 def test_iter_function(self): 335 def spam(state=[0]): 336 i = state[0] 337 state[0] = i+1 338 return i 339 self.check_iterator(iter(spam, 10), list(range(10)), pickle=False) 340 341 # Test two-argument iter() with function that raises StopIteration 342 def test_iter_function_stop(self): 343 def spam(state=[0]): 344 i = state[0] 345 if i == 10: 346 raise StopIteration 347 state[0] = i+1 348 return i 349 self.check_iterator(iter(spam, 20), list(range(10)), pickle=False) 350 351 def test_iter_function_concealing_reentrant_exhaustion(self): 352 # gh-101892: Test two-argument iter() with a function that 353 # exhausts its associated iterator but forgets to either return 354 # a sentinel value or raise StopIteration. 355 HAS_MORE = 1 356 NO_MORE = 2 357 358 def exhaust(iterator): 359 """Exhaust an iterator without raising StopIteration.""" 360 list(iterator) 361 362 def spam(): 363 # Touching the iterator with exhaust() below will call 364 # spam() once again so protect against recursion. 365 if spam.is_recursive_call: 366 return NO_MORE 367 spam.is_recursive_call = True 368 exhaust(spam.iterator) 369 return HAS_MORE 370 371 spam.is_recursive_call = False 372 spam.iterator = iter(spam, NO_MORE) 373 with self.assertRaises(StopIteration): 374 next(spam.iterator) 375 376 # Test exception propagation through function iterator 377 def test_exception_function(self): 378 def spam(state=[0]): 379 i = state[0] 380 state[0] = i+1 381 if i == 10: 382 raise RuntimeError 383 return i 384 res = [] 385 try: 386 for x in iter(spam, 20): 387 res.append(x) 388 except RuntimeError: 389 self.assertEqual(res, list(range(10))) 390 else: 391 self.fail("should have raised RuntimeError") 392 393 # Test exception propagation through sequence iterator 394 def test_exception_sequence(self): 395 class MySequenceClass(SequenceClass): 396 def __getitem__(self, i): 397 if i == 10: 398 raise RuntimeError 399 return SequenceClass.__getitem__(self, i) 400 res = [] 401 try: 402 for x in MySequenceClass(20): 403 res.append(x) 404 except RuntimeError: 405 self.assertEqual(res, list(range(10))) 406 else: 407 self.fail("should have raised RuntimeError") 408 409 # Test for StopIteration from __getitem__ 410 def test_stop_sequence(self): 411 class MySequenceClass(SequenceClass): 412 def __getitem__(self, i): 413 if i == 10: 414 raise StopIteration 415 return SequenceClass.__getitem__(self, i) 416 self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False) 417 418 # Test a big range 419 def test_iter_big_range(self): 420 self.check_for_loop(iter(range(10000)), list(range(10000))) 421 422 # Test an empty list 423 def test_iter_empty(self): 424 self.check_for_loop(iter([]), []) 425 426 # Test a tuple 427 def test_iter_tuple(self): 428 self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), list(range(10))) 429 430 # Test a range 431 def test_iter_range(self): 432 self.check_for_loop(iter(range(10)), list(range(10))) 433 434 # Test a string 435 def test_iter_string(self): 436 self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"]) 437 438 # Test a directory 439 def test_iter_dict(self): 440 dict = {} 441 for i in range(10): 442 dict[i] = None 443 self.check_for_loop(dict, list(dict.keys())) 444 445 # Test a file 446 def test_iter_file(self): 447 f = open(TESTFN, "w", encoding="utf-8") 448 try: 449 for i in range(5): 450 f.write("%d\n" % i) 451 finally: 452 f.close() 453 f = open(TESTFN, "r", encoding="utf-8") 454 try: 455 self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"], pickle=False) 456 self.check_for_loop(f, [], pickle=False) 457 finally: 458 f.close() 459 try: 460 unlink(TESTFN) 461 except OSError: 462 pass 463 464 # Test list()'s use of iterators. 465 def test_builtin_list(self): 466 self.assertEqual(list(SequenceClass(5)), list(range(5))) 467 self.assertEqual(list(SequenceClass(0)), []) 468 self.assertEqual(list(()), []) 469 470 d = {"one": 1, "two": 2, "three": 3} 471 self.assertEqual(list(d), list(d.keys())) 472 473 self.assertRaises(TypeError, list, list) 474 self.assertRaises(TypeError, list, 42) 475 476 f = open(TESTFN, "w", encoding="utf-8") 477 try: 478 for i in range(5): 479 f.write("%d\n" % i) 480 finally: 481 f.close() 482 f = open(TESTFN, "r", encoding="utf-8") 483 try: 484 self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"]) 485 f.seek(0, 0) 486 self.assertEqual(list(f), 487 ["0\n", "1\n", "2\n", "3\n", "4\n"]) 488 finally: 489 f.close() 490 try: 491 unlink(TESTFN) 492 except OSError: 493 pass 494 495 # Test tuples()'s use of iterators. 496 def test_builtin_tuple(self): 497 self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4)) 498 self.assertEqual(tuple(SequenceClass(0)), ()) 499 self.assertEqual(tuple([]), ()) 500 self.assertEqual(tuple(()), ()) 501 self.assertEqual(tuple("abc"), ("a", "b", "c")) 502 503 d = {"one": 1, "two": 2, "three": 3} 504 self.assertEqual(tuple(d), tuple(d.keys())) 505 506 self.assertRaises(TypeError, tuple, list) 507 self.assertRaises(TypeError, tuple, 42) 508 509 f = open(TESTFN, "w", encoding="utf-8") 510 try: 511 for i in range(5): 512 f.write("%d\n" % i) 513 finally: 514 f.close() 515 f = open(TESTFN, "r", encoding="utf-8") 516 try: 517 self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n")) 518 f.seek(0, 0) 519 self.assertEqual(tuple(f), 520 ("0\n", "1\n", "2\n", "3\n", "4\n")) 521 finally: 522 f.close() 523 try: 524 unlink(TESTFN) 525 except OSError: 526 pass 527 528 # Test filter()'s use of iterators. 529 def test_builtin_filter(self): 530 self.assertEqual(list(filter(None, SequenceClass(5))), 531 list(range(1, 5))) 532 self.assertEqual(list(filter(None, SequenceClass(0))), []) 533 self.assertEqual(list(filter(None, ())), []) 534 self.assertEqual(list(filter(None, "abc")), ["a", "b", "c"]) 535 536 d = {"one": 1, "two": 2, "three": 3} 537 self.assertEqual(list(filter(None, d)), list(d.keys())) 538 539 self.assertRaises(TypeError, filter, None, list) 540 self.assertRaises(TypeError, filter, None, 42) 541 542 class Boolean: 543 def __init__(self, truth): 544 self.truth = truth 545 def __bool__(self): 546 return self.truth 547 bTrue = Boolean(True) 548 bFalse = Boolean(False) 549 550 class Seq: 551 def __init__(self, *args): 552 self.vals = args 553 def __iter__(self): 554 class SeqIter: 555 def __init__(self, vals): 556 self.vals = vals 557 self.i = 0 558 def __iter__(self): 559 return self 560 def __next__(self): 561 i = self.i 562 self.i = i + 1 563 if i < len(self.vals): 564 return self.vals[i] 565 else: 566 raise StopIteration 567 return SeqIter(self.vals) 568 569 seq = Seq(*([bTrue, bFalse] * 25)) 570 self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25) 571 self.assertEqual(list(filter(lambda x: not x, iter(seq))), [bFalse]*25) 572 573 # Test max() and min()'s use of iterators. 574 def test_builtin_max_min(self): 575 self.assertEqual(max(SequenceClass(5)), 4) 576 self.assertEqual(min(SequenceClass(5)), 0) 577 self.assertEqual(max(8, -1), 8) 578 self.assertEqual(min(8, -1), -1) 579 580 d = {"one": 1, "two": 2, "three": 3} 581 self.assertEqual(max(d), "two") 582 self.assertEqual(min(d), "one") 583 self.assertEqual(max(d.values()), 3) 584 self.assertEqual(min(iter(d.values())), 1) 585 586 f = open(TESTFN, "w", encoding="utf-8") 587 try: 588 f.write("medium line\n") 589 f.write("xtra large line\n") 590 f.write("itty-bitty line\n") 591 finally: 592 f.close() 593 f = open(TESTFN, "r", encoding="utf-8") 594 try: 595 self.assertEqual(min(f), "itty-bitty line\n") 596 f.seek(0, 0) 597 self.assertEqual(max(f), "xtra large line\n") 598 finally: 599 f.close() 600 try: 601 unlink(TESTFN) 602 except OSError: 603 pass 604 605 # Test map()'s use of iterators. 606 def test_builtin_map(self): 607 self.assertEqual(list(map(lambda x: x+1, SequenceClass(5))), 608 list(range(1, 6))) 609 610 d = {"one": 1, "two": 2, "three": 3} 611 self.assertEqual(list(map(lambda k, d=d: (k, d[k]), d)), 612 list(d.items())) 613 dkeys = list(d.keys()) 614 expected = [(i < len(d) and dkeys[i] or None, 615 i, 616 i < len(d) and dkeys[i] or None) 617 for i in range(3)] 618 619 f = open(TESTFN, "w", encoding="utf-8") 620 try: 621 for i in range(10): 622 f.write("xy" * i + "\n") # line i has len 2*i+1 623 finally: 624 f.close() 625 f = open(TESTFN, "r", encoding="utf-8") 626 try: 627 self.assertEqual(list(map(len, f)), list(range(1, 21, 2))) 628 finally: 629 f.close() 630 try: 631 unlink(TESTFN) 632 except OSError: 633 pass 634 635 # Test zip()'s use of iterators. 636 def test_builtin_zip(self): 637 self.assertEqual(list(zip()), []) 638 self.assertEqual(list(zip(*[])), []) 639 self.assertEqual(list(zip(*[(1, 2), 'ab'])), [(1, 'a'), (2, 'b')]) 640 641 self.assertRaises(TypeError, zip, None) 642 self.assertRaises(TypeError, zip, range(10), 42) 643 self.assertRaises(TypeError, zip, range(10), zip) 644 645 self.assertEqual(list(zip(IteratingSequenceClass(3))), 646 [(0,), (1,), (2,)]) 647 self.assertEqual(list(zip(SequenceClass(3))), 648 [(0,), (1,), (2,)]) 649 650 d = {"one": 1, "two": 2, "three": 3} 651 self.assertEqual(list(d.items()), list(zip(d, d.values()))) 652 653 # Generate all ints starting at constructor arg. 654 class IntsFrom: 655 def __init__(self, start): 656 self.i = start 657 658 def __iter__(self): 659 return self 660 661 def __next__(self): 662 i = self.i 663 self.i = i+1 664 return i 665 666 f = open(TESTFN, "w", encoding="utf-8") 667 try: 668 f.write("a\n" "bbb\n" "cc\n") 669 finally: 670 f.close() 671 f = open(TESTFN, "r", encoding="utf-8") 672 try: 673 self.assertEqual(list(zip(IntsFrom(0), f, IntsFrom(-100))), 674 [(0, "a\n", -100), 675 (1, "bbb\n", -99), 676 (2, "cc\n", -98)]) 677 finally: 678 f.close() 679 try: 680 unlink(TESTFN) 681 except OSError: 682 pass 683 684 self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)]) 685 686 # Classes that lie about their lengths. 687 class NoGuessLen5: 688 def __getitem__(self, i): 689 if i >= 5: 690 raise IndexError 691 return i 692 693 class Guess3Len5(NoGuessLen5): 694 def __len__(self): 695 return 3 696 697 class Guess30Len5(NoGuessLen5): 698 def __len__(self): 699 return 30 700 701 def lzip(*args): 702 return list(zip(*args)) 703 704 self.assertEqual(len(Guess3Len5()), 3) 705 self.assertEqual(len(Guess30Len5()), 30) 706 self.assertEqual(lzip(NoGuessLen5()), lzip(range(5))) 707 self.assertEqual(lzip(Guess3Len5()), lzip(range(5))) 708 self.assertEqual(lzip(Guess30Len5()), lzip(range(5))) 709 710 expected = [(i, i) for i in range(5)] 711 for x in NoGuessLen5(), Guess3Len5(), Guess30Len5(): 712 for y in NoGuessLen5(), Guess3Len5(), Guess30Len5(): 713 self.assertEqual(lzip(x, y), expected) 714 715 def test_unicode_join_endcase(self): 716 717 # This class inserts a Unicode object into its argument's natural 718 # iteration, in the 3rd position. 719 class OhPhooey: 720 def __init__(self, seq): 721 self.it = iter(seq) 722 self.i = 0 723 724 def __iter__(self): 725 return self 726 727 def __next__(self): 728 i = self.i 729 self.i = i+1 730 if i == 2: 731 return "fooled you!" 732 return next(self.it) 733 734 f = open(TESTFN, "w", encoding="utf-8") 735 try: 736 f.write("a\n" + "b\n" + "c\n") 737 finally: 738 f.close() 739 740 f = open(TESTFN, "r", encoding="utf-8") 741 # Nasty: string.join(s) can't know whether unicode.join() is needed 742 # until it's seen all of s's elements. But in this case, f's 743 # iterator cannot be restarted. So what we're testing here is 744 # whether string.join() can manage to remember everything it's seen 745 # and pass that on to unicode.join(). 746 try: 747 got = " - ".join(OhPhooey(f)) 748 self.assertEqual(got, "a\n - b\n - fooled you! - c\n") 749 finally: 750 f.close() 751 try: 752 unlink(TESTFN) 753 except OSError: 754 pass 755 756 # Test iterators with 'x in y' and 'x not in y'. 757 def test_in_and_not_in(self): 758 for sc5 in IteratingSequenceClass(5), SequenceClass(5): 759 for i in range(5): 760 self.assertIn(i, sc5) 761 for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5: 762 self.assertNotIn(i, sc5) 763 764 self.assertIn(ALWAYS_EQ, IteratorProxyClass(iter([1]))) 765 self.assertIn(ALWAYS_EQ, SequenceProxyClass([1])) 766 self.assertNotIn(ALWAYS_EQ, IteratorProxyClass(iter([NEVER_EQ]))) 767 self.assertNotIn(ALWAYS_EQ, SequenceProxyClass([NEVER_EQ])) 768 self.assertIn(NEVER_EQ, IteratorProxyClass(iter([ALWAYS_EQ]))) 769 self.assertIn(NEVER_EQ, SequenceProxyClass([ALWAYS_EQ])) 770 771 self.assertRaises(TypeError, lambda: 3 in 12) 772 self.assertRaises(TypeError, lambda: 3 not in map) 773 self.assertRaises(ZeroDivisionError, lambda: 3 in BadIterableClass()) 774 775 d = {"one": 1, "two": 2, "three": 3, 1j: 2j} 776 for k in d: 777 self.assertIn(k, d) 778 self.assertNotIn(k, d.values()) 779 for v in d.values(): 780 self.assertIn(v, d.values()) 781 self.assertNotIn(v, d) 782 for k, v in d.items(): 783 self.assertIn((k, v), d.items()) 784 self.assertNotIn((v, k), d.items()) 785 786 f = open(TESTFN, "w", encoding="utf-8") 787 try: 788 f.write("a\n" "b\n" "c\n") 789 finally: 790 f.close() 791 f = open(TESTFN, "r", encoding="utf-8") 792 try: 793 for chunk in "abc": 794 f.seek(0, 0) 795 self.assertNotIn(chunk, f) 796 f.seek(0, 0) 797 self.assertIn((chunk + "\n"), f) 798 finally: 799 f.close() 800 try: 801 unlink(TESTFN) 802 except OSError: 803 pass 804 805 # Test iterators with operator.countOf (PySequence_Count). 806 def test_countOf(self): 807 from operator import countOf 808 self.assertEqual(countOf([1,2,2,3,2,5], 2), 3) 809 self.assertEqual(countOf((1,2,2,3,2,5), 2), 3) 810 self.assertEqual(countOf("122325", "2"), 3) 811 self.assertEqual(countOf("122325", "6"), 0) 812 813 self.assertRaises(TypeError, countOf, 42, 1) 814 self.assertRaises(TypeError, countOf, countOf, countOf) 815 816 d = {"one": 3, "two": 3, "three": 3, 1j: 2j} 817 for k in d: 818 self.assertEqual(countOf(d, k), 1) 819 self.assertEqual(countOf(d.values(), 3), 3) 820 self.assertEqual(countOf(d.values(), 2j), 1) 821 self.assertEqual(countOf(d.values(), 1j), 0) 822 823 f = open(TESTFN, "w", encoding="utf-8") 824 try: 825 f.write("a\n" "b\n" "c\n" "b\n") 826 finally: 827 f.close() 828 f = open(TESTFN, "r", encoding="utf-8") 829 try: 830 for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0): 831 f.seek(0, 0) 832 self.assertEqual(countOf(f, letter + "\n"), count) 833 finally: 834 f.close() 835 try: 836 unlink(TESTFN) 837 except OSError: 838 pass 839 840 # Test iterators with operator.indexOf (PySequence_Index). 841 def test_indexOf(self): 842 from operator import indexOf 843 self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0) 844 self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1) 845 self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3) 846 self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5) 847 self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0) 848 self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6) 849 850 self.assertEqual(indexOf("122325", "2"), 1) 851 self.assertEqual(indexOf("122325", "5"), 5) 852 self.assertRaises(ValueError, indexOf, "122325", "6") 853 854 self.assertRaises(TypeError, indexOf, 42, 1) 855 self.assertRaises(TypeError, indexOf, indexOf, indexOf) 856 self.assertRaises(ZeroDivisionError, indexOf, BadIterableClass(), 1) 857 858 f = open(TESTFN, "w", encoding="utf-8") 859 try: 860 f.write("a\n" "b\n" "c\n" "d\n" "e\n") 861 finally: 862 f.close() 863 f = open(TESTFN, "r", encoding="utf-8") 864 try: 865 fiter = iter(f) 866 self.assertEqual(indexOf(fiter, "b\n"), 1) 867 self.assertEqual(indexOf(fiter, "d\n"), 1) 868 self.assertEqual(indexOf(fiter, "e\n"), 0) 869 self.assertRaises(ValueError, indexOf, fiter, "a\n") 870 finally: 871 f.close() 872 try: 873 unlink(TESTFN) 874 except OSError: 875 pass 876 877 iclass = IteratingSequenceClass(3) 878 for i in range(3): 879 self.assertEqual(indexOf(iclass, i), i) 880 self.assertRaises(ValueError, indexOf, iclass, -1) 881 882 # Test iterators with file.writelines(). 883 def test_writelines(self): 884 f = open(TESTFN, "w", encoding="utf-8") 885 886 try: 887 self.assertRaises(TypeError, f.writelines, None) 888 self.assertRaises(TypeError, f.writelines, 42) 889 890 f.writelines(["1\n", "2\n"]) 891 f.writelines(("3\n", "4\n")) 892 f.writelines({'5\n': None}) 893 f.writelines({}) 894 895 # Try a big chunk too. 896 class Iterator: 897 def __init__(self, start, finish): 898 self.start = start 899 self.finish = finish 900 self.i = self.start 901 902 def __next__(self): 903 if self.i >= self.finish: 904 raise StopIteration 905 result = str(self.i) + '\n' 906 self.i += 1 907 return result 908 909 def __iter__(self): 910 return self 911 912 class Whatever: 913 def __init__(self, start, finish): 914 self.start = start 915 self.finish = finish 916 917 def __iter__(self): 918 return Iterator(self.start, self.finish) 919 920 f.writelines(Whatever(6, 6+2000)) 921 f.close() 922 923 f = open(TESTFN, encoding="utf-8") 924 expected = [str(i) + "\n" for i in range(1, 2006)] 925 self.assertEqual(list(f), expected) 926 927 finally: 928 f.close() 929 try: 930 unlink(TESTFN) 931 except OSError: 932 pass 933 934 935 # Test iterators on RHS of unpacking assignments. 936 def test_unpack_iter(self): 937 a, b = 1, 2 938 self.assertEqual((a, b), (1, 2)) 939 940 a, b, c = IteratingSequenceClass(3) 941 self.assertEqual((a, b, c), (0, 1, 2)) 942 943 try: # too many values 944 a, b = IteratingSequenceClass(3) 945 except ValueError: 946 pass 947 else: 948 self.fail("should have raised ValueError") 949 950 try: # not enough values 951 a, b, c = IteratingSequenceClass(2) 952 except ValueError: 953 pass 954 else: 955 self.fail("should have raised ValueError") 956 957 try: # not iterable 958 a, b, c = len 959 except TypeError: 960 pass 961 else: 962 self.fail("should have raised TypeError") 963 964 a, b, c = {1: 42, 2: 42, 3: 42}.values() 965 self.assertEqual((a, b, c), (42, 42, 42)) 966 967 f = open(TESTFN, "w", encoding="utf-8") 968 lines = ("a\n", "bb\n", "ccc\n") 969 try: 970 for line in lines: 971 f.write(line) 972 finally: 973 f.close() 974 f = open(TESTFN, "r", encoding="utf-8") 975 try: 976 a, b, c = f 977 self.assertEqual((a, b, c), lines) 978 finally: 979 f.close() 980 try: 981 unlink(TESTFN) 982 except OSError: 983 pass 984 985 (a, b), (c,) = IteratingSequenceClass(2), {42: 24} 986 self.assertEqual((a, b, c), (0, 1, 42)) 987 988 989 @cpython_only 990 def test_ref_counting_behavior(self): 991 class C(object): 992 count = 0 993 def __new__(cls): 994 cls.count += 1 995 return object.__new__(cls) 996 def __del__(self): 997 cls = self.__class__ 998 assert cls.count > 0 999 cls.count -= 1 1000 x = C() 1001 self.assertEqual(C.count, 1) 1002 del x 1003 self.assertEqual(C.count, 0) 1004 l = [C(), C(), C()] 1005 self.assertEqual(C.count, 3) 1006 try: 1007 a, b = iter(l) 1008 except ValueError: 1009 pass 1010 del l 1011 self.assertEqual(C.count, 0) 1012 1013 1014 # Make sure StopIteration is a "sink state". 1015 # This tests various things that weren't sink states in Python 2.2.1, 1016 # plus various things that always were fine. 1017 1018 def test_sinkstate_list(self): 1019 # This used to fail 1020 a = list(range(5)) 1021 b = iter(a) 1022 self.assertEqual(list(b), list(range(5))) 1023 a.extend(range(5, 10)) 1024 self.assertEqual(list(b), []) 1025 1026 def test_sinkstate_tuple(self): 1027 a = (0, 1, 2, 3, 4) 1028 b = iter(a) 1029 self.assertEqual(list(b), list(range(5))) 1030 self.assertEqual(list(b), []) 1031 1032 def test_sinkstate_string(self): 1033 a = "abcde" 1034 b = iter(a) 1035 self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e']) 1036 self.assertEqual(list(b), []) 1037 1038 def test_sinkstate_sequence(self): 1039 # This used to fail 1040 a = SequenceClass(5) 1041 b = iter(a) 1042 self.assertEqual(list(b), list(range(5))) 1043 a.n = 10 1044 self.assertEqual(list(b), []) 1045 1046 def test_sinkstate_callable(self): 1047 # This used to fail 1048 def spam(state=[0]): 1049 i = state[0] 1050 state[0] = i+1 1051 if i == 10: 1052 raise AssertionError("shouldn't have gotten this far") 1053 return i 1054 b = iter(spam, 5) 1055 self.assertEqual(list(b), list(range(5))) 1056 self.assertEqual(list(b), []) 1057 1058 def test_sinkstate_dict(self): 1059 # XXX For a more thorough test, see towards the end of: 1060 # http://mail.python.org/pipermail/python-dev/2002-July/026512.html 1061 a = {1:1, 2:2, 0:0, 4:4, 3:3} 1062 for b in iter(a), a.keys(), a.items(), a.values(): 1063 b = iter(a) 1064 self.assertEqual(len(list(b)), 5) 1065 self.assertEqual(list(b), []) 1066 1067 def test_sinkstate_yield(self): 1068 def gen(): 1069 for i in range(5): 1070 yield i 1071 b = gen() 1072 self.assertEqual(list(b), list(range(5))) 1073 self.assertEqual(list(b), []) 1074 1075 def test_sinkstate_range(self): 1076 a = range(5) 1077 b = iter(a) 1078 self.assertEqual(list(b), list(range(5))) 1079 self.assertEqual(list(b), []) 1080 1081 def test_sinkstate_enumerate(self): 1082 a = range(5) 1083 e = enumerate(a) 1084 b = iter(e) 1085 self.assertEqual(list(b), list(zip(range(5), range(5)))) 1086 self.assertEqual(list(b), []) 1087 1088 def test_3720(self): 1089 # Avoid a crash, when an iterator deletes its next() method. 1090 class BadIterator(object): 1091 def __iter__(self): 1092 return self 1093 def __next__(self): 1094 del BadIterator.__next__ 1095 return 1 1096 1097 try: 1098 for i in BadIterator() : 1099 pass 1100 except TypeError: 1101 pass 1102 1103 def test_extending_list_with_iterator_does_not_segfault(self): 1104 # The code to extend a list with an iterator has a fair 1105 # amount of nontrivial logic in terms of guessing how 1106 # much memory to allocate in advance, "stealing" refs, 1107 # and then shrinking at the end. This is a basic smoke 1108 # test for that scenario. 1109 def gen(): 1110 for i in range(500): 1111 yield i 1112 lst = [0] * 500 1113 for i in range(240): 1114 lst.pop(0) 1115 lst.extend(gen()) 1116 self.assertEqual(len(lst), 760) 1117 1118 @cpython_only 1119 def test_iter_overflow(self): 1120 # Test for the issue 22939 1121 it = iter(UnlimitedSequenceClass()) 1122 # Manually set `it_index` to PY_SSIZE_T_MAX-2 without a loop 1123 it.__setstate__(sys.maxsize - 2) 1124 self.assertEqual(next(it), sys.maxsize - 2) 1125 self.assertEqual(next(it), sys.maxsize - 1) 1126 with self.assertRaises(OverflowError): 1127 next(it) 1128 # Check that Overflow error is always raised 1129 with self.assertRaises(OverflowError): 1130 next(it) 1131 1132 def test_iter_neg_setstate(self): 1133 it = iter(UnlimitedSequenceClass()) 1134 it.__setstate__(-42) 1135 self.assertEqual(next(it), 0) 1136 self.assertEqual(next(it), 1) 1137 1138 def test_free_after_iterating(self): 1139 check_free_after_iterating(self, iter, SequenceClass, (0,)) 1140 1141 def test_error_iter(self): 1142 for typ in (DefaultIterClass, NoIterClass): 1143 self.assertRaises(TypeError, iter, typ()) 1144 self.assertRaises(ZeroDivisionError, iter, BadIterableClass()) 1145 1146 1147if __name__ == "__main__": 1148 unittest.main() 1149