xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/socket.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <torch/csrc/distributed/c10d/socket.h>
8 
9 #include <cstring>
10 #include <optional>
11 #include <system_error>
12 #include <utility>
13 #include <vector>
14 
15 #ifdef _WIN32
16 #include <mutex>
17 
18 #include <winsock2.h>
19 #include <ws2tcpip.h>
20 #else
21 #include <arpa/inet.h>
22 #include <fcntl.h>
23 #include <netdb.h>
24 #include <netinet/tcp.h>
25 #include <poll.h>
26 #include <sys/socket.h>
27 #include <sys/types.h>
28 #include <unistd.h>
29 #endif
30 
31 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated")
32 #include <fmt/chrono.h>
33 C10_DIAGNOSTIC_POP()
34 #include <fmt/format.h>
35 #include <fmt/ranges.h>
36 
37 #include <torch/csrc/distributed/c10d/error.h>
38 #include <torch/csrc/distributed/c10d/exception.h>
39 #include <torch/csrc/distributed/c10d/logging.h>
40 
41 #include <c10/util/CallOnce.h>
42 
43 namespace c10d::detail {
44 namespace {
45 #ifdef _WIN32
46 
47 // Since Winsock uses the name `WSAPoll` instead of `poll`, we alias it here
48 // to avoid #ifdefs in the source code.
49 const auto pollFd = ::WSAPoll;
50 
51 // Winsock's `getsockopt()` and `setsockopt()` functions expect option values to
52 // be passed as `char*` instead of `void*`. We wrap them here to avoid redundant
53 // casts in the source code.
getSocketOption(SOCKET s,int level,int optname,void * optval,int * optlen)54 int getSocketOption(
55     SOCKET s,
56     int level,
57     int optname,
58     void* optval,
59     int* optlen) {
60   return ::getsockopt(s, level, optname, static_cast<char*>(optval), optlen);
61 }
62 
setSocketOption(SOCKET s,int level,int optname,const void * optval,int optlen)63 int setSocketOption(
64     SOCKET s,
65     int level,
66     int optname,
67     const void* optval,
68     int optlen) {
69   return ::setsockopt(
70       s, level, optname, static_cast<const char*>(optval), optlen);
71 }
72 
73 // Winsock has its own error codes which differ from Berkeley's. Fortunately the
74 // C++ Standard Library on Windows can map them to standard error codes.
getSocketError()75 inline std::error_code getSocketError() noexcept {
76   return std::error_code{::WSAGetLastError(), std::system_category()};
77 }
78 
setSocketError(int val)79 inline void setSocketError(int val) noexcept {
80   ::WSASetLastError(val);
81 }
82 
83 #else
84 
85 const auto pollFd = ::poll;
86 
87 const auto getSocketOption = ::getsockopt;
88 const auto setSocketOption = ::setsockopt;
89 
90 inline std::error_code getSocketError() noexcept {
91   return lastError();
92 }
93 
94 inline void setSocketError(int val) noexcept {
95   errno = val;
96 }
97 
98 #endif
99 
100 // Suspends the current thread for the specified duration.
delay(std::chrono::milliseconds d)101 void delay(std::chrono::milliseconds d) {
102 #ifdef _WIN32
103   std::this_thread::sleep_for(d);
104 #else
105   ::timespec req{};
106   auto ms = d.count();
107   req.tv_sec = ms / 1000;
108   req.tv_nsec = (ms % 1000) * 1000000;
109 
110   // The C++ Standard does not specify whether `sleep_for()` should be signal-
111   // aware; therefore, we use the `nanosleep()` syscall.
112   if (::nanosleep(&req, nullptr) != 0) {
113     std::error_code err = getSocketError();
114     // We don't care about error conditions other than EINTR since a failure
115     // here is not critical.
116     if (err == std::errc::interrupted) {
117       C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
118     }
119   }
120 #endif
121 }
122 
123 class SocketListenOp;
124 class SocketConnectOp;
125 } // namespace
126 
127 class SocketImpl {
128   friend class SocketListenOp;
129   friend class SocketConnectOp;
130 
131  public:
132 #ifdef _WIN32
133   using Handle = SOCKET;
134 #else
135   using Handle = int;
136 #endif
137 
138 #ifdef _WIN32
139   static constexpr Handle invalid_socket = INVALID_SOCKET;
140 #else
141   static constexpr Handle invalid_socket = -1;
142 #endif
143 
SocketImpl(Handle hnd)144   explicit SocketImpl(Handle hnd) noexcept : hnd_{hnd} {}
145 
146   explicit SocketImpl(Handle hnd, const ::addrinfo& remote);
147 
148   SocketImpl(const SocketImpl& other) = delete;
149 
150   SocketImpl& operator=(const SocketImpl& other) = delete;
151 
152   SocketImpl(SocketImpl&& other) noexcept = delete;
153 
154   SocketImpl& operator=(SocketImpl&& other) noexcept = delete;
155 
156   ~SocketImpl();
157 
158   std::unique_ptr<SocketImpl> accept() const;
159 
160   void closeOnExec() noexcept;
161 
162   void enableNonBlocking();
163 
164   void disableNonBlocking();
165 
166   bool enableNoDelay() noexcept;
167 
168   bool enableDualStack() noexcept;
169 
170 #ifndef _WIN32
171   bool enableAddressReuse() noexcept;
172 #endif
173 
174 #ifdef _WIN32
175   bool enableExclusiveAddressUse() noexcept;
176 #endif
177 
178   std::uint16_t getPort() const;
179 
handle() const180   Handle handle() const noexcept {
181     return hnd_;
182   }
183 
remote() const184   const std::optional<std::string>& remote() const noexcept {
185     return remote_;
186   }
187 
188   bool waitForInput(std::chrono::milliseconds timeout);
189 
190  private:
191   bool setSocketFlag(int level, int optname, bool value) noexcept;
192 
193   Handle hnd_;
194   const std::optional<std::string> remote_;
195 };
196 } // namespace c10d::detail
197 
198 //
199 // libfmt formatters for `addrinfo` and `Socket`
200 //
201 namespace fmt {
202 
203 template <>
204 struct formatter<::addrinfo> {
parsefmt::formatter205   constexpr decltype(auto) parse(format_parse_context& ctx) const {
206     return ctx.begin();
207   }
208 
209   template <typename FormatContext>
formatfmt::formatter210   decltype(auto) format(const ::addrinfo& addr, FormatContext& ctx) const {
211     char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT
212 
213     int r = ::getnameinfo(
214         addr.ai_addr,
215         addr.ai_addrlen,
216         host,
217         NI_MAXHOST,
218         port,
219         NI_MAXSERV,
220         NI_NUMERICSERV);
221     if (r != 0) {
222       // if we can't resolve the hostname, display the IP address
223       if (addr.ai_family == AF_INET) {
224         struct sockaddr_in* psai = (struct sockaddr_in*)addr.ai_addr;
225         char ip[INET_ADDRSTRLEN];
226         if (inet_ntop(addr.ai_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) !=
227             NULL) {
228           return fmt::format_to(ctx.out(), "{}:{}", ip, psai->sin_port);
229         }
230       } else if (addr.ai_family == AF_INET6) {
231         struct sockaddr_in6* psai = (struct sockaddr_in6*)addr.ai_addr;
232         char ip[INET6_ADDRSTRLEN];
233         if (inet_ntop(
234                 addr.ai_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) !=
235             NULL) {
236           return fmt::format_to(ctx.out(), "[{}]:{}", ip, psai->sin6_port);
237         }
238       }
239       C10_THROW_ERROR(
240           DistNetworkError,
241           fmt::format(
242               "failed to format addr, unknown family={}", addr.ai_family));
243     }
244 
245     if (addr.ai_addr->sa_family == AF_INET) {
246       return fmt::format_to(ctx.out(), "{}:{}", host, port);
247     } else {
248       return fmt::format_to(ctx.out(), "[{}]:{}", host, port);
249     }
250   }
251 };
252 
253 template <>
254 struct formatter<c10d::detail::SocketImpl> {
parsefmt::formatter255   constexpr decltype(auto) parse(format_parse_context& ctx) const {
256     return ctx.begin();
257   }
258 
259   template <typename FormatContext>
formatfmt::formatter260   decltype(auto) format(
261       const c10d::detail::SocketImpl& socket,
262       FormatContext& ctx) const {
263     ::sockaddr_storage addr_s{};
264 
265     auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s);
266 
267     ::socklen_t addr_len = sizeof(addr_s);
268 
269     auto fd = socket.handle();
270 
271     if (::getsockname(fd, addr_ptr, &addr_len) != 0) {
272       return fmt::format_to(ctx.out(), "?UNKNOWN?");
273     }
274 
275     ::addrinfo addr{};
276     addr.ai_addr = addr_ptr;
277     addr.ai_addrlen = addr_len;
278 
279     auto remote = socket.remote();
280     std::string remoteStr = remote ? *remote : "none";
281 
282     return fmt::format_to(
283         ctx.out(),
284         "SocketImpl(fd={}, addr={}, remote={})",
285         fd,
286         addr,
287         remoteStr);
288   }
289 };
290 
291 } // namespace fmt
292 
293 namespace c10d::detail {
294 
SocketImpl(Handle hnd,const::addrinfo & remote)295 SocketImpl::SocketImpl(Handle hnd, const ::addrinfo& remote)
296     : hnd_{hnd}, remote_{fmt::format("{}", remote)} {}
297 
~SocketImpl()298 SocketImpl::~SocketImpl() {
299 #ifdef _WIN32
300   ::closesocket(hnd_);
301 #else
302   ::close(hnd_);
303 #endif
304 }
305 
accept() const306 std::unique_ptr<SocketImpl> SocketImpl::accept() const {
307   ::sockaddr_storage addr_s{};
308 
309   auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s);
310 
311   ::socklen_t addr_len = sizeof(addr_s);
312 
313   Handle hnd = ::accept(hnd_, addr_ptr, &addr_len);
314   if (hnd == invalid_socket) {
315     std::error_code err = getSocketError();
316     if (err == std::errc::interrupted) {
317       C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
318     }
319 
320     std::string msg{};
321     if (err == std::errc::invalid_argument) {
322       msg = fmt::format(
323           "The server socket on {} is not listening for connections.", *this);
324     } else {
325       msg = fmt::format(
326           "The server socket on {} has failed to accept a connection {}.",
327           *this,
328           err);
329     }
330 
331     C10D_ERROR(msg);
332 
333     C10D_THROW_ERROR(SocketError, msg);
334   }
335 
336   ::addrinfo addr{};
337   addr.ai_addr = addr_ptr;
338   addr.ai_addrlen = addr_len;
339 
340   C10D_DEBUG(
341       "The server socket on {} has accepted a connection from {}.",
342       *this,
343       addr);
344 
345   auto impl = std::make_unique<SocketImpl>(hnd, addr);
346 
347   // Make sure that we do not "leak" our file descriptors to child processes.
348   impl->closeOnExec();
349 
350   if (!impl->enableNoDelay()) {
351     C10D_WARNING(
352         "The no-delay option cannot be enabled for the client socket on {}.",
353         addr);
354   }
355 
356   return impl;
357 }
358 
closeOnExec()359 void SocketImpl::closeOnExec() noexcept {
360 #ifndef _WIN32
361   ::fcntl(hnd_, F_SETFD, FD_CLOEXEC);
362 #endif
363 }
364 
enableNonBlocking()365 void SocketImpl::enableNonBlocking() {
366 #ifdef _WIN32
367   unsigned long value = 1;
368   if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) {
369     return;
370   }
371 #else
372   int flg = ::fcntl(hnd_, F_GETFL);
373   if (flg != -1) {
374     if (::fcntl(hnd_, F_SETFL, flg | O_NONBLOCK) == 0) {
375       return;
376     }
377   }
378 #endif
379   C10D_THROW_ERROR(
380       SocketError, "The socket cannot be switched to non-blocking mode.");
381 }
382 
383 // TODO: Remove once we migrate everything to non-blocking mode.
disableNonBlocking()384 void SocketImpl::disableNonBlocking() {
385 #ifdef _WIN32
386   unsigned long value = 0;
387   if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) {
388     return;
389   }
390 #else
391   int flg = ::fcntl(hnd_, F_GETFL);
392   if (flg != -1) {
393     if (::fcntl(hnd_, F_SETFL, flg & ~O_NONBLOCK) == 0) {
394       return;
395     }
396   }
397 #endif
398   C10D_THROW_ERROR(
399       SocketError, "The socket cannot be switched to blocking mode.");
400 }
401 
enableNoDelay()402 bool SocketImpl::enableNoDelay() noexcept {
403   return setSocketFlag(IPPROTO_TCP, TCP_NODELAY, true);
404 }
405 
enableDualStack()406 bool SocketImpl::enableDualStack() noexcept {
407   return setSocketFlag(IPPROTO_IPV6, IPV6_V6ONLY, false);
408 }
409 
410 #ifndef _WIN32
enableAddressReuse()411 bool SocketImpl::enableAddressReuse() noexcept {
412   return setSocketFlag(SOL_SOCKET, SO_REUSEADDR, true);
413 }
414 #endif
415 
416 #ifdef _WIN32
enableExclusiveAddressUse()417 bool SocketImpl::enableExclusiveAddressUse() noexcept {
418   return setSocketFlag(SOL_SOCKET, SO_EXCLUSIVEADDRUSE, true);
419 }
420 #endif
421 
getPort() const422 std::uint16_t SocketImpl::getPort() const {
423   ::sockaddr_storage addr_s{};
424 
425   ::socklen_t addr_len = sizeof(addr_s);
426 
427   if (::getsockname(hnd_, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) !=
428       0) {
429     C10D_THROW_ERROR(
430         SocketError, "The port number of the socket cannot be retrieved.");
431   }
432 
433   if (addr_s.ss_family == AF_INET) {
434     return ntohs(reinterpret_cast<::sockaddr_in*>(&addr_s)->sin_port);
435   } else {
436     return ntohs(reinterpret_cast<::sockaddr_in6*>(&addr_s)->sin6_port);
437   }
438 }
439 
setSocketFlag(int level,int optname,bool value)440 bool SocketImpl::setSocketFlag(int level, int optname, bool value) noexcept {
441 #ifdef _WIN32
442   auto buf = value ? TRUE : FALSE;
443 #else
444   auto buf = value ? 1 : 0;
445 #endif
446   return setSocketOption(hnd_, level, optname, &buf, sizeof(buf)) == 0;
447 }
448 
waitForInput(std::chrono::milliseconds timeout)449 bool SocketImpl::waitForInput(std::chrono::milliseconds timeout) {
450   using Clock = std::chrono::steady_clock;
451 
452   auto deadline = Clock::now() + timeout;
453   do {
454     ::pollfd pfd{};
455     pfd.fd = hnd_;
456     pfd.events = POLLIN;
457 
458     int res = pollFd(&pfd, 1, static_cast<int>(timeout.count()));
459     if (res > 0) {
460       return true;
461     } else if (res == 0) {
462       C10D_WARNING(
463           "waitForInput: poll for socket {} returned 0, likely a timeout",
464           *this);
465       continue;
466     }
467 
468     std::error_code err = getSocketError();
469     if (err == std::errc::operation_in_progress) {
470       bool timedout = Clock::now() >= deadline;
471       if (timedout) {
472         return false;
473       }
474       C10D_WARNING(
475           "waitForInput: poll for socket {} returned operation_in_progress before a timeout",
476           *this);
477     } else if (err != std::errc::interrupted) {
478       C10D_WARNING(
479           "waitForInput: poll for socket {} failed with res={}, err={}.",
480           *this,
481           res,
482           err);
483       return false;
484     }
485   } while (Clock::now() < deadline);
486 
487   C10D_WARNING(
488       "waitForInput: socket {} timed out after {}ms", *this, timeout.count());
489   return false;
490 }
491 
492 namespace {
493 
494 struct addrinfo_delete {
operator ()c10d::detail::__anonac1afc6b0211::addrinfo_delete495   void operator()(::addrinfo* addr) const noexcept {
496     ::freeaddrinfo(addr);
497   }
498 };
499 
500 using addrinfo_ptr = std::unique_ptr<::addrinfo, addrinfo_delete>;
501 
502 class SocketListenOp {
503  public:
504   SocketListenOp(std::uint16_t port, const SocketOptions& opts);
505 
506   std::unique_ptr<SocketImpl> run();
507 
508  private:
509   bool tryListen(int family);
510 
511   bool tryListen(const ::addrinfo& addr);
512 
513   template <typename... Args>
514   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
recordError(fmt::string_view format,Args &&...args)515   void recordError(fmt::string_view format, Args&&... args) {
516     auto msg = fmt::vformat(format, fmt::make_format_args(args...));
517 
518     C10D_WARNING(msg);
519 
520     errors_.emplace_back(std::move(msg));
521   }
522 
523   std::string port_;
524   const SocketOptions* opts_;
525   std::vector<std::string> errors_{};
526   std::unique_ptr<SocketImpl> socket_{};
527 };
528 
SocketListenOp(std::uint16_t port,const SocketOptions & opts)529 SocketListenOp::SocketListenOp(std::uint16_t port, const SocketOptions& opts)
530     : port_{fmt::to_string(port)}, opts_{&opts} {}
531 
run()532 std::unique_ptr<SocketImpl> SocketListenOp::run() {
533   if (opts_->prefer_ipv6()) {
534     C10D_DEBUG("The server socket will attempt to listen on an IPv6 address.");
535     if (tryListen(AF_INET6)) {
536       return std::move(socket_);
537     }
538 
539     C10D_DEBUG("The server socket will attempt to listen on an IPv4 address.");
540     if (tryListen(AF_INET)) {
541       return std::move(socket_);
542     }
543   } else {
544     C10D_DEBUG(
545         "The server socket will attempt to listen on an IPv4 or IPv6 address.");
546     if (tryListen(AF_UNSPEC)) {
547       return std::move(socket_);
548     }
549   }
550 
551   constexpr auto* msg =
552       "The server socket has failed to listen on any local network address.";
553 
554   C10D_ERROR(msg);
555 
556   C10D_THROW_ERROR(
557       SocketError, fmt::format("{} {}", msg, fmt::join(errors_, " ")));
558 }
559 
tryListen(int family)560 bool SocketListenOp::tryListen(int family) {
561   ::addrinfo hints{}, *naked_result = nullptr;
562 
563   hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV;
564   hints.ai_family = family;
565   hints.ai_socktype = SOCK_STREAM;
566 
567   int r = ::getaddrinfo(nullptr, port_.c_str(), &hints, &naked_result);
568   if (r != 0) {
569     const char* gai_err = ::gai_strerror(r);
570 
571     recordError(
572         "The local {}network addresses cannot be retrieved (gai error: {} - {}).",
573         family == AF_INET        ? "IPv4 "
574             : family == AF_INET6 ? "IPv6 "
575                                  : "",
576         r,
577         gai_err);
578 
579     return false;
580   }
581 
582   addrinfo_ptr result{naked_result};
583 
584   for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) {
585     C10D_DEBUG("The server socket is attempting to listen on {}.", *addr);
586     if (tryListen(*addr)) {
587       return true;
588     }
589   }
590 
591   return false;
592 }
593 
tryListen(const::addrinfo & addr)594 bool SocketListenOp::tryListen(const ::addrinfo& addr) {
595   SocketImpl::Handle hnd =
596       ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
597   if (hnd == SocketImpl::invalid_socket) {
598     recordError(
599         "The server socket cannot be initialized on {} {}.",
600         addr,
601         getSocketError());
602 
603     return false;
604   }
605 
606   socket_ = std::make_unique<SocketImpl>(hnd);
607 
608 #ifndef _WIN32
609   if (!socket_->enableAddressReuse()) {
610     C10D_WARNING(
611         "The address reuse option cannot be enabled for the server socket on {}.",
612         addr);
613   }
614 #endif
615 
616 #ifdef _WIN32
617   // The SO_REUSEADDR flag has a significantly different behavior on Windows
618   // compared to Unix-like systems. It allows two or more processes to share
619   // the same port simultaneously, which is totally unsafe.
620   //
621   // Here we follow the recommendation of Microsoft and use the non-standard
622   // SO_EXCLUSIVEADDRUSE flag instead.
623   if (!socket_->enableExclusiveAddressUse()) {
624     C10D_WARNING(
625         "The exclusive address use option cannot be enabled for the server socket on {}.",
626         addr);
627   }
628 #endif
629 
630   // Not all operating systems support dual-stack sockets by default. Since we
631   // wish to use our IPv6 socket for IPv4 communication as well, we explicitly
632   // ask the system to enable it.
633   if (addr.ai_family == AF_INET6 && !socket_->enableDualStack()) {
634     C10D_WARNING(
635         "The server socket does not support IPv4 communication on {}.", addr);
636   }
637 
638   if (::bind(socket_->handle(), addr.ai_addr, addr.ai_addrlen) != 0) {
639     recordError(
640         "The server socket has failed to bind to {} {}.",
641         addr,
642         getSocketError());
643 
644     return false;
645   }
646 
647   // NOLINTNEXTLINE(bugprone-argument-comment)
648   if (::listen(socket_->handle(), -1 /* backlog */) != 0) {
649     recordError(
650         "The server socket has failed to listen on {} {}.",
651         addr,
652         getSocketError());
653 
654     return false;
655   }
656 
657   socket_->closeOnExec();
658 
659   C10D_INFO("The server socket has started to listen on {}.", addr);
660 
661   return true;
662 }
663 
664 class SocketListenFromFdOp {
665  public:
666   SocketListenFromFdOp(int fd, std::uint16_t expected_port);
667 
668   std::unique_ptr<SocketImpl> run() const;
669 
670  private:
671   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
672   const int fd_;
673   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
674   const std::uint16_t expected_port_;
675 };
676 
SocketListenFromFdOp(int fd,std::uint16_t expected_port)677 SocketListenFromFdOp::SocketListenFromFdOp(int fd, std::uint16_t expected_port)
678     : fd_(fd), expected_port_(expected_port) {}
679 
run() const680 std::unique_ptr<SocketImpl> SocketListenFromFdOp::run() const {
681   C10D_DEBUG("listenFromFd: fd {}, expected port {}", fd_, expected_port_);
682 
683   ::sockaddr_storage addr_storage{};
684   ::socklen_t addr_len = sizeof(addr_storage);
685   if (::getsockname(
686           fd_, reinterpret_cast<::sockaddr*>(&addr_storage), &addr_len) < 0) {
687     C10D_THROW_ERROR(
688         SocketError,
689         fmt::format("getsockname failed for fd {}: {}", fd_, getSocketError()));
690   }
691 
692   auto socket = std::make_unique<SocketImpl>(fd_);
693   const auto port = socket->getPort();
694 
695   if (port != expected_port_) {
696     C10D_THROW_ERROR(
697         SocketError,
698         fmt::format(
699             "listen fd {} is bound to port {}, expected to be bound to port {}",
700             fd_,
701             port,
702             expected_port_));
703   }
704 
705   if (::listen(socket->handle(), -1 /* backlog */) != 0) {
706     C10D_THROW_ERROR(
707         SocketError,
708         fmt::format(
709             "Failed to listen on socket initialized from fd {}: {}.",
710             socket->handle(),
711             getSocketError()));
712   }
713 
714   socket->closeOnExec();
715 
716   C10D_INFO(
717       "The server has taken over the listening socket with fd {}, address {}",
718       fd_,
719       *socket);
720   return socket;
721 }
722 
723 class SocketConnectOp {
724   using Clock = std::chrono::steady_clock;
725   using Duration = std::chrono::steady_clock::duration;
726   using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
727 
728   enum class ConnectResult : uint8_t { Success, Error, Retry };
729 
730  public:
731   SocketConnectOp(
732       const std::string& host,
733       std::uint16_t port,
734       const SocketOptions& opts);
735 
736   std::unique_ptr<SocketImpl> run();
737 
738  private:
739   bool tryConnect(int family);
740 
741   ConnectResult tryConnect(const ::addrinfo& addr);
742 
743   ConnectResult tryConnectCore(const ::addrinfo& addr);
744 
745   [[noreturn]] void throwTimeoutError() const;
746 
747   template <typename... Args>
748   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
recordError(fmt::string_view format,Args &&...args)749   void recordError(fmt::string_view format, Args&&... args) {
750     auto msg = fmt::vformat(format, fmt::make_format_args(args...));
751 
752     C10D_WARNING(msg);
753 
754     errors_.emplace_back(std::move(msg));
755   }
756 
757   const char* host_;
758   std::string port_;
759   const SocketOptions* opts_;
760   TimePoint deadline_{};
761   std::vector<std::string> errors_{};
762   std::unique_ptr<SocketImpl> socket_{};
763 };
764 
SocketConnectOp(const std::string & host,std::uint16_t port,const SocketOptions & opts)765 SocketConnectOp::SocketConnectOp(
766     const std::string& host,
767     std::uint16_t port,
768     const SocketOptions& opts)
769     : host_{host.c_str()}, port_{fmt::to_string(port)}, opts_{&opts} {}
770 
run()771 std::unique_ptr<SocketImpl> SocketConnectOp::run() {
772   if (opts_->prefer_ipv6()) {
773     C10D_DEBUG(
774         "The client socket will attempt to connect to an IPv6 address of ({}, {}).",
775         host_,
776         port_);
777 
778     if (tryConnect(AF_INET6)) {
779       return std::move(socket_);
780     }
781 
782     C10D_DEBUG(
783         "The client socket will attempt to connect to an IPv4 address of ({}, {}).",
784         host_,
785         port_);
786 
787     if (tryConnect(AF_INET)) {
788       return std::move(socket_);
789     }
790   } else {
791     C10D_DEBUG(
792         "The client socket will attempt to connect to an IPv4 or IPv6 address of ({}, {}).",
793         host_,
794         port_);
795 
796     if (tryConnect(AF_UNSPEC)) {
797       return std::move(socket_);
798     }
799   }
800 
801   auto msg = fmt::format(
802       "The client socket has failed to connect to any network address of ({}, {}).",
803       host_,
804       port_);
805 
806   C10D_ERROR(msg);
807 
808   C10D_THROW_ERROR(
809       SocketError, fmt::format("{} {}", msg, fmt::join(errors_, " ")));
810 }
811 
tryConnect(int family)812 bool SocketConnectOp::tryConnect(int family) {
813   ::addrinfo hints{};
814   hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV;
815   hints.ai_family = family;
816   hints.ai_socktype = SOCK_STREAM;
817 
818   deadline_ = Clock::now() + opts_->connect_timeout();
819 
820   bool retry; // NOLINT(cppcoreguidelines-init-variables)
821   do {
822     retry = false;
823 
824     errors_.clear();
825 
826     ::addrinfo* naked_result = nullptr;
827     // patternlint-disable cpp-dns-deps
828     int r = ::getaddrinfo(host_, port_.c_str(), &hints, &naked_result);
829     if (r != 0) {
830       const char* gai_err = ::gai_strerror(r);
831 
832       recordError(
833           "The {}network addresses of ({}, {}) cannot be retrieved (gai error: {} - {}).",
834           family == AF_INET        ? "IPv4 "
835               : family == AF_INET6 ? "IPv6 "
836                                    : "",
837           host_,
838           port_,
839           r,
840           gai_err);
841       retry = true;
842     } else {
843       addrinfo_ptr result{naked_result};
844 
845       for (::addrinfo* addr = naked_result; addr != nullptr;
846            addr = addr->ai_next) {
847         C10D_TRACE("The client socket is attempting to connect to {}.", *addr);
848 
849         ConnectResult cr = tryConnect(*addr);
850         if (cr == ConnectResult::Success) {
851           return true;
852         }
853 
854         if (cr == ConnectResult::Retry) {
855           retry = true;
856         }
857       }
858     }
859 
860     if (retry) {
861       auto connectBackoff = opts_->connect_backoff();
862       auto delayDuration = connectBackoff->nextBackoff();
863 
864       if (Clock::now() < deadline_ - delayDuration) {
865         // Prevent our log output to be too noisy, warn only every 30 seconds.
866         static auto lastLog = std::chrono::steady_clock::now();
867         auto now = std::chrono::steady_clock::now();
868         if ((now - lastLog) >= std::chrono::seconds(30)) {
869           C10D_INFO(
870               "No socket on ({}, {}) is listening yet, will retry.",
871               host_,
872               port_);
873 
874           lastLog = now;
875         }
876 
877         // Wait to avoid choking the server.
878         delay(delayDuration);
879       } else {
880         throwTimeoutError();
881       }
882     }
883   } while (retry);
884 
885   return false;
886 }
887 
tryConnect(const::addrinfo & addr)888 SocketConnectOp::ConnectResult SocketConnectOp::tryConnect(
889     const ::addrinfo& addr) {
890   if (Clock::now() >= deadline_) {
891     throwTimeoutError();
892   }
893 
894   SocketImpl::Handle hnd =
895       ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
896   if (hnd == SocketImpl::invalid_socket) {
897     recordError(
898         "The client socket cannot be initialized to connect to {} {}.",
899         addr,
900         getSocketError());
901 
902     return ConnectResult::Error;
903   }
904 
905   socket_ = std::make_unique<SocketImpl>(hnd, addr);
906 
907   socket_->enableNonBlocking();
908 
909   ConnectResult cr = tryConnectCore(addr);
910   if (cr == ConnectResult::Error) {
911     std::error_code err = getSocketError();
912     if (err == std::errc::interrupted) {
913       C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
914     }
915 
916     // Retry if the server is not yet listening or if its backlog is exhausted.
917     if (err == std::errc::connection_refused ||
918         err == std::errc::connection_reset) {
919       C10D_TRACE(
920           "The server socket on {} is not yet listening {}, will retry.",
921           addr,
922           err);
923 
924       return ConnectResult::Retry;
925     } else {
926       recordError(
927           "The client socket has failed to connect to {} {}.", addr, err);
928 
929       return ConnectResult::Error;
930     }
931   }
932 
933   socket_->closeOnExec();
934 
935   // TODO: Remove once we fully migrate to non-blocking mode.
936   socket_->disableNonBlocking();
937 
938   C10D_INFO("The client socket has connected to {} on {}.", addr, *socket_);
939 
940   if (!socket_->enableNoDelay()) {
941     C10D_WARNING(
942         "The no-delay option cannot be enabled for the client socket on {}.",
943         *socket_);
944   }
945 
946   return ConnectResult::Success;
947 }
948 
tryConnectCore(const::addrinfo & addr)949 SocketConnectOp::ConnectResult SocketConnectOp::tryConnectCore(
950     const ::addrinfo& addr) {
951   int r = ::connect(socket_->handle(), addr.ai_addr, addr.ai_addrlen);
952   if (r == 0) {
953     return ConnectResult::Success;
954   }
955 
956   std::error_code err = getSocketError();
957   if (err == std::errc::already_connected) {
958     return ConnectResult::Success;
959   }
960 
961   if (err != std::errc::operation_in_progress &&
962       err != std::errc::operation_would_block) {
963     return ConnectResult::Error;
964   }
965 
966   Duration remaining = deadline_ - Clock::now();
967   if (remaining <= Duration::zero()) {
968     throwTimeoutError();
969   }
970 
971   ::pollfd pfd{};
972   pfd.fd = socket_->handle();
973   pfd.events = POLLOUT;
974 
975   auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(remaining);
976 
977   r = pollFd(&pfd, 1, static_cast<int>(ms.count()));
978   if (r == 0) {
979     throwTimeoutError();
980   }
981   if (r == -1) {
982     return ConnectResult::Error;
983   }
984 
985   int err_code = 0;
986 
987   ::socklen_t err_len = sizeof(int);
988 
989   r = getSocketOption(
990       socket_->handle(), SOL_SOCKET, SO_ERROR, &err_code, &err_len);
991   if (r != 0) {
992     return ConnectResult::Error;
993   }
994 
995   if (err_code != 0) {
996     setSocketError(err_code);
997 
998     return ConnectResult::Error;
999   } else {
1000     return ConnectResult::Success;
1001   }
1002 }
1003 
throwTimeoutError() const1004 void SocketConnectOp::throwTimeoutError() const {
1005   auto msg = fmt::format(
1006       "The client socket has timed out after {} while trying to connect to ({}, {}).",
1007       opts_->connect_timeout(),
1008       host_,
1009       port_);
1010 
1011   C10D_ERROR(msg);
1012 
1013   C10D_THROW_ERROR(TimeoutError, msg);
1014 }
1015 
1016 } // namespace
1017 
initialize()1018 void Socket::initialize() {
1019 #ifdef _WIN32
1020   static c10::once_flag init_flag{};
1021 
1022   // All processes that call socket functions on Windows must first initialize
1023   // the Winsock library.
1024   c10::call_once(init_flag, []() {
1025     WSADATA data{};
1026     if (::WSAStartup(MAKEWORD(2, 2), &data) != 0) {
1027       C10D_THROW_ERROR(
1028           SocketError, "The initialization of Winsock has failed.");
1029     }
1030   });
1031 #endif
1032 }
1033 
listen(std::uint16_t port,const SocketOptions & opts)1034 Socket Socket::listen(std::uint16_t port, const SocketOptions& opts) {
1035   SocketListenOp op{port, opts};
1036 
1037   return Socket{op.run()};
1038 }
1039 
listenFromFd(int fd,std::uint16_t expected_port)1040 Socket Socket::listenFromFd(int fd, std::uint16_t expected_port) {
1041   SocketListenFromFdOp op{fd, expected_port};
1042 
1043   return Socket{op.run()};
1044 }
1045 
connect(const std::string & host,std::uint16_t port,const SocketOptions & opts)1046 Socket Socket::connect(
1047     const std::string& host,
1048     std::uint16_t port,
1049     const SocketOptions& opts) {
1050   SocketConnectOp op{host, port, opts};
1051 
1052   return Socket{op.run()};
1053 }
1054 
1055 Socket::Socket(Socket&& other) noexcept = default;
1056 
1057 Socket& Socket::operator=(Socket&& other) noexcept = default;
1058 
1059 Socket::~Socket() = default;
1060 
accept() const1061 Socket Socket::accept() const {
1062   if (impl_) {
1063     return Socket{impl_->accept()};
1064   }
1065 
1066   C10D_THROW_ERROR(SocketError, "The socket is not initialized.");
1067 }
1068 
handle() const1069 int Socket::handle() const noexcept {
1070   if (impl_) {
1071     return impl_->handle();
1072   }
1073   return SocketImpl::invalid_socket;
1074 }
1075 
port() const1076 std::uint16_t Socket::port() const {
1077   if (impl_) {
1078     return impl_->getPort();
1079   }
1080   return 0;
1081 }
1082 
Socket(std::unique_ptr<SocketImpl> && impl)1083 Socket::Socket(std::unique_ptr<SocketImpl>&& impl) noexcept
1084     : impl_{std::move(impl)} {}
1085 
waitForInput(std::chrono::milliseconds timeout)1086 bool Socket::waitForInput(std::chrono::milliseconds timeout) {
1087   return impl_->waitForInput(timeout);
1088 }
1089 
repr() const1090 std::string Socket::repr() const {
1091   if (impl_) {
1092     return fmt::format("{}", *impl_);
1093   }
1094   return "Socket(no-impl)";
1095 }
1096 
1097 } // namespace c10d::detail
1098