from typing import Optional import torch.distributed as dist from . import config _COMPILE_PG: Optional[dist.ProcessGroup] = None def get_compile_pg() -> Optional[dist.ProcessGroup]: if ( config.enable_compiler_collectives and dist.is_available() and dist.is_initialized() ): global _COMPILE_PG if _COMPILE_PG is None: # , timeout=datetime.timedelta(seconds=2) _COMPILE_PG = dist.distributed_c10d._new_group_with_tag( pg_tag="pt2_compile_pg" ) return _COMPILE_PG return None