1# Deliberately use "from dataclasses import *".  Every name in __all__
2# is tested, so they all must be present.  This is a way to catch
3# missing ones.
4
5from dataclasses import *
6
7import abc
8import io
9import pickle
10import inspect
11import builtins
12import types
13import weakref
14import traceback
15import unittest
16from unittest.mock import Mock
17from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol
18from typing import get_type_hints
19from collections import deque, OrderedDict, namedtuple
20from functools import total_ordering
21
22import typing       # Needed for the string "typing.ClassVar[int]" to work as an annotation.
23import dataclasses  # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
24
25# Just any custom exception we can catch.
26class CustomError(Exception): pass
27
28class TestCase(unittest.TestCase):
29    def test_no_fields(self):
30        @dataclass
31        class C:
32            pass
33
34        o = C()
35        self.assertEqual(len(fields(C)), 0)
36
37    def test_no_fields_but_member_variable(self):
38        @dataclass
39        class C:
40            i = 0
41
42        o = C()
43        self.assertEqual(len(fields(C)), 0)
44
45    def test_one_field_no_default(self):
46        @dataclass
47        class C:
48            x: int
49
50        o = C(42)
51        self.assertEqual(o.x, 42)
52
53    def test_field_default_default_factory_error(self):
54        msg = "cannot specify both default and default_factory"
55        with self.assertRaisesRegex(ValueError, msg):
56            @dataclass
57            class C:
58                x: int = field(default=1, default_factory=int)
59
60    def test_field_repr(self):
61        int_field = field(default=1, init=True, repr=False)
62        int_field.name = "id"
63        repr_output = repr(int_field)
64        expected_output = "Field(name='id',type=None," \
65                           f"default=1,default_factory={MISSING!r}," \
66                           "init=True,repr=False,hash=None," \
67                           "compare=True,metadata=mappingproxy({})," \
68                           f"kw_only={MISSING!r}," \
69                           "_field_type=None)"
70
71        self.assertEqual(repr_output, expected_output)
72
73    def test_field_recursive_repr(self):
74        rec_field = field()
75        rec_field.type = rec_field
76        rec_field.name = "id"
77        repr_output = repr(rec_field)
78
79        self.assertIn(",type=...,", repr_output)
80
81    def test_recursive_annotation(self):
82        class C:
83            pass
84
85        @dataclass
86        class D:
87            C: C = field()
88
89        self.assertIn(",type=...,", repr(D.__dataclass_fields__["C"]))
90
91    def test_named_init_params(self):
92        @dataclass
93        class C:
94            x: int
95
96        o = C(x=32)
97        self.assertEqual(o.x, 32)
98
99    def test_two_fields_one_default(self):
100        @dataclass
101        class C:
102            x: int
103            y: int = 0
104
105        o = C(3)
106        self.assertEqual((o.x, o.y), (3, 0))
107
108        # Non-defaults following defaults.
109        with self.assertRaisesRegex(TypeError,
110                                    "non-default argument 'y' follows "
111                                    "default argument"):
112            @dataclass
113            class C:
114                x: int = 0
115                y: int
116
117        # A derived class adds a non-default field after a default one.
118        with self.assertRaisesRegex(TypeError,
119                                    "non-default argument 'y' follows "
120                                    "default argument"):
121            @dataclass
122            class B:
123                x: int = 0
124
125            @dataclass
126            class C(B):
127                y: int
128
129        # Override a base class field and add a default to
130        #  a field which didn't use to have a default.
131        with self.assertRaisesRegex(TypeError,
132                                    "non-default argument 'y' follows "
133                                    "default argument"):
134            @dataclass
135            class B:
136                x: int
137                y: int
138
139            @dataclass
140            class C(B):
141                x: int = 0
142
143    def test_overwrite_hash(self):
144        # Test that declaring this class isn't an error.  It should
145        #  use the user-provided __hash__.
146        @dataclass(frozen=True)
147        class C:
148            x: int
149            def __hash__(self):
150                return 301
151        self.assertEqual(hash(C(100)), 301)
152
153        # Test that declaring this class isn't an error.  It should
154        #  use the generated __hash__.
155        @dataclass(frozen=True)
156        class C:
157            x: int
158            def __eq__(self, other):
159                return False
160        self.assertEqual(hash(C(100)), hash((100,)))
161
162        # But this one should generate an exception, because with
163        #  unsafe_hash=True, it's an error to have a __hash__ defined.
164        with self.assertRaisesRegex(TypeError,
165                                    'Cannot overwrite attribute __hash__'):
166            @dataclass(unsafe_hash=True)
167            class C:
168                def __hash__(self):
169                    pass
170
171        # Creating this class should not generate an exception,
172        #  because even though __hash__ exists before @dataclass is
173        #  called, (due to __eq__ being defined), since it's None
174        #  that's okay.
175        @dataclass(unsafe_hash=True)
176        class C:
177            x: int
178            def __eq__(self):
179                pass
180        # The generated hash function works as we'd expect.
181        self.assertEqual(hash(C(10)), hash((10,)))
182
183        # Creating this class should generate an exception, because
184        #  __hash__ exists and is not None, which it would be if it
185        #  had been auto-generated due to __eq__ being defined.
186        with self.assertRaisesRegex(TypeError,
187                                    'Cannot overwrite attribute __hash__'):
188            @dataclass(unsafe_hash=True)
189            class C:
190                x: int
191                def __eq__(self):
192                    pass
193                def __hash__(self):
194                    pass
195
196    def test_overwrite_fields_in_derived_class(self):
197        # Note that x from C1 replaces x in Base, but the order remains
198        #  the same as defined in Base.
199        @dataclass
200        class Base:
201            x: Any = 15.0
202            y: int = 0
203
204        @dataclass
205        class C1(Base):
206            z: int = 10
207            x: int = 15
208
209        o = Base()
210        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
211
212        o = C1()
213        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
214
215        o = C1(x=5)
216        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
217
218    def test_field_named_self(self):
219        @dataclass
220        class C:
221            self: str
222        c=C('foo')
223        self.assertEqual(c.self, 'foo')
224
225        # Make sure the first parameter is not named 'self'.
226        sig = inspect.signature(C.__init__)
227        first = next(iter(sig.parameters))
228        self.assertNotEqual('self', first)
229
230        # But we do use 'self' if no field named self.
231        @dataclass
232        class C:
233            selfx: str
234
235        # Make sure the first parameter is named 'self'.
236        sig = inspect.signature(C.__init__)
237        first = next(iter(sig.parameters))
238        self.assertEqual('self', first)
239
240    def test_field_named_object(self):
241        @dataclass
242        class C:
243            object: str
244        c = C('foo')
245        self.assertEqual(c.object, 'foo')
246
247    def test_field_named_object_frozen(self):
248        @dataclass(frozen=True)
249        class C:
250            object: str
251        c = C('foo')
252        self.assertEqual(c.object, 'foo')
253
254    def test_field_named_BUILTINS_frozen(self):
255        # gh-96151
256        @dataclass(frozen=True)
257        class C:
258            BUILTINS: int
259        c = C(5)
260        self.assertEqual(c.BUILTINS, 5)
261
262    def test_field_named_like_builtin(self):
263        # Attribute names can shadow built-in names
264        # since code generation is used.
265        # Ensure that this is not happening.
266        exclusions = {'None', 'True', 'False'}
267        builtins_names = sorted(
268            b for b in builtins.__dict__.keys()
269            if not b.startswith('__') and b not in exclusions
270        )
271        attributes = [(name, str) for name in builtins_names]
272        C = make_dataclass('C', attributes)
273
274        c = C(*[name for name in builtins_names])
275
276        for name in builtins_names:
277            self.assertEqual(getattr(c, name), name)
278
279    def test_field_named_like_builtin_frozen(self):
280        # Attribute names can shadow built-in names
281        # since code generation is used.
282        # Ensure that this is not happening
283        # for frozen data classes.
284        exclusions = {'None', 'True', 'False'}
285        builtins_names = sorted(
286            b for b in builtins.__dict__.keys()
287            if not b.startswith('__') and b not in exclusions
288        )
289        attributes = [(name, str) for name in builtins_names]
290        C = make_dataclass('C', attributes, frozen=True)
291
292        c = C(*[name for name in builtins_names])
293
294        for name in builtins_names:
295            self.assertEqual(getattr(c, name), name)
296
297    def test_0_field_compare(self):
298        # Ensure that order=False is the default.
299        @dataclass
300        class C0:
301            pass
302
303        @dataclass(order=False)
304        class C1:
305            pass
306
307        for cls in [C0, C1]:
308            with self.subTest(cls=cls):
309                self.assertEqual(cls(), cls())
310                for idx, fn in enumerate([lambda a, b: a < b,
311                                          lambda a, b: a <= b,
312                                          lambda a, b: a > b,
313                                          lambda a, b: a >= b]):
314                    with self.subTest(idx=idx):
315                        with self.assertRaisesRegex(TypeError,
316                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
317                            fn(cls(), cls())
318
319        @dataclass(order=True)
320        class C:
321            pass
322        self.assertLessEqual(C(), C())
323        self.assertGreaterEqual(C(), C())
324
325    def test_1_field_compare(self):
326        # Ensure that order=False is the default.
327        @dataclass
328        class C0:
329            x: int
330
331        @dataclass(order=False)
332        class C1:
333            x: int
334
335        for cls in [C0, C1]:
336            with self.subTest(cls=cls):
337                self.assertEqual(cls(1), cls(1))
338                self.assertNotEqual(cls(0), cls(1))
339                for idx, fn in enumerate([lambda a, b: a < b,
340                                          lambda a, b: a <= b,
341                                          lambda a, b: a > b,
342                                          lambda a, b: a >= b]):
343                    with self.subTest(idx=idx):
344                        with self.assertRaisesRegex(TypeError,
345                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
346                            fn(cls(0), cls(0))
347
348        @dataclass(order=True)
349        class C:
350            x: int
351        self.assertLess(C(0), C(1))
352        self.assertLessEqual(C(0), C(1))
353        self.assertLessEqual(C(1), C(1))
354        self.assertGreater(C(1), C(0))
355        self.assertGreaterEqual(C(1), C(0))
356        self.assertGreaterEqual(C(1), C(1))
357
358    def test_simple_compare(self):
359        # Ensure that order=False is the default.
360        @dataclass
361        class C0:
362            x: int
363            y: int
364
365        @dataclass(order=False)
366        class C1:
367            x: int
368            y: int
369
370        for cls in [C0, C1]:
371            with self.subTest(cls=cls):
372                self.assertEqual(cls(0, 0), cls(0, 0))
373                self.assertEqual(cls(1, 2), cls(1, 2))
374                self.assertNotEqual(cls(1, 0), cls(0, 0))
375                self.assertNotEqual(cls(1, 0), cls(1, 1))
376                for idx, fn in enumerate([lambda a, b: a < b,
377                                          lambda a, b: a <= b,
378                                          lambda a, b: a > b,
379                                          lambda a, b: a >= b]):
380                    with self.subTest(idx=idx):
381                        with self.assertRaisesRegex(TypeError,
382                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
383                            fn(cls(0, 0), cls(0, 0))
384
385        @dataclass(order=True)
386        class C:
387            x: int
388            y: int
389
390        for idx, fn in enumerate([lambda a, b: a == b,
391                                  lambda a, b: a <= b,
392                                  lambda a, b: a >= b]):
393            with self.subTest(idx=idx):
394                self.assertTrue(fn(C(0, 0), C(0, 0)))
395
396        for idx, fn in enumerate([lambda a, b: a < b,
397                                  lambda a, b: a <= b,
398                                  lambda a, b: a != b]):
399            with self.subTest(idx=idx):
400                self.assertTrue(fn(C(0, 0), C(0, 1)))
401                self.assertTrue(fn(C(0, 1), C(1, 0)))
402                self.assertTrue(fn(C(1, 0), C(1, 1)))
403
404        for idx, fn in enumerate([lambda a, b: a > b,
405                                  lambda a, b: a >= b,
406                                  lambda a, b: a != b]):
407            with self.subTest(idx=idx):
408                self.assertTrue(fn(C(0, 1), C(0, 0)))
409                self.assertTrue(fn(C(1, 0), C(0, 1)))
410                self.assertTrue(fn(C(1, 1), C(1, 0)))
411
412    def test_compare_subclasses(self):
413        # Comparisons fail for subclasses, even if no fields
414        #  are added.
415        @dataclass
416        class B:
417            i: int
418
419        @dataclass
420        class C(B):
421            pass
422
423        for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
424                                              (lambda a, b: a != b, True)]):
425            with self.subTest(idx=idx):
426                self.assertEqual(fn(B(0), C(0)), expected)
427
428        for idx, fn in enumerate([lambda a, b: a < b,
429                                  lambda a, b: a <= b,
430                                  lambda a, b: a > b,
431                                  lambda a, b: a >= b]):
432            with self.subTest(idx=idx):
433                with self.assertRaisesRegex(TypeError,
434                                            "not supported between instances of 'B' and 'C'"):
435                    fn(B(0), C(0))
436
437    def test_eq_order(self):
438        # Test combining eq and order.
439        for (eq,    order, result   ) in [
440            (False, False, 'neither'),
441            (False, True,  'exception'),
442            (True,  False, 'eq_only'),
443            (True,  True,  'both'),
444        ]:
445            with self.subTest(eq=eq, order=order):
446                if result == 'exception':
447                    with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
448                        @dataclass(eq=eq, order=order)
449                        class C:
450                            pass
451                else:
452                    @dataclass(eq=eq, order=order)
453                    class C:
454                        pass
455
456                    if result == 'neither':
457                        self.assertNotIn('__eq__', C.__dict__)
458                        self.assertNotIn('__lt__', C.__dict__)
459                        self.assertNotIn('__le__', C.__dict__)
460                        self.assertNotIn('__gt__', C.__dict__)
461                        self.assertNotIn('__ge__', C.__dict__)
462                    elif result == 'both':
463                        self.assertIn('__eq__', C.__dict__)
464                        self.assertIn('__lt__', C.__dict__)
465                        self.assertIn('__le__', C.__dict__)
466                        self.assertIn('__gt__', C.__dict__)
467                        self.assertIn('__ge__', C.__dict__)
468                    elif result == 'eq_only':
469                        self.assertIn('__eq__', C.__dict__)
470                        self.assertNotIn('__lt__', C.__dict__)
471                        self.assertNotIn('__le__', C.__dict__)
472                        self.assertNotIn('__gt__', C.__dict__)
473                        self.assertNotIn('__ge__', C.__dict__)
474                    else:
475                        assert False, f'unknown result {result!r}'
476
477    def test_field_no_default(self):
478        @dataclass
479        class C:
480            x: int = field()
481
482        self.assertEqual(C(5).x, 5)
483
484        with self.assertRaisesRegex(TypeError,
485                                    r"__init__\(\) missing 1 required "
486                                    "positional argument: 'x'"):
487            C()
488
489    def test_field_default(self):
490        default = object()
491        @dataclass
492        class C:
493            x: object = field(default=default)
494
495        self.assertIs(C.x, default)
496        c = C(10)
497        self.assertEqual(c.x, 10)
498
499        # If we delete the instance attribute, we should then see the
500        #  class attribute.
501        del c.x
502        self.assertIs(c.x, default)
503
504        self.assertIs(C().x, default)
505
506    def test_not_in_repr(self):
507        @dataclass
508        class C:
509            x: int = field(repr=False)
510        with self.assertRaises(TypeError):
511            C()
512        c = C(10)
513        self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
514
515        @dataclass
516        class C:
517            x: int = field(repr=False)
518            y: int
519        c = C(10, 20)
520        self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
521
522    def test_not_in_compare(self):
523        @dataclass
524        class C:
525            x: int = 0
526            y: int = field(compare=False, default=4)
527
528        self.assertEqual(C(), C(0, 20))
529        self.assertEqual(C(1, 10), C(1, 20))
530        self.assertNotEqual(C(3), C(4, 10))
531        self.assertNotEqual(C(3, 10), C(4, 10))
532
533    def test_no_unhashable_default(self):
534        # See bpo-44674.
535        class Unhashable:
536            __hash__ = None
537
538        unhashable_re = 'mutable default .* for field a is not allowed'
539        with self.assertRaisesRegex(ValueError, unhashable_re):
540            @dataclass
541            class A:
542                a: dict = {}
543
544        with self.assertRaisesRegex(ValueError, unhashable_re):
545            @dataclass
546            class A:
547                a: Any = Unhashable()
548
549        # Make sure that the machinery looking for hashability is using the
550        # class's __hash__, not the instance's __hash__.
551        with self.assertRaisesRegex(ValueError, unhashable_re):
552            unhashable = Unhashable()
553            # This shouldn't make the variable hashable.
554            unhashable.__hash__ = lambda: 0
555            @dataclass
556            class A:
557                a: Any = unhashable
558
559    def test_hash_field_rules(self):
560        # Test all 6 cases of:
561        #  hash=True/False/None
562        #  compare=True/False
563        for (hash_,    compare, result  ) in [
564            (True,     False,   'field' ),
565            (True,     True,    'field' ),
566            (False,    False,   'absent'),
567            (False,    True,    'absent'),
568            (None,     False,   'absent'),
569            (None,     True,    'field' ),
570            ]:
571            with self.subTest(hash=hash_, compare=compare):
572                @dataclass(unsafe_hash=True)
573                class C:
574                    x: int = field(compare=compare, hash=hash_, default=5)
575
576                if result == 'field':
577                    # __hash__ contains the field.
578                    self.assertEqual(hash(C(5)), hash((5,)))
579                elif result == 'absent':
580                    # The field is not present in the hash.
581                    self.assertEqual(hash(C(5)), hash(()))
582                else:
583                    assert False, f'unknown result {result!r}'
584
585    def test_init_false_no_default(self):
586        # If init=False and no default value, then the field won't be
587        #  present in the instance.
588        @dataclass
589        class C:
590            x: int = field(init=False)
591
592        self.assertNotIn('x', C().__dict__)
593
594        @dataclass
595        class C:
596            x: int
597            y: int = 0
598            z: int = field(init=False)
599            t: int = 10
600
601        self.assertNotIn('z', C(0).__dict__)
602        self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
603
604    def test_class_marker(self):
605        @dataclass
606        class C:
607            x: int
608            y: str = field(init=False, default=None)
609            z: str = field(repr=False)
610
611        the_fields = fields(C)
612        # the_fields is a tuple of 3 items, each value
613        #  is in __annotations__.
614        self.assertIsInstance(the_fields, tuple)
615        for f in the_fields:
616            self.assertIs(type(f), Field)
617            self.assertIn(f.name, C.__annotations__)
618
619        self.assertEqual(len(the_fields), 3)
620
621        self.assertEqual(the_fields[0].name, 'x')
622        self.assertEqual(the_fields[0].type, int)
623        self.assertFalse(hasattr(C, 'x'))
624        self.assertTrue (the_fields[0].init)
625        self.assertTrue (the_fields[0].repr)
626        self.assertEqual(the_fields[1].name, 'y')
627        self.assertEqual(the_fields[1].type, str)
628        self.assertIsNone(getattr(C, 'y'))
629        self.assertFalse(the_fields[1].init)
630        self.assertTrue (the_fields[1].repr)
631        self.assertEqual(the_fields[2].name, 'z')
632        self.assertEqual(the_fields[2].type, str)
633        self.assertFalse(hasattr(C, 'z'))
634        self.assertTrue (the_fields[2].init)
635        self.assertFalse(the_fields[2].repr)
636
637    def test_field_order(self):
638        @dataclass
639        class B:
640            a: str = 'B:a'
641            b: str = 'B:b'
642            c: str = 'B:c'
643
644        @dataclass
645        class C(B):
646            b: str = 'C:b'
647
648        self.assertEqual([(f.name, f.default) for f in fields(C)],
649                         [('a', 'B:a'),
650                          ('b', 'C:b'),
651                          ('c', 'B:c')])
652
653        @dataclass
654        class D(B):
655            c: str = 'D:c'
656
657        self.assertEqual([(f.name, f.default) for f in fields(D)],
658                         [('a', 'B:a'),
659                          ('b', 'B:b'),
660                          ('c', 'D:c')])
661
662        @dataclass
663        class E(D):
664            a: str = 'E:a'
665            d: str = 'E:d'
666
667        self.assertEqual([(f.name, f.default) for f in fields(E)],
668                         [('a', 'E:a'),
669                          ('b', 'B:b'),
670                          ('c', 'D:c'),
671                          ('d', 'E:d')])
672
673    def test_class_attrs(self):
674        # We only have a class attribute if a default value is
675        #  specified, either directly or via a field with a default.
676        default = object()
677        @dataclass
678        class C:
679            x: int
680            y: int = field(repr=False)
681            z: object = default
682            t: int = field(default=100)
683
684        self.assertFalse(hasattr(C, 'x'))
685        self.assertFalse(hasattr(C, 'y'))
686        self.assertIs   (C.z, default)
687        self.assertEqual(C.t, 100)
688
689    def test_disallowed_mutable_defaults(self):
690        # For the known types, don't allow mutable default values.
691        for typ, empty, non_empty in [(list, [], [1]),
692                                      (dict, {}, {0:1}),
693                                      (set, set(), set([1])),
694                                      ]:
695            with self.subTest(typ=typ):
696                # Can't use a zero-length value.
697                with self.assertRaisesRegex(ValueError,
698                                            f'mutable default {typ} for field '
699                                            'x is not allowed'):
700                    @dataclass
701                    class Point:
702                        x: typ = empty
703
704
705                # Nor a non-zero-length value
706                with self.assertRaisesRegex(ValueError,
707                                            f'mutable default {typ} for field '
708                                            'y is not allowed'):
709                    @dataclass
710                    class Point:
711                        y: typ = non_empty
712
713                # Check subtypes also fail.
714                class Subclass(typ): pass
715
716                with self.assertRaisesRegex(ValueError,
717                                            f"mutable default .*Subclass'>"
718                                            ' for field z is not allowed'
719                                            ):
720                    @dataclass
721                    class Point:
722                        z: typ = Subclass()
723
724                # Because this is a ClassVar, it can be mutable.
725                @dataclass
726                class C:
727                    z: ClassVar[typ] = typ()
728
729                # Because this is a ClassVar, it can be mutable.
730                @dataclass
731                class C:
732                    x: ClassVar[typ] = Subclass()
733
734    def test_deliberately_mutable_defaults(self):
735        # If a mutable default isn't in the known list of
736        #  (list, dict, set), then it's okay.
737        class Mutable:
738            def __init__(self):
739                self.l = []
740
741        @dataclass
742        class C:
743            x: Mutable
744
745        # These 2 instances will share this value of x.
746        lst = Mutable()
747        o1 = C(lst)
748        o2 = C(lst)
749        self.assertEqual(o1, o2)
750        o1.x.l.extend([1, 2])
751        self.assertEqual(o1, o2)
752        self.assertEqual(o1.x.l, [1, 2])
753        self.assertIs(o1.x, o2.x)
754
755    def test_no_options(self):
756        # Call with dataclass().
757        @dataclass()
758        class C:
759            x: int
760
761        self.assertEqual(C(42).x, 42)
762
763    def test_not_tuple(self):
764        # Make sure we can't be compared to a tuple.
765        @dataclass
766        class Point:
767            x: int
768            y: int
769        self.assertNotEqual(Point(1, 2), (1, 2))
770
771        # And that we can't compare to another unrelated dataclass.
772        @dataclass
773        class C:
774            x: int
775            y: int
776        self.assertNotEqual(Point(1, 3), C(1, 3))
777
778    def test_not_other_dataclass(self):
779        # Test that some of the problems with namedtuple don't happen
780        #  here.
781        @dataclass
782        class Point3D:
783            x: int
784            y: int
785            z: int
786
787        @dataclass
788        class Date:
789            year: int
790            month: int
791            day: int
792
793        self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
794        self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
795
796        # Make sure we can't unpack.
797        with self.assertRaisesRegex(TypeError, 'unpack'):
798            x, y, z = Point3D(4, 5, 6)
799
800        # Make sure another class with the same field names isn't
801        #  equal.
802        @dataclass
803        class Point3Dv1:
804            x: int = 0
805            y: int = 0
806            z: int = 0
807        self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
808
809    def test_function_annotations(self):
810        # Some dummy class and instance to use as a default.
811        class F:
812            pass
813        f = F()
814
815        def validate_class(cls):
816            # First, check __annotations__, even though they're not
817            #  function annotations.
818            self.assertEqual(cls.__annotations__['i'], int)
819            self.assertEqual(cls.__annotations__['j'], str)
820            self.assertEqual(cls.__annotations__['k'], F)
821            self.assertEqual(cls.__annotations__['l'], float)
822            self.assertEqual(cls.__annotations__['z'], complex)
823
824            # Verify __init__.
825
826            signature = inspect.signature(cls.__init__)
827            # Check the return type, should be None.
828            self.assertIs(signature.return_annotation, None)
829
830            # Check each parameter.
831            params = iter(signature.parameters.values())
832            param = next(params)
833            # This is testing an internal name, and probably shouldn't be tested.
834            self.assertEqual(param.name, 'self')
835            param = next(params)
836            self.assertEqual(param.name, 'i')
837            self.assertIs   (param.annotation, int)
838            self.assertEqual(param.default, inspect.Parameter.empty)
839            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
840            param = next(params)
841            self.assertEqual(param.name, 'j')
842            self.assertIs   (param.annotation, str)
843            self.assertEqual(param.default, inspect.Parameter.empty)
844            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
845            param = next(params)
846            self.assertEqual(param.name, 'k')
847            self.assertIs   (param.annotation, F)
848            # Don't test for the default, since it's set to MISSING.
849            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
850            param = next(params)
851            self.assertEqual(param.name, 'l')
852            self.assertIs   (param.annotation, float)
853            # Don't test for the default, since it's set to MISSING.
854            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
855            self.assertRaises(StopIteration, next, params)
856
857
858        @dataclass
859        class C:
860            i: int
861            j: str
862            k: F = f
863            l: float=field(default=None)
864            z: complex=field(default=3+4j, init=False)
865
866        validate_class(C)
867
868        # Now repeat with __hash__.
869        @dataclass(frozen=True, unsafe_hash=True)
870        class C:
871            i: int
872            j: str
873            k: F = f
874            l: float=field(default=None)
875            z: complex=field(default=3+4j, init=False)
876
877        validate_class(C)
878
879    def test_missing_default(self):
880        # Test that MISSING works the same as a default not being
881        #  specified.
882        @dataclass
883        class C:
884            x: int=field(default=MISSING)
885        with self.assertRaisesRegex(TypeError,
886                                    r'__init__\(\) missing 1 required '
887                                    'positional argument'):
888            C()
889        self.assertNotIn('x', C.__dict__)
890
891        @dataclass
892        class D:
893            x: int
894        with self.assertRaisesRegex(TypeError,
895                                    r'__init__\(\) missing 1 required '
896                                    'positional argument'):
897            D()
898        self.assertNotIn('x', D.__dict__)
899
900    def test_missing_default_factory(self):
901        # Test that MISSING works the same as a default factory not
902        #  being specified (which is really the same as a default not
903        #  being specified, too).
904        @dataclass
905        class C:
906            x: int=field(default_factory=MISSING)
907        with self.assertRaisesRegex(TypeError,
908                                    r'__init__\(\) missing 1 required '
909                                    'positional argument'):
910            C()
911        self.assertNotIn('x', C.__dict__)
912
913        @dataclass
914        class D:
915            x: int=field(default=MISSING, default_factory=MISSING)
916        with self.assertRaisesRegex(TypeError,
917                                    r'__init__\(\) missing 1 required '
918                                    'positional argument'):
919            D()
920        self.assertNotIn('x', D.__dict__)
921
922    def test_missing_repr(self):
923        self.assertIn('MISSING_TYPE object', repr(MISSING))
924
925    def test_dont_include_other_annotations(self):
926        @dataclass
927        class C:
928            i: int
929            def foo(self) -> int:
930                return 4
931            @property
932            def bar(self) -> int:
933                return 5
934        self.assertEqual(list(C.__annotations__), ['i'])
935        self.assertEqual(C(10).foo(), 4)
936        self.assertEqual(C(10).bar, 5)
937        self.assertEqual(C(10).i, 10)
938
939    def test_post_init(self):
940        # Just make sure it gets called
941        @dataclass
942        class C:
943            def __post_init__(self):
944                raise CustomError()
945        with self.assertRaises(CustomError):
946            C()
947
948        @dataclass
949        class C:
950            i: int = 10
951            def __post_init__(self):
952                if self.i == 10:
953                    raise CustomError()
954        with self.assertRaises(CustomError):
955            C()
956        # post-init gets called, but doesn't raise. This is just
957        #  checking that self is used correctly.
958        C(5)
959
960        # If there's not an __init__, then post-init won't get called.
961        @dataclass(init=False)
962        class C:
963            def __post_init__(self):
964                raise CustomError()
965        # Creating the class won't raise
966        C()
967
968        @dataclass
969        class C:
970            x: int = 0
971            def __post_init__(self):
972                self.x *= 2
973        self.assertEqual(C().x, 0)
974        self.assertEqual(C(2).x, 4)
975
976        # Make sure that if we're frozen, post-init can't set
977        #  attributes.
978        @dataclass(frozen=True)
979        class C:
980            x: int = 0
981            def __post_init__(self):
982                self.x *= 2
983        with self.assertRaises(FrozenInstanceError):
984            C()
985
986    def test_post_init_super(self):
987        # Make sure super() post-init isn't called by default.
988        class B:
989            def __post_init__(self):
990                raise CustomError()
991
992        @dataclass
993        class C(B):
994            def __post_init__(self):
995                self.x = 5
996
997        self.assertEqual(C().x, 5)
998
999        # Now call super(), and it will raise.
1000        @dataclass
1001        class C(B):
1002            def __post_init__(self):
1003                super().__post_init__()
1004
1005        with self.assertRaises(CustomError):
1006            C()
1007
1008        # Make sure post-init is called, even if not defined in our
1009        #  class.
1010        @dataclass
1011        class C(B):
1012            pass
1013
1014        with self.assertRaises(CustomError):
1015            C()
1016
1017    def test_post_init_staticmethod(self):
1018        flag = False
1019        @dataclass
1020        class C:
1021            x: int
1022            y: int
1023            @staticmethod
1024            def __post_init__():
1025                nonlocal flag
1026                flag = True
1027
1028        self.assertFalse(flag)
1029        c = C(3, 4)
1030        self.assertEqual((c.x, c.y), (3, 4))
1031        self.assertTrue(flag)
1032
1033    def test_post_init_classmethod(self):
1034        @dataclass
1035        class C:
1036            flag = False
1037            x: int
1038            y: int
1039            @classmethod
1040            def __post_init__(cls):
1041                cls.flag = True
1042
1043        self.assertFalse(C.flag)
1044        c = C(3, 4)
1045        self.assertEqual((c.x, c.y), (3, 4))
1046        self.assertTrue(C.flag)
1047
1048    def test_post_init_not_auto_added(self):
1049        # See bpo-46757, which had proposed always adding __post_init__.  As
1050        # Raymond Hettinger pointed out, that would be a breaking change.  So,
1051        # add a test to make sure that the current behavior doesn't change.
1052
1053        @dataclass
1054        class A0:
1055            pass
1056
1057        @dataclass
1058        class B0:
1059            b_called: bool = False
1060            def __post_init__(self):
1061                self.b_called = True
1062
1063        @dataclass
1064        class C0(A0, B0):
1065            c_called: bool = False
1066            def __post_init__(self):
1067                super().__post_init__()
1068                self.c_called = True
1069
1070        # Since A0 has no __post_init__, and one wasn't automatically added
1071        # (because that's the rule: it's never added by @dataclass, it's only
1072        # the class author that can add it), then B0.__post_init__ is called.
1073        # Verify that.
1074        c = C0()
1075        self.assertTrue(c.b_called)
1076        self.assertTrue(c.c_called)
1077
1078        ######################################
1079        # Now, the same thing, except A1 defines __post_init__.
1080        @dataclass
1081        class A1:
1082            def __post_init__(self):
1083                pass
1084
1085        @dataclass
1086        class B1:
1087            b_called: bool = False
1088            def __post_init__(self):
1089                self.b_called = True
1090
1091        @dataclass
1092        class C1(A1, B1):
1093            c_called: bool = False
1094            def __post_init__(self):
1095                super().__post_init__()
1096                self.c_called = True
1097
1098        # This time, B1.__post_init__ isn't being called.  This mimics what
1099        # would happen if A1.__post_init__ had been automatically added,
1100        # instead of manually added as we see here.  This test isn't really
1101        # needed, but I'm including it just to demonstrate the changed
1102        # behavior when A1 does define __post_init__.
1103        c = C1()
1104        self.assertFalse(c.b_called)
1105        self.assertTrue(c.c_called)
1106
1107    def test_class_var(self):
1108        # Make sure ClassVars are ignored in __init__, __repr__, etc.
1109        @dataclass
1110        class C:
1111            x: int
1112            y: int = 10
1113            z: ClassVar[int] = 1000
1114            w: ClassVar[int] = 2000
1115            t: ClassVar[int] = 3000
1116            s: ClassVar      = 4000
1117
1118        c = C(5)
1119        self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
1120        self.assertEqual(len(fields(C)), 2)                 # We have 2 fields.
1121        self.assertEqual(len(C.__annotations__), 6)         # And 4 ClassVars.
1122        self.assertEqual(c.z, 1000)
1123        self.assertEqual(c.w, 2000)
1124        self.assertEqual(c.t, 3000)
1125        self.assertEqual(c.s, 4000)
1126        C.z += 1
1127        self.assertEqual(c.z, 1001)
1128        c = C(20)
1129        self.assertEqual((c.x, c.y), (20, 10))
1130        self.assertEqual(c.z, 1001)
1131        self.assertEqual(c.w, 2000)
1132        self.assertEqual(c.t, 3000)
1133        self.assertEqual(c.s, 4000)
1134
1135    def test_class_var_no_default(self):
1136        # If a ClassVar has no default value, it should not be set on the class.
1137        @dataclass
1138        class C:
1139            x: ClassVar[int]
1140
1141        self.assertNotIn('x', C.__dict__)
1142
1143    def test_class_var_default_factory(self):
1144        # It makes no sense for a ClassVar to have a default factory. When
1145        #  would it be called? Call it yourself, since it's class-wide.
1146        with self.assertRaisesRegex(TypeError,
1147                                    'cannot have a default factory'):
1148            @dataclass
1149            class C:
1150                x: ClassVar[int] = field(default_factory=int)
1151
1152            self.assertNotIn('x', C.__dict__)
1153
1154    def test_class_var_with_default(self):
1155        # If a ClassVar has a default value, it should be set on the class.
1156        @dataclass
1157        class C:
1158            x: ClassVar[int] = 10
1159        self.assertEqual(C.x, 10)
1160
1161        @dataclass
1162        class C:
1163            x: ClassVar[int] = field(default=10)
1164        self.assertEqual(C.x, 10)
1165
1166    def test_class_var_frozen(self):
1167        # Make sure ClassVars work even if we're frozen.
1168        @dataclass(frozen=True)
1169        class C:
1170            x: int
1171            y: int = 10
1172            z: ClassVar[int] = 1000
1173            w: ClassVar[int] = 2000
1174            t: ClassVar[int] = 3000
1175
1176        c = C(5)
1177        self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1178        self.assertEqual(len(fields(C)), 2)                 # We have 2 fields
1179        self.assertEqual(len(C.__annotations__), 5)         # And 3 ClassVars
1180        self.assertEqual(c.z, 1000)
1181        self.assertEqual(c.w, 2000)
1182        self.assertEqual(c.t, 3000)
1183        # We can still modify the ClassVar, it's only instances that are
1184        #  frozen.
1185        C.z += 1
1186        self.assertEqual(c.z, 1001)
1187        c = C(20)
1188        self.assertEqual((c.x, c.y), (20, 10))
1189        self.assertEqual(c.z, 1001)
1190        self.assertEqual(c.w, 2000)
1191        self.assertEqual(c.t, 3000)
1192
1193    def test_init_var_no_default(self):
1194        # If an InitVar has no default value, it should not be set on the class.
1195        @dataclass
1196        class C:
1197            x: InitVar[int]
1198
1199        self.assertNotIn('x', C.__dict__)
1200
1201    def test_init_var_default_factory(self):
1202        # It makes no sense for an InitVar to have a default factory. When
1203        #  would it be called? Call it yourself, since it's class-wide.
1204        with self.assertRaisesRegex(TypeError,
1205                                    'cannot have a default factory'):
1206            @dataclass
1207            class C:
1208                x: InitVar[int] = field(default_factory=int)
1209
1210            self.assertNotIn('x', C.__dict__)
1211
1212    def test_init_var_with_default(self):
1213        # If an InitVar has a default value, it should be set on the class.
1214        @dataclass
1215        class C:
1216            x: InitVar[int] = 10
1217        self.assertEqual(C.x, 10)
1218
1219        @dataclass
1220        class C:
1221            x: InitVar[int] = field(default=10)
1222        self.assertEqual(C.x, 10)
1223
1224    def test_init_var(self):
1225        @dataclass
1226        class C:
1227            x: int = None
1228            init_param: InitVar[int] = None
1229
1230            def __post_init__(self, init_param):
1231                if self.x is None:
1232                    self.x = init_param*2
1233
1234        c = C(init_param=10)
1235        self.assertEqual(c.x, 20)
1236
1237    def test_init_var_preserve_type(self):
1238        self.assertEqual(InitVar[int].type, int)
1239
1240        # Make sure the repr is correct.
1241        self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
1242        self.assertEqual(repr(InitVar[List[int]]),
1243                         'dataclasses.InitVar[typing.List[int]]')
1244        self.assertEqual(repr(InitVar[list[int]]),
1245                         'dataclasses.InitVar[list[int]]')
1246        self.assertEqual(repr(InitVar[int|str]),
1247                         'dataclasses.InitVar[int | str]')
1248
1249    def test_init_var_inheritance(self):
1250        # Note that this deliberately tests that a dataclass need not
1251        #  have a __post_init__ function if it has an InitVar field.
1252        #  It could just be used in a derived class, as shown here.
1253        @dataclass
1254        class Base:
1255            x: int
1256            init_base: InitVar[int]
1257
1258        # We can instantiate by passing the InitVar, even though
1259        #  it's not used.
1260        b = Base(0, 10)
1261        self.assertEqual(vars(b), {'x': 0})
1262
1263        @dataclass
1264        class C(Base):
1265            y: int
1266            init_derived: InitVar[int]
1267
1268            def __post_init__(self, init_base, init_derived):
1269                self.x = self.x + init_base
1270                self.y = self.y + init_derived
1271
1272        c = C(10, 11, 50, 51)
1273        self.assertEqual(vars(c), {'x': 21, 'y': 101})
1274
1275    def test_default_factory(self):
1276        # Test a factory that returns a new list.
1277        @dataclass
1278        class C:
1279            x: int
1280            y: list = field(default_factory=list)
1281
1282        c0 = C(3)
1283        c1 = C(3)
1284        self.assertEqual(c0.x, 3)
1285        self.assertEqual(c0.y, [])
1286        self.assertEqual(c0, c1)
1287        self.assertIsNot(c0.y, c1.y)
1288        self.assertEqual(astuple(C(5, [1])), (5, [1]))
1289
1290        # Test a factory that returns a shared list.
1291        l = []
1292        @dataclass
1293        class C:
1294            x: int
1295            y: list = field(default_factory=lambda: l)
1296
1297        c0 = C(3)
1298        c1 = C(3)
1299        self.assertEqual(c0.x, 3)
1300        self.assertEqual(c0.y, [])
1301        self.assertEqual(c0, c1)
1302        self.assertIs(c0.y, c1.y)
1303        self.assertEqual(astuple(C(5, [1])), (5, [1]))
1304
1305        # Test various other field flags.
1306        # repr
1307        @dataclass
1308        class C:
1309            x: list = field(default_factory=list, repr=False)
1310        self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1311        self.assertEqual(C().x, [])
1312
1313        # hash
1314        @dataclass(unsafe_hash=True)
1315        class C:
1316            x: list = field(default_factory=list, hash=False)
1317        self.assertEqual(astuple(C()), ([],))
1318        self.assertEqual(hash(C()), hash(()))
1319
1320        # init (see also test_default_factory_with_no_init)
1321        @dataclass
1322        class C:
1323            x: list = field(default_factory=list, init=False)
1324        self.assertEqual(astuple(C()), ([],))
1325
1326        # compare
1327        @dataclass
1328        class C:
1329            x: list = field(default_factory=list, compare=False)
1330        self.assertEqual(C(), C([1]))
1331
1332    def test_default_factory_with_no_init(self):
1333        # We need a factory with a side effect.
1334        factory = Mock()
1335
1336        @dataclass
1337        class C:
1338            x: list = field(default_factory=factory, init=False)
1339
1340        # Make sure the default factory is called for each new instance.
1341        C().x
1342        self.assertEqual(factory.call_count, 1)
1343        C().x
1344        self.assertEqual(factory.call_count, 2)
1345
1346    def test_default_factory_not_called_if_value_given(self):
1347        # We need a factory that we can test if it's been called.
1348        factory = Mock()
1349
1350        @dataclass
1351        class C:
1352            x: int = field(default_factory=factory)
1353
1354        # Make sure that if a field has a default factory function,
1355        #  it's not called if a value is specified.
1356        C().x
1357        self.assertEqual(factory.call_count, 1)
1358        self.assertEqual(C(10).x, 10)
1359        self.assertEqual(factory.call_count, 1)
1360        C().x
1361        self.assertEqual(factory.call_count, 2)
1362
1363    def test_default_factory_derived(self):
1364        # See bpo-32896.
1365        @dataclass
1366        class Foo:
1367            x: dict = field(default_factory=dict)
1368
1369        @dataclass
1370        class Bar(Foo):
1371            y: int = 1
1372
1373        self.assertEqual(Foo().x, {})
1374        self.assertEqual(Bar().x, {})
1375        self.assertEqual(Bar().y, 1)
1376
1377        @dataclass
1378        class Baz(Foo):
1379            pass
1380        self.assertEqual(Baz().x, {})
1381
1382    def test_intermediate_non_dataclass(self):
1383        # Test that an intermediate class that defines
1384        #  annotations does not define fields.
1385
1386        @dataclass
1387        class A:
1388            x: int
1389
1390        class B(A):
1391            y: int
1392
1393        @dataclass
1394        class C(B):
1395            z: int
1396
1397        c = C(1, 3)
1398        self.assertEqual((c.x, c.z), (1, 3))
1399
1400        # .y was not initialized.
1401        with self.assertRaisesRegex(AttributeError,
1402                                    'object has no attribute'):
1403            c.y
1404
1405        # And if we again derive a non-dataclass, no fields are added.
1406        class D(C):
1407            t: int
1408        d = D(4, 5)
1409        self.assertEqual((d.x, d.z), (4, 5))
1410
1411    def test_classvar_default_factory(self):
1412        # It's an error for a ClassVar to have a factory function.
1413        with self.assertRaisesRegex(TypeError,
1414                                    'cannot have a default factory'):
1415            @dataclass
1416            class C:
1417                x: ClassVar[int] = field(default_factory=int)
1418
1419    def test_is_dataclass(self):
1420        class NotDataClass:
1421            pass
1422
1423        self.assertFalse(is_dataclass(0))
1424        self.assertFalse(is_dataclass(int))
1425        self.assertFalse(is_dataclass(NotDataClass))
1426        self.assertFalse(is_dataclass(NotDataClass()))
1427
1428        @dataclass
1429        class C:
1430            x: int
1431
1432        @dataclass
1433        class D:
1434            d: C
1435            e: int
1436
1437        c = C(10)
1438        d = D(c, 4)
1439
1440        self.assertTrue(is_dataclass(C))
1441        self.assertTrue(is_dataclass(c))
1442        self.assertFalse(is_dataclass(c.x))
1443        self.assertTrue(is_dataclass(d.d))
1444        self.assertFalse(is_dataclass(d.e))
1445
1446    def test_is_dataclass_when_getattr_always_returns(self):
1447        # See bpo-37868.
1448        class A:
1449            def __getattr__(self, key):
1450                return 0
1451        self.assertFalse(is_dataclass(A))
1452        a = A()
1453
1454        # Also test for an instance attribute.
1455        class B:
1456            pass
1457        b = B()
1458        b.__dataclass_fields__ = []
1459
1460        for obj in a, b:
1461            with self.subTest(obj=obj):
1462                self.assertFalse(is_dataclass(obj))
1463
1464                # Indirect tests for _is_dataclass_instance().
1465                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1466                    asdict(obj)
1467                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1468                    astuple(obj)
1469                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1470                    replace(obj, x=0)
1471
1472    def test_is_dataclass_genericalias(self):
1473        @dataclass
1474        class A(types.GenericAlias):
1475            origin: type
1476            args: type
1477        self.assertTrue(is_dataclass(A))
1478        a = A(list, int)
1479        self.assertTrue(is_dataclass(type(a)))
1480        self.assertTrue(is_dataclass(a))
1481
1482
1483    def test_helper_fields_with_class_instance(self):
1484        # Check that we can call fields() on either a class or instance,
1485        #  and get back the same thing.
1486        @dataclass
1487        class C:
1488            x: int
1489            y: float
1490
1491        self.assertEqual(fields(C), fields(C(0, 0.0)))
1492
1493    def test_helper_fields_exception(self):
1494        # Check that TypeError is raised if not passed a dataclass or
1495        #  instance.
1496        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1497            fields(0)
1498
1499        class C: pass
1500        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1501            fields(C)
1502        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1503            fields(C())
1504
1505    def test_clean_traceback_from_fields_exception(self):
1506        stdout = io.StringIO()
1507        try:
1508            fields(object)
1509        except TypeError as exc:
1510            traceback.print_exception(exc, file=stdout)
1511        printed_traceback = stdout.getvalue()
1512        self.assertNotIn("AttributeError", printed_traceback)
1513        self.assertNotIn("__dataclass_fields__", printed_traceback)
1514
1515    def test_helper_asdict(self):
1516        # Basic tests for asdict(), it should return a new dictionary.
1517        @dataclass
1518        class C:
1519            x: int
1520            y: int
1521        c = C(1, 2)
1522
1523        self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1524        self.assertEqual(asdict(c), asdict(c))
1525        self.assertIsNot(asdict(c), asdict(c))
1526        c.x = 42
1527        self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1528        self.assertIs(type(asdict(c)), dict)
1529
1530    def test_helper_asdict_raises_on_classes(self):
1531        # asdict() should raise on a class object.
1532        @dataclass
1533        class C:
1534            x: int
1535            y: int
1536        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1537            asdict(C)
1538        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1539            asdict(int)
1540
1541    def test_helper_asdict_copy_values(self):
1542        @dataclass
1543        class C:
1544            x: int
1545            y: List[int] = field(default_factory=list)
1546        initial = []
1547        c = C(1, initial)
1548        d = asdict(c)
1549        self.assertEqual(d['y'], initial)
1550        self.assertIsNot(d['y'], initial)
1551        c = C(1)
1552        d = asdict(c)
1553        d['y'].append(1)
1554        self.assertEqual(c.y, [])
1555
1556    def test_helper_asdict_nested(self):
1557        @dataclass
1558        class UserId:
1559            token: int
1560            group: int
1561        @dataclass
1562        class User:
1563            name: str
1564            id: UserId
1565        u = User('Joe', UserId(123, 1))
1566        d = asdict(u)
1567        self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1568        self.assertIsNot(asdict(u), asdict(u))
1569        u.id.group = 2
1570        self.assertEqual(asdict(u), {'name': 'Joe',
1571                                     'id': {'token': 123, 'group': 2}})
1572
1573    def test_helper_asdict_builtin_containers(self):
1574        @dataclass
1575        class User:
1576            name: str
1577            id: int
1578        @dataclass
1579        class GroupList:
1580            id: int
1581            users: List[User]
1582        @dataclass
1583        class GroupTuple:
1584            id: int
1585            users: Tuple[User, ...]
1586        @dataclass
1587        class GroupDict:
1588            id: int
1589            users: Dict[str, User]
1590        a = User('Alice', 1)
1591        b = User('Bob', 2)
1592        gl = GroupList(0, [a, b])
1593        gt = GroupTuple(0, (a, b))
1594        gd = GroupDict(0, {'first': a, 'second': b})
1595        self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1596                                                         {'name': 'Bob', 'id': 2}]})
1597        self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1598                                                         {'name': 'Bob', 'id': 2})})
1599        self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1600                                                         'second': {'name': 'Bob', 'id': 2}}})
1601
1602    def test_helper_asdict_builtin_object_containers(self):
1603        @dataclass
1604        class Child:
1605            d: object
1606
1607        @dataclass
1608        class Parent:
1609            child: Child
1610
1611        self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1612        self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1613
1614    def test_helper_asdict_factory(self):
1615        @dataclass
1616        class C:
1617            x: int
1618            y: int
1619        c = C(1, 2)
1620        d = asdict(c, dict_factory=OrderedDict)
1621        self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1622        self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1623        c.x = 42
1624        d = asdict(c, dict_factory=OrderedDict)
1625        self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1626        self.assertIs(type(d), OrderedDict)
1627
1628    def test_helper_asdict_namedtuple(self):
1629        T = namedtuple('T', 'a b c')
1630        @dataclass
1631        class C:
1632            x: str
1633            y: T
1634        c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1635
1636        d = asdict(c)
1637        self.assertEqual(d, {'x': 'outer',
1638                             'y': T(1,
1639                                    {'x': 'inner',
1640                                     'y': T(11, 12, 13)},
1641                                    2),
1642                             }
1643                         )
1644
1645        # Now with a dict_factory.  OrderedDict is convenient, but
1646        # since it compares to dicts, we also need to have separate
1647        # assertIs tests.
1648        d = asdict(c, dict_factory=OrderedDict)
1649        self.assertEqual(d, {'x': 'outer',
1650                             'y': T(1,
1651                                    {'x': 'inner',
1652                                     'y': T(11, 12, 13)},
1653                                    2),
1654                             }
1655                         )
1656
1657        # Make sure that the returned dicts are actually OrderedDicts.
1658        self.assertIs(type(d), OrderedDict)
1659        self.assertIs(type(d['y'][1]), OrderedDict)
1660
1661    def test_helper_asdict_namedtuple_key(self):
1662        # Ensure that a field that contains a dict which has a
1663        # namedtuple as a key works with asdict().
1664
1665        @dataclass
1666        class C:
1667            f: dict
1668        T = namedtuple('T', 'a')
1669
1670        c = C({T('an a'): 0})
1671
1672        self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1673
1674    def test_helper_asdict_namedtuple_derived(self):
1675        class T(namedtuple('Tbase', 'a')):
1676            def my_a(self):
1677                return self.a
1678
1679        @dataclass
1680        class C:
1681            f: T
1682
1683        t = T(6)
1684        c = C(t)
1685
1686        d = asdict(c)
1687        self.assertEqual(d, {'f': T(a=6)})
1688        # Make sure that t has been copied, not used directly.
1689        self.assertIsNot(d['f'], t)
1690        self.assertEqual(d['f'].my_a(), 6)
1691
1692    def test_helper_astuple(self):
1693        # Basic tests for astuple(), it should return a new tuple.
1694        @dataclass
1695        class C:
1696            x: int
1697            y: int = 0
1698        c = C(1)
1699
1700        self.assertEqual(astuple(c), (1, 0))
1701        self.assertEqual(astuple(c), astuple(c))
1702        self.assertIsNot(astuple(c), astuple(c))
1703        c.y = 42
1704        self.assertEqual(astuple(c), (1, 42))
1705        self.assertIs(type(astuple(c)), tuple)
1706
1707    def test_helper_astuple_raises_on_classes(self):
1708        # astuple() should raise on a class object.
1709        @dataclass
1710        class C:
1711            x: int
1712            y: int
1713        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1714            astuple(C)
1715        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1716            astuple(int)
1717
1718    def test_helper_astuple_copy_values(self):
1719        @dataclass
1720        class C:
1721            x: int
1722            y: List[int] = field(default_factory=list)
1723        initial = []
1724        c = C(1, initial)
1725        t = astuple(c)
1726        self.assertEqual(t[1], initial)
1727        self.assertIsNot(t[1], initial)
1728        c = C(1)
1729        t = astuple(c)
1730        t[1].append(1)
1731        self.assertEqual(c.y, [])
1732
1733    def test_helper_astuple_nested(self):
1734        @dataclass
1735        class UserId:
1736            token: int
1737            group: int
1738        @dataclass
1739        class User:
1740            name: str
1741            id: UserId
1742        u = User('Joe', UserId(123, 1))
1743        t = astuple(u)
1744        self.assertEqual(t, ('Joe', (123, 1)))
1745        self.assertIsNot(astuple(u), astuple(u))
1746        u.id.group = 2
1747        self.assertEqual(astuple(u), ('Joe', (123, 2)))
1748
1749    def test_helper_astuple_builtin_containers(self):
1750        @dataclass
1751        class User:
1752            name: str
1753            id: int
1754        @dataclass
1755        class GroupList:
1756            id: int
1757            users: List[User]
1758        @dataclass
1759        class GroupTuple:
1760            id: int
1761            users: Tuple[User, ...]
1762        @dataclass
1763        class GroupDict:
1764            id: int
1765            users: Dict[str, User]
1766        a = User('Alice', 1)
1767        b = User('Bob', 2)
1768        gl = GroupList(0, [a, b])
1769        gt = GroupTuple(0, (a, b))
1770        gd = GroupDict(0, {'first': a, 'second': b})
1771        self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1772        self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1773        self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1774
1775    def test_helper_astuple_builtin_object_containers(self):
1776        @dataclass
1777        class Child:
1778            d: object
1779
1780        @dataclass
1781        class Parent:
1782            child: Child
1783
1784        self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1785        self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1786
1787    def test_helper_astuple_factory(self):
1788        @dataclass
1789        class C:
1790            x: int
1791            y: int
1792        NT = namedtuple('NT', 'x y')
1793        def nt(lst):
1794            return NT(*lst)
1795        c = C(1, 2)
1796        t = astuple(c, tuple_factory=nt)
1797        self.assertEqual(t, NT(1, 2))
1798        self.assertIsNot(t, astuple(c, tuple_factory=nt))
1799        c.x = 42
1800        t = astuple(c, tuple_factory=nt)
1801        self.assertEqual(t, NT(42, 2))
1802        self.assertIs(type(t), NT)
1803
1804    def test_helper_astuple_namedtuple(self):
1805        T = namedtuple('T', 'a b c')
1806        @dataclass
1807        class C:
1808            x: str
1809            y: T
1810        c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1811
1812        t = astuple(c)
1813        self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1814
1815        # Now, using a tuple_factory.  list is convenient here.
1816        t = astuple(c, tuple_factory=list)
1817        self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1818
1819    def test_dynamic_class_creation(self):
1820        cls_dict = {'__annotations__': {'x': int, 'y': int},
1821                    }
1822
1823        # Create the class.
1824        cls = type('C', (), cls_dict)
1825
1826        # Make it a dataclass.
1827        cls1 = dataclass(cls)
1828
1829        self.assertEqual(cls1, cls)
1830        self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1831
1832    def test_dynamic_class_creation_using_field(self):
1833        cls_dict = {'__annotations__': {'x': int, 'y': int},
1834                    'y': field(default=5),
1835                    }
1836
1837        # Create the class.
1838        cls = type('C', (), cls_dict)
1839
1840        # Make it a dataclass.
1841        cls1 = dataclass(cls)
1842
1843        self.assertEqual(cls1, cls)
1844        self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1845
1846    def test_init_in_order(self):
1847        @dataclass
1848        class C:
1849            a: int
1850            b: int = field()
1851            c: list = field(default_factory=list, init=False)
1852            d: list = field(default_factory=list)
1853            e: int = field(default=4, init=False)
1854            f: int = 4
1855
1856        calls = []
1857        def setattr(self, name, value):
1858            calls.append((name, value))
1859
1860        C.__setattr__ = setattr
1861        c = C(0, 1)
1862        self.assertEqual(('a', 0), calls[0])
1863        self.assertEqual(('b', 1), calls[1])
1864        self.assertEqual(('c', []), calls[2])
1865        self.assertEqual(('d', []), calls[3])
1866        self.assertNotIn(('e', 4), calls)
1867        self.assertEqual(('f', 4), calls[4])
1868
1869    def test_items_in_dicts(self):
1870        @dataclass
1871        class C:
1872            a: int
1873            b: list = field(default_factory=list, init=False)
1874            c: list = field(default_factory=list)
1875            d: int = field(default=4, init=False)
1876            e: int = 0
1877
1878        c = C(0)
1879        # Class dict
1880        self.assertNotIn('a', C.__dict__)
1881        self.assertNotIn('b', C.__dict__)
1882        self.assertNotIn('c', C.__dict__)
1883        self.assertIn('d', C.__dict__)
1884        self.assertEqual(C.d, 4)
1885        self.assertIn('e', C.__dict__)
1886        self.assertEqual(C.e, 0)
1887        # Instance dict
1888        self.assertIn('a', c.__dict__)
1889        self.assertEqual(c.a, 0)
1890        self.assertIn('b', c.__dict__)
1891        self.assertEqual(c.b, [])
1892        self.assertIn('c', c.__dict__)
1893        self.assertEqual(c.c, [])
1894        self.assertNotIn('d', c.__dict__)
1895        self.assertIn('e', c.__dict__)
1896        self.assertEqual(c.e, 0)
1897
1898    def test_alternate_classmethod_constructor(self):
1899        # Since __post_init__ can't take params, use a classmethod
1900        #  alternate constructor.  This is mostly an example to show
1901        #  how to use this technique.
1902        @dataclass
1903        class C:
1904            x: int
1905            @classmethod
1906            def from_file(cls, filename):
1907                # In a real example, create a new instance
1908                #  and populate 'x' from contents of a file.
1909                value_in_file = 20
1910                return cls(value_in_file)
1911
1912        self.assertEqual(C.from_file('filename').x, 20)
1913
1914    def test_field_metadata_default(self):
1915        # Make sure the default metadata is read-only and of
1916        #  zero length.
1917        @dataclass
1918        class C:
1919            i: int
1920
1921        self.assertFalse(fields(C)[0].metadata)
1922        self.assertEqual(len(fields(C)[0].metadata), 0)
1923        with self.assertRaisesRegex(TypeError,
1924                                    'does not support item assignment'):
1925            fields(C)[0].metadata['test'] = 3
1926
1927    def test_field_metadata_mapping(self):
1928        # Make sure only a mapping can be passed as metadata
1929        #  zero length.
1930        with self.assertRaises(TypeError):
1931            @dataclass
1932            class C:
1933                i: int = field(metadata=0)
1934
1935        # Make sure an empty dict works.
1936        d = {}
1937        @dataclass
1938        class C:
1939            i: int = field(metadata=d)
1940        self.assertFalse(fields(C)[0].metadata)
1941        self.assertEqual(len(fields(C)[0].metadata), 0)
1942        # Update should work (see bpo-35960).
1943        d['foo'] = 1
1944        self.assertEqual(len(fields(C)[0].metadata), 1)
1945        self.assertEqual(fields(C)[0].metadata['foo'], 1)
1946        with self.assertRaisesRegex(TypeError,
1947                                    'does not support item assignment'):
1948            fields(C)[0].metadata['test'] = 3
1949
1950        # Make sure a non-empty dict works.
1951        d = {'test': 10, 'bar': '42', 3: 'three'}
1952        @dataclass
1953        class C:
1954            i: int = field(metadata=d)
1955        self.assertEqual(len(fields(C)[0].metadata), 3)
1956        self.assertEqual(fields(C)[0].metadata['test'], 10)
1957        self.assertEqual(fields(C)[0].metadata['bar'], '42')
1958        self.assertEqual(fields(C)[0].metadata[3], 'three')
1959        # Update should work.
1960        d['foo'] = 1
1961        self.assertEqual(len(fields(C)[0].metadata), 4)
1962        self.assertEqual(fields(C)[0].metadata['foo'], 1)
1963        with self.assertRaises(KeyError):
1964            # Non-existent key.
1965            fields(C)[0].metadata['baz']
1966        with self.assertRaisesRegex(TypeError,
1967                                    'does not support item assignment'):
1968            fields(C)[0].metadata['test'] = 3
1969
1970    def test_field_metadata_custom_mapping(self):
1971        # Try a custom mapping.
1972        class SimpleNameSpace:
1973            def __init__(self, **kw):
1974                self.__dict__.update(kw)
1975
1976            def __getitem__(self, item):
1977                if item == 'xyzzy':
1978                    return 'plugh'
1979                return getattr(self, item)
1980
1981            def __len__(self):
1982                return self.__dict__.__len__()
1983
1984        @dataclass
1985        class C:
1986            i: int = field(metadata=SimpleNameSpace(a=10))
1987
1988        self.assertEqual(len(fields(C)[0].metadata), 1)
1989        self.assertEqual(fields(C)[0].metadata['a'], 10)
1990        with self.assertRaises(AttributeError):
1991            fields(C)[0].metadata['b']
1992        # Make sure we're still talking to our custom mapping.
1993        self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1994
1995    def test_generic_dataclasses(self):
1996        T = TypeVar('T')
1997
1998        @dataclass
1999        class LabeledBox(Generic[T]):
2000            content: T
2001            label: str = '<unknown>'
2002
2003        box = LabeledBox(42)
2004        self.assertEqual(box.content, 42)
2005        self.assertEqual(box.label, '<unknown>')
2006
2007        # Subscripting the resulting class should work, etc.
2008        Alias = List[LabeledBox[int]]
2009
2010    def test_generic_extending(self):
2011        S = TypeVar('S')
2012        T = TypeVar('T')
2013
2014        @dataclass
2015        class Base(Generic[T, S]):
2016            x: T
2017            y: S
2018
2019        @dataclass
2020        class DataDerived(Base[int, T]):
2021            new_field: str
2022        Alias = DataDerived[str]
2023        c = Alias(0, 'test1', 'test2')
2024        self.assertEqual(astuple(c), (0, 'test1', 'test2'))
2025
2026        class NonDataDerived(Base[int, T]):
2027            def new_method(self):
2028                return self.y
2029        Alias = NonDataDerived[float]
2030        c = Alias(10, 1.0)
2031        self.assertEqual(c.new_method(), 1.0)
2032
2033    def test_generic_dynamic(self):
2034        T = TypeVar('T')
2035
2036        @dataclass
2037        class Parent(Generic[T]):
2038            x: T
2039        Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
2040                               bases=(Parent[int], Generic[T]), namespace={'other': 42})
2041        self.assertIs(Child[int](1, 2).z, None)
2042        self.assertEqual(Child[int](1, 2, 3).z, 3)
2043        self.assertEqual(Child[int](1, 2, 3).other, 42)
2044        # Check that type aliases work correctly.
2045        Alias = Child[T]
2046        self.assertEqual(Alias[int](1, 2).x, 1)
2047        # Check MRO resolution.
2048        self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
2049
2050    def test_dataclasses_pickleable(self):
2051        global P, Q, R
2052        @dataclass
2053        class P:
2054            x: int
2055            y: int = 0
2056        @dataclass
2057        class Q:
2058            x: int
2059            y: int = field(default=0, init=False)
2060        @dataclass
2061        class R:
2062            x: int
2063            y: List[int] = field(default_factory=list)
2064        q = Q(1)
2065        q.y = 2
2066        samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
2067        for sample in samples:
2068            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2069                with self.subTest(sample=sample, proto=proto):
2070                    new_sample = pickle.loads(pickle.dumps(sample, proto))
2071                    self.assertEqual(sample.x, new_sample.x)
2072                    self.assertEqual(sample.y, new_sample.y)
2073                    self.assertIsNot(sample, new_sample)
2074                    new_sample.x = 42
2075                    another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
2076                    self.assertEqual(new_sample.x, another_new_sample.x)
2077                    self.assertEqual(sample.y, another_new_sample.y)
2078
2079    def test_dataclasses_qualnames(self):
2080        @dataclass(order=True, unsafe_hash=True, frozen=True)
2081        class A:
2082            x: int
2083            y: int
2084
2085        self.assertEqual(A.__init__.__name__, "__init__")
2086        for function in (
2087            '__eq__',
2088            '__lt__',
2089            '__le__',
2090            '__gt__',
2091            '__ge__',
2092            '__hash__',
2093            '__init__',
2094            '__repr__',
2095            '__setattr__',
2096            '__delattr__',
2097        ):
2098            self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}")
2099
2100        with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"):
2101            A()
2102
2103
2104class TestFieldNoAnnotation(unittest.TestCase):
2105    def test_field_without_annotation(self):
2106        with self.assertRaisesRegex(TypeError,
2107                                    "'f' is a field but has no type annotation"):
2108            @dataclass
2109            class C:
2110                f = field()
2111
2112    def test_field_without_annotation_but_annotation_in_base(self):
2113        @dataclass
2114        class B:
2115            f: int
2116
2117        with self.assertRaisesRegex(TypeError,
2118                                    "'f' is a field but has no type annotation"):
2119            # This is still an error: make sure we don't pick up the
2120            #  type annotation in the base class.
2121            @dataclass
2122            class C(B):
2123                f = field()
2124
2125    def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
2126        # Same test, but with the base class not a dataclass.
2127        class B:
2128            f: int
2129
2130        with self.assertRaisesRegex(TypeError,
2131                                    "'f' is a field but has no type annotation"):
2132            # This is still an error: make sure we don't pick up the
2133            #  type annotation in the base class.
2134            @dataclass
2135            class C(B):
2136                f = field()
2137
2138
2139class TestDocString(unittest.TestCase):
2140    def assertDocStrEqual(self, a, b):
2141        # Because 3.6 and 3.7 differ in how inspect.signature work
2142        #  (see bpo #32108), for the time being just compare them with
2143        #  whitespace stripped.
2144        self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
2145
2146    def test_existing_docstring_not_overridden(self):
2147        @dataclass
2148        class C:
2149            """Lorem ipsum"""
2150            x: int
2151
2152        self.assertEqual(C.__doc__, "Lorem ipsum")
2153
2154    def test_docstring_no_fields(self):
2155        @dataclass
2156        class C:
2157            pass
2158
2159        self.assertDocStrEqual(C.__doc__, "C()")
2160
2161    def test_docstring_one_field(self):
2162        @dataclass
2163        class C:
2164            x: int
2165
2166        self.assertDocStrEqual(C.__doc__, "C(x:int)")
2167
2168    def test_docstring_two_fields(self):
2169        @dataclass
2170        class C:
2171            x: int
2172            y: int
2173
2174        self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2175
2176    def test_docstring_three_fields(self):
2177        @dataclass
2178        class C:
2179            x: int
2180            y: int
2181            z: str
2182
2183        self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2184
2185    def test_docstring_one_field_with_default(self):
2186        @dataclass
2187        class C:
2188            x: int = 3
2189
2190        self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2191
2192    def test_docstring_one_field_with_default_none(self):
2193        @dataclass
2194        class C:
2195            x: Union[int, type(None)] = None
2196
2197        self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
2198
2199    def test_docstring_list_field(self):
2200        @dataclass
2201        class C:
2202            x: List[int]
2203
2204        self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2205
2206    def test_docstring_list_field_with_default_factory(self):
2207        @dataclass
2208        class C:
2209            x: List[int] = field(default_factory=list)
2210
2211        self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2212
2213    def test_docstring_deque_field(self):
2214        @dataclass
2215        class C:
2216            x: deque
2217
2218        self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2219
2220    def test_docstring_deque_field_with_default_factory(self):
2221        @dataclass
2222        class C:
2223            x: deque = field(default_factory=deque)
2224
2225        self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2226
2227    def test_docstring_with_no_signature(self):
2228        # See https://github.com/python/cpython/issues/103449
2229        class Meta(type):
2230            __call__ = dict
2231        class Base(metaclass=Meta):
2232            pass
2233
2234        @dataclass
2235        class C(Base):
2236            pass
2237
2238        self.assertDocStrEqual(C.__doc__, "C")
2239
2240
2241class TestInit(unittest.TestCase):
2242    def test_base_has_init(self):
2243        class B:
2244            def __init__(self):
2245                self.z = 100
2246                pass
2247
2248        # Make sure that declaring this class doesn't raise an error.
2249        #  The issue is that we can't override __init__ in our class,
2250        #  but it should be okay to add __init__ to us if our base has
2251        #  an __init__.
2252        @dataclass
2253        class C(B):
2254            x: int = 0
2255        c = C(10)
2256        self.assertEqual(c.x, 10)
2257        self.assertNotIn('z', vars(c))
2258
2259        # Make sure that if we don't add an init, the base __init__
2260        #  gets called.
2261        @dataclass(init=False)
2262        class C(B):
2263            x: int = 10
2264        c = C()
2265        self.assertEqual(c.x, 10)
2266        self.assertEqual(c.z, 100)
2267
2268    def test_no_init(self):
2269        @dataclass(init=False)
2270        class C:
2271            i: int = 0
2272        self.assertEqual(C().i, 0)
2273
2274        @dataclass(init=False)
2275        class C:
2276            i: int = 2
2277            def __init__(self):
2278                self.i = 3
2279        self.assertEqual(C().i, 3)
2280
2281    def test_overwriting_init(self):
2282        # If the class has __init__, use it no matter the value of
2283        #  init=.
2284
2285        @dataclass
2286        class C:
2287            x: int
2288            def __init__(self, x):
2289                self.x = 2 * x
2290        self.assertEqual(C(3).x, 6)
2291
2292        @dataclass(init=True)
2293        class C:
2294            x: int
2295            def __init__(self, x):
2296                self.x = 2 * x
2297        self.assertEqual(C(4).x, 8)
2298
2299        @dataclass(init=False)
2300        class C:
2301            x: int
2302            def __init__(self, x):
2303                self.x = 2 * x
2304        self.assertEqual(C(5).x, 10)
2305
2306    def test_inherit_from_protocol(self):
2307        # Dataclasses inheriting from protocol should preserve their own `__init__`.
2308        # See bpo-45081.
2309
2310        class P(Protocol):
2311            a: int
2312
2313        @dataclass
2314        class C(P):
2315            a: int
2316
2317        self.assertEqual(C(5).a, 5)
2318
2319        @dataclass
2320        class D(P):
2321            def __init__(self, a):
2322                self.a = a * 2
2323
2324        self.assertEqual(D(5).a, 10)
2325
2326
2327class TestRepr(unittest.TestCase):
2328    def test_repr(self):
2329        @dataclass
2330        class B:
2331            x: int
2332
2333        @dataclass
2334        class C(B):
2335            y: int = 10
2336
2337        o = C(4)
2338        self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2339
2340        @dataclass
2341        class D(C):
2342            x: int = 20
2343        self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2344
2345        @dataclass
2346        class C:
2347            @dataclass
2348            class D:
2349                i: int
2350            @dataclass
2351            class E:
2352                pass
2353        self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2354        self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2355
2356    def test_no_repr(self):
2357        # Test a class with no __repr__ and repr=False.
2358        @dataclass(repr=False)
2359        class C:
2360            x: int
2361        self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
2362                      repr(C(3)))
2363
2364        # Test a class with a __repr__ and repr=False.
2365        @dataclass(repr=False)
2366        class C:
2367            x: int
2368            def __repr__(self):
2369                return 'C-class'
2370        self.assertEqual(repr(C(3)), 'C-class')
2371
2372    def test_overwriting_repr(self):
2373        # If the class has __repr__, use it no matter the value of
2374        #  repr=.
2375
2376        @dataclass
2377        class C:
2378            x: int
2379            def __repr__(self):
2380                return 'x'
2381        self.assertEqual(repr(C(0)), 'x')
2382
2383        @dataclass(repr=True)
2384        class C:
2385            x: int
2386            def __repr__(self):
2387                return 'x'
2388        self.assertEqual(repr(C(0)), 'x')
2389
2390        @dataclass(repr=False)
2391        class C:
2392            x: int
2393            def __repr__(self):
2394                return 'x'
2395        self.assertEqual(repr(C(0)), 'x')
2396
2397
2398class TestEq(unittest.TestCase):
2399    def test_no_eq(self):
2400        # Test a class with no __eq__ and eq=False.
2401        @dataclass(eq=False)
2402        class C:
2403            x: int
2404        self.assertNotEqual(C(0), C(0))
2405        c = C(3)
2406        self.assertEqual(c, c)
2407
2408        # Test a class with an __eq__ and eq=False.
2409        @dataclass(eq=False)
2410        class C:
2411            x: int
2412            def __eq__(self, other):
2413                return other == 10
2414        self.assertEqual(C(3), 10)
2415
2416    def test_overwriting_eq(self):
2417        # If the class has __eq__, use it no matter the value of
2418        #  eq=.
2419
2420        @dataclass
2421        class C:
2422            x: int
2423            def __eq__(self, other):
2424                return other == 3
2425        self.assertEqual(C(1), 3)
2426        self.assertNotEqual(C(1), 1)
2427
2428        @dataclass(eq=True)
2429        class C:
2430            x: int
2431            def __eq__(self, other):
2432                return other == 4
2433        self.assertEqual(C(1), 4)
2434        self.assertNotEqual(C(1), 1)
2435
2436        @dataclass(eq=False)
2437        class C:
2438            x: int
2439            def __eq__(self, other):
2440                return other == 5
2441        self.assertEqual(C(1), 5)
2442        self.assertNotEqual(C(1), 1)
2443
2444
2445class TestOrdering(unittest.TestCase):
2446    def test_functools_total_ordering(self):
2447        # Test that functools.total_ordering works with this class.
2448        @total_ordering
2449        @dataclass
2450        class C:
2451            x: int
2452            def __lt__(self, other):
2453                # Perform the test "backward", just to make
2454                #  sure this is being called.
2455                return self.x >= other
2456
2457        self.assertLess(C(0), -1)
2458        self.assertLessEqual(C(0), -1)
2459        self.assertGreater(C(0), 1)
2460        self.assertGreaterEqual(C(0), 1)
2461
2462    def test_no_order(self):
2463        # Test that no ordering functions are added by default.
2464        @dataclass(order=False)
2465        class C:
2466            x: int
2467        # Make sure no order methods are added.
2468        self.assertNotIn('__le__', C.__dict__)
2469        self.assertNotIn('__lt__', C.__dict__)
2470        self.assertNotIn('__ge__', C.__dict__)
2471        self.assertNotIn('__gt__', C.__dict__)
2472
2473        # Test that __lt__ is still called
2474        @dataclass(order=False)
2475        class C:
2476            x: int
2477            def __lt__(self, other):
2478                return False
2479        # Make sure other methods aren't added.
2480        self.assertNotIn('__le__', C.__dict__)
2481        self.assertNotIn('__ge__', C.__dict__)
2482        self.assertNotIn('__gt__', C.__dict__)
2483
2484    def test_overwriting_order(self):
2485        with self.assertRaisesRegex(TypeError,
2486                                    'Cannot overwrite attribute __lt__'
2487                                    '.*using functools.total_ordering'):
2488            @dataclass(order=True)
2489            class C:
2490                x: int
2491                def __lt__(self):
2492                    pass
2493
2494        with self.assertRaisesRegex(TypeError,
2495                                    'Cannot overwrite attribute __le__'
2496                                    '.*using functools.total_ordering'):
2497            @dataclass(order=True)
2498            class C:
2499                x: int
2500                def __le__(self):
2501                    pass
2502
2503        with self.assertRaisesRegex(TypeError,
2504                                    'Cannot overwrite attribute __gt__'
2505                                    '.*using functools.total_ordering'):
2506            @dataclass(order=True)
2507            class C:
2508                x: int
2509                def __gt__(self):
2510                    pass
2511
2512        with self.assertRaisesRegex(TypeError,
2513                                    'Cannot overwrite attribute __ge__'
2514                                    '.*using functools.total_ordering'):
2515            @dataclass(order=True)
2516            class C:
2517                x: int
2518                def __ge__(self):
2519                    pass
2520
2521class TestHash(unittest.TestCase):
2522    def test_unsafe_hash(self):
2523        @dataclass(unsafe_hash=True)
2524        class C:
2525            x: int
2526            y: str
2527        self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2528
2529    def test_hash_rules(self):
2530        def non_bool(value):
2531            # Map to something else that's True, but not a bool.
2532            if value is None:
2533                return None
2534            if value:
2535                return (3,)
2536            return 0
2537
2538        def test(case, unsafe_hash, eq, frozen, with_hash, result):
2539            with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2540                              frozen=frozen):
2541                if result != 'exception':
2542                    if with_hash:
2543                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2544                        class C:
2545                            def __hash__(self):
2546                                return 0
2547                    else:
2548                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2549                        class C:
2550                            pass
2551
2552                # See if the result matches what's expected.
2553                if result == 'fn':
2554                    # __hash__ contains the function we generated.
2555                    self.assertIn('__hash__', C.__dict__)
2556                    self.assertIsNotNone(C.__dict__['__hash__'])
2557
2558                elif result == '':
2559                    # __hash__ is not present in our class.
2560                    if not with_hash:
2561                        self.assertNotIn('__hash__', C.__dict__)
2562
2563                elif result == 'none':
2564                    # __hash__ is set to None.
2565                    self.assertIn('__hash__', C.__dict__)
2566                    self.assertIsNone(C.__dict__['__hash__'])
2567
2568                elif result == 'exception':
2569                    # Creating the class should cause an exception.
2570                    #  This only happens with with_hash==True.
2571                    assert(with_hash)
2572                    with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2573                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2574                        class C:
2575                            def __hash__(self):
2576                                return 0
2577
2578                else:
2579                    assert False, f'unknown result {result!r}'
2580
2581        # There are 8 cases of:
2582        #  unsafe_hash=True/False
2583        #  eq=True/False
2584        #  frozen=True/False
2585        # And for each of these, a different result if
2586        #  __hash__ is defined or not.
2587        for case, (unsafe_hash,  eq,    frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2588                  (False,        False, False,  '',                  ''),
2589                  (False,        False, True,   '',                  ''),
2590                  (False,        True,  False,  'none',              ''),
2591                  (False,        True,  True,   'fn',                ''),
2592                  (True,         False, False,  'fn',                'exception'),
2593                  (True,         False, True,   'fn',                'exception'),
2594                  (True,         True,  False,  'fn',                'exception'),
2595                  (True,         True,  True,   'fn',                'exception'),
2596                  ], 1):
2597            test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2598            test(case, unsafe_hash, eq, frozen, True,  res_defined_hash)
2599
2600            # Test non-bool truth values, too.  This is just to
2601            #  make sure the data-driven table in the decorator
2602            #  handles non-bool values.
2603            test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2604            test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True,  res_defined_hash)
2605
2606
2607    def test_eq_only(self):
2608        # If a class defines __eq__, __hash__ is automatically added
2609        #  and set to None.  This is normal Python behavior, not
2610        #  related to dataclasses.  Make sure we don't interfere with
2611        #  that (see bpo=32546).
2612
2613        @dataclass
2614        class C:
2615            i: int
2616            def __eq__(self, other):
2617                return self.i == other.i
2618        self.assertEqual(C(1), C(1))
2619        self.assertNotEqual(C(1), C(4))
2620
2621        # And make sure things work in this case if we specify
2622        #  unsafe_hash=True.
2623        @dataclass(unsafe_hash=True)
2624        class C:
2625            i: int
2626            def __eq__(self, other):
2627                return self.i == other.i
2628        self.assertEqual(C(1), C(1.0))
2629        self.assertEqual(hash(C(1)), hash(C(1.0)))
2630
2631        # And check that the classes __eq__ is being used, despite
2632        #  specifying eq=True.
2633        @dataclass(unsafe_hash=True, eq=True)
2634        class C:
2635            i: int
2636            def __eq__(self, other):
2637                return self.i == 3 and self.i == other.i
2638        self.assertEqual(C(3), C(3))
2639        self.assertNotEqual(C(1), C(1))
2640        self.assertEqual(hash(C(1)), hash(C(1.0)))
2641
2642    def test_0_field_hash(self):
2643        @dataclass(frozen=True)
2644        class C:
2645            pass
2646        self.assertEqual(hash(C()), hash(()))
2647
2648        @dataclass(unsafe_hash=True)
2649        class C:
2650            pass
2651        self.assertEqual(hash(C()), hash(()))
2652
2653    def test_1_field_hash(self):
2654        @dataclass(frozen=True)
2655        class C:
2656            x: int
2657        self.assertEqual(hash(C(4)), hash((4,)))
2658        self.assertEqual(hash(C(42)), hash((42,)))
2659
2660        @dataclass(unsafe_hash=True)
2661        class C:
2662            x: int
2663        self.assertEqual(hash(C(4)), hash((4,)))
2664        self.assertEqual(hash(C(42)), hash((42,)))
2665
2666    def test_hash_no_args(self):
2667        # Test dataclasses with no hash= argument.  This exists to
2668        #  make sure that if the @dataclass parameter name is changed
2669        #  or the non-default hashing behavior changes, the default
2670        #  hashability keeps working the same way.
2671
2672        class Base:
2673            def __hash__(self):
2674                return 301
2675
2676        # If frozen or eq is None, then use the default value (do not
2677        #  specify any value in the decorator).
2678        for frozen, eq,    base,   expected       in [
2679            (None,  None,  object, 'unhashable'),
2680            (None,  None,  Base,   'unhashable'),
2681            (None,  False, object, 'object'),
2682            (None,  False, Base,   'base'),
2683            (None,  True,  object, 'unhashable'),
2684            (None,  True,  Base,   'unhashable'),
2685            (False, None,  object, 'unhashable'),
2686            (False, None,  Base,   'unhashable'),
2687            (False, False, object, 'object'),
2688            (False, False, Base,   'base'),
2689            (False, True,  object, 'unhashable'),
2690            (False, True,  Base,   'unhashable'),
2691            (True,  None,  object, 'tuple'),
2692            (True,  None,  Base,   'tuple'),
2693            (True,  False, object, 'object'),
2694            (True,  False, Base,   'base'),
2695            (True,  True,  object, 'tuple'),
2696            (True,  True,  Base,   'tuple'),
2697            ]:
2698
2699            with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2700                # First, create the class.
2701                if frozen is None and eq is None:
2702                    @dataclass
2703                    class C(base):
2704                        i: int
2705                elif frozen is None:
2706                    @dataclass(eq=eq)
2707                    class C(base):
2708                        i: int
2709                elif eq is None:
2710                    @dataclass(frozen=frozen)
2711                    class C(base):
2712                        i: int
2713                else:
2714                    @dataclass(frozen=frozen, eq=eq)
2715                    class C(base):
2716                        i: int
2717
2718                # Now, make sure it hashes as expected.
2719                if expected == 'unhashable':
2720                    c = C(10)
2721                    with self.assertRaisesRegex(TypeError, 'unhashable type'):
2722                        hash(c)
2723
2724                elif expected == 'base':
2725                    self.assertEqual(hash(C(10)), 301)
2726
2727                elif expected == 'object':
2728                    # I'm not sure what test to use here.  object's
2729                    #  hash isn't based on id(), so calling hash()
2730                    #  won't tell us much.  So, just check the
2731                    #  function used is object's.
2732                    self.assertIs(C.__hash__, object.__hash__)
2733
2734                elif expected == 'tuple':
2735                    self.assertEqual(hash(C(42)), hash((42,)))
2736
2737                else:
2738                    assert False, f'unknown value for expected={expected!r}'
2739
2740
2741class TestFrozen(unittest.TestCase):
2742    def test_frozen(self):
2743        @dataclass(frozen=True)
2744        class C:
2745            i: int
2746
2747        c = C(10)
2748        self.assertEqual(c.i, 10)
2749        with self.assertRaises(FrozenInstanceError):
2750            c.i = 5
2751        self.assertEqual(c.i, 10)
2752
2753    def test_inherit(self):
2754        @dataclass(frozen=True)
2755        class C:
2756            i: int
2757
2758        @dataclass(frozen=True)
2759        class D(C):
2760            j: int
2761
2762        d = D(0, 10)
2763        with self.assertRaises(FrozenInstanceError):
2764            d.i = 5
2765        with self.assertRaises(FrozenInstanceError):
2766            d.j = 6
2767        self.assertEqual(d.i, 0)
2768        self.assertEqual(d.j, 10)
2769
2770    def test_inherit_nonfrozen_from_empty_frozen(self):
2771        @dataclass(frozen=True)
2772        class C:
2773            pass
2774
2775        with self.assertRaisesRegex(TypeError,
2776                                    'cannot inherit non-frozen dataclass from a frozen one'):
2777            @dataclass
2778            class D(C):
2779                j: int
2780
2781    def test_inherit_nonfrozen_from_empty(self):
2782        @dataclass
2783        class C:
2784            pass
2785
2786        @dataclass
2787        class D(C):
2788            j: int
2789
2790        d = D(3)
2791        self.assertEqual(d.j, 3)
2792        self.assertIsInstance(d, C)
2793
2794    # Test both ways: with an intermediate normal (non-dataclass)
2795    #  class and without an intermediate class.
2796    def test_inherit_nonfrozen_from_frozen(self):
2797        for intermediate_class in [True, False]:
2798            with self.subTest(intermediate_class=intermediate_class):
2799                @dataclass(frozen=True)
2800                class C:
2801                    i: int
2802
2803                if intermediate_class:
2804                    class I(C): pass
2805                else:
2806                    I = C
2807
2808                with self.assertRaisesRegex(TypeError,
2809                                            'cannot inherit non-frozen dataclass from a frozen one'):
2810                    @dataclass
2811                    class D(I):
2812                        pass
2813
2814    def test_inherit_frozen_from_nonfrozen(self):
2815        for intermediate_class in [True, False]:
2816            with self.subTest(intermediate_class=intermediate_class):
2817                @dataclass
2818                class C:
2819                    i: int
2820
2821                if intermediate_class:
2822                    class I(C): pass
2823                else:
2824                    I = C
2825
2826                with self.assertRaisesRegex(TypeError,
2827                                            'cannot inherit frozen dataclass from a non-frozen one'):
2828                    @dataclass(frozen=True)
2829                    class D(I):
2830                        pass
2831
2832    def test_inherit_from_normal_class(self):
2833        for intermediate_class in [True, False]:
2834            with self.subTest(intermediate_class=intermediate_class):
2835                class C:
2836                    pass
2837
2838                if intermediate_class:
2839                    class I(C): pass
2840                else:
2841                    I = C
2842
2843                @dataclass(frozen=True)
2844                class D(I):
2845                    i: int
2846
2847            d = D(10)
2848            with self.assertRaises(FrozenInstanceError):
2849                d.i = 5
2850
2851    def test_non_frozen_normal_derived(self):
2852        # See bpo-32953.
2853
2854        @dataclass(frozen=True)
2855        class D:
2856            x: int
2857            y: int = 10
2858
2859        class S(D):
2860            pass
2861
2862        s = S(3)
2863        self.assertEqual(s.x, 3)
2864        self.assertEqual(s.y, 10)
2865        s.cached = True
2866
2867        # But can't change the frozen attributes.
2868        with self.assertRaises(FrozenInstanceError):
2869            s.x = 5
2870        with self.assertRaises(FrozenInstanceError):
2871            s.y = 5
2872        self.assertEqual(s.x, 3)
2873        self.assertEqual(s.y, 10)
2874        self.assertEqual(s.cached, True)
2875
2876    def test_overwriting_frozen(self):
2877        # frozen uses __setattr__ and __delattr__.
2878        with self.assertRaisesRegex(TypeError,
2879                                    'Cannot overwrite attribute __setattr__'):
2880            @dataclass(frozen=True)
2881            class C:
2882                x: int
2883                def __setattr__(self):
2884                    pass
2885
2886        with self.assertRaisesRegex(TypeError,
2887                                    'Cannot overwrite attribute __delattr__'):
2888            @dataclass(frozen=True)
2889            class C:
2890                x: int
2891                def __delattr__(self):
2892                    pass
2893
2894        @dataclass(frozen=False)
2895        class C:
2896            x: int
2897            def __setattr__(self, name, value):
2898                self.__dict__['x'] = value * 2
2899        self.assertEqual(C(10).x, 20)
2900
2901    def test_frozen_hash(self):
2902        @dataclass(frozen=True)
2903        class C:
2904            x: Any
2905
2906        # If x is immutable, we can compute the hash.  No exception is
2907        # raised.
2908        hash(C(3))
2909
2910        # If x is mutable, computing the hash is an error.
2911        with self.assertRaisesRegex(TypeError, 'unhashable type'):
2912            hash(C({}))
2913
2914
2915class TestSlots(unittest.TestCase):
2916    def test_simple(self):
2917        @dataclass
2918        class C:
2919            __slots__ = ('x',)
2920            x: Any
2921
2922        # There was a bug where a variable in a slot was assumed to
2923        #  also have a default value (of type
2924        #  types.MemberDescriptorType).
2925        with self.assertRaisesRegex(TypeError,
2926                                    r"__init__\(\) missing 1 required positional argument: 'x'"):
2927            C()
2928
2929        # We can create an instance, and assign to x.
2930        c = C(10)
2931        self.assertEqual(c.x, 10)
2932        c.x = 5
2933        self.assertEqual(c.x, 5)
2934
2935        # We can't assign to anything else.
2936        with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2937            c.y = 5
2938
2939    def test_derived_added_field(self):
2940        # See bpo-33100.
2941        @dataclass
2942        class Base:
2943            __slots__ = ('x',)
2944            x: Any
2945
2946        @dataclass
2947        class Derived(Base):
2948            x: int
2949            y: int
2950
2951        d = Derived(1, 2)
2952        self.assertEqual((d.x, d.y), (1, 2))
2953
2954        # We can add a new field to the derived instance.
2955        d.z = 10
2956
2957    def test_generated_slots(self):
2958        @dataclass(slots=True)
2959        class C:
2960            x: int
2961            y: int
2962
2963        c = C(1, 2)
2964        self.assertEqual((c.x, c.y), (1, 2))
2965
2966        c.x = 3
2967        c.y = 4
2968        self.assertEqual((c.x, c.y), (3, 4))
2969
2970        with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"):
2971            c.z = 5
2972
2973    def test_add_slots_when_slots_exists(self):
2974        with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'):
2975            @dataclass(slots=True)
2976            class C:
2977                __slots__ = ('x',)
2978                x: int
2979
2980    def test_generated_slots_value(self):
2981
2982        class Root:
2983            __slots__ = {'x'}
2984
2985        class Root2(Root):
2986            __slots__ = {'k': '...', 'j': ''}
2987
2988        class Root3(Root2):
2989            __slots__ = ['h']
2990
2991        class Root4(Root3):
2992            __slots__ = 'aa'
2993
2994        @dataclass(slots=True)
2995        class Base(Root4):
2996            y: int
2997            j: str
2998            h: str
2999
3000        self.assertEqual(Base.__slots__, ('y', ))
3001
3002        @dataclass(slots=True)
3003        class Derived(Base):
3004            aa: float
3005            x: str
3006            z: int
3007            k: str
3008            h: str
3009
3010        self.assertEqual(Derived.__slots__, ('z', ))
3011
3012        @dataclass
3013        class AnotherDerived(Base):
3014            z: int
3015
3016        self.assertNotIn('__slots__', AnotherDerived.__dict__)
3017
3018    def test_cant_inherit_from_iterator_slots(self):
3019
3020        class Root:
3021            __slots__ = iter(['a'])
3022
3023        class Root2(Root):
3024            __slots__ = ('b', )
3025
3026        with self.assertRaisesRegex(
3027           TypeError,
3028            "^Slots of 'Root' cannot be determined"
3029        ):
3030            @dataclass(slots=True)
3031            class C(Root2):
3032                x: int
3033
3034    def test_returns_new_class(self):
3035        class A:
3036            x: int
3037
3038        B = dataclass(A, slots=True)
3039        self.assertIsNot(A, B)
3040
3041        self.assertFalse(hasattr(A, "__slots__"))
3042        self.assertTrue(hasattr(B, "__slots__"))
3043
3044    # Can't be local to test_frozen_pickle.
3045    @dataclass(frozen=True, slots=True)
3046    class FrozenSlotsClass:
3047        foo: str
3048        bar: int
3049
3050    @dataclass(frozen=True)
3051    class FrozenWithoutSlotsClass:
3052        foo: str
3053        bar: int
3054
3055    def test_frozen_pickle(self):
3056        # bpo-43999
3057
3058        self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar"))
3059        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
3060            with self.subTest(proto=proto):
3061                obj = self.FrozenSlotsClass("a", 1)
3062                p = pickle.loads(pickle.dumps(obj, protocol=proto))
3063                self.assertIsNot(obj, p)
3064                self.assertEqual(obj, p)
3065
3066                obj = self.FrozenWithoutSlotsClass("a", 1)
3067                p = pickle.loads(pickle.dumps(obj, protocol=proto))
3068                self.assertIsNot(obj, p)
3069                self.assertEqual(obj, p)
3070
3071    @dataclass(frozen=True, slots=True)
3072    class FrozenSlotsGetStateClass:
3073        foo: str
3074        bar: int
3075
3076        getstate_called: bool = field(default=False, compare=False)
3077
3078        def __getstate__(self):
3079            object.__setattr__(self, 'getstate_called', True)
3080            return [self.foo, self.bar]
3081
3082    @dataclass(frozen=True, slots=True)
3083    class FrozenSlotsSetStateClass:
3084        foo: str
3085        bar: int
3086
3087        setstate_called: bool = field(default=False, compare=False)
3088
3089        def __setstate__(self, state):
3090            object.__setattr__(self, 'setstate_called', True)
3091            object.__setattr__(self, 'foo', state[0])
3092            object.__setattr__(self, 'bar', state[1])
3093
3094    @dataclass(frozen=True, slots=True)
3095    class FrozenSlotsAllStateClass:
3096        foo: str
3097        bar: int
3098
3099        getstate_called: bool = field(default=False, compare=False)
3100        setstate_called: bool = field(default=False, compare=False)
3101
3102        def __getstate__(self):
3103            object.__setattr__(self, 'getstate_called', True)
3104            return [self.foo, self.bar]
3105
3106        def __setstate__(self, state):
3107            object.__setattr__(self, 'setstate_called', True)
3108            object.__setattr__(self, 'foo', state[0])
3109            object.__setattr__(self, 'bar', state[1])
3110
3111    def test_frozen_slots_pickle_custom_state(self):
3112        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
3113            with self.subTest(proto=proto):
3114                obj = self.FrozenSlotsGetStateClass('a', 1)
3115                dumped = pickle.dumps(obj, protocol=proto)
3116
3117                self.assertTrue(obj.getstate_called)
3118                self.assertEqual(obj, pickle.loads(dumped))
3119
3120        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
3121            with self.subTest(proto=proto):
3122                obj = self.FrozenSlotsSetStateClass('a', 1)
3123                obj2 = pickle.loads(pickle.dumps(obj, protocol=proto))
3124
3125                self.assertTrue(obj2.setstate_called)
3126                self.assertEqual(obj, obj2)
3127
3128        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
3129            with self.subTest(proto=proto):
3130                obj = self.FrozenSlotsAllStateClass('a', 1)
3131                dumped = pickle.dumps(obj, protocol=proto)
3132
3133                self.assertTrue(obj.getstate_called)
3134
3135                obj2 = pickle.loads(dumped)
3136                self.assertTrue(obj2.setstate_called)
3137                self.assertEqual(obj, obj2)
3138
3139    def test_slots_with_default_no_init(self):
3140        # Originally reported in bpo-44649.
3141        @dataclass(slots=True)
3142        class A:
3143            a: str
3144            b: str = field(default='b', init=False)
3145
3146        obj = A("a")
3147        self.assertEqual(obj.a, 'a')
3148        self.assertEqual(obj.b, 'b')
3149
3150    def test_slots_with_default_factory_no_init(self):
3151        # Originally reported in bpo-44649.
3152        @dataclass(slots=True)
3153        class A:
3154            a: str
3155            b: str = field(default_factory=lambda:'b', init=False)
3156
3157        obj = A("a")
3158        self.assertEqual(obj.a, 'a')
3159        self.assertEqual(obj.b, 'b')
3160
3161    def test_slots_no_weakref(self):
3162        @dataclass(slots=True)
3163        class A:
3164            # No weakref.
3165            pass
3166
3167        self.assertNotIn("__weakref__", A.__slots__)
3168        a = A()
3169        with self.assertRaisesRegex(TypeError,
3170                                    "cannot create weak reference"):
3171            weakref.ref(a)
3172        with self.assertRaises(AttributeError):
3173            a.__weakref__
3174
3175    def test_slots_weakref(self):
3176        @dataclass(slots=True, weakref_slot=True)
3177        class A:
3178            a: int
3179
3180        self.assertIn("__weakref__", A.__slots__)
3181        a = A(1)
3182        a_ref = weakref.ref(a)
3183
3184        self.assertIs(a.__weakref__, a_ref)
3185
3186    def test_slots_weakref_base_str(self):
3187        class Base:
3188            __slots__ = '__weakref__'
3189
3190        @dataclass(slots=True)
3191        class A(Base):
3192            a: int
3193
3194        # __weakref__ is in the base class, not A.  But an A is still weakref-able.
3195        self.assertIn("__weakref__", Base.__slots__)
3196        self.assertNotIn("__weakref__", A.__slots__)
3197        a = A(1)
3198        weakref.ref(a)
3199
3200    def test_slots_weakref_base_tuple(self):
3201        # Same as test_slots_weakref_base, but use a tuple instead of a string
3202        # in the base class.
3203        class Base:
3204            __slots__ = ('__weakref__',)
3205
3206        @dataclass(slots=True)
3207        class A(Base):
3208            a: int
3209
3210        # __weakref__ is in the base class, not A.  But an A is still
3211        # weakref-able.
3212        self.assertIn("__weakref__", Base.__slots__)
3213        self.assertNotIn("__weakref__", A.__slots__)
3214        a = A(1)
3215        weakref.ref(a)
3216
3217    def test_weakref_slot_without_slot(self):
3218        with self.assertRaisesRegex(TypeError,
3219                                    "weakref_slot is True but slots is False"):
3220            @dataclass(weakref_slot=True)
3221            class A:
3222                a: int
3223
3224    def test_weakref_slot_make_dataclass(self):
3225        A = make_dataclass('A', [('a', int),], slots=True, weakref_slot=True)
3226        self.assertIn("__weakref__", A.__slots__)
3227        a = A(1)
3228        weakref.ref(a)
3229
3230        # And make sure if raises if slots=True is not given.
3231        with self.assertRaisesRegex(TypeError,
3232                                    "weakref_slot is True but slots is False"):
3233            B = make_dataclass('B', [('a', int),], weakref_slot=True)
3234
3235    def test_weakref_slot_subclass_weakref_slot(self):
3236        @dataclass(slots=True, weakref_slot=True)
3237        class Base:
3238            field: int
3239
3240        # A *can* also specify weakref_slot=True if it wants to (gh-93521)
3241        @dataclass(slots=True, weakref_slot=True)
3242        class A(Base):
3243            ...
3244
3245        # __weakref__ is in the base class, not A.  But an instance of A
3246        # is still weakref-able.
3247        self.assertIn("__weakref__", Base.__slots__)
3248        self.assertNotIn("__weakref__", A.__slots__)
3249        a = A(1)
3250        a_ref = weakref.ref(a)
3251        self.assertIs(a.__weakref__, a_ref)
3252
3253    def test_weakref_slot_subclass_no_weakref_slot(self):
3254        @dataclass(slots=True, weakref_slot=True)
3255        class Base:
3256            field: int
3257
3258        @dataclass(slots=True)
3259        class A(Base):
3260            ...
3261
3262        # __weakref__ is in the base class, not A.  Even though A doesn't
3263        # specify weakref_slot, it should still be weakref-able.
3264        self.assertIn("__weakref__", Base.__slots__)
3265        self.assertNotIn("__weakref__", A.__slots__)
3266        a = A(1)
3267        a_ref = weakref.ref(a)
3268        self.assertIs(a.__weakref__, a_ref)
3269
3270    def test_weakref_slot_normal_base_weakref_slot(self):
3271        class Base:
3272            __slots__ = ('__weakref__',)
3273
3274        @dataclass(slots=True, weakref_slot=True)
3275        class A(Base):
3276            field: int
3277
3278        # __weakref__ is in the base class, not A.  But an instance of
3279        # A is still weakref-able.
3280        self.assertIn("__weakref__", Base.__slots__)
3281        self.assertNotIn("__weakref__", A.__slots__)
3282        a = A(1)
3283        a_ref = weakref.ref(a)
3284        self.assertIs(a.__weakref__, a_ref)
3285
3286
3287class TestDescriptors(unittest.TestCase):
3288    def test_set_name(self):
3289        # See bpo-33141.
3290
3291        # Create a descriptor.
3292        class D:
3293            def __set_name__(self, owner, name):
3294                self.name = name + 'x'
3295            def __get__(self, instance, owner):
3296                if instance is not None:
3297                    return 1
3298                return self
3299
3300        # This is the case of just normal descriptor behavior, no
3301        #  dataclass code is involved in initializing the descriptor.
3302        @dataclass
3303        class C:
3304            c: int=D()
3305        self.assertEqual(C.c.name, 'cx')
3306
3307        # Now test with a default value and init=False, which is the
3308        #  only time this is really meaningful.  If not using
3309        #  init=False, then the descriptor will be overwritten, anyway.
3310        @dataclass
3311        class C:
3312            c: int=field(default=D(), init=False)
3313        self.assertEqual(C.c.name, 'cx')
3314        self.assertEqual(C().c, 1)
3315
3316    def test_non_descriptor(self):
3317        # PEP 487 says __set_name__ should work on non-descriptors.
3318        # Create a descriptor.
3319
3320        class D:
3321            def __set_name__(self, owner, name):
3322                self.name = name + 'x'
3323
3324        @dataclass
3325        class C:
3326            c: int=field(default=D(), init=False)
3327        self.assertEqual(C.c.name, 'cx')
3328
3329    def test_lookup_on_instance(self):
3330        # See bpo-33175.
3331        class D:
3332            pass
3333
3334        d = D()
3335        # Create an attribute on the instance, not type.
3336        d.__set_name__ = Mock()
3337
3338        # Make sure d.__set_name__ is not called.
3339        @dataclass
3340        class C:
3341            i: int=field(default=d, init=False)
3342
3343        self.assertEqual(d.__set_name__.call_count, 0)
3344
3345    def test_lookup_on_class(self):
3346        # See bpo-33175.
3347        class D:
3348            pass
3349        D.__set_name__ = Mock()
3350
3351        # Make sure D.__set_name__ is called.
3352        @dataclass
3353        class C:
3354            i: int=field(default=D(), init=False)
3355
3356        self.assertEqual(D.__set_name__.call_count, 1)
3357
3358    def test_init_calls_set(self):
3359        class D:
3360            pass
3361
3362        D.__set__ = Mock()
3363
3364        @dataclass
3365        class C:
3366            i: D = D()
3367
3368        # Make sure D.__set__ is called.
3369        D.__set__.reset_mock()
3370        c = C(5)
3371        self.assertEqual(D.__set__.call_count, 1)
3372
3373    def test_getting_field_calls_get(self):
3374        class D:
3375            pass
3376
3377        D.__set__ = Mock()
3378        D.__get__ = Mock()
3379
3380        @dataclass
3381        class C:
3382            i: D = D()
3383
3384        c = C(5)
3385
3386        # Make sure D.__get__ is called.
3387        D.__get__.reset_mock()
3388        value = c.i
3389        self.assertEqual(D.__get__.call_count, 1)
3390
3391    def test_setting_field_calls_set(self):
3392        class D:
3393            pass
3394
3395        D.__set__ = Mock()
3396
3397        @dataclass
3398        class C:
3399            i: D = D()
3400
3401        c = C(5)
3402
3403        # Make sure D.__set__ is called.
3404        D.__set__.reset_mock()
3405        c.i = 10
3406        self.assertEqual(D.__set__.call_count, 1)
3407
3408    def test_setting_uninitialized_descriptor_field(self):
3409        class D:
3410            pass
3411
3412        D.__set__ = Mock()
3413
3414        @dataclass
3415        class C:
3416            i: D
3417
3418        # D.__set__ is not called because there's no D instance to call it on
3419        D.__set__.reset_mock()
3420        c = C(5)
3421        self.assertEqual(D.__set__.call_count, 0)
3422
3423        # D.__set__ still isn't called after setting i to an instance of D
3424        # because descriptors don't behave like that when stored as instance vars
3425        c.i = D()
3426        c.i = 5
3427        self.assertEqual(D.__set__.call_count, 0)
3428
3429    def test_default_value(self):
3430        class D:
3431            def __get__(self, instance: Any, owner: object) -> int:
3432                if instance is None:
3433                    return 100
3434
3435                return instance._x
3436
3437            def __set__(self, instance: Any, value: int) -> None:
3438                instance._x = value
3439
3440        @dataclass
3441        class C:
3442            i: D = D()
3443
3444        c = C()
3445        self.assertEqual(c.i, 100)
3446
3447        c = C(5)
3448        self.assertEqual(c.i, 5)
3449
3450    def test_no_default_value(self):
3451        class D:
3452            def __get__(self, instance: Any, owner: object) -> int:
3453                if instance is None:
3454                    raise AttributeError()
3455
3456                return instance._x
3457
3458            def __set__(self, instance: Any, value: int) -> None:
3459                instance._x = value
3460
3461        @dataclass
3462        class C:
3463            i: D = D()
3464
3465        with self.assertRaisesRegex(TypeError, 'missing 1 required positional argument'):
3466            c = C()
3467
3468class TestStringAnnotations(unittest.TestCase):
3469    def test_classvar(self):
3470        # Some expressions recognized as ClassVar really aren't.  But
3471        #  if you're using string annotations, it's not an exact
3472        #  science.
3473        # These tests assume that both "import typing" and "from
3474        # typing import *" have been run in this file.
3475        for typestr in ('ClassVar[int]',
3476                        'ClassVar [int]',
3477                        ' ClassVar [int]',
3478                        'ClassVar',
3479                        ' ClassVar ',
3480                        'typing.ClassVar[int]',
3481                        'typing.ClassVar[str]',
3482                        ' typing.ClassVar[str]',
3483                        'typing .ClassVar[str]',
3484                        'typing. ClassVar[str]',
3485                        'typing.ClassVar [str]',
3486                        'typing.ClassVar [ str]',
3487
3488                        # Not syntactically valid, but these will
3489                        #  be treated as ClassVars.
3490                        'typing.ClassVar.[int]',
3491                        'typing.ClassVar+',
3492                        ):
3493            with self.subTest(typestr=typestr):
3494                @dataclass
3495                class C:
3496                    x: typestr
3497
3498                # x is a ClassVar, so C() takes no args.
3499                C()
3500
3501                # And it won't appear in the class's dict because it doesn't
3502                # have a default.
3503                self.assertNotIn('x', C.__dict__)
3504
3505    def test_isnt_classvar(self):
3506        for typestr in ('CV',
3507                        't.ClassVar',
3508                        't.ClassVar[int]',
3509                        'typing..ClassVar[int]',
3510                        'Classvar',
3511                        'Classvar[int]',
3512                        'typing.ClassVarx[int]',
3513                        'typong.ClassVar[int]',
3514                        'dataclasses.ClassVar[int]',
3515                        'typingxClassVar[str]',
3516                        ):
3517            with self.subTest(typestr=typestr):
3518                @dataclass
3519                class C:
3520                    x: typestr
3521
3522                # x is not a ClassVar, so C() takes one arg.
3523                self.assertEqual(C(10).x, 10)
3524
3525    def test_initvar(self):
3526        # These tests assume that both "import dataclasses" and "from
3527        #  dataclasses import *" have been run in this file.
3528        for typestr in ('InitVar[int]',
3529                        'InitVar [int]'
3530                        ' InitVar [int]',
3531                        'InitVar',
3532                        ' InitVar ',
3533                        'dataclasses.InitVar[int]',
3534                        'dataclasses.InitVar[str]',
3535                        ' dataclasses.InitVar[str]',
3536                        'dataclasses .InitVar[str]',
3537                        'dataclasses. InitVar[str]',
3538                        'dataclasses.InitVar [str]',
3539                        'dataclasses.InitVar [ str]',
3540
3541                        # Not syntactically valid, but these will
3542                        #  be treated as InitVars.
3543                        'dataclasses.InitVar.[int]',
3544                        'dataclasses.InitVar+',
3545                        ):
3546            with self.subTest(typestr=typestr):
3547                @dataclass
3548                class C:
3549                    x: typestr
3550
3551                # x is an InitVar, so doesn't create a member.
3552                with self.assertRaisesRegex(AttributeError,
3553                                            "object has no attribute 'x'"):
3554                    C(1).x
3555
3556    def test_isnt_initvar(self):
3557        for typestr in ('IV',
3558                        'dc.InitVar',
3559                        'xdataclasses.xInitVar',
3560                        'typing.xInitVar[int]',
3561                        ):
3562            with self.subTest(typestr=typestr):
3563                @dataclass
3564                class C:
3565                    x: typestr
3566
3567                # x is not an InitVar, so there will be a member x.
3568                self.assertEqual(C(10).x, 10)
3569
3570    def test_classvar_module_level_import(self):
3571        from test import dataclass_module_1
3572        from test import dataclass_module_1_str
3573        from test import dataclass_module_2
3574        from test import dataclass_module_2_str
3575
3576        for m in (dataclass_module_1, dataclass_module_1_str,
3577                  dataclass_module_2, dataclass_module_2_str,
3578                  ):
3579            with self.subTest(m=m):
3580                # There's a difference in how the ClassVars are
3581                # interpreted when using string annotations or
3582                # not. See the imported modules for details.
3583                if m.USING_STRINGS:
3584                    c = m.CV(10)
3585                else:
3586                    c = m.CV()
3587                self.assertEqual(c.cv0, 20)
3588
3589
3590                # There's a difference in how the InitVars are
3591                # interpreted when using string annotations or
3592                # not. See the imported modules for details.
3593                c = m.IV(0, 1, 2, 3, 4)
3594
3595                for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
3596                    with self.subTest(field_name=field_name):
3597                        with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
3598                            # Since field_name is an InitVar, it's
3599                            # not an instance field.
3600                            getattr(c, field_name)
3601
3602                if m.USING_STRINGS:
3603                    # iv4 is interpreted as a normal field.
3604                    self.assertIn('not_iv4', c.__dict__)
3605                    self.assertEqual(c.not_iv4, 4)
3606                else:
3607                    # iv4 is interpreted as an InitVar, so it
3608                    # won't exist on the instance.
3609                    self.assertNotIn('not_iv4', c.__dict__)
3610
3611    def test_text_annotations(self):
3612        from test import dataclass_textanno
3613
3614        self.assertEqual(
3615            get_type_hints(dataclass_textanno.Bar),
3616            {'foo': dataclass_textanno.Foo})
3617        self.assertEqual(
3618            get_type_hints(dataclass_textanno.Bar.__init__),
3619            {'foo': dataclass_textanno.Foo,
3620             'return': type(None)})
3621
3622
3623class TestMakeDataclass(unittest.TestCase):
3624    def test_simple(self):
3625        C = make_dataclass('C',
3626                           [('x', int),
3627                            ('y', int, field(default=5))],
3628                           namespace={'add_one': lambda self: self.x + 1})
3629        c = C(10)
3630        self.assertEqual((c.x, c.y), (10, 5))
3631        self.assertEqual(c.add_one(), 11)
3632
3633
3634    def test_no_mutate_namespace(self):
3635        # Make sure a provided namespace isn't mutated.
3636        ns = {}
3637        C = make_dataclass('C',
3638                           [('x', int),
3639                            ('y', int, field(default=5))],
3640                           namespace=ns)
3641        self.assertEqual(ns, {})
3642
3643    def test_base(self):
3644        class Base1:
3645            pass
3646        class Base2:
3647            pass
3648        C = make_dataclass('C',
3649                           [('x', int)],
3650                           bases=(Base1, Base2))
3651        c = C(2)
3652        self.assertIsInstance(c, C)
3653        self.assertIsInstance(c, Base1)
3654        self.assertIsInstance(c, Base2)
3655
3656    def test_base_dataclass(self):
3657        @dataclass
3658        class Base1:
3659            x: int
3660        class Base2:
3661            pass
3662        C = make_dataclass('C',
3663                           [('y', int)],
3664                           bases=(Base1, Base2))
3665        with self.assertRaisesRegex(TypeError, 'required positional'):
3666            c = C(2)
3667        c = C(1, 2)
3668        self.assertIsInstance(c, C)
3669        self.assertIsInstance(c, Base1)
3670        self.assertIsInstance(c, Base2)
3671
3672        self.assertEqual((c.x, c.y), (1, 2))
3673
3674    def test_init_var(self):
3675        def post_init(self, y):
3676            self.x *= y
3677
3678        C = make_dataclass('C',
3679                           [('x', int),
3680                            ('y', InitVar[int]),
3681                            ],
3682                           namespace={'__post_init__': post_init},
3683                           )
3684        c = C(2, 3)
3685        self.assertEqual(vars(c), {'x': 6})
3686        self.assertEqual(len(fields(c)), 1)
3687
3688    def test_class_var(self):
3689        C = make_dataclass('C',
3690                           [('x', int),
3691                            ('y', ClassVar[int], 10),
3692                            ('z', ClassVar[int], field(default=20)),
3693                            ])
3694        c = C(1)
3695        self.assertEqual(vars(c), {'x': 1})
3696        self.assertEqual(len(fields(c)), 1)
3697        self.assertEqual(C.y, 10)
3698        self.assertEqual(C.z, 20)
3699
3700    def test_other_params(self):
3701        C = make_dataclass('C',
3702                           [('x', int),
3703                            ('y', ClassVar[int], 10),
3704                            ('z', ClassVar[int], field(default=20)),
3705                            ],
3706                           init=False)
3707        # Make sure we have a repr, but no init.
3708        self.assertNotIn('__init__', vars(C))
3709        self.assertIn('__repr__', vars(C))
3710
3711        # Make sure random other params don't work.
3712        with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3713            C = make_dataclass('C',
3714                               [],
3715                               xxinit=False)
3716
3717    def test_no_types(self):
3718        C = make_dataclass('Point', ['x', 'y', 'z'])
3719        c = C(1, 2, 3)
3720        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3721        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3722                                             'y': 'typing.Any',
3723                                             'z': 'typing.Any'})
3724
3725        C = make_dataclass('Point', ['x', ('y', int), 'z'])
3726        c = C(1, 2, 3)
3727        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3728        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3729                                             'y': int,
3730                                             'z': 'typing.Any'})
3731
3732    def test_invalid_type_specification(self):
3733        for bad_field in [(),
3734                          (1, 2, 3, 4),
3735                          ]:
3736            with self.subTest(bad_field=bad_field):
3737                with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3738                    make_dataclass('C', ['a', bad_field])
3739
3740        # And test for things with no len().
3741        for bad_field in [float,
3742                          lambda x:x,
3743                          ]:
3744            with self.subTest(bad_field=bad_field):
3745                with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3746                    make_dataclass('C', ['a', bad_field])
3747
3748    def test_duplicate_field_names(self):
3749        for field in ['a', 'ab']:
3750            with self.subTest(field=field):
3751                with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3752                    make_dataclass('C', [field, 'a', field])
3753
3754    def test_keyword_field_names(self):
3755        for field in ['for', 'async', 'await', 'as']:
3756            with self.subTest(field=field):
3757                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3758                    make_dataclass('C', ['a', field])
3759                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3760                    make_dataclass('C', [field])
3761                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3762                    make_dataclass('C', [field, 'a'])
3763
3764    def test_non_identifier_field_names(self):
3765        for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3766            with self.subTest(field=field):
3767                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
3768                    make_dataclass('C', ['a', field])
3769                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
3770                    make_dataclass('C', [field])
3771                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
3772                    make_dataclass('C', [field, 'a'])
3773
3774    def test_underscore_field_names(self):
3775        # Unlike namedtuple, it's okay if dataclass field names have
3776        # an underscore.
3777        make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3778
3779    def test_funny_class_names_names(self):
3780        # No reason to prevent weird class names, since
3781        # types.new_class allows them.
3782        for classname in ['()', 'x,y', '*', '2@3', '']:
3783            with self.subTest(classname=classname):
3784                C = make_dataclass(classname, ['a', 'b'])
3785                self.assertEqual(C.__name__, classname)
3786
3787class TestReplace(unittest.TestCase):
3788    def test(self):
3789        @dataclass(frozen=True)
3790        class C:
3791            x: int
3792            y: int
3793
3794        c = C(1, 2)
3795        c1 = replace(c, x=3)
3796        self.assertEqual(c1.x, 3)
3797        self.assertEqual(c1.y, 2)
3798
3799    def test_frozen(self):
3800        @dataclass(frozen=True)
3801        class C:
3802            x: int
3803            y: int
3804            z: int = field(init=False, default=10)
3805            t: int = field(init=False, default=100)
3806
3807        c = C(1, 2)
3808        c1 = replace(c, x=3)
3809        self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3810        self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3811
3812
3813        with self.assertRaisesRegex(ValueError, 'init=False'):
3814            replace(c, x=3, z=20, t=50)
3815        with self.assertRaisesRegex(ValueError, 'init=False'):
3816            replace(c, z=20)
3817            replace(c, x=3, z=20, t=50)
3818
3819        # Make sure the result is still frozen.
3820        with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3821            c1.x = 3
3822
3823        # Make sure we can't replace an attribute that doesn't exist,
3824        #  if we're also replacing one that does exist.  Test this
3825        #  here, because setting attributes on frozen instances is
3826        #  handled slightly differently from non-frozen ones.
3827        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3828                                             "keyword argument 'a'"):
3829            c1 = replace(c, x=20, a=5)
3830
3831    def test_invalid_field_name(self):
3832        @dataclass(frozen=True)
3833        class C:
3834            x: int
3835            y: int
3836
3837        c = C(1, 2)
3838        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3839                                    "keyword argument 'z'"):
3840            c1 = replace(c, z=3)
3841
3842    def test_invalid_object(self):
3843        @dataclass(frozen=True)
3844        class C:
3845            x: int
3846            y: int
3847
3848        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3849            replace(C, x=3)
3850
3851        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3852            replace(0, x=3)
3853
3854    def test_no_init(self):
3855        @dataclass
3856        class C:
3857            x: int
3858            y: int = field(init=False, default=10)
3859
3860        c = C(1)
3861        c.y = 20
3862
3863        # Make sure y gets the default value.
3864        c1 = replace(c, x=5)
3865        self.assertEqual((c1.x, c1.y), (5, 10))
3866
3867        # Trying to replace y is an error.
3868        with self.assertRaisesRegex(ValueError, 'init=False'):
3869            replace(c, x=2, y=30)
3870
3871        with self.assertRaisesRegex(ValueError, 'init=False'):
3872            replace(c, y=30)
3873
3874    def test_classvar(self):
3875        @dataclass
3876        class C:
3877            x: int
3878            y: ClassVar[int] = 1000
3879
3880        c = C(1)
3881        d = C(2)
3882
3883        self.assertIs(c.y, d.y)
3884        self.assertEqual(c.y, 1000)
3885
3886        # Trying to replace y is an error: can't replace ClassVars.
3887        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3888                                    "unexpected keyword argument 'y'"):
3889            replace(c, y=30)
3890
3891        replace(c, x=5)
3892
3893    def test_initvar_is_specified(self):
3894        @dataclass
3895        class C:
3896            x: int
3897            y: InitVar[int]
3898
3899            def __post_init__(self, y):
3900                self.x *= y
3901
3902        c = C(1, 10)
3903        self.assertEqual(c.x, 10)
3904        with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3905                                    "specified with replace()"):
3906            replace(c, x=3)
3907        c = replace(c, x=3, y=5)
3908        self.assertEqual(c.x, 15)
3909
3910    def test_initvar_with_default_value(self):
3911        @dataclass
3912        class C:
3913            x: int
3914            y: InitVar[int] = None
3915            z: InitVar[int] = 42
3916
3917            def __post_init__(self, y, z):
3918                if y is not None:
3919                    self.x += y
3920                if z is not None:
3921                    self.x += z
3922
3923        c = C(x=1, y=10, z=1)
3924        self.assertEqual(replace(c), C(x=12))
3925        self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42))
3926        self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1))
3927
3928    def test_recursive_repr(self):
3929        @dataclass
3930        class C:
3931            f: "C"
3932
3933        c = C(None)
3934        c.f = c
3935        self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3936
3937    def test_recursive_repr_two_attrs(self):
3938        @dataclass
3939        class C:
3940            f: "C"
3941            g: "C"
3942
3943        c = C(None, None)
3944        c.f = c
3945        c.g = c
3946        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3947                                  ".<locals>.C(f=..., g=...)")
3948
3949    def test_recursive_repr_indirection(self):
3950        @dataclass
3951        class C:
3952            f: "D"
3953
3954        @dataclass
3955        class D:
3956            f: "C"
3957
3958        c = C(None)
3959        d = D(None)
3960        c.f = d
3961        d.f = c
3962        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3963                                  ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3964                                  ".<locals>.D(f=...))")
3965
3966    def test_recursive_repr_indirection_two(self):
3967        @dataclass
3968        class C:
3969            f: "D"
3970
3971        @dataclass
3972        class D:
3973            f: "E"
3974
3975        @dataclass
3976        class E:
3977            f: "C"
3978
3979        c = C(None)
3980        d = D(None)
3981        e = E(None)
3982        c.f = d
3983        d.f = e
3984        e.f = c
3985        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3986                                  ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3987                                  ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3988                                  ".<locals>.E(f=...)))")
3989
3990    def test_recursive_repr_misc_attrs(self):
3991        @dataclass
3992        class C:
3993            f: "C"
3994            g: int
3995
3996        c = C(None, 1)
3997        c.f = c
3998        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3999                                  ".<locals>.C(f=..., g=1)")
4000
4001    ## def test_initvar(self):
4002    ##     @dataclass
4003    ##     class C:
4004    ##         x: int
4005    ##         y: InitVar[int]
4006
4007    ##     c = C(1, 10)
4008    ##     d = C(2, 20)
4009
4010    ##     # In our case, replacing an InitVar is a no-op
4011    ##     self.assertEqual(c, replace(c, y=5))
4012
4013    ##     replace(c, x=5)
4014
4015class TestAbstract(unittest.TestCase):
4016    def test_abc_implementation(self):
4017        class Ordered(abc.ABC):
4018            @abc.abstractmethod
4019            def __lt__(self, other):
4020                pass
4021
4022            @abc.abstractmethod
4023            def __le__(self, other):
4024                pass
4025
4026        @dataclass(order=True)
4027        class Date(Ordered):
4028            year: int
4029            month: 'Month'
4030            day: 'int'
4031
4032        self.assertFalse(inspect.isabstract(Date))
4033        self.assertGreater(Date(2020,12,25), Date(2020,8,31))
4034
4035    def test_maintain_abc(self):
4036        class A(abc.ABC):
4037            @abc.abstractmethod
4038            def foo(self):
4039                pass
4040
4041        @dataclass
4042        class Date(A):
4043            year: int
4044            month: 'Month'
4045            day: 'int'
4046
4047        self.assertTrue(inspect.isabstract(Date))
4048        msg = 'class Date with abstract method foo'
4049        self.assertRaisesRegex(TypeError, msg, Date)
4050
4051
4052class TestMatchArgs(unittest.TestCase):
4053    def test_match_args(self):
4054        @dataclass
4055        class C:
4056            a: int
4057        self.assertEqual(C(42).__match_args__, ('a',))
4058
4059    def test_explicit_match_args(self):
4060        ma = ()
4061        @dataclass
4062        class C:
4063            a: int
4064            __match_args__ = ma
4065        self.assertIs(C(42).__match_args__, ma)
4066
4067    def test_bpo_43764(self):
4068        @dataclass(repr=False, eq=False, init=False)
4069        class X:
4070            a: int
4071            b: int
4072            c: int
4073        self.assertEqual(X.__match_args__, ("a", "b", "c"))
4074
4075    def test_match_args_argument(self):
4076        @dataclass(match_args=False)
4077        class X:
4078            a: int
4079        self.assertNotIn('__match_args__', X.__dict__)
4080
4081        @dataclass(match_args=False)
4082        class Y:
4083            a: int
4084            __match_args__ = ('b',)
4085        self.assertEqual(Y.__match_args__, ('b',))
4086
4087        @dataclass(match_args=False)
4088        class Z(Y):
4089            z: int
4090        self.assertEqual(Z.__match_args__, ('b',))
4091
4092        # Ensure parent dataclass __match_args__ is seen, if child class
4093        # specifies match_args=False.
4094        @dataclass
4095        class A:
4096            a: int
4097            z: int
4098        @dataclass(match_args=False)
4099        class B(A):
4100            b: int
4101        self.assertEqual(B.__match_args__, ('a', 'z'))
4102
4103    def test_make_dataclasses(self):
4104        C = make_dataclass('C', [('x', int), ('y', int)])
4105        self.assertEqual(C.__match_args__, ('x', 'y'))
4106
4107        C = make_dataclass('C', [('x', int), ('y', int)], match_args=True)
4108        self.assertEqual(C.__match_args__, ('x', 'y'))
4109
4110        C = make_dataclass('C', [('x', int), ('y', int)], match_args=False)
4111        self.assertNotIn('__match__args__', C.__dict__)
4112
4113        C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)})
4114        self.assertEqual(C.__match_args__, ('z',))
4115
4116
4117class TestKeywordArgs(unittest.TestCase):
4118    def test_no_classvar_kwarg(self):
4119        msg = 'field a is a ClassVar but specifies kw_only'
4120        with self.assertRaisesRegex(TypeError, msg):
4121            @dataclass
4122            class A:
4123                a: ClassVar[int] = field(kw_only=True)
4124
4125        with self.assertRaisesRegex(TypeError, msg):
4126            @dataclass
4127            class A:
4128                a: ClassVar[int] = field(kw_only=False)
4129
4130        with self.assertRaisesRegex(TypeError, msg):
4131            @dataclass(kw_only=True)
4132            class A:
4133                a: ClassVar[int] = field(kw_only=False)
4134
4135    def test_field_marked_as_kwonly(self):
4136        #######################
4137        # Using dataclass(kw_only=True)
4138        @dataclass(kw_only=True)
4139        class A:
4140            a: int
4141        self.assertTrue(fields(A)[0].kw_only)
4142
4143        @dataclass(kw_only=True)
4144        class A:
4145            a: int = field(kw_only=True)
4146        self.assertTrue(fields(A)[0].kw_only)
4147
4148        @dataclass(kw_only=True)
4149        class A:
4150            a: int = field(kw_only=False)
4151        self.assertFalse(fields(A)[0].kw_only)
4152
4153        #######################
4154        # Using dataclass(kw_only=False)
4155        @dataclass(kw_only=False)
4156        class A:
4157            a: int
4158        self.assertFalse(fields(A)[0].kw_only)
4159
4160        @dataclass(kw_only=False)
4161        class A:
4162            a: int = field(kw_only=True)
4163        self.assertTrue(fields(A)[0].kw_only)
4164
4165        @dataclass(kw_only=False)
4166        class A:
4167            a: int = field(kw_only=False)
4168        self.assertFalse(fields(A)[0].kw_only)
4169
4170        #######################
4171        # Not specifying dataclass(kw_only)
4172        @dataclass
4173        class A:
4174            a: int
4175        self.assertFalse(fields(A)[0].kw_only)
4176
4177        @dataclass
4178        class A:
4179            a: int = field(kw_only=True)
4180        self.assertTrue(fields(A)[0].kw_only)
4181
4182        @dataclass
4183        class A:
4184            a: int = field(kw_only=False)
4185        self.assertFalse(fields(A)[0].kw_only)
4186
4187    def test_match_args(self):
4188        # kw fields don't show up in __match_args__.
4189        @dataclass(kw_only=True)
4190        class C:
4191            a: int
4192        self.assertEqual(C(a=42).__match_args__, ())
4193
4194        @dataclass
4195        class C:
4196            a: int
4197            b: int = field(kw_only=True)
4198        self.assertEqual(C(42, b=10).__match_args__, ('a',))
4199
4200    def test_KW_ONLY(self):
4201        @dataclass
4202        class A:
4203            a: int
4204            _: KW_ONLY
4205            b: int
4206            c: int
4207        A(3, c=5, b=4)
4208        msg = "takes 2 positional arguments but 4 were given"
4209        with self.assertRaisesRegex(TypeError, msg):
4210            A(3, 4, 5)
4211
4212
4213        @dataclass(kw_only=True)
4214        class B:
4215            a: int
4216            _: KW_ONLY
4217            b: int
4218            c: int
4219        B(a=3, b=4, c=5)
4220        msg = "takes 1 positional argument but 4 were given"
4221        with self.assertRaisesRegex(TypeError, msg):
4222            B(3, 4, 5)
4223
4224        # Explicitly make a field that follows KW_ONLY be non-keyword-only.
4225        @dataclass
4226        class C:
4227            a: int
4228            _: KW_ONLY
4229            b: int
4230            c: int = field(kw_only=False)
4231        c = C(1, 2, b=3)
4232        self.assertEqual(c.a, 1)
4233        self.assertEqual(c.b, 3)
4234        self.assertEqual(c.c, 2)
4235        c = C(1, b=3, c=2)
4236        self.assertEqual(c.a, 1)
4237        self.assertEqual(c.b, 3)
4238        self.assertEqual(c.c, 2)
4239        c = C(1, b=3, c=2)
4240        self.assertEqual(c.a, 1)
4241        self.assertEqual(c.b, 3)
4242        self.assertEqual(c.c, 2)
4243        c = C(c=2, b=3, a=1)
4244        self.assertEqual(c.a, 1)
4245        self.assertEqual(c.b, 3)
4246        self.assertEqual(c.c, 2)
4247
4248    def test_KW_ONLY_as_string(self):
4249        @dataclass
4250        class A:
4251            a: int
4252            _: 'dataclasses.KW_ONLY'
4253            b: int
4254            c: int
4255        A(3, c=5, b=4)
4256        msg = "takes 2 positional arguments but 4 were given"
4257        with self.assertRaisesRegex(TypeError, msg):
4258            A(3, 4, 5)
4259
4260    def test_KW_ONLY_twice(self):
4261        msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified"
4262
4263        with self.assertRaisesRegex(TypeError, msg):
4264            @dataclass
4265            class A:
4266                a: int
4267                X: KW_ONLY
4268                Y: KW_ONLY
4269                b: int
4270                c: int
4271
4272        with self.assertRaisesRegex(TypeError, msg):
4273            @dataclass
4274            class A:
4275                a: int
4276                X: KW_ONLY
4277                b: int
4278                Y: KW_ONLY
4279                c: int
4280
4281        with self.assertRaisesRegex(TypeError, msg):
4282            @dataclass
4283            class A:
4284                a: int
4285                X: KW_ONLY
4286                b: int
4287                c: int
4288                Y: KW_ONLY
4289
4290        # But this usage is okay, since it's not using KW_ONLY.
4291        @dataclass
4292        class A:
4293            a: int
4294            _: KW_ONLY
4295            b: int
4296            c: int = field(kw_only=True)
4297
4298        # And if inheriting, it's okay.
4299        @dataclass
4300        class A:
4301            a: int
4302            _: KW_ONLY
4303            b: int
4304            c: int
4305        @dataclass
4306        class B(A):
4307            _: KW_ONLY
4308            d: int
4309
4310        # Make sure the error is raised in a derived class.
4311        with self.assertRaisesRegex(TypeError, msg):
4312            @dataclass
4313            class A:
4314                a: int
4315                _: KW_ONLY
4316                b: int
4317                c: int
4318            @dataclass
4319            class B(A):
4320                X: KW_ONLY
4321                d: int
4322                Y: KW_ONLY
4323
4324
4325    def test_post_init(self):
4326        @dataclass
4327        class A:
4328            a: int
4329            _: KW_ONLY
4330            b: InitVar[int]
4331            c: int
4332            d: InitVar[int]
4333            def __post_init__(self, b, d):
4334                raise CustomError(f'{b=} {d=}')
4335        with self.assertRaisesRegex(CustomError, 'b=3 d=4'):
4336            A(1, c=2, b=3, d=4)
4337
4338        @dataclass
4339        class B:
4340            a: int
4341            _: KW_ONLY
4342            b: InitVar[int]
4343            c: int
4344            d: InitVar[int]
4345            def __post_init__(self, b, d):
4346                self.a = b
4347                self.c = d
4348        b = B(1, c=2, b=3, d=4)
4349        self.assertEqual(asdict(b), {'a': 3, 'c': 4})
4350
4351    def test_defaults(self):
4352        # For kwargs, make sure we can have defaults after non-defaults.
4353        @dataclass
4354        class A:
4355            a: int = 0
4356            _: KW_ONLY
4357            b: int
4358            c: int = 1
4359            d: int
4360
4361        a = A(d=4, b=3)
4362        self.assertEqual(a.a, 0)
4363        self.assertEqual(a.b, 3)
4364        self.assertEqual(a.c, 1)
4365        self.assertEqual(a.d, 4)
4366
4367        # Make sure we still check for non-kwarg non-defaults not following
4368        # defaults.
4369        err_regex = "non-default argument 'z' follows default argument"
4370        with self.assertRaisesRegex(TypeError, err_regex):
4371            @dataclass
4372            class A:
4373                a: int = 0
4374                z: int
4375                _: KW_ONLY
4376                b: int
4377                c: int = 1
4378                d: int
4379
4380    def test_make_dataclass(self):
4381        A = make_dataclass("A", ['a'], kw_only=True)
4382        self.assertTrue(fields(A)[0].kw_only)
4383
4384        B = make_dataclass("B",
4385                           ['a', ('b', int, field(kw_only=False))],
4386                           kw_only=True)
4387        self.assertTrue(fields(B)[0].kw_only)
4388        self.assertFalse(fields(B)[1].kw_only)
4389
4390
4391if __name__ == '__main__':
4392    unittest.main()
4393