1 #pragma once 2 3 #ifdef USE_C10D_NCCL 4 5 #include <stdio.h> 6 #include <stdlib.h> 7 8 #include <memory> 9 #include <mutex> 10 #include <thread> 11 12 #include <ATen/ATen.h> 13 #include <ATen/cuda/CUDAEvent.h> 14 #include <c10/util/Exception.h> 15 #include <nccl.h> 16 #include <torch/csrc/distributed/c10d/TraceUtils.h> 17 #include <optional> 18 19 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ 20 (NCCL_MINOR >= 14) 21 #define NCCL_HAS_COMM_NONBLOCKING 22 #endif 23 24 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ 25 (NCCL_MINOR >= 18) 26 #define NCCL_HAS_COMM_SPLIT 27 #endif 28 29 // ncclGetLastError() is enabled only for NCCL versions 2.13+ 30 // ncclRemoteError only exists in NCCL versions 2.13+ 31 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ 32 (NCCL_MINOR >= 13) 33 #define ENABLE_NCCL_GET_LAST_ERROR 34 #define NCCL_REMOTE_ERROR 35 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) 36 #define ENABLE_NCCL_GET_LAST_ERROR 37 #define NCCL_REMOTE_ERROR 38 #endif 39 40 // Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort() 41 // and ncclCommGetAsyncError() are not supported in earlier versions. 42 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ 43 (NCCL_MINOR >= 4) 44 #define ENABLE_NCCL_ERROR_CHECKING 45 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) 46 #define ENABLE_NCCL_ERROR_CHECKING 47 #endif 48 49 // P2P is enabled only for NCCL versions 2.7+ since ncclSend() 50 // and ncclRecv() are not supported in earlier versions. 51 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ 52 (NCCL_MINOR >= 7) 53 #define ENABLE_NCCL_P2P_SUPPORT 54 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) 55 #define ENABLE_NCCL_P2P_SUPPORT 56 #endif 57 58 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ 59 (NCCL_MINOR >= 11) 60 #define ENABLE_NCCL_PREMUL_SUM_SUPPORT 61 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) 62 #define ENABLE_NCCL_PREMUL_SUM_SUPPORT 63 #endif 64 65 #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ 66 (NCCL_MINOR >= 17) 67 #define NCCL_HAS_COMM_CTA_CGA 68 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) 69 #define NCCL_HAS_COMM_CTA_CGA 70 #endif 71 72 #if defined(NCCL_REGISTRATION_SUPPORTED) || \ 73 ((defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ 74 (NCCL_MINOR >= 19))) 75 #define NCCL_HAS_COMM_REGISTER 76 #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) 77 #define NCCL_HAS_COMM_REGISTER 78 #endif 79 80 // Macro to throw on a non-successful NCCL return value. 81 #define C10D_NCCL_CHECK(cmd, failureReason) \ 82 do { \ 83 ncclResult_t result = cmd; \ 84 if (result != ncclSuccess) { \ 85 std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ 86 std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ 87 "\n" + getNcclErrorDetailStr(result, failureReason); \ 88 TORCH_CHECK_WITH(DistBackendError, false, err); \ 89 } \ 90 } while (0) 91 92 // Macro to throw on a non-successful NCCL return value for NONBLOCKING calls. 93 #define C10D_NCCL_CHECK_NONBLOCKING(cmd, failureReason) \ 94 do { \ 95 ncclResult_t result = cmd; \ 96 if (result != ncclSuccess && result != ncclInProgress) { \ 97 std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ 98 std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ 99 "\n" + getNcclErrorDetailStr(result, failureReason); \ 100 TORCH_CHECK_WITH(DistBackendError, false, err); \ 101 } \ 102 } while (0) 103 104 // Macro to throw on a non-successful NCCL return value, non-blocking. 105 #define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \ 106 ncclResult_t result = cmd; \ 107 auto startTimepoint = std::chrono::steady_clock::now(); \ 108 while (result == ncclInProgress) { \ 109 if (nccl_nonblocking_timeout() > 0) { \ 110 auto currentTimepoint = std::chrono::steady_clock::now(); \ 111 auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \ 112 currentTimepoint - startTimepoint) \ 113 .count(); \ 114 if (timeElapsed > nccl_nonblocking_timeout()) { \ 115 std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \ 116 std::to_string(__LINE__) + ", " + \ 117 ncclGetErrorWithVersion(result) + "\n" + \ 118 getNcclErrorDetailStr(result, failureReason); \ 119 TORCH_CHECK_WITH(DistBackendError, false, err); \ 120 } \ 121 } \ 122 ncclCommGetAsyncError(comm, &result); \ 123 } \ 124 if (result != ncclSuccess) { \ 125 std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ 126 std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ 127 "\n" + getNcclErrorDetailStr(result, failureReason); \ 128 TORCH_CHECK_WITH(DistBackendError, false, err); \ 129 } 130 131 #define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comm, failureReason) \ 132 ncclResult_t state = cmd; \ 133 auto startTimepoint = std::chrono::steady_clock::now(); \ 134 if (state == ncclInProgress) { \ 135 do { \ 136 if (nccl_nonblocking_timeout() > 0) { \ 137 auto currentTimepoint = std::chrono::steady_clock::now(); \ 138 auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \ 139 currentTimepoint - startTimepoint) \ 140 .count(); \ 141 if (timeElapsed > nccl_nonblocking_timeout()) { \ 142 std::string err = "NCCL timeout in: " + std::string(__FILE__) + \ 143 ":" + std::to_string(__LINE__) + ", " + \ 144 ncclGetErrorWithVersion(state) + "\n" + \ 145 getNcclErrorDetailStr(state, failureReason); \ 146 TORCH_CHECK_WITH(DistBackendError, false, err); \ 147 } \ 148 } \ 149 ncclCommGetAsyncError(comm->getNcclComm(), &state); \ 150 } while (state == ncclInProgress); \ 151 } \ 152 if (state != ncclSuccess) { \ 153 std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ 154 std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \ 155 "\n" + getNcclErrorDetailStr(state, failureReason); \ 156 TORCH_CHECK_WITH(DistBackendError, false, err); \ 157 } 158 159 // Macro to print and abort on a non-successful NCCL return value. 160 #define C10D_NCCL_ASSERT(cmd) \ 161 do { \ 162 ncclResult_t result = cmd; \ 163 if (result != ncclSuccess) { \ 164 std::string err = ncclGetErrorWithVersion(result); \ 165 fprintf( \ 166 stderr, \ 167 "NCCL error in: %s:%d, %s\n", \ 168 __FILE__, \ 169 __LINE__, \ 170 err.c_str()); \ 171 abort(); \ 172 } \ 173 } while (0) 174 175 namespace c10d { 176 #define DEFINE_CONSTANT(name, value) \ 177 static c10::IValue name = value; \ 178 static std::string name##_str = value; 179 // Update whenever changing contents or formatting of the dump 180 // (minor when adding fields, major when changing existing fields) 181 // Also update both JSON and Pickle dumps to make use of the newly defined 182 // field(s). 183 DEFINE_CONSTANT(version_val, "2.4"); 184 DEFINE_CONSTANT(entries_key, "entries"); 185 DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state"); 186 DEFINE_CONSTANT(version_key, "version"); 187 DEFINE_CONSTANT(pg_config_key, "pg_config"); 188 DEFINE_CONSTANT(pg_status_key, "pg_status"); 189 DEFINE_CONSTANT(record_id_key, "record_id"); 190 DEFINE_CONSTANT(pg_id_key, "pg_id"); 191 DEFINE_CONSTANT(pg_name_key, "process_group"); 192 DEFINE_CONSTANT(collective_seq_id_key, "collective_seq_id"); 193 DEFINE_CONSTANT(p2p_seq_id_key, "p2p_seq_id"); 194 DEFINE_CONSTANT(is_p2p_key, "is_p2p"); 195 DEFINE_CONSTANT(op_id_key, "op_id"); 196 DEFINE_CONSTANT(profiling_name_key, "profiling_name"); 197 DEFINE_CONSTANT(input_sizes_key, "input_sizes"); 198 DEFINE_CONSTANT(input_dtypes_key, "input_dtypes"); 199 DEFINE_CONSTANT(output_sizes_key, "output_sizes"); 200 DEFINE_CONSTANT(output_dtypes_key, "output_dtypes"); 201 DEFINE_CONSTANT(time_created_key, "time_created_ns"); 202 DEFINE_CONSTANT(duration_key, "duration_ms"); 203 DEFINE_CONSTANT(timeout_key, "timeout_ms"); 204 DEFINE_CONSTANT(frames_key, "frames"); 205 DEFINE_CONSTANT(state_key, "state"); 206 DEFINE_CONSTANT(line_key, "line"); 207 DEFINE_CONSTANT(name_key, "name"); 208 DEFINE_CONSTANT(filename_key, "filename"); 209 DEFINE_CONSTANT(retired_key, "retired"); 210 DEFINE_CONSTANT(time_discovered_started_key, "time_discovered_started_ns"); 211 DEFINE_CONSTANT(time_discovered_completed_key, "time_discovered_completed_ns"); 212 DEFINE_CONSTANT(completed_state, "completed"); 213 DEFINE_CONSTANT(scheduled_state, "scheduled"); 214 DEFINE_CONSTANT(started_state, "started"); 215 #undef DEFINE_CONSTANT 216 217 TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors); 218 TORCH_API std::string getNcclVersion(); 219 TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error); 220 bool nccl_use_nonblocking(); 221 int nccl_nonblocking_timeout(); 222 223 // Provides additional detail into NCCL error codes based on when these are 224 // thrown in the NCCL codebase. 225 TORCH_API std::string getNcclErrorDetailStr( 226 ncclResult_t error, 227 std::optional<std::string> processGroupFailureReason = std::nullopt); 228 229 // Write NCCL debug info to local disk or any storage users define. 230 // There are some constrains we set for the debug info writer: 231 // 1. The writer should only be registered once. 232 // 2. Once registered, users cannot change it including un-register. 233 // 3. It is recommended to register the customized writer in the trainer setup, 234 // If users don't register before calling launchAsyncDebugDump, then users 235 // lose the chance to register (and the default writer will be 236 // auto-registered). 237 class TORCH_API DebugInfoWriter { 238 public: 239 virtual ~DebugInfoWriter() = default; 240 virtual void write(const std::string& ncclTrace); 241 static DebugInfoWriter& getWriter(int rank); 242 static void registerWriter(std::unique_ptr<DebugInfoWriter> writer); getWriterTarget()243 virtual std::string getWriterTarget() { 244 return filename_; 245 } 246 247 protected: DebugInfoWriter(std::string namePrefix,int rank)248 DebugInfoWriter(std::string namePrefix, int rank) { 249 filename_ = c10::str(namePrefix, rank); 250 } 251 std::string filename_; 252 253 private: 254 static std::unique_ptr<DebugInfoWriter> writer_; 255 static std::atomic<bool> hasWriterRegistered_; 256 }; 257 258 // RAII wrapper for NCCL communicator 259 class NCCLComm { 260 public: NCCLComm(ncclComm_t ncclComm)261 explicit NCCLComm(ncclComm_t ncclComm) 262 : ncclComm_(ncclComm), 263 aborted_(false), 264 ncclAsyncErr_(ncclSuccess), 265 commFailureReason_(std::nullopt), 266 initialized_(false) {} 267 NCCLComm()268 NCCLComm() : NCCLComm(nullptr) {} 269 ~NCCLComm()270 ~NCCLComm() noexcept { 271 // Add lock in this destructor, as aborted_ needs to be read after memory 272 // barrier here. 273 std::unique_lock<std::mutex> lock(mutex_); 274 if (ncclComm_ && initialized_ && !aborted_) { 275 #ifdef ENABLE_NCCL_ERROR_CHECKING 276 // Use ncclCommAbort instead of ncclCommDestroy here since 277 // ncclCommDestroy could block forever waiting for work to complete on 278 // the communicator. 279 C10D_NCCL_ASSERT(::ncclCommAbort(ncclComm_)); 280 #else 281 C10D_NCCL_ASSERT(::ncclCommDestroy(ncclComm_)); 282 #endif 283 } 284 } 285 create(int numRanks,int rank,ncclUniqueId commId)286 static std::shared_ptr<NCCLComm> create( 287 int numRanks, 288 int rank, 289 ncclUniqueId commId) { 290 auto comm = std::make_shared<NCCLComm>(); 291 C10D_NCCL_CHECK( 292 ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), 293 std::nullopt); 294 comm->ncclId_ = commId; 295 comm->rank_ = rank; 296 comm->initialized_ = true; 297 return comm; 298 } 299 300 #ifdef NCCL_HAS_COMM_NONBLOCKING create(int numRanks,int rank,ncclUniqueId commId,ncclConfig_t & config)301 static std::shared_ptr<NCCLComm> create( 302 int numRanks, 303 int rank, 304 ncclUniqueId commId, 305 ncclConfig_t& config) { 306 auto comm = std::make_shared<NCCLComm>(); 307 bool isInitialized = false; 308 if (nccl_use_nonblocking()) { 309 config.blocking = 0; 310 LOG(INFO) << "Rank " << rank 311 << ": creating NCCL communicator in nonblocking mode"; 312 C10D_NCCL_CHECK_NONBLOCKING( 313 ncclCommInitRankConfig( 314 &(comm->ncclComm_), numRanks, commId, rank, &config), 315 std::nullopt); 316 } else { 317 C10D_NCCL_CHECK( 318 ncclCommInitRankConfig( 319 &(comm->ncclComm_), numRanks, commId, rank, &config), 320 std::nullopt); 321 // under blocking mode, comm is initialized after NCCL CHECK 322 isInitialized = true; 323 } 324 comm->ncclId_ = commId; 325 comm->rank_ = rank; 326 comm->initialized_ = isInitialized; 327 return comm; 328 } 329 330 static std::shared_ptr<NCCLComm> split( 331 NCCLComm* source, 332 int color_id, 333 int rank, 334 ncclConfig_t& config, 335 std::vector<uint64_t>& ranks_ull); 336 #endif 337 338 #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) ncclCommDump()339 std::unordered_map<std::string, std::string> ncclCommDump() { 340 std::unordered_map<std::string, std::string> dump; 341 if (isAborted()) { 342 LOG(INFO) << "Communicator was aborted before trying to dump its state."; 343 return dump; 344 } 345 C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt); 346 return dump; 347 } 348 #endif 349 getNcclId()350 ncclUniqueId getNcclId() { 351 return ncclId_; 352 } 353 354 // Must not be copyable 355 NCCLComm(const NCCLComm&) = delete; 356 NCCLComm& operator=(const NCCLComm&) = delete; 357 358 // Do not support move assignment as there is no valid use case 359 NCCLComm& operator=(NCCLComm&& other) = delete; 360 361 // Move constructable NCCLComm(NCCLComm && other)362 NCCLComm(NCCLComm&& other) { 363 // Using other's lock, as it reads other's states 364 // Can not use this.mutex_, as this object is being constructed. 365 std::unique_lock<std::mutex> lock(other.mutex_); 366 std::swap(ncclComm_, other.ncclComm_); 367 std::swap(aborted_, other.aborted_); 368 std::swap(ncclAsyncErr_, other.ncclAsyncErr_); 369 std::swap(initialized_, other.initialized_); 370 } 371 372 ncclComm_t getNcclComm(); 373 getNcclCommFailureReason() const374 std::optional<std::string> getNcclCommFailureReason() const { 375 std::unique_lock<std::mutex> lock(mutex_); 376 return commFailureReason_; 377 } 378 ncclCommAbort(std::optional<std::string> commFailureReason=std::nullopt)379 void ncclCommAbort( 380 std::optional<std::string> commFailureReason = std::nullopt) { 381 std::unique_lock<std::mutex> lock(mutex_); 382 #ifdef ENABLE_NCCL_ERROR_CHECKING 383 if (aborted_ && !initialized_) { 384 // Should not abort twice. 385 return; 386 } 387 388 #ifdef NCCL_HAS_COMM_REGISTER 389 // Deregister all registered segments before aborting. 390 for (auto& it : registeredSegmentHandles_) { 391 void* handle = it.second; 392 C10D_NCCL_CHECK( 393 ::ncclCommDeregister(ncclComm_, handle), 394 c10::str( 395 "Failed to deregister segment handle ", 396 handle, 397 " on ncclComm_ ", 398 ncclComm_)); 399 } 400 registeredSegmentHandles_.clear(); 401 #endif 402 403 // Set true failure reason if provided by ProcessGroupNCCL (e.g. work 404 // timeout) 405 commFailureReason_ = commFailureReason; 406 LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: " 407 << (commFailureReason ? *commFailureReason 408 : "No abort reason provided."); 409 #ifndef NCCL_HAS_COMM_NONBLOCKING 410 C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); 411 #else 412 C10D_NCCL_CHECK_TIMEOUT( 413 ::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_); 414 #endif 415 aborted_ = true; 416 ncclComm_ = nullptr; 417 418 // Set an appropriate error so that we avoid using the communicator. 419 if (ncclAsyncErr_ == ncclSuccess) { 420 ncclAsyncErr_ = ncclSystemError; 421 } 422 #else 423 // This is a NOOP, if error checks are disabled. 424 return; 425 #endif 426 } 427 isAborted() const428 bool isAborted() const { 429 std::unique_lock<std::mutex> lock(mutex_); 430 return aborted_; 431 } 432 getCommSplitCounter() const433 uint64_t getCommSplitCounter() const { 434 return ncclCommSplitCounter_; 435 } 436 checkForNcclError()437 ncclResult_t checkForNcclError() { 438 std::unique_lock<std::mutex> lock(mutex_); 439 #ifdef ENABLE_NCCL_ERROR_CHECKING 440 if (ncclAsyncErr_ != ncclSuccess) { 441 return ncclAsyncErr_; 442 } 443 C10D_NCCL_CHECK( 444 ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_); 445 return ncclAsyncErr_; 446 #else 447 // Always return success, if error checks are disabled. 448 return ncclSuccess; 449 #endif 450 } 451 registerSegment(void * ptr,size_t size)452 ncclResult_t registerSegment(void* ptr, size_t size) { 453 std::unique_lock<std::mutex> lock(mutex_); 454 #ifdef NCCL_HAS_COMM_REGISTER 455 // We register only segments from cache allocator 456 // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always 457 // maps to a unique handle and should not be registered before the current 458 // ptr is deregistered and freed. 459 TORCH_CHECK( 460 registeredSegmentHandles_.count(ptr) == 0, 461 "Segment with ptr ", 462 ptr, 463 " has already been registered on ncclComm_ ", 464 ncclComm_); 465 466 void* handle; 467 C10D_NCCL_CHECK( 468 ncclCommRegister(ncclComm_, ptr, size, &handle), 469 c10::str( 470 "Failed to register segment with ptr ", 471 ptr, 472 ", size ", 473 size, 474 " on ncclComm_ ", 475 ncclComm_)); 476 registeredSegmentHandles_[ptr] = handle; 477 return ncclSuccess; 478 #else 479 return ncclInvalidUsage; 480 #endif 481 } 482 deregisterSegment(void * ptr)483 ncclResult_t deregisterSegment(void* ptr) { 484 std::unique_lock<std::mutex> lock(mutex_); 485 #ifdef NCCL_HAS_COMM_REGISTER 486 TORCH_CHECK( 487 registeredSegmentHandles_.count(ptr) == 1, 488 "Segment with ptr ", 489 ptr, 490 " is not registered on ncclComm_ ", 491 ncclComm_); 492 493 void* handle = registeredSegmentHandles_[ptr]; 494 C10D_NCCL_CHECK( 495 ncclCommDeregister(ncclComm_, handle), 496 c10::str( 497 "Failed to deregister segment handle ", 498 handle, 499 ", with ptr ", 500 ptr, 501 " on ncclComm_ ", 502 ncclComm_)); 503 registeredSegmentHandles_.erase(ptr); 504 return ncclSuccess; 505 #else 506 return ncclInvalidUsage; 507 #endif 508 } 509 510 friend class ProcessGroupNCCL; 511 512 protected: 513 // a helper function to wait until the communicator is initialized; 514 void waitUntilInitialized(int timeoutSecs); 515 ncclComm_t ncclComm_; 516 // Unique nccl_id for this communicator. 517 ncclUniqueId ncclId_; 518 bool aborted_; 519 uint64_t ncclCommSplitCounter_{0}; 520 ncclResult_t ncclAsyncErr_; 521 mutable std::mutex mutex_; 522 // Rank that this communicator corresponds to. 523 int rank_; 524 // Optional reason for communicator failure, provided by ProcessGroupNCCL for 525 // better error messaging. 526 std::optional<std::string> commFailureReason_; 527 bool initialized_{false}; 528 #ifdef NCCL_HAS_COMM_REGISTER 529 // Stores handlers for tensors registered by NCCL 530 std::unordered_map<void*, void*> registeredSegmentHandles_; 531 #endif 532 }; 533 534 // Helper that automatically cleans up premul sums. 535 struct ncclRedOpRAII { 536 ncclRedOpRAII() = default; ncclRedOpRAIIc10d::ncclRedOpRAII537 ncclRedOpRAII(ncclRedOp_t op) : op_(op) {} ncclRedOpRAIIc10d::ncclRedOpRAII538 ncclRedOpRAII(ncclRedOp_t op, ncclComm_t comm) 539 : op_(op), comm_(comm), premul_sum_(true) {} 540 ncclRedOpRAII(const ncclRedOpRAII&) = delete; 541 ncclRedOpRAII& operator=(const ncclRedOpRAII&) = delete; ncclRedOpRAIIc10d::ncclRedOpRAII542 ncclRedOpRAII(ncclRedOpRAII&& tmp) : ncclRedOpRAII() { 543 std::swap(tmp.op_, this->op_); 544 std::swap(tmp.comm_, this->comm_); 545 std::swap(tmp.premul_sum_, this->premul_sum_); 546 } 547 #if defined(ENABLE_NCCL_PREMUL_SUM_SUPPORT) ~ncclRedOpRAIIc10d::ncclRedOpRAII548 ~ncclRedOpRAII() { 549 if (premul_sum_) { 550 ncclRedOpDestroy(op_, comm_); 551 } 552 } 553 #endif operator ncclRedOp_tc10d::ncclRedOpRAII554 operator ncclRedOp_t() const { 555 return op_; 556 } 557 ncclRedOp_t op_; 558 ncclComm_t comm_; 559 bool premul_sum_ = false; 560 }; 561 562 /* Helper used by work::getDuration() and nccl flight recorder */ 563 float getDurationFromEvent( 564 at::cuda::CUDAEvent& ncclStartEvent, 565 at::cuda::CUDAEvent& ncclEndEvent); 566 567 struct NCCLTraceBuffer { getc10d::NCCLTraceBuffer568 static NCCLTraceBuffer* get() { 569 // intentionally leak on exit 570 // because this will hold python state that may get destructed 571 static NCCLTraceBuffer* instance = new NCCLTraceBuffer(); 572 return instance; 573 } NCCLTraceBufferc10d::NCCLTraceBuffer574 NCCLTraceBuffer() { 575 max_entries_ = getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0); 576 capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false); 577 enabled_ = max_entries_ > 0; 578 } 579 using Event = at::cuda::CUDAEvent; 580 struct Entry { 581 size_t id_; // incremented id in the trace buffer 582 // used to figure out where in the circular entries 583 // buffer this entry will be located to 584 // update state information 585 size_t pg_id_; 586 std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc> 587 588 // collective_seq_id and p2p_seq_id refer to actual kernel launches (e.g. 1 589 // per coalesced group). 590 // collective_seq_id only increments for true collective operations (over 591 // all ranks in the group). p2p_seq_id only increments over non-collective 592 // operations in the group. op_id refers to logical operations (e.g. one per 593 // op inside coalesced group) 594 size_t collective_seq_id_; 595 size_t p2p_seq_id_; 596 size_t op_id_; 597 std::string profiling_name_; 598 599 std::shared_ptr<torch::CapturedTraceback> traceback_; 600 // we borrow pointers to start_ and end_ so we can query the state 601 // on reporting. However, once the event is completed, the call 602 // to `complete` will clear these. 603 Event *start_, *end_; 604 605 // timestamp when the entry was created, likely close to the time the work 606 // was 'enqueued'- not necessarily started 607 c10::time_t time_created_; 608 609 // configured timeout for this entry 610 c10::time_t timeout_ms_; 611 612 // Is this a P2P event? 613 bool isP2P_; 614 615 std::optional<float> duration_; 616 617 // timestamp when our CPU threads discovered that the kernel started. 618 // will always be _after_ it actually started, and can be very late 619 // if the watchdog thread got stuck on CUDA APIs. 620 std::optional<c10::time_t> time_discovered_started_; 621 622 // timestamp when our CPU threads discovered that the kernel completed. 623 // will always be _after_ it actually complated, and can be the same time 624 // as the discovery of the start if the watchdog thread is stuck on CUDA 625 // APIs 626 std::optional<c10::time_t> time_discovered_completed_; 627 628 // size information for input/output tensors 629 c10::SmallVector<int, 4> input_dims_; 630 std::vector<c10::ScalarType> input_dtypes_; 631 c10::SmallVector<int, 4> output_dims_; 632 std::vector<c10::ScalarType> output_dtypes_; 633 c10::SmallVector<int64_t, 8> sizes_; // flattened from inputs, outputs 634 bool retired_ = false; // is this work entry no longer in the workMetaList_? 635 // a retired but not completed event has timed out 636 }; 637 638 bool enabled_ = false; 639 bool capture_cpp_stack_ = false; 640 std::mutex mutex_; 641 std::vector<Entry> entries_; 642 size_t max_entries_ = 0; 643 size_t next_ = 0; 644 size_t id_ = 0; 645 std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_ = {}; 646 std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>> 647 pg_name_to_ranks_ = {}; 648 649 std::optional<size_t> record( 650 size_t pg_id, 651 const std::tuple<std::string, std::string>& pg_name, 652 size_t collective_seq_id, 653 size_t p2p_seq_id, 654 size_t op_id, 655 std::string profiling_name, 656 const std::vector<at::Tensor>& inputs, 657 const std::vector<at::Tensor>& outputs, 658 Event* start, 659 Event* end, 660 std::chrono::milliseconds timeout_ms, 661 std::shared_ptr<ProcessGroupStatus> pg_status, 662 bool isP2P); 663 664 void record_pg_ranks( 665 const std::tuple<std::string, std::string>& pg_name, 666 std::vector<uint64_t> ranks); 667 668 void update_state(Entry& r); 669 670 std::vector<Entry> dump_entries(); 671 672 /* 673 Mark an Event as completed and free its events. 674 This is called by the watchdog thread, and is asynchronous from the 675 perspective of the main thread. 676 compute_duration defaults to true since retire_id is only called in the 677 watchdog thread, which is currently a place we call cuda APIs which may hang, 678 but care should be taken to avoid computing duration in any function that must 679 never hang. (timing must also be enabled for compute_duration - see 680 TORCH_NCCL_ENABLE_TIMING). 681 */ 682 void retire_id(std::optional<size_t> id, bool compute_duration = true); 683 684 const c10::List<c10::IValue> getCollectiveTrace( 685 bool includeStacktraces, 686 bool onlyActive); 687 688 // dump pg_entries 689 const c10::Dict<c10::IValue, c10::IValue> getPgConfig(); 690 691 const std::map<std::string, std::map<std::string, std::string>> 692 getPgConfigJson(); 693 694 // dump pg_status 695 const c10::Dict<c10::IValue, c10::IValue> getPgStatus(); 696 697 const std::map<std::string, std::map<std::string, std::string>> 698 getPgStatusJson(); 699 700 std::string dump_json( 701 const std::optional<std::unordered_map< 702 std::string, 703 std::unordered_map<std::string, std::string>>>& ncclDumpMap, 704 bool includeCollectives, 705 bool onlyActive); 706 707 // dump all collectives + ncclDumpMap 708 std::string dump( 709 const std::optional<std::unordered_map< 710 std::string, 711 std::unordered_map<std::string, std::string>>>& ncclDumpMap, 712 bool includeCollectives, 713 bool includeStackTraces, 714 bool onlyActive); 715 }; 716 } // namespace c10d 717 718 #endif // USE_C10D_NCCL 719