1# Adapted with permission from the EdgeDB project; 2# license: PSFL. 3 4 5__all__ = ["TaskGroup"] 6 7from . import events 8from . import exceptions 9from . import tasks 10 11 12class TaskGroup: 13 """Asynchronous context manager for managing groups of tasks. 14 15 Example use: 16 17 async with asyncio.TaskGroup() as group: 18 task1 = group.create_task(some_coroutine(...)) 19 task2 = group.create_task(other_coroutine(...)) 20 print("Both tasks have completed now.") 21 22 All tasks are awaited when the context manager exits. 23 24 Any exceptions other than `asyncio.CancelledError` raised within 25 a task will cancel all remaining tasks and wait for them to exit. 26 The exceptions are then combined and raised as an `ExceptionGroup`. 27 """ 28 def __init__(self): 29 self._entered = False 30 self._exiting = False 31 self._aborting = False 32 self._loop = None 33 self._parent_task = None 34 self._parent_cancel_requested = False 35 self._tasks = set() 36 self._errors = [] 37 self._base_error = None 38 self._on_completed_fut = None 39 40 def __repr__(self): 41 info = [''] 42 if self._tasks: 43 info.append(f'tasks={len(self._tasks)}') 44 if self._errors: 45 info.append(f'errors={len(self._errors)}') 46 if self._aborting: 47 info.append('cancelling') 48 elif self._entered: 49 info.append('entered') 50 51 info_str = ' '.join(info) 52 return f'<TaskGroup{info_str}>' 53 54 async def __aenter__(self): 55 if self._entered: 56 raise RuntimeError( 57 f"TaskGroup {self!r} has been already entered") 58 self._entered = True 59 60 if self._loop is None: 61 self._loop = events.get_running_loop() 62 63 self._parent_task = tasks.current_task(self._loop) 64 if self._parent_task is None: 65 raise RuntimeError( 66 f'TaskGroup {self!r} cannot determine the parent task') 67 68 return self 69 70 async def __aexit__(self, et, exc, tb): 71 self._exiting = True 72 73 if (exc is not None and 74 self._is_base_error(exc) and 75 self._base_error is None): 76 self._base_error = exc 77 78 propagate_cancellation_error = \ 79 exc if et is exceptions.CancelledError else None 80 if self._parent_cancel_requested: 81 # If this flag is set we *must* call uncancel(). 82 if self._parent_task.uncancel() == 0: 83 # If there are no pending cancellations left, 84 # don't propagate CancelledError. 85 propagate_cancellation_error = None 86 87 if et is not None: 88 if not self._aborting: 89 # Our parent task is being cancelled: 90 # 91 # async with TaskGroup() as g: 92 # g.create_task(...) 93 # await ... # <- CancelledError 94 # 95 # or there's an exception in "async with": 96 # 97 # async with TaskGroup() as g: 98 # g.create_task(...) 99 # 1 / 0 100 # 101 self._abort() 102 103 # We use while-loop here because "self._on_completed_fut" 104 # can be cancelled multiple times if our parent task 105 # is being cancelled repeatedly (or even once, when 106 # our own cancellation is already in progress) 107 while self._tasks: 108 if self._on_completed_fut is None: 109 self._on_completed_fut = self._loop.create_future() 110 111 try: 112 await self._on_completed_fut 113 except exceptions.CancelledError as ex: 114 if not self._aborting: 115 # Our parent task is being cancelled: 116 # 117 # async def wrapper(): 118 # async with TaskGroup() as g: 119 # g.create_task(foo) 120 # 121 # "wrapper" is being cancelled while "foo" is 122 # still running. 123 propagate_cancellation_error = ex 124 self._abort() 125 126 self._on_completed_fut = None 127 128 assert not self._tasks 129 130 if self._base_error is not None: 131 raise self._base_error 132 133 # Propagate CancelledError if there is one, except if there 134 # are other errors -- those have priority. 135 if propagate_cancellation_error and not self._errors: 136 raise propagate_cancellation_error 137 138 if et is not None and et is not exceptions.CancelledError: 139 self._errors.append(exc) 140 141 if self._errors: 142 # Exceptions are heavy objects that can have object 143 # cycles (bad for GC); let's not keep a reference to 144 # a bunch of them. 145 try: 146 me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors) 147 raise me from None 148 finally: 149 self._errors = None 150 151 def create_task(self, coro, *, name=None, context=None): 152 """Create a new task in this group and return it. 153 154 Similar to `asyncio.create_task`. 155 """ 156 if not self._entered: 157 raise RuntimeError(f"TaskGroup {self!r} has not been entered") 158 if self._exiting and not self._tasks: 159 raise RuntimeError(f"TaskGroup {self!r} is finished") 160 if self._aborting: 161 raise RuntimeError(f"TaskGroup {self!r} is shutting down") 162 if context is None: 163 task = self._loop.create_task(coro) 164 else: 165 task = self._loop.create_task(coro, context=context) 166 tasks._set_task_name(task, name) 167 task.add_done_callback(self._on_task_done) 168 self._tasks.add(task) 169 return task 170 171 # Since Python 3.8 Tasks propagate all exceptions correctly, 172 # except for KeyboardInterrupt and SystemExit which are 173 # still considered special. 174 175 def _is_base_error(self, exc: BaseException) -> bool: 176 assert isinstance(exc, BaseException) 177 return isinstance(exc, (SystemExit, KeyboardInterrupt)) 178 179 def _abort(self): 180 self._aborting = True 181 182 for t in self._tasks: 183 if not t.done(): 184 t.cancel() 185 186 def _on_task_done(self, task): 187 self._tasks.discard(task) 188 189 if self._on_completed_fut is not None and not self._tasks: 190 if not self._on_completed_fut.done(): 191 self._on_completed_fut.set_result(True) 192 193 if task.cancelled(): 194 return 195 196 exc = task.exception() 197 if exc is None: 198 return 199 200 self._errors.append(exc) 201 if self._is_base_error(exc) and self._base_error is None: 202 self._base_error = exc 203 204 if self._parent_task.done(): 205 # Not sure if this case is possible, but we want to handle 206 # it anyways. 207 self._loop.call_exception_handler({ 208 'message': f'Task {task!r} has errored out but its parent ' 209 f'task {self._parent_task} is already completed', 210 'exception': exc, 211 'task': task, 212 }) 213 return 214 215 if not self._aborting and not self._parent_cancel_requested: 216 # If parent task *is not* being cancelled, it means that we want 217 # to manually cancel it to abort whatever is being run right now 218 # in the TaskGroup. But we want to mark parent task as 219 # "not cancelled" later in __aexit__. Example situation that 220 # we need to handle: 221 # 222 # async def foo(): 223 # try: 224 # async with TaskGroup() as g: 225 # g.create_task(crash_soon()) 226 # await something # <- this needs to be canceled 227 # # by the TaskGroup, e.g. 228 # # foo() needs to be cancelled 229 # except Exception: 230 # # Ignore any exceptions raised in the TaskGroup 231 # pass 232 # await something_else # this line has to be called 233 # # after TaskGroup is finished. 234 self._abort() 235 self._parent_cancel_requested = True 236 self._parent_task.cancel() 237