1import builtins
2import contextlib
3import errno
4import functools
5from importlib import machinery, util, invalidate_caches
6import marshal
7import os
8import os.path
9from test.support import import_helper
10from test.support import os_helper
11import unittest
12import sys
13import tempfile
14import types
15
16
17BUILTINS = types.SimpleNamespace()
18BUILTINS.good_name = None
19BUILTINS.bad_name = None
20if 'errno' in sys.builtin_module_names:
21    BUILTINS.good_name = 'errno'
22if 'importlib' not in sys.builtin_module_names:
23    BUILTINS.bad_name = 'importlib'
24
25EXTENSIONS = types.SimpleNamespace()
26EXTENSIONS.path = None
27EXTENSIONS.ext = None
28EXTENSIONS.filename = None
29EXTENSIONS.file_path = None
30EXTENSIONS.name = '_testcapi'
31
32def _extension_details():
33    global EXTENSIONS
34    for path in sys.path:
35        for ext in machinery.EXTENSION_SUFFIXES:
36            filename = EXTENSIONS.name + ext
37            file_path = os.path.join(path, filename)
38            if os.path.exists(file_path):
39                EXTENSIONS.path = path
40                EXTENSIONS.ext = ext
41                EXTENSIONS.filename = filename
42                EXTENSIONS.file_path = file_path
43                return
44
45_extension_details()
46
47
48def import_importlib(module_name):
49    """Import a module from importlib both w/ and w/o _frozen_importlib."""
50    fresh = ('importlib',) if '.' in module_name else ()
51    frozen = import_helper.import_fresh_module(module_name)
52    source = import_helper.import_fresh_module(module_name, fresh=fresh,
53                                         blocked=('_frozen_importlib', '_frozen_importlib_external'))
54    return {'Frozen': frozen, 'Source': source}
55
56
57def specialize_class(cls, kind, base=None, **kwargs):
58    # XXX Support passing in submodule names--load (and cache) them?
59    # That would clean up the test modules a bit more.
60    if base is None:
61        base = unittest.TestCase
62    elif not isinstance(base, type):
63        base = base[kind]
64    name = '{}_{}'.format(kind, cls.__name__)
65    bases = (cls, base)
66    specialized = types.new_class(name, bases)
67    specialized.__module__ = cls.__module__
68    specialized._NAME = cls.__name__
69    specialized._KIND = kind
70    for attr, values in kwargs.items():
71        value = values[kind]
72        setattr(specialized, attr, value)
73    return specialized
74
75
76def split_frozen(cls, base=None, **kwargs):
77    frozen = specialize_class(cls, 'Frozen', base, **kwargs)
78    source = specialize_class(cls, 'Source', base, **kwargs)
79    return frozen, source
80
81
82def test_both(test_class, base=None, **kwargs):
83    return split_frozen(test_class, base, **kwargs)
84
85
86CASE_INSENSITIVE_FS = True
87# Windows is the only OS that is *always* case-insensitive
88# (OS X *can* be case-sensitive).
89if sys.platform not in ('win32', 'cygwin'):
90    changed_name = __file__.upper()
91    if changed_name == __file__:
92        changed_name = __file__.lower()
93    if not os.path.exists(changed_name):
94        CASE_INSENSITIVE_FS = False
95
96source_importlib = import_importlib('importlib')['Source']
97__import__ = {'Frozen': staticmethod(builtins.__import__),
98              'Source': staticmethod(source_importlib.__import__)}
99
100
101def case_insensitive_tests(test):
102    """Class decorator that nullifies tests requiring a case-insensitive
103    file system."""
104    return unittest.skipIf(not CASE_INSENSITIVE_FS,
105                            "requires a case-insensitive filesystem")(test)
106
107
108def submodule(parent, name, pkg_dir, content=''):
109    path = os.path.join(pkg_dir, name + '.py')
110    with open(path, 'w', encoding='utf-8') as subfile:
111        subfile.write(content)
112    return '{}.{}'.format(parent, name), path
113
114
115def get_code_from_pyc(pyc_path):
116    """Reads a pyc file and returns the unmarshalled code object within.
117
118    No header validation is performed.
119    """
120    with open(pyc_path, 'rb') as pyc_f:
121        pyc_f.seek(16)
122        return marshal.load(pyc_f)
123
124
125@contextlib.contextmanager
126def uncache(*names):
127    """Uncache a module from sys.modules.
128
129    A basic sanity check is performed to prevent uncaching modules that either
130    cannot/shouldn't be uncached.
131
132    """
133    for name in names:
134        if name in ('sys', 'marshal', 'imp'):
135            raise ValueError(
136                "cannot uncache {0}".format(name))
137        try:
138            del sys.modules[name]
139        except KeyError:
140            pass
141    try:
142        yield
143    finally:
144        for name in names:
145            try:
146                del sys.modules[name]
147            except KeyError:
148                pass
149
150
151@contextlib.contextmanager
152def temp_module(name, content='', *, pkg=False):
153    conflicts = [n for n in sys.modules if n.partition('.')[0] == name]
154    with os_helper.temp_cwd(None) as cwd:
155        with uncache(name, *conflicts):
156            with import_helper.DirsOnSysPath(cwd):
157                invalidate_caches()
158
159                location = os.path.join(cwd, name)
160                if pkg:
161                    modpath = os.path.join(location, '__init__.py')
162                    os.mkdir(name)
163                else:
164                    modpath = location + '.py'
165                    if content is None:
166                        # Make sure the module file gets created.
167                        content = ''
168                if content is not None:
169                    # not a namespace package
170                    with open(modpath, 'w', encoding='utf-8') as modfile:
171                        modfile.write(content)
172                yield location
173
174
175@contextlib.contextmanager
176def import_state(**kwargs):
177    """Context manager to manage the various importers and stored state in the
178    sys module.
179
180    The 'modules' attribute is not supported as the interpreter state stores a
181    pointer to the dict that the interpreter uses internally;
182    reassigning to sys.modules does not have the desired effect.
183
184    """
185    originals = {}
186    try:
187        for attr, default in (('meta_path', []), ('path', []),
188                              ('path_hooks', []),
189                              ('path_importer_cache', {})):
190            originals[attr] = getattr(sys, attr)
191            if attr in kwargs:
192                new_value = kwargs[attr]
193                del kwargs[attr]
194            else:
195                new_value = default
196            setattr(sys, attr, new_value)
197        if len(kwargs):
198            raise ValueError(
199                    'unrecognized arguments: {0}'.format(kwargs.keys()))
200        yield
201    finally:
202        for attr, value in originals.items():
203            setattr(sys, attr, value)
204
205
206class _ImporterMock:
207
208    """Base class to help with creating importer mocks."""
209
210    def __init__(self, *names, module_code={}):
211        self.modules = {}
212        self.module_code = {}
213        for name in names:
214            if not name.endswith('.__init__'):
215                import_name = name
216            else:
217                import_name = name[:-len('.__init__')]
218            if '.' not in name:
219                package = None
220            elif import_name == name:
221                package = name.rsplit('.', 1)[0]
222            else:
223                package = import_name
224            module = types.ModuleType(import_name)
225            module.__loader__ = self
226            module.__file__ = '<mock __file__>'
227            module.__package__ = package
228            module.attr = name
229            if import_name != name:
230                module.__path__ = ['<mock __path__>']
231            self.modules[import_name] = module
232            if import_name in module_code:
233                self.module_code[import_name] = module_code[import_name]
234
235    def __getitem__(self, name):
236        return self.modules[name]
237
238    def __enter__(self):
239        self._uncache = uncache(*self.modules.keys())
240        self._uncache.__enter__()
241        return self
242
243    def __exit__(self, *exc_info):
244        self._uncache.__exit__(None, None, None)
245
246
247class mock_modules(_ImporterMock):
248
249    """Importer mock using PEP 302 APIs."""
250
251    def find_module(self, fullname, path=None):
252        if fullname not in self.modules:
253            return None
254        else:
255            return self
256
257    def load_module(self, fullname):
258        if fullname not in self.modules:
259            raise ImportError
260        else:
261            sys.modules[fullname] = self.modules[fullname]
262            if fullname in self.module_code:
263                try:
264                    self.module_code[fullname]()
265                except Exception:
266                    del sys.modules[fullname]
267                    raise
268            return self.modules[fullname]
269
270
271class mock_spec(_ImporterMock):
272
273    """Importer mock using PEP 451 APIs."""
274
275    def find_spec(self, fullname, path=None, parent=None):
276        try:
277            module = self.modules[fullname]
278        except KeyError:
279            return None
280        spec = util.spec_from_file_location(
281                fullname, module.__file__, loader=self,
282                submodule_search_locations=getattr(module, '__path__', None))
283        return spec
284
285    def create_module(self, spec):
286        if spec.name not in self.modules:
287            raise ImportError
288        return self.modules[spec.name]
289
290    def exec_module(self, module):
291        try:
292            self.module_code[module.__spec__.name]()
293        except KeyError:
294            pass
295
296
297def writes_bytecode_files(fxn):
298    """Decorator to protect sys.dont_write_bytecode from mutation and to skip
299    tests that require it to be set to False."""
300    if sys.dont_write_bytecode:
301        return unittest.skip("relies on writing bytecode")(fxn)
302    @functools.wraps(fxn)
303    def wrapper(*args, **kwargs):
304        original = sys.dont_write_bytecode
305        sys.dont_write_bytecode = False
306        try:
307            to_return = fxn(*args, **kwargs)
308        finally:
309            sys.dont_write_bytecode = original
310        return to_return
311    return wrapper
312
313
314def ensure_bytecode_path(bytecode_path):
315    """Ensure that the __pycache__ directory for PEP 3147 pyc file exists.
316
317    :param bytecode_path: File system path to PEP 3147 pyc file.
318    """
319    try:
320        os.mkdir(os.path.dirname(bytecode_path))
321    except OSError as error:
322        if error.errno != errno.EEXIST:
323            raise
324
325
326@contextlib.contextmanager
327def temporary_pycache_prefix(prefix):
328    """Adjust and restore sys.pycache_prefix."""
329    _orig_prefix = sys.pycache_prefix
330    sys.pycache_prefix = prefix
331    try:
332        yield
333    finally:
334        sys.pycache_prefix = _orig_prefix
335
336
337@contextlib.contextmanager
338def create_modules(*names):
339    """Temporarily create each named module with an attribute (named 'attr')
340    that contains the name passed into the context manager that caused the
341    creation of the module.
342
343    All files are created in a temporary directory returned by
344    tempfile.mkdtemp(). This directory is inserted at the beginning of
345    sys.path. When the context manager exits all created files (source and
346    bytecode) are explicitly deleted.
347
348    No magic is performed when creating packages! This means that if you create
349    a module within a package you must also create the package's __init__ as
350    well.
351
352    """
353    source = 'attr = {0!r}'
354    created_paths = []
355    mapping = {}
356    state_manager = None
357    uncache_manager = None
358    try:
359        temp_dir = tempfile.mkdtemp()
360        mapping['.root'] = temp_dir
361        import_names = set()
362        for name in names:
363            if not name.endswith('__init__'):
364                import_name = name
365            else:
366                import_name = name[:-len('.__init__')]
367            import_names.add(import_name)
368            if import_name in sys.modules:
369                del sys.modules[import_name]
370            name_parts = name.split('.')
371            file_path = temp_dir
372            for directory in name_parts[:-1]:
373                file_path = os.path.join(file_path, directory)
374                if not os.path.exists(file_path):
375                    os.mkdir(file_path)
376                    created_paths.append(file_path)
377            file_path = os.path.join(file_path, name_parts[-1] + '.py')
378            with open(file_path, 'w', encoding='utf-8') as file:
379                file.write(source.format(name))
380            created_paths.append(file_path)
381            mapping[name] = file_path
382        uncache_manager = uncache(*import_names)
383        uncache_manager.__enter__()
384        state_manager = import_state(path=[temp_dir])
385        state_manager.__enter__()
386        yield mapping
387    finally:
388        if state_manager is not None:
389            state_manager.__exit__(None, None, None)
390        if uncache_manager is not None:
391            uncache_manager.__exit__(None, None, None)
392        os_helper.rmtree(temp_dir)
393
394
395def mock_path_hook(*entries, importer):
396    """A mock sys.path_hooks entry."""
397    def hook(entry):
398        if entry not in entries:
399            raise ImportError
400        return importer
401    return hook
402
403
404class CASEOKTestBase:
405
406    def caseok_env_changed(self, *, should_exist):
407        possibilities = b'PYTHONCASEOK', 'PYTHONCASEOK'
408        if any(x in self.importlib._bootstrap_external._os.environ
409                    for x in possibilities) != should_exist:
410            self.skipTest('os.environ changes not reflected in _os.environ')
411