xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_limiter_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import collections
2from typing import Deque, Optional
3
4import torch
5
6
7class _FreeEventQueue:
8    """
9    This tracks all pending frees corresponding to inflight all-gathers. The
10    queueing pattern is iterative enqueues with a single dequeue per iteration
11    once the limit ``_max_num_inflight_all_gathers`` is reached.
12    """
13
14    def __init__(self) -> None:
15        self._queue: Deque[torch.Event] = collections.deque()
16        self._max_num_inflight_all_gathers = 2  # empirically chosen
17
18    def enqueue(self, free_event: torch.Event) -> None:
19        """Enqueues a free event."""
20        self._queue.append(free_event)
21
22    def dequeue_if_needed(self) -> Optional[torch.Event]:
23        """Dequeues a single event if the limit is reached."""
24        if len(self._queue) >= self._max_num_inflight_all_gathers:
25            return self._dequeue()
26        return None
27
28    def _dequeue(self) -> Optional[torch.Event]:
29        """Dequeues a free event if possible."""
30        if self._queue:
31            event = self._queue.popleft()
32            return event
33        return None
34