xref: /aosp_15_r20/prebuilts/build-tools/common/py3-stdlib/asyncio/locks.py (revision cda5da8d549138a6648c5ee6d7a49cf8f4a657be)
1"""Synchronization primitives."""
2
3__all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
4           'BoundedSemaphore', 'Barrier')
5
6import collections
7import enum
8
9from . import exceptions
10from . import mixins
11from . import tasks
12
13class _ContextManagerMixin:
14    async def __aenter__(self):
15        await self.acquire()
16        # We have no use for the "as ..."  clause in the with
17        # statement for locks.
18        return None
19
20    async def __aexit__(self, exc_type, exc, tb):
21        self.release()
22
23
24class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
25    """Primitive lock objects.
26
27    A primitive lock is a synchronization primitive that is not owned
28    by a particular coroutine when locked.  A primitive lock is in one
29    of two states, 'locked' or 'unlocked'.
30
31    It is created in the unlocked state.  It has two basic methods,
32    acquire() and release().  When the state is unlocked, acquire()
33    changes the state to locked and returns immediately.  When the
34    state is locked, acquire() blocks until a call to release() in
35    another coroutine changes it to unlocked, then the acquire() call
36    resets it to locked and returns.  The release() method should only
37    be called in the locked state; it changes the state to unlocked
38    and returns immediately.  If an attempt is made to release an
39    unlocked lock, a RuntimeError will be raised.
40
41    When more than one coroutine is blocked in acquire() waiting for
42    the state to turn to unlocked, only one coroutine proceeds when a
43    release() call resets the state to unlocked; first coroutine which
44    is blocked in acquire() is being processed.
45
46    acquire() is a coroutine and should be called with 'await'.
47
48    Locks also support the asynchronous context management protocol.
49    'async with lock' statement should be used.
50
51    Usage:
52
53        lock = Lock()
54        ...
55        await lock.acquire()
56        try:
57            ...
58        finally:
59            lock.release()
60
61    Context manager usage:
62
63        lock = Lock()
64        ...
65        async with lock:
66             ...
67
68    Lock objects can be tested for locking state:
69
70        if not lock.locked():
71           await lock.acquire()
72        else:
73           # lock is acquired
74           ...
75
76    """
77
78    def __init__(self):
79        self._waiters = None
80        self._locked = False
81
82    def __repr__(self):
83        res = super().__repr__()
84        extra = 'locked' if self._locked else 'unlocked'
85        if self._waiters:
86            extra = f'{extra}, waiters:{len(self._waiters)}'
87        return f'<{res[1:-1]} [{extra}]>'
88
89    def locked(self):
90        """Return True if lock is acquired."""
91        return self._locked
92
93    async def acquire(self):
94        """Acquire a lock.
95
96        This method blocks until the lock is unlocked, then sets it to
97        locked and returns True.
98        """
99        if (not self._locked and (self._waiters is None or
100                all(w.cancelled() for w in self._waiters))):
101            self._locked = True
102            return True
103
104        if self._waiters is None:
105            self._waiters = collections.deque()
106        fut = self._get_loop().create_future()
107        self._waiters.append(fut)
108
109        # Finally block should be called before the CancelledError
110        # handling as we don't want CancelledError to call
111        # _wake_up_first() and attempt to wake up itself.
112        try:
113            try:
114                await fut
115            finally:
116                self._waiters.remove(fut)
117        except exceptions.CancelledError:
118            if not self._locked:
119                self._wake_up_first()
120            raise
121
122        self._locked = True
123        return True
124
125    def release(self):
126        """Release a lock.
127
128        When the lock is locked, reset it to unlocked, and return.
129        If any other coroutines are blocked waiting for the lock to become
130        unlocked, allow exactly one of them to proceed.
131
132        When invoked on an unlocked lock, a RuntimeError is raised.
133
134        There is no return value.
135        """
136        if self._locked:
137            self._locked = False
138            self._wake_up_first()
139        else:
140            raise RuntimeError('Lock is not acquired.')
141
142    def _wake_up_first(self):
143        """Wake up the first waiter if it isn't done."""
144        if not self._waiters:
145            return
146        try:
147            fut = next(iter(self._waiters))
148        except StopIteration:
149            return
150
151        # .done() necessarily means that a waiter will wake up later on and
152        # either take the lock, or, if it was cancelled and lock wasn't
153        # taken already, will hit this again and wake up a new waiter.
154        if not fut.done():
155            fut.set_result(True)
156
157
158class Event(mixins._LoopBoundMixin):
159    """Asynchronous equivalent to threading.Event.
160
161    Class implementing event objects. An event manages a flag that can be set
162    to true with the set() method and reset to false with the clear() method.
163    The wait() method blocks until the flag is true. The flag is initially
164    false.
165    """
166
167    def __init__(self):
168        self._waiters = collections.deque()
169        self._value = False
170
171    def __repr__(self):
172        res = super().__repr__()
173        extra = 'set' if self._value else 'unset'
174        if self._waiters:
175            extra = f'{extra}, waiters:{len(self._waiters)}'
176        return f'<{res[1:-1]} [{extra}]>'
177
178    def is_set(self):
179        """Return True if and only if the internal flag is true."""
180        return self._value
181
182    def set(self):
183        """Set the internal flag to true. All coroutines waiting for it to
184        become true are awakened. Coroutine that call wait() once the flag is
185        true will not block at all.
186        """
187        if not self._value:
188            self._value = True
189
190            for fut in self._waiters:
191                if not fut.done():
192                    fut.set_result(True)
193
194    def clear(self):
195        """Reset the internal flag to false. Subsequently, coroutines calling
196        wait() will block until set() is called to set the internal flag
197        to true again."""
198        self._value = False
199
200    async def wait(self):
201        """Block until the internal flag is true.
202
203        If the internal flag is true on entry, return True
204        immediately.  Otherwise, block until another coroutine calls
205        set() to set the flag to true, then return True.
206        """
207        if self._value:
208            return True
209
210        fut = self._get_loop().create_future()
211        self._waiters.append(fut)
212        try:
213            await fut
214            return True
215        finally:
216            self._waiters.remove(fut)
217
218
219class Condition(_ContextManagerMixin, mixins._LoopBoundMixin):
220    """Asynchronous equivalent to threading.Condition.
221
222    This class implements condition variable objects. A condition variable
223    allows one or more coroutines to wait until they are notified by another
224    coroutine.
225
226    A new Lock object is created and used as the underlying lock.
227    """
228
229    def __init__(self, lock=None):
230        if lock is None:
231            lock = Lock()
232
233        self._lock = lock
234        # Export the lock's locked(), acquire() and release() methods.
235        self.locked = lock.locked
236        self.acquire = lock.acquire
237        self.release = lock.release
238
239        self._waiters = collections.deque()
240
241    def __repr__(self):
242        res = super().__repr__()
243        extra = 'locked' if self.locked() else 'unlocked'
244        if self._waiters:
245            extra = f'{extra}, waiters:{len(self._waiters)}'
246        return f'<{res[1:-1]} [{extra}]>'
247
248    async def wait(self):
249        """Wait until notified.
250
251        If the calling coroutine has not acquired the lock when this
252        method is called, a RuntimeError is raised.
253
254        This method releases the underlying lock, and then blocks
255        until it is awakened by a notify() or notify_all() call for
256        the same condition variable in another coroutine.  Once
257        awakened, it re-acquires the lock and returns True.
258        """
259        if not self.locked():
260            raise RuntimeError('cannot wait on un-acquired lock')
261
262        self.release()
263        try:
264            fut = self._get_loop().create_future()
265            self._waiters.append(fut)
266            try:
267                await fut
268                return True
269            finally:
270                self._waiters.remove(fut)
271
272        finally:
273            # Must reacquire lock even if wait is cancelled
274            cancelled = False
275            while True:
276                try:
277                    await self.acquire()
278                    break
279                except exceptions.CancelledError:
280                    cancelled = True
281
282            if cancelled:
283                raise exceptions.CancelledError
284
285    async def wait_for(self, predicate):
286        """Wait until a predicate becomes true.
287
288        The predicate should be a callable which result will be
289        interpreted as a boolean value.  The final predicate value is
290        the return value.
291        """
292        result = predicate()
293        while not result:
294            await self.wait()
295            result = predicate()
296        return result
297
298    def notify(self, n=1):
299        """By default, wake up one coroutine waiting on this condition, if any.
300        If the calling coroutine has not acquired the lock when this method
301        is called, a RuntimeError is raised.
302
303        This method wakes up at most n of the coroutines waiting for the
304        condition variable; it is a no-op if no coroutines are waiting.
305
306        Note: an awakened coroutine does not actually return from its
307        wait() call until it can reacquire the lock. Since notify() does
308        not release the lock, its caller should.
309        """
310        if not self.locked():
311            raise RuntimeError('cannot notify on un-acquired lock')
312
313        idx = 0
314        for fut in self._waiters:
315            if idx >= n:
316                break
317
318            if not fut.done():
319                idx += 1
320                fut.set_result(False)
321
322    def notify_all(self):
323        """Wake up all threads waiting on this condition. This method acts
324        like notify(), but wakes up all waiting threads instead of one. If the
325        calling thread has not acquired the lock when this method is called,
326        a RuntimeError is raised.
327        """
328        self.notify(len(self._waiters))
329
330
331class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
332    """A Semaphore implementation.
333
334    A semaphore manages an internal counter which is decremented by each
335    acquire() call and incremented by each release() call. The counter
336    can never go below zero; when acquire() finds that it is zero, it blocks,
337    waiting until some other thread calls release().
338
339    Semaphores also support the context management protocol.
340
341    The optional argument gives the initial value for the internal
342    counter; it defaults to 1. If the value given is less than 0,
343    ValueError is raised.
344    """
345
346    def __init__(self, value=1):
347        if value < 0:
348            raise ValueError("Semaphore initial value must be >= 0")
349        self._waiters = None
350        self._value = value
351
352    def __repr__(self):
353        res = super().__repr__()
354        extra = 'locked' if self.locked() else f'unlocked, value:{self._value}'
355        if self._waiters:
356            extra = f'{extra}, waiters:{len(self._waiters)}'
357        return f'<{res[1:-1]} [{extra}]>'
358
359    def locked(self):
360        """Returns True if semaphore cannot be acquired immediately."""
361        return self._value == 0 or (
362            any(not w.cancelled() for w in (self._waiters or ())))
363
364    async def acquire(self):
365        """Acquire a semaphore.
366
367        If the internal counter is larger than zero on entry,
368        decrement it by one and return True immediately.  If it is
369        zero on entry, block, waiting until some other coroutine has
370        called release() to make it larger than 0, and then return
371        True.
372        """
373        if not self.locked():
374            self._value -= 1
375            return True
376
377        if self._waiters is None:
378            self._waiters = collections.deque()
379        fut = self._get_loop().create_future()
380        self._waiters.append(fut)
381
382        # Finally block should be called before the CancelledError
383        # handling as we don't want CancelledError to call
384        # _wake_up_first() and attempt to wake up itself.
385        try:
386            try:
387                await fut
388            finally:
389                self._waiters.remove(fut)
390        except exceptions.CancelledError:
391            if not fut.cancelled():
392                self._value += 1
393                self._wake_up_next()
394            raise
395
396        if self._value > 0:
397            self._wake_up_next()
398        return True
399
400    def release(self):
401        """Release a semaphore, incrementing the internal counter by one.
402
403        When it was zero on entry and another coroutine is waiting for it to
404        become larger than zero again, wake up that coroutine.
405        """
406        self._value += 1
407        self._wake_up_next()
408
409    def _wake_up_next(self):
410        """Wake up the first waiter that isn't done."""
411        if not self._waiters:
412            return
413
414        for fut in self._waiters:
415            if not fut.done():
416                self._value -= 1
417                fut.set_result(True)
418                return
419
420
421class BoundedSemaphore(Semaphore):
422    """A bounded semaphore implementation.
423
424    This raises ValueError in release() if it would increase the value
425    above the initial value.
426    """
427
428    def __init__(self, value=1):
429        self._bound_value = value
430        super().__init__(value)
431
432    def release(self):
433        if self._value >= self._bound_value:
434            raise ValueError('BoundedSemaphore released too many times')
435        super().release()
436
437
438
439class _BarrierState(enum.Enum):
440    FILLING = 'filling'
441    DRAINING = 'draining'
442    RESETTING = 'resetting'
443    BROKEN = 'broken'
444
445
446class Barrier(mixins._LoopBoundMixin):
447    """Asyncio equivalent to threading.Barrier
448
449    Implements a Barrier primitive.
450    Useful for synchronizing a fixed number of tasks at known synchronization
451    points. Tasks block on 'wait()' and are simultaneously awoken once they
452    have all made their call.
453    """
454
455    def __init__(self, parties):
456        """Create a barrier, initialised to 'parties' tasks."""
457        if parties < 1:
458            raise ValueError('parties must be > 0')
459
460        self._cond = Condition() # notify all tasks when state changes
461
462        self._parties = parties
463        self._state = _BarrierState.FILLING
464        self._count = 0       # count tasks in Barrier
465
466    def __repr__(self):
467        res = super().__repr__()
468        extra = f'{self._state.value}'
469        if not self.broken:
470            extra += f', waiters:{self.n_waiting}/{self.parties}'
471        return f'<{res[1:-1]} [{extra}]>'
472
473    async def __aenter__(self):
474        # wait for the barrier reaches the parties number
475        # when start draining release and return index of waited task
476        return await self.wait()
477
478    async def __aexit__(self, *args):
479        pass
480
481    async def wait(self):
482        """Wait for the barrier.
483
484        When the specified number of tasks have started waiting, they are all
485        simultaneously awoken.
486        Returns an unique and individual index number from 0 to 'parties-1'.
487        """
488        async with self._cond:
489            await self._block() # Block while the barrier drains or resets.
490            try:
491                index = self._count
492                self._count += 1
493                if index + 1 == self._parties:
494                    # We release the barrier
495                    await self._release()
496                else:
497                    await self._wait()
498                return index
499            finally:
500                self._count -= 1
501                # Wake up any tasks waiting for barrier to drain.
502                self._exit()
503
504    async def _block(self):
505        # Block until the barrier is ready for us,
506        # or raise an exception if it is broken.
507        #
508        # It is draining or resetting, wait until done
509        # unless a CancelledError occurs
510        await self._cond.wait_for(
511            lambda: self._state not in (
512                _BarrierState.DRAINING, _BarrierState.RESETTING
513            )
514        )
515
516        # see if the barrier is in a broken state
517        if self._state is _BarrierState.BROKEN:
518            raise exceptions.BrokenBarrierError("Barrier aborted")
519
520    async def _release(self):
521        # Release the tasks waiting in the barrier.
522
523        # Enter draining state.
524        # Next waiting tasks will be blocked until the end of draining.
525        self._state = _BarrierState.DRAINING
526        self._cond.notify_all()
527
528    async def _wait(self):
529        # Wait in the barrier until we are released. Raise an exception
530        # if the barrier is reset or broken.
531
532        # wait for end of filling
533        # unless a CancelledError occurs
534        await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
535
536        if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
537            raise exceptions.BrokenBarrierError("Abort or reset of barrier")
538
539    def _exit(self):
540        # If we are the last tasks to exit the barrier, signal any tasks
541        # waiting for the barrier to drain.
542        if self._count == 0:
543            if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
544                self._state = _BarrierState.FILLING
545            self._cond.notify_all()
546
547    async def reset(self):
548        """Reset the barrier to the initial state.
549
550        Any tasks currently waiting will get the BrokenBarrier exception
551        raised.
552        """
553        async with self._cond:
554            if self._count > 0:
555                if self._state is not _BarrierState.RESETTING:
556                    #reset the barrier, waking up tasks
557                    self._state = _BarrierState.RESETTING
558            else:
559                self._state = _BarrierState.FILLING
560            self._cond.notify_all()
561
562    async def abort(self):
563        """Place the barrier into a 'broken' state.
564
565        Useful in case of error.  Any currently waiting tasks and tasks
566        attempting to 'wait()' will have BrokenBarrierError raised.
567        """
568        async with self._cond:
569            self._state = _BarrierState.BROKEN
570            self._cond.notify_all()
571
572    @property
573    def parties(self):
574        """Return the number of tasks required to trip the barrier."""
575        return self._parties
576
577    @property
578    def n_waiting(self):
579        """Return the number of tasks currently waiting at the barrier."""
580        if self._state is _BarrierState.FILLING:
581            return self._count
582        return 0
583
584    @property
585    def broken(self):
586        """Return True if the barrier is in a broken state."""
587        return self._state is _BarrierState.BROKEN
588