1import enum
2
3from types import TracebackType
4from typing import final, Optional, Type
5
6from . import events
7from . import exceptions
8from . import tasks
9
10
11__all__ = (
12    "Timeout",
13    "timeout",
14    "timeout_at",
15)
16
17
18class _State(enum.Enum):
19    CREATED = "created"
20    ENTERED = "active"
21    EXPIRING = "expiring"
22    EXPIRED = "expired"
23    EXITED = "finished"
24
25
26@final
27class Timeout:
28    """Asynchronous context manager for cancelling overdue coroutines.
29
30    Use `timeout()` or `timeout_at()` rather than instantiating this class directly.
31    """
32
33    def __init__(self, when: Optional[float]) -> None:
34        """Schedule a timeout that will trigger at a given loop time.
35
36        - If `when` is `None`, the timeout will never trigger.
37        - If `when < loop.time()`, the timeout will trigger on the next
38          iteration of the event loop.
39        """
40        self._state = _State.CREATED
41
42        self._timeout_handler: Optional[events.TimerHandle] = None
43        self._task: Optional[tasks.Task] = None
44        self._when = when
45
46    def when(self) -> Optional[float]:
47        """Return the current deadline."""
48        return self._when
49
50    def reschedule(self, when: Optional[float]) -> None:
51        """Reschedule the timeout."""
52        assert self._state is not _State.CREATED
53        if self._state is not _State.ENTERED:
54            raise RuntimeError(
55                f"Cannot change state of {self._state.value} Timeout",
56            )
57
58        self._when = when
59
60        if self._timeout_handler is not None:
61            self._timeout_handler.cancel()
62
63        if when is None:
64            self._timeout_handler = None
65        else:
66            loop = events.get_running_loop()
67            if when <= loop.time():
68                self._timeout_handler = loop.call_soon(self._on_timeout)
69            else:
70                self._timeout_handler = loop.call_at(when, self._on_timeout)
71
72    def expired(self) -> bool:
73        """Is timeout expired during execution?"""
74        return self._state in (_State.EXPIRING, _State.EXPIRED)
75
76    def __repr__(self) -> str:
77        info = ['']
78        if self._state is _State.ENTERED:
79            when = round(self._when, 3) if self._when is not None else None
80            info.append(f"when={when}")
81        info_str = ' '.join(info)
82        return f"<Timeout [{self._state.value}]{info_str}>"
83
84    async def __aenter__(self) -> "Timeout":
85        self._state = _State.ENTERED
86        self._task = tasks.current_task()
87        self._cancelling = self._task.cancelling()
88        if self._task is None:
89            raise RuntimeError("Timeout should be used inside a task")
90        self.reschedule(self._when)
91        return self
92
93    async def __aexit__(
94        self,
95        exc_type: Optional[Type[BaseException]],
96        exc_val: Optional[BaseException],
97        exc_tb: Optional[TracebackType],
98    ) -> Optional[bool]:
99        assert self._state in (_State.ENTERED, _State.EXPIRING)
100
101        if self._timeout_handler is not None:
102            self._timeout_handler.cancel()
103            self._timeout_handler = None
104
105        if self._state is _State.EXPIRING:
106            self._state = _State.EXPIRED
107
108            if self._task.uncancel() <= self._cancelling and exc_type is exceptions.CancelledError:
109                # Since there are no new cancel requests, we're
110                # handling this.
111                raise TimeoutError from exc_val
112        elif self._state is _State.ENTERED:
113            self._state = _State.EXITED
114
115        return None
116
117    def _on_timeout(self) -> None:
118        assert self._state is _State.ENTERED
119        self._task.cancel()
120        self._state = _State.EXPIRING
121        # drop the reference early
122        self._timeout_handler = None
123
124
125def timeout(delay: Optional[float]) -> Timeout:
126    """Timeout async context manager.
127
128    Useful in cases when you want to apply timeout logic around block
129    of code or in cases when asyncio.wait_for is not suitable. For example:
130
131    >>> async with asyncio.timeout(10):  # 10 seconds timeout
132    ...     await long_running_task()
133
134
135    delay - value in seconds or None to disable timeout logic
136
137    long_running_task() is interrupted by raising asyncio.CancelledError,
138    the top-most affected timeout() context manager converts CancelledError
139    into TimeoutError.
140    """
141    loop = events.get_running_loop()
142    return Timeout(loop.time() + delay if delay is not None else None)
143
144
145def timeout_at(when: Optional[float]) -> Timeout:
146    """Schedule the timeout at absolute time.
147
148    Like timeout() but argument gives absolute time in the same clock system
149    as loop.time().
150
151    Please note: it is not POSIX time but a time with
152    undefined starting base, e.g. the time of the system power on.
153
154    >>> async with asyncio.timeout_at(loop.time() + 10):
155    ...     await long_running_task()
156
157
158    when - a deadline when timeout occurs or None to disable timeout logic
159
160    long_running_task() is interrupted by raising asyncio.CancelledError,
161    the top-most affected timeout() context manager converts CancelledError
162    into TimeoutError.
163    """
164    return Timeout(when)
165