1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <ATen/cuda/CUDAEvent.h> 5 #include <c10/cuda/CUDAStream.h> 6 #include <torch/csrc/distributed/c10d/Store.hpp> 7 #include <torch/csrc/distributed/c10d/SymmetricMemory.hpp> 8 #include <torch/csrc/distributed/c10d/Work.hpp> 9 10 namespace c10d::intra_node_comm { 11 12 using namespace c10d::symmetric_memory; 13 14 constexpr size_t kMaxDevices = 8; 15 constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024; 16 constexpr size_t kP2pStateSize = 2048; 17 18 using NvlMesh = std::array<std::array<size_t, kMaxDevices>, kMaxDevices>; 19 using HybridCubeMesh = std::array<std::array<int, 4>, kMaxDevices>; 20 21 enum class Topology : uint8_t { 22 UNKNOWN = 0, 23 FULLY_CONNECTED = 1, 24 HYBRID_CUBE_MESH = 2 25 }; 26 27 enum class AllReduceAlgo : uint8_t { 28 NONE = 0, 29 ONE_SHOT = 1, 30 TWO_SHOT = 2, 31 HCM = 3 32 }; 33 34 // NOTE: this class will be be removed soon in favor of SymmetricMemory 35 class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { 36 public: 37 IntraNodeComm( 38 c10::intrusive_ptr<c10d::Store> store, 39 size_t rank, 40 size_t worldSize, 41 std::optional<size_t> bufferSize = std::nullopt); 42 43 ~IntraNodeComm() override; 44 45 static bool isEnabled(); 46 47 /** 48 * Performs rendezvous. 49 * If rendezvous fails, the IntraNodeComm object will be in an invalid 50 * state and it is the caller's responsibility to dispose it. 51 */ 52 bool rendezvous(); 53 getTopology()54 Topology getTopology() { 55 return topology_; 56 } 57 getBufferSize()58 size_t getBufferSize() { 59 return bufferSize_; 60 } 61 62 /** 63 * Selects a AllReduceAlgo that we think will outperform nccl. 64 * Returns AllReduceAlgo::NONE if we don't think we can outperform nccl. 65 */ 66 AllReduceAlgo selectAllReduceAlgo(const at::Tensor& input); 67 68 at::Tensor allReduce(const at::Tensor& input, AllReduceAlgo algo); 69 70 /** 71 * Perform a barrier among the specified ranks. 72 */ 73 void barrier(std::optional<std::vector<int64_t>> ranks = std::nullopt); 74 75 at::Tensor getBuffer( 76 size_t rank, 77 const std::vector<int64_t>& sizes, 78 c10::ScalarType dtype, 79 int64_t storageOffset); 80 81 private: 82 at::Tensor oneShotAllReduce( 83 const at::Tensor& input, 84 at::cuda::CUDAStream& stream); 85 86 at::Tensor twoShotAllReduce( 87 const at::Tensor& input, 88 at::cuda::CUDAStream& stream); 89 90 at::Tensor hybridCubeMeshAllReduce( 91 const at::Tensor& input, 92 at::cuda::CUDAStream& stream); 93 94 c10::intrusive_ptr<Store> store_; 95 size_t rank_; 96 size_t worldSize_; 97 size_t bufferSize_; 98 at::cuda::CUDAEvent barrierReady_; 99 100 /** 101 * Members initialized after rendezvous 102 */ 103 bool isInitialized_ = false; 104 int deviceIdx_; 105 Topology topology_ = Topology::UNKNOWN; 106 void* symmetricMemoryPtr_ = nullptr; 107 c10::intrusive_ptr<SymmetricMemory> symmetricMemory_ = nullptr; 108 void* p2pStatesDev_{}; 109 void* buffersDev_{}; 110 void* topoInfo_{}; 111 }; 112 113 /** 114 * NOTE [IntraNodeComm Stream Semantics] 115 * 116 * ProcessGroupNCCL launches kernels differently from the conventional PyTorch 117 * CUDA semantics: it always launches collective kernels onto a dedicated 118 * communication stream. Therefore, it needs to: 119 * 120 * - Synchronize the calling stream and the comm stream. 121 * - Ensure the memory safety of the operands (via record_stream or stashing). 122 * - Synchronize the waiting stream with the comm stream. 123 * 124 * Unconditionally performing these tasks makes sense when we expect most of the 125 * communication to benefit from compute/comm overlap. However, IntraNodeComm 126 * primarily aims to optimize small, latency-sensitive, blocking communication, 127 * in which the overhead incurred by the above steps can be quite pronounced. 128 * 129 * Thus, IntraNodeComm follows the conventional PyTorch CUDA semantics and 130 * launches kernels onto the stream specified by the user. Although the user 131 * can perform neccessary synchronization via wait_stream, to provide a UX 132 * consistent to that of ProcessGroupNCCL, the neccessary stream 133 * synchronization can also be performed via IntraNodeWork::wait(). 134 */ 135 class IntraNodeCommWork : public c10d::Work { 136 public: IntraNodeCommWork()137 IntraNodeCommWork() : c10d::Work() { 138 event_.record(); 139 } 140 wait(std::chrono::milliseconds timeout=kNoTimeout)141 bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { 142 event_.block(at::cuda::getCurrentCUDAStream()); 143 return true; 144 } 145 146 private: 147 at::cuda::CUDAEvent event_; 148 }; 149 150 TORCH_API int64_t getIntraNodeCommUsageCounter(); 151 152 } // namespace c10d::intra_node_comm 153