xref: /aosp_15_r20/external/pytorch/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <chrono>
2 #include <filesystem>
3 #include <fstream>
4 #include <thread>
5 
6 #include <c10/util/irange.h>
7 #include <torch/csrc/cuda/nccl.h>
8 #include <torch/csrc/distributed/c10d/FileStore.hpp>
9 #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
10 #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
11 #include "CUDATest.hpp"
12 #include "TestUtils.hpp"
13 
14 #include <gtest/gtest.h>
15 
16 using namespace c10d::test;
17 
18 constexpr int kNcclErrorHandlingVersion = 2400;
19 
20 class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
21  public:
WorkNCCLSimulateErrors(at::Device & device,bool simulate_error,int rank,c10d::OpType opType,uint64_t seq)22   WorkNCCLSimulateErrors(
23       at::Device& device,
24       bool simulate_error,
25       int rank,
26       c10d::OpType opType,
27       uint64_t seq)
28       : WorkNCCL("0", "default_pg", device, rank, opType, seq),
29         simulateError_(simulate_error) {}
30 
checkForNCCLErrors()31   std::exception_ptr checkForNCCLErrors() override {
32     if (simulateError_) {
33       return std::make_exception_ptr(std::runtime_error("Error"));
34     }
35     return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors();
36   }
37 
38  private:
39   bool simulateError_;
40 };
41 
42 class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
43  public:
ProcessGroupNCCLSimulateErrors(const c10::intrusive_ptr<c10d::Store> & store,int rank,int size,c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)44   ProcessGroupNCCLSimulateErrors(
45       const c10::intrusive_ptr<c10d::Store>& store,
46       int rank,
47       int size,
48       c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
49       : ProcessGroupNCCL(store, rank, size, opts), simulateError_(false) {}
50 
checkForNCCLErrors(std::shared_ptr<c10d::NCCLComm> & ncclComm)51   std::exception_ptr checkForNCCLErrors(
52       std::shared_ptr<c10d::NCCLComm>& ncclComm) override {
53     if (simulateError_) {
54       return std::make_exception_ptr(std::runtime_error("Error"));
55     }
56     return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComm);
57   }
58 
getWatchdogSleepInterval()59   std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
60     return std::chrono::milliseconds(
61         ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis);
62   }
63 
initWork(at::Device & device,int rank,c10d::OpType opType,const char * profilingTitle,const std::vector<at::Tensor> & inputs={},const std::vector<at::Tensor> & outputs={},bool record=false)64   c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
65       at::Device& device,
66       int rank,
67       c10d::OpType opType,
68       const char* profilingTitle,
69       const std::vector<at::Tensor>& inputs = {},
70       const std::vector<at::Tensor>& outputs = {},
71       bool record = false) override {
72     return c10::make_intrusive<WorkNCCLSimulateErrors>(
73         device, simulateError_, rank, opType, seqCollective_);
74   }
75 
getNCCLCommCacheSize()76   size_t getNCCLCommCacheSize() {
77     return devNCCLCommMap_.size();
78   }
79 
simulateError()80   void simulateError() {
81     simulateError_ = true;
82   }
83 
resetError()84   void resetError() {
85     simulateError_ = false;
86   }
87 
88  private:
89   bool simulateError_;
90 };
91 
92 class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
93  public:
WorkNCCLTimedoutErrors(at::Device & device,bool set_timedout_error,int rank,c10d::OpType opType,uint64_t seq)94   WorkNCCLTimedoutErrors(
95       at::Device& device,
96       bool set_timedout_error,
97       int rank,
98       c10d::OpType opType,
99       uint64_t seq)
100       : WorkNCCL("0", "default_pg", device, rank, opType, seq),
101         setTimedoutError_(set_timedout_error) {}
102 
103  private:
isCompleted()104   bool isCompleted() override {
105     if (setTimedoutError_) {
106       return false;
107     }
108     return c10d::ProcessGroupNCCL::WorkNCCL::isCompleted();
109   }
110 
111  private:
112   bool setTimedoutError_;
113 };
114 
115 class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
116  public:
ProcessGroupNCCLTimedOutErrors(const c10::intrusive_ptr<c10d::Store> & store,int rank,int size,c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)117   ProcessGroupNCCLTimedOutErrors(
118       const c10::intrusive_ptr<c10d::Store>& store,
119       int rank,
120       int size,
121       c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
122       : ProcessGroupNCCLSimulateErrors(store, rank, size, opts),
123         watchDogDebugInfoFinished_(false),
124         setTimedoutError_(false) {}
125 
initWork(at::Device & device,int rank,c10d::OpType opType,const char * profilingTitle,const std::vector<at::Tensor> & inputs={},const std::vector<at::Tensor> & outputs={},bool record=false)126   c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
127       at::Device& device,
128       int rank,
129       c10d::OpType opType,
130       const char* profilingTitle,
131       const std::vector<at::Tensor>& inputs = {},
132       const std::vector<at::Tensor>& outputs = {},
133       bool record = false) override {
134     return c10::make_intrusive<WorkNCCLTimedoutErrors>(
135         device, setTimedoutError_, rank, opType, seqCollective_);
136   }
137 
setTimedoutError()138   void setTimedoutError() {
139     setTimedoutError_ = true;
140   }
141 
resetTimedoutError()142   void resetTimedoutError() {
143     setTimedoutError_ = false;
144   }
145 
getWatchDogDebugInfoFinishedFlag()146   bool getWatchDogDebugInfoFinishedFlag() {
147     return watchDogDebugInfoFinished_;
148   }
149 
150   // In the constructor of ProcessGroupNCCL. We don't allow the watchdog thread
151   // to run any handling or desync report when the main thread is block wait.
152   // Even if users set handling and turn on desyncDebug flag, they will get
153   // reset. For the ease of unit test, we want the main thread to be block wait,
154   // so we have this hack to manually set the desync debug flag after PG
155   // creation.
forceSetDesyncDebugFlag()156   void forceSetDesyncDebugFlag() {
157     desyncDebug_ = true;
158   }
159 
160  protected:
getNCCLWatchdogDebugInfo()161   std::string getNCCLWatchdogDebugInfo() override {
162     LOG(INFO) << "overridden getNCCLWatchdogDebugInfo called";
163     watchDogDebugInfoFinished_ = true;
164     return "";
165   }
166   bool watchDogDebugInfoFinished_;
167 
168  private:
169   bool setTimedoutError_;
170 };
171 
172 class ProcessGroupNCCLNoHeartbeatCaught
173     : public ProcessGroupNCCLTimedOutErrors {
174  public:
ProcessGroupNCCLNoHeartbeatCaught(const c10::intrusive_ptr<c10d::Store> & store,int rank,int size,c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)175   ProcessGroupNCCLNoHeartbeatCaught(
176       const c10::intrusive_ptr<c10d::Store>& store,
177       int rank,
178       int size,
179       c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
180       : ProcessGroupNCCLTimedOutErrors(store, rank, size, opts),
181         hasMonitorThreadCaughtError_(false) {}
182 
getWatchdogMutex()183   std::mutex& getWatchdogMutex() {
184     return workMetaListMutex_;
185   }
186 
getErrorCaughtFlag()187   bool getErrorCaughtFlag() {
188     return hasMonitorThreadCaughtError_;
189   }
190 
forceTryWriteDebugInfo()191   void forceTryWriteDebugInfo() {
192     std::future<bool> asyncDebugDump = std::async(
193         std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
194     asyncDebugDump.wait();
195   }
196 
197  protected:
198   // Override the heartbeat monitor function to make sure that we capture
199   // the exception in the monitor thread because we cannot try-catch it in
200   // the main thread and we set a flag for the main thread to check.
heartbeatMonitor()201   void heartbeatMonitor() override {
202     try {
203       c10d::ProcessGroupNCCL::heartbeatMonitor();
204     } catch (std::runtime_error& e) {
205       hasMonitorThreadCaughtError_ = true;
206     }
207   }
208 
209   // It's really hard to unit test std::abort. So we override it instead.
210   // Commented this override, we do see process aborted with core dump without
211   // this override.
terminateProcess(std::string errMsg)212   void terminateProcess(std::string errMsg) override {
213     throw std::runtime_error(errMsg);
214   }
215 
216   bool hasMonitorThreadCaughtError_;
217 };
218 
219 class ProcessGroupNCCLDebugInfoStuck
220     : public ProcessGroupNCCLNoHeartbeatCaught {
221  public:
ProcessGroupNCCLDebugInfoStuck(const c10::intrusive_ptr<c10d::Store> & store,int rank,int size,c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)222   ProcessGroupNCCLDebugInfoStuck(
223       const c10::intrusive_ptr<c10d::Store>& store,
224       int rank,
225       int size,
226       c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
227       : ProcessGroupNCCLNoHeartbeatCaught(store, rank, size, opts) {}
228 
229  protected:
230   // Override the heartbeat monitor function to set a long timeout to mimic the
231   // stuck in getting debug info.
getNCCLWatchdogDebugInfo()232   std::string getNCCLWatchdogDebugInfo() override {
233     std::this_thread::sleep_for(
234         std::chrono::seconds(heartbeatTimeoutInSec_ * 20));
235     watchDogDebugInfoFinished_ = true;
236     return "";
237   }
238 };
239 
240 class ProcessGroupNCCLErrorsTest : public ::testing::Test {
241  protected:
skipTest()242   bool skipTest() {
243     if (cudaNumDevices() == 0) {
244       LOG(INFO) << "Skipping test since CUDA is not available";
245       return true;
246     }
247 #ifdef USE_C10D_NCCL
248     if (torch::cuda::nccl::version() < kNcclErrorHandlingVersion) {
249       LOG(INFO) << "Skipping test since NCCL version is too old";
250       return true;
251     }
252 #endif
253     return false;
254   }
255 
SetUp()256   void SetUp() override {
257     // Enable LOG(INFO) messages.
258     c10::initLogging();
259     // Need to have this check for at SetUp to make sure we only run the test --
260     // including the init -- when there are GPUs available.
261     if (skipTest()) {
262       GTEST_SKIP() << "Skipping ProcessGroupNCCLErrorsTest because system "
263                    << "requirement is not met (no CUDA or GPU).";
264     }
265 
266     size_t numDevices = 1; // One device per rank (thread)
267     TemporaryFile file;
268     store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1);
269 
270     tensors_.resize(numDevices);
271     tensors_[0] = at::empty({3, 3}, at::kCUDA);
272   }
273 
TearDown()274   void TearDown() override {
275     ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);
276   }
277 
278   std::vector<at::Tensor> tensors_;
279   c10::intrusive_ptr<::c10d::FileStore> store_;
280 };
281 
TEST_F(ProcessGroupNCCLErrorsTest,testNCCLErrorsBlocking)282 TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
283   ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
284   auto options = c10d::ProcessGroupNCCL::Options::create();
285   options->timeout = std::chrono::milliseconds(1000);
286   ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
287 
288   auto work = pg.allreduce(tensors_);
289   work->wait();
290   EXPECT_EQ(1, pg.getNCCLCommCacheSize());
291 
292   // Now run all reduce with errors.
293   pg.simulateError();
294   work = pg.allreduce(tensors_);
295   EXPECT_THROW(work->wait(), std::runtime_error);
296 
297   // Verify the work item failed.
298   EXPECT_TRUE(work->isCompleted());
299   EXPECT_THROW(work->wait(), std::runtime_error);
300 
301   // Communicators might be aborted here, further operations would fail.
302 }
303 
TEST_F(ProcessGroupNCCLErrorsTest,testNCCLTimedoutErrorsBlocking)304 TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
305   ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
306   auto options = c10d::ProcessGroupNCCL::Options::create();
307   options->timeout = std::chrono::milliseconds(3000);
308   ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options);
309 
310   auto work = pg.allreduce(tensors_);
311   work->wait();
312   EXPECT_EQ(1, pg.getNCCLCommCacheSize());
313 
314   // Now run all reduce with errors.
315   pg.setTimedoutError();
316   work = pg.allreduce(tensors_);
317   EXPECT_THROW(work->wait(), c10::DistBackendError);
318 
319   // Communicators might be aborted here, further operations would fail.
320 }
321 
TEST_F(ProcessGroupNCCLErrorsTest,testNCCLErrorsNonBlocking)322 TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
323   auto options = c10d::ProcessGroupNCCL::Options::create();
324   options->timeout = std::chrono::milliseconds(3000);
325   ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
326 
327   auto work = pg.allreduce(tensors_);
328   pg.barrier()->wait();
329   EXPECT_EQ(1, pg.getNCCLCommCacheSize());
330 
331   // Now run all reduce with errors.
332   pg.simulateError();
333   work = pg.allreduce(tensors_);
334 
335   // Should not throw exceptions.
336   work->wait();
337   pg.barrier()->wait();
338 
339   EXPECT_TRUE(work->isCompleted());
340   // Communicators might be aborted here, further operations would fail.
341 }
342 
343 // Function to read what we wrote to the local disk for validation.
readTraceFromFile(const std::string & filename,size_t size)344 std::string readTraceFromFile(const std::string& filename, size_t size) {
345   std::ifstream file(filename, std::ios::binary);
346   // Read the strings from the file
347   if (file) { // While the file stream is in good state
348     std::string str(size, '\0');
349     file.read(&str[0], size);
350     if (file) {
351       return str;
352     }
353   }
354   return "";
355 }
356 
357 // Extend the nested class outside the parent class
358 class TestDebugInfoWriter : public c10d::DebugInfoWriter {
359  public:
TestDebugInfoWriter(std::string namePrefix)360   TestDebugInfoWriter(std::string namePrefix)
361       : DebugInfoWriter(namePrefix, 0) {}
362 
write(const std::string & ncclTrace)363   void write(const std::string& ncclTrace) override {
364     traces_.assign(ncclTrace.begin(), ncclTrace.end());
365     c10d::DebugInfoWriter::write(ncclTrace);
366   }
367 
getTraces()368   std::vector<uint8_t>& getTraces() {
369     return traces_;
370   }
371 
372  private:
373   std::vector<uint8_t> traces_;
374 };
375 
TEST_F(ProcessGroupNCCLErrorsTest,testNCCLErrorsNoHeartbeat)376 TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
377   int heartBeatIntervalInSec = 2;
378   std::string timeInterval = std::to_string(heartBeatIntervalInSec);
379   ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
380   ASSERT_TRUE(
381       setenv(
382           c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(),
383           timeInterval.c_str(),
384           1) == 0);
385   ASSERT_TRUE(
386       setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
387   auto tempFilename = c10::str(
388       std::filesystem::temp_directory_path().string(), "/nccl_trace_rank_");
389   ASSERT_TRUE(
390       setenv("TORCH_NCCL_DEBUG_INFO_TEMP_FILE", tempFilename.c_str(), 1) == 0);
391   // Enable nccl flight recorder.
392   ASSERT_TRUE(setenv("TORCH_NCCL_TRACE_BUFFER_SIZE", "10", 1) == 0);
393   ASSERT_TRUE(setenv(c10d::TORCH_NCCL_DUMP_ON_TIMEOUT[0].c_str(), "1", 1) == 0);
394   auto options = c10d::ProcessGroupNCCL::Options::create();
395   // Set a long watchdog timeout, so that we have enough time to lock the
396   // watchdog and let the heartbeat monitor thread to kick in.
397   options->timeout = std::chrono::milliseconds(30000);
398   ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options);
399   // The storer here is very similar to the fallback storer.
400   // The only difference is that we are storing traces also in memory for
401   // validation.
402   std::string fileNamePrefix = c10d::getCvarString(
403       {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
404   std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
405       std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
406   std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
407   c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr));
408 
409   // Normal collective case.
410   auto work = pg.allreduce(tensors_);
411   work->wait();
412 
413   work = pg.allreduce(tensors_);
414   {
415     // Now run all reduce with errors.
416     std::lock_guard<std::mutex> lock(pg.getWatchdogMutex());
417     LOG(INFO) << "Lock watchdog thread.";
418     // Wait long enough before monitor thread throws exceptions.
419     std::this_thread::sleep_for(
420         std::chrono::seconds(heartBeatIntervalInSec * 3));
421     // Check the monitoring thread launched and exception thrown.
422     EXPECT_TRUE(pg.getErrorCaughtFlag());
423   }
424   work->wait();
425   EXPECT_TRUE(traces.size() > 0);
426   auto filename = c10::str(tempFilename, 0);
427   auto traceFromStorage = readTraceFromFile(filename, traces.size());
428   // Check the traces read from storage match with the original nccl trace.
429   EXPECT_TRUE(traceFromStorage == std::string(traces.begin(), traces.end()));
430   std::filesystem::remove(filename);
431 }
432 
433 class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest {
434  protected:
SetUp()435   void SetUp() override {
436     // TODO (kwen2501)
437     GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; "
438                  << "will rewrite them after refactoring Work queues.";
439     ProcessGroupNCCLErrorsTest::SetUp();
440     std::string timeInterval = std::to_string(heartBeatIntervalInSec);
441     ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
442     ASSERT_TRUE(
443         setenv(
444             c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(),
445             timeInterval.c_str(),
446             1) == 0);
447     ASSERT_TRUE(
448         setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
449     ASSERT_TRUE(setenv(c10d::TORCH_NCCL_DESYNC_DEBUG[0].c_str(), "1", 1) == 0);
450     // We cannot capture the exception thrown in watchdog thread without making
451     // lots of changes to the code. So we don't let the watchdog throw
452     // exception.
453     ASSERT_TRUE(
454         setenv(c10d::TORCH_NCCL_ASYNC_ERROR_HANDLING[0].c_str(), "0", 1) == 0);
455     options_ = c10d::ProcessGroupNCCL::Options::create();
456     // Set a super short watchdog timeout.
457     options_->timeout = std::chrono::milliseconds(100);
458   }
459 
watchdogTimeoutTestCommon(ProcessGroupNCCLNoHeartbeatCaught & pg,int multiplier)460   void watchdogTimeoutTestCommon(
461       ProcessGroupNCCLNoHeartbeatCaught& pg,
462       int multiplier) {
463     pg.forceSetDesyncDebugFlag();
464     pg.setTimedoutError();
465     auto work = pg.allreduce(tensors_);
466     std::this_thread::sleep_for(
467         std::chrono::seconds(heartBeatIntervalInSec * multiplier));
468     EXPECT_THROW(work->wait(), c10::DistBackendError);
469   }
470 
471   const int heartBeatIntervalInSec = 2;
472   c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> options_;
473 };
474 
TEST_F(ProcessGroupNCCLWatchdogTimeoutTest,testNCCLTimedoutDebugInfoFinished)475 TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoFinished) {
476   ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options_);
477   // Write debug info will lead to watchdog thread to wait for 30 seconds.
478   // And this is hard to override, so we just call it before hand. Otherwise,
479   // we need to set a long heartbeat timeout which will make the test way
480   // slower.
481   pg.forceTryWriteDebugInfo();
482   watchdogTimeoutTestCommon(pg, 2);
483 
484   // The flag is true shows that the heartbeat monitor thread does not kill
485   // the watchdog thread when it is getting debug info such as desync debug
486   // info.
487   EXPECT_TRUE(pg.getWatchDogDebugInfoFinishedFlag());
488   // The flag is false shows that the heartbeat monitor thread does not
489   // trigger process abort if getting debug info and destroy PG is fast.
490   EXPECT_FALSE(pg.getErrorCaughtFlag());
491 
492   // Communicators might be aborted here, further operations would fail.
493 }
494 
TEST_F(ProcessGroupNCCLWatchdogTimeoutTest,testNCCLTimedoutDebugInfoStuck)495 TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoStuck) {
496   ProcessGroupNCCLDebugInfoStuck pg(store_, 0, 1, options_);
497   // Need to keep main thread sleep longer so that we can let heartbeat monitor
498   // thread to finish the extra wait and flip the flag.
499   watchdogTimeoutTestCommon(pg, 4);
500   // The flag is false shows that we get stuck in getting debug info such as
501   // desync debug info in the watchdog thread.
502   EXPECT_FALSE(pg.getWatchDogDebugInfoFinishedFlag());
503   // The flag is true shows that the heartbeat monitor thread does trigger
504   // process abort if getting debug info gets stuck.
505   EXPECT_TRUE(pg.getErrorCaughtFlag());
506 
507   // Communicators might be aborted here, further operations would fail.
508 }
509