1import contextlib
2import _imp
3import importlib
4import importlib.util
5import os
6import shutil
7import sys
8import unittest
9import warnings
10
11from .os_helper import unlink
12
13
14@contextlib.contextmanager
15def _ignore_deprecated_imports(ignore=True):
16    """Context manager to suppress package and module deprecation
17    warnings when importing them.
18
19    If ignore is False, this context manager has no effect.
20    """
21    if ignore:
22        with warnings.catch_warnings():
23            warnings.filterwarnings("ignore", ".+ (module|package)",
24                                    DeprecationWarning)
25            yield
26    else:
27        yield
28
29
30def unload(name):
31    try:
32        del sys.modules[name]
33    except KeyError:
34        pass
35
36
37def forget(modname):
38    """'Forget' a module was ever imported.
39
40    This removes the module from sys.modules and deletes any PEP 3147/488 or
41    legacy .pyc files.
42    """
43    unload(modname)
44    for dirname in sys.path:
45        source = os.path.join(dirname, modname + '.py')
46        # It doesn't matter if they exist or not, unlink all possible
47        # combinations of PEP 3147/488 and legacy pyc files.
48        unlink(source + 'c')
49        for opt in ('', 1, 2):
50            unlink(importlib.util.cache_from_source(source, optimization=opt))
51
52
53def make_legacy_pyc(source):
54    """Move a PEP 3147/488 pyc file to its legacy pyc location.
55
56    :param source: The file system path to the source file.  The source file
57        does not need to exist, however the PEP 3147/488 pyc file must exist.
58    :return: The file system path to the legacy pyc file.
59    """
60    pyc_file = importlib.util.cache_from_source(source)
61    up_one = os.path.dirname(os.path.abspath(source))
62    legacy_pyc = os.path.join(up_one, source + 'c')
63    shutil.move(pyc_file, legacy_pyc)
64    return legacy_pyc
65
66
67def import_module(name, deprecated=False, *, required_on=()):
68    """Import and return the module to be tested, raising SkipTest if
69    it is not available.
70
71    If deprecated is True, any module or package deprecation messages
72    will be suppressed. If a module is required on a platform but optional for
73    others, set required_on to an iterable of platform prefixes which will be
74    compared against sys.platform.
75    """
76    with _ignore_deprecated_imports(deprecated):
77        try:
78            return importlib.import_module(name)
79        except ImportError as msg:
80            if sys.platform.startswith(tuple(required_on)):
81                raise
82            raise unittest.SkipTest(str(msg))
83
84
85def _save_and_remove_modules(names):
86    orig_modules = {}
87    prefixes = tuple(name + '.' for name in names)
88    for modname in list(sys.modules):
89        if modname in names or modname.startswith(prefixes):
90            orig_modules[modname] = sys.modules.pop(modname)
91    return orig_modules
92
93
94@contextlib.contextmanager
95def frozen_modules(enabled=True):
96    """Force frozen modules to be used (or not).
97
98    This only applies to modules that haven't been imported yet.
99    Also, some essential modules will always be imported frozen.
100    """
101    _imp._override_frozen_modules_for_tests(1 if enabled else -1)
102    try:
103        yield
104    finally:
105        _imp._override_frozen_modules_for_tests(0)
106
107
108def import_fresh_module(name, fresh=(), blocked=(), *,
109                        deprecated=False,
110                        usefrozen=False,
111                        ):
112    """Import and return a module, deliberately bypassing sys.modules.
113
114    This function imports and returns a fresh copy of the named Python module
115    by removing the named module from sys.modules before doing the import.
116    Note that unlike reload, the original module is not affected by
117    this operation.
118
119    *fresh* is an iterable of additional module names that are also removed
120    from the sys.modules cache before doing the import. If one of these
121    modules can't be imported, None is returned.
122
123    *blocked* is an iterable of module names that are replaced with None
124    in the module cache during the import to ensure that attempts to import
125    them raise ImportError.
126
127    The named module and any modules named in the *fresh* and *blocked*
128    parameters are saved before starting the import and then reinserted into
129    sys.modules when the fresh import is complete.
130
131    Module and package deprecation messages are suppressed during this import
132    if *deprecated* is True.
133
134    This function will raise ImportError if the named module cannot be
135    imported.
136
137    If "usefrozen" is False (the default) then the frozen importer is
138    disabled (except for essential modules like importlib._bootstrap).
139    """
140    # NOTE: test_heapq, test_json and test_warnings include extra sanity checks
141    # to make sure that this utility function is working as expected
142    with _ignore_deprecated_imports(deprecated):
143        # Keep track of modules saved for later restoration as well
144        # as those which just need a blocking entry removed
145        fresh = list(fresh)
146        blocked = list(blocked)
147        names = {name, *fresh, *blocked}
148        orig_modules = _save_and_remove_modules(names)
149        for modname in blocked:
150            sys.modules[modname] = None
151
152        try:
153            with frozen_modules(usefrozen):
154                # Return None when one of the "fresh" modules can not be imported.
155                try:
156                    for modname in fresh:
157                        __import__(modname)
158                except ImportError:
159                    return None
160                return importlib.import_module(name)
161        finally:
162            _save_and_remove_modules(names)
163            sys.modules.update(orig_modules)
164
165
166class CleanImport(object):
167    """Context manager to force import to return a new module reference.
168
169    This is useful for testing module-level behaviours, such as
170    the emission of a DeprecationWarning on import.
171
172    Use like this:
173
174        with CleanImport("foo"):
175            importlib.import_module("foo") # new reference
176
177    If "usefrozen" is False (the default) then the frozen importer is
178    disabled (except for essential modules like importlib._bootstrap).
179    """
180
181    def __init__(self, *module_names, usefrozen=False):
182        self.original_modules = sys.modules.copy()
183        for module_name in module_names:
184            if module_name in sys.modules:
185                module = sys.modules[module_name]
186                # It is possible that module_name is just an alias for
187                # another module (e.g. stub for modules renamed in 3.x).
188                # In that case, we also need delete the real module to clear
189                # the import cache.
190                if module.__name__ != module_name:
191                    del sys.modules[module.__name__]
192                del sys.modules[module_name]
193        self._frozen_modules = frozen_modules(usefrozen)
194
195    def __enter__(self):
196        self._frozen_modules.__enter__()
197        return self
198
199    def __exit__(self, *ignore_exc):
200        sys.modules.update(self.original_modules)
201        self._frozen_modules.__exit__(*ignore_exc)
202
203
204class DirsOnSysPath(object):
205    """Context manager to temporarily add directories to sys.path.
206
207    This makes a copy of sys.path, appends any directories given
208    as positional arguments, then reverts sys.path to the copied
209    settings when the context ends.
210
211    Note that *all* sys.path modifications in the body of the
212    context manager, including replacement of the object,
213    will be reverted at the end of the block.
214    """
215
216    def __init__(self, *paths):
217        self.original_value = sys.path[:]
218        self.original_object = sys.path
219        sys.path.extend(paths)
220
221    def __enter__(self):
222        return self
223
224    def __exit__(self, *ignore_exc):
225        sys.path = self.original_object
226        sys.path[:] = self.original_value
227
228
229def modules_setup():
230    return sys.modules.copy(),
231
232
233def modules_cleanup(oldmodules):
234    # Encoders/decoders are registered permanently within the internal
235    # codec cache. If we destroy the corresponding modules their
236    # globals will be set to None which will trip up the cached functions.
237    encodings = [(k, v) for k, v in sys.modules.items()
238                 if k.startswith('encodings.')]
239    sys.modules.clear()
240    sys.modules.update(encodings)
241    # XXX: This kind of problem can affect more than just encodings.
242    # In particular extension modules (such as _ssl) don't cope
243    # with reloading properly. Really, test modules should be cleaning
244    # out the test specific modules they know they added (ala test_runpy)
245    # rather than relying on this function (as test_importhooks and test_pkg
246    # do currently). Implicitly imported *real* modules should be left alone
247    # (see issue 10556).
248    sys.modules.update(oldmodules)
249