xref: /aosp_15_r20/prebuilts/build-tools/common/py3-stdlib/asyncio/taskgroups.py (revision cda5da8d549138a6648c5ee6d7a49cf8f4a657be)
1# Adapted with permission from the EdgeDB project;
2# license: PSFL.
3
4
5__all__ = ["TaskGroup"]
6
7from . import events
8from . import exceptions
9from . import tasks
10
11
12class TaskGroup:
13    """Asynchronous context manager for managing groups of tasks.
14
15    Example use:
16
17        async with asyncio.TaskGroup() as group:
18            task1 = group.create_task(some_coroutine(...))
19            task2 = group.create_task(other_coroutine(...))
20        print("Both tasks have completed now.")
21
22    All tasks are awaited when the context manager exits.
23
24    Any exceptions other than `asyncio.CancelledError` raised within
25    a task will cancel all remaining tasks and wait for them to exit.
26    The exceptions are then combined and raised as an `ExceptionGroup`.
27    """
28    def __init__(self):
29        self._entered = False
30        self._exiting = False
31        self._aborting = False
32        self._loop = None
33        self._parent_task = None
34        self._parent_cancel_requested = False
35        self._tasks = set()
36        self._errors = []
37        self._base_error = None
38        self._on_completed_fut = None
39
40    def __repr__(self):
41        info = ['']
42        if self._tasks:
43            info.append(f'tasks={len(self._tasks)}')
44        if self._errors:
45            info.append(f'errors={len(self._errors)}')
46        if self._aborting:
47            info.append('cancelling')
48        elif self._entered:
49            info.append('entered')
50
51        info_str = ' '.join(info)
52        return f'<TaskGroup{info_str}>'
53
54    async def __aenter__(self):
55        if self._entered:
56            raise RuntimeError(
57                f"TaskGroup {self!r} has been already entered")
58        self._entered = True
59
60        if self._loop is None:
61            self._loop = events.get_running_loop()
62
63        self._parent_task = tasks.current_task(self._loop)
64        if self._parent_task is None:
65            raise RuntimeError(
66                f'TaskGroup {self!r} cannot determine the parent task')
67
68        return self
69
70    async def __aexit__(self, et, exc, tb):
71        self._exiting = True
72
73        if (exc is not None and
74                self._is_base_error(exc) and
75                self._base_error is None):
76            self._base_error = exc
77
78        propagate_cancellation_error = \
79            exc if et is exceptions.CancelledError else None
80        if self._parent_cancel_requested:
81            # If this flag is set we *must* call uncancel().
82            if self._parent_task.uncancel() == 0:
83                # If there are no pending cancellations left,
84                # don't propagate CancelledError.
85                propagate_cancellation_error = None
86
87        if et is not None:
88            if not self._aborting:
89                # Our parent task is being cancelled:
90                #
91                #    async with TaskGroup() as g:
92                #        g.create_task(...)
93                #        await ...  # <- CancelledError
94                #
95                # or there's an exception in "async with":
96                #
97                #    async with TaskGroup() as g:
98                #        g.create_task(...)
99                #        1 / 0
100                #
101                self._abort()
102
103        # We use while-loop here because "self._on_completed_fut"
104        # can be cancelled multiple times if our parent task
105        # is being cancelled repeatedly (or even once, when
106        # our own cancellation is already in progress)
107        while self._tasks:
108            if self._on_completed_fut is None:
109                self._on_completed_fut = self._loop.create_future()
110
111            try:
112                await self._on_completed_fut
113            except exceptions.CancelledError as ex:
114                if not self._aborting:
115                    # Our parent task is being cancelled:
116                    #
117                    #    async def wrapper():
118                    #        async with TaskGroup() as g:
119                    #            g.create_task(foo)
120                    #
121                    # "wrapper" is being cancelled while "foo" is
122                    # still running.
123                    propagate_cancellation_error = ex
124                    self._abort()
125
126            self._on_completed_fut = None
127
128        assert not self._tasks
129
130        if self._base_error is not None:
131            raise self._base_error
132
133        # Propagate CancelledError if there is one, except if there
134        # are other errors -- those have priority.
135        if propagate_cancellation_error and not self._errors:
136            raise propagate_cancellation_error
137
138        if et is not None and et is not exceptions.CancelledError:
139            self._errors.append(exc)
140
141        if self._errors:
142            # Exceptions are heavy objects that can have object
143            # cycles (bad for GC); let's not keep a reference to
144            # a bunch of them.
145            try:
146                me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors)
147                raise me from None
148            finally:
149                self._errors = None
150
151    def create_task(self, coro, *, name=None, context=None):
152        """Create a new task in this group and return it.
153
154        Similar to `asyncio.create_task`.
155        """
156        if not self._entered:
157            raise RuntimeError(f"TaskGroup {self!r} has not been entered")
158        if self._exiting and not self._tasks:
159            raise RuntimeError(f"TaskGroup {self!r} is finished")
160        if self._aborting:
161            raise RuntimeError(f"TaskGroup {self!r} is shutting down")
162        if context is None:
163            task = self._loop.create_task(coro)
164        else:
165            task = self._loop.create_task(coro, context=context)
166        tasks._set_task_name(task, name)
167        task.add_done_callback(self._on_task_done)
168        self._tasks.add(task)
169        return task
170
171    # Since Python 3.8 Tasks propagate all exceptions correctly,
172    # except for KeyboardInterrupt and SystemExit which are
173    # still considered special.
174
175    def _is_base_error(self, exc: BaseException) -> bool:
176        assert isinstance(exc, BaseException)
177        return isinstance(exc, (SystemExit, KeyboardInterrupt))
178
179    def _abort(self):
180        self._aborting = True
181
182        for t in self._tasks:
183            if not t.done():
184                t.cancel()
185
186    def _on_task_done(self, task):
187        self._tasks.discard(task)
188
189        if self._on_completed_fut is not None and not self._tasks:
190            if not self._on_completed_fut.done():
191                self._on_completed_fut.set_result(True)
192
193        if task.cancelled():
194            return
195
196        exc = task.exception()
197        if exc is None:
198            return
199
200        self._errors.append(exc)
201        if self._is_base_error(exc) and self._base_error is None:
202            self._base_error = exc
203
204        if self._parent_task.done():
205            # Not sure if this case is possible, but we want to handle
206            # it anyways.
207            self._loop.call_exception_handler({
208                'message': f'Task {task!r} has errored out but its parent '
209                           f'task {self._parent_task} is already completed',
210                'exception': exc,
211                'task': task,
212            })
213            return
214
215        if not self._aborting and not self._parent_cancel_requested:
216            # If parent task *is not* being cancelled, it means that we want
217            # to manually cancel it to abort whatever is being run right now
218            # in the TaskGroup.  But we want to mark parent task as
219            # "not cancelled" later in __aexit__.  Example situation that
220            # we need to handle:
221            #
222            #    async def foo():
223            #        try:
224            #            async with TaskGroup() as g:
225            #                g.create_task(crash_soon())
226            #                await something  # <- this needs to be canceled
227            #                                 #    by the TaskGroup, e.g.
228            #                                 #    foo() needs to be cancelled
229            #        except Exception:
230            #            # Ignore any exceptions raised in the TaskGroup
231            #            pass
232            #        await something_else     # this line has to be called
233            #                                 # after TaskGroup is finished.
234            self._abort()
235            self._parent_cancel_requested = True
236            self._parent_task.cancel()
237