1# UserString is a wrapper around the native builtin string type.
2# UserString instances should behave similar to builtin string objects.
3
4import unittest
5from test import string_tests
6
7from collections import UserString
8
9class UserStringTest(
10    string_tests.CommonTest,
11    string_tests.MixinStrUnicodeUserStringTest,
12    unittest.TestCase
13    ):
14
15    type2test = UserString
16
17    # Overwrite the three testing methods, because UserString
18    # can't cope with arguments propagated to UserString
19    # (and we don't test with subclasses)
20    def checkequal(self, result, object, methodname, *args, **kwargs):
21        result = self.fixtype(result)
22        object = self.fixtype(object)
23        # we don't fix the arguments, because UserString can't cope with it
24        realresult = getattr(object, methodname)(*args, **kwargs)
25        self.assertEqual(
26            result,
27            realresult
28        )
29
30    def checkraises(self, exc, obj, methodname, *args, expected_msg=None):
31        obj = self.fixtype(obj)
32        # we don't fix the arguments, because UserString can't cope with it
33        with self.assertRaises(exc) as cm:
34            getattr(obj, methodname)(*args)
35        self.assertNotEqual(str(cm.exception), '')
36        if expected_msg is not None:
37            self.assertEqual(str(cm.exception), expected_msg)
38
39    def checkcall(self, object, methodname, *args):
40        object = self.fixtype(object)
41        # we don't fix the arguments, because UserString can't cope with it
42        getattr(object, methodname)(*args)
43
44    def test_rmod(self):
45        class ustr2(UserString):
46            pass
47
48        class ustr3(ustr2):
49            def __rmod__(self, other):
50                return super().__rmod__(other)
51
52        fmt2 = ustr2('value is %s')
53        str3 = ustr3('TEST')
54        self.assertEqual(fmt2 % str3, 'value is TEST')
55
56    def test_encode_default_args(self):
57        self.checkequal(b'hello', 'hello', 'encode')
58        # Check that encoding defaults to utf-8
59        self.checkequal(b'\xf0\xa3\x91\x96', '\U00023456', 'encode')
60        # Check that errors defaults to 'strict'
61        self.checkraises(UnicodeError, '\ud800', 'encode')
62
63    def test_encode_explicit_none_args(self):
64        self.checkequal(b'hello', 'hello', 'encode', None, None)
65        # Check that encoding defaults to utf-8
66        self.checkequal(b'\xf0\xa3\x91\x96', '\U00023456', 'encode', None, None)
67        # Check that errors defaults to 'strict'
68        self.checkraises(UnicodeError, '\ud800', 'encode', None, None)
69
70
71if __name__ == "__main__":
72    unittest.main()
73