1import tempfile 2import os 3import shutil 4import sys 5import contextlib 6import site 7import io 8 9import pkg_resources 10from filelock import FileLock 11 12 13@contextlib.contextmanager 14def tempdir(cd=lambda dir: None, **kwargs): 15 temp_dir = tempfile.mkdtemp(**kwargs) 16 orig_dir = os.getcwd() 17 try: 18 cd(temp_dir) 19 yield temp_dir 20 finally: 21 cd(orig_dir) 22 shutil.rmtree(temp_dir) 23 24 25@contextlib.contextmanager 26def environment(**replacements): 27 """ 28 In a context, patch the environment with replacements. Pass None values 29 to clear the values. 30 """ 31 saved = dict( 32 (key, os.environ[key]) 33 for key in replacements 34 if key in os.environ 35 ) 36 37 # remove values that are null 38 remove = (key for (key, value) in replacements.items() if value is None) 39 for key in list(remove): 40 os.environ.pop(key, None) 41 replacements.pop(key) 42 43 os.environ.update(replacements) 44 45 try: 46 yield saved 47 finally: 48 for key in replacements: 49 os.environ.pop(key, None) 50 os.environ.update(saved) 51 52 53@contextlib.contextmanager 54def quiet(): 55 """ 56 Redirect stdout/stderr to StringIO objects to prevent console output from 57 distutils commands. 58 """ 59 60 old_stdout = sys.stdout 61 old_stderr = sys.stderr 62 new_stdout = sys.stdout = io.StringIO() 63 new_stderr = sys.stderr = io.StringIO() 64 try: 65 yield new_stdout, new_stderr 66 finally: 67 new_stdout.seek(0) 68 new_stderr.seek(0) 69 sys.stdout = old_stdout 70 sys.stderr = old_stderr 71 72 73@contextlib.contextmanager 74def save_user_site_setting(): 75 saved = site.ENABLE_USER_SITE 76 try: 77 yield saved 78 finally: 79 site.ENABLE_USER_SITE = saved 80 81 82@contextlib.contextmanager 83def save_pkg_resources_state(): 84 pr_state = pkg_resources.__getstate__() 85 # also save sys.path 86 sys_path = sys.path[:] 87 try: 88 yield pr_state, sys_path 89 finally: 90 sys.path[:] = sys_path 91 pkg_resources.__setstate__(pr_state) 92 93 94@contextlib.contextmanager 95def suppress_exceptions(*excs): 96 try: 97 yield 98 except excs: 99 pass 100 101 102def multiproc(request): 103 """ 104 Return True if running under xdist and multiple 105 workers are used. 106 """ 107 try: 108 worker_id = request.getfixturevalue('worker_id') 109 except Exception: 110 return False 111 return worker_id != 'master' 112 113 114@contextlib.contextmanager 115def session_locked_tmp_dir(request, tmp_path_factory, name): 116 """Uses a file lock to guarantee only one worker can access a temp dir""" 117 # get the temp directory shared by all workers 118 base = tmp_path_factory.getbasetemp() 119 shared_dir = base.parent if multiproc(request) else base 120 121 locked_dir = shared_dir / name 122 with FileLock(locked_dir.with_suffix(".lock")): 123 # ^-- prevent multiple workers to access the directory at once 124 locked_dir.mkdir(exist_ok=True, parents=True) 125 yield locked_dir 126