xref: /aosp_15_r20/external/pytorch/test/cpp/c10d/TestUtils.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifndef _WIN32
4 #include <signal.h>
5 #include <sys/wait.h>
6 #include <unistd.h>
7 #endif
8 
9 #include <sys/types.h>
10 #include <cstring>
11 
12 #include <condition_variable>
13 #include <mutex>
14 #include <string>
15 #include <system_error>
16 #include <vector>
17 
18 namespace c10d {
19 namespace test {
20 
21 class Semaphore {
22  public:
post(int n=1)23   void post(int n = 1) {
24     std::unique_lock<std::mutex> lock(m_);
25     n_ += n;
26     cv_.notify_all();
27   }
28 
wait(int n=1)29   void wait(int n = 1) {
30     std::unique_lock<std::mutex> lock(m_);
31     while (n_ < n) {
32       cv_.wait(lock);
33     }
34     n_ -= n;
35   }
36 
37  protected:
38   int n_ = 0;
39   std::mutex m_;
40   std::condition_variable cv_;
41 };
42 
43 #ifdef _WIN32
autoGenerateTmpFilePath()44 std::string autoGenerateTmpFilePath() {
45   char tmp[L_tmpnam_s];
46   errno_t err;
47   err = tmpnam_s(tmp, L_tmpnam_s);
48   if (err != 0)
49   {
50     throw std::system_error(errno, std::system_category());
51   }
52   return std::string(tmp);
53 }
54 
tmppath()55 std::string tmppath() {
56   const char* tmpfile = getenv("TMPFILE");
57   if (tmpfile) {
58     return std::string(tmpfile);
59   }
60   else {
61     return autoGenerateTmpFilePath();
62   }
63 }
64 #else
tmppath()65 std::string tmppath() {
66   // TMPFILE is for manual test execution during which the user will specify
67   // the full temp file path using the environmental variable TMPFILE
68   const char* tmpfile = getenv("TMPFILE");
69   if (tmpfile) {
70     return std::string(tmpfile);
71   }
72 
73   const char* tmpdir = getenv("TMPDIR");
74   if (tmpdir == nullptr) {
75     tmpdir = "/tmp";
76   }
77 
78   // Create template
79   std::vector<char> tmp(256);
80   auto len = snprintf(tmp.data(), tmp.size(), "%s/testXXXXXX", tmpdir);
81   tmp.resize(len);
82 
83   // Create temporary file
84   auto fd = mkstemp(&tmp[0]);
85   if (fd == -1) {
86     throw std::system_error(errno, std::system_category());
87   }
88   close(fd);
89   return std::string(tmp.data(), tmp.size());
90 }
91 #endif
92 
isTSANEnabled()93 bool isTSANEnabled() {
94   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
95   return s && strcmp(s, "1") == 0;
96 }
97 struct TemporaryFile {
98   std::string path;
99 
TemporaryFilec10d::test::TemporaryFile100   TemporaryFile() {
101     path = tmppath();
102   }
103 
~TemporaryFilec10d::test::TemporaryFile104   ~TemporaryFile() {
105     unlink(path.c_str());
106   }
107 };
108 
109 #ifndef _WIN32
110 struct Fork {
111   pid_t pid;
112 
Forkc10d::test::Fork113   Fork() {
114     pid = fork();
115     if (pid < 0) {
116       throw std::system_error(errno, std::system_category(), "fork");
117     }
118   }
119 
~Forkc10d::test::Fork120   ~Fork() {
121     if (pid > 0) {
122       kill(pid, SIGKILL);
123       waitpid(pid, nullptr, 0);
124     }
125   }
126 
isChildc10d::test::Fork127   bool isChild() {
128     return pid == 0;
129   }
130 };
131 #endif
132 
133 } // namespace test
134 } // namespace c10d
135