1from collections import namedtuple
2import contextlib
3import itertools
4import os
5import pickle
6import sys
7from textwrap import dedent
8import threading
9import time
10import unittest
11
12from test import support
13from test.support import import_helper
14from test.support import script_helper
15
16
17interpreters = import_helper.import_module('_xxsubinterpreters')
18
19
20##################################
21# helpers
22
23def _captured_script(script):
24    r, w = os.pipe()
25    indented = script.replace('\n', '\n                ')
26    wrapped = dedent(f"""
27        import contextlib
28        with open({w}, 'w', encoding="utf-8") as spipe:
29            with contextlib.redirect_stdout(spipe):
30                {indented}
31        """)
32    return wrapped, open(r, encoding="utf-8")
33
34
35def _run_output(interp, request, shared=None):
36    script, rpipe = _captured_script(request)
37    with rpipe:
38        interpreters.run_string(interp, script, shared)
39        return rpipe.read()
40
41
42def _wait_for_interp_to_run(interp, timeout=None):
43    # bpo-37224: Running this test file in multiprocesses will fail randomly.
44    # The failure reason is that the thread can't acquire the cpu to
45    # run subinterpreter eariler than the main thread in multiprocess.
46    if timeout is None:
47        timeout = support.SHORT_TIMEOUT
48    start_time = time.monotonic()
49    deadline = start_time + timeout
50    while not interpreters.is_running(interp):
51        if time.monotonic() > deadline:
52            raise RuntimeError('interp is not running')
53        time.sleep(0.010)
54
55
56@contextlib.contextmanager
57def _running(interp):
58    r, w = os.pipe()
59    def run():
60        interpreters.run_string(interp, dedent(f"""
61            # wait for "signal"
62            with open({r}, encoding="utf-8") as rpipe:
63                rpipe.read()
64            """))
65
66    t = threading.Thread(target=run)
67    t.start()
68    _wait_for_interp_to_run(interp)
69
70    yield
71
72    with open(w, 'w', encoding="utf-8") as spipe:
73        spipe.write('done')
74    t.join()
75
76
77#@contextmanager
78#def run_threaded(id, source, **shared):
79#    def run():
80#        run_interp(id, source, **shared)
81#    t = threading.Thread(target=run)
82#    t.start()
83#    yield
84#    t.join()
85
86
87def run_interp(id, source, **shared):
88    _run_interp(id, source, shared)
89
90
91def _run_interp(id, source, shared, _mainns={}):
92    source = dedent(source)
93    main = interpreters.get_main()
94    if main == id:
95        if interpreters.get_current() != main:
96            raise RuntimeError
97        # XXX Run a func?
98        exec(source, _mainns)
99    else:
100        interpreters.run_string(id, source, shared)
101
102
103class Interpreter(namedtuple('Interpreter', 'name id')):
104
105    @classmethod
106    def from_raw(cls, raw):
107        if isinstance(raw, cls):
108            return raw
109        elif isinstance(raw, str):
110            return cls(raw)
111        else:
112            raise NotImplementedError
113
114    def __new__(cls, name=None, id=None):
115        main = interpreters.get_main()
116        if id == main:
117            if not name:
118                name = 'main'
119            elif name != 'main':
120                raise ValueError(
121                    'name mismatch (expected "main", got "{}")'.format(name))
122            id = main
123        elif id is not None:
124            if not name:
125                name = 'interp'
126            elif name == 'main':
127                raise ValueError('name mismatch (unexpected "main")')
128            if not isinstance(id, interpreters.InterpreterID):
129                id = interpreters.InterpreterID(id)
130        elif not name or name == 'main':
131            name = 'main'
132            id = main
133        else:
134            id = interpreters.create()
135        self = super().__new__(cls, name, id)
136        return self
137
138
139# XXX expect_channel_closed() is unnecessary once we improve exc propagation.
140
141@contextlib.contextmanager
142def expect_channel_closed():
143    try:
144        yield
145    except interpreters.ChannelClosedError:
146        pass
147    else:
148        assert False, 'channel not closed'
149
150
151class ChannelAction(namedtuple('ChannelAction', 'action end interp')):
152
153    def __new__(cls, action, end=None, interp=None):
154        if not end:
155            end = 'both'
156        if not interp:
157            interp = 'main'
158        self = super().__new__(cls, action, end, interp)
159        return self
160
161    def __init__(self, *args, **kwargs):
162        if self.action == 'use':
163            if self.end not in ('same', 'opposite', 'send', 'recv'):
164                raise ValueError(self.end)
165        elif self.action in ('close', 'force-close'):
166            if self.end not in ('both', 'same', 'opposite', 'send', 'recv'):
167                raise ValueError(self.end)
168        else:
169            raise ValueError(self.action)
170        if self.interp not in ('main', 'same', 'other', 'extra'):
171            raise ValueError(self.interp)
172
173    def resolve_end(self, end):
174        if self.end == 'same':
175            return end
176        elif self.end == 'opposite':
177            return 'recv' if end == 'send' else 'send'
178        else:
179            return self.end
180
181    def resolve_interp(self, interp, other, extra):
182        if self.interp == 'same':
183            return interp
184        elif self.interp == 'other':
185            if other is None:
186                raise RuntimeError
187            return other
188        elif self.interp == 'extra':
189            if extra is None:
190                raise RuntimeError
191            return extra
192        elif self.interp == 'main':
193            if interp.name == 'main':
194                return interp
195            elif other and other.name == 'main':
196                return other
197            else:
198                raise RuntimeError
199        # Per __init__(), there aren't any others.
200
201
202class ChannelState(namedtuple('ChannelState', 'pending closed')):
203
204    def __new__(cls, pending=0, *, closed=False):
205        self = super().__new__(cls, pending, closed)
206        return self
207
208    def incr(self):
209        return type(self)(self.pending + 1, closed=self.closed)
210
211    def decr(self):
212        return type(self)(self.pending - 1, closed=self.closed)
213
214    def close(self, *, force=True):
215        if self.closed:
216            if not force or self.pending == 0:
217                return self
218        return type(self)(0 if force else self.pending, closed=True)
219
220
221def run_action(cid, action, end, state, *, hideclosed=True):
222    if state.closed:
223        if action == 'use' and end == 'recv' and state.pending:
224            expectfail = False
225        else:
226            expectfail = True
227    else:
228        expectfail = False
229
230    try:
231        result = _run_action(cid, action, end, state)
232    except interpreters.ChannelClosedError:
233        if not hideclosed and not expectfail:
234            raise
235        result = state.close()
236    else:
237        if expectfail:
238            raise ...  # XXX
239    return result
240
241
242def _run_action(cid, action, end, state):
243    if action == 'use':
244        if end == 'send':
245            interpreters.channel_send(cid, b'spam')
246            return state.incr()
247        elif end == 'recv':
248            if not state.pending:
249                try:
250                    interpreters.channel_recv(cid)
251                except interpreters.ChannelEmptyError:
252                    return state
253                else:
254                    raise Exception('expected ChannelEmptyError')
255            else:
256                interpreters.channel_recv(cid)
257                return state.decr()
258        else:
259            raise ValueError(end)
260    elif action == 'close':
261        kwargs = {}
262        if end in ('recv', 'send'):
263            kwargs[end] = True
264        interpreters.channel_close(cid, **kwargs)
265        return state.close()
266    elif action == 'force-close':
267        kwargs = {
268            'force': True,
269            }
270        if end in ('recv', 'send'):
271            kwargs[end] = True
272        interpreters.channel_close(cid, **kwargs)
273        return state.close(force=True)
274    else:
275        raise ValueError(action)
276
277
278def clean_up_interpreters():
279    for id in interpreters.list_all():
280        if id == 0:  # main
281            continue
282        try:
283            interpreters.destroy(id)
284        except RuntimeError:
285            pass  # already destroyed
286
287
288def clean_up_channels():
289    for cid in interpreters.channel_list_all():
290        try:
291            interpreters.channel_destroy(cid)
292        except interpreters.ChannelNotFoundError:
293            pass  # already destroyed
294
295
296class TestBase(unittest.TestCase):
297
298    def tearDown(self):
299        clean_up_interpreters()
300        clean_up_channels()
301
302
303##################################
304# misc. tests
305
306class IsShareableTests(unittest.TestCase):
307
308    def test_default_shareables(self):
309        shareables = [
310                # singletons
311                None,
312                # builtin objects
313                b'spam',
314                'spam',
315                10,
316                -10,
317                ]
318        for obj in shareables:
319            with self.subTest(obj):
320                self.assertTrue(
321                    interpreters.is_shareable(obj))
322
323    def test_not_shareable(self):
324        class Cheese:
325            def __init__(self, name):
326                self.name = name
327            def __str__(self):
328                return self.name
329
330        class SubBytes(bytes):
331            """A subclass of a shareable type."""
332
333        not_shareables = [
334                # singletons
335                True,
336                False,
337                NotImplemented,
338                ...,
339                # builtin types and objects
340                type,
341                object,
342                object(),
343                Exception(),
344                100.0,
345                # user-defined types and objects
346                Cheese,
347                Cheese('Wensleydale'),
348                SubBytes(b'spam'),
349                ]
350        for obj in not_shareables:
351            with self.subTest(repr(obj)):
352                self.assertFalse(
353                    interpreters.is_shareable(obj))
354
355
356class ShareableTypeTests(unittest.TestCase):
357
358    def setUp(self):
359        super().setUp()
360        self.cid = interpreters.channel_create()
361
362    def tearDown(self):
363        interpreters.channel_destroy(self.cid)
364        super().tearDown()
365
366    def _assert_values(self, values):
367        for obj in values:
368            with self.subTest(obj):
369                interpreters.channel_send(self.cid, obj)
370                got = interpreters.channel_recv(self.cid)
371
372                self.assertEqual(got, obj)
373                self.assertIs(type(got), type(obj))
374                # XXX Check the following in the channel tests?
375                #self.assertIsNot(got, obj)
376
377    def test_singletons(self):
378        for obj in [None]:
379            with self.subTest(obj):
380                interpreters.channel_send(self.cid, obj)
381                got = interpreters.channel_recv(self.cid)
382
383                # XXX What about between interpreters?
384                self.assertIs(got, obj)
385
386    def test_types(self):
387        self._assert_values([
388            b'spam',
389            9999,
390            self.cid,
391            ])
392
393    def test_bytes(self):
394        self._assert_values(i.to_bytes(2, 'little', signed=True)
395                            for i in range(-1, 258))
396
397    def test_strs(self):
398        self._assert_values(['hello world', '你好世界', ''])
399
400    def test_int(self):
401        self._assert_values(itertools.chain(range(-1, 258),
402                                            [sys.maxsize, -sys.maxsize - 1]))
403
404    def test_non_shareable_int(self):
405        ints = [
406            sys.maxsize + 1,
407            -sys.maxsize - 2,
408            2**1000,
409        ]
410        for i in ints:
411            with self.subTest(i):
412                with self.assertRaises(OverflowError):
413                    interpreters.channel_send(self.cid, i)
414
415
416##################################
417# interpreter tests
418
419class ListAllTests(TestBase):
420
421    def test_initial(self):
422        main = interpreters.get_main()
423        ids = interpreters.list_all()
424        self.assertEqual(ids, [main])
425
426    def test_after_creating(self):
427        main = interpreters.get_main()
428        first = interpreters.create()
429        second = interpreters.create()
430        ids = interpreters.list_all()
431        self.assertEqual(ids, [main, first, second])
432
433    def test_after_destroying(self):
434        main = interpreters.get_main()
435        first = interpreters.create()
436        second = interpreters.create()
437        interpreters.destroy(first)
438        ids = interpreters.list_all()
439        self.assertEqual(ids, [main, second])
440
441
442class GetCurrentTests(TestBase):
443
444    def test_main(self):
445        main = interpreters.get_main()
446        cur = interpreters.get_current()
447        self.assertEqual(cur, main)
448        self.assertIsInstance(cur, interpreters.InterpreterID)
449
450    def test_subinterpreter(self):
451        main = interpreters.get_main()
452        interp = interpreters.create()
453        out = _run_output(interp, dedent("""
454            import _xxsubinterpreters as _interpreters
455            cur = _interpreters.get_current()
456            print(cur)
457            assert isinstance(cur, _interpreters.InterpreterID)
458            """))
459        cur = int(out.strip())
460        _, expected = interpreters.list_all()
461        self.assertEqual(cur, expected)
462        self.assertNotEqual(cur, main)
463
464
465class GetMainTests(TestBase):
466
467    def test_from_main(self):
468        [expected] = interpreters.list_all()
469        main = interpreters.get_main()
470        self.assertEqual(main, expected)
471        self.assertIsInstance(main, interpreters.InterpreterID)
472
473    def test_from_subinterpreter(self):
474        [expected] = interpreters.list_all()
475        interp = interpreters.create()
476        out = _run_output(interp, dedent("""
477            import _xxsubinterpreters as _interpreters
478            main = _interpreters.get_main()
479            print(main)
480            assert isinstance(main, _interpreters.InterpreterID)
481            """))
482        main = int(out.strip())
483        self.assertEqual(main, expected)
484
485
486class IsRunningTests(TestBase):
487
488    def test_main(self):
489        main = interpreters.get_main()
490        self.assertTrue(interpreters.is_running(main))
491
492    @unittest.skip('Fails on FreeBSD')
493    def test_subinterpreter(self):
494        interp = interpreters.create()
495        self.assertFalse(interpreters.is_running(interp))
496
497        with _running(interp):
498            self.assertTrue(interpreters.is_running(interp))
499        self.assertFalse(interpreters.is_running(interp))
500
501    def test_from_subinterpreter(self):
502        interp = interpreters.create()
503        out = _run_output(interp, dedent(f"""
504            import _xxsubinterpreters as _interpreters
505            if _interpreters.is_running({interp}):
506                print(True)
507            else:
508                print(False)
509            """))
510        self.assertEqual(out.strip(), 'True')
511
512    def test_already_destroyed(self):
513        interp = interpreters.create()
514        interpreters.destroy(interp)
515        with self.assertRaises(RuntimeError):
516            interpreters.is_running(interp)
517
518    def test_does_not_exist(self):
519        with self.assertRaises(RuntimeError):
520            interpreters.is_running(1_000_000)
521
522    def test_bad_id(self):
523        with self.assertRaises(ValueError):
524            interpreters.is_running(-1)
525
526
527class InterpreterIDTests(TestBase):
528
529    def test_with_int(self):
530        id = interpreters.InterpreterID(10, force=True)
531
532        self.assertEqual(int(id), 10)
533
534    def test_coerce_id(self):
535        class Int(str):
536            def __index__(self):
537                return 10
538
539        id = interpreters.InterpreterID(Int(), force=True)
540        self.assertEqual(int(id), 10)
541
542    def test_bad_id(self):
543        self.assertRaises(TypeError, interpreters.InterpreterID, object())
544        self.assertRaises(TypeError, interpreters.InterpreterID, 10.0)
545        self.assertRaises(TypeError, interpreters.InterpreterID, '10')
546        self.assertRaises(TypeError, interpreters.InterpreterID, b'10')
547        self.assertRaises(ValueError, interpreters.InterpreterID, -1)
548        self.assertRaises(OverflowError, interpreters.InterpreterID, 2**64)
549
550    def test_does_not_exist(self):
551        id = interpreters.channel_create()
552        with self.assertRaises(RuntimeError):
553            interpreters.InterpreterID(int(id) + 1)  # unforced
554
555    def test_str(self):
556        id = interpreters.InterpreterID(10, force=True)
557        self.assertEqual(str(id), '10')
558
559    def test_repr(self):
560        id = interpreters.InterpreterID(10, force=True)
561        self.assertEqual(repr(id), 'InterpreterID(10)')
562
563    def test_equality(self):
564        id1 = interpreters.create()
565        id2 = interpreters.InterpreterID(int(id1))
566        id3 = interpreters.create()
567
568        self.assertTrue(id1 == id1)
569        self.assertTrue(id1 == id2)
570        self.assertTrue(id1 == int(id1))
571        self.assertTrue(int(id1) == id1)
572        self.assertTrue(id1 == float(int(id1)))
573        self.assertTrue(float(int(id1)) == id1)
574        self.assertFalse(id1 == float(int(id1)) + 0.1)
575        self.assertFalse(id1 == str(int(id1)))
576        self.assertFalse(id1 == 2**1000)
577        self.assertFalse(id1 == float('inf'))
578        self.assertFalse(id1 == 'spam')
579        self.assertFalse(id1 == id3)
580
581        self.assertFalse(id1 != id1)
582        self.assertFalse(id1 != id2)
583        self.assertTrue(id1 != id3)
584
585
586class CreateTests(TestBase):
587
588    def test_in_main(self):
589        id = interpreters.create()
590        self.assertIsInstance(id, interpreters.InterpreterID)
591
592        self.assertIn(id, interpreters.list_all())
593
594    @unittest.skip('enable this test when working on pystate.c')
595    def test_unique_id(self):
596        seen = set()
597        for _ in range(100):
598            id = interpreters.create()
599            interpreters.destroy(id)
600            seen.add(id)
601
602        self.assertEqual(len(seen), 100)
603
604    def test_in_thread(self):
605        lock = threading.Lock()
606        id = None
607        def f():
608            nonlocal id
609            id = interpreters.create()
610            lock.acquire()
611            lock.release()
612
613        t = threading.Thread(target=f)
614        with lock:
615            t.start()
616        t.join()
617        self.assertIn(id, interpreters.list_all())
618
619    def test_in_subinterpreter(self):
620        main, = interpreters.list_all()
621        id1 = interpreters.create()
622        out = _run_output(id1, dedent("""
623            import _xxsubinterpreters as _interpreters
624            id = _interpreters.create()
625            print(id)
626            assert isinstance(id, _interpreters.InterpreterID)
627            """))
628        id2 = int(out.strip())
629
630        self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
631
632    def test_in_threaded_subinterpreter(self):
633        main, = interpreters.list_all()
634        id1 = interpreters.create()
635        id2 = None
636        def f():
637            nonlocal id2
638            out = _run_output(id1, dedent("""
639                import _xxsubinterpreters as _interpreters
640                id = _interpreters.create()
641                print(id)
642                """))
643            id2 = int(out.strip())
644
645        t = threading.Thread(target=f)
646        t.start()
647        t.join()
648
649        self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
650
651    def test_after_destroy_all(self):
652        before = set(interpreters.list_all())
653        # Create 3 subinterpreters.
654        ids = []
655        for _ in range(3):
656            id = interpreters.create()
657            ids.append(id)
658        # Now destroy them.
659        for id in ids:
660            interpreters.destroy(id)
661        # Finally, create another.
662        id = interpreters.create()
663        self.assertEqual(set(interpreters.list_all()), before | {id})
664
665    def test_after_destroy_some(self):
666        before = set(interpreters.list_all())
667        # Create 3 subinterpreters.
668        id1 = interpreters.create()
669        id2 = interpreters.create()
670        id3 = interpreters.create()
671        # Now destroy 2 of them.
672        interpreters.destroy(id1)
673        interpreters.destroy(id3)
674        # Finally, create another.
675        id = interpreters.create()
676        self.assertEqual(set(interpreters.list_all()), before | {id, id2})
677
678
679class DestroyTests(TestBase):
680
681    def test_one(self):
682        id1 = interpreters.create()
683        id2 = interpreters.create()
684        id3 = interpreters.create()
685        self.assertIn(id2, interpreters.list_all())
686        interpreters.destroy(id2)
687        self.assertNotIn(id2, interpreters.list_all())
688        self.assertIn(id1, interpreters.list_all())
689        self.assertIn(id3, interpreters.list_all())
690
691    def test_all(self):
692        before = set(interpreters.list_all())
693        ids = set()
694        for _ in range(3):
695            id = interpreters.create()
696            ids.add(id)
697        self.assertEqual(set(interpreters.list_all()), before | ids)
698        for id in ids:
699            interpreters.destroy(id)
700        self.assertEqual(set(interpreters.list_all()), before)
701
702    def test_main(self):
703        main, = interpreters.list_all()
704        with self.assertRaises(RuntimeError):
705            interpreters.destroy(main)
706
707        def f():
708            with self.assertRaises(RuntimeError):
709                interpreters.destroy(main)
710
711        t = threading.Thread(target=f)
712        t.start()
713        t.join()
714
715    def test_already_destroyed(self):
716        id = interpreters.create()
717        interpreters.destroy(id)
718        with self.assertRaises(RuntimeError):
719            interpreters.destroy(id)
720
721    def test_does_not_exist(self):
722        with self.assertRaises(RuntimeError):
723            interpreters.destroy(1_000_000)
724
725    def test_bad_id(self):
726        with self.assertRaises(ValueError):
727            interpreters.destroy(-1)
728
729    def test_from_current(self):
730        main, = interpreters.list_all()
731        id = interpreters.create()
732        script = dedent(f"""
733            import _xxsubinterpreters as _interpreters
734            try:
735                _interpreters.destroy({id})
736            except RuntimeError:
737                pass
738            """)
739
740        interpreters.run_string(id, script)
741        self.assertEqual(set(interpreters.list_all()), {main, id})
742
743    def test_from_sibling(self):
744        main, = interpreters.list_all()
745        id1 = interpreters.create()
746        id2 = interpreters.create()
747        script = dedent(f"""
748            import _xxsubinterpreters as _interpreters
749            _interpreters.destroy({id2})
750            """)
751        interpreters.run_string(id1, script)
752
753        self.assertEqual(set(interpreters.list_all()), {main, id1})
754
755    def test_from_other_thread(self):
756        id = interpreters.create()
757        def f():
758            interpreters.destroy(id)
759
760        t = threading.Thread(target=f)
761        t.start()
762        t.join()
763
764    def test_still_running(self):
765        main, = interpreters.list_all()
766        interp = interpreters.create()
767        with _running(interp):
768            self.assertTrue(interpreters.is_running(interp),
769                            msg=f"Interp {interp} should be running before destruction.")
770
771            with self.assertRaises(RuntimeError,
772                                   msg=f"Should not be able to destroy interp {interp} while it's still running."):
773                interpreters.destroy(interp)
774            self.assertTrue(interpreters.is_running(interp))
775
776
777class RunStringTests(TestBase):
778
779    def setUp(self):
780        super().setUp()
781        self.id = interpreters.create()
782
783    def test_success(self):
784        script, file = _captured_script('print("it worked!", end="")')
785        with file:
786            interpreters.run_string(self.id, script)
787            out = file.read()
788
789        self.assertEqual(out, 'it worked!')
790
791    def test_in_thread(self):
792        script, file = _captured_script('print("it worked!", end="")')
793        with file:
794            def f():
795                interpreters.run_string(self.id, script)
796
797            t = threading.Thread(target=f)
798            t.start()
799            t.join()
800            out = file.read()
801
802        self.assertEqual(out, 'it worked!')
803
804    def test_create_thread(self):
805        subinterp = interpreters.create(isolated=False)
806        script, file = _captured_script("""
807            import threading
808            def f():
809                print('it worked!', end='')
810
811            t = threading.Thread(target=f)
812            t.start()
813            t.join()
814            """)
815        with file:
816            interpreters.run_string(subinterp, script)
817            out = file.read()
818
819        self.assertEqual(out, 'it worked!')
820
821    @support.requires_fork()
822    def test_fork(self):
823        import tempfile
824        with tempfile.NamedTemporaryFile('w+', encoding="utf-8") as file:
825            file.write('')
826            file.flush()
827
828            expected = 'spam spam spam spam spam'
829            script = dedent(f"""
830                import os
831                try:
832                    os.fork()
833                except RuntimeError:
834                    with open('{file.name}', 'w', encoding='utf-8') as out:
835                        out.write('{expected}')
836                """)
837            interpreters.run_string(self.id, script)
838
839            file.seek(0)
840            content = file.read()
841            self.assertEqual(content, expected)
842
843    def test_already_running(self):
844        with _running(self.id):
845            with self.assertRaises(RuntimeError):
846                interpreters.run_string(self.id, 'print("spam")')
847
848    def test_does_not_exist(self):
849        id = 0
850        while id in interpreters.list_all():
851            id += 1
852        with self.assertRaises(RuntimeError):
853            interpreters.run_string(id, 'print("spam")')
854
855    def test_error_id(self):
856        with self.assertRaises(ValueError):
857            interpreters.run_string(-1, 'print("spam")')
858
859    def test_bad_id(self):
860        with self.assertRaises(TypeError):
861            interpreters.run_string('spam', 'print("spam")')
862
863    def test_bad_script(self):
864        with self.assertRaises(TypeError):
865            interpreters.run_string(self.id, 10)
866
867    def test_bytes_for_script(self):
868        with self.assertRaises(TypeError):
869            interpreters.run_string(self.id, b'print("spam")')
870
871    @contextlib.contextmanager
872    def assert_run_failed(self, exctype, msg=None):
873        with self.assertRaises(interpreters.RunFailedError) as caught:
874            yield
875        if msg is None:
876            self.assertEqual(str(caught.exception).split(':')[0],
877                             str(exctype))
878        else:
879            self.assertEqual(str(caught.exception),
880                             "{}: {}".format(exctype, msg))
881
882    def test_invalid_syntax(self):
883        with self.assert_run_failed(SyntaxError):
884            # missing close paren
885            interpreters.run_string(self.id, 'print("spam"')
886
887    def test_failure(self):
888        with self.assert_run_failed(Exception, 'spam'):
889            interpreters.run_string(self.id, 'raise Exception("spam")')
890
891    def test_SystemExit(self):
892        with self.assert_run_failed(SystemExit, '42'):
893            interpreters.run_string(self.id, 'raise SystemExit(42)')
894
895    def test_sys_exit(self):
896        with self.assert_run_failed(SystemExit):
897            interpreters.run_string(self.id, dedent("""
898                import sys
899                sys.exit()
900                """))
901
902        with self.assert_run_failed(SystemExit, '42'):
903            interpreters.run_string(self.id, dedent("""
904                import sys
905                sys.exit(42)
906                """))
907
908    def test_with_shared(self):
909        r, w = os.pipe()
910
911        shared = {
912                'spam': b'ham',
913                'eggs': b'-1',
914                'cheddar': None,
915                }
916        script = dedent(f"""
917            eggs = int(eggs)
918            spam = 42
919            result = spam + eggs
920
921            ns = dict(vars())
922            del ns['__builtins__']
923            import pickle
924            with open({w}, 'wb') as chan:
925                pickle.dump(ns, chan)
926            """)
927        interpreters.run_string(self.id, script, shared)
928        with open(r, 'rb') as chan:
929            ns = pickle.load(chan)
930
931        self.assertEqual(ns['spam'], 42)
932        self.assertEqual(ns['eggs'], -1)
933        self.assertEqual(ns['result'], 41)
934        self.assertIsNone(ns['cheddar'])
935
936    def test_shared_overwrites(self):
937        interpreters.run_string(self.id, dedent("""
938            spam = 'eggs'
939            ns1 = dict(vars())
940            del ns1['__builtins__']
941            """))
942
943        shared = {'spam': b'ham'}
944        script = dedent(f"""
945            ns2 = dict(vars())
946            del ns2['__builtins__']
947        """)
948        interpreters.run_string(self.id, script, shared)
949
950        r, w = os.pipe()
951        script = dedent(f"""
952            ns = dict(vars())
953            del ns['__builtins__']
954            import pickle
955            with open({w}, 'wb') as chan:
956                pickle.dump(ns, chan)
957            """)
958        interpreters.run_string(self.id, script)
959        with open(r, 'rb') as chan:
960            ns = pickle.load(chan)
961
962        self.assertEqual(ns['ns1']['spam'], 'eggs')
963        self.assertEqual(ns['ns2']['spam'], b'ham')
964        self.assertEqual(ns['spam'], b'ham')
965
966    def test_shared_overwrites_default_vars(self):
967        r, w = os.pipe()
968
969        shared = {'__name__': b'not __main__'}
970        script = dedent(f"""
971            spam = 42
972
973            ns = dict(vars())
974            del ns['__builtins__']
975            import pickle
976            with open({w}, 'wb') as chan:
977                pickle.dump(ns, chan)
978            """)
979        interpreters.run_string(self.id, script, shared)
980        with open(r, 'rb') as chan:
981            ns = pickle.load(chan)
982
983        self.assertEqual(ns['__name__'], b'not __main__')
984
985    def test_main_reused(self):
986        r, w = os.pipe()
987        interpreters.run_string(self.id, dedent(f"""
988            spam = True
989
990            ns = dict(vars())
991            del ns['__builtins__']
992            import pickle
993            with open({w}, 'wb') as chan:
994                pickle.dump(ns, chan)
995            del ns, pickle, chan
996            """))
997        with open(r, 'rb') as chan:
998            ns1 = pickle.load(chan)
999
1000        r, w = os.pipe()
1001        interpreters.run_string(self.id, dedent(f"""
1002            eggs = False
1003
1004            ns = dict(vars())
1005            del ns['__builtins__']
1006            import pickle
1007            with open({w}, 'wb') as chan:
1008                pickle.dump(ns, chan)
1009            """))
1010        with open(r, 'rb') as chan:
1011            ns2 = pickle.load(chan)
1012
1013        self.assertIn('spam', ns1)
1014        self.assertNotIn('eggs', ns1)
1015        self.assertIn('eggs', ns2)
1016        self.assertIn('spam', ns2)
1017
1018    def test_execution_namespace_is_main(self):
1019        r, w = os.pipe()
1020
1021        script = dedent(f"""
1022            spam = 42
1023
1024            ns = dict(vars())
1025            ns['__builtins__'] = str(ns['__builtins__'])
1026            import pickle
1027            with open({w}, 'wb') as chan:
1028                pickle.dump(ns, chan)
1029            """)
1030        interpreters.run_string(self.id, script)
1031        with open(r, 'rb') as chan:
1032            ns = pickle.load(chan)
1033
1034        ns.pop('__builtins__')
1035        ns.pop('__loader__')
1036        self.assertEqual(ns, {
1037            '__name__': '__main__',
1038            '__annotations__': {},
1039            '__doc__': None,
1040            '__package__': None,
1041            '__spec__': None,
1042            'spam': 42,
1043            })
1044
1045    # XXX Fix this test!
1046    @unittest.skip('blocking forever')
1047    def test_still_running_at_exit(self):
1048        script = dedent(f"""
1049        from textwrap import dedent
1050        import threading
1051        import _xxsubinterpreters as _interpreters
1052        id = _interpreters.create()
1053        def f():
1054            _interpreters.run_string(id, dedent('''
1055                import time
1056                # Give plenty of time for the main interpreter to finish.
1057                time.sleep(1_000_000)
1058                '''))
1059
1060        t = threading.Thread(target=f)
1061        t.start()
1062        """)
1063        with support.temp_dir() as dirname:
1064            filename = script_helper.make_script(dirname, 'interp', script)
1065            with script_helper.spawn_python(filename) as proc:
1066                retcode = proc.wait()
1067
1068        self.assertEqual(retcode, 0)
1069
1070
1071##################################
1072# channel tests
1073
1074class ChannelIDTests(TestBase):
1075
1076    def test_default_kwargs(self):
1077        cid = interpreters._channel_id(10, force=True)
1078
1079        self.assertEqual(int(cid), 10)
1080        self.assertEqual(cid.end, 'both')
1081
1082    def test_with_kwargs(self):
1083        cid = interpreters._channel_id(10, send=True, force=True)
1084        self.assertEqual(cid.end, 'send')
1085
1086        cid = interpreters._channel_id(10, send=True, recv=False, force=True)
1087        self.assertEqual(cid.end, 'send')
1088
1089        cid = interpreters._channel_id(10, recv=True, force=True)
1090        self.assertEqual(cid.end, 'recv')
1091
1092        cid = interpreters._channel_id(10, recv=True, send=False, force=True)
1093        self.assertEqual(cid.end, 'recv')
1094
1095        cid = interpreters._channel_id(10, send=True, recv=True, force=True)
1096        self.assertEqual(cid.end, 'both')
1097
1098    def test_coerce_id(self):
1099        class Int(str):
1100            def __index__(self):
1101                return 10
1102
1103        cid = interpreters._channel_id(Int(), force=True)
1104        self.assertEqual(int(cid), 10)
1105
1106    def test_bad_id(self):
1107        self.assertRaises(TypeError, interpreters._channel_id, object())
1108        self.assertRaises(TypeError, interpreters._channel_id, 10.0)
1109        self.assertRaises(TypeError, interpreters._channel_id, '10')
1110        self.assertRaises(TypeError, interpreters._channel_id, b'10')
1111        self.assertRaises(ValueError, interpreters._channel_id, -1)
1112        self.assertRaises(OverflowError, interpreters._channel_id, 2**64)
1113
1114    def test_bad_kwargs(self):
1115        with self.assertRaises(ValueError):
1116            interpreters._channel_id(10, send=False, recv=False)
1117
1118    def test_does_not_exist(self):
1119        cid = interpreters.channel_create()
1120        with self.assertRaises(interpreters.ChannelNotFoundError):
1121            interpreters._channel_id(int(cid) + 1)  # unforced
1122
1123    def test_str(self):
1124        cid = interpreters._channel_id(10, force=True)
1125        self.assertEqual(str(cid), '10')
1126
1127    def test_repr(self):
1128        cid = interpreters._channel_id(10, force=True)
1129        self.assertEqual(repr(cid), 'ChannelID(10)')
1130
1131        cid = interpreters._channel_id(10, send=True, force=True)
1132        self.assertEqual(repr(cid), 'ChannelID(10, send=True)')
1133
1134        cid = interpreters._channel_id(10, recv=True, force=True)
1135        self.assertEqual(repr(cid), 'ChannelID(10, recv=True)')
1136
1137        cid = interpreters._channel_id(10, send=True, recv=True, force=True)
1138        self.assertEqual(repr(cid), 'ChannelID(10)')
1139
1140    def test_equality(self):
1141        cid1 = interpreters.channel_create()
1142        cid2 = interpreters._channel_id(int(cid1))
1143        cid3 = interpreters.channel_create()
1144
1145        self.assertTrue(cid1 == cid1)
1146        self.assertTrue(cid1 == cid2)
1147        self.assertTrue(cid1 == int(cid1))
1148        self.assertTrue(int(cid1) == cid1)
1149        self.assertTrue(cid1 == float(int(cid1)))
1150        self.assertTrue(float(int(cid1)) == cid1)
1151        self.assertFalse(cid1 == float(int(cid1)) + 0.1)
1152        self.assertFalse(cid1 == str(int(cid1)))
1153        self.assertFalse(cid1 == 2**1000)
1154        self.assertFalse(cid1 == float('inf'))
1155        self.assertFalse(cid1 == 'spam')
1156        self.assertFalse(cid1 == cid3)
1157
1158        self.assertFalse(cid1 != cid1)
1159        self.assertFalse(cid1 != cid2)
1160        self.assertTrue(cid1 != cid3)
1161
1162
1163class ChannelTests(TestBase):
1164
1165    def test_create_cid(self):
1166        cid = interpreters.channel_create()
1167        self.assertIsInstance(cid, interpreters.ChannelID)
1168
1169    def test_sequential_ids(self):
1170        before = interpreters.channel_list_all()
1171        id1 = interpreters.channel_create()
1172        id2 = interpreters.channel_create()
1173        id3 = interpreters.channel_create()
1174        after = interpreters.channel_list_all()
1175
1176        self.assertEqual(id2, int(id1) + 1)
1177        self.assertEqual(id3, int(id2) + 1)
1178        self.assertEqual(set(after) - set(before), {id1, id2, id3})
1179
1180    def test_ids_global(self):
1181        id1 = interpreters.create()
1182        out = _run_output(id1, dedent("""
1183            import _xxsubinterpreters as _interpreters
1184            cid = _interpreters.channel_create()
1185            print(cid)
1186            """))
1187        cid1 = int(out.strip())
1188
1189        id2 = interpreters.create()
1190        out = _run_output(id2, dedent("""
1191            import _xxsubinterpreters as _interpreters
1192            cid = _interpreters.channel_create()
1193            print(cid)
1194            """))
1195        cid2 = int(out.strip())
1196
1197        self.assertEqual(cid2, int(cid1) + 1)
1198
1199    def test_channel_list_interpreters_none(self):
1200        """Test listing interpreters for a channel with no associations."""
1201        # Test for channel with no associated interpreters.
1202        cid = interpreters.channel_create()
1203        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1204        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1205        self.assertEqual(send_interps, [])
1206        self.assertEqual(recv_interps, [])
1207
1208    def test_channel_list_interpreters_basic(self):
1209        """Test basic listing channel interpreters."""
1210        interp0 = interpreters.get_main()
1211        cid = interpreters.channel_create()
1212        interpreters.channel_send(cid, "send")
1213        # Test for a channel that has one end associated to an interpreter.
1214        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1215        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1216        self.assertEqual(send_interps, [interp0])
1217        self.assertEqual(recv_interps, [])
1218
1219        interp1 = interpreters.create()
1220        _run_output(interp1, dedent(f"""
1221            import _xxsubinterpreters as _interpreters
1222            obj = _interpreters.channel_recv({cid})
1223            """))
1224        # Test for channel that has both ends associated to an interpreter.
1225        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1226        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1227        self.assertEqual(send_interps, [interp0])
1228        self.assertEqual(recv_interps, [interp1])
1229
1230    def test_channel_list_interpreters_multiple(self):
1231        """Test listing interpreters for a channel with many associations."""
1232        interp0 = interpreters.get_main()
1233        interp1 = interpreters.create()
1234        interp2 = interpreters.create()
1235        interp3 = interpreters.create()
1236        cid = interpreters.channel_create()
1237
1238        interpreters.channel_send(cid, "send")
1239        _run_output(interp1, dedent(f"""
1240            import _xxsubinterpreters as _interpreters
1241            _interpreters.channel_send({cid}, "send")
1242            """))
1243        _run_output(interp2, dedent(f"""
1244            import _xxsubinterpreters as _interpreters
1245            obj = _interpreters.channel_recv({cid})
1246            """))
1247        _run_output(interp3, dedent(f"""
1248            import _xxsubinterpreters as _interpreters
1249            obj = _interpreters.channel_recv({cid})
1250            """))
1251        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1252        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1253        self.assertEqual(set(send_interps), {interp0, interp1})
1254        self.assertEqual(set(recv_interps), {interp2, interp3})
1255
1256    def test_channel_list_interpreters_destroyed(self):
1257        """Test listing channel interpreters with a destroyed interpreter."""
1258        interp0 = interpreters.get_main()
1259        interp1 = interpreters.create()
1260        cid = interpreters.channel_create()
1261        interpreters.channel_send(cid, "send")
1262        _run_output(interp1, dedent(f"""
1263            import _xxsubinterpreters as _interpreters
1264            obj = _interpreters.channel_recv({cid})
1265            """))
1266        # Should be one interpreter associated with each end.
1267        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1268        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1269        self.assertEqual(send_interps, [interp0])
1270        self.assertEqual(recv_interps, [interp1])
1271
1272        interpreters.destroy(interp1)
1273        # Destroyed interpreter should not be listed.
1274        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1275        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1276        self.assertEqual(send_interps, [interp0])
1277        self.assertEqual(recv_interps, [])
1278
1279    def test_channel_list_interpreters_released(self):
1280        """Test listing channel interpreters with a released channel."""
1281        # Set up one channel with main interpreter on the send end and two
1282        # subinterpreters on the receive end.
1283        interp0 = interpreters.get_main()
1284        interp1 = interpreters.create()
1285        interp2 = interpreters.create()
1286        cid = interpreters.channel_create()
1287        interpreters.channel_send(cid, "data")
1288        _run_output(interp1, dedent(f"""
1289            import _xxsubinterpreters as _interpreters
1290            obj = _interpreters.channel_recv({cid})
1291            """))
1292        interpreters.channel_send(cid, "data")
1293        _run_output(interp2, dedent(f"""
1294            import _xxsubinterpreters as _interpreters
1295            obj = _interpreters.channel_recv({cid})
1296            """))
1297        # Check the setup.
1298        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1299        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1300        self.assertEqual(len(send_interps), 1)
1301        self.assertEqual(len(recv_interps), 2)
1302
1303        # Release the main interpreter from the send end.
1304        interpreters.channel_release(cid, send=True)
1305        # Send end should have no associated interpreters.
1306        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1307        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1308        self.assertEqual(len(send_interps), 0)
1309        self.assertEqual(len(recv_interps), 2)
1310
1311        # Release one of the subinterpreters from the receive end.
1312        _run_output(interp2, dedent(f"""
1313            import _xxsubinterpreters as _interpreters
1314            _interpreters.channel_release({cid})
1315            """))
1316        # Receive end should have the released interpreter removed.
1317        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1318        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1319        self.assertEqual(len(send_interps), 0)
1320        self.assertEqual(recv_interps, [interp1])
1321
1322    def test_channel_list_interpreters_closed(self):
1323        """Test listing channel interpreters with a closed channel."""
1324        interp0 = interpreters.get_main()
1325        interp1 = interpreters.create()
1326        cid = interpreters.channel_create()
1327        # Put something in the channel so that it's not empty.
1328        interpreters.channel_send(cid, "send")
1329
1330        # Check initial state.
1331        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1332        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1333        self.assertEqual(len(send_interps), 1)
1334        self.assertEqual(len(recv_interps), 0)
1335
1336        # Force close the channel.
1337        interpreters.channel_close(cid, force=True)
1338        # Both ends should raise an error.
1339        with self.assertRaises(interpreters.ChannelClosedError):
1340            interpreters.channel_list_interpreters(cid, send=True)
1341        with self.assertRaises(interpreters.ChannelClosedError):
1342            interpreters.channel_list_interpreters(cid, send=False)
1343
1344    def test_channel_list_interpreters_closed_send_end(self):
1345        """Test listing channel interpreters with a channel's send end closed."""
1346        interp0 = interpreters.get_main()
1347        interp1 = interpreters.create()
1348        cid = interpreters.channel_create()
1349        # Put something in the channel so that it's not empty.
1350        interpreters.channel_send(cid, "send")
1351
1352        # Check initial state.
1353        send_interps = interpreters.channel_list_interpreters(cid, send=True)
1354        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1355        self.assertEqual(len(send_interps), 1)
1356        self.assertEqual(len(recv_interps), 0)
1357
1358        # Close the send end of the channel.
1359        interpreters.channel_close(cid, send=True)
1360        # Send end should raise an error.
1361        with self.assertRaises(interpreters.ChannelClosedError):
1362            interpreters.channel_list_interpreters(cid, send=True)
1363        # Receive end should not be closed (since channel is not empty).
1364        recv_interps = interpreters.channel_list_interpreters(cid, send=False)
1365        self.assertEqual(len(recv_interps), 0)
1366
1367        # Close the receive end of the channel from a subinterpreter.
1368        _run_output(interp1, dedent(f"""
1369            import _xxsubinterpreters as _interpreters
1370            _interpreters.channel_close({cid}, force=True)
1371            """))
1372        # Both ends should raise an error.
1373        with self.assertRaises(interpreters.ChannelClosedError):
1374            interpreters.channel_list_interpreters(cid, send=True)
1375        with self.assertRaises(interpreters.ChannelClosedError):
1376            interpreters.channel_list_interpreters(cid, send=False)
1377
1378    ####################
1379
1380    def test_send_recv_main(self):
1381        cid = interpreters.channel_create()
1382        orig = b'spam'
1383        interpreters.channel_send(cid, orig)
1384        obj = interpreters.channel_recv(cid)
1385
1386        self.assertEqual(obj, orig)
1387        self.assertIsNot(obj, orig)
1388
1389    def test_send_recv_same_interpreter(self):
1390        id1 = interpreters.create()
1391        out = _run_output(id1, dedent("""
1392            import _xxsubinterpreters as _interpreters
1393            cid = _interpreters.channel_create()
1394            orig = b'spam'
1395            _interpreters.channel_send(cid, orig)
1396            obj = _interpreters.channel_recv(cid)
1397            assert obj is not orig
1398            assert obj == orig
1399            """))
1400
1401    def test_send_recv_different_interpreters(self):
1402        cid = interpreters.channel_create()
1403        id1 = interpreters.create()
1404        out = _run_output(id1, dedent(f"""
1405            import _xxsubinterpreters as _interpreters
1406            _interpreters.channel_send({cid}, b'spam')
1407            """))
1408        obj = interpreters.channel_recv(cid)
1409
1410        self.assertEqual(obj, b'spam')
1411
1412    def test_send_recv_different_threads(self):
1413        cid = interpreters.channel_create()
1414
1415        def f():
1416            while True:
1417                try:
1418                    obj = interpreters.channel_recv(cid)
1419                    break
1420                except interpreters.ChannelEmptyError:
1421                    time.sleep(0.1)
1422            interpreters.channel_send(cid, obj)
1423        t = threading.Thread(target=f)
1424        t.start()
1425
1426        interpreters.channel_send(cid, b'spam')
1427        t.join()
1428        obj = interpreters.channel_recv(cid)
1429
1430        self.assertEqual(obj, b'spam')
1431
1432    def test_send_recv_different_interpreters_and_threads(self):
1433        cid = interpreters.channel_create()
1434        id1 = interpreters.create()
1435        out = None
1436
1437        def f():
1438            nonlocal out
1439            out = _run_output(id1, dedent(f"""
1440                import time
1441                import _xxsubinterpreters as _interpreters
1442                while True:
1443                    try:
1444                        obj = _interpreters.channel_recv({cid})
1445                        break
1446                    except _interpreters.ChannelEmptyError:
1447                        time.sleep(0.1)
1448                assert(obj == b'spam')
1449                _interpreters.channel_send({cid}, b'eggs')
1450                """))
1451        t = threading.Thread(target=f)
1452        t.start()
1453
1454        interpreters.channel_send(cid, b'spam')
1455        t.join()
1456        obj = interpreters.channel_recv(cid)
1457
1458        self.assertEqual(obj, b'eggs')
1459
1460    def test_send_not_found(self):
1461        with self.assertRaises(interpreters.ChannelNotFoundError):
1462            interpreters.channel_send(10, b'spam')
1463
1464    def test_recv_not_found(self):
1465        with self.assertRaises(interpreters.ChannelNotFoundError):
1466            interpreters.channel_recv(10)
1467
1468    def test_recv_empty(self):
1469        cid = interpreters.channel_create()
1470        with self.assertRaises(interpreters.ChannelEmptyError):
1471            interpreters.channel_recv(cid)
1472
1473    def test_recv_default(self):
1474        default = object()
1475        cid = interpreters.channel_create()
1476        obj1 = interpreters.channel_recv(cid, default)
1477        interpreters.channel_send(cid, None)
1478        interpreters.channel_send(cid, 1)
1479        interpreters.channel_send(cid, b'spam')
1480        interpreters.channel_send(cid, b'eggs')
1481        obj2 = interpreters.channel_recv(cid, default)
1482        obj3 = interpreters.channel_recv(cid, default)
1483        obj4 = interpreters.channel_recv(cid)
1484        obj5 = interpreters.channel_recv(cid, default)
1485        obj6 = interpreters.channel_recv(cid, default)
1486
1487        self.assertIs(obj1, default)
1488        self.assertIs(obj2, None)
1489        self.assertEqual(obj3, 1)
1490        self.assertEqual(obj4, b'spam')
1491        self.assertEqual(obj5, b'eggs')
1492        self.assertIs(obj6, default)
1493
1494    def test_run_string_arg_unresolved(self):
1495        cid = interpreters.channel_create()
1496        interp = interpreters.create()
1497
1498        out = _run_output(interp, dedent("""
1499            import _xxsubinterpreters as _interpreters
1500            print(cid.end)
1501            _interpreters.channel_send(cid, b'spam')
1502            """),
1503            dict(cid=cid.send))
1504        obj = interpreters.channel_recv(cid)
1505
1506        self.assertEqual(obj, b'spam')
1507        self.assertEqual(out.strip(), 'send')
1508
1509    # XXX For now there is no high-level channel into which the
1510    # sent channel ID can be converted...
1511    # Note: this test caused crashes on some buildbots (bpo-33615).
1512    @unittest.skip('disabled until high-level channels exist')
1513    def test_run_string_arg_resolved(self):
1514        cid = interpreters.channel_create()
1515        cid = interpreters._channel_id(cid, _resolve=True)
1516        interp = interpreters.create()
1517
1518        out = _run_output(interp, dedent("""
1519            import _xxsubinterpreters as _interpreters
1520            print(chan.id.end)
1521            _interpreters.channel_send(chan.id, b'spam')
1522            """),
1523            dict(chan=cid.send))
1524        obj = interpreters.channel_recv(cid)
1525
1526        self.assertEqual(obj, b'spam')
1527        self.assertEqual(out.strip(), 'send')
1528
1529    # close
1530
1531    def test_close_single_user(self):
1532        cid = interpreters.channel_create()
1533        interpreters.channel_send(cid, b'spam')
1534        interpreters.channel_recv(cid)
1535        interpreters.channel_close(cid)
1536
1537        with self.assertRaises(interpreters.ChannelClosedError):
1538            interpreters.channel_send(cid, b'eggs')
1539        with self.assertRaises(interpreters.ChannelClosedError):
1540            interpreters.channel_recv(cid)
1541
1542    def test_close_multiple_users(self):
1543        cid = interpreters.channel_create()
1544        id1 = interpreters.create()
1545        id2 = interpreters.create()
1546        interpreters.run_string(id1, dedent(f"""
1547            import _xxsubinterpreters as _interpreters
1548            _interpreters.channel_send({cid}, b'spam')
1549            """))
1550        interpreters.run_string(id2, dedent(f"""
1551            import _xxsubinterpreters as _interpreters
1552            _interpreters.channel_recv({cid})
1553            """))
1554        interpreters.channel_close(cid)
1555        with self.assertRaises(interpreters.RunFailedError) as cm:
1556            interpreters.run_string(id1, dedent(f"""
1557                _interpreters.channel_send({cid}, b'spam')
1558                """))
1559        self.assertIn('ChannelClosedError', str(cm.exception))
1560        with self.assertRaises(interpreters.RunFailedError) as cm:
1561            interpreters.run_string(id2, dedent(f"""
1562                _interpreters.channel_send({cid}, b'spam')
1563                """))
1564        self.assertIn('ChannelClosedError', str(cm.exception))
1565
1566    def test_close_multiple_times(self):
1567        cid = interpreters.channel_create()
1568        interpreters.channel_send(cid, b'spam')
1569        interpreters.channel_recv(cid)
1570        interpreters.channel_close(cid)
1571
1572        with self.assertRaises(interpreters.ChannelClosedError):
1573            interpreters.channel_close(cid)
1574
1575    def test_close_empty(self):
1576        tests = [
1577            (False, False),
1578            (True, False),
1579            (False, True),
1580            (True, True),
1581            ]
1582        for send, recv in tests:
1583            with self.subTest((send, recv)):
1584                cid = interpreters.channel_create()
1585                interpreters.channel_send(cid, b'spam')
1586                interpreters.channel_recv(cid)
1587                interpreters.channel_close(cid, send=send, recv=recv)
1588
1589                with self.assertRaises(interpreters.ChannelClosedError):
1590                    interpreters.channel_send(cid, b'eggs')
1591                with self.assertRaises(interpreters.ChannelClosedError):
1592                    interpreters.channel_recv(cid)
1593
1594    def test_close_defaults_with_unused_items(self):
1595        cid = interpreters.channel_create()
1596        interpreters.channel_send(cid, b'spam')
1597        interpreters.channel_send(cid, b'ham')
1598
1599        with self.assertRaises(interpreters.ChannelNotEmptyError):
1600            interpreters.channel_close(cid)
1601        interpreters.channel_recv(cid)
1602        interpreters.channel_send(cid, b'eggs')
1603
1604    def test_close_recv_with_unused_items_unforced(self):
1605        cid = interpreters.channel_create()
1606        interpreters.channel_send(cid, b'spam')
1607        interpreters.channel_send(cid, b'ham')
1608
1609        with self.assertRaises(interpreters.ChannelNotEmptyError):
1610            interpreters.channel_close(cid, recv=True)
1611        interpreters.channel_recv(cid)
1612        interpreters.channel_send(cid, b'eggs')
1613        interpreters.channel_recv(cid)
1614        interpreters.channel_recv(cid)
1615        interpreters.channel_close(cid, recv=True)
1616
1617    def test_close_send_with_unused_items_unforced(self):
1618        cid = interpreters.channel_create()
1619        interpreters.channel_send(cid, b'spam')
1620        interpreters.channel_send(cid, b'ham')
1621        interpreters.channel_close(cid, send=True)
1622
1623        with self.assertRaises(interpreters.ChannelClosedError):
1624            interpreters.channel_send(cid, b'eggs')
1625        interpreters.channel_recv(cid)
1626        interpreters.channel_recv(cid)
1627        with self.assertRaises(interpreters.ChannelClosedError):
1628            interpreters.channel_recv(cid)
1629
1630    def test_close_both_with_unused_items_unforced(self):
1631        cid = interpreters.channel_create()
1632        interpreters.channel_send(cid, b'spam')
1633        interpreters.channel_send(cid, b'ham')
1634
1635        with self.assertRaises(interpreters.ChannelNotEmptyError):
1636            interpreters.channel_close(cid, recv=True, send=True)
1637        interpreters.channel_recv(cid)
1638        interpreters.channel_send(cid, b'eggs')
1639        interpreters.channel_recv(cid)
1640        interpreters.channel_recv(cid)
1641        interpreters.channel_close(cid, recv=True)
1642
1643    def test_close_recv_with_unused_items_forced(self):
1644        cid = interpreters.channel_create()
1645        interpreters.channel_send(cid, b'spam')
1646        interpreters.channel_send(cid, b'ham')
1647        interpreters.channel_close(cid, recv=True, force=True)
1648
1649        with self.assertRaises(interpreters.ChannelClosedError):
1650            interpreters.channel_send(cid, b'eggs')
1651        with self.assertRaises(interpreters.ChannelClosedError):
1652            interpreters.channel_recv(cid)
1653
1654    def test_close_send_with_unused_items_forced(self):
1655        cid = interpreters.channel_create()
1656        interpreters.channel_send(cid, b'spam')
1657        interpreters.channel_send(cid, b'ham')
1658        interpreters.channel_close(cid, send=True, force=True)
1659
1660        with self.assertRaises(interpreters.ChannelClosedError):
1661            interpreters.channel_send(cid, b'eggs')
1662        with self.assertRaises(interpreters.ChannelClosedError):
1663            interpreters.channel_recv(cid)
1664
1665    def test_close_both_with_unused_items_forced(self):
1666        cid = interpreters.channel_create()
1667        interpreters.channel_send(cid, b'spam')
1668        interpreters.channel_send(cid, b'ham')
1669        interpreters.channel_close(cid, send=True, recv=True, force=True)
1670
1671        with self.assertRaises(interpreters.ChannelClosedError):
1672            interpreters.channel_send(cid, b'eggs')
1673        with self.assertRaises(interpreters.ChannelClosedError):
1674            interpreters.channel_recv(cid)
1675
1676    def test_close_never_used(self):
1677        cid = interpreters.channel_create()
1678        interpreters.channel_close(cid)
1679
1680        with self.assertRaises(interpreters.ChannelClosedError):
1681            interpreters.channel_send(cid, b'spam')
1682        with self.assertRaises(interpreters.ChannelClosedError):
1683            interpreters.channel_recv(cid)
1684
1685    def test_close_by_unassociated_interp(self):
1686        cid = interpreters.channel_create()
1687        interpreters.channel_send(cid, b'spam')
1688        interp = interpreters.create()
1689        interpreters.run_string(interp, dedent(f"""
1690            import _xxsubinterpreters as _interpreters
1691            _interpreters.channel_close({cid}, force=True)
1692            """))
1693        with self.assertRaises(interpreters.ChannelClosedError):
1694            interpreters.channel_recv(cid)
1695        with self.assertRaises(interpreters.ChannelClosedError):
1696            interpreters.channel_close(cid)
1697
1698    def test_close_used_multiple_times_by_single_user(self):
1699        cid = interpreters.channel_create()
1700        interpreters.channel_send(cid, b'spam')
1701        interpreters.channel_send(cid, b'spam')
1702        interpreters.channel_send(cid, b'spam')
1703        interpreters.channel_recv(cid)
1704        interpreters.channel_close(cid, force=True)
1705
1706        with self.assertRaises(interpreters.ChannelClosedError):
1707            interpreters.channel_send(cid, b'eggs')
1708        with self.assertRaises(interpreters.ChannelClosedError):
1709            interpreters.channel_recv(cid)
1710
1711    def test_channel_list_interpreters_invalid_channel(self):
1712        cid = interpreters.channel_create()
1713        # Test for invalid channel ID.
1714        with self.assertRaises(interpreters.ChannelNotFoundError):
1715            interpreters.channel_list_interpreters(1000, send=True)
1716
1717        interpreters.channel_close(cid)
1718        # Test for a channel that has been closed.
1719        with self.assertRaises(interpreters.ChannelClosedError):
1720            interpreters.channel_list_interpreters(cid, send=True)
1721
1722    def test_channel_list_interpreters_invalid_args(self):
1723        # Tests for invalid arguments passed to the API.
1724        cid = interpreters.channel_create()
1725        with self.assertRaises(TypeError):
1726            interpreters.channel_list_interpreters(cid)
1727
1728
1729class ChannelReleaseTests(TestBase):
1730
1731    # XXX Add more test coverage a la the tests for close().
1732
1733    """
1734    - main / interp / other
1735    - run in: current thread / new thread / other thread / different threads
1736    - end / opposite
1737    - force / no force
1738    - used / not used  (associated / not associated)
1739    - empty / emptied / never emptied / partly emptied
1740    - closed / not closed
1741    - released / not released
1742    - creator (interp) / other
1743    - associated interpreter not running
1744    - associated interpreter destroyed
1745    """
1746
1747    """
1748    use
1749    pre-release
1750    release
1751    after
1752    check
1753    """
1754
1755    """
1756    release in:         main, interp1
1757    creator:            same, other (incl. interp2)
1758
1759    use:                None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all
1760    pre-release:        None,send,recv,both in None,same,other(incl. interp2),same+other(incl. interp2),all
1761    pre-release forced: None,send,recv,both in None,same,other(incl. interp2),same+other(incl. interp2),all
1762
1763    release:            same
1764    release forced:     same
1765
1766    use after:          None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all
1767    release after:      None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all
1768    check released:     send/recv for same/other(incl. interp2)
1769    check closed:       send/recv for same/other(incl. interp2)
1770    """
1771
1772    def test_single_user(self):
1773        cid = interpreters.channel_create()
1774        interpreters.channel_send(cid, b'spam')
1775        interpreters.channel_recv(cid)
1776        interpreters.channel_release(cid, send=True, recv=True)
1777
1778        with self.assertRaises(interpreters.ChannelClosedError):
1779            interpreters.channel_send(cid, b'eggs')
1780        with self.assertRaises(interpreters.ChannelClosedError):
1781            interpreters.channel_recv(cid)
1782
1783    def test_multiple_users(self):
1784        cid = interpreters.channel_create()
1785        id1 = interpreters.create()
1786        id2 = interpreters.create()
1787        interpreters.run_string(id1, dedent(f"""
1788            import _xxsubinterpreters as _interpreters
1789            _interpreters.channel_send({cid}, b'spam')
1790            """))
1791        out = _run_output(id2, dedent(f"""
1792            import _xxsubinterpreters as _interpreters
1793            obj = _interpreters.channel_recv({cid})
1794            _interpreters.channel_release({cid})
1795            print(repr(obj))
1796            """))
1797        interpreters.run_string(id1, dedent(f"""
1798            _interpreters.channel_release({cid})
1799            """))
1800
1801        self.assertEqual(out.strip(), "b'spam'")
1802
1803    def test_no_kwargs(self):
1804        cid = interpreters.channel_create()
1805        interpreters.channel_send(cid, b'spam')
1806        interpreters.channel_recv(cid)
1807        interpreters.channel_release(cid)
1808
1809        with self.assertRaises(interpreters.ChannelClosedError):
1810            interpreters.channel_send(cid, b'eggs')
1811        with self.assertRaises(interpreters.ChannelClosedError):
1812            interpreters.channel_recv(cid)
1813
1814    def test_multiple_times(self):
1815        cid = interpreters.channel_create()
1816        interpreters.channel_send(cid, b'spam')
1817        interpreters.channel_recv(cid)
1818        interpreters.channel_release(cid, send=True, recv=True)
1819
1820        with self.assertRaises(interpreters.ChannelClosedError):
1821            interpreters.channel_release(cid, send=True, recv=True)
1822
1823    def test_with_unused_items(self):
1824        cid = interpreters.channel_create()
1825        interpreters.channel_send(cid, b'spam')
1826        interpreters.channel_send(cid, b'ham')
1827        interpreters.channel_release(cid, send=True, recv=True)
1828
1829        with self.assertRaises(interpreters.ChannelClosedError):
1830            interpreters.channel_recv(cid)
1831
1832    def test_never_used(self):
1833        cid = interpreters.channel_create()
1834        interpreters.channel_release(cid)
1835
1836        with self.assertRaises(interpreters.ChannelClosedError):
1837            interpreters.channel_send(cid, b'spam')
1838        with self.assertRaises(interpreters.ChannelClosedError):
1839            interpreters.channel_recv(cid)
1840
1841    def test_by_unassociated_interp(self):
1842        cid = interpreters.channel_create()
1843        interpreters.channel_send(cid, b'spam')
1844        interp = interpreters.create()
1845        interpreters.run_string(interp, dedent(f"""
1846            import _xxsubinterpreters as _interpreters
1847            _interpreters.channel_release({cid})
1848            """))
1849        obj = interpreters.channel_recv(cid)
1850        interpreters.channel_release(cid)
1851
1852        with self.assertRaises(interpreters.ChannelClosedError):
1853            interpreters.channel_send(cid, b'eggs')
1854        self.assertEqual(obj, b'spam')
1855
1856    def test_close_if_unassociated(self):
1857        # XXX Something's not right with this test...
1858        cid = interpreters.channel_create()
1859        interp = interpreters.create()
1860        interpreters.run_string(interp, dedent(f"""
1861            import _xxsubinterpreters as _interpreters
1862            obj = _interpreters.channel_send({cid}, b'spam')
1863            _interpreters.channel_release({cid})
1864            """))
1865
1866        with self.assertRaises(interpreters.ChannelClosedError):
1867            interpreters.channel_recv(cid)
1868
1869    def test_partially(self):
1870        # XXX Is partial close too weird/confusing?
1871        cid = interpreters.channel_create()
1872        interpreters.channel_send(cid, None)
1873        interpreters.channel_recv(cid)
1874        interpreters.channel_send(cid, b'spam')
1875        interpreters.channel_release(cid, send=True)
1876        obj = interpreters.channel_recv(cid)
1877
1878        self.assertEqual(obj, b'spam')
1879
1880    def test_used_multiple_times_by_single_user(self):
1881        cid = interpreters.channel_create()
1882        interpreters.channel_send(cid, b'spam')
1883        interpreters.channel_send(cid, b'spam')
1884        interpreters.channel_send(cid, b'spam')
1885        interpreters.channel_recv(cid)
1886        interpreters.channel_release(cid, send=True, recv=True)
1887
1888        with self.assertRaises(interpreters.ChannelClosedError):
1889            interpreters.channel_send(cid, b'eggs')
1890        with self.assertRaises(interpreters.ChannelClosedError):
1891            interpreters.channel_recv(cid)
1892
1893
1894class ChannelCloseFixture(namedtuple('ChannelCloseFixture',
1895                                     'end interp other extra creator')):
1896
1897    # Set this to True to avoid creating interpreters, e.g. when
1898    # scanning through test permutations without running them.
1899    QUICK = False
1900
1901    def __new__(cls, end, interp, other, extra, creator):
1902        assert end in ('send', 'recv')
1903        if cls.QUICK:
1904            known = {}
1905        else:
1906            interp = Interpreter.from_raw(interp)
1907            other = Interpreter.from_raw(other)
1908            extra = Interpreter.from_raw(extra)
1909            known = {
1910                interp.name: interp,
1911                other.name: other,
1912                extra.name: extra,
1913                }
1914        if not creator:
1915            creator = 'same'
1916        self = super().__new__(cls, end, interp, other, extra, creator)
1917        self._prepped = set()
1918        self._state = ChannelState()
1919        self._known = known
1920        return self
1921
1922    @property
1923    def state(self):
1924        return self._state
1925
1926    @property
1927    def cid(self):
1928        try:
1929            return self._cid
1930        except AttributeError:
1931            creator = self._get_interpreter(self.creator)
1932            self._cid = self._new_channel(creator)
1933            return self._cid
1934
1935    def get_interpreter(self, interp):
1936        interp = self._get_interpreter(interp)
1937        self._prep_interpreter(interp)
1938        return interp
1939
1940    def expect_closed_error(self, end=None):
1941        if end is None:
1942            end = self.end
1943        if end == 'recv' and self.state.closed == 'send':
1944            return False
1945        return bool(self.state.closed)
1946
1947    def prep_interpreter(self, interp):
1948        self._prep_interpreter(interp)
1949
1950    def record_action(self, action, result):
1951        self._state = result
1952
1953    def clean_up(self):
1954        clean_up_interpreters()
1955        clean_up_channels()
1956
1957    # internal methods
1958
1959    def _new_channel(self, creator):
1960        if creator.name == 'main':
1961            return interpreters.channel_create()
1962        else:
1963            ch = interpreters.channel_create()
1964            run_interp(creator.id, f"""
1965                import _xxsubinterpreters
1966                cid = _xxsubinterpreters.channel_create()
1967                # We purposefully send back an int to avoid tying the
1968                # channel to the other interpreter.
1969                _xxsubinterpreters.channel_send({ch}, int(cid))
1970                del _xxsubinterpreters
1971                """)
1972            self._cid = interpreters.channel_recv(ch)
1973        return self._cid
1974
1975    def _get_interpreter(self, interp):
1976        if interp in ('same', 'interp'):
1977            return self.interp
1978        elif interp == 'other':
1979            return self.other
1980        elif interp == 'extra':
1981            return self.extra
1982        else:
1983            name = interp
1984            try:
1985                interp = self._known[name]
1986            except KeyError:
1987                interp = self._known[name] = Interpreter(name)
1988            return interp
1989
1990    def _prep_interpreter(self, interp):
1991        if interp.id in self._prepped:
1992            return
1993        self._prepped.add(interp.id)
1994        if interp.name == 'main':
1995            return
1996        run_interp(interp.id, f"""
1997            import _xxsubinterpreters as interpreters
1998            import test.test__xxsubinterpreters as helpers
1999            ChannelState = helpers.ChannelState
2000            try:
2001                cid
2002            except NameError:
2003                cid = interpreters._channel_id({self.cid})
2004            """)
2005
2006
2007@unittest.skip('these tests take several hours to run')
2008class ExhaustiveChannelTests(TestBase):
2009
2010    """
2011    - main / interp / other
2012    - run in: current thread / new thread / other thread / different threads
2013    - end / opposite
2014    - force / no force
2015    - used / not used  (associated / not associated)
2016    - empty / emptied / never emptied / partly emptied
2017    - closed / not closed
2018    - released / not released
2019    - creator (interp) / other
2020    - associated interpreter not running
2021    - associated interpreter destroyed
2022
2023    - close after unbound
2024    """
2025
2026    """
2027    use
2028    pre-close
2029    close
2030    after
2031    check
2032    """
2033
2034    """
2035    close in:         main, interp1
2036    creator:          same, other, extra
2037
2038    use:              None,send,recv,send/recv in None,same,other,same+other,all
2039    pre-close:        None,send,recv in None,same,other,same+other,all
2040    pre-close forced: None,send,recv in None,same,other,same+other,all
2041
2042    close:            same
2043    close forced:     same
2044
2045    use after:        None,send,recv,send/recv in None,same,other,extra,same+other,all
2046    close after:      None,send,recv,send/recv in None,same,other,extra,same+other,all
2047    check closed:     send/recv for same/other(incl. interp2)
2048    """
2049
2050    def iter_action_sets(self):
2051        # - used / not used  (associated / not associated)
2052        # - empty / emptied / never emptied / partly emptied
2053        # - closed / not closed
2054        # - released / not released
2055
2056        # never used
2057        yield []
2058
2059        # only pre-closed (and possible used after)
2060        for closeactions in self._iter_close_action_sets('same', 'other'):
2061            yield closeactions
2062            for postactions in self._iter_post_close_action_sets():
2063                yield closeactions + postactions
2064        for closeactions in self._iter_close_action_sets('other', 'extra'):
2065            yield closeactions
2066            for postactions in self._iter_post_close_action_sets():
2067                yield closeactions + postactions
2068
2069        # used
2070        for useactions in self._iter_use_action_sets('same', 'other'):
2071            yield useactions
2072            for closeactions in self._iter_close_action_sets('same', 'other'):
2073                actions = useactions + closeactions
2074                yield actions
2075                for postactions in self._iter_post_close_action_sets():
2076                    yield actions + postactions
2077            for closeactions in self._iter_close_action_sets('other', 'extra'):
2078                actions = useactions + closeactions
2079                yield actions
2080                for postactions in self._iter_post_close_action_sets():
2081                    yield actions + postactions
2082        for useactions in self._iter_use_action_sets('other', 'extra'):
2083            yield useactions
2084            for closeactions in self._iter_close_action_sets('same', 'other'):
2085                actions = useactions + closeactions
2086                yield actions
2087                for postactions in self._iter_post_close_action_sets():
2088                    yield actions + postactions
2089            for closeactions in self._iter_close_action_sets('other', 'extra'):
2090                actions = useactions + closeactions
2091                yield actions
2092                for postactions in self._iter_post_close_action_sets():
2093                    yield actions + postactions
2094
2095    def _iter_use_action_sets(self, interp1, interp2):
2096        interps = (interp1, interp2)
2097
2098        # only recv end used
2099        yield [
2100            ChannelAction('use', 'recv', interp1),
2101            ]
2102        yield [
2103            ChannelAction('use', 'recv', interp2),
2104            ]
2105        yield [
2106            ChannelAction('use', 'recv', interp1),
2107            ChannelAction('use', 'recv', interp2),
2108            ]
2109
2110        # never emptied
2111        yield [
2112            ChannelAction('use', 'send', interp1),
2113            ]
2114        yield [
2115            ChannelAction('use', 'send', interp2),
2116            ]
2117        yield [
2118            ChannelAction('use', 'send', interp1),
2119            ChannelAction('use', 'send', interp2),
2120            ]
2121
2122        # partially emptied
2123        for interp1 in interps:
2124            for interp2 in interps:
2125                for interp3 in interps:
2126                    yield [
2127                        ChannelAction('use', 'send', interp1),
2128                        ChannelAction('use', 'send', interp2),
2129                        ChannelAction('use', 'recv', interp3),
2130                        ]
2131
2132        # fully emptied
2133        for interp1 in interps:
2134            for interp2 in interps:
2135                for interp3 in interps:
2136                    for interp4 in interps:
2137                        yield [
2138                            ChannelAction('use', 'send', interp1),
2139                            ChannelAction('use', 'send', interp2),
2140                            ChannelAction('use', 'recv', interp3),
2141                            ChannelAction('use', 'recv', interp4),
2142                            ]
2143
2144    def _iter_close_action_sets(self, interp1, interp2):
2145        ends = ('recv', 'send')
2146        interps = (interp1, interp2)
2147        for force in (True, False):
2148            op = 'force-close' if force else 'close'
2149            for interp in interps:
2150                for end in ends:
2151                    yield [
2152                        ChannelAction(op, end, interp),
2153                        ]
2154        for recvop in ('close', 'force-close'):
2155            for sendop in ('close', 'force-close'):
2156                for recv in interps:
2157                    for send in interps:
2158                        yield [
2159                            ChannelAction(recvop, 'recv', recv),
2160                            ChannelAction(sendop, 'send', send),
2161                            ]
2162
2163    def _iter_post_close_action_sets(self):
2164        for interp in ('same', 'extra', 'other'):
2165            yield [
2166                ChannelAction('use', 'recv', interp),
2167                ]
2168            yield [
2169                ChannelAction('use', 'send', interp),
2170                ]
2171
2172    def run_actions(self, fix, actions):
2173        for action in actions:
2174            self.run_action(fix, action)
2175
2176    def run_action(self, fix, action, *, hideclosed=True):
2177        end = action.resolve_end(fix.end)
2178        interp = action.resolve_interp(fix.interp, fix.other, fix.extra)
2179        fix.prep_interpreter(interp)
2180        if interp.name == 'main':
2181            result = run_action(
2182                fix.cid,
2183                action.action,
2184                end,
2185                fix.state,
2186                hideclosed=hideclosed,
2187                )
2188            fix.record_action(action, result)
2189        else:
2190            _cid = interpreters.channel_create()
2191            run_interp(interp.id, f"""
2192                result = helpers.run_action(
2193                    {fix.cid},
2194                    {repr(action.action)},
2195                    {repr(end)},
2196                    {repr(fix.state)},
2197                    hideclosed={hideclosed},
2198                    )
2199                interpreters.channel_send({_cid}, result.pending.to_bytes(1, 'little'))
2200                interpreters.channel_send({_cid}, b'X' if result.closed else b'')
2201                """)
2202            result = ChannelState(
2203                pending=int.from_bytes(interpreters.channel_recv(_cid), 'little'),
2204                closed=bool(interpreters.channel_recv(_cid)),
2205                )
2206            fix.record_action(action, result)
2207
2208    def iter_fixtures(self):
2209        # XXX threads?
2210        interpreters = [
2211            ('main', 'interp', 'extra'),
2212            ('interp', 'main', 'extra'),
2213            ('interp1', 'interp2', 'extra'),
2214            ('interp1', 'interp2', 'main'),
2215        ]
2216        for interp, other, extra in interpreters:
2217            for creator in ('same', 'other', 'creator'):
2218                for end in ('send', 'recv'):
2219                    yield ChannelCloseFixture(end, interp, other, extra, creator)
2220
2221    def _close(self, fix, *, force):
2222        op = 'force-close' if force else 'close'
2223        close = ChannelAction(op, fix.end, 'same')
2224        if not fix.expect_closed_error():
2225            self.run_action(fix, close, hideclosed=False)
2226        else:
2227            with self.assertRaises(interpreters.ChannelClosedError):
2228                self.run_action(fix, close, hideclosed=False)
2229
2230    def _assert_closed_in_interp(self, fix, interp=None):
2231        if interp is None or interp.name == 'main':
2232            with self.assertRaises(interpreters.ChannelClosedError):
2233                interpreters.channel_recv(fix.cid)
2234            with self.assertRaises(interpreters.ChannelClosedError):
2235                interpreters.channel_send(fix.cid, b'spam')
2236            with self.assertRaises(interpreters.ChannelClosedError):
2237                interpreters.channel_close(fix.cid)
2238            with self.assertRaises(interpreters.ChannelClosedError):
2239                interpreters.channel_close(fix.cid, force=True)
2240        else:
2241            run_interp(interp.id, f"""
2242                with helpers.expect_channel_closed():
2243                    interpreters.channel_recv(cid)
2244                """)
2245            run_interp(interp.id, f"""
2246                with helpers.expect_channel_closed():
2247                    interpreters.channel_send(cid, b'spam')
2248                """)
2249            run_interp(interp.id, f"""
2250                with helpers.expect_channel_closed():
2251                    interpreters.channel_close(cid)
2252                """)
2253            run_interp(interp.id, f"""
2254                with helpers.expect_channel_closed():
2255                    interpreters.channel_close(cid, force=True)
2256                """)
2257
2258    def _assert_closed(self, fix):
2259        self.assertTrue(fix.state.closed)
2260
2261        for _ in range(fix.state.pending):
2262            interpreters.channel_recv(fix.cid)
2263        self._assert_closed_in_interp(fix)
2264
2265        for interp in ('same', 'other'):
2266            interp = fix.get_interpreter(interp)
2267            if interp.name == 'main':
2268                continue
2269            self._assert_closed_in_interp(fix, interp)
2270
2271        interp = fix.get_interpreter('fresh')
2272        self._assert_closed_in_interp(fix, interp)
2273
2274    def _iter_close_tests(self, verbose=False):
2275        i = 0
2276        for actions in self.iter_action_sets():
2277            print()
2278            for fix in self.iter_fixtures():
2279                i += 1
2280                if i > 1000:
2281                    return
2282                if verbose:
2283                    if (i - 1) % 6 == 0:
2284                        print()
2285                    print(i, fix, '({} actions)'.format(len(actions)))
2286                else:
2287                    if (i - 1) % 6 == 0:
2288                        print(' ', end='')
2289                    print('.', end=''); sys.stdout.flush()
2290                yield i, fix, actions
2291            if verbose:
2292                print('---')
2293        print()
2294
2295    # This is useful for scanning through the possible tests.
2296    def _skim_close_tests(self):
2297        ChannelCloseFixture.QUICK = True
2298        for i, fix, actions in self._iter_close_tests():
2299            pass
2300
2301    def test_close(self):
2302        for i, fix, actions in self._iter_close_tests():
2303            with self.subTest('{} {}  {}'.format(i, fix, actions)):
2304                fix.prep_interpreter(fix.interp)
2305                self.run_actions(fix, actions)
2306
2307                self._close(fix, force=False)
2308
2309                self._assert_closed(fix)
2310            # XXX Things slow down if we have too many interpreters.
2311            fix.clean_up()
2312
2313    def test_force_close(self):
2314        for i, fix, actions in self._iter_close_tests():
2315            with self.subTest('{} {}  {}'.format(i, fix, actions)):
2316                fix.prep_interpreter(fix.interp)
2317                self.run_actions(fix, actions)
2318
2319                self._close(fix, force=True)
2320
2321                self._assert_closed(fix)
2322            # XXX Things slow down if we have too many interpreters.
2323            fix.clean_up()
2324
2325
2326if __name__ == '__main__':
2327    unittest.main()
2328