1# Deliberately use "from dataclasses import *". Every name in __all__ 2# is tested, so they all must be present. This is a way to catch 3# missing ones. 4 5from dataclasses import * 6 7import abc 8import io 9import pickle 10import inspect 11import builtins 12import types 13import weakref 14import traceback 15import unittest 16from unittest.mock import Mock 17from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol 18from typing import get_type_hints 19from collections import deque, OrderedDict, namedtuple 20from functools import total_ordering 21 22import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation. 23import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation. 24 25# Just any custom exception we can catch. 26class CustomError(Exception): pass 27 28class TestCase(unittest.TestCase): 29 def test_no_fields(self): 30 @dataclass 31 class C: 32 pass 33 34 o = C() 35 self.assertEqual(len(fields(C)), 0) 36 37 def test_no_fields_but_member_variable(self): 38 @dataclass 39 class C: 40 i = 0 41 42 o = C() 43 self.assertEqual(len(fields(C)), 0) 44 45 def test_one_field_no_default(self): 46 @dataclass 47 class C: 48 x: int 49 50 o = C(42) 51 self.assertEqual(o.x, 42) 52 53 def test_field_default_default_factory_error(self): 54 msg = "cannot specify both default and default_factory" 55 with self.assertRaisesRegex(ValueError, msg): 56 @dataclass 57 class C: 58 x: int = field(default=1, default_factory=int) 59 60 def test_field_repr(self): 61 int_field = field(default=1, init=True, repr=False) 62 int_field.name = "id" 63 repr_output = repr(int_field) 64 expected_output = "Field(name='id',type=None," \ 65 f"default=1,default_factory={MISSING!r}," \ 66 "init=True,repr=False,hash=None," \ 67 "compare=True,metadata=mappingproxy({})," \ 68 f"kw_only={MISSING!r}," \ 69 "_field_type=None)" 70 71 self.assertEqual(repr_output, expected_output) 72 73 def test_field_recursive_repr(self): 74 rec_field = field() 75 rec_field.type = rec_field 76 rec_field.name = "id" 77 repr_output = repr(rec_field) 78 79 self.assertIn(",type=...,", repr_output) 80 81 def test_recursive_annotation(self): 82 class C: 83 pass 84 85 @dataclass 86 class D: 87 C: C = field() 88 89 self.assertIn(",type=...,", repr(D.__dataclass_fields__["C"])) 90 91 def test_named_init_params(self): 92 @dataclass 93 class C: 94 x: int 95 96 o = C(x=32) 97 self.assertEqual(o.x, 32) 98 99 def test_two_fields_one_default(self): 100 @dataclass 101 class C: 102 x: int 103 y: int = 0 104 105 o = C(3) 106 self.assertEqual((o.x, o.y), (3, 0)) 107 108 # Non-defaults following defaults. 109 with self.assertRaisesRegex(TypeError, 110 "non-default argument 'y' follows " 111 "default argument"): 112 @dataclass 113 class C: 114 x: int = 0 115 y: int 116 117 # A derived class adds a non-default field after a default one. 118 with self.assertRaisesRegex(TypeError, 119 "non-default argument 'y' follows " 120 "default argument"): 121 @dataclass 122 class B: 123 x: int = 0 124 125 @dataclass 126 class C(B): 127 y: int 128 129 # Override a base class field and add a default to 130 # a field which didn't use to have a default. 131 with self.assertRaisesRegex(TypeError, 132 "non-default argument 'y' follows " 133 "default argument"): 134 @dataclass 135 class B: 136 x: int 137 y: int 138 139 @dataclass 140 class C(B): 141 x: int = 0 142 143 def test_overwrite_hash(self): 144 # Test that declaring this class isn't an error. It should 145 # use the user-provided __hash__. 146 @dataclass(frozen=True) 147 class C: 148 x: int 149 def __hash__(self): 150 return 301 151 self.assertEqual(hash(C(100)), 301) 152 153 # Test that declaring this class isn't an error. It should 154 # use the generated __hash__. 155 @dataclass(frozen=True) 156 class C: 157 x: int 158 def __eq__(self, other): 159 return False 160 self.assertEqual(hash(C(100)), hash((100,))) 161 162 # But this one should generate an exception, because with 163 # unsafe_hash=True, it's an error to have a __hash__ defined. 164 with self.assertRaisesRegex(TypeError, 165 'Cannot overwrite attribute __hash__'): 166 @dataclass(unsafe_hash=True) 167 class C: 168 def __hash__(self): 169 pass 170 171 # Creating this class should not generate an exception, 172 # because even though __hash__ exists before @dataclass is 173 # called, (due to __eq__ being defined), since it's None 174 # that's okay. 175 @dataclass(unsafe_hash=True) 176 class C: 177 x: int 178 def __eq__(self): 179 pass 180 # The generated hash function works as we'd expect. 181 self.assertEqual(hash(C(10)), hash((10,))) 182 183 # Creating this class should generate an exception, because 184 # __hash__ exists and is not None, which it would be if it 185 # had been auto-generated due to __eq__ being defined. 186 with self.assertRaisesRegex(TypeError, 187 'Cannot overwrite attribute __hash__'): 188 @dataclass(unsafe_hash=True) 189 class C: 190 x: int 191 def __eq__(self): 192 pass 193 def __hash__(self): 194 pass 195 196 def test_overwrite_fields_in_derived_class(self): 197 # Note that x from C1 replaces x in Base, but the order remains 198 # the same as defined in Base. 199 @dataclass 200 class Base: 201 x: Any = 15.0 202 y: int = 0 203 204 @dataclass 205 class C1(Base): 206 z: int = 10 207 x: int = 15 208 209 o = Base() 210 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)') 211 212 o = C1() 213 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)') 214 215 o = C1(x=5) 216 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)') 217 218 def test_field_named_self(self): 219 @dataclass 220 class C: 221 self: str 222 c=C('foo') 223 self.assertEqual(c.self, 'foo') 224 225 # Make sure the first parameter is not named 'self'. 226 sig = inspect.signature(C.__init__) 227 first = next(iter(sig.parameters)) 228 self.assertNotEqual('self', first) 229 230 # But we do use 'self' if no field named self. 231 @dataclass 232 class C: 233 selfx: str 234 235 # Make sure the first parameter is named 'self'. 236 sig = inspect.signature(C.__init__) 237 first = next(iter(sig.parameters)) 238 self.assertEqual('self', first) 239 240 def test_field_named_object(self): 241 @dataclass 242 class C: 243 object: str 244 c = C('foo') 245 self.assertEqual(c.object, 'foo') 246 247 def test_field_named_object_frozen(self): 248 @dataclass(frozen=True) 249 class C: 250 object: str 251 c = C('foo') 252 self.assertEqual(c.object, 'foo') 253 254 def test_field_named_BUILTINS_frozen(self): 255 # gh-96151 256 @dataclass(frozen=True) 257 class C: 258 BUILTINS: int 259 c = C(5) 260 self.assertEqual(c.BUILTINS, 5) 261 262 def test_field_named_like_builtin(self): 263 # Attribute names can shadow built-in names 264 # since code generation is used. 265 # Ensure that this is not happening. 266 exclusions = {'None', 'True', 'False'} 267 builtins_names = sorted( 268 b for b in builtins.__dict__.keys() 269 if not b.startswith('__') and b not in exclusions 270 ) 271 attributes = [(name, str) for name in builtins_names] 272 C = make_dataclass('C', attributes) 273 274 c = C(*[name for name in builtins_names]) 275 276 for name in builtins_names: 277 self.assertEqual(getattr(c, name), name) 278 279 def test_field_named_like_builtin_frozen(self): 280 # Attribute names can shadow built-in names 281 # since code generation is used. 282 # Ensure that this is not happening 283 # for frozen data classes. 284 exclusions = {'None', 'True', 'False'} 285 builtins_names = sorted( 286 b for b in builtins.__dict__.keys() 287 if not b.startswith('__') and b not in exclusions 288 ) 289 attributes = [(name, str) for name in builtins_names] 290 C = make_dataclass('C', attributes, frozen=True) 291 292 c = C(*[name for name in builtins_names]) 293 294 for name in builtins_names: 295 self.assertEqual(getattr(c, name), name) 296 297 def test_0_field_compare(self): 298 # Ensure that order=False is the default. 299 @dataclass 300 class C0: 301 pass 302 303 @dataclass(order=False) 304 class C1: 305 pass 306 307 for cls in [C0, C1]: 308 with self.subTest(cls=cls): 309 self.assertEqual(cls(), cls()) 310 for idx, fn in enumerate([lambda a, b: a < b, 311 lambda a, b: a <= b, 312 lambda a, b: a > b, 313 lambda a, b: a >= b]): 314 with self.subTest(idx=idx): 315 with self.assertRaisesRegex(TypeError, 316 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 317 fn(cls(), cls()) 318 319 @dataclass(order=True) 320 class C: 321 pass 322 self.assertLessEqual(C(), C()) 323 self.assertGreaterEqual(C(), C()) 324 325 def test_1_field_compare(self): 326 # Ensure that order=False is the default. 327 @dataclass 328 class C0: 329 x: int 330 331 @dataclass(order=False) 332 class C1: 333 x: int 334 335 for cls in [C0, C1]: 336 with self.subTest(cls=cls): 337 self.assertEqual(cls(1), cls(1)) 338 self.assertNotEqual(cls(0), cls(1)) 339 for idx, fn in enumerate([lambda a, b: a < b, 340 lambda a, b: a <= b, 341 lambda a, b: a > b, 342 lambda a, b: a >= b]): 343 with self.subTest(idx=idx): 344 with self.assertRaisesRegex(TypeError, 345 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 346 fn(cls(0), cls(0)) 347 348 @dataclass(order=True) 349 class C: 350 x: int 351 self.assertLess(C(0), C(1)) 352 self.assertLessEqual(C(0), C(1)) 353 self.assertLessEqual(C(1), C(1)) 354 self.assertGreater(C(1), C(0)) 355 self.assertGreaterEqual(C(1), C(0)) 356 self.assertGreaterEqual(C(1), C(1)) 357 358 def test_simple_compare(self): 359 # Ensure that order=False is the default. 360 @dataclass 361 class C0: 362 x: int 363 y: int 364 365 @dataclass(order=False) 366 class C1: 367 x: int 368 y: int 369 370 for cls in [C0, C1]: 371 with self.subTest(cls=cls): 372 self.assertEqual(cls(0, 0), cls(0, 0)) 373 self.assertEqual(cls(1, 2), cls(1, 2)) 374 self.assertNotEqual(cls(1, 0), cls(0, 0)) 375 self.assertNotEqual(cls(1, 0), cls(1, 1)) 376 for idx, fn in enumerate([lambda a, b: a < b, 377 lambda a, b: a <= b, 378 lambda a, b: a > b, 379 lambda a, b: a >= b]): 380 with self.subTest(idx=idx): 381 with self.assertRaisesRegex(TypeError, 382 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 383 fn(cls(0, 0), cls(0, 0)) 384 385 @dataclass(order=True) 386 class C: 387 x: int 388 y: int 389 390 for idx, fn in enumerate([lambda a, b: a == b, 391 lambda a, b: a <= b, 392 lambda a, b: a >= b]): 393 with self.subTest(idx=idx): 394 self.assertTrue(fn(C(0, 0), C(0, 0))) 395 396 for idx, fn in enumerate([lambda a, b: a < b, 397 lambda a, b: a <= b, 398 lambda a, b: a != b]): 399 with self.subTest(idx=idx): 400 self.assertTrue(fn(C(0, 0), C(0, 1))) 401 self.assertTrue(fn(C(0, 1), C(1, 0))) 402 self.assertTrue(fn(C(1, 0), C(1, 1))) 403 404 for idx, fn in enumerate([lambda a, b: a > b, 405 lambda a, b: a >= b, 406 lambda a, b: a != b]): 407 with self.subTest(idx=idx): 408 self.assertTrue(fn(C(0, 1), C(0, 0))) 409 self.assertTrue(fn(C(1, 0), C(0, 1))) 410 self.assertTrue(fn(C(1, 1), C(1, 0))) 411 412 def test_compare_subclasses(self): 413 # Comparisons fail for subclasses, even if no fields 414 # are added. 415 @dataclass 416 class B: 417 i: int 418 419 @dataclass 420 class C(B): 421 pass 422 423 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False), 424 (lambda a, b: a != b, True)]): 425 with self.subTest(idx=idx): 426 self.assertEqual(fn(B(0), C(0)), expected) 427 428 for idx, fn in enumerate([lambda a, b: a < b, 429 lambda a, b: a <= b, 430 lambda a, b: a > b, 431 lambda a, b: a >= b]): 432 with self.subTest(idx=idx): 433 with self.assertRaisesRegex(TypeError, 434 "not supported between instances of 'B' and 'C'"): 435 fn(B(0), C(0)) 436 437 def test_eq_order(self): 438 # Test combining eq and order. 439 for (eq, order, result ) in [ 440 (False, False, 'neither'), 441 (False, True, 'exception'), 442 (True, False, 'eq_only'), 443 (True, True, 'both'), 444 ]: 445 with self.subTest(eq=eq, order=order): 446 if result == 'exception': 447 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'): 448 @dataclass(eq=eq, order=order) 449 class C: 450 pass 451 else: 452 @dataclass(eq=eq, order=order) 453 class C: 454 pass 455 456 if result == 'neither': 457 self.assertNotIn('__eq__', C.__dict__) 458 self.assertNotIn('__lt__', C.__dict__) 459 self.assertNotIn('__le__', C.__dict__) 460 self.assertNotIn('__gt__', C.__dict__) 461 self.assertNotIn('__ge__', C.__dict__) 462 elif result == 'both': 463 self.assertIn('__eq__', C.__dict__) 464 self.assertIn('__lt__', C.__dict__) 465 self.assertIn('__le__', C.__dict__) 466 self.assertIn('__gt__', C.__dict__) 467 self.assertIn('__ge__', C.__dict__) 468 elif result == 'eq_only': 469 self.assertIn('__eq__', C.__dict__) 470 self.assertNotIn('__lt__', C.__dict__) 471 self.assertNotIn('__le__', C.__dict__) 472 self.assertNotIn('__gt__', C.__dict__) 473 self.assertNotIn('__ge__', C.__dict__) 474 else: 475 assert False, f'unknown result {result!r}' 476 477 def test_field_no_default(self): 478 @dataclass 479 class C: 480 x: int = field() 481 482 self.assertEqual(C(5).x, 5) 483 484 with self.assertRaisesRegex(TypeError, 485 r"__init__\(\) missing 1 required " 486 "positional argument: 'x'"): 487 C() 488 489 def test_field_default(self): 490 default = object() 491 @dataclass 492 class C: 493 x: object = field(default=default) 494 495 self.assertIs(C.x, default) 496 c = C(10) 497 self.assertEqual(c.x, 10) 498 499 # If we delete the instance attribute, we should then see the 500 # class attribute. 501 del c.x 502 self.assertIs(c.x, default) 503 504 self.assertIs(C().x, default) 505 506 def test_not_in_repr(self): 507 @dataclass 508 class C: 509 x: int = field(repr=False) 510 with self.assertRaises(TypeError): 511 C() 512 c = C(10) 513 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()') 514 515 @dataclass 516 class C: 517 x: int = field(repr=False) 518 y: int 519 c = C(10, 20) 520 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)') 521 522 def test_not_in_compare(self): 523 @dataclass 524 class C: 525 x: int = 0 526 y: int = field(compare=False, default=4) 527 528 self.assertEqual(C(), C(0, 20)) 529 self.assertEqual(C(1, 10), C(1, 20)) 530 self.assertNotEqual(C(3), C(4, 10)) 531 self.assertNotEqual(C(3, 10), C(4, 10)) 532 533 def test_no_unhashable_default(self): 534 # See bpo-44674. 535 class Unhashable: 536 __hash__ = None 537 538 unhashable_re = 'mutable default .* for field a is not allowed' 539 with self.assertRaisesRegex(ValueError, unhashable_re): 540 @dataclass 541 class A: 542 a: dict = {} 543 544 with self.assertRaisesRegex(ValueError, unhashable_re): 545 @dataclass 546 class A: 547 a: Any = Unhashable() 548 549 # Make sure that the machinery looking for hashability is using the 550 # class's __hash__, not the instance's __hash__. 551 with self.assertRaisesRegex(ValueError, unhashable_re): 552 unhashable = Unhashable() 553 # This shouldn't make the variable hashable. 554 unhashable.__hash__ = lambda: 0 555 @dataclass 556 class A: 557 a: Any = unhashable 558 559 def test_hash_field_rules(self): 560 # Test all 6 cases of: 561 # hash=True/False/None 562 # compare=True/False 563 for (hash_, compare, result ) in [ 564 (True, False, 'field' ), 565 (True, True, 'field' ), 566 (False, False, 'absent'), 567 (False, True, 'absent'), 568 (None, False, 'absent'), 569 (None, True, 'field' ), 570 ]: 571 with self.subTest(hash=hash_, compare=compare): 572 @dataclass(unsafe_hash=True) 573 class C: 574 x: int = field(compare=compare, hash=hash_, default=5) 575 576 if result == 'field': 577 # __hash__ contains the field. 578 self.assertEqual(hash(C(5)), hash((5,))) 579 elif result == 'absent': 580 # The field is not present in the hash. 581 self.assertEqual(hash(C(5)), hash(())) 582 else: 583 assert False, f'unknown result {result!r}' 584 585 def test_init_false_no_default(self): 586 # If init=False and no default value, then the field won't be 587 # present in the instance. 588 @dataclass 589 class C: 590 x: int = field(init=False) 591 592 self.assertNotIn('x', C().__dict__) 593 594 @dataclass 595 class C: 596 x: int 597 y: int = 0 598 z: int = field(init=False) 599 t: int = 10 600 601 self.assertNotIn('z', C(0).__dict__) 602 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0}) 603 604 def test_class_marker(self): 605 @dataclass 606 class C: 607 x: int 608 y: str = field(init=False, default=None) 609 z: str = field(repr=False) 610 611 the_fields = fields(C) 612 # the_fields is a tuple of 3 items, each value 613 # is in __annotations__. 614 self.assertIsInstance(the_fields, tuple) 615 for f in the_fields: 616 self.assertIs(type(f), Field) 617 self.assertIn(f.name, C.__annotations__) 618 619 self.assertEqual(len(the_fields), 3) 620 621 self.assertEqual(the_fields[0].name, 'x') 622 self.assertEqual(the_fields[0].type, int) 623 self.assertFalse(hasattr(C, 'x')) 624 self.assertTrue (the_fields[0].init) 625 self.assertTrue (the_fields[0].repr) 626 self.assertEqual(the_fields[1].name, 'y') 627 self.assertEqual(the_fields[1].type, str) 628 self.assertIsNone(getattr(C, 'y')) 629 self.assertFalse(the_fields[1].init) 630 self.assertTrue (the_fields[1].repr) 631 self.assertEqual(the_fields[2].name, 'z') 632 self.assertEqual(the_fields[2].type, str) 633 self.assertFalse(hasattr(C, 'z')) 634 self.assertTrue (the_fields[2].init) 635 self.assertFalse(the_fields[2].repr) 636 637 def test_field_order(self): 638 @dataclass 639 class B: 640 a: str = 'B:a' 641 b: str = 'B:b' 642 c: str = 'B:c' 643 644 @dataclass 645 class C(B): 646 b: str = 'C:b' 647 648 self.assertEqual([(f.name, f.default) for f in fields(C)], 649 [('a', 'B:a'), 650 ('b', 'C:b'), 651 ('c', 'B:c')]) 652 653 @dataclass 654 class D(B): 655 c: str = 'D:c' 656 657 self.assertEqual([(f.name, f.default) for f in fields(D)], 658 [('a', 'B:a'), 659 ('b', 'B:b'), 660 ('c', 'D:c')]) 661 662 @dataclass 663 class E(D): 664 a: str = 'E:a' 665 d: str = 'E:d' 666 667 self.assertEqual([(f.name, f.default) for f in fields(E)], 668 [('a', 'E:a'), 669 ('b', 'B:b'), 670 ('c', 'D:c'), 671 ('d', 'E:d')]) 672 673 def test_class_attrs(self): 674 # We only have a class attribute if a default value is 675 # specified, either directly or via a field with a default. 676 default = object() 677 @dataclass 678 class C: 679 x: int 680 y: int = field(repr=False) 681 z: object = default 682 t: int = field(default=100) 683 684 self.assertFalse(hasattr(C, 'x')) 685 self.assertFalse(hasattr(C, 'y')) 686 self.assertIs (C.z, default) 687 self.assertEqual(C.t, 100) 688 689 def test_disallowed_mutable_defaults(self): 690 # For the known types, don't allow mutable default values. 691 for typ, empty, non_empty in [(list, [], [1]), 692 (dict, {}, {0:1}), 693 (set, set(), set([1])), 694 ]: 695 with self.subTest(typ=typ): 696 # Can't use a zero-length value. 697 with self.assertRaisesRegex(ValueError, 698 f'mutable default {typ} for field ' 699 'x is not allowed'): 700 @dataclass 701 class Point: 702 x: typ = empty 703 704 705 # Nor a non-zero-length value 706 with self.assertRaisesRegex(ValueError, 707 f'mutable default {typ} for field ' 708 'y is not allowed'): 709 @dataclass 710 class Point: 711 y: typ = non_empty 712 713 # Check subtypes also fail. 714 class Subclass(typ): pass 715 716 with self.assertRaisesRegex(ValueError, 717 f"mutable default .*Subclass'>" 718 ' for field z is not allowed' 719 ): 720 @dataclass 721 class Point: 722 z: typ = Subclass() 723 724 # Because this is a ClassVar, it can be mutable. 725 @dataclass 726 class C: 727 z: ClassVar[typ] = typ() 728 729 # Because this is a ClassVar, it can be mutable. 730 @dataclass 731 class C: 732 x: ClassVar[typ] = Subclass() 733 734 def test_deliberately_mutable_defaults(self): 735 # If a mutable default isn't in the known list of 736 # (list, dict, set), then it's okay. 737 class Mutable: 738 def __init__(self): 739 self.l = [] 740 741 @dataclass 742 class C: 743 x: Mutable 744 745 # These 2 instances will share this value of x. 746 lst = Mutable() 747 o1 = C(lst) 748 o2 = C(lst) 749 self.assertEqual(o1, o2) 750 o1.x.l.extend([1, 2]) 751 self.assertEqual(o1, o2) 752 self.assertEqual(o1.x.l, [1, 2]) 753 self.assertIs(o1.x, o2.x) 754 755 def test_no_options(self): 756 # Call with dataclass(). 757 @dataclass() 758 class C: 759 x: int 760 761 self.assertEqual(C(42).x, 42) 762 763 def test_not_tuple(self): 764 # Make sure we can't be compared to a tuple. 765 @dataclass 766 class Point: 767 x: int 768 y: int 769 self.assertNotEqual(Point(1, 2), (1, 2)) 770 771 # And that we can't compare to another unrelated dataclass. 772 @dataclass 773 class C: 774 x: int 775 y: int 776 self.assertNotEqual(Point(1, 3), C(1, 3)) 777 778 def test_not_other_dataclass(self): 779 # Test that some of the problems with namedtuple don't happen 780 # here. 781 @dataclass 782 class Point3D: 783 x: int 784 y: int 785 z: int 786 787 @dataclass 788 class Date: 789 year: int 790 month: int 791 day: int 792 793 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3)) 794 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3)) 795 796 # Make sure we can't unpack. 797 with self.assertRaisesRegex(TypeError, 'unpack'): 798 x, y, z = Point3D(4, 5, 6) 799 800 # Make sure another class with the same field names isn't 801 # equal. 802 @dataclass 803 class Point3Dv1: 804 x: int = 0 805 y: int = 0 806 z: int = 0 807 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1()) 808 809 def test_function_annotations(self): 810 # Some dummy class and instance to use as a default. 811 class F: 812 pass 813 f = F() 814 815 def validate_class(cls): 816 # First, check __annotations__, even though they're not 817 # function annotations. 818 self.assertEqual(cls.__annotations__['i'], int) 819 self.assertEqual(cls.__annotations__['j'], str) 820 self.assertEqual(cls.__annotations__['k'], F) 821 self.assertEqual(cls.__annotations__['l'], float) 822 self.assertEqual(cls.__annotations__['z'], complex) 823 824 # Verify __init__. 825 826 signature = inspect.signature(cls.__init__) 827 # Check the return type, should be None. 828 self.assertIs(signature.return_annotation, None) 829 830 # Check each parameter. 831 params = iter(signature.parameters.values()) 832 param = next(params) 833 # This is testing an internal name, and probably shouldn't be tested. 834 self.assertEqual(param.name, 'self') 835 param = next(params) 836 self.assertEqual(param.name, 'i') 837 self.assertIs (param.annotation, int) 838 self.assertEqual(param.default, inspect.Parameter.empty) 839 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 840 param = next(params) 841 self.assertEqual(param.name, 'j') 842 self.assertIs (param.annotation, str) 843 self.assertEqual(param.default, inspect.Parameter.empty) 844 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 845 param = next(params) 846 self.assertEqual(param.name, 'k') 847 self.assertIs (param.annotation, F) 848 # Don't test for the default, since it's set to MISSING. 849 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 850 param = next(params) 851 self.assertEqual(param.name, 'l') 852 self.assertIs (param.annotation, float) 853 # Don't test for the default, since it's set to MISSING. 854 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 855 self.assertRaises(StopIteration, next, params) 856 857 858 @dataclass 859 class C: 860 i: int 861 j: str 862 k: F = f 863 l: float=field(default=None) 864 z: complex=field(default=3+4j, init=False) 865 866 validate_class(C) 867 868 # Now repeat with __hash__. 869 @dataclass(frozen=True, unsafe_hash=True) 870 class C: 871 i: int 872 j: str 873 k: F = f 874 l: float=field(default=None) 875 z: complex=field(default=3+4j, init=False) 876 877 validate_class(C) 878 879 def test_missing_default(self): 880 # Test that MISSING works the same as a default not being 881 # specified. 882 @dataclass 883 class C: 884 x: int=field(default=MISSING) 885 with self.assertRaisesRegex(TypeError, 886 r'__init__\(\) missing 1 required ' 887 'positional argument'): 888 C() 889 self.assertNotIn('x', C.__dict__) 890 891 @dataclass 892 class D: 893 x: int 894 with self.assertRaisesRegex(TypeError, 895 r'__init__\(\) missing 1 required ' 896 'positional argument'): 897 D() 898 self.assertNotIn('x', D.__dict__) 899 900 def test_missing_default_factory(self): 901 # Test that MISSING works the same as a default factory not 902 # being specified (which is really the same as a default not 903 # being specified, too). 904 @dataclass 905 class C: 906 x: int=field(default_factory=MISSING) 907 with self.assertRaisesRegex(TypeError, 908 r'__init__\(\) missing 1 required ' 909 'positional argument'): 910 C() 911 self.assertNotIn('x', C.__dict__) 912 913 @dataclass 914 class D: 915 x: int=field(default=MISSING, default_factory=MISSING) 916 with self.assertRaisesRegex(TypeError, 917 r'__init__\(\) missing 1 required ' 918 'positional argument'): 919 D() 920 self.assertNotIn('x', D.__dict__) 921 922 def test_missing_repr(self): 923 self.assertIn('MISSING_TYPE object', repr(MISSING)) 924 925 def test_dont_include_other_annotations(self): 926 @dataclass 927 class C: 928 i: int 929 def foo(self) -> int: 930 return 4 931 @property 932 def bar(self) -> int: 933 return 5 934 self.assertEqual(list(C.__annotations__), ['i']) 935 self.assertEqual(C(10).foo(), 4) 936 self.assertEqual(C(10).bar, 5) 937 self.assertEqual(C(10).i, 10) 938 939 def test_post_init(self): 940 # Just make sure it gets called 941 @dataclass 942 class C: 943 def __post_init__(self): 944 raise CustomError() 945 with self.assertRaises(CustomError): 946 C() 947 948 @dataclass 949 class C: 950 i: int = 10 951 def __post_init__(self): 952 if self.i == 10: 953 raise CustomError() 954 with self.assertRaises(CustomError): 955 C() 956 # post-init gets called, but doesn't raise. This is just 957 # checking that self is used correctly. 958 C(5) 959 960 # If there's not an __init__, then post-init won't get called. 961 @dataclass(init=False) 962 class C: 963 def __post_init__(self): 964 raise CustomError() 965 # Creating the class won't raise 966 C() 967 968 @dataclass 969 class C: 970 x: int = 0 971 def __post_init__(self): 972 self.x *= 2 973 self.assertEqual(C().x, 0) 974 self.assertEqual(C(2).x, 4) 975 976 # Make sure that if we're frozen, post-init can't set 977 # attributes. 978 @dataclass(frozen=True) 979 class C: 980 x: int = 0 981 def __post_init__(self): 982 self.x *= 2 983 with self.assertRaises(FrozenInstanceError): 984 C() 985 986 def test_post_init_super(self): 987 # Make sure super() post-init isn't called by default. 988 class B: 989 def __post_init__(self): 990 raise CustomError() 991 992 @dataclass 993 class C(B): 994 def __post_init__(self): 995 self.x = 5 996 997 self.assertEqual(C().x, 5) 998 999 # Now call super(), and it will raise. 1000 @dataclass 1001 class C(B): 1002 def __post_init__(self): 1003 super().__post_init__() 1004 1005 with self.assertRaises(CustomError): 1006 C() 1007 1008 # Make sure post-init is called, even if not defined in our 1009 # class. 1010 @dataclass 1011 class C(B): 1012 pass 1013 1014 with self.assertRaises(CustomError): 1015 C() 1016 1017 def test_post_init_staticmethod(self): 1018 flag = False 1019 @dataclass 1020 class C: 1021 x: int 1022 y: int 1023 @staticmethod 1024 def __post_init__(): 1025 nonlocal flag 1026 flag = True 1027 1028 self.assertFalse(flag) 1029 c = C(3, 4) 1030 self.assertEqual((c.x, c.y), (3, 4)) 1031 self.assertTrue(flag) 1032 1033 def test_post_init_classmethod(self): 1034 @dataclass 1035 class C: 1036 flag = False 1037 x: int 1038 y: int 1039 @classmethod 1040 def __post_init__(cls): 1041 cls.flag = True 1042 1043 self.assertFalse(C.flag) 1044 c = C(3, 4) 1045 self.assertEqual((c.x, c.y), (3, 4)) 1046 self.assertTrue(C.flag) 1047 1048 def test_post_init_not_auto_added(self): 1049 # See bpo-46757, which had proposed always adding __post_init__. As 1050 # Raymond Hettinger pointed out, that would be a breaking change. So, 1051 # add a test to make sure that the current behavior doesn't change. 1052 1053 @dataclass 1054 class A0: 1055 pass 1056 1057 @dataclass 1058 class B0: 1059 b_called: bool = False 1060 def __post_init__(self): 1061 self.b_called = True 1062 1063 @dataclass 1064 class C0(A0, B0): 1065 c_called: bool = False 1066 def __post_init__(self): 1067 super().__post_init__() 1068 self.c_called = True 1069 1070 # Since A0 has no __post_init__, and one wasn't automatically added 1071 # (because that's the rule: it's never added by @dataclass, it's only 1072 # the class author that can add it), then B0.__post_init__ is called. 1073 # Verify that. 1074 c = C0() 1075 self.assertTrue(c.b_called) 1076 self.assertTrue(c.c_called) 1077 1078 ###################################### 1079 # Now, the same thing, except A1 defines __post_init__. 1080 @dataclass 1081 class A1: 1082 def __post_init__(self): 1083 pass 1084 1085 @dataclass 1086 class B1: 1087 b_called: bool = False 1088 def __post_init__(self): 1089 self.b_called = True 1090 1091 @dataclass 1092 class C1(A1, B1): 1093 c_called: bool = False 1094 def __post_init__(self): 1095 super().__post_init__() 1096 self.c_called = True 1097 1098 # This time, B1.__post_init__ isn't being called. This mimics what 1099 # would happen if A1.__post_init__ had been automatically added, 1100 # instead of manually added as we see here. This test isn't really 1101 # needed, but I'm including it just to demonstrate the changed 1102 # behavior when A1 does define __post_init__. 1103 c = C1() 1104 self.assertFalse(c.b_called) 1105 self.assertTrue(c.c_called) 1106 1107 def test_class_var(self): 1108 # Make sure ClassVars are ignored in __init__, __repr__, etc. 1109 @dataclass 1110 class C: 1111 x: int 1112 y: int = 10 1113 z: ClassVar[int] = 1000 1114 w: ClassVar[int] = 2000 1115 t: ClassVar[int] = 3000 1116 s: ClassVar = 4000 1117 1118 c = C(5) 1119 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)') 1120 self.assertEqual(len(fields(C)), 2) # We have 2 fields. 1121 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars. 1122 self.assertEqual(c.z, 1000) 1123 self.assertEqual(c.w, 2000) 1124 self.assertEqual(c.t, 3000) 1125 self.assertEqual(c.s, 4000) 1126 C.z += 1 1127 self.assertEqual(c.z, 1001) 1128 c = C(20) 1129 self.assertEqual((c.x, c.y), (20, 10)) 1130 self.assertEqual(c.z, 1001) 1131 self.assertEqual(c.w, 2000) 1132 self.assertEqual(c.t, 3000) 1133 self.assertEqual(c.s, 4000) 1134 1135 def test_class_var_no_default(self): 1136 # If a ClassVar has no default value, it should not be set on the class. 1137 @dataclass 1138 class C: 1139 x: ClassVar[int] 1140 1141 self.assertNotIn('x', C.__dict__) 1142 1143 def test_class_var_default_factory(self): 1144 # It makes no sense for a ClassVar to have a default factory. When 1145 # would it be called? Call it yourself, since it's class-wide. 1146 with self.assertRaisesRegex(TypeError, 1147 'cannot have a default factory'): 1148 @dataclass 1149 class C: 1150 x: ClassVar[int] = field(default_factory=int) 1151 1152 self.assertNotIn('x', C.__dict__) 1153 1154 def test_class_var_with_default(self): 1155 # If a ClassVar has a default value, it should be set on the class. 1156 @dataclass 1157 class C: 1158 x: ClassVar[int] = 10 1159 self.assertEqual(C.x, 10) 1160 1161 @dataclass 1162 class C: 1163 x: ClassVar[int] = field(default=10) 1164 self.assertEqual(C.x, 10) 1165 1166 def test_class_var_frozen(self): 1167 # Make sure ClassVars work even if we're frozen. 1168 @dataclass(frozen=True) 1169 class C: 1170 x: int 1171 y: int = 10 1172 z: ClassVar[int] = 1000 1173 w: ClassVar[int] = 2000 1174 t: ClassVar[int] = 3000 1175 1176 c = C(5) 1177 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)') 1178 self.assertEqual(len(fields(C)), 2) # We have 2 fields 1179 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars 1180 self.assertEqual(c.z, 1000) 1181 self.assertEqual(c.w, 2000) 1182 self.assertEqual(c.t, 3000) 1183 # We can still modify the ClassVar, it's only instances that are 1184 # frozen. 1185 C.z += 1 1186 self.assertEqual(c.z, 1001) 1187 c = C(20) 1188 self.assertEqual((c.x, c.y), (20, 10)) 1189 self.assertEqual(c.z, 1001) 1190 self.assertEqual(c.w, 2000) 1191 self.assertEqual(c.t, 3000) 1192 1193 def test_init_var_no_default(self): 1194 # If an InitVar has no default value, it should not be set on the class. 1195 @dataclass 1196 class C: 1197 x: InitVar[int] 1198 1199 self.assertNotIn('x', C.__dict__) 1200 1201 def test_init_var_default_factory(self): 1202 # It makes no sense for an InitVar to have a default factory. When 1203 # would it be called? Call it yourself, since it's class-wide. 1204 with self.assertRaisesRegex(TypeError, 1205 'cannot have a default factory'): 1206 @dataclass 1207 class C: 1208 x: InitVar[int] = field(default_factory=int) 1209 1210 self.assertNotIn('x', C.__dict__) 1211 1212 def test_init_var_with_default(self): 1213 # If an InitVar has a default value, it should be set on the class. 1214 @dataclass 1215 class C: 1216 x: InitVar[int] = 10 1217 self.assertEqual(C.x, 10) 1218 1219 @dataclass 1220 class C: 1221 x: InitVar[int] = field(default=10) 1222 self.assertEqual(C.x, 10) 1223 1224 def test_init_var(self): 1225 @dataclass 1226 class C: 1227 x: int = None 1228 init_param: InitVar[int] = None 1229 1230 def __post_init__(self, init_param): 1231 if self.x is None: 1232 self.x = init_param*2 1233 1234 c = C(init_param=10) 1235 self.assertEqual(c.x, 20) 1236 1237 def test_init_var_preserve_type(self): 1238 self.assertEqual(InitVar[int].type, int) 1239 1240 # Make sure the repr is correct. 1241 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]') 1242 self.assertEqual(repr(InitVar[List[int]]), 1243 'dataclasses.InitVar[typing.List[int]]') 1244 self.assertEqual(repr(InitVar[list[int]]), 1245 'dataclasses.InitVar[list[int]]') 1246 self.assertEqual(repr(InitVar[int|str]), 1247 'dataclasses.InitVar[int | str]') 1248 1249 def test_init_var_inheritance(self): 1250 # Note that this deliberately tests that a dataclass need not 1251 # have a __post_init__ function if it has an InitVar field. 1252 # It could just be used in a derived class, as shown here. 1253 @dataclass 1254 class Base: 1255 x: int 1256 init_base: InitVar[int] 1257 1258 # We can instantiate by passing the InitVar, even though 1259 # it's not used. 1260 b = Base(0, 10) 1261 self.assertEqual(vars(b), {'x': 0}) 1262 1263 @dataclass 1264 class C(Base): 1265 y: int 1266 init_derived: InitVar[int] 1267 1268 def __post_init__(self, init_base, init_derived): 1269 self.x = self.x + init_base 1270 self.y = self.y + init_derived 1271 1272 c = C(10, 11, 50, 51) 1273 self.assertEqual(vars(c), {'x': 21, 'y': 101}) 1274 1275 def test_default_factory(self): 1276 # Test a factory that returns a new list. 1277 @dataclass 1278 class C: 1279 x: int 1280 y: list = field(default_factory=list) 1281 1282 c0 = C(3) 1283 c1 = C(3) 1284 self.assertEqual(c0.x, 3) 1285 self.assertEqual(c0.y, []) 1286 self.assertEqual(c0, c1) 1287 self.assertIsNot(c0.y, c1.y) 1288 self.assertEqual(astuple(C(5, [1])), (5, [1])) 1289 1290 # Test a factory that returns a shared list. 1291 l = [] 1292 @dataclass 1293 class C: 1294 x: int 1295 y: list = field(default_factory=lambda: l) 1296 1297 c0 = C(3) 1298 c1 = C(3) 1299 self.assertEqual(c0.x, 3) 1300 self.assertEqual(c0.y, []) 1301 self.assertEqual(c0, c1) 1302 self.assertIs(c0.y, c1.y) 1303 self.assertEqual(astuple(C(5, [1])), (5, [1])) 1304 1305 # Test various other field flags. 1306 # repr 1307 @dataclass 1308 class C: 1309 x: list = field(default_factory=list, repr=False) 1310 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()') 1311 self.assertEqual(C().x, []) 1312 1313 # hash 1314 @dataclass(unsafe_hash=True) 1315 class C: 1316 x: list = field(default_factory=list, hash=False) 1317 self.assertEqual(astuple(C()), ([],)) 1318 self.assertEqual(hash(C()), hash(())) 1319 1320 # init (see also test_default_factory_with_no_init) 1321 @dataclass 1322 class C: 1323 x: list = field(default_factory=list, init=False) 1324 self.assertEqual(astuple(C()), ([],)) 1325 1326 # compare 1327 @dataclass 1328 class C: 1329 x: list = field(default_factory=list, compare=False) 1330 self.assertEqual(C(), C([1])) 1331 1332 def test_default_factory_with_no_init(self): 1333 # We need a factory with a side effect. 1334 factory = Mock() 1335 1336 @dataclass 1337 class C: 1338 x: list = field(default_factory=factory, init=False) 1339 1340 # Make sure the default factory is called for each new instance. 1341 C().x 1342 self.assertEqual(factory.call_count, 1) 1343 C().x 1344 self.assertEqual(factory.call_count, 2) 1345 1346 def test_default_factory_not_called_if_value_given(self): 1347 # We need a factory that we can test if it's been called. 1348 factory = Mock() 1349 1350 @dataclass 1351 class C: 1352 x: int = field(default_factory=factory) 1353 1354 # Make sure that if a field has a default factory function, 1355 # it's not called if a value is specified. 1356 C().x 1357 self.assertEqual(factory.call_count, 1) 1358 self.assertEqual(C(10).x, 10) 1359 self.assertEqual(factory.call_count, 1) 1360 C().x 1361 self.assertEqual(factory.call_count, 2) 1362 1363 def test_default_factory_derived(self): 1364 # See bpo-32896. 1365 @dataclass 1366 class Foo: 1367 x: dict = field(default_factory=dict) 1368 1369 @dataclass 1370 class Bar(Foo): 1371 y: int = 1 1372 1373 self.assertEqual(Foo().x, {}) 1374 self.assertEqual(Bar().x, {}) 1375 self.assertEqual(Bar().y, 1) 1376 1377 @dataclass 1378 class Baz(Foo): 1379 pass 1380 self.assertEqual(Baz().x, {}) 1381 1382 def test_intermediate_non_dataclass(self): 1383 # Test that an intermediate class that defines 1384 # annotations does not define fields. 1385 1386 @dataclass 1387 class A: 1388 x: int 1389 1390 class B(A): 1391 y: int 1392 1393 @dataclass 1394 class C(B): 1395 z: int 1396 1397 c = C(1, 3) 1398 self.assertEqual((c.x, c.z), (1, 3)) 1399 1400 # .y was not initialized. 1401 with self.assertRaisesRegex(AttributeError, 1402 'object has no attribute'): 1403 c.y 1404 1405 # And if we again derive a non-dataclass, no fields are added. 1406 class D(C): 1407 t: int 1408 d = D(4, 5) 1409 self.assertEqual((d.x, d.z), (4, 5)) 1410 1411 def test_classvar_default_factory(self): 1412 # It's an error for a ClassVar to have a factory function. 1413 with self.assertRaisesRegex(TypeError, 1414 'cannot have a default factory'): 1415 @dataclass 1416 class C: 1417 x: ClassVar[int] = field(default_factory=int) 1418 1419 def test_is_dataclass(self): 1420 class NotDataClass: 1421 pass 1422 1423 self.assertFalse(is_dataclass(0)) 1424 self.assertFalse(is_dataclass(int)) 1425 self.assertFalse(is_dataclass(NotDataClass)) 1426 self.assertFalse(is_dataclass(NotDataClass())) 1427 1428 @dataclass 1429 class C: 1430 x: int 1431 1432 @dataclass 1433 class D: 1434 d: C 1435 e: int 1436 1437 c = C(10) 1438 d = D(c, 4) 1439 1440 self.assertTrue(is_dataclass(C)) 1441 self.assertTrue(is_dataclass(c)) 1442 self.assertFalse(is_dataclass(c.x)) 1443 self.assertTrue(is_dataclass(d.d)) 1444 self.assertFalse(is_dataclass(d.e)) 1445 1446 def test_is_dataclass_when_getattr_always_returns(self): 1447 # See bpo-37868. 1448 class A: 1449 def __getattr__(self, key): 1450 return 0 1451 self.assertFalse(is_dataclass(A)) 1452 a = A() 1453 1454 # Also test for an instance attribute. 1455 class B: 1456 pass 1457 b = B() 1458 b.__dataclass_fields__ = [] 1459 1460 for obj in a, b: 1461 with self.subTest(obj=obj): 1462 self.assertFalse(is_dataclass(obj)) 1463 1464 # Indirect tests for _is_dataclass_instance(). 1465 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1466 asdict(obj) 1467 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1468 astuple(obj) 1469 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1470 replace(obj, x=0) 1471 1472 def test_is_dataclass_genericalias(self): 1473 @dataclass 1474 class A(types.GenericAlias): 1475 origin: type 1476 args: type 1477 self.assertTrue(is_dataclass(A)) 1478 a = A(list, int) 1479 self.assertTrue(is_dataclass(type(a))) 1480 self.assertTrue(is_dataclass(a)) 1481 1482 1483 def test_helper_fields_with_class_instance(self): 1484 # Check that we can call fields() on either a class or instance, 1485 # and get back the same thing. 1486 @dataclass 1487 class C: 1488 x: int 1489 y: float 1490 1491 self.assertEqual(fields(C), fields(C(0, 0.0))) 1492 1493 def test_helper_fields_exception(self): 1494 # Check that TypeError is raised if not passed a dataclass or 1495 # instance. 1496 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1497 fields(0) 1498 1499 class C: pass 1500 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1501 fields(C) 1502 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1503 fields(C()) 1504 1505 def test_clean_traceback_from_fields_exception(self): 1506 stdout = io.StringIO() 1507 try: 1508 fields(object) 1509 except TypeError as exc: 1510 traceback.print_exception(exc, file=stdout) 1511 printed_traceback = stdout.getvalue() 1512 self.assertNotIn("AttributeError", printed_traceback) 1513 self.assertNotIn("__dataclass_fields__", printed_traceback) 1514 1515 def test_helper_asdict(self): 1516 # Basic tests for asdict(), it should return a new dictionary. 1517 @dataclass 1518 class C: 1519 x: int 1520 y: int 1521 c = C(1, 2) 1522 1523 self.assertEqual(asdict(c), {'x': 1, 'y': 2}) 1524 self.assertEqual(asdict(c), asdict(c)) 1525 self.assertIsNot(asdict(c), asdict(c)) 1526 c.x = 42 1527 self.assertEqual(asdict(c), {'x': 42, 'y': 2}) 1528 self.assertIs(type(asdict(c)), dict) 1529 1530 def test_helper_asdict_raises_on_classes(self): 1531 # asdict() should raise on a class object. 1532 @dataclass 1533 class C: 1534 x: int 1535 y: int 1536 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1537 asdict(C) 1538 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1539 asdict(int) 1540 1541 def test_helper_asdict_copy_values(self): 1542 @dataclass 1543 class C: 1544 x: int 1545 y: List[int] = field(default_factory=list) 1546 initial = [] 1547 c = C(1, initial) 1548 d = asdict(c) 1549 self.assertEqual(d['y'], initial) 1550 self.assertIsNot(d['y'], initial) 1551 c = C(1) 1552 d = asdict(c) 1553 d['y'].append(1) 1554 self.assertEqual(c.y, []) 1555 1556 def test_helper_asdict_nested(self): 1557 @dataclass 1558 class UserId: 1559 token: int 1560 group: int 1561 @dataclass 1562 class User: 1563 name: str 1564 id: UserId 1565 u = User('Joe', UserId(123, 1)) 1566 d = asdict(u) 1567 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}}) 1568 self.assertIsNot(asdict(u), asdict(u)) 1569 u.id.group = 2 1570 self.assertEqual(asdict(u), {'name': 'Joe', 1571 'id': {'token': 123, 'group': 2}}) 1572 1573 def test_helper_asdict_builtin_containers(self): 1574 @dataclass 1575 class User: 1576 name: str 1577 id: int 1578 @dataclass 1579 class GroupList: 1580 id: int 1581 users: List[User] 1582 @dataclass 1583 class GroupTuple: 1584 id: int 1585 users: Tuple[User, ...] 1586 @dataclass 1587 class GroupDict: 1588 id: int 1589 users: Dict[str, User] 1590 a = User('Alice', 1) 1591 b = User('Bob', 2) 1592 gl = GroupList(0, [a, b]) 1593 gt = GroupTuple(0, (a, b)) 1594 gd = GroupDict(0, {'first': a, 'second': b}) 1595 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1}, 1596 {'name': 'Bob', 'id': 2}]}) 1597 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1}, 1598 {'name': 'Bob', 'id': 2})}) 1599 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1}, 1600 'second': {'name': 'Bob', 'id': 2}}}) 1601 1602 def test_helper_asdict_builtin_object_containers(self): 1603 @dataclass 1604 class Child: 1605 d: object 1606 1607 @dataclass 1608 class Parent: 1609 child: Child 1610 1611 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}}) 1612 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}}) 1613 1614 def test_helper_asdict_factory(self): 1615 @dataclass 1616 class C: 1617 x: int 1618 y: int 1619 c = C(1, 2) 1620 d = asdict(c, dict_factory=OrderedDict) 1621 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)])) 1622 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict)) 1623 c.x = 42 1624 d = asdict(c, dict_factory=OrderedDict) 1625 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)])) 1626 self.assertIs(type(d), OrderedDict) 1627 1628 def test_helper_asdict_namedtuple(self): 1629 T = namedtuple('T', 'a b c') 1630 @dataclass 1631 class C: 1632 x: str 1633 y: T 1634 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 1635 1636 d = asdict(c) 1637 self.assertEqual(d, {'x': 'outer', 1638 'y': T(1, 1639 {'x': 'inner', 1640 'y': T(11, 12, 13)}, 1641 2), 1642 } 1643 ) 1644 1645 # Now with a dict_factory. OrderedDict is convenient, but 1646 # since it compares to dicts, we also need to have separate 1647 # assertIs tests. 1648 d = asdict(c, dict_factory=OrderedDict) 1649 self.assertEqual(d, {'x': 'outer', 1650 'y': T(1, 1651 {'x': 'inner', 1652 'y': T(11, 12, 13)}, 1653 2), 1654 } 1655 ) 1656 1657 # Make sure that the returned dicts are actually OrderedDicts. 1658 self.assertIs(type(d), OrderedDict) 1659 self.assertIs(type(d['y'][1]), OrderedDict) 1660 1661 def test_helper_asdict_namedtuple_key(self): 1662 # Ensure that a field that contains a dict which has a 1663 # namedtuple as a key works with asdict(). 1664 1665 @dataclass 1666 class C: 1667 f: dict 1668 T = namedtuple('T', 'a') 1669 1670 c = C({T('an a'): 0}) 1671 1672 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}}) 1673 1674 def test_helper_asdict_namedtuple_derived(self): 1675 class T(namedtuple('Tbase', 'a')): 1676 def my_a(self): 1677 return self.a 1678 1679 @dataclass 1680 class C: 1681 f: T 1682 1683 t = T(6) 1684 c = C(t) 1685 1686 d = asdict(c) 1687 self.assertEqual(d, {'f': T(a=6)}) 1688 # Make sure that t has been copied, not used directly. 1689 self.assertIsNot(d['f'], t) 1690 self.assertEqual(d['f'].my_a(), 6) 1691 1692 def test_helper_astuple(self): 1693 # Basic tests for astuple(), it should return a new tuple. 1694 @dataclass 1695 class C: 1696 x: int 1697 y: int = 0 1698 c = C(1) 1699 1700 self.assertEqual(astuple(c), (1, 0)) 1701 self.assertEqual(astuple(c), astuple(c)) 1702 self.assertIsNot(astuple(c), astuple(c)) 1703 c.y = 42 1704 self.assertEqual(astuple(c), (1, 42)) 1705 self.assertIs(type(astuple(c)), tuple) 1706 1707 def test_helper_astuple_raises_on_classes(self): 1708 # astuple() should raise on a class object. 1709 @dataclass 1710 class C: 1711 x: int 1712 y: int 1713 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1714 astuple(C) 1715 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1716 astuple(int) 1717 1718 def test_helper_astuple_copy_values(self): 1719 @dataclass 1720 class C: 1721 x: int 1722 y: List[int] = field(default_factory=list) 1723 initial = [] 1724 c = C(1, initial) 1725 t = astuple(c) 1726 self.assertEqual(t[1], initial) 1727 self.assertIsNot(t[1], initial) 1728 c = C(1) 1729 t = astuple(c) 1730 t[1].append(1) 1731 self.assertEqual(c.y, []) 1732 1733 def test_helper_astuple_nested(self): 1734 @dataclass 1735 class UserId: 1736 token: int 1737 group: int 1738 @dataclass 1739 class User: 1740 name: str 1741 id: UserId 1742 u = User('Joe', UserId(123, 1)) 1743 t = astuple(u) 1744 self.assertEqual(t, ('Joe', (123, 1))) 1745 self.assertIsNot(astuple(u), astuple(u)) 1746 u.id.group = 2 1747 self.assertEqual(astuple(u), ('Joe', (123, 2))) 1748 1749 def test_helper_astuple_builtin_containers(self): 1750 @dataclass 1751 class User: 1752 name: str 1753 id: int 1754 @dataclass 1755 class GroupList: 1756 id: int 1757 users: List[User] 1758 @dataclass 1759 class GroupTuple: 1760 id: int 1761 users: Tuple[User, ...] 1762 @dataclass 1763 class GroupDict: 1764 id: int 1765 users: Dict[str, User] 1766 a = User('Alice', 1) 1767 b = User('Bob', 2) 1768 gl = GroupList(0, [a, b]) 1769 gt = GroupTuple(0, (a, b)) 1770 gd = GroupDict(0, {'first': a, 'second': b}) 1771 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)])) 1772 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2)))) 1773 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)})) 1774 1775 def test_helper_astuple_builtin_object_containers(self): 1776 @dataclass 1777 class Child: 1778 d: object 1779 1780 @dataclass 1781 class Parent: 1782 child: Child 1783 1784 self.assertEqual(astuple(Parent(Child([1]))), (([1],),)) 1785 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),)) 1786 1787 def test_helper_astuple_factory(self): 1788 @dataclass 1789 class C: 1790 x: int 1791 y: int 1792 NT = namedtuple('NT', 'x y') 1793 def nt(lst): 1794 return NT(*lst) 1795 c = C(1, 2) 1796 t = astuple(c, tuple_factory=nt) 1797 self.assertEqual(t, NT(1, 2)) 1798 self.assertIsNot(t, astuple(c, tuple_factory=nt)) 1799 c.x = 42 1800 t = astuple(c, tuple_factory=nt) 1801 self.assertEqual(t, NT(42, 2)) 1802 self.assertIs(type(t), NT) 1803 1804 def test_helper_astuple_namedtuple(self): 1805 T = namedtuple('T', 'a b c') 1806 @dataclass 1807 class C: 1808 x: str 1809 y: T 1810 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 1811 1812 t = astuple(c) 1813 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2))) 1814 1815 # Now, using a tuple_factory. list is convenient here. 1816 t = astuple(c, tuple_factory=list) 1817 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) 1818 1819 def test_dynamic_class_creation(self): 1820 cls_dict = {'__annotations__': {'x': int, 'y': int}, 1821 } 1822 1823 # Create the class. 1824 cls = type('C', (), cls_dict) 1825 1826 # Make it a dataclass. 1827 cls1 = dataclass(cls) 1828 1829 self.assertEqual(cls1, cls) 1830 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2}) 1831 1832 def test_dynamic_class_creation_using_field(self): 1833 cls_dict = {'__annotations__': {'x': int, 'y': int}, 1834 'y': field(default=5), 1835 } 1836 1837 # Create the class. 1838 cls = type('C', (), cls_dict) 1839 1840 # Make it a dataclass. 1841 cls1 = dataclass(cls) 1842 1843 self.assertEqual(cls1, cls) 1844 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5}) 1845 1846 def test_init_in_order(self): 1847 @dataclass 1848 class C: 1849 a: int 1850 b: int = field() 1851 c: list = field(default_factory=list, init=False) 1852 d: list = field(default_factory=list) 1853 e: int = field(default=4, init=False) 1854 f: int = 4 1855 1856 calls = [] 1857 def setattr(self, name, value): 1858 calls.append((name, value)) 1859 1860 C.__setattr__ = setattr 1861 c = C(0, 1) 1862 self.assertEqual(('a', 0), calls[0]) 1863 self.assertEqual(('b', 1), calls[1]) 1864 self.assertEqual(('c', []), calls[2]) 1865 self.assertEqual(('d', []), calls[3]) 1866 self.assertNotIn(('e', 4), calls) 1867 self.assertEqual(('f', 4), calls[4]) 1868 1869 def test_items_in_dicts(self): 1870 @dataclass 1871 class C: 1872 a: int 1873 b: list = field(default_factory=list, init=False) 1874 c: list = field(default_factory=list) 1875 d: int = field(default=4, init=False) 1876 e: int = 0 1877 1878 c = C(0) 1879 # Class dict 1880 self.assertNotIn('a', C.__dict__) 1881 self.assertNotIn('b', C.__dict__) 1882 self.assertNotIn('c', C.__dict__) 1883 self.assertIn('d', C.__dict__) 1884 self.assertEqual(C.d, 4) 1885 self.assertIn('e', C.__dict__) 1886 self.assertEqual(C.e, 0) 1887 # Instance dict 1888 self.assertIn('a', c.__dict__) 1889 self.assertEqual(c.a, 0) 1890 self.assertIn('b', c.__dict__) 1891 self.assertEqual(c.b, []) 1892 self.assertIn('c', c.__dict__) 1893 self.assertEqual(c.c, []) 1894 self.assertNotIn('d', c.__dict__) 1895 self.assertIn('e', c.__dict__) 1896 self.assertEqual(c.e, 0) 1897 1898 def test_alternate_classmethod_constructor(self): 1899 # Since __post_init__ can't take params, use a classmethod 1900 # alternate constructor. This is mostly an example to show 1901 # how to use this technique. 1902 @dataclass 1903 class C: 1904 x: int 1905 @classmethod 1906 def from_file(cls, filename): 1907 # In a real example, create a new instance 1908 # and populate 'x' from contents of a file. 1909 value_in_file = 20 1910 return cls(value_in_file) 1911 1912 self.assertEqual(C.from_file('filename').x, 20) 1913 1914 def test_field_metadata_default(self): 1915 # Make sure the default metadata is read-only and of 1916 # zero length. 1917 @dataclass 1918 class C: 1919 i: int 1920 1921 self.assertFalse(fields(C)[0].metadata) 1922 self.assertEqual(len(fields(C)[0].metadata), 0) 1923 with self.assertRaisesRegex(TypeError, 1924 'does not support item assignment'): 1925 fields(C)[0].metadata['test'] = 3 1926 1927 def test_field_metadata_mapping(self): 1928 # Make sure only a mapping can be passed as metadata 1929 # zero length. 1930 with self.assertRaises(TypeError): 1931 @dataclass 1932 class C: 1933 i: int = field(metadata=0) 1934 1935 # Make sure an empty dict works. 1936 d = {} 1937 @dataclass 1938 class C: 1939 i: int = field(metadata=d) 1940 self.assertFalse(fields(C)[0].metadata) 1941 self.assertEqual(len(fields(C)[0].metadata), 0) 1942 # Update should work (see bpo-35960). 1943 d['foo'] = 1 1944 self.assertEqual(len(fields(C)[0].metadata), 1) 1945 self.assertEqual(fields(C)[0].metadata['foo'], 1) 1946 with self.assertRaisesRegex(TypeError, 1947 'does not support item assignment'): 1948 fields(C)[0].metadata['test'] = 3 1949 1950 # Make sure a non-empty dict works. 1951 d = {'test': 10, 'bar': '42', 3: 'three'} 1952 @dataclass 1953 class C: 1954 i: int = field(metadata=d) 1955 self.assertEqual(len(fields(C)[0].metadata), 3) 1956 self.assertEqual(fields(C)[0].metadata['test'], 10) 1957 self.assertEqual(fields(C)[0].metadata['bar'], '42') 1958 self.assertEqual(fields(C)[0].metadata[3], 'three') 1959 # Update should work. 1960 d['foo'] = 1 1961 self.assertEqual(len(fields(C)[0].metadata), 4) 1962 self.assertEqual(fields(C)[0].metadata['foo'], 1) 1963 with self.assertRaises(KeyError): 1964 # Non-existent key. 1965 fields(C)[0].metadata['baz'] 1966 with self.assertRaisesRegex(TypeError, 1967 'does not support item assignment'): 1968 fields(C)[0].metadata['test'] = 3 1969 1970 def test_field_metadata_custom_mapping(self): 1971 # Try a custom mapping. 1972 class SimpleNameSpace: 1973 def __init__(self, **kw): 1974 self.__dict__.update(kw) 1975 1976 def __getitem__(self, item): 1977 if item == 'xyzzy': 1978 return 'plugh' 1979 return getattr(self, item) 1980 1981 def __len__(self): 1982 return self.__dict__.__len__() 1983 1984 @dataclass 1985 class C: 1986 i: int = field(metadata=SimpleNameSpace(a=10)) 1987 1988 self.assertEqual(len(fields(C)[0].metadata), 1) 1989 self.assertEqual(fields(C)[0].metadata['a'], 10) 1990 with self.assertRaises(AttributeError): 1991 fields(C)[0].metadata['b'] 1992 # Make sure we're still talking to our custom mapping. 1993 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh') 1994 1995 def test_generic_dataclasses(self): 1996 T = TypeVar('T') 1997 1998 @dataclass 1999 class LabeledBox(Generic[T]): 2000 content: T 2001 label: str = '<unknown>' 2002 2003 box = LabeledBox(42) 2004 self.assertEqual(box.content, 42) 2005 self.assertEqual(box.label, '<unknown>') 2006 2007 # Subscripting the resulting class should work, etc. 2008 Alias = List[LabeledBox[int]] 2009 2010 def test_generic_extending(self): 2011 S = TypeVar('S') 2012 T = TypeVar('T') 2013 2014 @dataclass 2015 class Base(Generic[T, S]): 2016 x: T 2017 y: S 2018 2019 @dataclass 2020 class DataDerived(Base[int, T]): 2021 new_field: str 2022 Alias = DataDerived[str] 2023 c = Alias(0, 'test1', 'test2') 2024 self.assertEqual(astuple(c), (0, 'test1', 'test2')) 2025 2026 class NonDataDerived(Base[int, T]): 2027 def new_method(self): 2028 return self.y 2029 Alias = NonDataDerived[float] 2030 c = Alias(10, 1.0) 2031 self.assertEqual(c.new_method(), 1.0) 2032 2033 def test_generic_dynamic(self): 2034 T = TypeVar('T') 2035 2036 @dataclass 2037 class Parent(Generic[T]): 2038 x: T 2039 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)], 2040 bases=(Parent[int], Generic[T]), namespace={'other': 42}) 2041 self.assertIs(Child[int](1, 2).z, None) 2042 self.assertEqual(Child[int](1, 2, 3).z, 3) 2043 self.assertEqual(Child[int](1, 2, 3).other, 42) 2044 # Check that type aliases work correctly. 2045 Alias = Child[T] 2046 self.assertEqual(Alias[int](1, 2).x, 1) 2047 # Check MRO resolution. 2048 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object)) 2049 2050 def test_dataclasses_pickleable(self): 2051 global P, Q, R 2052 @dataclass 2053 class P: 2054 x: int 2055 y: int = 0 2056 @dataclass 2057 class Q: 2058 x: int 2059 y: int = field(default=0, init=False) 2060 @dataclass 2061 class R: 2062 x: int 2063 y: List[int] = field(default_factory=list) 2064 q = Q(1) 2065 q.y = 2 2066 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])] 2067 for sample in samples: 2068 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 2069 with self.subTest(sample=sample, proto=proto): 2070 new_sample = pickle.loads(pickle.dumps(sample, proto)) 2071 self.assertEqual(sample.x, new_sample.x) 2072 self.assertEqual(sample.y, new_sample.y) 2073 self.assertIsNot(sample, new_sample) 2074 new_sample.x = 42 2075 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto)) 2076 self.assertEqual(new_sample.x, another_new_sample.x) 2077 self.assertEqual(sample.y, another_new_sample.y) 2078 2079 def test_dataclasses_qualnames(self): 2080 @dataclass(order=True, unsafe_hash=True, frozen=True) 2081 class A: 2082 x: int 2083 y: int 2084 2085 self.assertEqual(A.__init__.__name__, "__init__") 2086 for function in ( 2087 '__eq__', 2088 '__lt__', 2089 '__le__', 2090 '__gt__', 2091 '__ge__', 2092 '__hash__', 2093 '__init__', 2094 '__repr__', 2095 '__setattr__', 2096 '__delattr__', 2097 ): 2098 self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}") 2099 2100 with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"): 2101 A() 2102 2103 2104class TestFieldNoAnnotation(unittest.TestCase): 2105 def test_field_without_annotation(self): 2106 with self.assertRaisesRegex(TypeError, 2107 "'f' is a field but has no type annotation"): 2108 @dataclass 2109 class C: 2110 f = field() 2111 2112 def test_field_without_annotation_but_annotation_in_base(self): 2113 @dataclass 2114 class B: 2115 f: int 2116 2117 with self.assertRaisesRegex(TypeError, 2118 "'f' is a field but has no type annotation"): 2119 # This is still an error: make sure we don't pick up the 2120 # type annotation in the base class. 2121 @dataclass 2122 class C(B): 2123 f = field() 2124 2125 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self): 2126 # Same test, but with the base class not a dataclass. 2127 class B: 2128 f: int 2129 2130 with self.assertRaisesRegex(TypeError, 2131 "'f' is a field but has no type annotation"): 2132 # This is still an error: make sure we don't pick up the 2133 # type annotation in the base class. 2134 @dataclass 2135 class C(B): 2136 f = field() 2137 2138 2139class TestDocString(unittest.TestCase): 2140 def assertDocStrEqual(self, a, b): 2141 # Because 3.6 and 3.7 differ in how inspect.signature work 2142 # (see bpo #32108), for the time being just compare them with 2143 # whitespace stripped. 2144 self.assertEqual(a.replace(' ', ''), b.replace(' ', '')) 2145 2146 def test_existing_docstring_not_overridden(self): 2147 @dataclass 2148 class C: 2149 """Lorem ipsum""" 2150 x: int 2151 2152 self.assertEqual(C.__doc__, "Lorem ipsum") 2153 2154 def test_docstring_no_fields(self): 2155 @dataclass 2156 class C: 2157 pass 2158 2159 self.assertDocStrEqual(C.__doc__, "C()") 2160 2161 def test_docstring_one_field(self): 2162 @dataclass 2163 class C: 2164 x: int 2165 2166 self.assertDocStrEqual(C.__doc__, "C(x:int)") 2167 2168 def test_docstring_two_fields(self): 2169 @dataclass 2170 class C: 2171 x: int 2172 y: int 2173 2174 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)") 2175 2176 def test_docstring_three_fields(self): 2177 @dataclass 2178 class C: 2179 x: int 2180 y: int 2181 z: str 2182 2183 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)") 2184 2185 def test_docstring_one_field_with_default(self): 2186 @dataclass 2187 class C: 2188 x: int = 3 2189 2190 self.assertDocStrEqual(C.__doc__, "C(x:int=3)") 2191 2192 def test_docstring_one_field_with_default_none(self): 2193 @dataclass 2194 class C: 2195 x: Union[int, type(None)] = None 2196 2197 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)") 2198 2199 def test_docstring_list_field(self): 2200 @dataclass 2201 class C: 2202 x: List[int] 2203 2204 self.assertDocStrEqual(C.__doc__, "C(x:List[int])") 2205 2206 def test_docstring_list_field_with_default_factory(self): 2207 @dataclass 2208 class C: 2209 x: List[int] = field(default_factory=list) 2210 2211 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)") 2212 2213 def test_docstring_deque_field(self): 2214 @dataclass 2215 class C: 2216 x: deque 2217 2218 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)") 2219 2220 def test_docstring_deque_field_with_default_factory(self): 2221 @dataclass 2222 class C: 2223 x: deque = field(default_factory=deque) 2224 2225 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)") 2226 2227 def test_docstring_with_no_signature(self): 2228 # See https://github.com/python/cpython/issues/103449 2229 class Meta(type): 2230 __call__ = dict 2231 class Base(metaclass=Meta): 2232 pass 2233 2234 @dataclass 2235 class C(Base): 2236 pass 2237 2238 self.assertDocStrEqual(C.__doc__, "C") 2239 2240 2241class TestInit(unittest.TestCase): 2242 def test_base_has_init(self): 2243 class B: 2244 def __init__(self): 2245 self.z = 100 2246 pass 2247 2248 # Make sure that declaring this class doesn't raise an error. 2249 # The issue is that we can't override __init__ in our class, 2250 # but it should be okay to add __init__ to us if our base has 2251 # an __init__. 2252 @dataclass 2253 class C(B): 2254 x: int = 0 2255 c = C(10) 2256 self.assertEqual(c.x, 10) 2257 self.assertNotIn('z', vars(c)) 2258 2259 # Make sure that if we don't add an init, the base __init__ 2260 # gets called. 2261 @dataclass(init=False) 2262 class C(B): 2263 x: int = 10 2264 c = C() 2265 self.assertEqual(c.x, 10) 2266 self.assertEqual(c.z, 100) 2267 2268 def test_no_init(self): 2269 @dataclass(init=False) 2270 class C: 2271 i: int = 0 2272 self.assertEqual(C().i, 0) 2273 2274 @dataclass(init=False) 2275 class C: 2276 i: int = 2 2277 def __init__(self): 2278 self.i = 3 2279 self.assertEqual(C().i, 3) 2280 2281 def test_overwriting_init(self): 2282 # If the class has __init__, use it no matter the value of 2283 # init=. 2284 2285 @dataclass 2286 class C: 2287 x: int 2288 def __init__(self, x): 2289 self.x = 2 * x 2290 self.assertEqual(C(3).x, 6) 2291 2292 @dataclass(init=True) 2293 class C: 2294 x: int 2295 def __init__(self, x): 2296 self.x = 2 * x 2297 self.assertEqual(C(4).x, 8) 2298 2299 @dataclass(init=False) 2300 class C: 2301 x: int 2302 def __init__(self, x): 2303 self.x = 2 * x 2304 self.assertEqual(C(5).x, 10) 2305 2306 def test_inherit_from_protocol(self): 2307 # Dataclasses inheriting from protocol should preserve their own `__init__`. 2308 # See bpo-45081. 2309 2310 class P(Protocol): 2311 a: int 2312 2313 @dataclass 2314 class C(P): 2315 a: int 2316 2317 self.assertEqual(C(5).a, 5) 2318 2319 @dataclass 2320 class D(P): 2321 def __init__(self, a): 2322 self.a = a * 2 2323 2324 self.assertEqual(D(5).a, 10) 2325 2326 2327class TestRepr(unittest.TestCase): 2328 def test_repr(self): 2329 @dataclass 2330 class B: 2331 x: int 2332 2333 @dataclass 2334 class C(B): 2335 y: int = 10 2336 2337 o = C(4) 2338 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)') 2339 2340 @dataclass 2341 class D(C): 2342 x: int = 20 2343 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)') 2344 2345 @dataclass 2346 class C: 2347 @dataclass 2348 class D: 2349 i: int 2350 @dataclass 2351 class E: 2352 pass 2353 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)') 2354 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()') 2355 2356 def test_no_repr(self): 2357 # Test a class with no __repr__ and repr=False. 2358 @dataclass(repr=False) 2359 class C: 2360 x: int 2361 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at', 2362 repr(C(3))) 2363 2364 # Test a class with a __repr__ and repr=False. 2365 @dataclass(repr=False) 2366 class C: 2367 x: int 2368 def __repr__(self): 2369 return 'C-class' 2370 self.assertEqual(repr(C(3)), 'C-class') 2371 2372 def test_overwriting_repr(self): 2373 # If the class has __repr__, use it no matter the value of 2374 # repr=. 2375 2376 @dataclass 2377 class C: 2378 x: int 2379 def __repr__(self): 2380 return 'x' 2381 self.assertEqual(repr(C(0)), 'x') 2382 2383 @dataclass(repr=True) 2384 class C: 2385 x: int 2386 def __repr__(self): 2387 return 'x' 2388 self.assertEqual(repr(C(0)), 'x') 2389 2390 @dataclass(repr=False) 2391 class C: 2392 x: int 2393 def __repr__(self): 2394 return 'x' 2395 self.assertEqual(repr(C(0)), 'x') 2396 2397 2398class TestEq(unittest.TestCase): 2399 def test_no_eq(self): 2400 # Test a class with no __eq__ and eq=False. 2401 @dataclass(eq=False) 2402 class C: 2403 x: int 2404 self.assertNotEqual(C(0), C(0)) 2405 c = C(3) 2406 self.assertEqual(c, c) 2407 2408 # Test a class with an __eq__ and eq=False. 2409 @dataclass(eq=False) 2410 class C: 2411 x: int 2412 def __eq__(self, other): 2413 return other == 10 2414 self.assertEqual(C(3), 10) 2415 2416 def test_overwriting_eq(self): 2417 # If the class has __eq__, use it no matter the value of 2418 # eq=. 2419 2420 @dataclass 2421 class C: 2422 x: int 2423 def __eq__(self, other): 2424 return other == 3 2425 self.assertEqual(C(1), 3) 2426 self.assertNotEqual(C(1), 1) 2427 2428 @dataclass(eq=True) 2429 class C: 2430 x: int 2431 def __eq__(self, other): 2432 return other == 4 2433 self.assertEqual(C(1), 4) 2434 self.assertNotEqual(C(1), 1) 2435 2436 @dataclass(eq=False) 2437 class C: 2438 x: int 2439 def __eq__(self, other): 2440 return other == 5 2441 self.assertEqual(C(1), 5) 2442 self.assertNotEqual(C(1), 1) 2443 2444 2445class TestOrdering(unittest.TestCase): 2446 def test_functools_total_ordering(self): 2447 # Test that functools.total_ordering works with this class. 2448 @total_ordering 2449 @dataclass 2450 class C: 2451 x: int 2452 def __lt__(self, other): 2453 # Perform the test "backward", just to make 2454 # sure this is being called. 2455 return self.x >= other 2456 2457 self.assertLess(C(0), -1) 2458 self.assertLessEqual(C(0), -1) 2459 self.assertGreater(C(0), 1) 2460 self.assertGreaterEqual(C(0), 1) 2461 2462 def test_no_order(self): 2463 # Test that no ordering functions are added by default. 2464 @dataclass(order=False) 2465 class C: 2466 x: int 2467 # Make sure no order methods are added. 2468 self.assertNotIn('__le__', C.__dict__) 2469 self.assertNotIn('__lt__', C.__dict__) 2470 self.assertNotIn('__ge__', C.__dict__) 2471 self.assertNotIn('__gt__', C.__dict__) 2472 2473 # Test that __lt__ is still called 2474 @dataclass(order=False) 2475 class C: 2476 x: int 2477 def __lt__(self, other): 2478 return False 2479 # Make sure other methods aren't added. 2480 self.assertNotIn('__le__', C.__dict__) 2481 self.assertNotIn('__ge__', C.__dict__) 2482 self.assertNotIn('__gt__', C.__dict__) 2483 2484 def test_overwriting_order(self): 2485 with self.assertRaisesRegex(TypeError, 2486 'Cannot overwrite attribute __lt__' 2487 '.*using functools.total_ordering'): 2488 @dataclass(order=True) 2489 class C: 2490 x: int 2491 def __lt__(self): 2492 pass 2493 2494 with self.assertRaisesRegex(TypeError, 2495 'Cannot overwrite attribute __le__' 2496 '.*using functools.total_ordering'): 2497 @dataclass(order=True) 2498 class C: 2499 x: int 2500 def __le__(self): 2501 pass 2502 2503 with self.assertRaisesRegex(TypeError, 2504 'Cannot overwrite attribute __gt__' 2505 '.*using functools.total_ordering'): 2506 @dataclass(order=True) 2507 class C: 2508 x: int 2509 def __gt__(self): 2510 pass 2511 2512 with self.assertRaisesRegex(TypeError, 2513 'Cannot overwrite attribute __ge__' 2514 '.*using functools.total_ordering'): 2515 @dataclass(order=True) 2516 class C: 2517 x: int 2518 def __ge__(self): 2519 pass 2520 2521class TestHash(unittest.TestCase): 2522 def test_unsafe_hash(self): 2523 @dataclass(unsafe_hash=True) 2524 class C: 2525 x: int 2526 y: str 2527 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) 2528 2529 def test_hash_rules(self): 2530 def non_bool(value): 2531 # Map to something else that's True, but not a bool. 2532 if value is None: 2533 return None 2534 if value: 2535 return (3,) 2536 return 0 2537 2538 def test(case, unsafe_hash, eq, frozen, with_hash, result): 2539 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq, 2540 frozen=frozen): 2541 if result != 'exception': 2542 if with_hash: 2543 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2544 class C: 2545 def __hash__(self): 2546 return 0 2547 else: 2548 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2549 class C: 2550 pass 2551 2552 # See if the result matches what's expected. 2553 if result == 'fn': 2554 # __hash__ contains the function we generated. 2555 self.assertIn('__hash__', C.__dict__) 2556 self.assertIsNotNone(C.__dict__['__hash__']) 2557 2558 elif result == '': 2559 # __hash__ is not present in our class. 2560 if not with_hash: 2561 self.assertNotIn('__hash__', C.__dict__) 2562 2563 elif result == 'none': 2564 # __hash__ is set to None. 2565 self.assertIn('__hash__', C.__dict__) 2566 self.assertIsNone(C.__dict__['__hash__']) 2567 2568 elif result == 'exception': 2569 # Creating the class should cause an exception. 2570 # This only happens with with_hash==True. 2571 assert(with_hash) 2572 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'): 2573 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2574 class C: 2575 def __hash__(self): 2576 return 0 2577 2578 else: 2579 assert False, f'unknown result {result!r}' 2580 2581 # There are 8 cases of: 2582 # unsafe_hash=True/False 2583 # eq=True/False 2584 # frozen=True/False 2585 # And for each of these, a different result if 2586 # __hash__ is defined or not. 2587 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([ 2588 (False, False, False, '', ''), 2589 (False, False, True, '', ''), 2590 (False, True, False, 'none', ''), 2591 (False, True, True, 'fn', ''), 2592 (True, False, False, 'fn', 'exception'), 2593 (True, False, True, 'fn', 'exception'), 2594 (True, True, False, 'fn', 'exception'), 2595 (True, True, True, 'fn', 'exception'), 2596 ], 1): 2597 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash) 2598 test(case, unsafe_hash, eq, frozen, True, res_defined_hash) 2599 2600 # Test non-bool truth values, too. This is just to 2601 # make sure the data-driven table in the decorator 2602 # handles non-bool values. 2603 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash) 2604 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash) 2605 2606 2607 def test_eq_only(self): 2608 # If a class defines __eq__, __hash__ is automatically added 2609 # and set to None. This is normal Python behavior, not 2610 # related to dataclasses. Make sure we don't interfere with 2611 # that (see bpo=32546). 2612 2613 @dataclass 2614 class C: 2615 i: int 2616 def __eq__(self, other): 2617 return self.i == other.i 2618 self.assertEqual(C(1), C(1)) 2619 self.assertNotEqual(C(1), C(4)) 2620 2621 # And make sure things work in this case if we specify 2622 # unsafe_hash=True. 2623 @dataclass(unsafe_hash=True) 2624 class C: 2625 i: int 2626 def __eq__(self, other): 2627 return self.i == other.i 2628 self.assertEqual(C(1), C(1.0)) 2629 self.assertEqual(hash(C(1)), hash(C(1.0))) 2630 2631 # And check that the classes __eq__ is being used, despite 2632 # specifying eq=True. 2633 @dataclass(unsafe_hash=True, eq=True) 2634 class C: 2635 i: int 2636 def __eq__(self, other): 2637 return self.i == 3 and self.i == other.i 2638 self.assertEqual(C(3), C(3)) 2639 self.assertNotEqual(C(1), C(1)) 2640 self.assertEqual(hash(C(1)), hash(C(1.0))) 2641 2642 def test_0_field_hash(self): 2643 @dataclass(frozen=True) 2644 class C: 2645 pass 2646 self.assertEqual(hash(C()), hash(())) 2647 2648 @dataclass(unsafe_hash=True) 2649 class C: 2650 pass 2651 self.assertEqual(hash(C()), hash(())) 2652 2653 def test_1_field_hash(self): 2654 @dataclass(frozen=True) 2655 class C: 2656 x: int 2657 self.assertEqual(hash(C(4)), hash((4,))) 2658 self.assertEqual(hash(C(42)), hash((42,))) 2659 2660 @dataclass(unsafe_hash=True) 2661 class C: 2662 x: int 2663 self.assertEqual(hash(C(4)), hash((4,))) 2664 self.assertEqual(hash(C(42)), hash((42,))) 2665 2666 def test_hash_no_args(self): 2667 # Test dataclasses with no hash= argument. This exists to 2668 # make sure that if the @dataclass parameter name is changed 2669 # or the non-default hashing behavior changes, the default 2670 # hashability keeps working the same way. 2671 2672 class Base: 2673 def __hash__(self): 2674 return 301 2675 2676 # If frozen or eq is None, then use the default value (do not 2677 # specify any value in the decorator). 2678 for frozen, eq, base, expected in [ 2679 (None, None, object, 'unhashable'), 2680 (None, None, Base, 'unhashable'), 2681 (None, False, object, 'object'), 2682 (None, False, Base, 'base'), 2683 (None, True, object, 'unhashable'), 2684 (None, True, Base, 'unhashable'), 2685 (False, None, object, 'unhashable'), 2686 (False, None, Base, 'unhashable'), 2687 (False, False, object, 'object'), 2688 (False, False, Base, 'base'), 2689 (False, True, object, 'unhashable'), 2690 (False, True, Base, 'unhashable'), 2691 (True, None, object, 'tuple'), 2692 (True, None, Base, 'tuple'), 2693 (True, False, object, 'object'), 2694 (True, False, Base, 'base'), 2695 (True, True, object, 'tuple'), 2696 (True, True, Base, 'tuple'), 2697 ]: 2698 2699 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected): 2700 # First, create the class. 2701 if frozen is None and eq is None: 2702 @dataclass 2703 class C(base): 2704 i: int 2705 elif frozen is None: 2706 @dataclass(eq=eq) 2707 class C(base): 2708 i: int 2709 elif eq is None: 2710 @dataclass(frozen=frozen) 2711 class C(base): 2712 i: int 2713 else: 2714 @dataclass(frozen=frozen, eq=eq) 2715 class C(base): 2716 i: int 2717 2718 # Now, make sure it hashes as expected. 2719 if expected == 'unhashable': 2720 c = C(10) 2721 with self.assertRaisesRegex(TypeError, 'unhashable type'): 2722 hash(c) 2723 2724 elif expected == 'base': 2725 self.assertEqual(hash(C(10)), 301) 2726 2727 elif expected == 'object': 2728 # I'm not sure what test to use here. object's 2729 # hash isn't based on id(), so calling hash() 2730 # won't tell us much. So, just check the 2731 # function used is object's. 2732 self.assertIs(C.__hash__, object.__hash__) 2733 2734 elif expected == 'tuple': 2735 self.assertEqual(hash(C(42)), hash((42,))) 2736 2737 else: 2738 assert False, f'unknown value for expected={expected!r}' 2739 2740 2741class TestFrozen(unittest.TestCase): 2742 def test_frozen(self): 2743 @dataclass(frozen=True) 2744 class C: 2745 i: int 2746 2747 c = C(10) 2748 self.assertEqual(c.i, 10) 2749 with self.assertRaises(FrozenInstanceError): 2750 c.i = 5 2751 self.assertEqual(c.i, 10) 2752 2753 def test_inherit(self): 2754 @dataclass(frozen=True) 2755 class C: 2756 i: int 2757 2758 @dataclass(frozen=True) 2759 class D(C): 2760 j: int 2761 2762 d = D(0, 10) 2763 with self.assertRaises(FrozenInstanceError): 2764 d.i = 5 2765 with self.assertRaises(FrozenInstanceError): 2766 d.j = 6 2767 self.assertEqual(d.i, 0) 2768 self.assertEqual(d.j, 10) 2769 2770 def test_inherit_nonfrozen_from_empty_frozen(self): 2771 @dataclass(frozen=True) 2772 class C: 2773 pass 2774 2775 with self.assertRaisesRegex(TypeError, 2776 'cannot inherit non-frozen dataclass from a frozen one'): 2777 @dataclass 2778 class D(C): 2779 j: int 2780 2781 def test_inherit_nonfrozen_from_empty(self): 2782 @dataclass 2783 class C: 2784 pass 2785 2786 @dataclass 2787 class D(C): 2788 j: int 2789 2790 d = D(3) 2791 self.assertEqual(d.j, 3) 2792 self.assertIsInstance(d, C) 2793 2794 # Test both ways: with an intermediate normal (non-dataclass) 2795 # class and without an intermediate class. 2796 def test_inherit_nonfrozen_from_frozen(self): 2797 for intermediate_class in [True, False]: 2798 with self.subTest(intermediate_class=intermediate_class): 2799 @dataclass(frozen=True) 2800 class C: 2801 i: int 2802 2803 if intermediate_class: 2804 class I(C): pass 2805 else: 2806 I = C 2807 2808 with self.assertRaisesRegex(TypeError, 2809 'cannot inherit non-frozen dataclass from a frozen one'): 2810 @dataclass 2811 class D(I): 2812 pass 2813 2814 def test_inherit_frozen_from_nonfrozen(self): 2815 for intermediate_class in [True, False]: 2816 with self.subTest(intermediate_class=intermediate_class): 2817 @dataclass 2818 class C: 2819 i: int 2820 2821 if intermediate_class: 2822 class I(C): pass 2823 else: 2824 I = C 2825 2826 with self.assertRaisesRegex(TypeError, 2827 'cannot inherit frozen dataclass from a non-frozen one'): 2828 @dataclass(frozen=True) 2829 class D(I): 2830 pass 2831 2832 def test_inherit_from_normal_class(self): 2833 for intermediate_class in [True, False]: 2834 with self.subTest(intermediate_class=intermediate_class): 2835 class C: 2836 pass 2837 2838 if intermediate_class: 2839 class I(C): pass 2840 else: 2841 I = C 2842 2843 @dataclass(frozen=True) 2844 class D(I): 2845 i: int 2846 2847 d = D(10) 2848 with self.assertRaises(FrozenInstanceError): 2849 d.i = 5 2850 2851 def test_non_frozen_normal_derived(self): 2852 # See bpo-32953. 2853 2854 @dataclass(frozen=True) 2855 class D: 2856 x: int 2857 y: int = 10 2858 2859 class S(D): 2860 pass 2861 2862 s = S(3) 2863 self.assertEqual(s.x, 3) 2864 self.assertEqual(s.y, 10) 2865 s.cached = True 2866 2867 # But can't change the frozen attributes. 2868 with self.assertRaises(FrozenInstanceError): 2869 s.x = 5 2870 with self.assertRaises(FrozenInstanceError): 2871 s.y = 5 2872 self.assertEqual(s.x, 3) 2873 self.assertEqual(s.y, 10) 2874 self.assertEqual(s.cached, True) 2875 2876 def test_overwriting_frozen(self): 2877 # frozen uses __setattr__ and __delattr__. 2878 with self.assertRaisesRegex(TypeError, 2879 'Cannot overwrite attribute __setattr__'): 2880 @dataclass(frozen=True) 2881 class C: 2882 x: int 2883 def __setattr__(self): 2884 pass 2885 2886 with self.assertRaisesRegex(TypeError, 2887 'Cannot overwrite attribute __delattr__'): 2888 @dataclass(frozen=True) 2889 class C: 2890 x: int 2891 def __delattr__(self): 2892 pass 2893 2894 @dataclass(frozen=False) 2895 class C: 2896 x: int 2897 def __setattr__(self, name, value): 2898 self.__dict__['x'] = value * 2 2899 self.assertEqual(C(10).x, 20) 2900 2901 def test_frozen_hash(self): 2902 @dataclass(frozen=True) 2903 class C: 2904 x: Any 2905 2906 # If x is immutable, we can compute the hash. No exception is 2907 # raised. 2908 hash(C(3)) 2909 2910 # If x is mutable, computing the hash is an error. 2911 with self.assertRaisesRegex(TypeError, 'unhashable type'): 2912 hash(C({})) 2913 2914 2915class TestSlots(unittest.TestCase): 2916 def test_simple(self): 2917 @dataclass 2918 class C: 2919 __slots__ = ('x',) 2920 x: Any 2921 2922 # There was a bug where a variable in a slot was assumed to 2923 # also have a default value (of type 2924 # types.MemberDescriptorType). 2925 with self.assertRaisesRegex(TypeError, 2926 r"__init__\(\) missing 1 required positional argument: 'x'"): 2927 C() 2928 2929 # We can create an instance, and assign to x. 2930 c = C(10) 2931 self.assertEqual(c.x, 10) 2932 c.x = 5 2933 self.assertEqual(c.x, 5) 2934 2935 # We can't assign to anything else. 2936 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"): 2937 c.y = 5 2938 2939 def test_derived_added_field(self): 2940 # See bpo-33100. 2941 @dataclass 2942 class Base: 2943 __slots__ = ('x',) 2944 x: Any 2945 2946 @dataclass 2947 class Derived(Base): 2948 x: int 2949 y: int 2950 2951 d = Derived(1, 2) 2952 self.assertEqual((d.x, d.y), (1, 2)) 2953 2954 # We can add a new field to the derived instance. 2955 d.z = 10 2956 2957 def test_generated_slots(self): 2958 @dataclass(slots=True) 2959 class C: 2960 x: int 2961 y: int 2962 2963 c = C(1, 2) 2964 self.assertEqual((c.x, c.y), (1, 2)) 2965 2966 c.x = 3 2967 c.y = 4 2968 self.assertEqual((c.x, c.y), (3, 4)) 2969 2970 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"): 2971 c.z = 5 2972 2973 def test_add_slots_when_slots_exists(self): 2974 with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'): 2975 @dataclass(slots=True) 2976 class C: 2977 __slots__ = ('x',) 2978 x: int 2979 2980 def test_generated_slots_value(self): 2981 2982 class Root: 2983 __slots__ = {'x'} 2984 2985 class Root2(Root): 2986 __slots__ = {'k': '...', 'j': ''} 2987 2988 class Root3(Root2): 2989 __slots__ = ['h'] 2990 2991 class Root4(Root3): 2992 __slots__ = 'aa' 2993 2994 @dataclass(slots=True) 2995 class Base(Root4): 2996 y: int 2997 j: str 2998 h: str 2999 3000 self.assertEqual(Base.__slots__, ('y', )) 3001 3002 @dataclass(slots=True) 3003 class Derived(Base): 3004 aa: float 3005 x: str 3006 z: int 3007 k: str 3008 h: str 3009 3010 self.assertEqual(Derived.__slots__, ('z', )) 3011 3012 @dataclass 3013 class AnotherDerived(Base): 3014 z: int 3015 3016 self.assertNotIn('__slots__', AnotherDerived.__dict__) 3017 3018 def test_cant_inherit_from_iterator_slots(self): 3019 3020 class Root: 3021 __slots__ = iter(['a']) 3022 3023 class Root2(Root): 3024 __slots__ = ('b', ) 3025 3026 with self.assertRaisesRegex( 3027 TypeError, 3028 "^Slots of 'Root' cannot be determined" 3029 ): 3030 @dataclass(slots=True) 3031 class C(Root2): 3032 x: int 3033 3034 def test_returns_new_class(self): 3035 class A: 3036 x: int 3037 3038 B = dataclass(A, slots=True) 3039 self.assertIsNot(A, B) 3040 3041 self.assertFalse(hasattr(A, "__slots__")) 3042 self.assertTrue(hasattr(B, "__slots__")) 3043 3044 # Can't be local to test_frozen_pickle. 3045 @dataclass(frozen=True, slots=True) 3046 class FrozenSlotsClass: 3047 foo: str 3048 bar: int 3049 3050 @dataclass(frozen=True) 3051 class FrozenWithoutSlotsClass: 3052 foo: str 3053 bar: int 3054 3055 def test_frozen_pickle(self): 3056 # bpo-43999 3057 3058 self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar")) 3059 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 3060 with self.subTest(proto=proto): 3061 obj = self.FrozenSlotsClass("a", 1) 3062 p = pickle.loads(pickle.dumps(obj, protocol=proto)) 3063 self.assertIsNot(obj, p) 3064 self.assertEqual(obj, p) 3065 3066 obj = self.FrozenWithoutSlotsClass("a", 1) 3067 p = pickle.loads(pickle.dumps(obj, protocol=proto)) 3068 self.assertIsNot(obj, p) 3069 self.assertEqual(obj, p) 3070 3071 @dataclass(frozen=True, slots=True) 3072 class FrozenSlotsGetStateClass: 3073 foo: str 3074 bar: int 3075 3076 getstate_called: bool = field(default=False, compare=False) 3077 3078 def __getstate__(self): 3079 object.__setattr__(self, 'getstate_called', True) 3080 return [self.foo, self.bar] 3081 3082 @dataclass(frozen=True, slots=True) 3083 class FrozenSlotsSetStateClass: 3084 foo: str 3085 bar: int 3086 3087 setstate_called: bool = field(default=False, compare=False) 3088 3089 def __setstate__(self, state): 3090 object.__setattr__(self, 'setstate_called', True) 3091 object.__setattr__(self, 'foo', state[0]) 3092 object.__setattr__(self, 'bar', state[1]) 3093 3094 @dataclass(frozen=True, slots=True) 3095 class FrozenSlotsAllStateClass: 3096 foo: str 3097 bar: int 3098 3099 getstate_called: bool = field(default=False, compare=False) 3100 setstate_called: bool = field(default=False, compare=False) 3101 3102 def __getstate__(self): 3103 object.__setattr__(self, 'getstate_called', True) 3104 return [self.foo, self.bar] 3105 3106 def __setstate__(self, state): 3107 object.__setattr__(self, 'setstate_called', True) 3108 object.__setattr__(self, 'foo', state[0]) 3109 object.__setattr__(self, 'bar', state[1]) 3110 3111 def test_frozen_slots_pickle_custom_state(self): 3112 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 3113 with self.subTest(proto=proto): 3114 obj = self.FrozenSlotsGetStateClass('a', 1) 3115 dumped = pickle.dumps(obj, protocol=proto) 3116 3117 self.assertTrue(obj.getstate_called) 3118 self.assertEqual(obj, pickle.loads(dumped)) 3119 3120 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 3121 with self.subTest(proto=proto): 3122 obj = self.FrozenSlotsSetStateClass('a', 1) 3123 obj2 = pickle.loads(pickle.dumps(obj, protocol=proto)) 3124 3125 self.assertTrue(obj2.setstate_called) 3126 self.assertEqual(obj, obj2) 3127 3128 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 3129 with self.subTest(proto=proto): 3130 obj = self.FrozenSlotsAllStateClass('a', 1) 3131 dumped = pickle.dumps(obj, protocol=proto) 3132 3133 self.assertTrue(obj.getstate_called) 3134 3135 obj2 = pickle.loads(dumped) 3136 self.assertTrue(obj2.setstate_called) 3137 self.assertEqual(obj, obj2) 3138 3139 def test_slots_with_default_no_init(self): 3140 # Originally reported in bpo-44649. 3141 @dataclass(slots=True) 3142 class A: 3143 a: str 3144 b: str = field(default='b', init=False) 3145 3146 obj = A("a") 3147 self.assertEqual(obj.a, 'a') 3148 self.assertEqual(obj.b, 'b') 3149 3150 def test_slots_with_default_factory_no_init(self): 3151 # Originally reported in bpo-44649. 3152 @dataclass(slots=True) 3153 class A: 3154 a: str 3155 b: str = field(default_factory=lambda:'b', init=False) 3156 3157 obj = A("a") 3158 self.assertEqual(obj.a, 'a') 3159 self.assertEqual(obj.b, 'b') 3160 3161 def test_slots_no_weakref(self): 3162 @dataclass(slots=True) 3163 class A: 3164 # No weakref. 3165 pass 3166 3167 self.assertNotIn("__weakref__", A.__slots__) 3168 a = A() 3169 with self.assertRaisesRegex(TypeError, 3170 "cannot create weak reference"): 3171 weakref.ref(a) 3172 with self.assertRaises(AttributeError): 3173 a.__weakref__ 3174 3175 def test_slots_weakref(self): 3176 @dataclass(slots=True, weakref_slot=True) 3177 class A: 3178 a: int 3179 3180 self.assertIn("__weakref__", A.__slots__) 3181 a = A(1) 3182 a_ref = weakref.ref(a) 3183 3184 self.assertIs(a.__weakref__, a_ref) 3185 3186 def test_slots_weakref_base_str(self): 3187 class Base: 3188 __slots__ = '__weakref__' 3189 3190 @dataclass(slots=True) 3191 class A(Base): 3192 a: int 3193 3194 # __weakref__ is in the base class, not A. But an A is still weakref-able. 3195 self.assertIn("__weakref__", Base.__slots__) 3196 self.assertNotIn("__weakref__", A.__slots__) 3197 a = A(1) 3198 weakref.ref(a) 3199 3200 def test_slots_weakref_base_tuple(self): 3201 # Same as test_slots_weakref_base, but use a tuple instead of a string 3202 # in the base class. 3203 class Base: 3204 __slots__ = ('__weakref__',) 3205 3206 @dataclass(slots=True) 3207 class A(Base): 3208 a: int 3209 3210 # __weakref__ is in the base class, not A. But an A is still 3211 # weakref-able. 3212 self.assertIn("__weakref__", Base.__slots__) 3213 self.assertNotIn("__weakref__", A.__slots__) 3214 a = A(1) 3215 weakref.ref(a) 3216 3217 def test_weakref_slot_without_slot(self): 3218 with self.assertRaisesRegex(TypeError, 3219 "weakref_slot is True but slots is False"): 3220 @dataclass(weakref_slot=True) 3221 class A: 3222 a: int 3223 3224 def test_weakref_slot_make_dataclass(self): 3225 A = make_dataclass('A', [('a', int),], slots=True, weakref_slot=True) 3226 self.assertIn("__weakref__", A.__slots__) 3227 a = A(1) 3228 weakref.ref(a) 3229 3230 # And make sure if raises if slots=True is not given. 3231 with self.assertRaisesRegex(TypeError, 3232 "weakref_slot is True but slots is False"): 3233 B = make_dataclass('B', [('a', int),], weakref_slot=True) 3234 3235 def test_weakref_slot_subclass_weakref_slot(self): 3236 @dataclass(slots=True, weakref_slot=True) 3237 class Base: 3238 field: int 3239 3240 # A *can* also specify weakref_slot=True if it wants to (gh-93521) 3241 @dataclass(slots=True, weakref_slot=True) 3242 class A(Base): 3243 ... 3244 3245 # __weakref__ is in the base class, not A. But an instance of A 3246 # is still weakref-able. 3247 self.assertIn("__weakref__", Base.__slots__) 3248 self.assertNotIn("__weakref__", A.__slots__) 3249 a = A(1) 3250 a_ref = weakref.ref(a) 3251 self.assertIs(a.__weakref__, a_ref) 3252 3253 def test_weakref_slot_subclass_no_weakref_slot(self): 3254 @dataclass(slots=True, weakref_slot=True) 3255 class Base: 3256 field: int 3257 3258 @dataclass(slots=True) 3259 class A(Base): 3260 ... 3261 3262 # __weakref__ is in the base class, not A. Even though A doesn't 3263 # specify weakref_slot, it should still be weakref-able. 3264 self.assertIn("__weakref__", Base.__slots__) 3265 self.assertNotIn("__weakref__", A.__slots__) 3266 a = A(1) 3267 a_ref = weakref.ref(a) 3268 self.assertIs(a.__weakref__, a_ref) 3269 3270 def test_weakref_slot_normal_base_weakref_slot(self): 3271 class Base: 3272 __slots__ = ('__weakref__',) 3273 3274 @dataclass(slots=True, weakref_slot=True) 3275 class A(Base): 3276 field: int 3277 3278 # __weakref__ is in the base class, not A. But an instance of 3279 # A is still weakref-able. 3280 self.assertIn("__weakref__", Base.__slots__) 3281 self.assertNotIn("__weakref__", A.__slots__) 3282 a = A(1) 3283 a_ref = weakref.ref(a) 3284 self.assertIs(a.__weakref__, a_ref) 3285 3286 3287class TestDescriptors(unittest.TestCase): 3288 def test_set_name(self): 3289 # See bpo-33141. 3290 3291 # Create a descriptor. 3292 class D: 3293 def __set_name__(self, owner, name): 3294 self.name = name + 'x' 3295 def __get__(self, instance, owner): 3296 if instance is not None: 3297 return 1 3298 return self 3299 3300 # This is the case of just normal descriptor behavior, no 3301 # dataclass code is involved in initializing the descriptor. 3302 @dataclass 3303 class C: 3304 c: int=D() 3305 self.assertEqual(C.c.name, 'cx') 3306 3307 # Now test with a default value and init=False, which is the 3308 # only time this is really meaningful. If not using 3309 # init=False, then the descriptor will be overwritten, anyway. 3310 @dataclass 3311 class C: 3312 c: int=field(default=D(), init=False) 3313 self.assertEqual(C.c.name, 'cx') 3314 self.assertEqual(C().c, 1) 3315 3316 def test_non_descriptor(self): 3317 # PEP 487 says __set_name__ should work on non-descriptors. 3318 # Create a descriptor. 3319 3320 class D: 3321 def __set_name__(self, owner, name): 3322 self.name = name + 'x' 3323 3324 @dataclass 3325 class C: 3326 c: int=field(default=D(), init=False) 3327 self.assertEqual(C.c.name, 'cx') 3328 3329 def test_lookup_on_instance(self): 3330 # See bpo-33175. 3331 class D: 3332 pass 3333 3334 d = D() 3335 # Create an attribute on the instance, not type. 3336 d.__set_name__ = Mock() 3337 3338 # Make sure d.__set_name__ is not called. 3339 @dataclass 3340 class C: 3341 i: int=field(default=d, init=False) 3342 3343 self.assertEqual(d.__set_name__.call_count, 0) 3344 3345 def test_lookup_on_class(self): 3346 # See bpo-33175. 3347 class D: 3348 pass 3349 D.__set_name__ = Mock() 3350 3351 # Make sure D.__set_name__ is called. 3352 @dataclass 3353 class C: 3354 i: int=field(default=D(), init=False) 3355 3356 self.assertEqual(D.__set_name__.call_count, 1) 3357 3358 def test_init_calls_set(self): 3359 class D: 3360 pass 3361 3362 D.__set__ = Mock() 3363 3364 @dataclass 3365 class C: 3366 i: D = D() 3367 3368 # Make sure D.__set__ is called. 3369 D.__set__.reset_mock() 3370 c = C(5) 3371 self.assertEqual(D.__set__.call_count, 1) 3372 3373 def test_getting_field_calls_get(self): 3374 class D: 3375 pass 3376 3377 D.__set__ = Mock() 3378 D.__get__ = Mock() 3379 3380 @dataclass 3381 class C: 3382 i: D = D() 3383 3384 c = C(5) 3385 3386 # Make sure D.__get__ is called. 3387 D.__get__.reset_mock() 3388 value = c.i 3389 self.assertEqual(D.__get__.call_count, 1) 3390 3391 def test_setting_field_calls_set(self): 3392 class D: 3393 pass 3394 3395 D.__set__ = Mock() 3396 3397 @dataclass 3398 class C: 3399 i: D = D() 3400 3401 c = C(5) 3402 3403 # Make sure D.__set__ is called. 3404 D.__set__.reset_mock() 3405 c.i = 10 3406 self.assertEqual(D.__set__.call_count, 1) 3407 3408 def test_setting_uninitialized_descriptor_field(self): 3409 class D: 3410 pass 3411 3412 D.__set__ = Mock() 3413 3414 @dataclass 3415 class C: 3416 i: D 3417 3418 # D.__set__ is not called because there's no D instance to call it on 3419 D.__set__.reset_mock() 3420 c = C(5) 3421 self.assertEqual(D.__set__.call_count, 0) 3422 3423 # D.__set__ still isn't called after setting i to an instance of D 3424 # because descriptors don't behave like that when stored as instance vars 3425 c.i = D() 3426 c.i = 5 3427 self.assertEqual(D.__set__.call_count, 0) 3428 3429 def test_default_value(self): 3430 class D: 3431 def __get__(self, instance: Any, owner: object) -> int: 3432 if instance is None: 3433 return 100 3434 3435 return instance._x 3436 3437 def __set__(self, instance: Any, value: int) -> None: 3438 instance._x = value 3439 3440 @dataclass 3441 class C: 3442 i: D = D() 3443 3444 c = C() 3445 self.assertEqual(c.i, 100) 3446 3447 c = C(5) 3448 self.assertEqual(c.i, 5) 3449 3450 def test_no_default_value(self): 3451 class D: 3452 def __get__(self, instance: Any, owner: object) -> int: 3453 if instance is None: 3454 raise AttributeError() 3455 3456 return instance._x 3457 3458 def __set__(self, instance: Any, value: int) -> None: 3459 instance._x = value 3460 3461 @dataclass 3462 class C: 3463 i: D = D() 3464 3465 with self.assertRaisesRegex(TypeError, 'missing 1 required positional argument'): 3466 c = C() 3467 3468class TestStringAnnotations(unittest.TestCase): 3469 def test_classvar(self): 3470 # Some expressions recognized as ClassVar really aren't. But 3471 # if you're using string annotations, it's not an exact 3472 # science. 3473 # These tests assume that both "import typing" and "from 3474 # typing import *" have been run in this file. 3475 for typestr in ('ClassVar[int]', 3476 'ClassVar [int]', 3477 ' ClassVar [int]', 3478 'ClassVar', 3479 ' ClassVar ', 3480 'typing.ClassVar[int]', 3481 'typing.ClassVar[str]', 3482 ' typing.ClassVar[str]', 3483 'typing .ClassVar[str]', 3484 'typing. ClassVar[str]', 3485 'typing.ClassVar [str]', 3486 'typing.ClassVar [ str]', 3487 3488 # Not syntactically valid, but these will 3489 # be treated as ClassVars. 3490 'typing.ClassVar.[int]', 3491 'typing.ClassVar+', 3492 ): 3493 with self.subTest(typestr=typestr): 3494 @dataclass 3495 class C: 3496 x: typestr 3497 3498 # x is a ClassVar, so C() takes no args. 3499 C() 3500 3501 # And it won't appear in the class's dict because it doesn't 3502 # have a default. 3503 self.assertNotIn('x', C.__dict__) 3504 3505 def test_isnt_classvar(self): 3506 for typestr in ('CV', 3507 't.ClassVar', 3508 't.ClassVar[int]', 3509 'typing..ClassVar[int]', 3510 'Classvar', 3511 'Classvar[int]', 3512 'typing.ClassVarx[int]', 3513 'typong.ClassVar[int]', 3514 'dataclasses.ClassVar[int]', 3515 'typingxClassVar[str]', 3516 ): 3517 with self.subTest(typestr=typestr): 3518 @dataclass 3519 class C: 3520 x: typestr 3521 3522 # x is not a ClassVar, so C() takes one arg. 3523 self.assertEqual(C(10).x, 10) 3524 3525 def test_initvar(self): 3526 # These tests assume that both "import dataclasses" and "from 3527 # dataclasses import *" have been run in this file. 3528 for typestr in ('InitVar[int]', 3529 'InitVar [int]' 3530 ' InitVar [int]', 3531 'InitVar', 3532 ' InitVar ', 3533 'dataclasses.InitVar[int]', 3534 'dataclasses.InitVar[str]', 3535 ' dataclasses.InitVar[str]', 3536 'dataclasses .InitVar[str]', 3537 'dataclasses. InitVar[str]', 3538 'dataclasses.InitVar [str]', 3539 'dataclasses.InitVar [ str]', 3540 3541 # Not syntactically valid, but these will 3542 # be treated as InitVars. 3543 'dataclasses.InitVar.[int]', 3544 'dataclasses.InitVar+', 3545 ): 3546 with self.subTest(typestr=typestr): 3547 @dataclass 3548 class C: 3549 x: typestr 3550 3551 # x is an InitVar, so doesn't create a member. 3552 with self.assertRaisesRegex(AttributeError, 3553 "object has no attribute 'x'"): 3554 C(1).x 3555 3556 def test_isnt_initvar(self): 3557 for typestr in ('IV', 3558 'dc.InitVar', 3559 'xdataclasses.xInitVar', 3560 'typing.xInitVar[int]', 3561 ): 3562 with self.subTest(typestr=typestr): 3563 @dataclass 3564 class C: 3565 x: typestr 3566 3567 # x is not an InitVar, so there will be a member x. 3568 self.assertEqual(C(10).x, 10) 3569 3570 def test_classvar_module_level_import(self): 3571 from test import dataclass_module_1 3572 from test import dataclass_module_1_str 3573 from test import dataclass_module_2 3574 from test import dataclass_module_2_str 3575 3576 for m in (dataclass_module_1, dataclass_module_1_str, 3577 dataclass_module_2, dataclass_module_2_str, 3578 ): 3579 with self.subTest(m=m): 3580 # There's a difference in how the ClassVars are 3581 # interpreted when using string annotations or 3582 # not. See the imported modules for details. 3583 if m.USING_STRINGS: 3584 c = m.CV(10) 3585 else: 3586 c = m.CV() 3587 self.assertEqual(c.cv0, 20) 3588 3589 3590 # There's a difference in how the InitVars are 3591 # interpreted when using string annotations or 3592 # not. See the imported modules for details. 3593 c = m.IV(0, 1, 2, 3, 4) 3594 3595 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'): 3596 with self.subTest(field_name=field_name): 3597 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"): 3598 # Since field_name is an InitVar, it's 3599 # not an instance field. 3600 getattr(c, field_name) 3601 3602 if m.USING_STRINGS: 3603 # iv4 is interpreted as a normal field. 3604 self.assertIn('not_iv4', c.__dict__) 3605 self.assertEqual(c.not_iv4, 4) 3606 else: 3607 # iv4 is interpreted as an InitVar, so it 3608 # won't exist on the instance. 3609 self.assertNotIn('not_iv4', c.__dict__) 3610 3611 def test_text_annotations(self): 3612 from test import dataclass_textanno 3613 3614 self.assertEqual( 3615 get_type_hints(dataclass_textanno.Bar), 3616 {'foo': dataclass_textanno.Foo}) 3617 self.assertEqual( 3618 get_type_hints(dataclass_textanno.Bar.__init__), 3619 {'foo': dataclass_textanno.Foo, 3620 'return': type(None)}) 3621 3622 3623class TestMakeDataclass(unittest.TestCase): 3624 def test_simple(self): 3625 C = make_dataclass('C', 3626 [('x', int), 3627 ('y', int, field(default=5))], 3628 namespace={'add_one': lambda self: self.x + 1}) 3629 c = C(10) 3630 self.assertEqual((c.x, c.y), (10, 5)) 3631 self.assertEqual(c.add_one(), 11) 3632 3633 3634 def test_no_mutate_namespace(self): 3635 # Make sure a provided namespace isn't mutated. 3636 ns = {} 3637 C = make_dataclass('C', 3638 [('x', int), 3639 ('y', int, field(default=5))], 3640 namespace=ns) 3641 self.assertEqual(ns, {}) 3642 3643 def test_base(self): 3644 class Base1: 3645 pass 3646 class Base2: 3647 pass 3648 C = make_dataclass('C', 3649 [('x', int)], 3650 bases=(Base1, Base2)) 3651 c = C(2) 3652 self.assertIsInstance(c, C) 3653 self.assertIsInstance(c, Base1) 3654 self.assertIsInstance(c, Base2) 3655 3656 def test_base_dataclass(self): 3657 @dataclass 3658 class Base1: 3659 x: int 3660 class Base2: 3661 pass 3662 C = make_dataclass('C', 3663 [('y', int)], 3664 bases=(Base1, Base2)) 3665 with self.assertRaisesRegex(TypeError, 'required positional'): 3666 c = C(2) 3667 c = C(1, 2) 3668 self.assertIsInstance(c, C) 3669 self.assertIsInstance(c, Base1) 3670 self.assertIsInstance(c, Base2) 3671 3672 self.assertEqual((c.x, c.y), (1, 2)) 3673 3674 def test_init_var(self): 3675 def post_init(self, y): 3676 self.x *= y 3677 3678 C = make_dataclass('C', 3679 [('x', int), 3680 ('y', InitVar[int]), 3681 ], 3682 namespace={'__post_init__': post_init}, 3683 ) 3684 c = C(2, 3) 3685 self.assertEqual(vars(c), {'x': 6}) 3686 self.assertEqual(len(fields(c)), 1) 3687 3688 def test_class_var(self): 3689 C = make_dataclass('C', 3690 [('x', int), 3691 ('y', ClassVar[int], 10), 3692 ('z', ClassVar[int], field(default=20)), 3693 ]) 3694 c = C(1) 3695 self.assertEqual(vars(c), {'x': 1}) 3696 self.assertEqual(len(fields(c)), 1) 3697 self.assertEqual(C.y, 10) 3698 self.assertEqual(C.z, 20) 3699 3700 def test_other_params(self): 3701 C = make_dataclass('C', 3702 [('x', int), 3703 ('y', ClassVar[int], 10), 3704 ('z', ClassVar[int], field(default=20)), 3705 ], 3706 init=False) 3707 # Make sure we have a repr, but no init. 3708 self.assertNotIn('__init__', vars(C)) 3709 self.assertIn('__repr__', vars(C)) 3710 3711 # Make sure random other params don't work. 3712 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): 3713 C = make_dataclass('C', 3714 [], 3715 xxinit=False) 3716 3717 def test_no_types(self): 3718 C = make_dataclass('Point', ['x', 'y', 'z']) 3719 c = C(1, 2, 3) 3720 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 3721 self.assertEqual(C.__annotations__, {'x': 'typing.Any', 3722 'y': 'typing.Any', 3723 'z': 'typing.Any'}) 3724 3725 C = make_dataclass('Point', ['x', ('y', int), 'z']) 3726 c = C(1, 2, 3) 3727 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 3728 self.assertEqual(C.__annotations__, {'x': 'typing.Any', 3729 'y': int, 3730 'z': 'typing.Any'}) 3731 3732 def test_invalid_type_specification(self): 3733 for bad_field in [(), 3734 (1, 2, 3, 4), 3735 ]: 3736 with self.subTest(bad_field=bad_field): 3737 with self.assertRaisesRegex(TypeError, r'Invalid field: '): 3738 make_dataclass('C', ['a', bad_field]) 3739 3740 # And test for things with no len(). 3741 for bad_field in [float, 3742 lambda x:x, 3743 ]: 3744 with self.subTest(bad_field=bad_field): 3745 with self.assertRaisesRegex(TypeError, r'has no len\(\)'): 3746 make_dataclass('C', ['a', bad_field]) 3747 3748 def test_duplicate_field_names(self): 3749 for field in ['a', 'ab']: 3750 with self.subTest(field=field): 3751 with self.assertRaisesRegex(TypeError, 'Field name duplicated'): 3752 make_dataclass('C', [field, 'a', field]) 3753 3754 def test_keyword_field_names(self): 3755 for field in ['for', 'async', 'await', 'as']: 3756 with self.subTest(field=field): 3757 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3758 make_dataclass('C', ['a', field]) 3759 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3760 make_dataclass('C', [field]) 3761 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3762 make_dataclass('C', [field, 'a']) 3763 3764 def test_non_identifier_field_names(self): 3765 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']: 3766 with self.subTest(field=field): 3767 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 3768 make_dataclass('C', ['a', field]) 3769 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 3770 make_dataclass('C', [field]) 3771 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 3772 make_dataclass('C', [field, 'a']) 3773 3774 def test_underscore_field_names(self): 3775 # Unlike namedtuple, it's okay if dataclass field names have 3776 # an underscore. 3777 make_dataclass('C', ['_', '_a', 'a_a', 'a_']) 3778 3779 def test_funny_class_names_names(self): 3780 # No reason to prevent weird class names, since 3781 # types.new_class allows them. 3782 for classname in ['()', 'x,y', '*', '2@3', '']: 3783 with self.subTest(classname=classname): 3784 C = make_dataclass(classname, ['a', 'b']) 3785 self.assertEqual(C.__name__, classname) 3786 3787class TestReplace(unittest.TestCase): 3788 def test(self): 3789 @dataclass(frozen=True) 3790 class C: 3791 x: int 3792 y: int 3793 3794 c = C(1, 2) 3795 c1 = replace(c, x=3) 3796 self.assertEqual(c1.x, 3) 3797 self.assertEqual(c1.y, 2) 3798 3799 def test_frozen(self): 3800 @dataclass(frozen=True) 3801 class C: 3802 x: int 3803 y: int 3804 z: int = field(init=False, default=10) 3805 t: int = field(init=False, default=100) 3806 3807 c = C(1, 2) 3808 c1 = replace(c, x=3) 3809 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100)) 3810 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100)) 3811 3812 3813 with self.assertRaisesRegex(ValueError, 'init=False'): 3814 replace(c, x=3, z=20, t=50) 3815 with self.assertRaisesRegex(ValueError, 'init=False'): 3816 replace(c, z=20) 3817 replace(c, x=3, z=20, t=50) 3818 3819 # Make sure the result is still frozen. 3820 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"): 3821 c1.x = 3 3822 3823 # Make sure we can't replace an attribute that doesn't exist, 3824 # if we're also replacing one that does exist. Test this 3825 # here, because setting attributes on frozen instances is 3826 # handled slightly differently from non-frozen ones. 3827 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 3828 "keyword argument 'a'"): 3829 c1 = replace(c, x=20, a=5) 3830 3831 def test_invalid_field_name(self): 3832 @dataclass(frozen=True) 3833 class C: 3834 x: int 3835 y: int 3836 3837 c = C(1, 2) 3838 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 3839 "keyword argument 'z'"): 3840 c1 = replace(c, z=3) 3841 3842 def test_invalid_object(self): 3843 @dataclass(frozen=True) 3844 class C: 3845 x: int 3846 y: int 3847 3848 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 3849 replace(C, x=3) 3850 3851 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 3852 replace(0, x=3) 3853 3854 def test_no_init(self): 3855 @dataclass 3856 class C: 3857 x: int 3858 y: int = field(init=False, default=10) 3859 3860 c = C(1) 3861 c.y = 20 3862 3863 # Make sure y gets the default value. 3864 c1 = replace(c, x=5) 3865 self.assertEqual((c1.x, c1.y), (5, 10)) 3866 3867 # Trying to replace y is an error. 3868 with self.assertRaisesRegex(ValueError, 'init=False'): 3869 replace(c, x=2, y=30) 3870 3871 with self.assertRaisesRegex(ValueError, 'init=False'): 3872 replace(c, y=30) 3873 3874 def test_classvar(self): 3875 @dataclass 3876 class C: 3877 x: int 3878 y: ClassVar[int] = 1000 3879 3880 c = C(1) 3881 d = C(2) 3882 3883 self.assertIs(c.y, d.y) 3884 self.assertEqual(c.y, 1000) 3885 3886 # Trying to replace y is an error: can't replace ClassVars. 3887 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an " 3888 "unexpected keyword argument 'y'"): 3889 replace(c, y=30) 3890 3891 replace(c, x=5) 3892 3893 def test_initvar_is_specified(self): 3894 @dataclass 3895 class C: 3896 x: int 3897 y: InitVar[int] 3898 3899 def __post_init__(self, y): 3900 self.x *= y 3901 3902 c = C(1, 10) 3903 self.assertEqual(c.x, 10) 3904 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be " 3905 "specified with replace()"): 3906 replace(c, x=3) 3907 c = replace(c, x=3, y=5) 3908 self.assertEqual(c.x, 15) 3909 3910 def test_initvar_with_default_value(self): 3911 @dataclass 3912 class C: 3913 x: int 3914 y: InitVar[int] = None 3915 z: InitVar[int] = 42 3916 3917 def __post_init__(self, y, z): 3918 if y is not None: 3919 self.x += y 3920 if z is not None: 3921 self.x += z 3922 3923 c = C(x=1, y=10, z=1) 3924 self.assertEqual(replace(c), C(x=12)) 3925 self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42)) 3926 self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1)) 3927 3928 def test_recursive_repr(self): 3929 @dataclass 3930 class C: 3931 f: "C" 3932 3933 c = C(None) 3934 c.f = c 3935 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)") 3936 3937 def test_recursive_repr_two_attrs(self): 3938 @dataclass 3939 class C: 3940 f: "C" 3941 g: "C" 3942 3943 c = C(None, None) 3944 c.f = c 3945 c.g = c 3946 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs" 3947 ".<locals>.C(f=..., g=...)") 3948 3949 def test_recursive_repr_indirection(self): 3950 @dataclass 3951 class C: 3952 f: "D" 3953 3954 @dataclass 3955 class D: 3956 f: "C" 3957 3958 c = C(None) 3959 d = D(None) 3960 c.f = d 3961 d.f = c 3962 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection" 3963 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection" 3964 ".<locals>.D(f=...))") 3965 3966 def test_recursive_repr_indirection_two(self): 3967 @dataclass 3968 class C: 3969 f: "D" 3970 3971 @dataclass 3972 class D: 3973 f: "E" 3974 3975 @dataclass 3976 class E: 3977 f: "C" 3978 3979 c = C(None) 3980 d = D(None) 3981 e = E(None) 3982 c.f = d 3983 d.f = e 3984 e.f = c 3985 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two" 3986 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two" 3987 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two" 3988 ".<locals>.E(f=...)))") 3989 3990 def test_recursive_repr_misc_attrs(self): 3991 @dataclass 3992 class C: 3993 f: "C" 3994 g: int 3995 3996 c = C(None, 1) 3997 c.f = c 3998 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs" 3999 ".<locals>.C(f=..., g=1)") 4000 4001 ## def test_initvar(self): 4002 ## @dataclass 4003 ## class C: 4004 ## x: int 4005 ## y: InitVar[int] 4006 4007 ## c = C(1, 10) 4008 ## d = C(2, 20) 4009 4010 ## # In our case, replacing an InitVar is a no-op 4011 ## self.assertEqual(c, replace(c, y=5)) 4012 4013 ## replace(c, x=5) 4014 4015class TestAbstract(unittest.TestCase): 4016 def test_abc_implementation(self): 4017 class Ordered(abc.ABC): 4018 @abc.abstractmethod 4019 def __lt__(self, other): 4020 pass 4021 4022 @abc.abstractmethod 4023 def __le__(self, other): 4024 pass 4025 4026 @dataclass(order=True) 4027 class Date(Ordered): 4028 year: int 4029 month: 'Month' 4030 day: 'int' 4031 4032 self.assertFalse(inspect.isabstract(Date)) 4033 self.assertGreater(Date(2020,12,25), Date(2020,8,31)) 4034 4035 def test_maintain_abc(self): 4036 class A(abc.ABC): 4037 @abc.abstractmethod 4038 def foo(self): 4039 pass 4040 4041 @dataclass 4042 class Date(A): 4043 year: int 4044 month: 'Month' 4045 day: 'int' 4046 4047 self.assertTrue(inspect.isabstract(Date)) 4048 msg = 'class Date with abstract method foo' 4049 self.assertRaisesRegex(TypeError, msg, Date) 4050 4051 4052class TestMatchArgs(unittest.TestCase): 4053 def test_match_args(self): 4054 @dataclass 4055 class C: 4056 a: int 4057 self.assertEqual(C(42).__match_args__, ('a',)) 4058 4059 def test_explicit_match_args(self): 4060 ma = () 4061 @dataclass 4062 class C: 4063 a: int 4064 __match_args__ = ma 4065 self.assertIs(C(42).__match_args__, ma) 4066 4067 def test_bpo_43764(self): 4068 @dataclass(repr=False, eq=False, init=False) 4069 class X: 4070 a: int 4071 b: int 4072 c: int 4073 self.assertEqual(X.__match_args__, ("a", "b", "c")) 4074 4075 def test_match_args_argument(self): 4076 @dataclass(match_args=False) 4077 class X: 4078 a: int 4079 self.assertNotIn('__match_args__', X.__dict__) 4080 4081 @dataclass(match_args=False) 4082 class Y: 4083 a: int 4084 __match_args__ = ('b',) 4085 self.assertEqual(Y.__match_args__, ('b',)) 4086 4087 @dataclass(match_args=False) 4088 class Z(Y): 4089 z: int 4090 self.assertEqual(Z.__match_args__, ('b',)) 4091 4092 # Ensure parent dataclass __match_args__ is seen, if child class 4093 # specifies match_args=False. 4094 @dataclass 4095 class A: 4096 a: int 4097 z: int 4098 @dataclass(match_args=False) 4099 class B(A): 4100 b: int 4101 self.assertEqual(B.__match_args__, ('a', 'z')) 4102 4103 def test_make_dataclasses(self): 4104 C = make_dataclass('C', [('x', int), ('y', int)]) 4105 self.assertEqual(C.__match_args__, ('x', 'y')) 4106 4107 C = make_dataclass('C', [('x', int), ('y', int)], match_args=True) 4108 self.assertEqual(C.__match_args__, ('x', 'y')) 4109 4110 C = make_dataclass('C', [('x', int), ('y', int)], match_args=False) 4111 self.assertNotIn('__match__args__', C.__dict__) 4112 4113 C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)}) 4114 self.assertEqual(C.__match_args__, ('z',)) 4115 4116 4117class TestKeywordArgs(unittest.TestCase): 4118 def test_no_classvar_kwarg(self): 4119 msg = 'field a is a ClassVar but specifies kw_only' 4120 with self.assertRaisesRegex(TypeError, msg): 4121 @dataclass 4122 class A: 4123 a: ClassVar[int] = field(kw_only=True) 4124 4125 with self.assertRaisesRegex(TypeError, msg): 4126 @dataclass 4127 class A: 4128 a: ClassVar[int] = field(kw_only=False) 4129 4130 with self.assertRaisesRegex(TypeError, msg): 4131 @dataclass(kw_only=True) 4132 class A: 4133 a: ClassVar[int] = field(kw_only=False) 4134 4135 def test_field_marked_as_kwonly(self): 4136 ####################### 4137 # Using dataclass(kw_only=True) 4138 @dataclass(kw_only=True) 4139 class A: 4140 a: int 4141 self.assertTrue(fields(A)[0].kw_only) 4142 4143 @dataclass(kw_only=True) 4144 class A: 4145 a: int = field(kw_only=True) 4146 self.assertTrue(fields(A)[0].kw_only) 4147 4148 @dataclass(kw_only=True) 4149 class A: 4150 a: int = field(kw_only=False) 4151 self.assertFalse(fields(A)[0].kw_only) 4152 4153 ####################### 4154 # Using dataclass(kw_only=False) 4155 @dataclass(kw_only=False) 4156 class A: 4157 a: int 4158 self.assertFalse(fields(A)[0].kw_only) 4159 4160 @dataclass(kw_only=False) 4161 class A: 4162 a: int = field(kw_only=True) 4163 self.assertTrue(fields(A)[0].kw_only) 4164 4165 @dataclass(kw_only=False) 4166 class A: 4167 a: int = field(kw_only=False) 4168 self.assertFalse(fields(A)[0].kw_only) 4169 4170 ####################### 4171 # Not specifying dataclass(kw_only) 4172 @dataclass 4173 class A: 4174 a: int 4175 self.assertFalse(fields(A)[0].kw_only) 4176 4177 @dataclass 4178 class A: 4179 a: int = field(kw_only=True) 4180 self.assertTrue(fields(A)[0].kw_only) 4181 4182 @dataclass 4183 class A: 4184 a: int = field(kw_only=False) 4185 self.assertFalse(fields(A)[0].kw_only) 4186 4187 def test_match_args(self): 4188 # kw fields don't show up in __match_args__. 4189 @dataclass(kw_only=True) 4190 class C: 4191 a: int 4192 self.assertEqual(C(a=42).__match_args__, ()) 4193 4194 @dataclass 4195 class C: 4196 a: int 4197 b: int = field(kw_only=True) 4198 self.assertEqual(C(42, b=10).__match_args__, ('a',)) 4199 4200 def test_KW_ONLY(self): 4201 @dataclass 4202 class A: 4203 a: int 4204 _: KW_ONLY 4205 b: int 4206 c: int 4207 A(3, c=5, b=4) 4208 msg = "takes 2 positional arguments but 4 were given" 4209 with self.assertRaisesRegex(TypeError, msg): 4210 A(3, 4, 5) 4211 4212 4213 @dataclass(kw_only=True) 4214 class B: 4215 a: int 4216 _: KW_ONLY 4217 b: int 4218 c: int 4219 B(a=3, b=4, c=5) 4220 msg = "takes 1 positional argument but 4 were given" 4221 with self.assertRaisesRegex(TypeError, msg): 4222 B(3, 4, 5) 4223 4224 # Explicitly make a field that follows KW_ONLY be non-keyword-only. 4225 @dataclass 4226 class C: 4227 a: int 4228 _: KW_ONLY 4229 b: int 4230 c: int = field(kw_only=False) 4231 c = C(1, 2, b=3) 4232 self.assertEqual(c.a, 1) 4233 self.assertEqual(c.b, 3) 4234 self.assertEqual(c.c, 2) 4235 c = C(1, b=3, c=2) 4236 self.assertEqual(c.a, 1) 4237 self.assertEqual(c.b, 3) 4238 self.assertEqual(c.c, 2) 4239 c = C(1, b=3, c=2) 4240 self.assertEqual(c.a, 1) 4241 self.assertEqual(c.b, 3) 4242 self.assertEqual(c.c, 2) 4243 c = C(c=2, b=3, a=1) 4244 self.assertEqual(c.a, 1) 4245 self.assertEqual(c.b, 3) 4246 self.assertEqual(c.c, 2) 4247 4248 def test_KW_ONLY_as_string(self): 4249 @dataclass 4250 class A: 4251 a: int 4252 _: 'dataclasses.KW_ONLY' 4253 b: int 4254 c: int 4255 A(3, c=5, b=4) 4256 msg = "takes 2 positional arguments but 4 were given" 4257 with self.assertRaisesRegex(TypeError, msg): 4258 A(3, 4, 5) 4259 4260 def test_KW_ONLY_twice(self): 4261 msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified" 4262 4263 with self.assertRaisesRegex(TypeError, msg): 4264 @dataclass 4265 class A: 4266 a: int 4267 X: KW_ONLY 4268 Y: KW_ONLY 4269 b: int 4270 c: int 4271 4272 with self.assertRaisesRegex(TypeError, msg): 4273 @dataclass 4274 class A: 4275 a: int 4276 X: KW_ONLY 4277 b: int 4278 Y: KW_ONLY 4279 c: int 4280 4281 with self.assertRaisesRegex(TypeError, msg): 4282 @dataclass 4283 class A: 4284 a: int 4285 X: KW_ONLY 4286 b: int 4287 c: int 4288 Y: KW_ONLY 4289 4290 # But this usage is okay, since it's not using KW_ONLY. 4291 @dataclass 4292 class A: 4293 a: int 4294 _: KW_ONLY 4295 b: int 4296 c: int = field(kw_only=True) 4297 4298 # And if inheriting, it's okay. 4299 @dataclass 4300 class A: 4301 a: int 4302 _: KW_ONLY 4303 b: int 4304 c: int 4305 @dataclass 4306 class B(A): 4307 _: KW_ONLY 4308 d: int 4309 4310 # Make sure the error is raised in a derived class. 4311 with self.assertRaisesRegex(TypeError, msg): 4312 @dataclass 4313 class A: 4314 a: int 4315 _: KW_ONLY 4316 b: int 4317 c: int 4318 @dataclass 4319 class B(A): 4320 X: KW_ONLY 4321 d: int 4322 Y: KW_ONLY 4323 4324 4325 def test_post_init(self): 4326 @dataclass 4327 class A: 4328 a: int 4329 _: KW_ONLY 4330 b: InitVar[int] 4331 c: int 4332 d: InitVar[int] 4333 def __post_init__(self, b, d): 4334 raise CustomError(f'{b=} {d=}') 4335 with self.assertRaisesRegex(CustomError, 'b=3 d=4'): 4336 A(1, c=2, b=3, d=4) 4337 4338 @dataclass 4339 class B: 4340 a: int 4341 _: KW_ONLY 4342 b: InitVar[int] 4343 c: int 4344 d: InitVar[int] 4345 def __post_init__(self, b, d): 4346 self.a = b 4347 self.c = d 4348 b = B(1, c=2, b=3, d=4) 4349 self.assertEqual(asdict(b), {'a': 3, 'c': 4}) 4350 4351 def test_defaults(self): 4352 # For kwargs, make sure we can have defaults after non-defaults. 4353 @dataclass 4354 class A: 4355 a: int = 0 4356 _: KW_ONLY 4357 b: int 4358 c: int = 1 4359 d: int 4360 4361 a = A(d=4, b=3) 4362 self.assertEqual(a.a, 0) 4363 self.assertEqual(a.b, 3) 4364 self.assertEqual(a.c, 1) 4365 self.assertEqual(a.d, 4) 4366 4367 # Make sure we still check for non-kwarg non-defaults not following 4368 # defaults. 4369 err_regex = "non-default argument 'z' follows default argument" 4370 with self.assertRaisesRegex(TypeError, err_regex): 4371 @dataclass 4372 class A: 4373 a: int = 0 4374 z: int 4375 _: KW_ONLY 4376 b: int 4377 c: int = 1 4378 d: int 4379 4380 def test_make_dataclass(self): 4381 A = make_dataclass("A", ['a'], kw_only=True) 4382 self.assertTrue(fields(A)[0].kw_only) 4383 4384 B = make_dataclass("B", 4385 ['a', ('b', int, field(kw_only=False))], 4386 kw_only=True) 4387 self.assertTrue(fields(B)[0].kw_only) 4388 self.assertFalse(fields(B)[1].kw_only) 4389 4390 4391if __name__ == '__main__': 4392 unittest.main() 4393