1"""Synchronization primitives.""" 2 3__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 4 'BoundedSemaphore', 'Barrier') 5 6import collections 7import enum 8 9from . import exceptions 10from . import mixins 11from . import tasks 12 13class _ContextManagerMixin: 14 async def __aenter__(self): 15 await self.acquire() 16 # We have no use for the "as ..." clause in the with 17 # statement for locks. 18 return None 19 20 async def __aexit__(self, exc_type, exc, tb): 21 self.release() 22 23 24class Lock(_ContextManagerMixin, mixins._LoopBoundMixin): 25 """Primitive lock objects. 26 27 A primitive lock is a synchronization primitive that is not owned 28 by a particular coroutine when locked. A primitive lock is in one 29 of two states, 'locked' or 'unlocked'. 30 31 It is created in the unlocked state. It has two basic methods, 32 acquire() and release(). When the state is unlocked, acquire() 33 changes the state to locked and returns immediately. When the 34 state is locked, acquire() blocks until a call to release() in 35 another coroutine changes it to unlocked, then the acquire() call 36 resets it to locked and returns. The release() method should only 37 be called in the locked state; it changes the state to unlocked 38 and returns immediately. If an attempt is made to release an 39 unlocked lock, a RuntimeError will be raised. 40 41 When more than one coroutine is blocked in acquire() waiting for 42 the state to turn to unlocked, only one coroutine proceeds when a 43 release() call resets the state to unlocked; first coroutine which 44 is blocked in acquire() is being processed. 45 46 acquire() is a coroutine and should be called with 'await'. 47 48 Locks also support the asynchronous context management protocol. 49 'async with lock' statement should be used. 50 51 Usage: 52 53 lock = Lock() 54 ... 55 await lock.acquire() 56 try: 57 ... 58 finally: 59 lock.release() 60 61 Context manager usage: 62 63 lock = Lock() 64 ... 65 async with lock: 66 ... 67 68 Lock objects can be tested for locking state: 69 70 if not lock.locked(): 71 await lock.acquire() 72 else: 73 # lock is acquired 74 ... 75 76 """ 77 78 def __init__(self): 79 self._waiters = None 80 self._locked = False 81 82 def __repr__(self): 83 res = super().__repr__() 84 extra = 'locked' if self._locked else 'unlocked' 85 if self._waiters: 86 extra = f'{extra}, waiters:{len(self._waiters)}' 87 return f'<{res[1:-1]} [{extra}]>' 88 89 def locked(self): 90 """Return True if lock is acquired.""" 91 return self._locked 92 93 async def acquire(self): 94 """Acquire a lock. 95 96 This method blocks until the lock is unlocked, then sets it to 97 locked and returns True. 98 """ 99 if (not self._locked and (self._waiters is None or 100 all(w.cancelled() for w in self._waiters))): 101 self._locked = True 102 return True 103 104 if self._waiters is None: 105 self._waiters = collections.deque() 106 fut = self._get_loop().create_future() 107 self._waiters.append(fut) 108 109 # Finally block should be called before the CancelledError 110 # handling as we don't want CancelledError to call 111 # _wake_up_first() and attempt to wake up itself. 112 try: 113 try: 114 await fut 115 finally: 116 self._waiters.remove(fut) 117 except exceptions.CancelledError: 118 if not self._locked: 119 self._wake_up_first() 120 raise 121 122 self._locked = True 123 return True 124 125 def release(self): 126 """Release a lock. 127 128 When the lock is locked, reset it to unlocked, and return. 129 If any other coroutines are blocked waiting for the lock to become 130 unlocked, allow exactly one of them to proceed. 131 132 When invoked on an unlocked lock, a RuntimeError is raised. 133 134 There is no return value. 135 """ 136 if self._locked: 137 self._locked = False 138 self._wake_up_first() 139 else: 140 raise RuntimeError('Lock is not acquired.') 141 142 def _wake_up_first(self): 143 """Wake up the first waiter if it isn't done.""" 144 if not self._waiters: 145 return 146 try: 147 fut = next(iter(self._waiters)) 148 except StopIteration: 149 return 150 151 # .done() necessarily means that a waiter will wake up later on and 152 # either take the lock, or, if it was cancelled and lock wasn't 153 # taken already, will hit this again and wake up a new waiter. 154 if not fut.done(): 155 fut.set_result(True) 156 157 158class Event(mixins._LoopBoundMixin): 159 """Asynchronous equivalent to threading.Event. 160 161 Class implementing event objects. An event manages a flag that can be set 162 to true with the set() method and reset to false with the clear() method. 163 The wait() method blocks until the flag is true. The flag is initially 164 false. 165 """ 166 167 def __init__(self): 168 self._waiters = collections.deque() 169 self._value = False 170 171 def __repr__(self): 172 res = super().__repr__() 173 extra = 'set' if self._value else 'unset' 174 if self._waiters: 175 extra = f'{extra}, waiters:{len(self._waiters)}' 176 return f'<{res[1:-1]} [{extra}]>' 177 178 def is_set(self): 179 """Return True if and only if the internal flag is true.""" 180 return self._value 181 182 def set(self): 183 """Set the internal flag to true. All coroutines waiting for it to 184 become true are awakened. Coroutine that call wait() once the flag is 185 true will not block at all. 186 """ 187 if not self._value: 188 self._value = True 189 190 for fut in self._waiters: 191 if not fut.done(): 192 fut.set_result(True) 193 194 def clear(self): 195 """Reset the internal flag to false. Subsequently, coroutines calling 196 wait() will block until set() is called to set the internal flag 197 to true again.""" 198 self._value = False 199 200 async def wait(self): 201 """Block until the internal flag is true. 202 203 If the internal flag is true on entry, return True 204 immediately. Otherwise, block until another coroutine calls 205 set() to set the flag to true, then return True. 206 """ 207 if self._value: 208 return True 209 210 fut = self._get_loop().create_future() 211 self._waiters.append(fut) 212 try: 213 await fut 214 return True 215 finally: 216 self._waiters.remove(fut) 217 218 219class Condition(_ContextManagerMixin, mixins._LoopBoundMixin): 220 """Asynchronous equivalent to threading.Condition. 221 222 This class implements condition variable objects. A condition variable 223 allows one or more coroutines to wait until they are notified by another 224 coroutine. 225 226 A new Lock object is created and used as the underlying lock. 227 """ 228 229 def __init__(self, lock=None): 230 if lock is None: 231 lock = Lock() 232 233 self._lock = lock 234 # Export the lock's locked(), acquire() and release() methods. 235 self.locked = lock.locked 236 self.acquire = lock.acquire 237 self.release = lock.release 238 239 self._waiters = collections.deque() 240 241 def __repr__(self): 242 res = super().__repr__() 243 extra = 'locked' if self.locked() else 'unlocked' 244 if self._waiters: 245 extra = f'{extra}, waiters:{len(self._waiters)}' 246 return f'<{res[1:-1]} [{extra}]>' 247 248 async def wait(self): 249 """Wait until notified. 250 251 If the calling coroutine has not acquired the lock when this 252 method is called, a RuntimeError is raised. 253 254 This method releases the underlying lock, and then blocks 255 until it is awakened by a notify() or notify_all() call for 256 the same condition variable in another coroutine. Once 257 awakened, it re-acquires the lock and returns True. 258 """ 259 if not self.locked(): 260 raise RuntimeError('cannot wait on un-acquired lock') 261 262 self.release() 263 try: 264 fut = self._get_loop().create_future() 265 self._waiters.append(fut) 266 try: 267 await fut 268 return True 269 finally: 270 self._waiters.remove(fut) 271 272 finally: 273 # Must reacquire lock even if wait is cancelled 274 cancelled = False 275 while True: 276 try: 277 await self.acquire() 278 break 279 except exceptions.CancelledError: 280 cancelled = True 281 282 if cancelled: 283 raise exceptions.CancelledError 284 285 async def wait_for(self, predicate): 286 """Wait until a predicate becomes true. 287 288 The predicate should be a callable which result will be 289 interpreted as a boolean value. The final predicate value is 290 the return value. 291 """ 292 result = predicate() 293 while not result: 294 await self.wait() 295 result = predicate() 296 return result 297 298 def notify(self, n=1): 299 """By default, wake up one coroutine waiting on this condition, if any. 300 If the calling coroutine has not acquired the lock when this method 301 is called, a RuntimeError is raised. 302 303 This method wakes up at most n of the coroutines waiting for the 304 condition variable; it is a no-op if no coroutines are waiting. 305 306 Note: an awakened coroutine does not actually return from its 307 wait() call until it can reacquire the lock. Since notify() does 308 not release the lock, its caller should. 309 """ 310 if not self.locked(): 311 raise RuntimeError('cannot notify on un-acquired lock') 312 313 idx = 0 314 for fut in self._waiters: 315 if idx >= n: 316 break 317 318 if not fut.done(): 319 idx += 1 320 fut.set_result(False) 321 322 def notify_all(self): 323 """Wake up all threads waiting on this condition. This method acts 324 like notify(), but wakes up all waiting threads instead of one. If the 325 calling thread has not acquired the lock when this method is called, 326 a RuntimeError is raised. 327 """ 328 self.notify(len(self._waiters)) 329 330 331class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin): 332 """A Semaphore implementation. 333 334 A semaphore manages an internal counter which is decremented by each 335 acquire() call and incremented by each release() call. The counter 336 can never go below zero; when acquire() finds that it is zero, it blocks, 337 waiting until some other thread calls release(). 338 339 Semaphores also support the context management protocol. 340 341 The optional argument gives the initial value for the internal 342 counter; it defaults to 1. If the value given is less than 0, 343 ValueError is raised. 344 """ 345 346 def __init__(self, value=1): 347 if value < 0: 348 raise ValueError("Semaphore initial value must be >= 0") 349 self._waiters = None 350 self._value = value 351 352 def __repr__(self): 353 res = super().__repr__() 354 extra = 'locked' if self.locked() else f'unlocked, value:{self._value}' 355 if self._waiters: 356 extra = f'{extra}, waiters:{len(self._waiters)}' 357 return f'<{res[1:-1]} [{extra}]>' 358 359 def locked(self): 360 """Returns True if semaphore cannot be acquired immediately.""" 361 return self._value == 0 or ( 362 any(not w.cancelled() for w in (self._waiters or ()))) 363 364 async def acquire(self): 365 """Acquire a semaphore. 366 367 If the internal counter is larger than zero on entry, 368 decrement it by one and return True immediately. If it is 369 zero on entry, block, waiting until some other coroutine has 370 called release() to make it larger than 0, and then return 371 True. 372 """ 373 if not self.locked(): 374 self._value -= 1 375 return True 376 377 if self._waiters is None: 378 self._waiters = collections.deque() 379 fut = self._get_loop().create_future() 380 self._waiters.append(fut) 381 382 # Finally block should be called before the CancelledError 383 # handling as we don't want CancelledError to call 384 # _wake_up_first() and attempt to wake up itself. 385 try: 386 try: 387 await fut 388 finally: 389 self._waiters.remove(fut) 390 except exceptions.CancelledError: 391 if not fut.cancelled(): 392 self._value += 1 393 self._wake_up_next() 394 raise 395 396 if self._value > 0: 397 self._wake_up_next() 398 return True 399 400 def release(self): 401 """Release a semaphore, incrementing the internal counter by one. 402 403 When it was zero on entry and another coroutine is waiting for it to 404 become larger than zero again, wake up that coroutine. 405 """ 406 self._value += 1 407 self._wake_up_next() 408 409 def _wake_up_next(self): 410 """Wake up the first waiter that isn't done.""" 411 if not self._waiters: 412 return 413 414 for fut in self._waiters: 415 if not fut.done(): 416 self._value -= 1 417 fut.set_result(True) 418 return 419 420 421class BoundedSemaphore(Semaphore): 422 """A bounded semaphore implementation. 423 424 This raises ValueError in release() if it would increase the value 425 above the initial value. 426 """ 427 428 def __init__(self, value=1): 429 self._bound_value = value 430 super().__init__(value) 431 432 def release(self): 433 if self._value >= self._bound_value: 434 raise ValueError('BoundedSemaphore released too many times') 435 super().release() 436 437 438 439class _BarrierState(enum.Enum): 440 FILLING = 'filling' 441 DRAINING = 'draining' 442 RESETTING = 'resetting' 443 BROKEN = 'broken' 444 445 446class Barrier(mixins._LoopBoundMixin): 447 """Asyncio equivalent to threading.Barrier 448 449 Implements a Barrier primitive. 450 Useful for synchronizing a fixed number of tasks at known synchronization 451 points. Tasks block on 'wait()' and are simultaneously awoken once they 452 have all made their call. 453 """ 454 455 def __init__(self, parties): 456 """Create a barrier, initialised to 'parties' tasks.""" 457 if parties < 1: 458 raise ValueError('parties must be > 0') 459 460 self._cond = Condition() # notify all tasks when state changes 461 462 self._parties = parties 463 self._state = _BarrierState.FILLING 464 self._count = 0 # count tasks in Barrier 465 466 def __repr__(self): 467 res = super().__repr__() 468 extra = f'{self._state.value}' 469 if not self.broken: 470 extra += f', waiters:{self.n_waiting}/{self.parties}' 471 return f'<{res[1:-1]} [{extra}]>' 472 473 async def __aenter__(self): 474 # wait for the barrier reaches the parties number 475 # when start draining release and return index of waited task 476 return await self.wait() 477 478 async def __aexit__(self, *args): 479 pass 480 481 async def wait(self): 482 """Wait for the barrier. 483 484 When the specified number of tasks have started waiting, they are all 485 simultaneously awoken. 486 Returns an unique and individual index number from 0 to 'parties-1'. 487 """ 488 async with self._cond: 489 await self._block() # Block while the barrier drains or resets. 490 try: 491 index = self._count 492 self._count += 1 493 if index + 1 == self._parties: 494 # We release the barrier 495 await self._release() 496 else: 497 await self._wait() 498 return index 499 finally: 500 self._count -= 1 501 # Wake up any tasks waiting for barrier to drain. 502 self._exit() 503 504 async def _block(self): 505 # Block until the barrier is ready for us, 506 # or raise an exception if it is broken. 507 # 508 # It is draining or resetting, wait until done 509 # unless a CancelledError occurs 510 await self._cond.wait_for( 511 lambda: self._state not in ( 512 _BarrierState.DRAINING, _BarrierState.RESETTING 513 ) 514 ) 515 516 # see if the barrier is in a broken state 517 if self._state is _BarrierState.BROKEN: 518 raise exceptions.BrokenBarrierError("Barrier aborted") 519 520 async def _release(self): 521 # Release the tasks waiting in the barrier. 522 523 # Enter draining state. 524 # Next waiting tasks will be blocked until the end of draining. 525 self._state = _BarrierState.DRAINING 526 self._cond.notify_all() 527 528 async def _wait(self): 529 # Wait in the barrier until we are released. Raise an exception 530 # if the barrier is reset or broken. 531 532 # wait for end of filling 533 # unless a CancelledError occurs 534 await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING) 535 536 if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING): 537 raise exceptions.BrokenBarrierError("Abort or reset of barrier") 538 539 def _exit(self): 540 # If we are the last tasks to exit the barrier, signal any tasks 541 # waiting for the barrier to drain. 542 if self._count == 0: 543 if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING): 544 self._state = _BarrierState.FILLING 545 self._cond.notify_all() 546 547 async def reset(self): 548 """Reset the barrier to the initial state. 549 550 Any tasks currently waiting will get the BrokenBarrier exception 551 raised. 552 """ 553 async with self._cond: 554 if self._count > 0: 555 if self._state is not _BarrierState.RESETTING: 556 #reset the barrier, waking up tasks 557 self._state = _BarrierState.RESETTING 558 else: 559 self._state = _BarrierState.FILLING 560 self._cond.notify_all() 561 562 async def abort(self): 563 """Place the barrier into a 'broken' state. 564 565 Useful in case of error. Any currently waiting tasks and tasks 566 attempting to 'wait()' will have BrokenBarrierError raised. 567 """ 568 async with self._cond: 569 self._state = _BarrierState.BROKEN 570 self._cond.notify_all() 571 572 @property 573 def parties(self): 574 """Return the number of tasks required to trip the barrier.""" 575 return self._parties 576 577 @property 578 def n_waiting(self): 579 """Return the number of tasks currently waiting at the barrier.""" 580 if self._state is _BarrierState.FILLING: 581 return self._count 582 return 0 583 584 @property 585 def broken(self): 586 """Return True if the barrier is in a broken state.""" 587 return self._state is _BarrierState.BROKEN 588