1 #pragma once 2 3 #include <torch/csrc/distributed/c10d/Backend.hpp> 4 #include <memory> 5 #include <unordered_map> 6 #include <utility> 7 #include <vector> 8 9 #include <ATen/ATen.h> 10 #include <ATen/core/dispatch/Dispatcher.h> 11 #include <c10/macros/Macros.h> 12 13 #include <torch/csrc/distributed/c10d/Work.hpp> 14 // ************************************************************************* 15 // PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN 16 // versions 1.7 and 1.8. 17 // PLEASE DO NOT ADD ANY DEPENDENCIES. 18 // SEE RFC: https://github.com/pytorch/pytorch/issues/39662 19 // ************************************************************************* 20 21 constexpr auto kProcessGroupDefaultTimeout = 22 std::chrono::milliseconds(30 * 60 * 1000); 23 24 namespace c10d { 25 26 // ProcessGroup is a base class that captures collective and point to 27 // point communication in a fixed set of processes. 28 // 29 // The functions specified in the class below describe the API alone; 30 // implementations are provided in subclasses. 31 // 32 // Every function that performs I/O is executed asynchronously by a 33 // thread pool owned by the ProcessGroup (by default). They return an 34 // object that can be used to wait for completion or error. 35 // 36 // The ProcessGroup can instantiate subgroups with fewer or an equal 37 // number of members. Implementations must take care that multiple 38 // process groups can be used in parallel and synchronize accordingly. 39 // 40 // The ProcessGroup assumes a fixed set of processes. If the set 41 // changes, existing instances must be destructed and instantiation 42 // and initialization must start from scratch. For members of the 43 // process group to find each other (referred to as rendezvous from 44 // hereon) 45 // 46 class TORCH_API ProcessGroup : public torch::CustomClassHolder { 47 public: 48 // ProcessGroup Options is a base struct that defines the basic options 49 // when constructing a ProcessGroup. Each ProcessGroup subclass should 50 // extend this struct and define its options if it wants to provide more 51 // config options (beyond basic ones defined here) to end user. 52 struct TORCH_API Options : torch::CustomClassHolder { Optionsc10d::ProcessGroup::Options53 explicit Options( 54 std::string backend, 55 std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout) 56 : timeout(timeout), backend(std::move(backend)) {} 57 ~Options() override = default; 58 59 std::chrono::milliseconds timeout; 60 61 // backend name 62 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 63 const std::string backend; 64 }; 65 66 enum BackendType : uint8_t { 67 UNDEFINED = 0, 68 GLOO = 1, 69 NCCL = 2, 70 UCC = 3, 71 MPI = 4, 72 CUSTOM = 5, 73 }; 74 75 // Not used, set for backwards compatibility and only used for TypeDef in 76 // Ops.cpp 77 explicit ProcessGroup(int rank, int size); 78 79 explicit ProcessGroup( 80 const c10::intrusive_ptr<::c10d::Store>& store, 81 int rank, 82 int size, 83 c10::intrusive_ptr<Options> options); 84 ~ProcessGroup() override; 85 getRank() const86 int getRank() const { 87 return rank_; 88 } 89 getSize() const90 int getSize() const { 91 return size_; 92 } 93 94 // Returns an unique opaque ID of this process group object. getID() const95 int64_t getID() const { 96 return reinterpret_cast<std::intptr_t>(this); 97 } 98 99 // Returns an unique opaque ID of a backend for the specific backend type 100 // that can correlate with this process group's collectives. getBackendID(BackendType backend_type) const101 int64_t getBackendID(BackendType backend_type) const { 102 return reinterpret_cast<std::intptr_t>(getBackend(backend_type).get()); 103 } 104 getBackendName() const105 virtual const std::string getBackendName() const { 106 return options_->backend; 107 }; 108 getBackendType() const109 BackendType getBackendType() const { 110 return backendType_; 111 }; 112 startCoalescing(c10::DeviceType deviceType)113 virtual void startCoalescing(c10::DeviceType deviceType) { 114 // only nccl has implemented startCoalescing so only execute for nccl 115 // backends 116 auto backend = getBackend(deviceType); 117 backend->startCoalescing(); 118 } 119 endCoalescing(c10::DeviceType deviceType)120 virtual c10::intrusive_ptr<Work> endCoalescing(c10::DeviceType deviceType) { 121 // only nccl has implemented endCoalescing so only execute for nccl 122 // backends 123 auto backend = getBackend(deviceType); 124 auto work = backend->endCoalescing(); 125 return work; 126 } 127 broadcast(std::vector<at::Tensor> & tensors,const BroadcastOptions & opts=BroadcastOptions ())128 virtual c10::intrusive_ptr<Work> broadcast( 129 std::vector<at::Tensor>& tensors, 130 const BroadcastOptions& opts = BroadcastOptions()) { 131 static auto op = 132 c10::Dispatcher::singleton() 133 .findSchemaOrThrow("c10d::broadcast_", "") 134 .typed< 135 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( 136 at::TensorList, 137 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 138 int64_t, 139 int64_t, 140 bool, 141 int64_t)>(); 142 // It's awakward to unbox the opts here and box them again in the custom C++ 143 // op. But it's also complicated to make opts as a CustomClassHolder. Leave 144 // it as it is now. 145 return std::get<1>(op.call( 146 tensors, 147 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 148 opts.rootRank, 149 opts.rootTensor, 150 opts.asyncOp, 151 opts.timeout.count())); 152 } 153 allreduce(std::vector<at::Tensor> & tensors,const AllreduceOptions & opts=AllreduceOptions ())154 virtual c10::intrusive_ptr<Work> allreduce( 155 std::vector<at::Tensor>& tensors, 156 const AllreduceOptions& opts = AllreduceOptions()) { 157 static auto op = 158 c10::Dispatcher::singleton() 159 .findSchemaOrThrow("c10d::allreduce_", "") 160 .typed< 161 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( 162 at::TensorList, 163 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 164 const c10::intrusive_ptr<::c10d::ReduceOp>&, 165 const std::optional<at::Tensor>& sparse_indices, 166 int64_t)>(); 167 168 return std::get<1>(op.call( 169 tensors, 170 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 171 c10::make_intrusive<ReduceOp>(opts.reduceOp), 172 opts.sparseIndices, 173 opts.timeout.count())); 174 } 175 allreduce_coalesced(std::vector<at::Tensor> & tensors,const AllreduceCoalescedOptions & opts=AllreduceCoalescedOptions ())176 virtual c10::intrusive_ptr<Work> allreduce_coalesced( 177 std::vector<at::Tensor>& tensors, 178 const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) { 179 static auto op = c10::Dispatcher::singleton() 180 .findSchemaOrThrow("c10d::allreduce_coalesced_", "") 181 .typed<c10::intrusive_ptr<::c10d::Work>( 182 at::TensorList, 183 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 184 const c10::intrusive_ptr<::c10d::ReduceOp>&, 185 int64_t)>(); 186 187 return op.call( 188 tensors, 189 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 190 c10::make_intrusive<ReduceOp>(opts.reduceOp), 191 opts.timeout.count()); 192 } 193 reduce(std::vector<at::Tensor> & tensors,const ReduceOptions & opts=ReduceOptions ())194 virtual c10::intrusive_ptr<Work> reduce( 195 std::vector<at::Tensor>& tensors, 196 const ReduceOptions& opts = ReduceOptions()) { 197 static auto op = c10::Dispatcher::singleton() 198 .findSchemaOrThrow("c10d::reduce_", "") 199 .typed<c10::intrusive_ptr<::c10d::Work>( 200 at::TensorList, 201 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 202 const c10::intrusive_ptr<::c10d::ReduceOp>&, 203 int64_t, 204 int64_t, 205 int64_t)>(); 206 return op.call( 207 tensors, 208 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 209 c10::make_intrusive<ReduceOp>(opts.reduceOp), 210 opts.rootRank, 211 opts.rootTensor, 212 opts.timeout.count()); 213 } 214 allgather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts=AllgatherOptions ())215 virtual c10::intrusive_ptr<Work> allgather( 216 std::vector<std::vector<at::Tensor>>& outputTensors, 217 std::vector<at::Tensor>& inputTensors, 218 const AllgatherOptions& opts = AllgatherOptions()) { 219 static auto op = c10::Dispatcher::singleton() 220 .findSchemaOrThrow("c10d::allgather_", "") 221 .typed<std::tuple< 222 std::vector<std::vector<at::Tensor>>, 223 c10::intrusive_ptr<Work>>( 224 const std::vector<std::vector<at::Tensor>>&, 225 at::TensorList, 226 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 227 int64_t)>(); 228 229 return std::get<1>(op.call( 230 outputTensors, 231 inputTensors, 232 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 233 opts.timeout.count())); 234 } 235 236 // Gathers a single tensor inputBuffer into a single buffer outputBuffer that 237 // is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE. 238 // For implementers of ProcessGroup API and advanced users only. 239 // Note: this function will be deprecated in near future. _allgather_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const AllgatherOptions & opts=AllgatherOptions ())240 virtual c10::intrusive_ptr<Work> _allgather_base( 241 at::Tensor& outputBuffer, 242 at::Tensor& inputBuffer, 243 const AllgatherOptions& opts = AllgatherOptions()) { 244 static auto op = 245 c10::Dispatcher::singleton() 246 .findSchemaOrThrow("c10d::_allgather_base_", "") 247 .typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>( 248 at::Tensor&, 249 at::Tensor&, 250 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 251 bool, 252 int64_t)>(); 253 254 return std::get<1>(op.call( 255 outputBuffer, 256 inputBuffer, 257 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 258 opts.asyncOp, 259 opts.timeout.count())); 260 } 261 262 // This function is deprecated and will be moved out of ProcessGroup to comms: 263 // * do not add dependencies on this function, 264 // * do not implement it in your ProcessGroup, implement _allgather_base 265 // instead. allgather_coalesced(std::vector<std::vector<at::Tensor>> & outputTensorLists,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts=AllgatherOptions ())266 virtual c10::intrusive_ptr<Work> allgather_coalesced( 267 std::vector<std::vector<at::Tensor>>& outputTensorLists, 268 std::vector<at::Tensor>& inputTensors, 269 const AllgatherOptions& opts = AllgatherOptions()) { 270 static auto op = 271 c10::Dispatcher::singleton() 272 .findSchemaOrThrow("c10d::allgather_coalesced_", "") 273 .typed<c10::intrusive_ptr<Work>( 274 const std::vector<std::vector<at::Tensor>>&, 275 const at::TensorList&, 276 const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); 277 278 return op.call( 279 outputTensorLists, 280 inputTensors, 281 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this)); 282 } 283 284 // This function is a coalesced version of `allgather_into_tensor` (currently 285 // still named as `_allgather_base`). Each tensor in the vector corresponds to 286 // an input/output of one `allgather_into_tensor` operation. allgather_into_tensor_coalesced(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllgatherOptions & opts=AllgatherOptions ())287 virtual c10::intrusive_ptr<Work> allgather_into_tensor_coalesced( 288 std::vector<at::Tensor>& outputTensors, 289 std::vector<at::Tensor>& inputTensors, 290 const AllgatherOptions& opts = AllgatherOptions()) { 291 static auto op = 292 c10::Dispatcher::singleton() 293 .findSchemaOrThrow("c10d::allgather_into_tensor_coalesced_", "") 294 .typed<c10::intrusive_ptr<Work>( 295 const at::TensorList, 296 const at::TensorList, 297 const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); 298 299 return op.call( 300 outputTensors, 301 inputTensors, 302 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this)); 303 } 304 gather(std::vector<std::vector<at::Tensor>> & outputTensors,std::vector<at::Tensor> & inputTensors,const GatherOptions & opts=GatherOptions ())305 virtual c10::intrusive_ptr<Work> gather( 306 std::vector<std::vector<at::Tensor>>& outputTensors, 307 std::vector<at::Tensor>& inputTensors, 308 const GatherOptions& opts = GatherOptions()) { 309 static auto op = c10::Dispatcher::singleton() 310 .findSchemaOrThrow("c10d::gather_", "") 311 .typed<c10::intrusive_ptr<::c10d::Work>( 312 const std::vector<std::vector<at::Tensor>>&, 313 const at::TensorList&, 314 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 315 int64_t, 316 int64_t)>(); 317 return op.call( 318 outputTensors, 319 inputTensors, 320 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 321 opts.rootRank, 322 opts.timeout.count()); 323 } 324 scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ScatterOptions & opts=ScatterOptions ())325 virtual c10::intrusive_ptr<Work> scatter( 326 std::vector<at::Tensor>& outputTensors, 327 std::vector<std::vector<at::Tensor>>& inputTensors, 328 const ScatterOptions& opts = ScatterOptions()) { 329 static auto op = 330 c10::Dispatcher::singleton() 331 .findSchemaOrThrow("c10d::scatter_", "") 332 .typed< 333 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( 334 const at::TensorList&, 335 const std::vector<std::vector<at::Tensor>>&, 336 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 337 int64_t, 338 bool, 339 int64_t)>(); 340 return std::get<1>(op.call( 341 outputTensors, 342 inputTensors, 343 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 344 opts.rootRank, 345 opts.asyncOp, 346 opts.timeout.count())); 347 } 348 reduce_scatter(std::vector<at::Tensor> & outputTensors,std::vector<std::vector<at::Tensor>> & inputTensors,const ReduceScatterOptions & opts=ReduceScatterOptions ())349 virtual c10::intrusive_ptr<Work> reduce_scatter( 350 std::vector<at::Tensor>& outputTensors, 351 std::vector<std::vector<at::Tensor>>& inputTensors, 352 const ReduceScatterOptions& opts = ReduceScatterOptions()) { 353 static auto op = 354 c10::Dispatcher::singleton() 355 .findSchemaOrThrow("c10d::reduce_scatter_", "") 356 .typed< 357 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( 358 const at::TensorList&, 359 const std::vector<std::vector<at::Tensor>>&, 360 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 361 const c10::intrusive_ptr<::c10d::ReduceOp>&, 362 int64_t)>(); 363 return std::get<1>(op.call( 364 outputTensors, 365 inputTensors, 366 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 367 c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), 368 opts.timeout.count())); 369 } 370 _reduce_scatter_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,const ReduceScatterOptions & opts=ReduceScatterOptions ())371 virtual c10::intrusive_ptr<Work> _reduce_scatter_base( 372 at::Tensor& outputBuffer, 373 at::Tensor& inputBuffer, 374 const ReduceScatterOptions& opts = ReduceScatterOptions()) { 375 static auto op = 376 c10::Dispatcher::singleton() 377 .findSchemaOrThrow("c10d::_reduce_scatter_base_", "") 378 .typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>( 379 at::Tensor&, 380 at::Tensor&, 381 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 382 const c10::intrusive_ptr<::c10d::ReduceOp>&, 383 bool, 384 int64_t)>(); 385 return std::get<1>(op.call( 386 outputBuffer, 387 inputBuffer, 388 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 389 c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), 390 opts.asyncOp, 391 opts.timeout.count())); 392 } 393 394 // This function is a coalesced version of `reduce_scatter_tensor` (currently 395 // still named as `_reduce_scatter_base`). Each tensor in the vector 396 // corresponds to an input/output of one `reduce_scatter_tensor` operation. reduce_scatter_tensor_coalesced(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const ReduceScatterOptions & opts=ReduceScatterOptions ())397 virtual c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced( 398 std::vector<at::Tensor>& outputTensors, 399 std::vector<at::Tensor>& inputTensors, 400 const ReduceScatterOptions& opts = ReduceScatterOptions()) { 401 static auto op = 402 c10::Dispatcher::singleton() 403 .findSchemaOrThrow("c10d::reduce_scatter_tensor_coalesced_", "") 404 .typed<c10::intrusive_ptr<Work>( 405 const at::TensorList, 406 const at::TensorList, 407 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 408 const c10::intrusive_ptr<::c10d::ReduceOp>&, 409 int64_t)>(); 410 411 return op.call( 412 outputTensors, 413 inputTensors, 414 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 415 c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), 416 opts.timeout.count()); 417 } 418 alltoall_base(at::Tensor & outputBuffer,at::Tensor & inputBuffer,std::vector<int64_t> & outputSplitSizes,std::vector<int64_t> & inputSplitSizes,const AllToAllOptions & opts=AllToAllOptions ())419 virtual c10::intrusive_ptr<Work> alltoall_base( 420 at::Tensor& outputBuffer, 421 at::Tensor& inputBuffer, 422 std::vector<int64_t>& outputSplitSizes, 423 std::vector<int64_t>& inputSplitSizes, 424 const AllToAllOptions& opts = AllToAllOptions()) { 425 static auto op = c10::Dispatcher::singleton() 426 .findSchemaOrThrow("c10d::alltoall_base_", "") 427 .typed<c10::intrusive_ptr<::c10d::Work>( 428 at::Tensor&, 429 at::Tensor&, 430 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 431 std::vector<int64_t>, 432 std::vector<int64_t>, 433 int64_t)>(); 434 return op.call( 435 outputBuffer, 436 inputBuffer, 437 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 438 outputSplitSizes, 439 inputSplitSizes, 440 opts.timeout.count()); 441 } 442 alltoall(std::vector<at::Tensor> & outputTensors,std::vector<at::Tensor> & inputTensors,const AllToAllOptions & opts=AllToAllOptions ())443 virtual c10::intrusive_ptr<Work> alltoall( 444 std::vector<at::Tensor>& outputTensors, 445 std::vector<at::Tensor>& inputTensors, 446 const AllToAllOptions& opts = AllToAllOptions()) { 447 static auto op = 448 c10::Dispatcher::singleton() 449 .findSchemaOrThrow("c10d::alltoall_", "") 450 .typed< 451 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( 452 const at::TensorList&, 453 const at::TensorList&, 454 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 455 int64_t)>(); 456 return std::get<1>(op.call( 457 outputTensors, 458 inputTensors, 459 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 460 opts.timeout.count())); 461 } 462 monitoredBarrier(const BarrierOptions & opts,bool wait_all_ranks=false)463 virtual void monitoredBarrier( 464 const BarrierOptions& opts, 465 bool wait_all_ranks = false) { 466 static auto op = c10::Dispatcher::singleton() 467 .findSchemaOrThrow("c10d::monitored_barrier_", "") 468 .typed<void( 469 at::Tensor, 470 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 471 const std::vector<int64_t>&, 472 int64_t, 473 bool)>(); 474 // Default to using cpu implementation, monitored barrier is only for GLOO 475 at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU)); 476 op.call( 477 tensor, 478 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 479 opts.device_ids, 480 opts.timeout.count(), 481 wait_all_ranks); 482 } 483 484 // Agrees on an initial sequence number for the whole group by having rank 0 485 // create it and broadcast it to other ranks using the store. Only implemented 486 // for GLOO and NCCL backends currently. setSequenceNumberForGroup()487 virtual void setSequenceNumberForGroup() { 488 auto backendType = getBackendType(); 489 // TODO: HACK for backend name to get sequence number for that backend. 490 if (backendType == ProcessGroup::BackendType::GLOO || 491 backendType == ProcessGroup::BackendType::NCCL || 492 backendType == ProcessGroup::BackendType::UCC) { 493 getDefaultBackend()->setSequenceNumberForGroup(); 494 } else { 495 TORCH_CHECK( 496 false, 497 c10::str( 498 "ProcessGroup ", 499 getBackendName(), 500 " does not yet support sequence numbers.")); 501 } 502 } 503 504 // Retrieves the current sequence number for the whole group, which should be 505 // in sync. If the returned number is not consistent across the group, it 506 // may indicate that there is some sort of collective desynchronization. getSequenceNumberForGroup()507 virtual uint64_t getSequenceNumberForGroup() { 508 auto backendType = getBackendType(); 509 510 // TODO: HACK for backend name to get sequence number for that backend. 511 if (backendType == ProcessGroup::BackendType::GLOO || 512 backendType == ProcessGroup::BackendType::NCCL || 513 backendType == ProcessGroup::BackendType::UCC) { 514 return getDefaultBackend()->getSequenceNumberForGroup(); 515 } else { 516 TORCH_CHECK( 517 false, 518 c10::str( 519 "ProcessGroup ", 520 getBackendName(), 521 " does not yet support sequence numbers.")); 522 } 523 } 524 send(std::vector<at::Tensor> & tensors,int dstRank,int tag)525 virtual c10::intrusive_ptr<Work> send( 526 std::vector<at::Tensor>& tensors, 527 int dstRank, 528 int tag) { 529 static auto op = c10::Dispatcher::singleton() 530 .findSchemaOrThrow("c10d::send", "") 531 .typed<c10::intrusive_ptr<::c10d::Work>( 532 at::TensorList, 533 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 534 int64_t, 535 int64_t)>(); 536 return op.call( 537 tensors, 538 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 539 dstRank, 540 tag); 541 } 542 recv(std::vector<at::Tensor> & tensors,int srcRank,int tag)543 virtual c10::intrusive_ptr<Work> recv( 544 std::vector<at::Tensor>& tensors, 545 int srcRank, 546 int tag) { 547 static auto op = c10::Dispatcher::singleton() 548 .findSchemaOrThrow("c10d::recv_", "") 549 .typed<c10::intrusive_ptr<::c10d::Work>( 550 at::TensorList, 551 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 552 int64_t, 553 int64_t)>(); 554 return op.call( 555 tensors, 556 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 557 srcRank, 558 tag); 559 } 560 recvAnysource(std::vector<at::Tensor> & tensors,int tag)561 virtual c10::intrusive_ptr<Work> recvAnysource( 562 std::vector<at::Tensor>& tensors, 563 int tag) { 564 static auto op = c10::Dispatcher::singleton() 565 .findSchemaOrThrow("c10d::recv_any_source_", "") 566 .typed<c10::intrusive_ptr<::c10d::Work>( 567 at::TensorList, 568 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 569 int64_t)>(); 570 return op.call( 571 tensors, 572 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 573 tag); 574 } 575 barrier(const BarrierOptions & opts=BarrierOptions ())576 virtual c10::intrusive_ptr<Work> barrier( 577 const BarrierOptions& opts = BarrierOptions()) { 578 static at::Tensor tensor; 579 // TODO: if nccl was specified then use it 580 auto device = opts.device; 581 if (device.has_value()) { 582 // set device tensor from argument 583 tensor = at::empty( 584 {1}, at::TensorOptions().device(device.value()).dtype(at::kByte)); 585 } else if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) { 586 // set cuda tensor 587 tensor = at::empty( 588 {1}, 589 at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte)); 590 } else { 591 // Default to using cpu implementation 592 tensor = at::empty( 593 {1}, 594 at::TensorOptions().device(at::DeviceType::CPU).dtype(at::kByte)); 595 } 596 597 static auto op = c10::Dispatcher::singleton() 598 .findSchemaOrThrow("c10d::barrier", "") 599 .typed<c10::intrusive_ptr<::c10d::Work>( 600 at::Tensor, 601 const c10::intrusive_ptr<::c10d::ProcessGroup>&, 602 const std::vector<int64_t>&, 603 int64_t)>(); 604 605 return op.call( 606 tensor, 607 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), 608 opts.device_ids, 609 opts.timeout.count()); 610 } 611 getOptions()612 c10::intrusive_ptr<Options> getOptions() { 613 return options_; 614 } 615 hasBackends()616 bool hasBackends() { 617 return !deviceTypeToBackendType_.empty(); 618 } 619 setBackend(c10::DeviceType deviceType,BackendType backendType,const std::optional<c10::intrusive_ptr<Backend>> & backend)620 void setBackend( 621 c10::DeviceType deviceType, 622 BackendType backendType, 623 const std::optional<c10::intrusive_ptr<Backend>>& backend) { 624 // TODO: should we add these entries after the backend setting succeeds? 625 deviceTypeToBackendType_[deviceType] = backendType; 626 deviceTypes_.insert(deviceType); 627 // if the backendType is already set then reuse it for this device 628 if (backendTypeToBackend_.find(backendType) != 629 backendTypeToBackend_.end()) { 630 auto existingBackend = backendTypeToBackend_.at(backendType); 631 deviceTypeToBackend_[deviceType] = existingBackend; 632 TORCH_CHECK( 633 existingBackend->getBoundDeviceId() == 634 (*backend)->getBoundDeviceId()); 635 } else { 636 // check if backend has value 637 if (backend.has_value()) { 638 deviceTypeToBackend_[deviceType] = backend.value(); 639 backendTypeToBackend_[backendType] = backend.value(); 640 (*backend)->setBoundDeviceId(bound_device_id_); 641 } 642 } 643 } 644 getDefaultBackend() const645 c10::intrusive_ptr<Backend> getDefaultBackend() const { 646 TORCH_CHECK( 647 backendTypeToBackend_.find(backendType_) != backendTypeToBackend_.end(), 648 "Could not find the default backend type ", 649 backendType_, 650 " for Process Group with name ", 651 getBackendName(), 652 "."); 653 return backendTypeToBackend_.at(backendType_); 654 } 655 656 c10::intrusive_ptr<Backend> getBackend(c10::DeviceType deviceType); 657 getBackend(BackendType backendType) const658 c10::intrusive_ptr<Backend> getBackend(BackendType backendType) const { 659 TORCH_CHECK( 660 backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end(), 661 "Could not find backend type ", 662 backendType, 663 "."); 664 return backendTypeToBackend_.at(backendType); 665 } 666 667 // Return device types supported by this ProcessGroup. 668 // Note: the return type is `Device` rather than `DeviceType` for the purpose 669 // of easy comparison at Python level. The `Device` will have default index 670 // (-1). getDeviceTypes() const671 std::vector<c10::Device> getDeviceTypes() const { 672 std::vector<c10::Device> devices; 673 devices.reserve(deviceTypes_.size()); 674 for (auto& dt : deviceTypes_) { 675 devices.emplace_back(dt); 676 } 677 return devices; 678 } 679 registerOnCompletionHook(std::function<void (std::shared_ptr<WorkInfo>)> && hook)680 void registerOnCompletionHook( 681 std::function<void(std::shared_ptr<WorkInfo>)>&& hook) { 682 getDefaultBackend()->registerOnCompletionHook(std::move(hook)); 683 } 684 waitForPendingWorks()685 void waitForPendingWorks() { 686 getDefaultBackend()->waitForPendingWorks(); 687 } 688 hasHooks() const689 bool hasHooks() const { 690 return getDefaultBackend()->hasHooks(); 691 } 692 693 const std::string& getGroupName() const; 694 void setGroupName(const std::string& name); 695 const std::string& getGroupDesc() const; 696 void setGroupDesc(const std::string& name); 697 void enableCollectivesTiming(); 698 699 void release_resources() override; 700 701 // ProcessGroups optionally can be "bound" to a specific device. 702 // Currently this is only for nccl and allows for some opt-in 703 // optimizations such as automatic use of ncclCommSplit. The device 704 // is specified in `init_process_group` and eventually makes it 705 // here and then down into the actual backend instances. getBoundDeviceId() const706 std::optional<at::Device> getBoundDeviceId() const { 707 return bound_device_id_; 708 } 709 setBoundDeviceId(std::optional<at::Device> device)710 void setBoundDeviceId(std::optional<at::Device> device) { 711 if (device) { 712 TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index"); 713 } 714 bound_device_id_ = device; 715 } 716 717 protected: 718 // Implementations of this interface need to call this to setup 719 // appropriate logging etc. 720 void init(); 721 722 c10::intrusive_ptr<c10d::Store> store_; 723 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 724 const int rank_; 725 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 726 const int size_; 727 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 728 const c10::intrusive_ptr<Options> options_; 729 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 730 const BackendType backendType_; 731 std::string pg_desc_; 732 733 // Debug level setting. It is parsed once when ProcessGroup is constructed and 734 // remains the same across use of this process group. 735 DebugLevel dist_debug_level_{DebugLevel::Off}; 736 737 // Backend classes for this ProcessGroup 738 std::unordered_set<c10::DeviceType> deviceTypes_; 739 std::unordered_map<c10::DeviceType, BackendType> deviceTypeToBackendType_; 740 std::unordered_map<c10::DeviceType, c10::intrusive_ptr<Backend>> 741 deviceTypeToBackend_; 742 std::unordered_map<BackendType, c10::intrusive_ptr<Backend>> 743 backendTypeToBackend_; 744 745 std::optional<at::Device> bound_device_id_; 746 }; 747 748 } // namespace c10d 749