1"""Test equality and order comparisons.""" 2import unittest 3from test.support import ALWAYS_EQ 4from fractions import Fraction 5from decimal import Decimal 6 7 8class ComparisonSimpleTest(unittest.TestCase): 9 """Test equality and order comparisons for some simple cases.""" 10 11 class Empty: 12 def __repr__(self): 13 return '<Empty>' 14 15 class Cmp: 16 def __init__(self, arg): 17 self.arg = arg 18 19 def __repr__(self): 20 return '<Cmp %s>' % self.arg 21 22 def __eq__(self, other): 23 return self.arg == other 24 25 set1 = [2, 2.0, 2, 2+0j, Cmp(2.0)] 26 set2 = [[1], (3,), None, Empty()] 27 candidates = set1 + set2 28 29 def test_comparisons(self): 30 for a in self.candidates: 31 for b in self.candidates: 32 if ((a in self.set1) and (b in self.set1)) or a is b: 33 self.assertEqual(a, b) 34 else: 35 self.assertNotEqual(a, b) 36 37 def test_id_comparisons(self): 38 # Ensure default comparison compares id() of args 39 L = [] 40 for i in range(10): 41 L.insert(len(L)//2, self.Empty()) 42 for a in L: 43 for b in L: 44 self.assertEqual(a == b, a is b, 'a=%r, b=%r' % (a, b)) 45 46 def test_ne_defaults_to_not_eq(self): 47 a = self.Cmp(1) 48 b = self.Cmp(1) 49 c = self.Cmp(2) 50 self.assertIs(a == b, True) 51 self.assertIs(a != b, False) 52 self.assertIs(a != c, True) 53 54 def test_ne_high_priority(self): 55 """object.__ne__() should allow reflected __ne__() to be tried""" 56 calls = [] 57 class Left: 58 # Inherits object.__ne__() 59 def __eq__(*args): 60 calls.append('Left.__eq__') 61 return NotImplemented 62 class Right: 63 def __eq__(*args): 64 calls.append('Right.__eq__') 65 return NotImplemented 66 def __ne__(*args): 67 calls.append('Right.__ne__') 68 return NotImplemented 69 Left() != Right() 70 self.assertSequenceEqual(calls, ['Left.__eq__', 'Right.__ne__']) 71 72 def test_ne_low_priority(self): 73 """object.__ne__() should not invoke reflected __eq__()""" 74 calls = [] 75 class Base: 76 # Inherits object.__ne__() 77 def __eq__(*args): 78 calls.append('Base.__eq__') 79 return NotImplemented 80 class Derived(Base): # Subclassing forces higher priority 81 def __eq__(*args): 82 calls.append('Derived.__eq__') 83 return NotImplemented 84 def __ne__(*args): 85 calls.append('Derived.__ne__') 86 return NotImplemented 87 Base() != Derived() 88 self.assertSequenceEqual(calls, ['Derived.__ne__', 'Base.__eq__']) 89 90 def test_other_delegation(self): 91 """No default delegation between operations except __ne__()""" 92 ops = ( 93 ('__eq__', lambda a, b: a == b), 94 ('__lt__', lambda a, b: a < b), 95 ('__le__', lambda a, b: a <= b), 96 ('__gt__', lambda a, b: a > b), 97 ('__ge__', lambda a, b: a >= b), 98 ) 99 for name, func in ops: 100 with self.subTest(name): 101 def unexpected(*args): 102 self.fail('Unexpected operator method called') 103 class C: 104 __ne__ = unexpected 105 for other, _ in ops: 106 if other != name: 107 setattr(C, other, unexpected) 108 if name == '__eq__': 109 self.assertIs(func(C(), object()), False) 110 else: 111 self.assertRaises(TypeError, func, C(), object()) 112 113 def test_issue_1393(self): 114 x = lambda: None 115 self.assertEqual(x, ALWAYS_EQ) 116 self.assertEqual(ALWAYS_EQ, x) 117 y = object() 118 self.assertEqual(y, ALWAYS_EQ) 119 self.assertEqual(ALWAYS_EQ, y) 120 121 122class ComparisonFullTest(unittest.TestCase): 123 """Test equality and ordering comparisons for built-in types and 124 user-defined classes that implement relevant combinations of rich 125 comparison methods. 126 """ 127 128 class CompBase: 129 """Base class for classes with rich comparison methods. 130 131 The "x" attribute should be set to an underlying value to compare. 132 133 Derived classes have a "meth" tuple attribute listing names of 134 comparison methods implemented. See assert_total_order(). 135 """ 136 137 # Class without any rich comparison methods. 138 class CompNone(CompBase): 139 meth = () 140 141 # Classes with all combinations of value-based equality comparison methods. 142 class CompEq(CompBase): 143 meth = ("eq",) 144 def __eq__(self, other): 145 return self.x == other.x 146 147 class CompNe(CompBase): 148 meth = ("ne",) 149 def __ne__(self, other): 150 return self.x != other.x 151 152 class CompEqNe(CompBase): 153 meth = ("eq", "ne") 154 def __eq__(self, other): 155 return self.x == other.x 156 def __ne__(self, other): 157 return self.x != other.x 158 159 # Classes with all combinations of value-based less/greater-than order 160 # comparison methods. 161 class CompLt(CompBase): 162 meth = ("lt",) 163 def __lt__(self, other): 164 return self.x < other.x 165 166 class CompGt(CompBase): 167 meth = ("gt",) 168 def __gt__(self, other): 169 return self.x > other.x 170 171 class CompLtGt(CompBase): 172 meth = ("lt", "gt") 173 def __lt__(self, other): 174 return self.x < other.x 175 def __gt__(self, other): 176 return self.x > other.x 177 178 # Classes with all combinations of value-based less/greater-or-equal-than 179 # order comparison methods 180 class CompLe(CompBase): 181 meth = ("le",) 182 def __le__(self, other): 183 return self.x <= other.x 184 185 class CompGe(CompBase): 186 meth = ("ge",) 187 def __ge__(self, other): 188 return self.x >= other.x 189 190 class CompLeGe(CompBase): 191 meth = ("le", "ge") 192 def __le__(self, other): 193 return self.x <= other.x 194 def __ge__(self, other): 195 return self.x >= other.x 196 197 # It should be sufficient to combine the comparison methods only within 198 # each group. 199 all_comp_classes = ( 200 CompNone, 201 CompEq, CompNe, CompEqNe, # equal group 202 CompLt, CompGt, CompLtGt, # less/greater-than group 203 CompLe, CompGe, CompLeGe) # less/greater-or-equal group 204 205 def create_sorted_instances(self, class_, values): 206 """Create objects of type `class_` and return them in a list. 207 208 `values` is a list of values that determines the value of data 209 attribute `x` of each object. 210 211 Objects in the returned list are sorted by their identity. They 212 assigned values in `values` list order. By assign decreasing 213 values to objects with increasing identities, testcases can assert 214 that order comparison is performed by value and not by identity. 215 """ 216 217 instances = [class_() for __ in range(len(values))] 218 instances.sort(key=id) 219 # Assign the provided values to the instances. 220 for inst, value in zip(instances, values): 221 inst.x = value 222 return instances 223 224 def assert_equality_only(self, a, b, equal): 225 """Assert equality result and that ordering is not implemented. 226 227 a, b: Instances to be tested (of same or different type). 228 equal: Boolean indicating the expected equality comparison results. 229 """ 230 self.assertEqual(a == b, equal) 231 self.assertEqual(b == a, equal) 232 self.assertEqual(a != b, not equal) 233 self.assertEqual(b != a, not equal) 234 with self.assertRaisesRegex(TypeError, "not supported"): 235 a < b 236 with self.assertRaisesRegex(TypeError, "not supported"): 237 a <= b 238 with self.assertRaisesRegex(TypeError, "not supported"): 239 a > b 240 with self.assertRaisesRegex(TypeError, "not supported"): 241 a >= b 242 with self.assertRaisesRegex(TypeError, "not supported"): 243 b < a 244 with self.assertRaisesRegex(TypeError, "not supported"): 245 b <= a 246 with self.assertRaisesRegex(TypeError, "not supported"): 247 b > a 248 with self.assertRaisesRegex(TypeError, "not supported"): 249 b >= a 250 251 def assert_total_order(self, a, b, comp, a_meth=None, b_meth=None): 252 """Test total ordering comparison of two instances. 253 254 a, b: Instances to be tested (of same or different type). 255 256 comp: -1, 0, or 1 indicates that the expected order comparison 257 result for operations that are supported by the classes is 258 a <, ==, or > b. 259 260 a_meth, b_meth: Either None, indicating that all rich comparison 261 methods are available, aa for builtins, or the tuple (subset) 262 of "eq", "ne", "lt", "le", "gt", and "ge" that are available 263 for the corresponding instance (of a user-defined class). 264 """ 265 self.assert_eq_subtest(a, b, comp, a_meth, b_meth) 266 self.assert_ne_subtest(a, b, comp, a_meth, b_meth) 267 self.assert_lt_subtest(a, b, comp, a_meth, b_meth) 268 self.assert_le_subtest(a, b, comp, a_meth, b_meth) 269 self.assert_gt_subtest(a, b, comp, a_meth, b_meth) 270 self.assert_ge_subtest(a, b, comp, a_meth, b_meth) 271 272 # The body of each subtest has form: 273 # 274 # if value-based comparison methods: 275 # expect what the testcase defined for a op b and b rop a; 276 # else: no value-based comparison 277 # expect default behavior of object for a op b and b rop a. 278 279 def assert_eq_subtest(self, a, b, comp, a_meth, b_meth): 280 if a_meth is None or "eq" in a_meth or "eq" in b_meth: 281 self.assertEqual(a == b, comp == 0) 282 self.assertEqual(b == a, comp == 0) 283 else: 284 self.assertEqual(a == b, a is b) 285 self.assertEqual(b == a, a is b) 286 287 def assert_ne_subtest(self, a, b, comp, a_meth, b_meth): 288 if a_meth is None or not {"ne", "eq"}.isdisjoint(a_meth + b_meth): 289 self.assertEqual(a != b, comp != 0) 290 self.assertEqual(b != a, comp != 0) 291 else: 292 self.assertEqual(a != b, a is not b) 293 self.assertEqual(b != a, a is not b) 294 295 def assert_lt_subtest(self, a, b, comp, a_meth, b_meth): 296 if a_meth is None or "lt" in a_meth or "gt" in b_meth: 297 self.assertEqual(a < b, comp < 0) 298 self.assertEqual(b > a, comp < 0) 299 else: 300 with self.assertRaisesRegex(TypeError, "not supported"): 301 a < b 302 with self.assertRaisesRegex(TypeError, "not supported"): 303 b > a 304 305 def assert_le_subtest(self, a, b, comp, a_meth, b_meth): 306 if a_meth is None or "le" in a_meth or "ge" in b_meth: 307 self.assertEqual(a <= b, comp <= 0) 308 self.assertEqual(b >= a, comp <= 0) 309 else: 310 with self.assertRaisesRegex(TypeError, "not supported"): 311 a <= b 312 with self.assertRaisesRegex(TypeError, "not supported"): 313 b >= a 314 315 def assert_gt_subtest(self, a, b, comp, a_meth, b_meth): 316 if a_meth is None or "gt" in a_meth or "lt" in b_meth: 317 self.assertEqual(a > b, comp > 0) 318 self.assertEqual(b < a, comp > 0) 319 else: 320 with self.assertRaisesRegex(TypeError, "not supported"): 321 a > b 322 with self.assertRaisesRegex(TypeError, "not supported"): 323 b < a 324 325 def assert_ge_subtest(self, a, b, comp, a_meth, b_meth): 326 if a_meth is None or "ge" in a_meth or "le" in b_meth: 327 self.assertEqual(a >= b, comp >= 0) 328 self.assertEqual(b <= a, comp >= 0) 329 else: 330 with self.assertRaisesRegex(TypeError, "not supported"): 331 a >= b 332 with self.assertRaisesRegex(TypeError, "not supported"): 333 b <= a 334 335 def test_objects(self): 336 """Compare instances of type 'object'.""" 337 a = object() 338 b = object() 339 self.assert_equality_only(a, a, True) 340 self.assert_equality_only(a, b, False) 341 342 def test_comp_classes_same(self): 343 """Compare same-class instances with comparison methods.""" 344 345 for cls in self.all_comp_classes: 346 with self.subTest(cls): 347 instances = self.create_sorted_instances(cls, (1, 2, 1)) 348 349 # Same object. 350 self.assert_total_order(instances[0], instances[0], 0, 351 cls.meth, cls.meth) 352 353 # Different objects, same value. 354 self.assert_total_order(instances[0], instances[2], 0, 355 cls.meth, cls.meth) 356 357 # Different objects, value ascending for ascending identities. 358 self.assert_total_order(instances[0], instances[1], -1, 359 cls.meth, cls.meth) 360 361 # different objects, value descending for ascending identities. 362 # This is the interesting case to assert that order comparison 363 # is performed based on the value and not based on the identity. 364 self.assert_total_order(instances[1], instances[2], +1, 365 cls.meth, cls.meth) 366 367 def test_comp_classes_different(self): 368 """Compare different-class instances with comparison methods.""" 369 370 for cls_a in self.all_comp_classes: 371 for cls_b in self.all_comp_classes: 372 with self.subTest(a=cls_a, b=cls_b): 373 a1 = cls_a() 374 a1.x = 1 375 b1 = cls_b() 376 b1.x = 1 377 b2 = cls_b() 378 b2.x = 2 379 380 self.assert_total_order( 381 a1, b1, 0, cls_a.meth, cls_b.meth) 382 self.assert_total_order( 383 a1, b2, -1, cls_a.meth, cls_b.meth) 384 385 def test_str_subclass(self): 386 """Compare instances of str and a subclass.""" 387 class StrSubclass(str): 388 pass 389 390 s1 = str("a") 391 s2 = str("b") 392 c1 = StrSubclass("a") 393 c2 = StrSubclass("b") 394 c3 = StrSubclass("b") 395 396 self.assert_total_order(s1, s1, 0) 397 self.assert_total_order(s1, s2, -1) 398 self.assert_total_order(c1, c1, 0) 399 self.assert_total_order(c1, c2, -1) 400 self.assert_total_order(c2, c3, 0) 401 402 self.assert_total_order(s1, c2, -1) 403 self.assert_total_order(s2, c3, 0) 404 self.assert_total_order(c1, s2, -1) 405 self.assert_total_order(c2, s2, 0) 406 407 def test_numbers(self): 408 """Compare number types.""" 409 410 # Same types. 411 i1 = 1001 412 i2 = 1002 413 self.assert_total_order(i1, i1, 0) 414 self.assert_total_order(i1, i2, -1) 415 416 f1 = 1001.0 417 f2 = 1001.1 418 self.assert_total_order(f1, f1, 0) 419 self.assert_total_order(f1, f2, -1) 420 421 q1 = Fraction(2002, 2) 422 q2 = Fraction(2003, 2) 423 self.assert_total_order(q1, q1, 0) 424 self.assert_total_order(q1, q2, -1) 425 426 d1 = Decimal('1001.0') 427 d2 = Decimal('1001.1') 428 self.assert_total_order(d1, d1, 0) 429 self.assert_total_order(d1, d2, -1) 430 431 c1 = 1001+0j 432 c2 = 1001+1j 433 self.assert_equality_only(c1, c1, True) 434 self.assert_equality_only(c1, c2, False) 435 436 437 # Mixing types. 438 for n1, n2 in ((i1,f1), (i1,q1), (i1,d1), (f1,q1), (f1,d1), (q1,d1)): 439 self.assert_total_order(n1, n2, 0) 440 for n1 in (i1, f1, q1, d1): 441 self.assert_equality_only(n1, c1, True) 442 443 def test_sequences(self): 444 """Compare list, tuple, and range.""" 445 l1 = [1, 2] 446 l2 = [2, 3] 447 self.assert_total_order(l1, l1, 0) 448 self.assert_total_order(l1, l2, -1) 449 450 t1 = (1, 2) 451 t2 = (2, 3) 452 self.assert_total_order(t1, t1, 0) 453 self.assert_total_order(t1, t2, -1) 454 455 r1 = range(1, 2) 456 r2 = range(2, 2) 457 self.assert_equality_only(r1, r1, True) 458 self.assert_equality_only(r1, r2, False) 459 460 self.assert_equality_only(t1, l1, False) 461 self.assert_equality_only(l1, r1, False) 462 self.assert_equality_only(r1, t1, False) 463 464 def test_bytes(self): 465 """Compare bytes and bytearray.""" 466 bs1 = b'a1' 467 bs2 = b'b2' 468 self.assert_total_order(bs1, bs1, 0) 469 self.assert_total_order(bs1, bs2, -1) 470 471 ba1 = bytearray(b'a1') 472 ba2 = bytearray(b'b2') 473 self.assert_total_order(ba1, ba1, 0) 474 self.assert_total_order(ba1, ba2, -1) 475 476 self.assert_total_order(bs1, ba1, 0) 477 self.assert_total_order(bs1, ba2, -1) 478 self.assert_total_order(ba1, bs1, 0) 479 self.assert_total_order(ba1, bs2, -1) 480 481 def test_sets(self): 482 """Compare set and frozenset.""" 483 s1 = {1, 2} 484 s2 = {1, 2, 3} 485 self.assert_total_order(s1, s1, 0) 486 self.assert_total_order(s1, s2, -1) 487 488 f1 = frozenset(s1) 489 f2 = frozenset(s2) 490 self.assert_total_order(f1, f1, 0) 491 self.assert_total_order(f1, f2, -1) 492 493 self.assert_total_order(s1, f1, 0) 494 self.assert_total_order(s1, f2, -1) 495 self.assert_total_order(f1, s1, 0) 496 self.assert_total_order(f1, s2, -1) 497 498 def test_mappings(self): 499 """ Compare dict. 500 """ 501 d1 = {1: "a", 2: "b"} 502 d2 = {2: "b", 3: "c"} 503 d3 = {3: "c", 2: "b"} 504 self.assert_equality_only(d1, d1, True) 505 self.assert_equality_only(d1, d2, False) 506 self.assert_equality_only(d2, d3, True) 507 508 509if __name__ == '__main__': 510 unittest.main() 511