1import collections.abc
2import traceback
3import types
4import unittest
5
6
7class TestExceptionGroupTypeHierarchy(unittest.TestCase):
8    def test_exception_group_types(self):
9        self.assertTrue(issubclass(ExceptionGroup, Exception))
10        self.assertTrue(issubclass(ExceptionGroup, BaseExceptionGroup))
11        self.assertTrue(issubclass(BaseExceptionGroup, BaseException))
12
13    def test_exception_is_not_generic_type(self):
14        with self.assertRaisesRegex(TypeError, 'Exception'):
15            Exception[OSError]
16
17    def test_exception_group_is_generic_type(self):
18        E = OSError
19        self.assertIsInstance(ExceptionGroup[E], types.GenericAlias)
20        self.assertIsInstance(BaseExceptionGroup[E], types.GenericAlias)
21
22
23class BadConstructorArgs(unittest.TestCase):
24    def test_bad_EG_construction__too_many_args(self):
25        MSG = r'BaseExceptionGroup.__new__\(\) takes exactly 2 arguments'
26        with self.assertRaisesRegex(TypeError, MSG):
27            ExceptionGroup('no errors')
28        with self.assertRaisesRegex(TypeError, MSG):
29            ExceptionGroup([ValueError('no msg')])
30        with self.assertRaisesRegex(TypeError, MSG):
31            ExceptionGroup('eg', [ValueError('too')], [TypeError('many')])
32
33    def test_bad_EG_construction__bad_message(self):
34        MSG = 'argument 1 must be str, not '
35        with self.assertRaisesRegex(TypeError, MSG):
36            ExceptionGroup(ValueError(12), SyntaxError('bad syntax'))
37        with self.assertRaisesRegex(TypeError, MSG):
38            ExceptionGroup(None, [ValueError(12)])
39
40    def test_bad_EG_construction__bad_excs_sequence(self):
41        MSG = r'second argument \(exceptions\) must be a sequence'
42        with self.assertRaisesRegex(TypeError, MSG):
43            ExceptionGroup('errors not sequence', {ValueError(42)})
44        with self.assertRaisesRegex(TypeError, MSG):
45            ExceptionGroup("eg", None)
46
47        MSG = r'second argument \(exceptions\) must be a non-empty sequence'
48        with self.assertRaisesRegex(ValueError, MSG):
49            ExceptionGroup("eg", [])
50
51    def test_bad_EG_construction__nested_non_exceptions(self):
52        MSG = (r'Item [0-9]+ of second argument \(exceptions\)'
53              ' is not an exception')
54        with self.assertRaisesRegex(ValueError, MSG):
55            ExceptionGroup('expect instance, not type', [OSError]);
56        with self.assertRaisesRegex(ValueError, MSG):
57            ExceptionGroup('bad error', ["not an exception"])
58
59
60class InstanceCreation(unittest.TestCase):
61    def test_EG_wraps_Exceptions__creates_EG(self):
62        excs = [ValueError(1), TypeError(2)]
63        self.assertIs(
64            type(ExceptionGroup("eg", excs)),
65            ExceptionGroup)
66
67    def test_BEG_wraps_Exceptions__creates_EG(self):
68        excs = [ValueError(1), TypeError(2)]
69        self.assertIs(
70            type(BaseExceptionGroup("beg", excs)),
71            ExceptionGroup)
72
73    def test_EG_wraps_BaseException__raises_TypeError(self):
74        MSG= "Cannot nest BaseExceptions in an ExceptionGroup"
75        with self.assertRaisesRegex(TypeError, MSG):
76            eg = ExceptionGroup("eg", [ValueError(1), KeyboardInterrupt(2)])
77
78    def test_BEG_wraps_BaseException__creates_BEG(self):
79        beg = BaseExceptionGroup("beg", [ValueError(1), KeyboardInterrupt(2)])
80        self.assertIs(type(beg), BaseExceptionGroup)
81
82    def test_EG_subclass_wraps_non_base_exceptions(self):
83        class MyEG(ExceptionGroup):
84            pass
85
86        self.assertIs(
87            type(MyEG("eg", [ValueError(12), TypeError(42)])),
88            MyEG)
89
90    def test_EG_subclass_does_not_wrap_base_exceptions(self):
91        class MyEG(ExceptionGroup):
92            pass
93
94        msg = "Cannot nest BaseExceptions in 'MyEG'"
95        with self.assertRaisesRegex(TypeError, msg):
96            MyEG("eg", [ValueError(12), KeyboardInterrupt(42)])
97
98    def test_BEG_and_E_subclass_does_not_wrap_base_exceptions(self):
99        class MyEG(BaseExceptionGroup, ValueError):
100            pass
101
102        msg = "Cannot nest BaseExceptions in 'MyEG'"
103        with self.assertRaisesRegex(TypeError, msg):
104            MyEG("eg", [ValueError(12), KeyboardInterrupt(42)])
105
106    def test_EG_and_specific_subclass_can_wrap_any_nonbase_exception(self):
107        class MyEG(ExceptionGroup, ValueError):
108            pass
109
110        # The restriction is specific to Exception, not "the other base class"
111        MyEG("eg", [ValueError(12), Exception()])
112
113    def test_BEG_and_specific_subclass_can_wrap_any_nonbase_exception(self):
114        class MyEG(BaseExceptionGroup, ValueError):
115            pass
116
117        # The restriction is specific to Exception, not "the other base class"
118        MyEG("eg", [ValueError(12), Exception()])
119
120
121    def test_BEG_subclass_wraps_anything(self):
122        class MyBEG(BaseExceptionGroup):
123            pass
124
125        self.assertIs(
126            type(MyBEG("eg", [ValueError(12), TypeError(42)])),
127            MyBEG)
128        self.assertIs(
129            type(MyBEG("eg", [ValueError(12), KeyboardInterrupt(42)])),
130            MyBEG)
131
132
133class StrAndReprTests(unittest.TestCase):
134    def test_ExceptionGroup(self):
135        eg = BaseExceptionGroup(
136            'flat', [ValueError(1), TypeError(2)])
137
138        self.assertEqual(str(eg), "flat (2 sub-exceptions)")
139        self.assertEqual(repr(eg),
140            "ExceptionGroup('flat', [ValueError(1), TypeError(2)])")
141
142        eg = BaseExceptionGroup(
143            'nested', [eg, ValueError(1), eg, TypeError(2)])
144
145        self.assertEqual(str(eg), "nested (4 sub-exceptions)")
146        self.assertEqual(repr(eg),
147            "ExceptionGroup('nested', "
148                "[ExceptionGroup('flat', "
149                    "[ValueError(1), TypeError(2)]), "
150                 "ValueError(1), "
151                 "ExceptionGroup('flat', "
152                    "[ValueError(1), TypeError(2)]), TypeError(2)])")
153
154    def test_BaseExceptionGroup(self):
155        eg = BaseExceptionGroup(
156            'flat', [ValueError(1), KeyboardInterrupt(2)])
157
158        self.assertEqual(str(eg), "flat (2 sub-exceptions)")
159        self.assertEqual(repr(eg),
160            "BaseExceptionGroup("
161                "'flat', "
162                "[ValueError(1), KeyboardInterrupt(2)])")
163
164        eg = BaseExceptionGroup(
165            'nested', [eg, ValueError(1), eg])
166
167        self.assertEqual(str(eg), "nested (3 sub-exceptions)")
168        self.assertEqual(repr(eg),
169            "BaseExceptionGroup('nested', "
170                "[BaseExceptionGroup('flat', "
171                    "[ValueError(1), KeyboardInterrupt(2)]), "
172                "ValueError(1), "
173                "BaseExceptionGroup('flat', "
174                    "[ValueError(1), KeyboardInterrupt(2)])])")
175
176    def test_custom_exception(self):
177        class MyEG(ExceptionGroup):
178            pass
179
180        eg = MyEG(
181            'flat', [ValueError(1), TypeError(2)])
182
183        self.assertEqual(str(eg), "flat (2 sub-exceptions)")
184        self.assertEqual(repr(eg), "MyEG('flat', [ValueError(1), TypeError(2)])")
185
186        eg = MyEG(
187            'nested', [eg, ValueError(1), eg, TypeError(2)])
188
189        self.assertEqual(str(eg), "nested (4 sub-exceptions)")
190        self.assertEqual(repr(eg), (
191                 "MyEG('nested', "
192                     "[MyEG('flat', [ValueError(1), TypeError(2)]), "
193                      "ValueError(1), "
194                      "MyEG('flat', [ValueError(1), TypeError(2)]), "
195                      "TypeError(2)])"))
196
197
198def create_simple_eg():
199    excs = []
200    try:
201        try:
202            raise MemoryError("context and cause for ValueError(1)")
203        except MemoryError as e:
204            raise ValueError(1) from e
205    except ValueError as e:
206        excs.append(e)
207
208    try:
209        try:
210            raise OSError("context for TypeError")
211        except OSError as e:
212            raise TypeError(int)
213    except TypeError as e:
214        excs.append(e)
215
216    try:
217        try:
218            raise ImportError("context for ValueError(2)")
219        except ImportError as e:
220            raise ValueError(2)
221    except ValueError as e:
222        excs.append(e)
223
224    try:
225        raise ExceptionGroup('simple eg', excs)
226    except ExceptionGroup as e:
227        return e
228
229
230class ExceptionGroupFields(unittest.TestCase):
231    def test_basics_ExceptionGroup_fields(self):
232        eg = create_simple_eg()
233
234        # check msg
235        self.assertEqual(eg.message, 'simple eg')
236        self.assertEqual(eg.args[0], 'simple eg')
237
238        # check cause and context
239        self.assertIsInstance(eg.exceptions[0], ValueError)
240        self.assertIsInstance(eg.exceptions[0].__cause__, MemoryError)
241        self.assertIsInstance(eg.exceptions[0].__context__, MemoryError)
242        self.assertIsInstance(eg.exceptions[1], TypeError)
243        self.assertIsNone(eg.exceptions[1].__cause__)
244        self.assertIsInstance(eg.exceptions[1].__context__, OSError)
245        self.assertIsInstance(eg.exceptions[2], ValueError)
246        self.assertIsNone(eg.exceptions[2].__cause__)
247        self.assertIsInstance(eg.exceptions[2].__context__, ImportError)
248
249        # check tracebacks
250        line0 = create_simple_eg.__code__.co_firstlineno
251        tb_linenos = [line0 + 27,
252                      [line0 + 6, line0 + 14, line0 + 22]]
253        self.assertEqual(eg.__traceback__.tb_lineno, tb_linenos[0])
254        self.assertIsNone(eg.__traceback__.tb_next)
255        for i in range(3):
256            tb = eg.exceptions[i].__traceback__
257            self.assertIsNone(tb.tb_next)
258            self.assertEqual(tb.tb_lineno, tb_linenos[1][i])
259
260    def test_fields_are_readonly(self):
261        eg = ExceptionGroup('eg', [TypeError(1), OSError(2)])
262
263        self.assertEqual(type(eg.exceptions), tuple)
264
265        eg.message
266        with self.assertRaises(AttributeError):
267            eg.message = "new msg"
268
269        eg.exceptions
270        with self.assertRaises(AttributeError):
271            eg.exceptions = [OSError('xyz')]
272
273
274class ExceptionGroupTestBase(unittest.TestCase):
275    def assertMatchesTemplate(self, exc, exc_type, template):
276        """ Assert that the exception matches the template
277
278            A template describes the shape of exc. If exc is a
279            leaf exception (i.e., not an exception group) then
280            template is an exception instance that has the
281            expected type and args value of exc. If exc is an
282            exception group, then template is a list of the
283            templates of its nested exceptions.
284        """
285        if exc_type is not None:
286            self.assertIs(type(exc), exc_type)
287
288        if isinstance(exc, BaseExceptionGroup):
289            self.assertIsInstance(template, collections.abc.Sequence)
290            self.assertEqual(len(exc.exceptions), len(template))
291            for e, t in zip(exc.exceptions, template):
292                self.assertMatchesTemplate(e, None, t)
293        else:
294            self.assertIsInstance(template, BaseException)
295            self.assertEqual(type(exc), type(template))
296            self.assertEqual(exc.args, template.args)
297
298
299class ExceptionGroupSubgroupTests(ExceptionGroupTestBase):
300    def setUp(self):
301        self.eg = create_simple_eg()
302        self.eg_template = [ValueError(1), TypeError(int), ValueError(2)]
303
304    def test_basics_subgroup_split__bad_arg_type(self):
305        bad_args = ["bad arg",
306                    OSError('instance not type'),
307                    [OSError, TypeError],
308                    (OSError, 42)]
309        for arg in bad_args:
310            with self.assertRaises(TypeError):
311                self.eg.subgroup(arg)
312            with self.assertRaises(TypeError):
313                self.eg.split(arg)
314
315    def test_basics_subgroup_by_type__passthrough(self):
316        eg = self.eg
317        self.assertIs(eg, eg.subgroup(BaseException))
318        self.assertIs(eg, eg.subgroup(Exception))
319        self.assertIs(eg, eg.subgroup(BaseExceptionGroup))
320        self.assertIs(eg, eg.subgroup(ExceptionGroup))
321
322    def test_basics_subgroup_by_type__no_match(self):
323        self.assertIsNone(self.eg.subgroup(OSError))
324
325    def test_basics_subgroup_by_type__match(self):
326        eg = self.eg
327        testcases = [
328            # (match_type, result_template)
329            (ValueError, [ValueError(1), ValueError(2)]),
330            (TypeError, [TypeError(int)]),
331            ((ValueError, TypeError), self.eg_template)]
332
333        for match_type, template in testcases:
334            with self.subTest(match=match_type):
335                subeg = eg.subgroup(match_type)
336                self.assertEqual(subeg.message, eg.message)
337                self.assertMatchesTemplate(subeg, ExceptionGroup, template)
338
339    def test_basics_subgroup_by_predicate__passthrough(self):
340        self.assertIs(self.eg, self.eg.subgroup(lambda e: True))
341
342    def test_basics_subgroup_by_predicate__no_match(self):
343        self.assertIsNone(self.eg.subgroup(lambda e: False))
344
345    def test_basics_subgroup_by_predicate__match(self):
346        eg = self.eg
347        testcases = [
348            # (match_type, result_template)
349            (ValueError, [ValueError(1), ValueError(2)]),
350            (TypeError, [TypeError(int)]),
351            ((ValueError, TypeError), self.eg_template)]
352
353        for match_type, template in testcases:
354            subeg = eg.subgroup(lambda e: isinstance(e, match_type))
355            self.assertEqual(subeg.message, eg.message)
356            self.assertMatchesTemplate(subeg, ExceptionGroup, template)
357
358
359class ExceptionGroupSplitTests(ExceptionGroupTestBase):
360    def setUp(self):
361        self.eg = create_simple_eg()
362        self.eg_template = [ValueError(1), TypeError(int), ValueError(2)]
363
364    def test_basics_split_by_type__passthrough(self):
365        for E in [BaseException, Exception,
366                  BaseExceptionGroup, ExceptionGroup]:
367            match, rest = self.eg.split(E)
368            self.assertMatchesTemplate(
369                match, ExceptionGroup, self.eg_template)
370            self.assertIsNone(rest)
371
372    def test_basics_split_by_type__no_match(self):
373        match, rest = self.eg.split(OSError)
374        self.assertIsNone(match)
375        self.assertMatchesTemplate(
376            rest, ExceptionGroup, self.eg_template)
377
378    def test_basics_split_by_type__match(self):
379        eg = self.eg
380        VE = ValueError
381        TE = TypeError
382        testcases = [
383            # (matcher, match_template, rest_template)
384            (VE, [VE(1), VE(2)], [TE(int)]),
385            (TE, [TE(int)], [VE(1), VE(2)]),
386            ((VE, TE), self.eg_template, None),
387            ((OSError, VE), [VE(1), VE(2)], [TE(int)]),
388        ]
389
390        for match_type, match_template, rest_template in testcases:
391            match, rest = eg.split(match_type)
392            self.assertEqual(match.message, eg.message)
393            self.assertMatchesTemplate(
394                match, ExceptionGroup, match_template)
395            if rest_template is not None:
396                self.assertEqual(rest.message, eg.message)
397                self.assertMatchesTemplate(
398                    rest, ExceptionGroup, rest_template)
399            else:
400                self.assertIsNone(rest)
401
402    def test_basics_split_by_predicate__passthrough(self):
403        match, rest = self.eg.split(lambda e: True)
404        self.assertMatchesTemplate(match, ExceptionGroup, self.eg_template)
405        self.assertIsNone(rest)
406
407    def test_basics_split_by_predicate__no_match(self):
408        match, rest = self.eg.split(lambda e: False)
409        self.assertIsNone(match)
410        self.assertMatchesTemplate(rest, ExceptionGroup, self.eg_template)
411
412    def test_basics_split_by_predicate__match(self):
413        eg = self.eg
414        VE = ValueError
415        TE = TypeError
416        testcases = [
417            # (matcher, match_template, rest_template)
418            (VE, [VE(1), VE(2)], [TE(int)]),
419            (TE, [TE(int)], [VE(1), VE(2)]),
420            ((VE, TE), self.eg_template, None),
421        ]
422
423        for match_type, match_template, rest_template in testcases:
424            match, rest = eg.split(lambda e: isinstance(e, match_type))
425            self.assertEqual(match.message, eg.message)
426            self.assertMatchesTemplate(
427                match, ExceptionGroup, match_template)
428            if rest_template is not None:
429                self.assertEqual(rest.message, eg.message)
430                self.assertMatchesTemplate(
431                    rest, ExceptionGroup, rest_template)
432
433
434class DeepRecursionInSplitAndSubgroup(unittest.TestCase):
435    def make_deep_eg(self):
436        e = TypeError(1)
437        for i in range(2000):
438            e = ExceptionGroup('eg', [e])
439        return e
440
441    def test_deep_split(self):
442        e = self.make_deep_eg()
443        with self.assertRaises(RecursionError):
444            e.split(TypeError)
445
446    def test_deep_subgroup(self):
447        e = self.make_deep_eg()
448        with self.assertRaises(RecursionError):
449            e.subgroup(TypeError)
450
451
452def leaf_generator(exc, tbs=None):
453    if tbs is None:
454        tbs = []
455    tbs.append(exc.__traceback__)
456    if isinstance(exc, BaseExceptionGroup):
457        for e in exc.exceptions:
458            yield from leaf_generator(e, tbs)
459    else:
460        # exc is a leaf exception and its traceback
461        # is the concatenation of the traceback
462        # segments in tbs
463        yield exc, tbs
464    tbs.pop()
465
466
467class LeafGeneratorTest(unittest.TestCase):
468    # The leaf_generator is mentioned in PEP 654 as a suggestion
469    # on how to iterate over leaf nodes of an EG. Is is also
470    # used below as a test utility. So we test it here.
471
472    def test_leaf_generator(self):
473        eg = create_simple_eg()
474
475        self.assertSequenceEqual(
476            [e for e, _ in leaf_generator(eg)],
477            eg.exceptions)
478
479        for e, tbs in leaf_generator(eg):
480            self.assertSequenceEqual(
481                tbs, [eg.__traceback__, e.__traceback__])
482
483
484def create_nested_eg():
485    excs = []
486    try:
487        try:
488            raise TypeError(bytes)
489        except TypeError as e:
490            raise ExceptionGroup("nested", [e])
491    except ExceptionGroup as e:
492        excs.append(e)
493
494    try:
495        try:
496            raise MemoryError('out of memory')
497        except MemoryError as e:
498            raise ValueError(1) from e
499    except ValueError as e:
500        excs.append(e)
501
502    try:
503        raise ExceptionGroup("root", excs)
504    except ExceptionGroup as eg:
505        return eg
506
507
508class NestedExceptionGroupBasicsTest(ExceptionGroupTestBase):
509    def test_nested_group_matches_template(self):
510        eg = create_nested_eg()
511        self.assertMatchesTemplate(
512            eg,
513            ExceptionGroup,
514            [[TypeError(bytes)], ValueError(1)])
515
516    def test_nested_group_chaining(self):
517        eg = create_nested_eg()
518        self.assertIsInstance(eg.exceptions[1].__context__, MemoryError)
519        self.assertIsInstance(eg.exceptions[1].__cause__, MemoryError)
520        self.assertIsInstance(eg.exceptions[0].__context__, TypeError)
521
522    def test_nested_exception_group_tracebacks(self):
523        eg = create_nested_eg()
524
525        line0 = create_nested_eg.__code__.co_firstlineno
526        for (tb, expected) in [
527            (eg.__traceback__, line0 + 19),
528            (eg.exceptions[0].__traceback__, line0 + 6),
529            (eg.exceptions[1].__traceback__, line0 + 14),
530            (eg.exceptions[0].exceptions[0].__traceback__, line0 + 4),
531        ]:
532            self.assertEqual(tb.tb_lineno, expected)
533            self.assertIsNone(tb.tb_next)
534
535    def test_iteration_full_tracebacks(self):
536        eg = create_nested_eg()
537        # check that iteration over leaves
538        # produces the expected tracebacks
539        self.assertEqual(len(list(leaf_generator(eg))), 2)
540
541        line0 = create_nested_eg.__code__.co_firstlineno
542        expected_tbs = [ [line0 + 19, line0 + 6, line0 + 4],
543                         [line0 + 19, line0 + 14]]
544
545        for (i, (_, tbs)) in enumerate(leaf_generator(eg)):
546            self.assertSequenceEqual(
547                [tb.tb_lineno for tb in tbs],
548                expected_tbs[i])
549
550
551class ExceptionGroupSplitTestBase(ExceptionGroupTestBase):
552
553    def split_exception_group(self, eg, types):
554        """ Split an EG and do some sanity checks on the result """
555        self.assertIsInstance(eg, BaseExceptionGroup)
556
557        match, rest = eg.split(types)
558        sg = eg.subgroup(types)
559
560        if match is not None:
561            self.assertIsInstance(match, BaseExceptionGroup)
562            for e,_ in leaf_generator(match):
563                self.assertIsInstance(e, types)
564
565            self.assertIsNotNone(sg)
566            self.assertIsInstance(sg, BaseExceptionGroup)
567            for e,_ in leaf_generator(sg):
568                self.assertIsInstance(e, types)
569
570        if rest is not None:
571            self.assertIsInstance(rest, BaseExceptionGroup)
572
573        def leaves(exc):
574            return [] if exc is None else [e for e,_ in leaf_generator(exc)]
575
576        # match and subgroup have the same leaves
577        self.assertSequenceEqual(leaves(match), leaves(sg))
578
579        match_leaves = leaves(match)
580        rest_leaves = leaves(rest)
581        # each leaf exception of eg is in exactly one of match and rest
582        self.assertEqual(
583            len(leaves(eg)),
584            len(leaves(match)) + len(leaves(rest)))
585
586        for e in leaves(eg):
587            self.assertNotEqual(
588                match and e in match_leaves,
589                rest and e in rest_leaves)
590
591        # message, cause and context, traceback and note equal to eg
592        for part in [match, rest, sg]:
593            if part is not None:
594                self.assertEqual(eg.message, part.message)
595                self.assertIs(eg.__cause__, part.__cause__)
596                self.assertIs(eg.__context__, part.__context__)
597                self.assertIs(eg.__traceback__, part.__traceback__)
598                self.assertEqual(
599                    getattr(eg, '__notes__', None),
600                    getattr(part, '__notes__', None))
601
602        def tbs_for_leaf(leaf, eg):
603            for e, tbs in leaf_generator(eg):
604                if e is leaf:
605                    return tbs
606
607        def tb_linenos(tbs):
608            return [tb.tb_lineno for tb in tbs if tb]
609
610        # full tracebacks match
611        for part in [match, rest, sg]:
612            for e in leaves(part):
613                self.assertSequenceEqual(
614                    tb_linenos(tbs_for_leaf(e, eg)),
615                    tb_linenos(tbs_for_leaf(e, part)))
616
617        return match, rest
618
619
620class NestedExceptionGroupSplitTest(ExceptionGroupSplitTestBase):
621
622    def test_split_by_type(self):
623        class MyExceptionGroup(ExceptionGroup):
624            pass
625
626        def raiseVE(v):
627            raise ValueError(v)
628
629        def raiseTE(t):
630            raise TypeError(t)
631
632        def nested_group():
633            def level1(i):
634                excs = []
635                for f, arg in [(raiseVE, i), (raiseTE, int), (raiseVE, i+1)]:
636                    try:
637                        f(arg)
638                    except Exception as e:
639                        excs.append(e)
640                raise ExceptionGroup('msg1', excs)
641
642            def level2(i):
643                excs = []
644                for f, arg in [(level1, i), (level1, i+1), (raiseVE, i+2)]:
645                    try:
646                        f(arg)
647                    except Exception as e:
648                        excs.append(e)
649                raise MyExceptionGroup('msg2', excs)
650
651            def level3(i):
652                excs = []
653                for f, arg in [(level2, i+1), (raiseVE, i+2)]:
654                    try:
655                        f(arg)
656                    except Exception as e:
657                        excs.append(e)
658                raise ExceptionGroup('msg3', excs)
659
660            level3(5)
661
662        try:
663            nested_group()
664        except ExceptionGroup as e:
665            e.add_note(f"the note: {id(e)}")
666            eg = e
667
668        eg_template = [
669            [
670                [ValueError(6), TypeError(int), ValueError(7)],
671                [ValueError(7), TypeError(int), ValueError(8)],
672                ValueError(8),
673            ],
674            ValueError(7)]
675
676        valueErrors_template = [
677            [
678                [ValueError(6), ValueError(7)],
679                [ValueError(7), ValueError(8)],
680                ValueError(8),
681            ],
682            ValueError(7)]
683
684        typeErrors_template = [[[TypeError(int)], [TypeError(int)]]]
685
686        self.assertMatchesTemplate(eg, ExceptionGroup, eg_template)
687
688        # Match Nothing
689        match, rest = self.split_exception_group(eg, SyntaxError)
690        self.assertIsNone(match)
691        self.assertMatchesTemplate(rest, ExceptionGroup, eg_template)
692
693        # Match Everything
694        match, rest = self.split_exception_group(eg, BaseException)
695        self.assertMatchesTemplate(match, ExceptionGroup, eg_template)
696        self.assertIsNone(rest)
697        match, rest = self.split_exception_group(eg, (ValueError, TypeError))
698        self.assertMatchesTemplate(match, ExceptionGroup, eg_template)
699        self.assertIsNone(rest)
700
701        # Match ValueErrors
702        match, rest = self.split_exception_group(eg, ValueError)
703        self.assertMatchesTemplate(match, ExceptionGroup, valueErrors_template)
704        self.assertMatchesTemplate(rest, ExceptionGroup, typeErrors_template)
705
706        # Match TypeErrors
707        match, rest = self.split_exception_group(eg, (TypeError, SyntaxError))
708        self.assertMatchesTemplate(match, ExceptionGroup, typeErrors_template)
709        self.assertMatchesTemplate(rest, ExceptionGroup, valueErrors_template)
710
711        # Match ExceptionGroup
712        match, rest = eg.split(ExceptionGroup)
713        self.assertIs(match, eg)
714        self.assertIsNone(rest)
715
716        # Match MyExceptionGroup (ExceptionGroup subclass)
717        match, rest = eg.split(MyExceptionGroup)
718        self.assertMatchesTemplate(match, ExceptionGroup, [eg_template[0]])
719        self.assertMatchesTemplate(rest, ExceptionGroup, [eg_template[1]])
720
721    def test_split_BaseExceptionGroup(self):
722        def exc(ex):
723            try:
724                raise ex
725            except BaseException as e:
726                return e
727
728        try:
729            raise BaseExceptionGroup(
730                "beg", [exc(ValueError(1)), exc(KeyboardInterrupt(2))])
731        except BaseExceptionGroup as e:
732            beg = e
733
734        # Match Nothing
735        match, rest = self.split_exception_group(beg, TypeError)
736        self.assertIsNone(match)
737        self.assertMatchesTemplate(
738            rest, BaseExceptionGroup, [ValueError(1), KeyboardInterrupt(2)])
739
740        # Match Everything
741        match, rest = self.split_exception_group(
742            beg, (ValueError, KeyboardInterrupt))
743        self.assertMatchesTemplate(
744            match, BaseExceptionGroup, [ValueError(1), KeyboardInterrupt(2)])
745        self.assertIsNone(rest)
746
747        # Match ValueErrors
748        match, rest = self.split_exception_group(beg, ValueError)
749        self.assertMatchesTemplate(
750            match, ExceptionGroup, [ValueError(1)])
751        self.assertMatchesTemplate(
752            rest, BaseExceptionGroup, [KeyboardInterrupt(2)])
753
754        # Match KeyboardInterrupts
755        match, rest = self.split_exception_group(beg, KeyboardInterrupt)
756        self.assertMatchesTemplate(
757            match, BaseExceptionGroup, [KeyboardInterrupt(2)])
758        self.assertMatchesTemplate(
759            rest, ExceptionGroup, [ValueError(1)])
760
761    def test_split_copies_notes(self):
762        # make sure each exception group after a split has its own __notes__ list
763        eg = ExceptionGroup("eg", [ValueError(1), TypeError(2)])
764        eg.add_note("note1")
765        eg.add_note("note2")
766        orig_notes = list(eg.__notes__)
767        match, rest = eg.split(TypeError)
768        self.assertEqual(eg.__notes__, orig_notes)
769        self.assertEqual(match.__notes__, orig_notes)
770        self.assertEqual(rest.__notes__, orig_notes)
771        self.assertIsNot(eg.__notes__, match.__notes__)
772        self.assertIsNot(eg.__notes__, rest.__notes__)
773        self.assertIsNot(match.__notes__, rest.__notes__)
774        eg.add_note("eg")
775        match.add_note("match")
776        rest.add_note("rest")
777        self.assertEqual(eg.__notes__, orig_notes + ["eg"])
778        self.assertEqual(match.__notes__, orig_notes + ["match"])
779        self.assertEqual(rest.__notes__, orig_notes + ["rest"])
780
781    def test_split_does_not_copy_non_sequence_notes(self):
782        # __notes__ should be a sequence, which is shallow copied.
783        # If it is not a sequence, the split parts don't get any notes.
784        eg = ExceptionGroup("eg", [ValueError(1), TypeError(2)])
785        eg.__notes__ = 123
786        match, rest = eg.split(TypeError)
787        self.assertFalse(hasattr(match, '__notes__'))
788        self.assertFalse(hasattr(rest, '__notes__'))
789
790
791class NestedExceptionGroupSubclassSplitTest(ExceptionGroupSplitTestBase):
792
793    def test_split_ExceptionGroup_subclass_no_derive_no_new_override(self):
794        class EG(ExceptionGroup):
795            pass
796
797        try:
798            try:
799                try:
800                    raise TypeError(2)
801                except TypeError as te:
802                    raise EG("nested", [te])
803            except EG as nested:
804                try:
805                    raise ValueError(1)
806                except ValueError as ve:
807                    raise EG("eg", [ve, nested])
808        except EG as e:
809            eg = e
810
811        self.assertMatchesTemplate(eg, EG, [ValueError(1), [TypeError(2)]])
812
813        # Match Nothing
814        match, rest = self.split_exception_group(eg, OSError)
815        self.assertIsNone(match)
816        self.assertMatchesTemplate(
817            rest, ExceptionGroup, [ValueError(1), [TypeError(2)]])
818
819        # Match Everything
820        match, rest = self.split_exception_group(eg, (ValueError, TypeError))
821        self.assertMatchesTemplate(
822            match, ExceptionGroup, [ValueError(1), [TypeError(2)]])
823        self.assertIsNone(rest)
824
825        # Match ValueErrors
826        match, rest = self.split_exception_group(eg, ValueError)
827        self.assertMatchesTemplate(match, ExceptionGroup, [ValueError(1)])
828        self.assertMatchesTemplate(rest, ExceptionGroup, [[TypeError(2)]])
829
830        # Match TypeErrors
831        match, rest = self.split_exception_group(eg, TypeError)
832        self.assertMatchesTemplate(match, ExceptionGroup, [[TypeError(2)]])
833        self.assertMatchesTemplate(rest, ExceptionGroup, [ValueError(1)])
834
835    def test_split_BaseExceptionGroup_subclass_no_derive_new_override(self):
836        class EG(BaseExceptionGroup):
837            def __new__(cls, message, excs, unused):
838                # The "unused" arg is here to show that split() doesn't call
839                # the actual class constructor from the default derive()
840                # implementation (it would fail on unused arg if so because
841                # it assumes the BaseExceptionGroup.__new__ signature).
842                return super().__new__(cls, message, excs)
843
844        try:
845            raise EG("eg", [ValueError(1), KeyboardInterrupt(2)], "unused")
846        except EG as e:
847            eg = e
848
849        self.assertMatchesTemplate(
850            eg, EG, [ValueError(1), KeyboardInterrupt(2)])
851
852        # Match Nothing
853        match, rest = self.split_exception_group(eg, OSError)
854        self.assertIsNone(match)
855        self.assertMatchesTemplate(
856            rest, BaseExceptionGroup, [ValueError(1), KeyboardInterrupt(2)])
857
858        # Match Everything
859        match, rest = self.split_exception_group(
860            eg, (ValueError, KeyboardInterrupt))
861        self.assertMatchesTemplate(
862            match, BaseExceptionGroup, [ValueError(1), KeyboardInterrupt(2)])
863        self.assertIsNone(rest)
864
865        # Match ValueErrors
866        match, rest = self.split_exception_group(eg, ValueError)
867        self.assertMatchesTemplate(match, ExceptionGroup, [ValueError(1)])
868        self.assertMatchesTemplate(
869            rest, BaseExceptionGroup, [KeyboardInterrupt(2)])
870
871        # Match KeyboardInterrupt
872        match, rest = self.split_exception_group(eg, KeyboardInterrupt)
873        self.assertMatchesTemplate(
874            match, BaseExceptionGroup, [KeyboardInterrupt(2)])
875        self.assertMatchesTemplate(rest, ExceptionGroup, [ValueError(1)])
876
877    def test_split_ExceptionGroup_subclass_derive_and_new_overrides(self):
878        class EG(ExceptionGroup):
879            def __new__(cls, message, excs, code):
880                obj = super().__new__(cls, message, excs)
881                obj.code = code
882                return obj
883
884            def derive(self, excs):
885                return EG(self.message, excs, self.code)
886
887        try:
888            try:
889                try:
890                    raise TypeError(2)
891                except TypeError as te:
892                    raise EG("nested", [te], 101)
893            except EG as nested:
894                try:
895                    raise ValueError(1)
896                except ValueError as ve:
897                    raise EG("eg", [ve, nested], 42)
898        except EG as e:
899            eg = e
900
901        self.assertMatchesTemplate(eg, EG, [ValueError(1), [TypeError(2)]])
902
903        # Match Nothing
904        match, rest = self.split_exception_group(eg, OSError)
905        self.assertIsNone(match)
906        self.assertMatchesTemplate(rest, EG, [ValueError(1), [TypeError(2)]])
907        self.assertEqual(rest.code, 42)
908        self.assertEqual(rest.exceptions[1].code, 101)
909
910        # Match Everything
911        match, rest = self.split_exception_group(eg, (ValueError, TypeError))
912        self.assertMatchesTemplate(match, EG, [ValueError(1), [TypeError(2)]])
913        self.assertEqual(match.code, 42)
914        self.assertEqual(match.exceptions[1].code, 101)
915        self.assertIsNone(rest)
916
917        # Match ValueErrors
918        match, rest = self.split_exception_group(eg, ValueError)
919        self.assertMatchesTemplate(match, EG, [ValueError(1)])
920        self.assertEqual(match.code, 42)
921        self.assertMatchesTemplate(rest, EG, [[TypeError(2)]])
922        self.assertEqual(rest.code, 42)
923        self.assertEqual(rest.exceptions[0].code, 101)
924
925        # Match TypeErrors
926        match, rest = self.split_exception_group(eg, TypeError)
927        self.assertMatchesTemplate(match, EG, [[TypeError(2)]])
928        self.assertEqual(match.code, 42)
929        self.assertEqual(match.exceptions[0].code, 101)
930        self.assertMatchesTemplate(rest, EG, [ValueError(1)])
931        self.assertEqual(rest.code, 42)
932
933
934if __name__ == '__main__':
935    unittest.main()
936