xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/intra_node_comm.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/cuda/CUDAEvent.h>
5 #include <c10/cuda/CUDAStream.h>
6 #include <torch/csrc/distributed/c10d/Store.hpp>
7 #include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
8 #include <torch/csrc/distributed/c10d/Work.hpp>
9 
10 namespace c10d::intra_node_comm {
11 
12 using namespace c10d::symmetric_memory;
13 
14 constexpr size_t kMaxDevices = 8;
15 constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024;
16 constexpr size_t kP2pStateSize = 2048;
17 
18 using NvlMesh = std::array<std::array<size_t, kMaxDevices>, kMaxDevices>;
19 using HybridCubeMesh = std::array<std::array<int, 4>, kMaxDevices>;
20 
21 enum class Topology : uint8_t {
22   UNKNOWN = 0,
23   FULLY_CONNECTED = 1,
24   HYBRID_CUBE_MESH = 2
25 };
26 
27 enum class AllReduceAlgo : uint8_t {
28   NONE = 0,
29   ONE_SHOT = 1,
30   TWO_SHOT = 2,
31   HCM = 3
32 };
33 
34 // NOTE: this class will be be removed soon in favor of SymmetricMemory
35 class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
36  public:
37   IntraNodeComm(
38       c10::intrusive_ptr<c10d::Store> store,
39       size_t rank,
40       size_t worldSize,
41       std::optional<size_t> bufferSize = std::nullopt);
42 
43   ~IntraNodeComm() override;
44 
45   static bool isEnabled();
46 
47   /**
48    * Performs rendezvous.
49    * If rendezvous fails, the IntraNodeComm object will be in an invalid
50    * state and it is the caller's responsibility to dispose it.
51    */
52   bool rendezvous();
53 
getTopology()54   Topology getTopology() {
55     return topology_;
56   }
57 
getBufferSize()58   size_t getBufferSize() {
59     return bufferSize_;
60   }
61 
62   /**
63    * Selects a AllReduceAlgo that we think will outperform nccl.
64    * Returns AllReduceAlgo::NONE if we don't think we can outperform nccl.
65    */
66   AllReduceAlgo selectAllReduceAlgo(const at::Tensor& input);
67 
68   at::Tensor allReduce(const at::Tensor& input, AllReduceAlgo algo);
69 
70   /**
71    * Perform a barrier among the specified ranks.
72    */
73   void barrier(std::optional<std::vector<int64_t>> ranks = std::nullopt);
74 
75   at::Tensor getBuffer(
76       size_t rank,
77       const std::vector<int64_t>& sizes,
78       c10::ScalarType dtype,
79       int64_t storageOffset);
80 
81  private:
82   at::Tensor oneShotAllReduce(
83       const at::Tensor& input,
84       at::cuda::CUDAStream& stream);
85 
86   at::Tensor twoShotAllReduce(
87       const at::Tensor& input,
88       at::cuda::CUDAStream& stream);
89 
90   at::Tensor hybridCubeMeshAllReduce(
91       const at::Tensor& input,
92       at::cuda::CUDAStream& stream);
93 
94   c10::intrusive_ptr<Store> store_;
95   size_t rank_;
96   size_t worldSize_;
97   size_t bufferSize_;
98   at::cuda::CUDAEvent barrierReady_;
99 
100   /**
101    * Members initialized after rendezvous
102    */
103   bool isInitialized_ = false;
104   int deviceIdx_;
105   Topology topology_ = Topology::UNKNOWN;
106   void* symmetricMemoryPtr_ = nullptr;
107   c10::intrusive_ptr<SymmetricMemory> symmetricMemory_ = nullptr;
108   void* p2pStatesDev_{};
109   void* buffersDev_{};
110   void* topoInfo_{};
111 };
112 
113 /**
114  * NOTE [IntraNodeComm Stream Semantics]
115  *
116  * ProcessGroupNCCL launches kernels differently from the conventional PyTorch
117  * CUDA semantics: it always launches collective kernels onto a dedicated
118  * communication stream. Therefore, it needs to:
119  *
120  * - Synchronize the calling stream and the comm stream.
121  * - Ensure the memory safety of the operands (via record_stream or stashing).
122  * - Synchronize the waiting stream with the comm stream.
123  *
124  * Unconditionally performing these tasks makes sense when we expect most of the
125  * communication to benefit from compute/comm overlap. However, IntraNodeComm
126  * primarily aims to optimize small, latency-sensitive, blocking communication,
127  * in which the overhead incurred by the above steps can be quite pronounced.
128  *
129  * Thus, IntraNodeComm follows the conventional PyTorch CUDA semantics and
130  * launches kernels onto the stream specified by the user. Although the user
131  * can perform neccessary synchronization via wait_stream, to provide a UX
132  * consistent to that of ProcessGroupNCCL, the neccessary stream
133  * synchronization can also be performed via IntraNodeWork::wait().
134  */
135 class IntraNodeCommWork : public c10d::Work {
136  public:
IntraNodeCommWork()137   IntraNodeCommWork() : c10d::Work() {
138     event_.record();
139   }
140 
wait(std::chrono::milliseconds timeout=kNoTimeout)141   bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
142     event_.block(at::cuda::getCurrentCUDAStream());
143     return true;
144   }
145 
146  private:
147   at::cuda::CUDAEvent event_;
148 };
149 
150 TORCH_API int64_t getIntraNodeCommUsageCounter();
151 
152 } // namespace c10d::intra_node_comm
153