1 #include <chrono>
2 #include <filesystem>
3 #include <fstream>
4 #include <thread>
5
6 #include <c10/util/irange.h>
7 #include <torch/csrc/cuda/nccl.h>
8 #include <torch/csrc/distributed/c10d/FileStore.hpp>
9 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
10 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
11 #include "CUDATest.hpp"
12 #include "TestUtils.hpp"
13
14 #include <gtest/gtest.h>
15
16 using namespace c10d::test;
17
18 constexpr int kNcclErrorHandlingVersion = 2400;
19
20 class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
21 public:
WorkNCCLSimulateErrors(at::Device & device,bool simulate_error,int rank,c10d::OpType opType,uint64_t seq)22 WorkNCCLSimulateErrors(
23 at::Device& device,
24 bool simulate_error,
25 int rank,
26 c10d::OpType opType,
27 uint64_t seq)
28 : WorkNCCL("0", "default_pg", device, rank, opType, seq),
29 simulateError_(simulate_error) {}
30
checkForNCCLErrors()31 std::exception_ptr checkForNCCLErrors() override {
32 if (simulateError_) {
33 return std::make_exception_ptr(std::runtime_error("Error"));
34 }
35 return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors();
36 }
37
38 private:
39 bool simulateError_;
40 };
41
42 class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
43 public:
ProcessGroupNCCLSimulateErrors(const c10::intrusive_ptr<c10d::Store> & store,int rank,int size,c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)44 ProcessGroupNCCLSimulateErrors(
45 const c10::intrusive_ptr<c10d::Store>& store,
46 int rank,
47 int size,
48 c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
49 : ProcessGroupNCCL(store, rank, size, opts), simulateError_(false) {}
50
checkForNCCLErrors(std::shared_ptr<c10d::NCCLComm> & ncclComm)51 std::exception_ptr checkForNCCLErrors(
52 std::shared_ptr<c10d::NCCLComm>& ncclComm) override {
53 if (simulateError_) {
54 return std::make_exception_ptr(std::runtime_error("Error"));
55 }
56 return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComm);
57 }
58
getWatchdogSleepInterval()59 std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
60 return std::chrono::milliseconds(
61 ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis);
62 }
63
initWork(at::Device & device,int rank,c10d::OpType opType,const char * profilingTitle,const std::vector<at::Tensor> & inputs={},const std::vector<at::Tensor> & outputs={},bool record=false)64 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
65 at::Device& device,
66 int rank,
67 c10d::OpType opType,
68 const char* profilingTitle,
69 const std::vector<at::Tensor>& inputs = {},
70 const std::vector<at::Tensor>& outputs = {},
71 bool record = false) override {
72 return c10::make_intrusive<WorkNCCLSimulateErrors>(
73 device, simulateError_, rank, opType, seqCollective_);
74 }
75
getNCCLCommCacheSize()76 size_t getNCCLCommCacheSize() {
77 return devNCCLCommMap_.size();
78 }
79
simulateError()80 void simulateError() {
81 simulateError_ = true;
82 }
83
resetError()84 void resetError() {
85 simulateError_ = false;
86 }
87
88 private:
89 bool simulateError_;
90 };
91
92 class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
93 public:
WorkNCCLTimedoutErrors(at::Device & device,bool set_timedout_error,int rank,c10d::OpType opType,uint64_t seq)94 WorkNCCLTimedoutErrors(
95 at::Device& device,
96 bool set_timedout_error,
97 int rank,
98 c10d::OpType opType,
99 uint64_t seq)
100 : WorkNCCL("0", "default_pg", device, rank, opType, seq),
101 setTimedoutError_(set_timedout_error) {}
102
103 private:
isCompleted()104 bool isCompleted() override {
105 if (setTimedoutError_) {
106 return false;
107 }
108 return c10d::ProcessGroupNCCL::WorkNCCL::isCompleted();
109 }
110
111 private:
112 bool setTimedoutError_;
113 };
114
115 class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
116 public:
ProcessGroupNCCLTimedOutErrors(const c10::intrusive_ptr<c10d::Store> & store,int rank,int size,c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)117 ProcessGroupNCCLTimedOutErrors(
118 const c10::intrusive_ptr<c10d::Store>& store,
119 int rank,
120 int size,
121 c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
122 : ProcessGroupNCCLSimulateErrors(store, rank, size, opts),
123 watchDogDebugInfoFinished_(false),
124 setTimedoutError_(false) {}
125
initWork(at::Device & device,int rank,c10d::OpType opType,const char * profilingTitle,const std::vector<at::Tensor> & inputs={},const std::vector<at::Tensor> & outputs={},bool record=false)126 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
127 at::Device& device,
128 int rank,
129 c10d::OpType opType,
130 const char* profilingTitle,
131 const std::vector<at::Tensor>& inputs = {},
132 const std::vector<at::Tensor>& outputs = {},
133 bool record = false) override {
134 return c10::make_intrusive<WorkNCCLTimedoutErrors>(
135 device, setTimedoutError_, rank, opType, seqCollective_);
136 }
137
setTimedoutError()138 void setTimedoutError() {
139 setTimedoutError_ = true;
140 }
141
resetTimedoutError()142 void resetTimedoutError() {
143 setTimedoutError_ = false;
144 }
145
getWatchDogDebugInfoFinishedFlag()146 bool getWatchDogDebugInfoFinishedFlag() {
147 return watchDogDebugInfoFinished_;
148 }
149
150 // In the constructor of ProcessGroupNCCL. We don't allow the watchdog thread
151 // to run any handling or desync report when the main thread is block wait.
152 // Even if users set handling and turn on desyncDebug flag, they will get
153 // reset. For the ease of unit test, we want the main thread to be block wait,
154 // so we have this hack to manually set the desync debug flag after PG
155 // creation.
forceSetDesyncDebugFlag()156 void forceSetDesyncDebugFlag() {
157 desyncDebug_ = true;
158 }
159
160 protected:
getNCCLWatchdogDebugInfo()161 std::string getNCCLWatchdogDebugInfo() override {
162 LOG(INFO) << "overridden getNCCLWatchdogDebugInfo called";
163 watchDogDebugInfoFinished_ = true;
164 return "";
165 }
166 bool watchDogDebugInfoFinished_;
167
168 private:
169 bool setTimedoutError_;
170 };
171
172 class ProcessGroupNCCLNoHeartbeatCaught
173 : public ProcessGroupNCCLTimedOutErrors {
174 public:
ProcessGroupNCCLNoHeartbeatCaught(const c10::intrusive_ptr<c10d::Store> & store,int rank,int size,c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)175 ProcessGroupNCCLNoHeartbeatCaught(
176 const c10::intrusive_ptr<c10d::Store>& store,
177 int rank,
178 int size,
179 c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
180 : ProcessGroupNCCLTimedOutErrors(store, rank, size, opts),
181 hasMonitorThreadCaughtError_(false) {}
182
getWatchdogMutex()183 std::mutex& getWatchdogMutex() {
184 return workMetaListMutex_;
185 }
186
getErrorCaughtFlag()187 bool getErrorCaughtFlag() {
188 return hasMonitorThreadCaughtError_;
189 }
190
forceTryWriteDebugInfo()191 void forceTryWriteDebugInfo() {
192 std::future<bool> asyncDebugDump = std::async(
193 std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
194 asyncDebugDump.wait();
195 }
196
197 protected:
198 // Override the heartbeat monitor function to make sure that we capture
199 // the exception in the monitor thread because we cannot try-catch it in
200 // the main thread and we set a flag for the main thread to check.
heartbeatMonitor()201 void heartbeatMonitor() override {
202 try {
203 c10d::ProcessGroupNCCL::heartbeatMonitor();
204 } catch (std::runtime_error& e) {
205 hasMonitorThreadCaughtError_ = true;
206 }
207 }
208
209 // It's really hard to unit test std::abort. So we override it instead.
210 // Commented this override, we do see process aborted with core dump without
211 // this override.
terminateProcess(std::string errMsg)212 void terminateProcess(std::string errMsg) override {
213 throw std::runtime_error(errMsg);
214 }
215
216 bool hasMonitorThreadCaughtError_;
217 };
218
219 class ProcessGroupNCCLDebugInfoStuck
220 : public ProcessGroupNCCLNoHeartbeatCaught {
221 public:
ProcessGroupNCCLDebugInfoStuck(const c10::intrusive_ptr<c10d::Store> & store,int rank,int size,c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)222 ProcessGroupNCCLDebugInfoStuck(
223 const c10::intrusive_ptr<c10d::Store>& store,
224 int rank,
225 int size,
226 c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
227 : ProcessGroupNCCLNoHeartbeatCaught(store, rank, size, opts) {}
228
229 protected:
230 // Override the heartbeat monitor function to set a long timeout to mimic the
231 // stuck in getting debug info.
getNCCLWatchdogDebugInfo()232 std::string getNCCLWatchdogDebugInfo() override {
233 std::this_thread::sleep_for(
234 std::chrono::seconds(heartbeatTimeoutInSec_ * 20));
235 watchDogDebugInfoFinished_ = true;
236 return "";
237 }
238 };
239
240 class ProcessGroupNCCLErrorsTest : public ::testing::Test {
241 protected:
skipTest()242 bool skipTest() {
243 if (cudaNumDevices() == 0) {
244 LOG(INFO) << "Skipping test since CUDA is not available";
245 return true;
246 }
247 #ifdef USE_C10D_NCCL
248 if (torch::cuda::nccl::version() < kNcclErrorHandlingVersion) {
249 LOG(INFO) << "Skipping test since NCCL version is too old";
250 return true;
251 }
252 #endif
253 return false;
254 }
255
SetUp()256 void SetUp() override {
257 // Enable LOG(INFO) messages.
258 c10::initLogging();
259 // Need to have this check for at SetUp to make sure we only run the test --
260 // including the init -- when there are GPUs available.
261 if (skipTest()) {
262 GTEST_SKIP() << "Skipping ProcessGroupNCCLErrorsTest because system "
263 << "requirement is not met (no CUDA or GPU).";
264 }
265
266 size_t numDevices = 1; // One device per rank (thread)
267 TemporaryFile file;
268 store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1);
269
270 tensors_.resize(numDevices);
271 tensors_[0] = at::empty({3, 3}, at::kCUDA);
272 }
273
TearDown()274 void TearDown() override {
275 ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);
276 }
277
278 std::vector<at::Tensor> tensors_;
279 c10::intrusive_ptr<::c10d::FileStore> store_;
280 };
281
TEST_F(ProcessGroupNCCLErrorsTest,testNCCLErrorsBlocking)282 TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
283 ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
284 auto options = c10d::ProcessGroupNCCL::Options::create();
285 options->timeout = std::chrono::milliseconds(1000);
286 ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
287
288 auto work = pg.allreduce(tensors_);
289 work->wait();
290 EXPECT_EQ(1, pg.getNCCLCommCacheSize());
291
292 // Now run all reduce with errors.
293 pg.simulateError();
294 work = pg.allreduce(tensors_);
295 EXPECT_THROW(work->wait(), std::runtime_error);
296
297 // Verify the work item failed.
298 EXPECT_TRUE(work->isCompleted());
299 EXPECT_THROW(work->wait(), std::runtime_error);
300
301 // Communicators might be aborted here, further operations would fail.
302 }
303
TEST_F(ProcessGroupNCCLErrorsTest,testNCCLTimedoutErrorsBlocking)304 TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
305 ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
306 auto options = c10d::ProcessGroupNCCL::Options::create();
307 options->timeout = std::chrono::milliseconds(3000);
308 ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options);
309
310 auto work = pg.allreduce(tensors_);
311 work->wait();
312 EXPECT_EQ(1, pg.getNCCLCommCacheSize());
313
314 // Now run all reduce with errors.
315 pg.setTimedoutError();
316 work = pg.allreduce(tensors_);
317 EXPECT_THROW(work->wait(), c10::DistBackendError);
318
319 // Communicators might be aborted here, further operations would fail.
320 }
321
TEST_F(ProcessGroupNCCLErrorsTest,testNCCLErrorsNonBlocking)322 TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
323 auto options = c10d::ProcessGroupNCCL::Options::create();
324 options->timeout = std::chrono::milliseconds(3000);
325 ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
326
327 auto work = pg.allreduce(tensors_);
328 pg.barrier()->wait();
329 EXPECT_EQ(1, pg.getNCCLCommCacheSize());
330
331 // Now run all reduce with errors.
332 pg.simulateError();
333 work = pg.allreduce(tensors_);
334
335 // Should not throw exceptions.
336 work->wait();
337 pg.barrier()->wait();
338
339 EXPECT_TRUE(work->isCompleted());
340 // Communicators might be aborted here, further operations would fail.
341 }
342
343 // Function to read what we wrote to the local disk for validation.
readTraceFromFile(const std::string & filename,size_t size)344 std::string readTraceFromFile(const std::string& filename, size_t size) {
345 std::ifstream file(filename, std::ios::binary);
346 // Read the strings from the file
347 if (file) { // While the file stream is in good state
348 std::string str(size, '\0');
349 file.read(&str[0], size);
350 if (file) {
351 return str;
352 }
353 }
354 return "";
355 }
356
357 // Extend the nested class outside the parent class
358 class TestDebugInfoWriter : public c10d::DebugInfoWriter {
359 public:
TestDebugInfoWriter(std::string namePrefix)360 TestDebugInfoWriter(std::string namePrefix)
361 : DebugInfoWriter(namePrefix, 0) {}
362
write(const std::string & ncclTrace)363 void write(const std::string& ncclTrace) override {
364 traces_.assign(ncclTrace.begin(), ncclTrace.end());
365 c10d::DebugInfoWriter::write(ncclTrace);
366 }
367
getTraces()368 std::vector<uint8_t>& getTraces() {
369 return traces_;
370 }
371
372 private:
373 std::vector<uint8_t> traces_;
374 };
375
TEST_F(ProcessGroupNCCLErrorsTest,testNCCLErrorsNoHeartbeat)376 TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
377 int heartBeatIntervalInSec = 2;
378 std::string timeInterval = std::to_string(heartBeatIntervalInSec);
379 ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
380 ASSERT_TRUE(
381 setenv(
382 c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(),
383 timeInterval.c_str(),
384 1) == 0);
385 ASSERT_TRUE(
386 setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
387 auto tempFilename = c10::str(
388 std::filesystem::temp_directory_path().string(), "/nccl_trace_rank_");
389 ASSERT_TRUE(
390 setenv("TORCH_NCCL_DEBUG_INFO_TEMP_FILE", tempFilename.c_str(), 1) == 0);
391 // Enable nccl flight recorder.
392 ASSERT_TRUE(setenv("TORCH_NCCL_TRACE_BUFFER_SIZE", "10", 1) == 0);
393 ASSERT_TRUE(setenv(c10d::TORCH_NCCL_DUMP_ON_TIMEOUT[0].c_str(), "1", 1) == 0);
394 auto options = c10d::ProcessGroupNCCL::Options::create();
395 // Set a long watchdog timeout, so that we have enough time to lock the
396 // watchdog and let the heartbeat monitor thread to kick in.
397 options->timeout = std::chrono::milliseconds(30000);
398 ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options);
399 // The storer here is very similar to the fallback storer.
400 // The only difference is that we are storing traces also in memory for
401 // validation.
402 std::string fileNamePrefix = c10d::getCvarString(
403 {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
404 std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
405 std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
406 std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
407 c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr));
408
409 // Normal collective case.
410 auto work = pg.allreduce(tensors_);
411 work->wait();
412
413 work = pg.allreduce(tensors_);
414 {
415 // Now run all reduce with errors.
416 std::lock_guard<std::mutex> lock(pg.getWatchdogMutex());
417 LOG(INFO) << "Lock watchdog thread.";
418 // Wait long enough before monitor thread throws exceptions.
419 std::this_thread::sleep_for(
420 std::chrono::seconds(heartBeatIntervalInSec * 3));
421 // Check the monitoring thread launched and exception thrown.
422 EXPECT_TRUE(pg.getErrorCaughtFlag());
423 }
424 work->wait();
425 EXPECT_TRUE(traces.size() > 0);
426 auto filename = c10::str(tempFilename, 0);
427 auto traceFromStorage = readTraceFromFile(filename, traces.size());
428 // Check the traces read from storage match with the original nccl trace.
429 EXPECT_TRUE(traceFromStorage == std::string(traces.begin(), traces.end()));
430 std::filesystem::remove(filename);
431 }
432
433 class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest {
434 protected:
SetUp()435 void SetUp() override {
436 // TODO (kwen2501)
437 GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; "
438 << "will rewrite them after refactoring Work queues.";
439 ProcessGroupNCCLErrorsTest::SetUp();
440 std::string timeInterval = std::to_string(heartBeatIntervalInSec);
441 ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
442 ASSERT_TRUE(
443 setenv(
444 c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(),
445 timeInterval.c_str(),
446 1) == 0);
447 ASSERT_TRUE(
448 setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
449 ASSERT_TRUE(setenv(c10d::TORCH_NCCL_DESYNC_DEBUG[0].c_str(), "1", 1) == 0);
450 // We cannot capture the exception thrown in watchdog thread without making
451 // lots of changes to the code. So we don't let the watchdog throw
452 // exception.
453 ASSERT_TRUE(
454 setenv(c10d::TORCH_NCCL_ASYNC_ERROR_HANDLING[0].c_str(), "0", 1) == 0);
455 options_ = c10d::ProcessGroupNCCL::Options::create();
456 // Set a super short watchdog timeout.
457 options_->timeout = std::chrono::milliseconds(100);
458 }
459
watchdogTimeoutTestCommon(ProcessGroupNCCLNoHeartbeatCaught & pg,int multiplier)460 void watchdogTimeoutTestCommon(
461 ProcessGroupNCCLNoHeartbeatCaught& pg,
462 int multiplier) {
463 pg.forceSetDesyncDebugFlag();
464 pg.setTimedoutError();
465 auto work = pg.allreduce(tensors_);
466 std::this_thread::sleep_for(
467 std::chrono::seconds(heartBeatIntervalInSec * multiplier));
468 EXPECT_THROW(work->wait(), c10::DistBackendError);
469 }
470
471 const int heartBeatIntervalInSec = 2;
472 c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> options_;
473 };
474
TEST_F(ProcessGroupNCCLWatchdogTimeoutTest,testNCCLTimedoutDebugInfoFinished)475 TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoFinished) {
476 ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options_);
477 // Write debug info will lead to watchdog thread to wait for 30 seconds.
478 // And this is hard to override, so we just call it before hand. Otherwise,
479 // we need to set a long heartbeat timeout which will make the test way
480 // slower.
481 pg.forceTryWriteDebugInfo();
482 watchdogTimeoutTestCommon(pg, 2);
483
484 // The flag is true shows that the heartbeat monitor thread does not kill
485 // the watchdog thread when it is getting debug info such as desync debug
486 // info.
487 EXPECT_TRUE(pg.getWatchDogDebugInfoFinishedFlag());
488 // The flag is false shows that the heartbeat monitor thread does not
489 // trigger process abort if getting debug info and destroy PG is fast.
490 EXPECT_FALSE(pg.getErrorCaughtFlag());
491
492 // Communicators might be aborted here, further operations would fail.
493 }
494
TEST_F(ProcessGroupNCCLWatchdogTimeoutTest,testNCCLTimedoutDebugInfoStuck)495 TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoStuck) {
496 ProcessGroupNCCLDebugInfoStuck pg(store_, 0, 1, options_);
497 // Need to keep main thread sleep longer so that we can let heartbeat monitor
498 // thread to finish the extra wait and flip the flag.
499 watchdogTimeoutTestCommon(pg, 4);
500 // The flag is false shows that we get stuck in getting debug info such as
501 // desync debug info in the watchdog thread.
502 EXPECT_FALSE(pg.getWatchDogDebugInfoFinishedFlag());
503 // The flag is true shows that the heartbeat monitor thread does trigger
504 // process abort if getting debug info gets stuck.
505 EXPECT_TRUE(pg.getErrorCaughtFlag());
506
507 // Communicators might be aborted here, further operations would fail.
508 }
509