1# mypy: allow-untyped-defs 2 3import sys 4import threading 5from dataclasses import dataclass 6from typing import Dict, List, Optional, Tuple, Union 7from functools import partial, reduce 8 9import torch 10import torch.distributed as dist 11import weakref 12from torch._C._distributed_c10d import ( 13 _create_work_from_future, 14 AllgatherOptions, 15 AllreduceOptions, 16 AllToAllOptions, 17 BarrierOptions, 18 BroadcastOptions, 19 ReduceScatterOptions, 20 ScatterOptions, 21 Store, 22 ReduceOp, 23) 24from torch.distributed.distributed_c10d import _CollOp, _store_based_barrier, P2POp 25from torch.futures import Future 26from torch.utils import _pytree as pytree 27 28""" 29TODO: 30Lots of missing collectives. 31Collectives validation. 32Make timeout robust by making collectives respect the test deadline. 33Make tests robust by making collectives interruptible. 34We need some synchronization around cleanup to ensure that timedout ranks don't cause spurious failures. 35 36""" 37 38 39def flatten_list(lst): 40 return pytree.tree_leaves(lst) 41 42 43def ret_work(ret): 44 fut = Future() 45 fut.set_result(ret) 46 return _create_work_from_future(fut) 47 48def binop_reduce(tensors, op): 49 res = op(torch.stack(tensors), dim=0) 50 if isinstance(res, torch.Tensor): 51 return res 52 # min/max return a namedtuple 53 return res.values 54 55def bitwise_reduce(tensors, op): 56 return reduce(op, tensors) 57 58_reduce_ops = { 59 ReduceOp.SUM: partial(binop_reduce, op=torch.sum), 60 ReduceOp.AVG: partial(binop_reduce, op=torch.mean), 61 ReduceOp.PRODUCT: partial(binop_reduce, op=torch.prod), 62 ReduceOp.MIN: partial(binop_reduce, op=torch.min), 63 ReduceOp.MAX: partial(binop_reduce, op=torch.max), 64 ReduceOp.BAND: partial(bitwise_reduce, op=torch.bitwise_and), 65 ReduceOp.BOR: partial(bitwise_reduce, op=torch.bitwise_or), 66 ReduceOp.BXOR: partial(bitwise_reduce, op=torch.bitwise_xor), 67} 68 69class AllToAll: 70 @torch.no_grad() 71 def work(self, data): 72 world_size = len(data) 73 for dest_rank in range(world_size): 74 output_tensor_list, _ = data[dest_rank] 75 for src_rank in range(world_size): 76 _, input_tensor_list = data[src_rank] 77 output_tensor_list[src_rank].copy_(input_tensor_list[dest_rank]) 78 79class AllToAllBase: 80 @torch.no_grad() 81 def work(self, data): 82 world_size = len(data) 83 for dest_rank in range(world_size): 84 output_buffer, _, output_split_sizes, _ = data[dest_rank] 85 86 output_indexes = self._size_cumsum(output_buffer.size(0), output_split_sizes, world_size) 87 88 for src_rank in range(world_size): 89 _, input_buffer, _, input_split_sizes = data[src_rank] 90 input_indexes = self._size_cumsum(input_buffer.size(0), input_split_sizes, world_size) 91 92 output_buffer[output_indexes[src_rank]:output_indexes[src_rank + 1]].copy_( 93 input_buffer[input_indexes[dest_rank]:input_indexes[dest_rank + 1]] 94 ) 95 96 def _size_cumsum(self, buf_size: int, sizes: Union[torch.Tensor, List[int], None], world_size: int) -> torch.Tensor: 97 if sizes is None or len(sizes) == 0: 98 sizes = torch.full( 99 (world_size,), buf_size // world_size, dtype=torch.int64 100 ) 101 if not isinstance(sizes, torch.Tensor): 102 sizes = torch.tensor(sizes, dtype=torch.int64) 103 assert sizes.dtype == torch.int64 104 sizes = torch.cumsum( 105 torch.cat( 106 ( 107 torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes 108 ), 109 dim=0 110 ), 111 dim=0 112 ) 113 return sizes 114 115class AllReduce: 116 def __init__(self, op): 117 if op.op not in _reduce_ops: 118 raise NotImplementedError( 119 f"AllReduce op {op.op} not supported on multithreaded pg for now." 120 ) 121 self.op = op.op 122 123 @torch.no_grad() 124 def work(self, data): 125 for i in range(len(data[0])): 126 tensors = [] 127 # use rank0 as the device for sum 128 rank_0_device = data[0][i].device 129 # collect all data to the list and make them 130 # all on rank 0 device 131 for src_rank in range(0, len(data)): 132 tensors.append(data[src_rank][i].to(rank_0_device)) 133 134 # now mimic reduce across all ranks 135 res = _reduce_ops[self.op](tensors) 136 137 # copy all the reduced value to each rank 138 for src_rank in range(len(data)): 139 data[src_rank][i].copy_(res.to(data[src_rank][i].device)) 140 141 142class AllGather: 143 @torch.no_grad() 144 def work(self, data): 145 for src_rank in range(len(data)): 146 in_tensor_list = data[src_rank][1] 147 # Can't handle all_gather with multiple tensors 148 assert len(in_tensor_list) == 1 149 src_tensor = in_tensor_list[0] 150 151 for dest in data: 152 dest_tensor = dest[0][0][src_rank] 153 dest_tensor.copy_(src_tensor) 154 155 156class Scatter: 157 def __init__(self, src): 158 self.src = src 159 160 @torch.no_grad() 161 def work(self, data): 162 src_in_tensor_list = data[self.src][1] 163 # Can't handle scatter with multiple input tensor list 164 assert len(src_in_tensor_list) == 1 165 src_in_tensors = src_in_tensor_list[0] 166 167 for rank, each_rank_data in enumerate(data): 168 out_tensor_list = each_rank_data[0] 169 # Can't handle scatter with multiple output tensor 170 assert len(out_tensor_list) == 1 171 dest_tensor = out_tensor_list[0] 172 dest_tensor.copy_(src_in_tensors[rank]) 173 174 175class Gather: 176 def __init__(self, dst): 177 self.dst = dst 178 179 @torch.no_grad() 180 def work(self, data): 181 # Can't handle gather with multiple tensor lists 182 assert len(data[self.dst][0]) == 1 183 out_tensor_list = data[self.dst][0][0] 184 for rank, each_rank_data in enumerate(data): 185 src_in_tensor_list = each_rank_data[1] 186 # Can't handle gather with multiple tensor lists 187 assert len(src_in_tensor_list) == 1 188 dest_tensor = out_tensor_list[rank] 189 dest_tensor.copy_(src_in_tensor_list[0]) 190 191class ReduceScatter: 192 def __init__(self, op): 193 if op != dist.ReduceOp.SUM and op != dist.ReduceOp.AVG: 194 raise NotImplementedError(f"ReduceScatter does not support {op}") 195 self.op = op 196 197 @torch.no_grad() 198 def work(self, data): 199 start_reduction = [False for _ in range(len(data))] 200 for each_rank_data in data: 201 # Can't handle reduce_scatter with multiple scatter list 202 assert len(each_rank_data[1]) == 1 203 to_scatter = each_rank_data[1][0] 204 for i in range(len(to_scatter)): 205 dest_tensor_on_rank_i = data[i][0] 206 # Can't handle reduce_scatter with multiple output tensor 207 assert len(dest_tensor_on_rank_i) == 1 208 dst_tensor_device = dest_tensor_on_rank_i[0].device 209 if not start_reduction[i]: 210 dest_tensor_on_rank_i[0].copy_(to_scatter[i].to(dst_tensor_device)) 211 start_reduction[i] = True 212 else: 213 dest_tensor_on_rank_i[0].add_(to_scatter[i].to(dst_tensor_device)) 214 if self.op == dist.ReduceOp.AVG: 215 num_ranks = len(data) 216 for each_rank_data in data: 217 each_rank_data[0][0] /= num_ranks 218 219 220class Broadcast: 221 def __init__(self, src): 222 self.src = src 223 224 @torch.no_grad() 225 def work(self, data): 226 in_tensor_list = flatten_list(data[self.src]) 227 for i in range(len(data)): 228 out_tensor_list = flatten_list(data[i]) 229 for j in range(len(in_tensor_list)): 230 out_tensor_list[j].copy_(in_tensor_list[j]) 231 232 233class Collective: 234 def __init__(self, world_size, collective, pg): 235 self._world_size = world_size 236 self._collective = collective 237 238 self._start_cond = threading.Condition() 239 self._done_cond = threading.Condition() 240 241 self._data = [None] * world_size 242 self._count = 0 243 self._done = False 244 245 self._pg = pg 246 247 def join(self, rank, data): 248 with self._start_cond: 249 self._data[rank] = data 250 self._count += 1 251 252 # notify rank 0 253 if self._count == self._world_size: 254 if rank > 0: 255 self._start_cond.notify() 256 257 if rank == 0: 258 self._start_cond.wait_for( 259 lambda: self._count == self._world_size or self._pg._terminate.is_set() 260 ) 261 # SystemExit is not a subclass of Exception but BaseException 262 # and can be distinguished from normal exception raised from program errors 263 # so that we can hide it from the exception queue 264 if self._pg._terminate.is_set(): 265 sys.exit("Test termination event occurs.") 266 267 with self._done_cond: 268 # wait for rank 0 to finish 269 if rank > 0: 270 self._done_cond.wait_for(lambda: self._done or self._pg._terminate.is_set()) 271 if self._pg._terminate.is_set(): 272 sys.exit("Test termination event occurs.") 273 else: 274 # copy data around 275 self._collective.work(self._data) 276 self._done = True 277 self._done_cond.notify_all() 278 return ret_work(data) 279 280 281class ProcessLocalGroup(dist.ProcessGroup): 282 _coll_lock = threading.Lock() 283 _cur_coll_on_pgs = {} 284 285 _terminate = threading.Event() 286 287 @classmethod 288 def _start_coll(cls, collective, pg): 289 with cls._coll_lock: 290 # pg_name is unique, we use that to record the mapping between pg and collective 291 if pg.pg_name not in cls._cur_coll_on_pgs: 292 cls._cur_coll_on_pgs[pg.pg_name] = Collective(pg.size(), collective, cls) 293 return cls._cur_coll_on_pgs[pg.pg_name] 294 295 @classmethod 296 def _end_coll(cls, collective, pg): 297 # This is racily called by all ranks, so only one will work 298 with cls._coll_lock: 299 if pg.pg_name in cls._cur_coll_on_pgs and cls._cur_coll_on_pgs[pg.pg_name] == collective: 300 cls._cur_coll_on_pgs.pop(pg.pg_name) 301 302 @classmethod 303 def exception_handle(cls, exc): 304 cls._terminate.set() 305 for coll in cls._cur_coll_on_pgs.values(): 306 with coll._start_cond: 307 coll._start_cond.notify() 308 with coll._done_cond: 309 coll._done_cond.notify_all() 310 311 @classmethod 312 def reset(cls): 313 with cls._coll_lock: 314 cls._cur_coll_on_pgs = {} 315 cls._terminate.clear() 316 317 def alltoall_base( 318 self, 319 output_buffer: torch.Tensor, 320 input_buffer: torch.Tensor, 321 output_split_sizes: Optional[List[int]], 322 input_split_sizes: Optional[List[int]], 323 opts=AllToAllOptions() 324 ) -> torch.Tensor: 325 coll = ProcessLocalGroup._start_coll(AllToAllBase(), self) 326 res = coll.join(self._rank, (output_buffer, input_buffer, output_split_sizes, input_split_sizes)) 327 ProcessLocalGroup._end_coll(coll, self) 328 return res 329 330 def alltoall(self, output_tensor_list, input_tensor_list, opts=AllToAllOptions()): 331 coll = ProcessLocalGroup._start_coll(AllToAll(), self) 332 res = coll.join(self._rank, (output_tensor_list, input_tensor_list)) 333 ProcessLocalGroup._end_coll(coll, self) 334 return res 335 336 def allreduce(self, tensor_list, opts=AllreduceOptions()): 337 coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self) 338 res = coll.join(self._rank, tensor_list) 339 ProcessLocalGroup._end_coll(coll, self) 340 return res 341 342 def allreduce_coalesced(self, tensor_list, opts=AllreduceOptions()): 343 coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self) 344 res = coll.join(self._rank, tensor_list) 345 ProcessLocalGroup._end_coll(coll, self) 346 return res 347 348 def barrier(self, opts=BarrierOptions()): 349 return self.allreduce(tensor_list=[torch.ones(1)]) 350 351 def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()): 352 coll = ProcessLocalGroup._start_coll(AllGather(), self) 353 res = coll.join(self._rank, (output_tensors, input_tensor)) 354 ProcessLocalGroup._end_coll(coll, self) 355 return res 356 357 def _allgather_base(self, output_tensor, input_tensor, opts=AllgatherOptions()): 358 tensor_list = list(torch.chunk(output_tensor, self._world_size)) 359 return self.allgather([tensor_list], [input_tensor], opts) 360 361 def broadcast(self, tensor_list, opts=BroadcastOptions()): 362 coll = ProcessLocalGroup._start_coll(Broadcast(opts.rootRank), self) 363 res = coll.join(self._rank, tensor_list) 364 ProcessLocalGroup._end_coll(coll, self) 365 return res 366 367 def scatter(self, output_tensors, input_tensors, opts=ScatterOptions()): 368 coll = ProcessLocalGroup._start_coll(Scatter(opts.rootRank), self) 369 res = coll.join(self._rank, (output_tensors, input_tensors)) 370 ProcessLocalGroup._end_coll(coll, self) 371 return res 372 373 def gather(self, output_tensors, input_tensors, opts=ScatterOptions()): 374 coll = ProcessLocalGroup._start_coll(Gather(opts.rootRank), self) 375 res = coll.join(self._rank, (output_tensors, input_tensors)) 376 ProcessLocalGroup._end_coll(coll, self) 377 return res 378 379 def reduce_scatter(self, output_tensor, scatter_list, opts=ReduceScatterOptions()): 380 coll = ProcessLocalGroup._start_coll(ReduceScatter(opts.reduceOp), self) 381 res = coll.join(self._rank, (output_tensor, scatter_list)) 382 ProcessLocalGroup._end_coll(coll, self) 383 return res 384 385 def _reduce_scatter_base(self, output_tensor, input_tensor, opts=ReduceScatterOptions()): 386 tensor_list = list(torch.chunk(input_tensor, self._world_size)) 387 return self.reduce_scatter([output_tensor], [tensor_list], opts) 388 389 def reduce_scatter_tensor_coalesced(self, output_tensors, input_tensors, opts=ReduceScatterOptions()): 390 works = [ 391 self._reduce_scatter_base(output_tensor, input_tensor, opts) 392 for output_tensor, input_tensor 393 in zip(output_tensors, input_tensors) 394 ] 395 for work in works[:-1]: 396 work.wait() 397 return works[-1] 398 399 def allgather_into_tensor_coalesced(self, output_tensor_list, input_tensor_list, opts=AllgatherOptions()): 400 res = None 401 for o_t, i_t in zip(output_tensor_list, input_tensor_list): 402 res = self._allgather_base(o_t, i_t) 403 return res 404 405 def __init__(self, rank, world_size): 406 super().__init__(rank, world_size) 407 self._rank = rank 408 self._world_size = world_size 409 world = dist.distributed_c10d._world 410 if isinstance(world, ThreadLocalWorld): 411 world = world._get_world() 412 self._world = weakref.ref(world) 413 self._ctx = torch.autograd.set_multithreading_enabled(False) 414 415 def size(self): 416 return self._world_size 417 418 @property 419 def pg_name(self): 420 """ 421 return the global registered name of the current pg in the world 422 """ 423 return self._world().pg_names[self] 424 425 @property 426 def group_name(self): 427 return self.pg_name 428 429 def getBackendName(self): 430 return "threaded" 431 432 def __repr__(self): 433 return f"ThreadedPG world_size:{self._world_size} rank:{self._rank}" 434 435 436def _create_threaded_pg(prefix_store, rank, world_size, timeout): 437 pg = ProcessLocalGroup(rank, world_size) 438 # https://github.com/pytorch/pytorch/pull/103033 changed store based barrier to optional 439 # When device mesh involves sub groups while store based barrier is not enabled in c10d, 440 # even though threaded pg actual collectives are assumed to be single threaded, 441 # different threads may be initializing different groups, 442 # leading to race conditions. 443 # For example, if we have a mesh of [[0, 1], [2, 3]], the sub groups 444 # (dim 0 and 1) would be initialized in different threads independently. 445 # In this case we can no longer rely on class or global variables 446 # but have to rely on store based barrier to make sure each group 447 # is ready separately before we can invoke collectives in any of the groups. 448 449 # the prefix store is already per group so we pass an empty name here 450 _store_based_barrier(rank, prefix_store, "", world_size, timeout) 451 return pg 452 453 454dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "cuda"]) 455 456 457@dataclass 458class WorldData: 459 default_pg: dist.ProcessGroup 460 pg_map: Dict[dist.ProcessGroup, Tuple[str, Optional[Store]]] 461 pg_names: Dict[dist.ProcessGroup, str] 462 pg_group_ranks: Dict[dist.ProcessGroup, Dict[int, int]] 463 pg_backend_config: Dict[dist.ProcessGroup, str] 464 group_count: int 465 tags_to_pg: Dict[str, List[dist.ProcessGroup]] 466 pg_to_tag: Dict[dist.ProcessGroup, str] 467 pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]] 468 pg_default_device: Dict[dist.ProcessGroup, torch.device] 469 470 471class ThreadLocalWorld: 472 _world = threading.local() 473 474 def _get_world(self) -> WorldData: 475 if not hasattr(ThreadLocalWorld._world, "world"): 476 ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {}) 477 return ThreadLocalWorld._world.world 478 479 @property 480 def default_pg(self): 481 return self._get_world().default_pg 482 483 @default_pg.setter 484 def default_pg(self, value): 485 self._get_world().default_pg = value 486 487 @property 488 def pg_map(self): 489 return self._get_world().pg_map 490 491 @property 492 def pg_names(self): 493 return self._get_world().pg_names 494 495 @property 496 def pg_group_ranks(self): 497 return self._get_world().pg_group_ranks 498 499 @property 500 def pg_backend_config(self): 501 return self._get_world().pg_backend_config 502 503 @property 504 def group_count(self) -> int: 505 return self._get_world().group_count 506 507 @group_count.setter 508 def group_count(self, value): 509 self._get_world().group_count = value 510 511 @property 512 def tags_to_pg(self): 513 return self._get_world().tags_to_pg 514 515 @property 516 def pg_to_tag(self): 517 return self._get_world().pg_to_tag 518 519 @property 520 def pg_coalesce_state(self) -> Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]: 521 return self._get_world().pg_coalesce_state 522 523 @property 524 def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]: 525 return self._get_world().pg_default_device 526 527 528_old_pg_world = None 529_ctx_manager = None 530 531 532def _install_threaded_pg(): 533 global _old_pg_world 534 global _ctx_manager 535 _old_pg_world = dist.distributed_c10d._world 536 dist.distributed_c10d._world = ThreadLocalWorld() 537 _ctx_manager = torch.autograd.set_multithreading_enabled(False) 538 539 return dist.distributed_c10d._world 540 541 542def _uninstall_threaded_pg(): 543 dist.distributed_c10d._world = _old_pg_world 544