xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/intra_node_comm.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
2 
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <c10/cuda/CUDAGuard.h>
6 
7 namespace c10d {
8 namespace intra_node_comm {
9 
10 static constexpr size_t kBytesPerThread = 16;
11 static constexpr size_t kMaxAllReduceBlocks = 24;
12 static constexpr size_t kThreadsPerBlock = 1024;
13 static constexpr size_t kWarpSize = 32;
14 
15 static constexpr size_t kHcmThreshBytes = 256 * 1024;
16 static constexpr size_t kOneShotThreshBytes = 256 * 1024;
17 static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
18 
19 #if defined(USE_ROCM)
20 using __nv_bfloat162 = uint32_t;
21 #endif
22 
23 struct __align__(16) bf16x8 {
24   __nv_bfloat162 vals[4];
25 };
26 
27 #define DEVICE_INLINE __device__ inline __attribute__((always_inline))
28 
29 DEVICE_INLINE __nv_bfloat162
bf16hadd2(const __nv_bfloat162 x,const __nv_bfloat162 y)30 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
31 #if defined(USE_ROCM)
32   CUDA_KERNEL_ASSERT(false);
33   return 0;
34 #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
35   CUDA_KERNEL_ASSERT(false);
36   __nv_bfloat162 res;
37   return res;
38 #else
39   return __hadd2(x, y);
40 #endif
41 }
42 
add_bf16x8(bf16x8 a,bf16x8 b)43 DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
44   bf16x8 c;
45   c.vals[0] = bf16hadd2(a.vals[0], b.vals[0]);
46   c.vals[1] = bf16hadd2(a.vals[1], b.vals[1]);
47   c.vals[2] = bf16hadd2(a.vals[2], b.vals[2]);
48   c.vals[3] = bf16hadd2(a.vals[3], b.vals[3]);
49   return c;
50 }
51 
52 /**
53  * NOTE [cross device memory synchronization]
54  *
55  * The multi-stage algorithms (e.g. two-shot, hcm allreduce) require the writes
56  * of a thread to be visible by threads with the same block/thread ID on other
57  * devices. To satisfy CUDA's memory consistency model, every thread has to
58  * release its writes at the system scope, and the consuming thread has to
59  * acquire the writes at the system scope. This incurs high overhead and
60  * attempts in optmizing this process can be prone to race condition.
61  *
62  * Instead, we go around caching by having each thread:
63  *
64  * - Directly write to global memory via st.cs (cache-streaming).
65  * - Synchronize with threads within the block.
66  * - Perform cross device synchronization at block level (via system scope
67  *   atomic ops).
68  * - Synchronize with threads within the block.
69  * - Directly read from global memory via ld.nc (non-coherent/non-cached).
70  */
71 template <typename T>
streamLoad128(bf16x8 & val,const T * addr)72 DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
73 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
74   CUDA_KERNEL_ASSERT(false);
75 #else
76   unsigned long long int low, high;
77   asm("ld.global.nc.v2.u64 {%0, %1}, [%2];"
78       : "=l"(low), "=l"(high)
79       : "l"(addr));
80   reinterpret_cast<unsigned long long int*>(&val)[0] = low;
81   reinterpret_cast<unsigned long long int*>(&val)[1] = high;
82 #endif
83 }
84 
streamStore128(at::BFloat16 * addr,const bf16x8 & val)85 __device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) {
86 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
87   CUDA_KERNEL_ASSERT(false);
88 #else
89   unsigned long long int low, high;
90   low = reinterpret_cast<const unsigned long long int*>(&val)[0];
91   high = reinterpret_cast<const unsigned long long int*>(&val)[1];
92   asm("st.global.cs.v2.u64 [%0], {%1, %2};" : : "l"(addr), "l"(low), "l"(high));
93 #endif
94 }
95 
96 template <typename T>
load128(bf16x8 & val,const T * addr)97 DEVICE_INLINE void load128(bf16x8& val, const T* addr) {
98   *reinterpret_cast<uint4*>(&val) = reinterpret_cast<const uint4*>(addr)[0];
99 }
100 
101 template <typename T>
store128(T * addr,const bf16x8 & val)102 DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
103   *reinterpret_cast<uint4*>(addr) = reinterpret_cast<const uint4*>(&val)[0];
104 }
105 
releaseSignal(uint32_t * addr)106 DEVICE_INLINE void releaseSignal(uint32_t* addr) {
107 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
108   CUDA_KERNEL_ASSERT(false);
109 #else
110   atomicAdd_system(addr, 1);
111 #endif
112 }
113 
acquireSignal(uint32_t * addr)114 DEVICE_INLINE void acquireSignal(uint32_t* addr) {
115 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
116   CUDA_KERNEL_ASSERT(false);
117 #else
118   volatile uint32_t* signal = addr;
119   uint32_t val;
120   do {
121     val = *signal;
122   } while (val == 0 || atomicCAS_system(addr, val, val - 1) != val);
123 #endif
124 }
125 
126 ////////////////////////////////////////////////////////////////////////////////
127 // Fully Connected Algos
128 ////////////////////////////////////////////////////////////////////////////////
129 
130 struct P2pState {
131   uint32_t signals0[kMaxAllReduceBlocks][kMaxDevices];
132   uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices];
133 };
134 
135 static_assert(sizeof(P2pState) <= kP2pStateSize);
136 
137 template <uint32_t kWorldSize, bool kAligned>
oneShotAllReduceKernel(at::BFloat16 * input,size_t N,size_t N_aligned,P2pState ** p2pStates,at::BFloat16 ** buffers,size_t rank,bool fuseInputCopy)138 static __global__ void oneShotAllReduceKernel(
139     at::BFloat16* input,
140     size_t N,
141     size_t N_aligned,
142     P2pState** p2pStates,
143     at::BFloat16** buffers,
144     size_t rank,
145     bool fuseInputCopy) {
146   const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
147   const size_t offset =
148       (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
149   const size_t stride = blockDim.x * gridDim.x * numelPerThread;
150 
151   if (fuseInputCopy) {
152     for (size_t i = offset; i < N_aligned; i += stride) {
153       bf16x8 val;
154       streamLoad128(val, &input[i]);
155       streamStore128(&buffers[rank][i], val);
156     }
157   }
158 
159   // Wait for all other ranks to enter the kernel
160   if (threadIdx.x < kWorldSize) {
161     auto targetRank = threadIdx.x;
162     releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
163     acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
164   }
165   __syncthreads();
166 
167   // The source pointers. Distributed round-robin for the different warps
168   const at::BFloat16* srcs[kWorldSize];
169 #pragma unroll kWorldSize
170   for (int ii = 0; ii < kWorldSize; ++ii) {
171     int srcRank = (rank + ii) % kWorldSize;
172     srcs[ii] = buffers[srcRank];
173   }
174 
175   for (size_t i = offset; i < N_aligned; i += stride) {
176     bf16x8 vals[kWorldSize];
177 #pragma unroll kWorldSize
178     for (size_t ii = 0; ii < kWorldSize; ++ii) {
179       // Make sure the values in `vals` are order by rank so that the reduction
180       // results are consistent across ranks.
181       int srcRank = (ii + kWorldSize - rank) % kWorldSize;
182       streamLoad128(vals[srcRank], &srcs[ii][i]);
183     }
184 
185     bf16x8 sums;
186     memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
187 
188 #pragma unroll kWorldSize
189     for (size_t ii = 0; ii < kWorldSize; ++ii) {
190       sums = add_bf16x8(sums, vals[ii]);
191     }
192     if constexpr (kAligned) {
193       streamStore128(&input[i], sums);
194     } else {
195       for (size_t ii = 0; ii < numelPerThread; ++ii) {
196         if (i + ii < N) {
197           input[i + ii] = reinterpret_cast<at::BFloat16*>(&sums)[ii];
198         }
199       }
200     }
201   }
202 }
203 
204 template <uint32_t kWorldSize>
twoShotAllReduceKernel(at::BFloat16 * input,size_t N_aligned,P2pState ** p2pStates,at::BFloat16 ** buffers,size_t rank)205 static __launch_bounds__(1024) __global__ void twoShotAllReduceKernel(
206     at::BFloat16* input,
207     size_t N_aligned,
208     P2pState** p2pStates,
209     at::BFloat16** buffers,
210     size_t rank) {
211   const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
212   const size_t offset =
213       (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
214   const size_t stride = blockDim.x * gridDim.x * numelPerThread;
215   const size_t N_per_rank = N_aligned / kWorldSize;
216   const size_t N_start = N_per_rank * rank;
217 
218   // Wait for all other ranks to enter the kernel
219   if (threadIdx.x < kWorldSize) {
220     auto targetRank = threadIdx.x;
221     releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
222     acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
223   }
224   __syncthreads();
225 
226   // The source pointers. Distributed round-robin for the different warps
227   at::BFloat16* srcs[kWorldSize];
228   size_t srcRanks[kWorldSize];
229 #pragma unroll kWorldSize
230   for (int ii = 0; ii < kWorldSize; ++ii) {
231     int srcRank = (rank + ii) % kWorldSize;
232     srcs[ii] = buffers[srcRank];
233     srcRanks[ii] = srcRank;
234   }
235 
236   for (size_t i = offset; i < N_per_rank; i += stride) {
237     bf16x8 vals[kWorldSize];
238 #pragma unroll kWorldSize
239     for (size_t ii = 0; ii < kWorldSize; ++ii) {
240       // Make sure the values in `vals` are order by rank so that the reduction
241       // results are consistent across ranks.
242       int srcRank = (ii + kWorldSize - rank) % kWorldSize;
243       streamLoad128(vals[srcRank], &srcs[ii][N_start + i]);
244     }
245 
246     bf16x8 sums;
247     memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
248 
249 #pragma unroll kWorldSize
250     for (size_t ii = 0; ii < kWorldSize; ++ii) {
251       sums = add_bf16x8(sums, vals[ii]);
252     }
253     streamStore128(&srcs[0][N_start + i], sums);
254     // Store local sums into input now so we can avoid
255     // a global memory access later for it.
256     streamStore128(&input[N_start + i], sums);
257   }
258   __syncthreads();
259 
260   if (threadIdx.x < kWorldSize) {
261     auto targetRank = threadIdx.x;
262     releaseSignal(&p2pStates[targetRank]->signals1[blockIdx.x][rank]);
263     acquireSignal(&p2pStates[rank]->signals1[blockIdx.x][targetRank]);
264   }
265   __syncthreads();
266 
267   for (size_t i = offset; i < N_per_rank; i += stride) {
268 #pragma unroll kWorldSize - 1
269     for (size_t ii = 1; ii < kWorldSize; ++ii) {
270       size_t k = N_start + i + (srcRanks[ii] - rank) * N_per_rank;
271       bf16x8 val;
272       streamLoad128(val, &srcs[ii][k]);
273       streamStore128(&input[k], val);
274     }
275   }
276 }
277 
278 ////////////////////////////////////////////////////////////////////////////////
279 // Hybrid Cube Mesh Algos
280 ////////////////////////////////////////////////////////////////////////////////
281 
282 /**
283  * NOTE [hybrid cube mesh]
284  *
285  * In a hybrid cube mesh topology, every device has exactly 4 neighbors
286  * (directly connected via NVLink). For every device X, it has exactly 1
287  * neighbor Y that is a neighbor of the 3 non-neighbor of X. We call Y the
288  * relay neighbor of X. This property is symmetrical: X is also guaranteed to
289  * be the relay neighbor of Y.
290  *
291  * With this property, we can perform a variant of one-shot allreduce algo that
292  * only moves data across NVLinks:
293  *
294  * - Each device one-shot allreduce among itself and 3 non-relay neighbors.
295  * - Each device exchange data with its relay neighbor.
296  *
297  * HybridCubeMesh is a data structure for describing the topology:
298  *
299  * - hcm[X][0:3] are the 3 neighbors of X.
300  * - hcm[X][3] is the relay neighbor of X.
301  * - For load balancing purpose, we also ensure that if hcm[X][k] = Y,
302  *   hcm[Y][k] = X.
303  */
getHybridCubeMesh(NvlMesh nvlMesh)304 std::optional<HybridCubeMesh> getHybridCubeMesh(NvlMesh nvlMesh) {
305   std::array<std::unordered_set<size_t>, kMaxDevices> neighbors = {};
306   std::array<size_t, kMaxDevices> neighborMasks = {};
307   for (size_t i = 0; i < kMaxDevices; ++i) {
308     for (size_t j = 0; j < kMaxDevices; ++j) {
309       if (nvlMesh[i][j] > 0) {
310         neighbors[i].insert(j);
311         neighborMasks[i] |= (1ul << j);
312       }
313     }
314   }
315   HybridCubeMesh hcm = {};
316   for (auto& row : hcm) {
317     row.fill(-1);
318   }
319   // A topology is an HCM if:
320   // - Every device has exactly 4 neighbors.
321   // - For every device, it has exactly 1 relay neighbor that is
322   //   a neighbor of the 3 non-neighbor of the device.
323   for (size_t i = 0; i < kMaxDevices; ++i) {
324     if (neighbors[i].size() != 4) {
325       return std::nullopt;
326     }
327     // Condition 1: check the number of neighbors
328     std::vector<size_t> relayNeighbors;
329     for (size_t j = 0; j < kMaxDevices; ++j) {
330       if ((neighborMasks[i] & neighborMasks[j]) == 0) {
331         relayNeighbors.push_back(j);
332       }
333     }
334     // Condition 2: check the number of relay neighbors
335     if (relayNeighbors.size() != 1) {
336       return std::nullopt;
337     }
338     neighbors[i].erase(relayNeighbors[0]);
339     hcm[i][3] = relayNeighbors[0];
340   }
341 
342   for (size_t i = 0; i < kMaxDevices; ++i) {
343     for (size_t k = 0; k < 3; ++k) {
344       // We can only fill hcm[i][k] with j if hcm[j][k] is not filled
345       for (size_t j : neighbors[i]) {
346         if (hcm[j][k] == -1) {
347           hcm[i][k] = j;
348           hcm[j][k] = i;
349           break;
350         }
351       }
352       TORCH_CHECK(hcm[i][k] != -1);
353       neighbors[i].erase(hcm[i][k]);
354     }
355   }
356   return hcm;
357 }
358 
359 template <bool kAligned>
hybridCubeMeshAllReduceKernel(at::BFloat16 * input,size_t N,size_t N_aligned,P2pState ** p2pStates,at::BFloat16 ** buffers,int hcmInfo[4],size_t bufferSize,size_t rank)360 static __global__ void hybridCubeMeshAllReduceKernel(
361     at::BFloat16* input,
362     size_t N,
363     size_t N_aligned,
364     P2pState** p2pStates,
365     at::BFloat16** buffers,
366     int hcmInfo[4],
367     size_t bufferSize,
368     size_t rank) {
369   const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16);
370   const size_t offset =
371       (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread;
372   const size_t stride = blockDim.x * gridDim.x * numelPerThread;
373   const int relayRank = hcmInfo[3];
374 
375   // Wait for HCM neigbors to enter the kernel
376   if (threadIdx.x < 3) {
377     auto targetRank = hcmInfo[threadIdx.x];
378     releaseSignal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]);
379     acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]);
380   }
381   __syncthreads();
382 
383   const at::BFloat16* srcs[4] = {
384       buffers[rank],
385       buffers[hcmInfo[0]],
386       buffers[hcmInfo[1]],
387       buffers[hcmInfo[2]],
388   };
389   // Use the half second half of the buffer as relay
390   at::BFloat16* localRelay =
391       buffers[rank] + (bufferSize / sizeof(at::BFloat16) / 2);
392   at::BFloat16* remoteRelay =
393       buffers[relayRank] + (bufferSize / sizeof(at::BFloat16) / 2);
394 
395   for (size_t i = offset; i < N_aligned; i += stride) {
396     bf16x8 vals[4];
397 
398 #pragma unroll 4
399     for (size_t ii = 0; ii < 4; ++ii) {
400       streamLoad128(vals[ii], &srcs[ii][i]);
401     }
402 
403     bf16x8 sums;
404     memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
405 
406 #pragma unroll 4
407     for (size_t ii = 0; ii < 4; ++ii) {
408       sums = add_bf16x8(sums, vals[ii]);
409     }
410     // Cached store for local sums
411     store128(&localRelay[i], sums);
412   }
413   __syncthreads();
414 
415   if (threadIdx.x == 0) {
416     releaseSignal(&p2pStates[relayRank]->signals0[blockIdx.x][rank]);
417     acquireSignal(&p2pStates[rank]->signals0[blockIdx.x][relayRank]);
418   }
419   __syncthreads();
420 
421   for (size_t i = offset; i < N_aligned; i += stride) {
422     bf16x8 localSum, remoteSum;
423     // Cached load for local sums
424     load128(localSum, &localRelay[i]);
425     streamLoad128(remoteSum, &remoteRelay[i]);
426     localSum = add_bf16x8(localSum, remoteSum);
427     if constexpr (kAligned) {
428       streamStore128(&input[i], localSum);
429     } else {
430       for (size_t ii = 0; ii < numelPerThread; ++ii) {
431         if (i + ii < N) {
432           input[i + ii] = reinterpret_cast<at::BFloat16*>(&localSum)[ii];
433         }
434       }
435     }
436   }
437 }
438 
divUp(uint32_t a,uint32_t b)439 static inline size_t divUp(uint32_t a, uint32_t b) {
440   return (a + b - 1) / b;
441 }
442 
alignUp(uint32_t a,uint32_t b)443 static inline size_t alignUp(uint32_t a, uint32_t b) {
444   return divUp(a, b) * b;
445 }
446 
checkInput(const at::Tensor & input,int deviceIdx)447 static void checkInput(const at::Tensor& input, int deviceIdx) {
448   TORCH_CHECK(
449       input.dtype() == at::kBFloat16,
450       "oneShotAllReduce only supports bf16 for now");
451   TORCH_CHECK(input.is_non_overlapping_and_dense());
452   TORCH_CHECK(input.device().is_cuda());
453   TORCH_CHECK(
454       input.get_device() == deviceIdx,
455       "IntraNodeComm: expect input to be on device ",
456       deviceIdx,
457       ", got device ",
458       input.get_device());
459 }
460 
getLaunchConfig(size_t N_aligned,size_t elemSize,dim3 & blocks,dim3 & threads)461 static void getLaunchConfig(
462     size_t N_aligned,
463     size_t elemSize,
464     dim3& blocks,
465     dim3& threads) {
466   blocks = dim3(0, 1, 1);
467   threads = dim3(0, 1, 1);
468 
469   const auto numelPerThread = kBytesPerThread / elemSize;
470   const auto numelPerWarp = numelPerThread * kWarpSize;
471   TORCH_CHECK(N_aligned % numelPerThread == 0);
472   TORCH_CHECK(N_aligned % numelPerWarp == 0);
473   if (N_aligned < numelPerThread * kThreadsPerBlock) {
474     threads.x = N_aligned / numelPerWarp * kWarpSize;
475     blocks.x = 1;
476   } else {
477     auto warpsRequired = N_aligned / numelPerWarp;
478     auto threadsRequired = N_aligned / numelPerThread;
479     blocks.x =
480         std::min(divUp(threadsRequired, kThreadsPerBlock), kMaxAllReduceBlocks);
481     auto warpsPerBlock = divUp(warpsRequired, blocks.x);
482     threads.x = std::min(kThreadsPerBlock, warpsPerBlock * kWarpSize);
483   }
484 }
485 
isIntraNodeCommSupported()486 bool isIntraNodeCommSupported() {
487 #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
488   return false;
489 #else
490   return true;
491 #endif
492 }
493 
initP2pState()494 void* initP2pState() {
495   void* state = nullptr;
496   AT_CUDA_CHECK(cudaMalloc(&state, sizeof(P2pState)));
497   AT_CUDA_CHECK(cudaMemset(state, 0, sizeof(P2pState)));
498   return state;
499 }
500 
initTopoInfo(Topology topology,NvlMesh nvlMesh,size_t rank)501 void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank) {
502   void* topoInfo = nullptr;
503   if (topology != Topology::HYBRID_CUBE_MESH) {
504     return topoInfo;
505   }
506   auto hcm = getHybridCubeMesh(nvlMesh);
507   int hcmInfo[4];
508   std::copy((*hcm)[rank].begin(), (*hcm)[rank].begin() + 4, hcmInfo);
509   AT_CUDA_CHECK(cudaMalloc(&topoInfo, sizeof(hcmInfo)));
510   AT_CUDA_CHECK(
511       cudaMemcpy(topoInfo, hcmInfo, sizeof(hcmInfo), cudaMemcpyHostToDevice));
512   return topoInfo;
513 }
514 
oneShotAllReduce(const at::Tensor & input,at::cuda::CUDAStream & stream)515 at::Tensor IntraNodeComm::oneShotAllReduce(
516     const at::Tensor& input,
517     at::cuda::CUDAStream& stream) {
518   checkInput(input, deviceIdx_);
519 
520   const size_t numelPerWarp =
521       kBytesPerThread / input.element_size() * kWarpSize;
522   const size_t N_aligned = alignUp(input.numel(), numelPerWarp);
523   const bool isAligned = (N_aligned == static_cast<size_t>(input.numel()));
524   TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size());
525 
526   dim3 blocks, threads;
527   getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
528 
529   at::cuda::OptionalCUDAGuard guard(input.get_device());
530 
531   // When the input data is small, copying inside the kernel is faster. Because
532   // in such cases, the launch overhead of cudaMemcpyAsync outweighs its
533   // efficiency. Here we consider the input data to be small if the copy loop
534   // can finish in a single iteration.
535   const bool fuseInputCopy = isAligned && blocks.x < kMaxAllReduceBlocks;
536   if (!fuseInputCopy) {
537     AT_CUDA_CHECK(cudaMemcpyAsync(
538         symmetricMemory_->get_buffer_ptrs()[rank_],
539         input.data_ptr(),
540         input.numel() * input.element_size(),
541         cudaMemcpyDeviceToDevice,
542         stream));
543   }
544 
545 #define X(kWorldSize, kAligned)                            \
546   if (worldSize_ == kWorldSize) {                          \
547     oneShotAllReduceKernel<kWorldSize, kAligned>           \
548         <<<blocks, threads, 0, stream>>>(                  \
549             input.data_ptr<at::BFloat16>(),                \
550             input.numel(),                                 \
551             N_aligned,                                     \
552             reinterpret_cast<P2pState**>(p2pStatesDev_),   \
553             reinterpret_cast<at::BFloat16**>(buffersDev_), \
554             rank_,                                         \
555             fuseInputCopy);                                \
556     C10_CUDA_KERNEL_LAUNCH_CHECK();                        \
557   }
558 
559 #define DISPATCH_ALL_WORLD_SIZES(kAligned) \
560   X(2, kAligned);                          \
561   X(3, kAligned);                          \
562   X(4, kAligned);                          \
563   X(5, kAligned);                          \
564   X(6, kAligned);                          \
565   X(7, kAligned);                          \
566   X(8, kAligned);
567 
568   if (isAligned) {
569     DISPATCH_ALL_WORLD_SIZES(true);
570   } else {
571     DISPATCH_ALL_WORLD_SIZES(false);
572   }
573 
574 #undef DISPATCH_ALL_WORLD_SIZES
575 #undef X
576   return input;
577 }
578 
twoShotAllReduce(const at::Tensor & input,at::cuda::CUDAStream & stream)579 at::Tensor IntraNodeComm::twoShotAllReduce(
580     const at::Tensor& input,
581     at::cuda::CUDAStream& stream) {
582   checkInput(input, deviceIdx_);
583 
584   size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
585   size_t N_aligned = alignUp(input.numel(), worldSize_ * numelPerWarp);
586   size_t N_per_rank = N_aligned / worldSize_;
587   TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size());
588 
589   dim3 blocks, threads;
590   getLaunchConfig(N_per_rank, input.element_size(), blocks, threads);
591 
592   auto output = N_aligned == static_cast<size_t>(input.numel())
593       ? input
594       : input.new_empty(N_aligned);
595 
596   at::cuda::OptionalCUDAGuard guard(input.get_device());
597   AT_CUDA_CHECK(cudaMemcpyAsync(
598       symmetricMemory_->get_buffer_ptrs()[rank_],
599       input.data_ptr(),
600       input.numel() * input.element_size(),
601       cudaMemcpyDeviceToDevice,
602       stream));
603 
604 #define X(kWorldSize)                                                   \
605   if (worldSize_ == kWorldSize) {                                       \
606     twoShotAllReduceKernel<kWorldSize><<<blocks, threads, 0, stream>>>( \
607         output.data_ptr<at::BFloat16>(),                                \
608         N_aligned,                                                      \
609         reinterpret_cast<P2pState**>(p2pStatesDev_),                    \
610         reinterpret_cast<at::BFloat16**>(buffersDev_),                  \
611         rank_);                                                         \
612     C10_CUDA_KERNEL_LAUNCH_CHECK();                                     \
613   }
614   X(2);
615   X(3);
616   X(4);
617   X(5);
618   X(6);
619   X(7);
620   X(8);
621 #undef X
622 
623   if (output.data_ptr() != input.data_ptr()) {
624     AT_CUDA_CHECK(cudaMemcpyAsync(
625         input.data_ptr(),
626         output.data_ptr(),
627         input.numel() * input.element_size(),
628         cudaMemcpyDeviceToDevice,
629         stream));
630   }
631   return input;
632 }
633 
hybridCubeMeshAllReduce(const at::Tensor & input,at::cuda::CUDAStream & stream)634 at::Tensor IntraNodeComm::hybridCubeMeshAllReduce(
635     const at::Tensor& input,
636     at::cuda::CUDAStream& stream) {
637   checkInput(input, deviceIdx_);
638 
639   size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
640   size_t N_aligned = alignUp(input.numel(), numelPerWarp);
641   TORCH_CHECK(N_aligned * 2 <= bufferSize_ / input.element_size());
642 
643   dim3 blocks, threads;
644   getLaunchConfig(N_aligned, input.element_size(), blocks, threads);
645 
646   at::cuda::OptionalCUDAGuard guard(input.get_device());
647   AT_CUDA_CHECK(cudaMemcpyAsync(
648       symmetricMemory_->get_buffer_ptrs()[rank_],
649       input.data_ptr(),
650       input.numel() * input.element_size(),
651       cudaMemcpyDeviceToDevice,
652       stream));
653 
654 #define X(kAligned)                                                        \
655   hybridCubeMeshAllReduceKernel<kAligned><<<blocks, threads, 0, stream>>>( \
656       input.data_ptr<at::BFloat16>(),                                      \
657       input.numel(),                                                       \
658       N_aligned,                                                           \
659       reinterpret_cast<P2pState**>(p2pStatesDev_),                         \
660       reinterpret_cast<at::BFloat16**>(buffersDev_),                       \
661       static_cast<int*>(topoInfo_),                                        \
662       bufferSize_,                                                         \
663       rank_);                                                              \
664   C10_CUDA_KERNEL_LAUNCH_CHECK();
665 
666   if (N_aligned == static_cast<size_t>(input.numel())) {
667     X(true);
668   } else {
669     X(false);
670   }
671 #undef X
672   return input;
673 }
674 
selectAllReduceAlgo(const at::Tensor & input)675 AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) {
676   // Only support bf16 for now
677   if (input.dtype() != at::kBFloat16) {
678     return AllReduceAlgo::NONE;
679   }
680   const auto inputSize = input.numel() * input.element_size();
681   const auto bytesPerWarp = kBytesPerThread * kWarpSize;
682 
683   if (topology_ == Topology::HYBRID_CUBE_MESH) {
684     TORCH_CHECK(
685         worldSize_ == 8, "hyperCubeAllReduce only supports exactly 8 GPUs");
686     const auto hcmInputSize = alignUp(inputSize, bytesPerWarp);
687     const auto hcmBufferSizeReq = hcmInputSize * 2;
688     if (hcmInputSize <= kHcmThreshBytes && hcmBufferSizeReq <= bufferSize_) {
689       return AllReduceAlgo::HCM;
690     }
691   }
692   if (topology_ == Topology::FULLY_CONNECTED) {
693     const auto oneShotInputSize = alignUp(inputSize, bytesPerWarp);
694     const auto oneShotBufferSizeReq = oneShotInputSize;
695     if (oneShotInputSize <= kOneShotThreshBytes &&
696         oneShotBufferSizeReq <= bufferSize_) {
697       return AllReduceAlgo::ONE_SHOT;
698     }
699 
700     const auto twoShotInputSize = alignUp(inputSize, bytesPerWarp * worldSize_);
701     const auto twoShotBufferSizeReq = twoShotInputSize;
702     if (twoShotInputSize <= kTwoShotThreshBytes &&
703         twoShotBufferSizeReq <= bufferSize_) {
704       return AllReduceAlgo::TWO_SHOT;
705     }
706   }
707   return AllReduceAlgo::NONE;
708 }
709 
710 static int64_t usageCounter = 0;
711 
allReduce(const at::Tensor & input,AllReduceAlgo algo)712 at::Tensor IntraNodeComm::allReduce(
713     const at::Tensor& input,
714     AllReduceAlgo algo) {
715   // Report usage for testing purposes.
716   // We don't care about overflowing.
717   ++usageCounter;
718   auto stream = at::cuda::getCurrentCUDAStream();
719   c10::cuda::CUDACachingAllocator::recordStream(
720       input.storage().data_ptr(), stream);
721   switch (algo) {
722     case AllReduceAlgo::ONE_SHOT:
723       return oneShotAllReduce(input, stream);
724     case AllReduceAlgo::TWO_SHOT:
725       return twoShotAllReduce(input, stream);
726     case AllReduceAlgo::HCM:
727       return hybridCubeMeshAllReduce(input, stream);
728     default:
729       C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo");
730   }
731 }
732 
getIntraNodeCommUsageCounter()733 int64_t getIntraNodeCommUsageCounter() {
734   return usageCounter;
735 }
736 
barrierKernel(P2pState ** p2pStates,uint64_t mask,size_t rank,size_t worldSize)737 static __global__ void barrierKernel(
738     P2pState** p2pStates,
739     uint64_t mask,
740     size_t rank,
741     size_t worldSize) {
742   if (threadIdx.x < worldSize && (mask & (1ULL << threadIdx.x))) {
743     auto targetRank = threadIdx.x;
744     releaseSignal(&p2pStates[targetRank]->signals0[0][rank]);
745     acquireSignal(&p2pStates[rank]->signals0[0][targetRank]);
746   }
747 }
748 
barrier(std::optional<std::vector<int64_t>> ranks)749 void IntraNodeComm::barrier(std::optional<std::vector<int64_t>> ranks) {
750   barrierReady_.block(at::cuda::getCurrentCUDAStream());
751   if (!ranks.has_value()) {
752     ranks = std::vector<int64_t>(worldSize_);
753     std::iota(ranks->begin(), ranks->end(), 0);
754   }
755   uint64_t mask = 0;
756   for (const auto& r : ranks.value()) {
757     TORCH_CHECK(r >= 0 && r < static_cast<int64_t>(worldSize_));
758     mask |= (1ULL << r);
759   }
760   barrierKernel<<<1, kWarpSize, 0, at::cuda::getCurrentCUDAStream()>>>(
761       reinterpret_cast<P2pState**>(p2pStatesDev_), mask, rank_, worldSize_);
762   C10_CUDA_KERNEL_LAUNCH_CHECK();
763   barrierReady_.record();
764 }
765 
getBuffer(size_t rank,const std::vector<int64_t> & sizes,c10::ScalarType dtype,int64_t storageOffset)766 at::Tensor IntraNodeComm::getBuffer(
767     size_t rank,
768     const std::vector<int64_t>& sizes,
769     c10::ScalarType dtype,
770     int64_t storageOffset) {
771   return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset);
772 }
773 
774 } // namespace intra_node_comm
775 } // namespace c10d
776