1"""Tests for distutils.util."""
2import os
3import sys
4import unittest
5import sysconfig as stdlib_sysconfig
6from copy import copy
7from test.support import run_unittest
8from unittest import mock
9
10from distutils.errors import DistutilsPlatformError, DistutilsByteCompileError
11from distutils.util import (get_platform, convert_path, change_root,
12                            check_environ, split_quoted, strtobool,
13                            rfc822_escape, byte_compile,
14                            grok_environment_error, get_host_platform)
15from distutils import util # used to patch _environ_checked
16from distutils import sysconfig
17from distutils.tests import support
18
19class UtilTestCase(support.EnvironGuard, unittest.TestCase):
20
21    def setUp(self):
22        super(UtilTestCase, self).setUp()
23        # saving the environment
24        self.name = os.name
25        self.platform = sys.platform
26        self.version = sys.version
27        self.sep = os.sep
28        self.join = os.path.join
29        self.isabs = os.path.isabs
30        self.splitdrive = os.path.splitdrive
31        self._config_vars = copy(sysconfig._config_vars)
32
33        # patching os.uname
34        if hasattr(os, 'uname'):
35            self.uname = os.uname
36            self._uname = os.uname()
37        else:
38            self.uname = None
39            self._uname = None
40
41        os.uname = self._get_uname
42
43    def tearDown(self):
44        # getting back the environment
45        os.name = self.name
46        sys.platform = self.platform
47        sys.version = self.version
48        os.sep = self.sep
49        os.path.join = self.join
50        os.path.isabs = self.isabs
51        os.path.splitdrive = self.splitdrive
52        if self.uname is not None:
53            os.uname = self.uname
54        else:
55            del os.uname
56        sysconfig._config_vars = copy(self._config_vars)
57        super(UtilTestCase, self).tearDown()
58
59    def _set_uname(self, uname):
60        self._uname = uname
61
62    def _get_uname(self):
63        return self._uname
64
65    def test_get_host_platform(self):
66        with unittest.mock.patch('os.name', 'nt'):
67             with unittest.mock.patch('sys.version', '... [... (ARM64)]'):
68                self.assertEqual(get_host_platform(), 'win-arm64')
69             with unittest.mock.patch('sys.version', '... [... (ARM)]'):
70                self.assertEqual(get_host_platform(), 'win-arm32')
71
72        with unittest.mock.patch('sys.version_info', (3, 9, 0, 'final', 0)):
73            self.assertEqual(get_host_platform(), stdlib_sysconfig.get_platform())
74
75    def test_get_platform(self):
76        with unittest.mock.patch('os.name', 'nt'):
77            with unittest.mock.patch.dict('os.environ', {'VSCMD_ARG_TGT_ARCH': 'x86'}):
78                 self.assertEqual(get_platform(), 'win32')
79            with unittest.mock.patch.dict('os.environ', {'VSCMD_ARG_TGT_ARCH': 'x64'}):
80                 self.assertEqual(get_platform(), 'win-amd64')
81            with unittest.mock.patch.dict('os.environ', {'VSCMD_ARG_TGT_ARCH': 'arm'}):
82                 self.assertEqual(get_platform(), 'win-arm32')
83            with unittest.mock.patch.dict('os.environ', {'VSCMD_ARG_TGT_ARCH': 'arm64'}):
84                 self.assertEqual(get_platform(), 'win-arm64')
85
86    def test_convert_path(self):
87        # linux/mac
88        os.sep = '/'
89        def _join(path):
90            return '/'.join(path)
91        os.path.join = _join
92
93        self.assertEqual(convert_path('/home/to/my/stuff'),
94                         '/home/to/my/stuff')
95
96        # win
97        os.sep = '\\'
98        def _join(*path):
99            return '\\'.join(path)
100        os.path.join = _join
101
102        self.assertRaises(ValueError, convert_path, '/home/to/my/stuff')
103        self.assertRaises(ValueError, convert_path, 'home/to/my/stuff/')
104
105        self.assertEqual(convert_path('home/to/my/stuff'),
106                         'home\\to\\my\\stuff')
107        self.assertEqual(convert_path('.'),
108                         os.curdir)
109
110    def test_change_root(self):
111        # linux/mac
112        os.name = 'posix'
113        def _isabs(path):
114            return path[0] == '/'
115        os.path.isabs = _isabs
116        def _join(*path):
117            return '/'.join(path)
118        os.path.join = _join
119
120        self.assertEqual(change_root('/root', '/old/its/here'),
121                         '/root/old/its/here')
122        self.assertEqual(change_root('/root', 'its/here'),
123                         '/root/its/here')
124
125        # windows
126        os.name = 'nt'
127        def _isabs(path):
128            return path.startswith('c:\\')
129        os.path.isabs = _isabs
130        def _splitdrive(path):
131            if path.startswith('c:'):
132                return ('', path.replace('c:', ''))
133            return ('', path)
134        os.path.splitdrive = _splitdrive
135        def _join(*path):
136            return '\\'.join(path)
137        os.path.join = _join
138
139        self.assertEqual(change_root('c:\\root', 'c:\\old\\its\\here'),
140                         'c:\\root\\old\\its\\here')
141        self.assertEqual(change_root('c:\\root', 'its\\here'),
142                         'c:\\root\\its\\here')
143
144        # BugsBunny os (it's a great os)
145        os.name = 'BugsBunny'
146        self.assertRaises(DistutilsPlatformError,
147                          change_root, 'c:\\root', 'its\\here')
148
149        # XXX platforms to be covered: mac
150
151    def test_check_environ(self):
152        util._environ_checked = 0
153        os.environ.pop('HOME', None)
154
155        check_environ()
156
157        self.assertEqual(os.environ['PLAT'], get_platform())
158        self.assertEqual(util._environ_checked, 1)
159
160    @unittest.skipUnless(os.name == 'posix', 'specific to posix')
161    def test_check_environ_getpwuid(self):
162        util._environ_checked = 0
163        os.environ.pop('HOME', None)
164
165        import pwd
166
167        # only set pw_dir field, other fields are not used
168        result = pwd.struct_passwd((None, None, None, None, None,
169                                    '/home/distutils', None))
170        with mock.patch.object(pwd, 'getpwuid', return_value=result):
171            check_environ()
172            self.assertEqual(os.environ['HOME'], '/home/distutils')
173
174        util._environ_checked = 0
175        os.environ.pop('HOME', None)
176
177        # bpo-10496: Catch pwd.getpwuid() error
178        with mock.patch.object(pwd, 'getpwuid', side_effect=KeyError):
179            check_environ()
180            self.assertNotIn('HOME', os.environ)
181
182    def test_split_quoted(self):
183        self.assertEqual(split_quoted('""one"" "two" \'three\' \\four'),
184                         ['one', 'two', 'three', 'four'])
185
186    def test_strtobool(self):
187        yes = ('y', 'Y', 'yes', 'True', 't', 'true', 'True', 'On', 'on', '1')
188        no = ('n', 'no', 'f', 'false', 'off', '0', 'Off', 'No', 'N')
189
190        for y in yes:
191            self.assertTrue(strtobool(y))
192
193        for n in no:
194            self.assertFalse(strtobool(n))
195
196    def test_rfc822_escape(self):
197        header = 'I am a\npoor\nlonesome\nheader\n'
198        res = rfc822_escape(header)
199        wanted = ('I am a%(8s)spoor%(8s)slonesome%(8s)s'
200                  'header%(8s)s') % {'8s': '\n'+8*' '}
201        self.assertEqual(res, wanted)
202
203    def test_dont_write_bytecode(self):
204        # makes sure byte_compile raise a DistutilsError
205        # if sys.dont_write_bytecode is True
206        old_dont_write_bytecode = sys.dont_write_bytecode
207        sys.dont_write_bytecode = True
208        try:
209            self.assertRaises(DistutilsByteCompileError, byte_compile, [])
210        finally:
211            sys.dont_write_bytecode = old_dont_write_bytecode
212
213    def test_grok_environment_error(self):
214        # test obsolete function to ensure backward compat (#4931)
215        exc = IOError("Unable to find batch file")
216        msg = grok_environment_error(exc)
217        self.assertEqual(msg, "error: Unable to find batch file")
218
219
220def test_suite():
221    return unittest.TestLoader().loadTestsFromTestCase(UtilTestCase)
222
223if __name__ == "__main__":
224    run_unittest(test_suite())
225