xref: /aosp_15_r20/external/pytorch/torch/utils/_contextlib.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Extra utilities for working with context managers that should have been
3# in the standard library but are not
4
5import functools
6import inspect
7import warnings
8import sys
9from typing import Any, Callable, TypeVar, cast
10
11# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
12# 'no_grad' and 'enable_grad').
13# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
14FuncType = Callable[..., Any]
15F = TypeVar('F', bound=FuncType)
16
17
18def _wrap_generator(ctx_factory, func):
19    """
20    Wrap each generator invocation with the context manager factory.
21
22    The input should be a function that returns a context manager,
23    not a context manager itself, to handle one-shot context managers.
24    """
25    @functools.wraps(func)
26    def generator_context(*args, **kwargs):
27        gen = func(*args, **kwargs)
28
29        # Generators are suspended and unsuspended at `yield`, hence we
30        # make sure the grad mode is properly set every time the execution
31        # flow returns into the wrapped generator and restored when it
32        # returns through our `yield` to our caller (see PR #49017).
33        try:
34            # Issuing `None` to a generator fires it up
35            with ctx_factory():
36                response = gen.send(None)
37
38            while True:
39                try:
40                    # Forward the response to our caller and get its next request
41                    request = yield response
42
43                except GeneratorExit:
44                    # Inform the still active generator about its imminent closure
45                    with ctx_factory():
46                        gen.close()
47                    raise
48
49                except BaseException:
50                    # Propagate the exception thrown at us by the caller
51                    with ctx_factory():
52                        response = gen.throw(*sys.exc_info())
53
54                else:
55                    # Pass the last request to the generator and get its response
56                    with ctx_factory():
57                        response = gen.send(request)
58
59        # We let the exceptions raised above by the generator's `.throw` or
60        # `.send` methods bubble up to our caller, except for StopIteration
61        except StopIteration as e:
62            # The generator informed us that it is done: take whatever its
63            # returned value (if any) was and indicate that we're done too
64            # by returning it (see docs for python's return-statement).
65            return e.value
66
67    return generator_context
68
69
70def context_decorator(ctx, func):
71    """
72    Like contextlib.ContextDecorator.
73
74    But with the following differences:
75    1. Is done by wrapping, rather than inheritance, so it works with context
76       managers that are implemented from C and thus cannot easily inherit from
77       Python classes
78    2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
79    3. Errors out if you try to wrap a class, because it is ambiguous whether
80       or not you intended to wrap only the constructor
81
82    The input argument can either be a context manager (in which case it must
83    be a multi-shot context manager that can be directly invoked multiple times)
84    or a callable that produces a context manager.
85    """
86    assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
87        f"Passed in {ctx} is both callable and also a valid context manager "
88        "(has __enter__), making it ambiguous which interface to use.  If you "
89        "intended to pass a context manager factory, rewrite your call as "
90        "context_decorator(lambda: ctx()); if you intended to pass a context "
91        "manager directly, rewrite your call as context_decorator(lambda: ctx)"
92    )
93
94    if not callable(ctx):
95        def ctx_factory():
96            return ctx
97    else:
98        ctx_factory = ctx
99
100    if inspect.isclass(func):
101        raise RuntimeError(
102            "Cannot decorate classes; it is ambiguous whether or not only the "
103            "constructor or all methods should have the context manager applied; "
104            "additionally, decorating a class at definition-site will prevent "
105            "use of the identifier as a conventional type.  "
106            "To specify which methods to decorate, decorate each of them "
107            "individually."
108        )
109
110    if inspect.isgeneratorfunction(func):
111        return _wrap_generator(ctx_factory, func)
112
113    @functools.wraps(func)
114    def decorate_context(*args, **kwargs):
115        with ctx_factory():
116            return func(*args, **kwargs)
117
118    return decorate_context
119
120
121class _DecoratorContextManager:
122    """Allow a context manager to be used as a decorator."""
123
124    def __call__(self, orig_func: F) -> F:
125        if inspect.isclass(orig_func):
126            warnings.warn(
127                "Decorating classes is deprecated and will be disabled in "
128                "future versions. You should only decorate functions or methods. "
129                "To preserve the current behavior of class decoration, you can "
130                "directly decorate the `__init__` method and nothing else.",
131                FutureWarning,
132                stacklevel=2,
133            )
134            func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
135        else:
136            func = orig_func
137
138        return cast(F, context_decorator(self.clone, func))
139
140    def __enter__(self) -> None:
141        raise NotImplementedError
142
143    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
144        raise NotImplementedError
145
146    def clone(self):
147        # override this method if your children class takes __init__ parameters
148        return self.__class__()
149
150
151class _NoParamDecoratorContextManager(_DecoratorContextManager):
152    """Allow a context manager to be used as a decorator without parentheses."""
153
154    def __new__(cls, orig_func=None):
155        if orig_func is None:
156            return super().__new__(cls)
157        return cls()(orig_func)
158