1# mypy: allow-untyped-defs 2# Extra utilities for working with context managers that should have been 3# in the standard library but are not 4 5import functools 6import inspect 7import warnings 8import sys 9from typing import Any, Callable, TypeVar, cast 10 11# Used for annotating the decorator usage of _DecoratorContextManager (e.g., 12# 'no_grad' and 'enable_grad'). 13# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators 14FuncType = Callable[..., Any] 15F = TypeVar('F', bound=FuncType) 16 17 18def _wrap_generator(ctx_factory, func): 19 """ 20 Wrap each generator invocation with the context manager factory. 21 22 The input should be a function that returns a context manager, 23 not a context manager itself, to handle one-shot context managers. 24 """ 25 @functools.wraps(func) 26 def generator_context(*args, **kwargs): 27 gen = func(*args, **kwargs) 28 29 # Generators are suspended and unsuspended at `yield`, hence we 30 # make sure the grad mode is properly set every time the execution 31 # flow returns into the wrapped generator and restored when it 32 # returns through our `yield` to our caller (see PR #49017). 33 try: 34 # Issuing `None` to a generator fires it up 35 with ctx_factory(): 36 response = gen.send(None) 37 38 while True: 39 try: 40 # Forward the response to our caller and get its next request 41 request = yield response 42 43 except GeneratorExit: 44 # Inform the still active generator about its imminent closure 45 with ctx_factory(): 46 gen.close() 47 raise 48 49 except BaseException: 50 # Propagate the exception thrown at us by the caller 51 with ctx_factory(): 52 response = gen.throw(*sys.exc_info()) 53 54 else: 55 # Pass the last request to the generator and get its response 56 with ctx_factory(): 57 response = gen.send(request) 58 59 # We let the exceptions raised above by the generator's `.throw` or 60 # `.send` methods bubble up to our caller, except for StopIteration 61 except StopIteration as e: 62 # The generator informed us that it is done: take whatever its 63 # returned value (if any) was and indicate that we're done too 64 # by returning it (see docs for python's return-statement). 65 return e.value 66 67 return generator_context 68 69 70def context_decorator(ctx, func): 71 """ 72 Like contextlib.ContextDecorator. 73 74 But with the following differences: 75 1. Is done by wrapping, rather than inheritance, so it works with context 76 managers that are implemented from C and thus cannot easily inherit from 77 Python classes 78 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743) 79 3. Errors out if you try to wrap a class, because it is ambiguous whether 80 or not you intended to wrap only the constructor 81 82 The input argument can either be a context manager (in which case it must 83 be a multi-shot context manager that can be directly invoked multiple times) 84 or a callable that produces a context manager. 85 """ 86 assert not (callable(ctx) and hasattr(ctx, '__enter__')), ( 87 f"Passed in {ctx} is both callable and also a valid context manager " 88 "(has __enter__), making it ambiguous which interface to use. If you " 89 "intended to pass a context manager factory, rewrite your call as " 90 "context_decorator(lambda: ctx()); if you intended to pass a context " 91 "manager directly, rewrite your call as context_decorator(lambda: ctx)" 92 ) 93 94 if not callable(ctx): 95 def ctx_factory(): 96 return ctx 97 else: 98 ctx_factory = ctx 99 100 if inspect.isclass(func): 101 raise RuntimeError( 102 "Cannot decorate classes; it is ambiguous whether or not only the " 103 "constructor or all methods should have the context manager applied; " 104 "additionally, decorating a class at definition-site will prevent " 105 "use of the identifier as a conventional type. " 106 "To specify which methods to decorate, decorate each of them " 107 "individually." 108 ) 109 110 if inspect.isgeneratorfunction(func): 111 return _wrap_generator(ctx_factory, func) 112 113 @functools.wraps(func) 114 def decorate_context(*args, **kwargs): 115 with ctx_factory(): 116 return func(*args, **kwargs) 117 118 return decorate_context 119 120 121class _DecoratorContextManager: 122 """Allow a context manager to be used as a decorator.""" 123 124 def __call__(self, orig_func: F) -> F: 125 if inspect.isclass(orig_func): 126 warnings.warn( 127 "Decorating classes is deprecated and will be disabled in " 128 "future versions. You should only decorate functions or methods. " 129 "To preserve the current behavior of class decoration, you can " 130 "directly decorate the `__init__` method and nothing else.", 131 FutureWarning, 132 stacklevel=2, 133 ) 134 func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs)) 135 else: 136 func = orig_func 137 138 return cast(F, context_decorator(self.clone, func)) 139 140 def __enter__(self) -> None: 141 raise NotImplementedError 142 143 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 144 raise NotImplementedError 145 146 def clone(self): 147 # override this method if your children class takes __init__ parameters 148 return self.__class__() 149 150 151class _NoParamDecoratorContextManager(_DecoratorContextManager): 152 """Allow a context manager to be used as a decorator without parentheses.""" 153 154 def __new__(cls, orig_func=None): 155 if orig_func is None: 156 return super().__new__(cls) 157 return cls()(orig_func) 158