1# mypy: allow-untyped-defs 2import sys 3import warnings 4from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union 5 6import torch 7import torch.distributed as dist 8import torch.distributed.distributed_c10d as c10d 9from torch.distributed.device_mesh import DeviceMesh 10from torch.fx.experimental.proxy_tensor import get_proxy_mode 11 12from . import _functional_collectives_impl as fun_col_impl 13 14 15try: 16 from torch.utils._cxx_pytree import tree_map_only 17except ImportError: 18 from torch.utils._pytree import tree_map_only # type: ignore[no-redef] 19 20 21if torch._running_with_deploy(): 22 23 def is_torchdynamo_compiling(): 24 """Can't import torchdynamo in torchdeploy builds currently.""" 25 return False 26 27else: 28 try: 29 from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling 30 except Exception: 31 warnings.warn( 32 "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" 33 ) 34 35 def is_torchdynamo_compiling(): 36 return False 37 38 39""" 40New traceable, functional collectives. 41RFC: https://github.com/pytorch/pytorch/issues/93173 42 43 compiler: trace these ops with plain-old-data schemas, then choose how to lower them. 44 eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses, 45 automatically calling .wait() on underlying/hidden async 'work' obj only when fed to 46 a downstream op. 47 48Issues: 49* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files 50* Proper support for eager requires inplace ops. We should explore having it as an option for the API. 51""" 52 53""" 54Functional collectives are asynchronous only and we perform implicit stream synchronization 55on behalf of the user. 56 57We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness 58first usage of the tensor and insert cross stream sync at the right place. 59 60The above are the easy bits, the hard one is how we match the Work object returned by 61c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective 62op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the 63dispatcher which might call other implementations that are allowed to change the returned 64tensor - even return a tensor with a different shape (see ``torch.vmap``). 65 66This means the caller of our ops receives a Tensor that is not guaranteed to be the same 67allocated by our implementations and that makes pairing The AsyncTensor to the original 68tensor a lot harder. This pairing is needed so we can lookup the Work object to use. 69 70Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's 71identity is not stable across dispatch, the op caller would end up with a different Tensor 72instance that would not match any in the dictionary. 73 74With Tensor identity out of the question, we decided use the tensor data pointer, which 75should be stable across all the Tensor changes done during dispatch. 76 77We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d. 78 79We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait() 80 81Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we 82can clean up stale entries in the dictionary. 83 84To eliminate the possibility of races we have a global version counter that is used by the finalizer. 85 86As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo) 87 88""" 89 90""" 91Functional collectives can accept any of these types to describe the ranks participating in collectives. 92 93The different types will be desugared to a canonical format 94""" 95RANK_TYPES = Union[ 96 List[int], 97 List[List[int]], 98 dist.ProcessGroup, 99 DeviceMesh, 100 Tuple["dist.tensor.DeviceMesh", int], 101 str, 102] 103 104 105""" 106User facing APIs for functional collectives 107------------------------------------------- 108 109These apis are called by user code and expected to work both in eager execution and compilation, 110but there are significant differences to how the two modes are implemented underneath. 111 112Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op) 113just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization, 114and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified 115if sufficient subclass support is added in dynamo. 116 117Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern. 118 119Here's how it works under torch.compile/dynamo: 120all_reduce(...) 121 |--> _expand_group(...) - desugars processgroup into canonical/traceable format 122 |--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper 123 |--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed 124 125And under eager execution: 126all_reduce(...) 127 |--> _expand_group(...) - same as above, but less critical for eager 128 |--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace 129 |--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor, 130 which issues wait_tensor() at the time of first use 131""" 132 133 134def wait_tensor(tensor): 135 """ 136 Wait on a tensor returned by the collectives ops. 137 138 Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA. 139 """ 140 return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] 141 142 143def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""): 144 """ 145 Broadcasts the tensor to all processes in the given process group. 146 147 Args: 148 src (int): Source rank 149 group (ProcessGroup or List[int]): The process group to work on. 150 tag (str, optional): A unique identifier for the collective. Default: empty string 151 """ 152 group_name = _resolve_group_name(group, tag) 153 tensor = torch.ops._c10d_functional.broadcast(self, src, group_name) 154 return _maybe_wrap_tensor(tensor) 155 156 157def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""): 158 """ 159 Reduces the tensor data across all machines in such a way that all get 160 the final result. 161 162 The input tensor is left unmodified. 163 164 Group can be one of: 165 List[int]: ranks participating in the collective. 166 List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. 167 ProcessGroup: Will perform a collective using the ranks and tag of the PG. 168 DeviceMesh: Do a SPMD collective over all ranks of the mesh 169 (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh 170 171 :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover 172 that information and perform collective algebraic optimization. Use other forms of input for that. 173 """ 174 group_name = _resolve_group_name(group, tag) 175 tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name) 176 return _maybe_wrap_tensor(tensor) 177 178 179def all_gather_tensor( 180 self: torch.Tensor, 181 gather_dim: int, 182 group: RANK_TYPES, 183 tag: str = "", 184): 185 """ 186 Gather tensor data across from all machines and concatenate over ``gather_dim``. 187 188 Note that it currently only supports gather_dim = 0. 189 190 The input tensor is left unmodified. 191 Group can be one of: 192 List[int]: ranks participating in the collective. 193 List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. 194 ProcessGroup: Will perform a collective using the ranks and tag of the PG. 195 DeviceMesh: Do a SPMD collective over all ranks of the mesh 196 (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh 197 198 :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover 199 that information and perform collective algebraic optimization. Use other forms of input for that. 200 """ 201 assert self.is_contiguous() 202 group_name = _resolve_group_name(group, tag) 203 group_size = c10d._get_group_size_by_name(group_name) 204 tensor = torch.ops._c10d_functional.all_gather_into_tensor( 205 self, group_size, group_name 206 ) 207 res = _maybe_wrap_tensor(tensor) 208 # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call 209 if gather_dim != 0: 210 # torch.cat access the data so we already need to wait here, first do wait 211 # and then chunk + cat avoid us going through ACT dispatching logic again 212 if isinstance(res, AsyncCollectiveTensor): 213 res = res.wait() # type: ignore[attr-defined] 214 res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) 215 return res 216 217 218def all_gather_tensor_autograd( 219 self: torch.Tensor, 220 gather_dim: int, 221 group: RANK_TYPES, 222 tag: str = "", 223): 224 """ 225 Gather tensor data across from all machines and concatenate over ``gather_dim``. 226 227 Note that it currently only supports gather_dim = 0. 228 229 This function is the same as all_gather_tensor but will propagate the 230 backwards gradient across workers. 231 232 See all_gather_tensor for more details on usage. 233 """ 234 group_name = _resolve_group_name(group, tag) 235 group_size = c10d._get_group_size_by_name(group_name) 236 237 tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor( 238 self, group_size, group_name 239 ) 240 res = _FromTorchTensor.apply(tensor) 241 # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call 242 if gather_dim != 0: 243 # torch.cat access the data so we already need to wait here, first do wait 244 # and then chunk + cat avoid us going through ACT dispatching logic again 245 if isinstance(res, AsyncCollectiveTensor): 246 res = res.wait() # type: ignore[attr-defined] 247 res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) 248 return res 249 250 251def reduce_scatter_tensor( 252 self: torch.Tensor, 253 reduceOp: str, 254 scatter_dim: int, 255 group: RANK_TYPES, 256 tag: str = "", 257): 258 """ 259 Reduces the tensor data across all machines in such a way that all get 260 the final result, then scatter the results to corresponding ranks. 261 262 263 The input tensor is left unmodified. 264 Group can be one of: 265 List[int]: ranks participating in the collective. 266 List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. 267 ProcessGroup: Will perform a collective using the ranks and tag of the PG. 268 DeviceMesh: Do a SPMD collective over all ranks of the mesh 269 (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh 270 :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover 271 that information and perform collective algebraic optimization. Use other forms of input for that. 272 """ 273 group_name = _resolve_group_name(group, tag) 274 group_size = c10d._get_group_size_by_name(group_name) 275 276 assert ( 277 self.size(scatter_dim) % group_size == 0 278 ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" 279 if scatter_dim != 0: 280 tensor_list = torch.chunk(self, group_size, dim=scatter_dim) 281 self = torch.cat(tensor_list) 282 283 tensor = torch.ops._c10d_functional.reduce_scatter_tensor( 284 self, 285 reduceOp.lower(), 286 group_size, 287 group_name, # type: ignore[possibly-undefined] 288 ) 289 res = _maybe_wrap_tensor(tensor) 290 return res 291 292 293def reduce_scatter_tensor_autograd( 294 self: torch.Tensor, 295 reduceOp: str, 296 scatter_dim: int, 297 group: RANK_TYPES, 298 tag: str = "", 299): 300 """ 301 Reduces the tensor data across all machines in such a way that all get 302 the final result, then scatter the results to corresponding ranks. 303 304 This function is the same as reduce_scatter_tensor but will propagate the 305 backwards gradient across workers. 306 307 Currently only the "sum" reduceOp is supported. 308 309 See reduce_scatter_tensor for more details on usage. 310 """ 311 312 group_name = _resolve_group_name(group, tag) 313 group_size = c10d._get_group_size_by_name(group_name) 314 315 assert ( 316 self.size(scatter_dim) % group_size == 0 317 ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" 318 if scatter_dim != 0: 319 tensor_list = torch.chunk(self, group_size, dim=scatter_dim) 320 self = torch.cat(tensor_list) 321 322 tensor = torch.ops._c10d_functional_autograd.reduce_scatter_tensor( 323 self, 324 reduceOp.lower(), 325 group_size, 326 group_name, # type: ignore[possibly-undefined] 327 ) 328 res = _FromTorchTensor.apply(tensor) 329 return res 330 331 332def all_reduce_coalesced( 333 self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = "" 334) -> List[torch.Tensor]: 335 """ 336 Reduces a list of tensors across all machines in such a way that all get 337 the final result. 338 339 The all tensors in the input list are left unmodified. 340 341 Group can be one of: 342 List[int]: ranks participating in the collective. 343 List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. 344 ProcessGroup: Will perform a collective using the ranks and tag of the PG. 345 DeviceMesh: Do a SPMD collective over all ranks of the mesh 346 (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh 347 348 :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover 349 that information and perform collective algebraic optimization. Use other forms of input for that. 350 """ 351 group_name = _resolve_group_name(group, tag) 352 tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined] 353 self, 354 reduceOp.lower(), 355 group_name, 356 ) 357 return list(map(_maybe_wrap_tensor, tensor_list)) 358 359 360def all_gather_into_tensor_coalesced( 361 self: List[torch.Tensor], group: RANK_TYPES, tag: str = "" 362) -> List[torch.Tensor]: 363 """ 364 Gather a list of tensors across from all machines. 365 366 Note that it currently only supports gather_dim = 0. 367 368 The input tensor is left unmodified. 369 Group can be one of: 370 List[int]: ranks participating in the collective. 371 List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. 372 ProcessGroup: Will perform a collective using the ranks and tag of the PG. 373 DeviceMesh: Do a SPMD collective over all ranks of the mesh 374 (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh 375 376 :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover 377 that information and perform collective algebraic optimization. Use other forms of input for that. 378 """ 379 group_name = _resolve_group_name(group, tag) 380 group_size = c10d._get_group_size_by_name(group_name) 381 tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined] 382 self, 383 group_size, 384 group_name, 385 ) 386 return list(map(_maybe_wrap_tensor, tensor_list)) 387 388 389def reduce_scatter_tensor_coalesced( 390 inputs: List[torch.Tensor], 391 reduceOp: str, 392 scatter_dim: List[int], 393 group: RANK_TYPES, 394 tag: str = "", 395) -> List[torch.Tensor]: 396 """ 397 Reduces a list of tensors across all machines in such a way that all get 398 the final result, then scatter the results to corresponding ranks. 399 400 The input tensors are left unmodified. 401 Group can be one of: 402 List[int]: ranks participating in the collective. 403 List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. 404 ProcessGroup: Will perform a collective using the ranks and tag of the PG. 405 DeviceMesh: Do a SPMD collective over all ranks of the mesh 406 (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh 407 408 :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover 409 that information and perform collective algebraic optimization. Use other forms of input for that. 410 """ 411 group_name = _resolve_group_name(group, tag) 412 group_size = c10d._get_group_size_by_name(group_name) 413 414 assert len(scatter_dim) == len(inputs) 415 for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): 416 assert ( 417 tensor.size(dim) % group_size == 0 418 ), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" 419 if dim != 0: 420 tensor_list = torch.chunk(tensor, group_size, dim=dim) 421 inputs[idx] = torch.cat(tensor_list) 422 423 tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined] 424 inputs, 425 reduceOp.lower(), 426 group_size, 427 group_name, # type: ignore[possibly-undefined] 428 ) 429 430 return list(map(_maybe_wrap_tensor, tensor_list)) 431 432 433# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias. 434# Today, this maps 1:1 with "aten ops that are views". 435def _is_view_op(tgt): 436 assert isinstance(tgt, torch._ops.OpOverload) 437 schema = tgt._schema 438 if len(schema.arguments) > 0: 439 first_arg = schema.arguments[0] 440 # check if op is a view 441 return first_arg.alias_info is not None and not first_arg.alias_info.is_write 442 443 444def all_to_all_single( 445 self: torch.Tensor, 446 output_split_sizes: Optional[List[int]], 447 input_split_sizes: Optional[List[int]], 448 group: RANK_TYPES, 449 tag: str = "", 450) -> torch.Tensor: 451 """ 452 Each process splits input tensor and then scatters the split list 453 to all processes in a group. Then concatenate the received tensors from all 454 the processes in the group and return single output tensor. 455 456 Group can be one of: 457 List[int]: ranks participating in the collective. 458 List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. 459 ProcessGroup: Will perform a collective using the ranks and tag of the PG. 460 DeviceMesh: Do a SPMD collective over all ranks of the mesh 461 (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh 462 463 :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover 464 that information and perform collective algebraic optimization. Use other forms of input for that. 465 """ 466 if output_split_sizes is not None: 467 assert all( 468 isinstance(size, (int, torch.SymInt)) for size in output_split_sizes 469 ), output_split_sizes 470 if input_split_sizes is not None: 471 assert all( 472 isinstance(size, (int, torch.SymInt)) for size in input_split_sizes 473 ), input_split_sizes 474 group_name = _resolve_group_name(group, tag) 475 group_size = c10d._get_group_size_by_name(group_name) 476 if output_split_sizes is None or input_split_sizes is None: 477 assert output_split_sizes is None and input_split_sizes is None, ( 478 "output_split_sizes and input_split_sizes must either be " 479 "specified together or both set to None" 480 ) 481 output_split_sizes = [self.shape[0] // group_size] * group_size 482 input_split_sizes = output_split_sizes 483 tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined] 484 self, 485 output_split_sizes, 486 input_split_sizes, 487 group_name, 488 ) 489 return _maybe_wrap_tensor(tensor) 490 491 492def all_to_all_single_autograd( 493 self: torch.Tensor, 494 output_split_sizes: Optional[List[int]], 495 input_split_sizes: Optional[List[int]], 496 group: RANK_TYPES, 497 tag: str = "", 498) -> torch.Tensor: 499 """ 500 Same as all_to_all_single but supports autograd. 501 """ 502 if output_split_sizes is not None: 503 assert all( 504 isinstance(size, (int, torch.SymInt)) for size in output_split_sizes 505 ), output_split_sizes 506 if input_split_sizes is not None: 507 assert all( 508 isinstance(size, (int, torch.SymInt)) for size in input_split_sizes 509 ), input_split_sizes 510 511 group_name = _resolve_group_name(group, tag) 512 group_size = c10d._get_group_size_by_name(group_name) 513 if output_split_sizes is None or input_split_sizes is None: 514 assert output_split_sizes is None and input_split_sizes is None, ( 515 "output_split_sizes and input_split_sizes must either be " 516 "specified together or both set to None" 517 ) 518 output_split_sizes = [self.shape[0] // group_size] * group_size 519 input_split_sizes = output_split_sizes 520 tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined] 521 self, 522 output_split_sizes, 523 input_split_sizes, 524 group_name, 525 ) 526 return _FromTorchTensor.apply(tensor) 527 528 529def permute_tensor( 530 self: torch.Tensor, 531 src_dst: List[int], 532 group: RANK_TYPES, 533 tag: str = "", 534) -> torch.Tensor: 535 """ 536 Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should 537 be defined such that src_dst[m] == n means m sends to n. 538 539 Group can be one of: 540 List[int]: ranks participating in the collective. 541 List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. 542 ProcessGroup: Will perform a collective using the ranks and tag of the PG. 543 DeviceMesh: Do a SPMD collective over all ranks of the mesh 544 (DeviceMesh, int): Do a MPMD collective over one 545 """ 546 t, rankset, group_size = _expand_group(group, tag) 547 local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size) 548 549 output_split_sizes = [0] * group_size 550 input_split_sizes = [0] * group_size 551 for src, dst in enumerate(src_dst): 552 if src == dist.get_rank(local_pg): 553 input_split_sizes[dst] = self.numel() 554 if dst == dist.get_rank(local_pg): 555 output_split_sizes[src] = self.numel() 556 557 return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag) 558 559 560class AsyncCollectiveTensor(torch.Tensor): 561 r""" 562 A Tensor wrapper subclass that is used to trigger a call to wait 563 prior to first use of the underlying tensor. 564 Use it inside functional collective pytorch wrappers like the following: 565 def functional_collective(self, group, tag): 566 tag, rankset, group_size = _expand_group(group, tag) 567 tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size) 568 return _maybe_wrap_tensor(tensor) 569 """ 570 elem: torch.Tensor 571 completed: bool 572 573 __slots__ = ["elem", "completed"] 574 575 @staticmethod 576 def __new__(cls, elem: torch.Tensor): 577 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 578 cls, 579 elem.size(), 580 strides=elem.stride(), 581 storage_offset=elem.storage_offset(), 582 dtype=elem.dtype, 583 layout=elem.layout, 584 device=elem.device, 585 requires_grad=elem.requires_grad, 586 ) 587 r.elem = elem 588 r.completed = False 589 return r 590 591 def __tensor_flatten__(self): 592 return ["elem"], None 593 594 def tolist(self): 595 return self.trigger_wait().tolist() 596 597 @staticmethod 598 def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): 599 assert meta is None 600 elem = inner_tensors["elem"] 601 return AsyncCollectiveTensor(elem) 602 603 def __repr__(self): 604 return f"AsyncCollectiveTensor({self.trigger_wait()})" 605 606 def trigger_wait(self): 607 if not self.completed: 608 out = wait_tensor(self.elem) 609 self.completed = True 610 return out 611 else: 612 return self.elem 613 614 def wait(self) -> torch.Tensor: 615 return wait_tensor(self.elem) 616 617 def _get_acs_underlying_tensor(self): 618 """This method enables _functional_collectives_impl to test if a tensor is an ACS""" 619 return self.elem 620 621 @classmethod 622 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 623 if func == torch.ops.aten.view.default: 624 # Fast handle aten.view as a lot of view related op goes to aten.view 625 # eventually, this avoids pytree slowdown 626 res = func(args[0].elem, args[1]) 627 wrapper_res = AsyncCollectiveTensor(res) 628 return wrapper_res 629 630 is_view_op = _is_view_op(func) 631 632 def unwrap(e: AsyncCollectiveTensor): 633 # wait_tensor is idepotent and will do stream sync only once 634 if not is_view_op: 635 return e.trigger_wait() 636 return e.elem 637 638 def wrap(e: torch.Tensor): 639 # wait_tensor is idepotent and will do stream sync only once 640 assert not isinstance(e, AsyncCollectiveTensor) 641 res = AsyncCollectiveTensor(e) 642 return res 643 644 unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args) 645 unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs) 646 647 # we don't wrap the result as it doesn't need to be waited on. 648 out = func(*unwrapped_args, **unwrapped_kwargs) 649 650 # View ops dont require a sync, so we should re-wrap the outputs. 651 if is_view_op: 652 out = tree_map_only(torch.Tensor, wrap, out) 653 654 return out 655 656 def numpy(self): 657 return self.wait().numpy() 658 659 660""" 661Utils and infrastructure for tracing support 662""" 663 664 665def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]: 666 """ 667 _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable. 668 669 By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside 670 torchdynamo and can still interoperate with processgroup objects or other untraceable forms. 671 """ 672 # had to define this hack _inside_ expand_group to avoid 673 # graph_break [('torch.* op returned non-Tensor int 674 # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc) 675 if TYPE_CHECKING: 676 677 def cast_listlistint(x): 678 return cast(List[List[int]], x) 679 680 def cast_listint(x): 681 return cast(List[int], x) 682 683 else: 684 # fake cast op for use at runtime since dynamo doesn't support real cast 685 # also, dynamo didn't like encountering 'typing' objects () 686 # NotImplementedError: argument of type: <class 'typing._GenericAlias'> 687 def cast_listlistint(x): 688 return x 689 690 def cast_listint(x): 691 return x 692 693 rankset: List[int] 694 if isinstance(group, list): 695 if isinstance(group[0], list): 696 nested_list = cast_listlistint(group) 697 rankset = [] 698 group_size = -1 699 for rs in nested_list: 700 rankset.extend(rs) 701 if group_size != -1 and group_size != len(rs): 702 raise ValueError( 703 f"group sizes must be identical found {group_size} and {len(rs)}" 704 ) 705 group_size = len(rs) 706 else: 707 rankset = cast_listint(group) 708 group_size = len(rankset) 709 elif isinstance(group, dist.ProcessGroup): 710 rankset = dist.get_process_group_ranks(group) 711 group_size = len(rankset) 712 tag = tag or c10d._get_group_tag(group) 713 elif isinstance(group, DeviceMesh): 714 assert ( 715 group.ndim == 1 716 ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" 717 # TODO: it should run collective in the whole mesh instead of dim 0 718 tag, rankset, _ = group._dim_group_infos[0] 719 group_size = len(rankset) 720 elif isinstance(group, tuple): 721 if ( 722 len(group) == 2 723 and isinstance(group[0], DeviceMesh) 724 and isinstance(group[1], int) 725 ): 726 dmesh = group[0] 727 dim = group[1] 728 tag, rankset, _ = dmesh._dim_group_infos[dim] 729 group_size = len(rankset) 730 else: 731 raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") 732 else: 733 raise ValueError( 734 "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)." 735 ) 736 737 return (tag, rankset, group_size) 738 739 740def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: 741 """ 742 Given group in RANK_TYPES, return the group name. 743 """ 744 # `tag` will be deprecated. See details in: 745 # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 746 if isinstance(group, dist.ProcessGroup): 747 return group.group_name 748 elif isinstance(group, str): 749 return group 750 elif isinstance(group, DeviceMesh): 751 assert ( 752 group.ndim == 1 753 ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" 754 return group._dim_group_infos[0][2] 755 elif isinstance(group, tuple): 756 if ( 757 len(group) == 2 758 and isinstance(group[0], DeviceMesh) 759 and isinstance(group[1], int) 760 ): 761 dmesh = group[0] 762 dim = group[1] 763 return dmesh._dim_group_infos[dim][2] 764 else: 765 raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") 766 elif isinstance(group, list): 767 if not is_torchdynamo_compiling(): 768 warnings.warn( 769 "The combination of ranks + tag as process group " 770 "identifier has been deprecated. Please switch to " 771 "using ProcessGroup, DeviceMesh, or group name instead.", 772 FutureWarning, 773 stacklevel=3, 774 ) 775 return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag) 776 else: 777 raise ValueError(f"Unsupported group type: {type(group)}, {group}") 778 779 780class _FromTorchTensor(torch.autograd.Function): 781 """ 782 _FromTorchTensor allows autograd to propagate from a normal Tensor to an 783 AsyncCollectiveTensor. 784 """ 785 786 @staticmethod 787 def forward( # type: ignore[override] 788 ctx, # pyre-ignore[2]: Parameter must be annotated. 789 input: torch.Tensor, 790 ) -> torch.Tensor: 791 return _maybe_wrap_tensor(input) 792 793 @staticmethod 794 def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # type: ignore[override] 795 return grad_output 796 797 798def _are_we_tracing() -> bool: 799 if is_torchdynamo_compiling(): 800 return True 801 # If functionalization is turned on, we are almost definitely compiling/tracing. 802 # (In particular, AOTAutograd traces a model once with functionalization on 803 # but proxy tracing turned of, so this is how we detect it). 804 if ( 805 torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) 806 is not None 807 ): 808 return True 809 return get_proxy_mode() is not None 810 811 812def _maybe_wrap_tensor(self) -> torch.Tensor: 813 if _are_we_tracing(): 814 return wait_tensor(self) 815 res = AsyncCollectiveTensor(self) 816 return cast(torch.Tensor, res) 817 818 819def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size): 820 def mk_out_tensor(shard): 821 out_size = list(shard.size()) 822 out_size[0] *= group_size 823 out_tensor = shard.new_empty(out_size) 824 return out_tensor 825 826 return [mk_out_tensor(t) for t in self] 827 828 829# We now register meta kernels to deal with tracing 830def _broadcast_meta(self, *args): 831 return torch.empty_like(self) 832 833 834def _all_reduce_meta(self, *args): 835 return torch.empty_like(self) 836 837 838def _wait_tensor_meta(self, *args): 839 return torch.empty_like(self) 840 841 842def _all_gather_into_tensor_meta(shard, tag, rankset, group_size): 843 out_size = list(shard.size()) 844 out_size[0] *= group_size 845 return shard.new_empty(out_size) 846 847 848def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size): 849 out_size = list(input.size()) 850 out_size[0] //= group_size 851 return input.new_empty(out_size) 852 853 854def _all_reduce_coalesced_meta(self, *args): 855 return [torch.empty_like(t) for t in self] 856 857 858def _all_reduce__meta(inp, *args): 859 return inp 860 861 862def _broadcast__meta(inp, *args): 863 return inp 864 865 866def _all_reduce_coalesced__meta(inputs, *args): 867 return inputs 868 869 870def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size): 871 def mk_out_tensor(input): 872 out_size = list(input.size()) 873 out_size[0] //= group_size 874 out_tensor = input.new_empty(out_size) 875 return out_tensor 876 877 return [mk_out_tensor(t) for t in inputs] 878 879 880# NB: We often say all_to_all has dynamic output size, but this is not 881# technically true: instead, what typically happens is you manually 882# communicate the output_split_sizes ahead of time (which is dynamic), 883# but then you pass those sizes explicitly, and the all to all itself 884# isn't dynamic, it just follows the specified output splits 885def _all_to_all_single_meta( 886 input, output_split_sizes, input_split_sizes, *args, **kwargs 887): 888 if output_split_sizes is None: 889 return input.new_empty(input.size()) 890 else: 891 for s in output_split_sizes: 892 torch._check_is_size(s) 893 out_size = list(input.size()) 894 out_size[0] = sum(output_split_sizes) 895 return input.new_empty(out_size) 896 897 898def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out): 899 shape = list(input.size()) 900 shape[0] *= group_size 901 return input.new_empty(shape) 902 903 904def _all_gather_into_tensor_native_meta(input, group_size, group_name): 905 shape = list(input.size()) 906 shape[0] *= group_size 907 return input.new_empty(shape) 908 909 910def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name): 911 return [ 912 _all_gather_into_tensor_native_meta(input, group_size, group_name) 913 for input in inputs 914 ] 915 916 917def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name): 918 shape = list(inp.size()) 919 shape[0] //= group_size 920 return inp.new_empty(shape) 921 922 923def _reduce_scatter_tensor_coalesced_native_meta( 924 inputs, reduce_op, group_size, group_name 925): 926 return [ 927 _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name) 928 for inp in inputs 929 ] 930 931 932if not torch._running_with_deploy(): 933 # Library MUST be defined at module scope or it doesn't work 934 # Creating a "DEF" Library always crashes torch::deploy so we create our 935 # Library instances here guarded against running inside it 936 lib_impl = torch.library.Library("_c10d_functional", "IMPL") 937 lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") 938 lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") 939 lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") 940 lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") 941 lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") 942 lib_impl.impl( 943 "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" 944 ) 945 lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") 946 lib_impl.impl( 947 "all_gather_into_tensor_coalesced", 948 _all_gather_into_tensor_coalesced_native_meta, 949 "Meta", 950 ) 951 lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") 952 lib_impl.impl( 953 "reduce_scatter_tensor_coalesced", 954 _reduce_scatter_tensor_coalesced_native_meta, 955 "Meta", 956 ) 957 lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") 958 lib_impl.impl("broadcast", _broadcast_meta, "Meta") 959 lib_impl.impl("broadcast_", _broadcast__meta, "Meta") 960 961 # mark these ops has side effect so that they won't be removed by DCE 962 torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) 963 torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) 964 965 # Register legacy ops for backward compatibility 966 # TODO(yifu): remove these in functional collective beta release 967 legacy_lib = torch.library.Library("c10d_functional", "DEF") 968 legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") 969 ops_defs = [ 970 "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", 971 "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", 972 "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", 973 "wait_tensor(Tensor self) -> Tensor", 974 "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", 975 "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", 976 "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", 977 "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", 978 "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 979 ] 980 981 my_module = sys.modules[__name__] 982 for op_def in ops_defs: 983 op_name = op_def[0 : op_def.index("(")] 984 backend_impl = getattr(fun_col_impl, f"_{op_name}") 985 legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) 986 legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") 987 988else: 989 warnings.warn( 990 "PyTorch Distributed functional collectives do not work with torch::deploy." 991 ) 992 993 994""" 995Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into 996functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph. 997 998We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via 999the mapping dict below. 1000 1001These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from 1002""" 1003 1004 1005def all_gather_tensor_inplace( 1006 output_tensor: torch.Tensor, 1007 input_tensor: torch.Tensor, 1008 group, # TODO add a type, 1009 async_op: bool = False, 1010 tag: str = "", 1011 gather_dim: int = 0, 1012): 1013 assert ( 1014 not async_op 1015 ), "Can't remap async version of inplace op to functional collective" 1016 1017 group = group or dist.group.WORLD 1018 assert group is not None 1019 1020 return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) 1021 1022 1023def reduce_scatter_tensor_inplace( 1024 output: torch.Tensor, 1025 input: torch.Tensor, 1026 op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok? 1027 group=None, # TODO add a type 1028 async_op: bool = False, 1029 scatter_dim: int = 0, 1030 tag: str = "", 1031): 1032 assert ( 1033 not async_op 1034 ), "Can't remap async version of inplace op to functional collective" 1035 1036 group = group or dist.group.WORLD 1037 assert group is not None 1038 1039 return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag)) 1040 1041 1042REDUCE_OP_TO_STR = { 1043 dist.ReduceOp.SUM: "sum", 1044 dist.ReduceOp.AVG: "avg", 1045 dist.ReduceOp.PRODUCT: "product", 1046 dist.ReduceOp.MIN: "min", 1047 dist.ReduceOp.MAX: "max", 1048 dist.ReduceOp.BAND: "band", 1049 dist.ReduceOp.BOR: "bor", 1050 dist.ReduceOp.BXOR: "bxor", 1051} 1052 1053 1054def all_reduce_inplace( 1055 tensor: torch.Tensor, 1056 op: str = "sum", 1057 group=None, 1058 async_op: bool = False, 1059 tag: str = "", 1060): 1061 assert ( 1062 not async_op 1063 ), "Can't remap async version of inplace op to functional collective" 1064 1065 group = group or dist.group.WORLD 1066 assert group is not None 1067 1068 return tensor.copy_(all_reduce(tensor, op, group, tag)) 1069 1070 1071def all_to_all_inplace( 1072 output: torch.Tensor, 1073 input: torch.Tensor, 1074 output_split_sizes=None, 1075 input_split_sizes=None, 1076 group=None, 1077 async_op=False, 1078 tag: str = "", 1079): 1080 assert ( 1081 not async_op 1082 ), "Can't remap async version of inplace op to functional collective" 1083 1084 group = group or dist.group.WORLD 1085 assert group is not None 1086 1087 return output.copy_( 1088 all_to_all_single( 1089 input, 1090 output_split_sizes, 1091 input_split_sizes, 1092 group, 1093 tag, 1094 ) 1095 ) 1096 1097 1098def all_gather_inplace( 1099 tensor_list: List[torch.Tensor], 1100 tensor: torch.Tensor, 1101 group=None, 1102 async_op=False, 1103 tag: str = "", 1104): 1105 assert ( 1106 not async_op 1107 ), "Can't remap async version of inplace op to functional collective" 1108 assert all( 1109 t.size(0) == tensor.size(0) for t in tensor_list 1110 ), "Remapping variable size all_gather is not yet supported" 1111 1112 group = group or dist.group.WORLD 1113 assert group is not None 1114 1115 output = all_gather_tensor(tensor, 0, group, tag) 1116 1117 # Use aten.slice instead of aten.split because the latter causes 1118 # tensor.shape(0) to be unnecessarily baked in when it's a SymInt. 1119 output_splits = [] 1120 offset = 0 1121 for t in tensor_list: 1122 output_splits.append(output[offset : offset + t.size(0)]) 1123 offset += t.size(0) 1124 for dst, src in zip(tensor_list, output_splits): 1125 dst.copy_(src) 1126 return tensor_list 1127 1128 1129from torch.distributed.distributed_c10d import ( 1130 _all_gather_base as legacy_all_gather_base, 1131 _reduce_scatter_base as legacy_reduce_scatter_base, 1132 all_gather as legacy_all_gather, 1133 all_gather_into_tensor as legacy_allgather, 1134 all_reduce as legacy_allreduce, 1135 all_to_all_single as legacy_all_to_all_single, 1136 reduce_scatter_tensor as legacy_reducescatter, 1137) 1138 1139 1140# This dict should contain sets of functions that dynamo is allowed to remap. 1141# Functions in this set should accept the same args/kwargs 1:1 as their mapping. 1142traceable_collective_remaps = { 1143 legacy_allgather: all_gather_tensor_inplace, 1144 legacy_reducescatter: reduce_scatter_tensor_inplace, 1145 legacy_allreduce: all_reduce_inplace, 1146 legacy_all_to_all_single: all_to_all_inplace, 1147 legacy_all_gather: all_gather_inplace, 1148 legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, 1149 legacy_all_gather_base: all_gather_tensor_inplace, 1150} 1151