1import collections 2from typing import Deque, Optional 3 4import torch 5 6 7class _FreeEventQueue: 8 """ 9 This tracks all pending frees corresponding to inflight all-gathers. The 10 queueing pattern is iterative enqueues with a single dequeue per iteration 11 once the limit ``_max_num_inflight_all_gathers`` is reached. 12 """ 13 14 def __init__(self) -> None: 15 self._queue: Deque[torch.Event] = collections.deque() 16 self._max_num_inflight_all_gathers = 2 # empirically chosen 17 18 def enqueue(self, free_event: torch.Event) -> None: 19 """Enqueues a free event.""" 20 self._queue.append(free_event) 21 22 def dequeue_if_needed(self) -> Optional[torch.Event]: 23 """Dequeues a single event if the limit is reached.""" 24 if len(self._queue) >= self._max_num_inflight_all_gathers: 25 return self._dequeue() 26 return None 27 28 def _dequeue(self) -> Optional[torch.Event]: 29 """Dequeues a free event if possible.""" 30 if self._queue: 31 event = self._queue.popleft() 32 return event 33 return None 34