1import contextlib
2import os
3import threading
4from textwrap import dedent
5import unittest
6import time
7
8from test import support
9from test.support import import_helper
10_interpreters = import_helper.import_module('_xxsubinterpreters')
11from test.support import interpreters
12
13
14def _captured_script(script):
15    r, w = os.pipe()
16    indented = script.replace('\n', '\n                ')
17    wrapped = dedent(f"""
18        import contextlib
19        with open({w}, 'w', encoding='utf-8') as spipe:
20            with contextlib.redirect_stdout(spipe):
21                {indented}
22        """)
23    return wrapped, open(r, encoding='utf-8')
24
25
26def clean_up_interpreters():
27    for interp in interpreters.list_all():
28        if interp.id == 0:  # main
29            continue
30        try:
31            interp.close()
32        except RuntimeError:
33            pass  # already destroyed
34
35
36def _run_output(interp, request, channels=None):
37    script, rpipe = _captured_script(request)
38    with rpipe:
39        interp.run(script, channels=channels)
40        return rpipe.read()
41
42
43@contextlib.contextmanager
44def _running(interp):
45    r, w = os.pipe()
46    def run():
47        interp.run(dedent(f"""
48            # wait for "signal"
49            with open({r}) as rpipe:
50                rpipe.read()
51            """))
52
53    t = threading.Thread(target=run)
54    t.start()
55
56    yield
57
58    with open(w, 'w') as spipe:
59        spipe.write('done')
60    t.join()
61
62
63class TestBase(unittest.TestCase):
64
65    def tearDown(self):
66        clean_up_interpreters()
67
68
69class CreateTests(TestBase):
70
71    def test_in_main(self):
72        interp = interpreters.create()
73        self.assertIsInstance(interp, interpreters.Interpreter)
74        self.assertIn(interp, interpreters.list_all())
75
76    def test_in_thread(self):
77        lock = threading.Lock()
78        interp = None
79        def f():
80            nonlocal interp
81            interp = interpreters.create()
82            lock.acquire()
83            lock.release()
84        t = threading.Thread(target=f)
85        with lock:
86            t.start()
87        t.join()
88        self.assertIn(interp, interpreters.list_all())
89
90    def test_in_subinterpreter(self):
91        main, = interpreters.list_all()
92        interp = interpreters.create()
93        out = _run_output(interp, dedent("""
94            from test.support import interpreters
95            interp = interpreters.create()
96            print(interp.id)
97            """))
98        interp2 = interpreters.Interpreter(int(out))
99        self.assertEqual(interpreters.list_all(), [main, interp, interp2])
100
101    def test_after_destroy_all(self):
102        before = set(interpreters.list_all())
103        # Create 3 subinterpreters.
104        interp_lst = []
105        for _ in range(3):
106            interps = interpreters.create()
107            interp_lst.append(interps)
108        # Now destroy them.
109        for interp in interp_lst:
110            interp.close()
111        # Finally, create another.
112        interp = interpreters.create()
113        self.assertEqual(set(interpreters.list_all()), before | {interp})
114
115    def test_after_destroy_some(self):
116        before = set(interpreters.list_all())
117        # Create 3 subinterpreters.
118        interp1 = interpreters.create()
119        interp2 = interpreters.create()
120        interp3 = interpreters.create()
121        # Now destroy 2 of them.
122        interp1.close()
123        interp2.close()
124        # Finally, create another.
125        interp = interpreters.create()
126        self.assertEqual(set(interpreters.list_all()), before | {interp3, interp})
127
128
129class GetCurrentTests(TestBase):
130
131    def test_main(self):
132        main = interpreters.get_main()
133        current = interpreters.get_current()
134        self.assertEqual(current, main)
135
136    def test_subinterpreter(self):
137        main = _interpreters.get_main()
138        interp = interpreters.create()
139        out = _run_output(interp, dedent("""
140            from test.support import interpreters
141            cur = interpreters.get_current()
142            print(cur.id)
143            """))
144        current = interpreters.Interpreter(int(out))
145        self.assertNotEqual(current, main)
146
147
148class ListAllTests(TestBase):
149
150    def test_initial(self):
151        interps = interpreters.list_all()
152        self.assertEqual(1, len(interps))
153
154    def test_after_creating(self):
155        main = interpreters.get_current()
156        first = interpreters.create()
157        second = interpreters.create()
158
159        ids = []
160        for interp in interpreters.list_all():
161            ids.append(interp.id)
162
163        self.assertEqual(ids, [main.id, first.id, second.id])
164
165    def test_after_destroying(self):
166        main = interpreters.get_current()
167        first = interpreters.create()
168        second = interpreters.create()
169        first.close()
170
171        ids = []
172        for interp in interpreters.list_all():
173            ids.append(interp.id)
174
175        self.assertEqual(ids, [main.id, second.id])
176
177
178class TestInterpreterAttrs(TestBase):
179
180    def test_id_type(self):
181        main = interpreters.get_main()
182        current = interpreters.get_current()
183        interp = interpreters.create()
184        self.assertIsInstance(main.id, _interpreters.InterpreterID)
185        self.assertIsInstance(current.id, _interpreters.InterpreterID)
186        self.assertIsInstance(interp.id, _interpreters.InterpreterID)
187
188    def test_main_id(self):
189        main = interpreters.get_main()
190        self.assertEqual(main.id, 0)
191
192    def test_custom_id(self):
193        interp = interpreters.Interpreter(1)
194        self.assertEqual(interp.id, 1)
195
196        with self.assertRaises(TypeError):
197            interpreters.Interpreter('1')
198
199    def test_id_readonly(self):
200        interp = interpreters.Interpreter(1)
201        with self.assertRaises(AttributeError):
202            interp.id = 2
203
204    @unittest.skip('not ready yet (see bpo-32604)')
205    def test_main_isolated(self):
206        main = interpreters.get_main()
207        self.assertFalse(main.isolated)
208
209    @unittest.skip('not ready yet (see bpo-32604)')
210    def test_subinterpreter_isolated_default(self):
211        interp = interpreters.create()
212        self.assertFalse(interp.isolated)
213
214    def test_subinterpreter_isolated_explicit(self):
215        interp1 = interpreters.create(isolated=True)
216        interp2 = interpreters.create(isolated=False)
217        self.assertTrue(interp1.isolated)
218        self.assertFalse(interp2.isolated)
219
220    @unittest.skip('not ready yet (see bpo-32604)')
221    def test_custom_isolated_default(self):
222        interp = interpreters.Interpreter(1)
223        self.assertFalse(interp.isolated)
224
225    def test_custom_isolated_explicit(self):
226        interp1 = interpreters.Interpreter(1, isolated=True)
227        interp2 = interpreters.Interpreter(1, isolated=False)
228        self.assertTrue(interp1.isolated)
229        self.assertFalse(interp2.isolated)
230
231    def test_isolated_readonly(self):
232        interp = interpreters.Interpreter(1)
233        with self.assertRaises(AttributeError):
234            interp.isolated = True
235
236    def test_equality(self):
237        interp1 = interpreters.create()
238        interp2 = interpreters.create()
239        self.assertEqual(interp1, interp1)
240        self.assertNotEqual(interp1, interp2)
241
242
243class TestInterpreterIsRunning(TestBase):
244
245    def test_main(self):
246        main = interpreters.get_main()
247        self.assertTrue(main.is_running())
248
249    @unittest.skip('Fails on FreeBSD')
250    def test_subinterpreter(self):
251        interp = interpreters.create()
252        self.assertFalse(interp.is_running())
253
254        with _running(interp):
255            self.assertTrue(interp.is_running())
256        self.assertFalse(interp.is_running())
257
258    def test_from_subinterpreter(self):
259        interp = interpreters.create()
260        out = _run_output(interp, dedent(f"""
261            import _xxsubinterpreters as _interpreters
262            if _interpreters.is_running({interp.id}):
263                print(True)
264            else:
265                print(False)
266            """))
267        self.assertEqual(out.strip(), 'True')
268
269    def test_already_destroyed(self):
270        interp = interpreters.create()
271        interp.close()
272        with self.assertRaises(RuntimeError):
273            interp.is_running()
274
275    def test_does_not_exist(self):
276        interp = interpreters.Interpreter(1_000_000)
277        with self.assertRaises(RuntimeError):
278            interp.is_running()
279
280    def test_bad_id(self):
281        interp = interpreters.Interpreter(-1)
282        with self.assertRaises(ValueError):
283            interp.is_running()
284
285
286class TestInterpreterClose(TestBase):
287
288    def test_basic(self):
289        main = interpreters.get_main()
290        interp1 = interpreters.create()
291        interp2 = interpreters.create()
292        interp3 = interpreters.create()
293        self.assertEqual(set(interpreters.list_all()),
294                         {main, interp1, interp2, interp3})
295        interp2.close()
296        self.assertEqual(set(interpreters.list_all()),
297                         {main, interp1, interp3})
298
299    def test_all(self):
300        before = set(interpreters.list_all())
301        interps = set()
302        for _ in range(3):
303            interp = interpreters.create()
304            interps.add(interp)
305        self.assertEqual(set(interpreters.list_all()), before | interps)
306        for interp in interps:
307            interp.close()
308        self.assertEqual(set(interpreters.list_all()), before)
309
310    def test_main(self):
311        main, = interpreters.list_all()
312        with self.assertRaises(RuntimeError):
313            main.close()
314
315        def f():
316            with self.assertRaises(RuntimeError):
317                main.close()
318
319        t = threading.Thread(target=f)
320        t.start()
321        t.join()
322
323    def test_already_destroyed(self):
324        interp = interpreters.create()
325        interp.close()
326        with self.assertRaises(RuntimeError):
327            interp.close()
328
329    def test_does_not_exist(self):
330        interp = interpreters.Interpreter(1_000_000)
331        with self.assertRaises(RuntimeError):
332            interp.close()
333
334    def test_bad_id(self):
335        interp = interpreters.Interpreter(-1)
336        with self.assertRaises(ValueError):
337            interp.close()
338
339    def test_from_current(self):
340        main, = interpreters.list_all()
341        interp = interpreters.create()
342        out = _run_output(interp, dedent(f"""
343            from test.support import interpreters
344            interp = interpreters.Interpreter({int(interp.id)})
345            try:
346                interp.close()
347            except RuntimeError:
348                print('failed')
349            """))
350        self.assertEqual(out.strip(), 'failed')
351        self.assertEqual(set(interpreters.list_all()), {main, interp})
352
353    def test_from_sibling(self):
354        main, = interpreters.list_all()
355        interp1 = interpreters.create()
356        interp2 = interpreters.create()
357        self.assertEqual(set(interpreters.list_all()),
358                         {main, interp1, interp2})
359        interp1.run(dedent(f"""
360            from test.support import interpreters
361            interp2 = interpreters.Interpreter(int({interp2.id}))
362            interp2.close()
363            interp3 = interpreters.create()
364            interp3.close()
365            """))
366        self.assertEqual(set(interpreters.list_all()), {main, interp1})
367
368    def test_from_other_thread(self):
369        interp = interpreters.create()
370        def f():
371            interp.close()
372
373        t = threading.Thread(target=f)
374        t.start()
375        t.join()
376
377    @unittest.skip('Fails on FreeBSD')
378    def test_still_running(self):
379        main, = interpreters.list_all()
380        interp = interpreters.create()
381        with _running(interp):
382            with self.assertRaises(RuntimeError):
383                interp.close()
384            self.assertTrue(interp.is_running())
385
386
387class TestInterpreterRun(TestBase):
388
389    def test_success(self):
390        interp = interpreters.create()
391        script, file = _captured_script('print("it worked!", end="")')
392        with file:
393            interp.run(script)
394            out = file.read()
395
396        self.assertEqual(out, 'it worked!')
397
398    def test_in_thread(self):
399        interp = interpreters.create()
400        script, file = _captured_script('print("it worked!", end="")')
401        with file:
402            def f():
403                interp.run(script)
404
405            t = threading.Thread(target=f)
406            t.start()
407            t.join()
408            out = file.read()
409
410        self.assertEqual(out, 'it worked!')
411
412    @support.requires_fork()
413    def test_fork(self):
414        interp = interpreters.create()
415        import tempfile
416        with tempfile.NamedTemporaryFile('w+', encoding='utf-8') as file:
417            file.write('')
418            file.flush()
419
420            expected = 'spam spam spam spam spam'
421            script = dedent(f"""
422                import os
423                try:
424                    os.fork()
425                except RuntimeError:
426                    with open('{file.name}', 'w', encoding='utf-8') as out:
427                        out.write('{expected}')
428                """)
429            interp.run(script)
430
431            file.seek(0)
432            content = file.read()
433            self.assertEqual(content, expected)
434
435    @unittest.skip('Fails on FreeBSD')
436    def test_already_running(self):
437        interp = interpreters.create()
438        with _running(interp):
439            with self.assertRaises(RuntimeError):
440                interp.run('print("spam")')
441
442    def test_does_not_exist(self):
443        interp = interpreters.Interpreter(1_000_000)
444        with self.assertRaises(RuntimeError):
445            interp.run('print("spam")')
446
447    def test_bad_id(self):
448        interp = interpreters.Interpreter(-1)
449        with self.assertRaises(ValueError):
450            interp.run('print("spam")')
451
452    def test_bad_script(self):
453        interp = interpreters.create()
454        with self.assertRaises(TypeError):
455            interp.run(10)
456
457    def test_bytes_for_script(self):
458        interp = interpreters.create()
459        with self.assertRaises(TypeError):
460            interp.run(b'print("spam")')
461
462    # test_xxsubinterpreters covers the remaining Interpreter.run() behavior.
463
464
465class TestIsShareable(TestBase):
466
467    def test_default_shareables(self):
468        shareables = [
469                # singletons
470                None,
471                # builtin objects
472                b'spam',
473                'spam',
474                10,
475                -10,
476                ]
477        for obj in shareables:
478            with self.subTest(obj):
479                shareable = interpreters.is_shareable(obj)
480                self.assertTrue(shareable)
481
482    def test_not_shareable(self):
483        class Cheese:
484            def __init__(self, name):
485                self.name = name
486            def __str__(self):
487                return self.name
488
489        class SubBytes(bytes):
490            """A subclass of a shareable type."""
491
492        not_shareables = [
493                # singletons
494                True,
495                False,
496                NotImplemented,
497                ...,
498                # builtin types and objects
499                type,
500                object,
501                object(),
502                Exception(),
503                100.0,
504                # user-defined types and objects
505                Cheese,
506                Cheese('Wensleydale'),
507                SubBytes(b'spam'),
508                ]
509        for obj in not_shareables:
510            with self.subTest(repr(obj)):
511                self.assertFalse(
512                    interpreters.is_shareable(obj))
513
514
515class TestChannels(TestBase):
516
517    def test_create(self):
518        r, s = interpreters.create_channel()
519        self.assertIsInstance(r, interpreters.RecvChannel)
520        self.assertIsInstance(s, interpreters.SendChannel)
521
522    def test_list_all(self):
523        self.assertEqual(interpreters.list_all_channels(), [])
524        created = set()
525        for _ in range(3):
526            ch = interpreters.create_channel()
527            created.add(ch)
528        after = set(interpreters.list_all_channels())
529        self.assertEqual(after, created)
530
531
532class TestRecvChannelAttrs(TestBase):
533
534    def test_id_type(self):
535        rch, _ = interpreters.create_channel()
536        self.assertIsInstance(rch.id, _interpreters.ChannelID)
537
538    def test_custom_id(self):
539        rch = interpreters.RecvChannel(1)
540        self.assertEqual(rch.id, 1)
541
542        with self.assertRaises(TypeError):
543            interpreters.RecvChannel('1')
544
545    def test_id_readonly(self):
546        rch = interpreters.RecvChannel(1)
547        with self.assertRaises(AttributeError):
548            rch.id = 2
549
550    def test_equality(self):
551        ch1, _ = interpreters.create_channel()
552        ch2, _ = interpreters.create_channel()
553        self.assertEqual(ch1, ch1)
554        self.assertNotEqual(ch1, ch2)
555
556
557class TestSendChannelAttrs(TestBase):
558
559    def test_id_type(self):
560        _, sch = interpreters.create_channel()
561        self.assertIsInstance(sch.id, _interpreters.ChannelID)
562
563    def test_custom_id(self):
564        sch = interpreters.SendChannel(1)
565        self.assertEqual(sch.id, 1)
566
567        with self.assertRaises(TypeError):
568            interpreters.SendChannel('1')
569
570    def test_id_readonly(self):
571        sch = interpreters.SendChannel(1)
572        with self.assertRaises(AttributeError):
573            sch.id = 2
574
575    def test_equality(self):
576        _, ch1 = interpreters.create_channel()
577        _, ch2 = interpreters.create_channel()
578        self.assertEqual(ch1, ch1)
579        self.assertNotEqual(ch1, ch2)
580
581
582class TestSendRecv(TestBase):
583
584    def test_send_recv_main(self):
585        r, s = interpreters.create_channel()
586        orig = b'spam'
587        s.send_nowait(orig)
588        obj = r.recv()
589
590        self.assertEqual(obj, orig)
591        self.assertIsNot(obj, orig)
592
593    def test_send_recv_same_interpreter(self):
594        interp = interpreters.create()
595        interp.run(dedent("""
596            from test.support import interpreters
597            r, s = interpreters.create_channel()
598            orig = b'spam'
599            s.send_nowait(orig)
600            obj = r.recv()
601            assert obj == orig, 'expected: obj == orig'
602            assert obj is not orig, 'expected: obj is not orig'
603            """))
604
605    @unittest.skip('broken (see BPO-...)')
606    def test_send_recv_different_interpreters(self):
607        r1, s1 = interpreters.create_channel()
608        r2, s2 = interpreters.create_channel()
609        orig1 = b'spam'
610        s1.send_nowait(orig1)
611        out = _run_output(
612            interpreters.create(),
613            dedent(f"""
614                obj1 = r.recv()
615                assert obj1 == b'spam', 'expected: obj1 == orig1'
616                # When going to another interpreter we get a copy.
617                assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1'
618                orig2 = b'eggs'
619                print(id(orig2))
620                s.send_nowait(orig2)
621                """),
622            channels=dict(r=r1, s=s2),
623            )
624        obj2 = r2.recv()
625
626        self.assertEqual(obj2, b'eggs')
627        self.assertNotEqual(id(obj2), int(out))
628
629    def test_send_recv_different_threads(self):
630        r, s = interpreters.create_channel()
631
632        def f():
633            while True:
634                try:
635                    obj = r.recv()
636                    break
637                except interpreters.ChannelEmptyError:
638                    time.sleep(0.1)
639            s.send(obj)
640        t = threading.Thread(target=f)
641        t.start()
642
643        orig = b'spam'
644        s.send(orig)
645        t.join()
646        obj = r.recv()
647
648        self.assertEqual(obj, orig)
649        self.assertIsNot(obj, orig)
650
651    def test_send_recv_nowait_main(self):
652        r, s = interpreters.create_channel()
653        orig = b'spam'
654        s.send_nowait(orig)
655        obj = r.recv_nowait()
656
657        self.assertEqual(obj, orig)
658        self.assertIsNot(obj, orig)
659
660    def test_send_recv_nowait_main_with_default(self):
661        r, _ = interpreters.create_channel()
662        obj = r.recv_nowait(None)
663
664        self.assertIsNone(obj)
665
666    def test_send_recv_nowait_same_interpreter(self):
667        interp = interpreters.create()
668        interp.run(dedent("""
669            from test.support import interpreters
670            r, s = interpreters.create_channel()
671            orig = b'spam'
672            s.send_nowait(orig)
673            obj = r.recv_nowait()
674            assert obj == orig, 'expected: obj == orig'
675            # When going back to the same interpreter we get the same object.
676            assert obj is not orig, 'expected: obj is not orig'
677            """))
678
679    @unittest.skip('broken (see BPO-...)')
680    def test_send_recv_nowait_different_interpreters(self):
681        r1, s1 = interpreters.create_channel()
682        r2, s2 = interpreters.create_channel()
683        orig1 = b'spam'
684        s1.send_nowait(orig1)
685        out = _run_output(
686            interpreters.create(),
687            dedent(f"""
688                obj1 = r.recv_nowait()
689                assert obj1 == b'spam', 'expected: obj1 == orig1'
690                # When going to another interpreter we get a copy.
691                assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1'
692                orig2 = b'eggs'
693                print(id(orig2))
694                s.send_nowait(orig2)
695                """),
696            channels=dict(r=r1, s=s2),
697            )
698        obj2 = r2.recv_nowait()
699
700        self.assertEqual(obj2, b'eggs')
701        self.assertNotEqual(id(obj2), int(out))
702
703    def test_recv_channel_does_not_exist(self):
704        ch = interpreters.RecvChannel(1_000_000)
705        with self.assertRaises(interpreters.ChannelNotFoundError):
706            ch.recv()
707
708    def test_send_channel_does_not_exist(self):
709        ch = interpreters.SendChannel(1_000_000)
710        with self.assertRaises(interpreters.ChannelNotFoundError):
711            ch.send(b'spam')
712
713    def test_recv_nowait_channel_does_not_exist(self):
714        ch = interpreters.RecvChannel(1_000_000)
715        with self.assertRaises(interpreters.ChannelNotFoundError):
716            ch.recv_nowait()
717
718    def test_send_nowait_channel_does_not_exist(self):
719        ch = interpreters.SendChannel(1_000_000)
720        with self.assertRaises(interpreters.ChannelNotFoundError):
721            ch.send_nowait(b'spam')
722
723    def test_recv_nowait_empty(self):
724        ch, _ = interpreters.create_channel()
725        with self.assertRaises(interpreters.ChannelEmptyError):
726            ch.recv_nowait()
727
728    def test_recv_nowait_default(self):
729        default = object()
730        rch, sch = interpreters.create_channel()
731        obj1 = rch.recv_nowait(default)
732        sch.send_nowait(None)
733        sch.send_nowait(1)
734        sch.send_nowait(b'spam')
735        sch.send_nowait(b'eggs')
736        obj2 = rch.recv_nowait(default)
737        obj3 = rch.recv_nowait(default)
738        obj4 = rch.recv_nowait()
739        obj5 = rch.recv_nowait(default)
740        obj6 = rch.recv_nowait(default)
741
742        self.assertIs(obj1, default)
743        self.assertIs(obj2, None)
744        self.assertEqual(obj3, 1)
745        self.assertEqual(obj4, b'spam')
746        self.assertEqual(obj5, b'eggs')
747        self.assertIs(obj6, default)
748