1"""Test equality and order comparisons."""
2import unittest
3from test.support import ALWAYS_EQ
4from fractions import Fraction
5from decimal import Decimal
6
7
8class ComparisonSimpleTest(unittest.TestCase):
9    """Test equality and order comparisons for some simple cases."""
10
11    class Empty:
12        def __repr__(self):
13            return '<Empty>'
14
15    class Cmp:
16        def __init__(self, arg):
17            self.arg = arg
18
19        def __repr__(self):
20            return '<Cmp %s>' % self.arg
21
22        def __eq__(self, other):
23            return self.arg == other
24
25    set1 = [2, 2.0, 2, 2+0j, Cmp(2.0)]
26    set2 = [[1], (3,), None, Empty()]
27    candidates = set1 + set2
28
29    def test_comparisons(self):
30        for a in self.candidates:
31            for b in self.candidates:
32                if ((a in self.set1) and (b in self.set1)) or a is b:
33                    self.assertEqual(a, b)
34                else:
35                    self.assertNotEqual(a, b)
36
37    def test_id_comparisons(self):
38        # Ensure default comparison compares id() of args
39        L = []
40        for i in range(10):
41            L.insert(len(L)//2, self.Empty())
42        for a in L:
43            for b in L:
44                self.assertEqual(a == b, a is b, 'a=%r, b=%r' % (a, b))
45
46    def test_ne_defaults_to_not_eq(self):
47        a = self.Cmp(1)
48        b = self.Cmp(1)
49        c = self.Cmp(2)
50        self.assertIs(a == b, True)
51        self.assertIs(a != b, False)
52        self.assertIs(a != c, True)
53
54    def test_ne_high_priority(self):
55        """object.__ne__() should allow reflected __ne__() to be tried"""
56        calls = []
57        class Left:
58            # Inherits object.__ne__()
59            def __eq__(*args):
60                calls.append('Left.__eq__')
61                return NotImplemented
62        class Right:
63            def __eq__(*args):
64                calls.append('Right.__eq__')
65                return NotImplemented
66            def __ne__(*args):
67                calls.append('Right.__ne__')
68                return NotImplemented
69        Left() != Right()
70        self.assertSequenceEqual(calls, ['Left.__eq__', 'Right.__ne__'])
71
72    def test_ne_low_priority(self):
73        """object.__ne__() should not invoke reflected __eq__()"""
74        calls = []
75        class Base:
76            # Inherits object.__ne__()
77            def __eq__(*args):
78                calls.append('Base.__eq__')
79                return NotImplemented
80        class Derived(Base):  # Subclassing forces higher priority
81            def __eq__(*args):
82                calls.append('Derived.__eq__')
83                return NotImplemented
84            def __ne__(*args):
85                calls.append('Derived.__ne__')
86                return NotImplemented
87        Base() != Derived()
88        self.assertSequenceEqual(calls, ['Derived.__ne__', 'Base.__eq__'])
89
90    def test_other_delegation(self):
91        """No default delegation between operations except __ne__()"""
92        ops = (
93            ('__eq__', lambda a, b: a == b),
94            ('__lt__', lambda a, b: a < b),
95            ('__le__', lambda a, b: a <= b),
96            ('__gt__', lambda a, b: a > b),
97            ('__ge__', lambda a, b: a >= b),
98        )
99        for name, func in ops:
100            with self.subTest(name):
101                def unexpected(*args):
102                    self.fail('Unexpected operator method called')
103                class C:
104                    __ne__ = unexpected
105                for other, _ in ops:
106                    if other != name:
107                        setattr(C, other, unexpected)
108                if name == '__eq__':
109                    self.assertIs(func(C(), object()), False)
110                else:
111                    self.assertRaises(TypeError, func, C(), object())
112
113    def test_issue_1393(self):
114        x = lambda: None
115        self.assertEqual(x, ALWAYS_EQ)
116        self.assertEqual(ALWAYS_EQ, x)
117        y = object()
118        self.assertEqual(y, ALWAYS_EQ)
119        self.assertEqual(ALWAYS_EQ, y)
120
121
122class ComparisonFullTest(unittest.TestCase):
123    """Test equality and ordering comparisons for built-in types and
124    user-defined classes that implement relevant combinations of rich
125    comparison methods.
126    """
127
128    class CompBase:
129        """Base class for classes with rich comparison methods.
130
131        The "x" attribute should be set to an underlying value to compare.
132
133        Derived classes have a "meth" tuple attribute listing names of
134        comparison methods implemented. See assert_total_order().
135        """
136
137    # Class without any rich comparison methods.
138    class CompNone(CompBase):
139        meth = ()
140
141    # Classes with all combinations of value-based equality comparison methods.
142    class CompEq(CompBase):
143        meth = ("eq",)
144        def __eq__(self, other):
145            return self.x == other.x
146
147    class CompNe(CompBase):
148        meth = ("ne",)
149        def __ne__(self, other):
150            return self.x != other.x
151
152    class CompEqNe(CompBase):
153        meth = ("eq", "ne")
154        def __eq__(self, other):
155            return self.x == other.x
156        def __ne__(self, other):
157            return self.x != other.x
158
159    # Classes with all combinations of value-based less/greater-than order
160    # comparison methods.
161    class CompLt(CompBase):
162        meth = ("lt",)
163        def __lt__(self, other):
164            return self.x < other.x
165
166    class CompGt(CompBase):
167        meth = ("gt",)
168        def __gt__(self, other):
169            return self.x > other.x
170
171    class CompLtGt(CompBase):
172        meth = ("lt", "gt")
173        def __lt__(self, other):
174            return self.x < other.x
175        def __gt__(self, other):
176            return self.x > other.x
177
178    # Classes with all combinations of value-based less/greater-or-equal-than
179    # order comparison methods
180    class CompLe(CompBase):
181        meth = ("le",)
182        def __le__(self, other):
183            return self.x <= other.x
184
185    class CompGe(CompBase):
186        meth = ("ge",)
187        def __ge__(self, other):
188            return self.x >= other.x
189
190    class CompLeGe(CompBase):
191        meth = ("le", "ge")
192        def __le__(self, other):
193            return self.x <= other.x
194        def __ge__(self, other):
195            return self.x >= other.x
196
197    # It should be sufficient to combine the comparison methods only within
198    # each group.
199    all_comp_classes = (
200            CompNone,
201            CompEq, CompNe, CompEqNe,  # equal group
202            CompLt, CompGt, CompLtGt,  # less/greater-than group
203            CompLe, CompGe, CompLeGe)  # less/greater-or-equal group
204
205    def create_sorted_instances(self, class_, values):
206        """Create objects of type `class_` and return them in a list.
207
208        `values` is a list of values that determines the value of data
209        attribute `x` of each object.
210
211        Objects in the returned list are sorted by their identity.  They
212        assigned values in `values` list order.  By assign decreasing
213        values to objects with increasing identities, testcases can assert
214        that order comparison is performed by value and not by identity.
215        """
216
217        instances = [class_() for __ in range(len(values))]
218        instances.sort(key=id)
219        # Assign the provided values to the instances.
220        for inst, value in zip(instances, values):
221            inst.x = value
222        return instances
223
224    def assert_equality_only(self, a, b, equal):
225        """Assert equality result and that ordering is not implemented.
226
227        a, b: Instances to be tested (of same or different type).
228        equal: Boolean indicating the expected equality comparison results.
229        """
230        self.assertEqual(a == b, equal)
231        self.assertEqual(b == a, equal)
232        self.assertEqual(a != b, not equal)
233        self.assertEqual(b != a, not equal)
234        with self.assertRaisesRegex(TypeError, "not supported"):
235            a < b
236        with self.assertRaisesRegex(TypeError, "not supported"):
237            a <= b
238        with self.assertRaisesRegex(TypeError, "not supported"):
239            a > b
240        with self.assertRaisesRegex(TypeError, "not supported"):
241            a >= b
242        with self.assertRaisesRegex(TypeError, "not supported"):
243            b < a
244        with self.assertRaisesRegex(TypeError, "not supported"):
245            b <= a
246        with self.assertRaisesRegex(TypeError, "not supported"):
247            b > a
248        with self.assertRaisesRegex(TypeError, "not supported"):
249            b >= a
250
251    def assert_total_order(self, a, b, comp, a_meth=None, b_meth=None):
252        """Test total ordering comparison of two instances.
253
254        a, b: Instances to be tested (of same or different type).
255
256        comp: -1, 0, or 1 indicates that the expected order comparison
257           result for operations that are supported by the classes is
258           a <, ==, or > b.
259
260        a_meth, b_meth: Either None, indicating that all rich comparison
261           methods are available, aa for builtins, or the tuple (subset)
262           of "eq", "ne", "lt", "le", "gt", and "ge" that are available
263           for the corresponding instance (of a user-defined class).
264        """
265        self.assert_eq_subtest(a, b, comp, a_meth, b_meth)
266        self.assert_ne_subtest(a, b, comp, a_meth, b_meth)
267        self.assert_lt_subtest(a, b, comp, a_meth, b_meth)
268        self.assert_le_subtest(a, b, comp, a_meth, b_meth)
269        self.assert_gt_subtest(a, b, comp, a_meth, b_meth)
270        self.assert_ge_subtest(a, b, comp, a_meth, b_meth)
271
272    # The body of each subtest has form:
273    #
274    #     if value-based comparison methods:
275    #         expect what the testcase defined for a op b and b rop a;
276    #     else:  no value-based comparison
277    #         expect default behavior of object for a op b and b rop a.
278
279    def assert_eq_subtest(self, a, b, comp, a_meth, b_meth):
280        if a_meth is None or "eq" in a_meth or "eq" in b_meth:
281            self.assertEqual(a == b, comp == 0)
282            self.assertEqual(b == a, comp == 0)
283        else:
284            self.assertEqual(a == b, a is b)
285            self.assertEqual(b == a, a is b)
286
287    def assert_ne_subtest(self, a, b, comp, a_meth, b_meth):
288        if a_meth is None or not {"ne", "eq"}.isdisjoint(a_meth + b_meth):
289            self.assertEqual(a != b, comp != 0)
290            self.assertEqual(b != a, comp != 0)
291        else:
292            self.assertEqual(a != b, a is not b)
293            self.assertEqual(b != a, a is not b)
294
295    def assert_lt_subtest(self, a, b, comp, a_meth, b_meth):
296        if a_meth is None or "lt" in a_meth or "gt" in b_meth:
297            self.assertEqual(a < b, comp < 0)
298            self.assertEqual(b > a, comp < 0)
299        else:
300            with self.assertRaisesRegex(TypeError, "not supported"):
301                a < b
302            with self.assertRaisesRegex(TypeError, "not supported"):
303                b > a
304
305    def assert_le_subtest(self, a, b, comp, a_meth, b_meth):
306        if a_meth is None or "le" in a_meth or "ge" in b_meth:
307            self.assertEqual(a <= b, comp <= 0)
308            self.assertEqual(b >= a, comp <= 0)
309        else:
310            with self.assertRaisesRegex(TypeError, "not supported"):
311                a <= b
312            with self.assertRaisesRegex(TypeError, "not supported"):
313                b >= a
314
315    def assert_gt_subtest(self, a, b, comp, a_meth, b_meth):
316        if a_meth is None or "gt" in a_meth or "lt" in b_meth:
317            self.assertEqual(a > b, comp > 0)
318            self.assertEqual(b < a, comp > 0)
319        else:
320            with self.assertRaisesRegex(TypeError, "not supported"):
321                a > b
322            with self.assertRaisesRegex(TypeError, "not supported"):
323                b < a
324
325    def assert_ge_subtest(self, a, b, comp, a_meth, b_meth):
326        if a_meth is None or "ge" in a_meth or "le" in b_meth:
327            self.assertEqual(a >= b, comp >= 0)
328            self.assertEqual(b <= a, comp >= 0)
329        else:
330            with self.assertRaisesRegex(TypeError, "not supported"):
331                a >= b
332            with self.assertRaisesRegex(TypeError, "not supported"):
333                b <= a
334
335    def test_objects(self):
336        """Compare instances of type 'object'."""
337        a = object()
338        b = object()
339        self.assert_equality_only(a, a, True)
340        self.assert_equality_only(a, b, False)
341
342    def test_comp_classes_same(self):
343        """Compare same-class instances with comparison methods."""
344
345        for cls in self.all_comp_classes:
346            with self.subTest(cls):
347                instances = self.create_sorted_instances(cls, (1, 2, 1))
348
349                # Same object.
350                self.assert_total_order(instances[0], instances[0], 0,
351                                        cls.meth, cls.meth)
352
353                # Different objects, same value.
354                self.assert_total_order(instances[0], instances[2], 0,
355                                        cls.meth, cls.meth)
356
357                # Different objects, value ascending for ascending identities.
358                self.assert_total_order(instances[0], instances[1], -1,
359                                        cls.meth, cls.meth)
360
361                # different objects, value descending for ascending identities.
362                # This is the interesting case to assert that order comparison
363                # is performed based on the value and not based on the identity.
364                self.assert_total_order(instances[1], instances[2], +1,
365                                        cls.meth, cls.meth)
366
367    def test_comp_classes_different(self):
368        """Compare different-class instances with comparison methods."""
369
370        for cls_a in self.all_comp_classes:
371            for cls_b in self.all_comp_classes:
372                with self.subTest(a=cls_a, b=cls_b):
373                    a1 = cls_a()
374                    a1.x = 1
375                    b1 = cls_b()
376                    b1.x = 1
377                    b2 = cls_b()
378                    b2.x = 2
379
380                    self.assert_total_order(
381                        a1, b1, 0, cls_a.meth, cls_b.meth)
382                    self.assert_total_order(
383                        a1, b2, -1, cls_a.meth, cls_b.meth)
384
385    def test_str_subclass(self):
386        """Compare instances of str and a subclass."""
387        class StrSubclass(str):
388            pass
389
390        s1 = str("a")
391        s2 = str("b")
392        c1 = StrSubclass("a")
393        c2 = StrSubclass("b")
394        c3 = StrSubclass("b")
395
396        self.assert_total_order(s1, s1,   0)
397        self.assert_total_order(s1, s2, -1)
398        self.assert_total_order(c1, c1,   0)
399        self.assert_total_order(c1, c2, -1)
400        self.assert_total_order(c2, c3,   0)
401
402        self.assert_total_order(s1, c2, -1)
403        self.assert_total_order(s2, c3,   0)
404        self.assert_total_order(c1, s2, -1)
405        self.assert_total_order(c2, s2,   0)
406
407    def test_numbers(self):
408        """Compare number types."""
409
410        # Same types.
411        i1 = 1001
412        i2 = 1002
413        self.assert_total_order(i1, i1, 0)
414        self.assert_total_order(i1, i2, -1)
415
416        f1 = 1001.0
417        f2 = 1001.1
418        self.assert_total_order(f1, f1, 0)
419        self.assert_total_order(f1, f2, -1)
420
421        q1 = Fraction(2002, 2)
422        q2 = Fraction(2003, 2)
423        self.assert_total_order(q1, q1, 0)
424        self.assert_total_order(q1, q2, -1)
425
426        d1 = Decimal('1001.0')
427        d2 = Decimal('1001.1')
428        self.assert_total_order(d1, d1, 0)
429        self.assert_total_order(d1, d2, -1)
430
431        c1 = 1001+0j
432        c2 = 1001+1j
433        self.assert_equality_only(c1, c1, True)
434        self.assert_equality_only(c1, c2, False)
435
436
437        # Mixing types.
438        for n1, n2 in ((i1,f1), (i1,q1), (i1,d1), (f1,q1), (f1,d1), (q1,d1)):
439            self.assert_total_order(n1, n2, 0)
440        for n1 in (i1, f1, q1, d1):
441            self.assert_equality_only(n1, c1, True)
442
443    def test_sequences(self):
444        """Compare list, tuple, and range."""
445        l1 = [1, 2]
446        l2 = [2, 3]
447        self.assert_total_order(l1, l1, 0)
448        self.assert_total_order(l1, l2, -1)
449
450        t1 = (1, 2)
451        t2 = (2, 3)
452        self.assert_total_order(t1, t1, 0)
453        self.assert_total_order(t1, t2, -1)
454
455        r1 = range(1, 2)
456        r2 = range(2, 2)
457        self.assert_equality_only(r1, r1, True)
458        self.assert_equality_only(r1, r2, False)
459
460        self.assert_equality_only(t1, l1, False)
461        self.assert_equality_only(l1, r1, False)
462        self.assert_equality_only(r1, t1, False)
463
464    def test_bytes(self):
465        """Compare bytes and bytearray."""
466        bs1 = b'a1'
467        bs2 = b'b2'
468        self.assert_total_order(bs1, bs1, 0)
469        self.assert_total_order(bs1, bs2, -1)
470
471        ba1 = bytearray(b'a1')
472        ba2 = bytearray(b'b2')
473        self.assert_total_order(ba1, ba1,  0)
474        self.assert_total_order(ba1, ba2, -1)
475
476        self.assert_total_order(bs1, ba1, 0)
477        self.assert_total_order(bs1, ba2, -1)
478        self.assert_total_order(ba1, bs1, 0)
479        self.assert_total_order(ba1, bs2, -1)
480
481    def test_sets(self):
482        """Compare set and frozenset."""
483        s1 = {1, 2}
484        s2 = {1, 2, 3}
485        self.assert_total_order(s1, s1, 0)
486        self.assert_total_order(s1, s2, -1)
487
488        f1 = frozenset(s1)
489        f2 = frozenset(s2)
490        self.assert_total_order(f1, f1,  0)
491        self.assert_total_order(f1, f2, -1)
492
493        self.assert_total_order(s1, f1, 0)
494        self.assert_total_order(s1, f2, -1)
495        self.assert_total_order(f1, s1, 0)
496        self.assert_total_order(f1, s2, -1)
497
498    def test_mappings(self):
499        """ Compare dict.
500        """
501        d1 = {1: "a", 2: "b"}
502        d2 = {2: "b", 3: "c"}
503        d3 = {3: "c", 2: "b"}
504        self.assert_equality_only(d1, d1, True)
505        self.assert_equality_only(d1, d2, False)
506        self.assert_equality_only(d2, d3, True)
507
508
509if __name__ == '__main__':
510    unittest.main()
511