1import tempfile
2import os
3import shutil
4import sys
5import contextlib
6import site
7import io
8
9import pkg_resources
10from filelock import FileLock
11
12
13@contextlib.contextmanager
14def tempdir(cd=lambda dir: None, **kwargs):
15    temp_dir = tempfile.mkdtemp(**kwargs)
16    orig_dir = os.getcwd()
17    try:
18        cd(temp_dir)
19        yield temp_dir
20    finally:
21        cd(orig_dir)
22        shutil.rmtree(temp_dir)
23
24
25@contextlib.contextmanager
26def environment(**replacements):
27    """
28    In a context, patch the environment with replacements. Pass None values
29    to clear the values.
30    """
31    saved = dict(
32        (key, os.environ[key])
33        for key in replacements
34        if key in os.environ
35    )
36
37    # remove values that are null
38    remove = (key for (key, value) in replacements.items() if value is None)
39    for key in list(remove):
40        os.environ.pop(key, None)
41        replacements.pop(key)
42
43    os.environ.update(replacements)
44
45    try:
46        yield saved
47    finally:
48        for key in replacements:
49            os.environ.pop(key, None)
50        os.environ.update(saved)
51
52
53@contextlib.contextmanager
54def quiet():
55    """
56    Redirect stdout/stderr to StringIO objects to prevent console output from
57    distutils commands.
58    """
59
60    old_stdout = sys.stdout
61    old_stderr = sys.stderr
62    new_stdout = sys.stdout = io.StringIO()
63    new_stderr = sys.stderr = io.StringIO()
64    try:
65        yield new_stdout, new_stderr
66    finally:
67        new_stdout.seek(0)
68        new_stderr.seek(0)
69        sys.stdout = old_stdout
70        sys.stderr = old_stderr
71
72
73@contextlib.contextmanager
74def save_user_site_setting():
75    saved = site.ENABLE_USER_SITE
76    try:
77        yield saved
78    finally:
79        site.ENABLE_USER_SITE = saved
80
81
82@contextlib.contextmanager
83def save_pkg_resources_state():
84    pr_state = pkg_resources.__getstate__()
85    # also save sys.path
86    sys_path = sys.path[:]
87    try:
88        yield pr_state, sys_path
89    finally:
90        sys.path[:] = sys_path
91        pkg_resources.__setstate__(pr_state)
92
93
94@contextlib.contextmanager
95def suppress_exceptions(*excs):
96    try:
97        yield
98    except excs:
99        pass
100
101
102def multiproc(request):
103    """
104    Return True if running under xdist and multiple
105    workers are used.
106    """
107    try:
108        worker_id = request.getfixturevalue('worker_id')
109    except Exception:
110        return False
111    return worker_id != 'master'
112
113
114@contextlib.contextmanager
115def session_locked_tmp_dir(request, tmp_path_factory, name):
116    """Uses a file lock to guarantee only one worker can access a temp dir"""
117    # get the temp directory shared by all workers
118    base = tmp_path_factory.getbasetemp()
119    shared_dir = base.parent if multiproc(request) else base
120
121    locked_dir = shared_dir / name
122    with FileLock(locked_dir.with_suffix(".lock")):
123        # ^-- prevent multiple workers to access the directory at once
124        locked_dir.mkdir(exist_ok=True, parents=True)
125        yield locked_dir
126