1 // Copyright (c) Meta Platforms, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <torch/csrc/distributed/c10d/socket.h>
8
9 #include <cstring>
10 #include <optional>
11 #include <system_error>
12 #include <utility>
13 #include <vector>
14
15 #ifdef _WIN32
16 #include <mutex>
17
18 #include <winsock2.h>
19 #include <ws2tcpip.h>
20 #else
21 #include <arpa/inet.h>
22 #include <fcntl.h>
23 #include <netdb.h>
24 #include <netinet/tcp.h>
25 #include <poll.h>
26 #include <sys/socket.h>
27 #include <sys/types.h>
28 #include <unistd.h>
29 #endif
30
31 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated")
32 #include <fmt/chrono.h>
33 C10_DIAGNOSTIC_POP()
34 #include <fmt/format.h>
35 #include <fmt/ranges.h>
36
37 #include <torch/csrc/distributed/c10d/error.h>
38 #include <torch/csrc/distributed/c10d/exception.h>
39 #include <torch/csrc/distributed/c10d/logging.h>
40
41 #include <c10/util/CallOnce.h>
42
43 namespace c10d::detail {
44 namespace {
45 #ifdef _WIN32
46
47 // Since Winsock uses the name `WSAPoll` instead of `poll`, we alias it here
48 // to avoid #ifdefs in the source code.
49 const auto pollFd = ::WSAPoll;
50
51 // Winsock's `getsockopt()` and `setsockopt()` functions expect option values to
52 // be passed as `char*` instead of `void*`. We wrap them here to avoid redundant
53 // casts in the source code.
getSocketOption(SOCKET s,int level,int optname,void * optval,int * optlen)54 int getSocketOption(
55 SOCKET s,
56 int level,
57 int optname,
58 void* optval,
59 int* optlen) {
60 return ::getsockopt(s, level, optname, static_cast<char*>(optval), optlen);
61 }
62
setSocketOption(SOCKET s,int level,int optname,const void * optval,int optlen)63 int setSocketOption(
64 SOCKET s,
65 int level,
66 int optname,
67 const void* optval,
68 int optlen) {
69 return ::setsockopt(
70 s, level, optname, static_cast<const char*>(optval), optlen);
71 }
72
73 // Winsock has its own error codes which differ from Berkeley's. Fortunately the
74 // C++ Standard Library on Windows can map them to standard error codes.
getSocketError()75 inline std::error_code getSocketError() noexcept {
76 return std::error_code{::WSAGetLastError(), std::system_category()};
77 }
78
setSocketError(int val)79 inline void setSocketError(int val) noexcept {
80 ::WSASetLastError(val);
81 }
82
83 #else
84
85 const auto pollFd = ::poll;
86
87 const auto getSocketOption = ::getsockopt;
88 const auto setSocketOption = ::setsockopt;
89
90 inline std::error_code getSocketError() noexcept {
91 return lastError();
92 }
93
94 inline void setSocketError(int val) noexcept {
95 errno = val;
96 }
97
98 #endif
99
100 // Suspends the current thread for the specified duration.
delay(std::chrono::milliseconds d)101 void delay(std::chrono::milliseconds d) {
102 #ifdef _WIN32
103 std::this_thread::sleep_for(d);
104 #else
105 ::timespec req{};
106 auto ms = d.count();
107 req.tv_sec = ms / 1000;
108 req.tv_nsec = (ms % 1000) * 1000000;
109
110 // The C++ Standard does not specify whether `sleep_for()` should be signal-
111 // aware; therefore, we use the `nanosleep()` syscall.
112 if (::nanosleep(&req, nullptr) != 0) {
113 std::error_code err = getSocketError();
114 // We don't care about error conditions other than EINTR since a failure
115 // here is not critical.
116 if (err == std::errc::interrupted) {
117 C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
118 }
119 }
120 #endif
121 }
122
123 class SocketListenOp;
124 class SocketConnectOp;
125 } // namespace
126
127 class SocketImpl {
128 friend class SocketListenOp;
129 friend class SocketConnectOp;
130
131 public:
132 #ifdef _WIN32
133 using Handle = SOCKET;
134 #else
135 using Handle = int;
136 #endif
137
138 #ifdef _WIN32
139 static constexpr Handle invalid_socket = INVALID_SOCKET;
140 #else
141 static constexpr Handle invalid_socket = -1;
142 #endif
143
SocketImpl(Handle hnd)144 explicit SocketImpl(Handle hnd) noexcept : hnd_{hnd} {}
145
146 explicit SocketImpl(Handle hnd, const ::addrinfo& remote);
147
148 SocketImpl(const SocketImpl& other) = delete;
149
150 SocketImpl& operator=(const SocketImpl& other) = delete;
151
152 SocketImpl(SocketImpl&& other) noexcept = delete;
153
154 SocketImpl& operator=(SocketImpl&& other) noexcept = delete;
155
156 ~SocketImpl();
157
158 std::unique_ptr<SocketImpl> accept() const;
159
160 void closeOnExec() noexcept;
161
162 void enableNonBlocking();
163
164 void disableNonBlocking();
165
166 bool enableNoDelay() noexcept;
167
168 bool enableDualStack() noexcept;
169
170 #ifndef _WIN32
171 bool enableAddressReuse() noexcept;
172 #endif
173
174 #ifdef _WIN32
175 bool enableExclusiveAddressUse() noexcept;
176 #endif
177
178 std::uint16_t getPort() const;
179
handle() const180 Handle handle() const noexcept {
181 return hnd_;
182 }
183
remote() const184 const std::optional<std::string>& remote() const noexcept {
185 return remote_;
186 }
187
188 bool waitForInput(std::chrono::milliseconds timeout);
189
190 private:
191 bool setSocketFlag(int level, int optname, bool value) noexcept;
192
193 Handle hnd_;
194 const std::optional<std::string> remote_;
195 };
196 } // namespace c10d::detail
197
198 //
199 // libfmt formatters for `addrinfo` and `Socket`
200 //
201 namespace fmt {
202
203 template <>
204 struct formatter<::addrinfo> {
parsefmt::formatter205 constexpr decltype(auto) parse(format_parse_context& ctx) const {
206 return ctx.begin();
207 }
208
209 template <typename FormatContext>
formatfmt::formatter210 decltype(auto) format(const ::addrinfo& addr, FormatContext& ctx) const {
211 char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT
212
213 int r = ::getnameinfo(
214 addr.ai_addr,
215 addr.ai_addrlen,
216 host,
217 NI_MAXHOST,
218 port,
219 NI_MAXSERV,
220 NI_NUMERICSERV);
221 if (r != 0) {
222 // if we can't resolve the hostname, display the IP address
223 if (addr.ai_family == AF_INET) {
224 struct sockaddr_in* psai = (struct sockaddr_in*)addr.ai_addr;
225 char ip[INET_ADDRSTRLEN];
226 if (inet_ntop(addr.ai_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) !=
227 NULL) {
228 return fmt::format_to(ctx.out(), "{}:{}", ip, psai->sin_port);
229 }
230 } else if (addr.ai_family == AF_INET6) {
231 struct sockaddr_in6* psai = (struct sockaddr_in6*)addr.ai_addr;
232 char ip[INET6_ADDRSTRLEN];
233 if (inet_ntop(
234 addr.ai_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) !=
235 NULL) {
236 return fmt::format_to(ctx.out(), "[{}]:{}", ip, psai->sin6_port);
237 }
238 }
239 C10_THROW_ERROR(
240 DistNetworkError,
241 fmt::format(
242 "failed to format addr, unknown family={}", addr.ai_family));
243 }
244
245 if (addr.ai_addr->sa_family == AF_INET) {
246 return fmt::format_to(ctx.out(), "{}:{}", host, port);
247 } else {
248 return fmt::format_to(ctx.out(), "[{}]:{}", host, port);
249 }
250 }
251 };
252
253 template <>
254 struct formatter<c10d::detail::SocketImpl> {
parsefmt::formatter255 constexpr decltype(auto) parse(format_parse_context& ctx) const {
256 return ctx.begin();
257 }
258
259 template <typename FormatContext>
formatfmt::formatter260 decltype(auto) format(
261 const c10d::detail::SocketImpl& socket,
262 FormatContext& ctx) const {
263 ::sockaddr_storage addr_s{};
264
265 auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s);
266
267 ::socklen_t addr_len = sizeof(addr_s);
268
269 auto fd = socket.handle();
270
271 if (::getsockname(fd, addr_ptr, &addr_len) != 0) {
272 return fmt::format_to(ctx.out(), "?UNKNOWN?");
273 }
274
275 ::addrinfo addr{};
276 addr.ai_addr = addr_ptr;
277 addr.ai_addrlen = addr_len;
278
279 auto remote = socket.remote();
280 std::string remoteStr = remote ? *remote : "none";
281
282 return fmt::format_to(
283 ctx.out(),
284 "SocketImpl(fd={}, addr={}, remote={})",
285 fd,
286 addr,
287 remoteStr);
288 }
289 };
290
291 } // namespace fmt
292
293 namespace c10d::detail {
294
SocketImpl(Handle hnd,const::addrinfo & remote)295 SocketImpl::SocketImpl(Handle hnd, const ::addrinfo& remote)
296 : hnd_{hnd}, remote_{fmt::format("{}", remote)} {}
297
~SocketImpl()298 SocketImpl::~SocketImpl() {
299 #ifdef _WIN32
300 ::closesocket(hnd_);
301 #else
302 ::close(hnd_);
303 #endif
304 }
305
accept() const306 std::unique_ptr<SocketImpl> SocketImpl::accept() const {
307 ::sockaddr_storage addr_s{};
308
309 auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s);
310
311 ::socklen_t addr_len = sizeof(addr_s);
312
313 Handle hnd = ::accept(hnd_, addr_ptr, &addr_len);
314 if (hnd == invalid_socket) {
315 std::error_code err = getSocketError();
316 if (err == std::errc::interrupted) {
317 C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
318 }
319
320 std::string msg{};
321 if (err == std::errc::invalid_argument) {
322 msg = fmt::format(
323 "The server socket on {} is not listening for connections.", *this);
324 } else {
325 msg = fmt::format(
326 "The server socket on {} has failed to accept a connection {}.",
327 *this,
328 err);
329 }
330
331 C10D_ERROR(msg);
332
333 C10D_THROW_ERROR(SocketError, msg);
334 }
335
336 ::addrinfo addr{};
337 addr.ai_addr = addr_ptr;
338 addr.ai_addrlen = addr_len;
339
340 C10D_DEBUG(
341 "The server socket on {} has accepted a connection from {}.",
342 *this,
343 addr);
344
345 auto impl = std::make_unique<SocketImpl>(hnd, addr);
346
347 // Make sure that we do not "leak" our file descriptors to child processes.
348 impl->closeOnExec();
349
350 if (!impl->enableNoDelay()) {
351 C10D_WARNING(
352 "The no-delay option cannot be enabled for the client socket on {}.",
353 addr);
354 }
355
356 return impl;
357 }
358
closeOnExec()359 void SocketImpl::closeOnExec() noexcept {
360 #ifndef _WIN32
361 ::fcntl(hnd_, F_SETFD, FD_CLOEXEC);
362 #endif
363 }
364
enableNonBlocking()365 void SocketImpl::enableNonBlocking() {
366 #ifdef _WIN32
367 unsigned long value = 1;
368 if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) {
369 return;
370 }
371 #else
372 int flg = ::fcntl(hnd_, F_GETFL);
373 if (flg != -1) {
374 if (::fcntl(hnd_, F_SETFL, flg | O_NONBLOCK) == 0) {
375 return;
376 }
377 }
378 #endif
379 C10D_THROW_ERROR(
380 SocketError, "The socket cannot be switched to non-blocking mode.");
381 }
382
383 // TODO: Remove once we migrate everything to non-blocking mode.
disableNonBlocking()384 void SocketImpl::disableNonBlocking() {
385 #ifdef _WIN32
386 unsigned long value = 0;
387 if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) {
388 return;
389 }
390 #else
391 int flg = ::fcntl(hnd_, F_GETFL);
392 if (flg != -1) {
393 if (::fcntl(hnd_, F_SETFL, flg & ~O_NONBLOCK) == 0) {
394 return;
395 }
396 }
397 #endif
398 C10D_THROW_ERROR(
399 SocketError, "The socket cannot be switched to blocking mode.");
400 }
401
enableNoDelay()402 bool SocketImpl::enableNoDelay() noexcept {
403 return setSocketFlag(IPPROTO_TCP, TCP_NODELAY, true);
404 }
405
enableDualStack()406 bool SocketImpl::enableDualStack() noexcept {
407 return setSocketFlag(IPPROTO_IPV6, IPV6_V6ONLY, false);
408 }
409
410 #ifndef _WIN32
enableAddressReuse()411 bool SocketImpl::enableAddressReuse() noexcept {
412 return setSocketFlag(SOL_SOCKET, SO_REUSEADDR, true);
413 }
414 #endif
415
416 #ifdef _WIN32
enableExclusiveAddressUse()417 bool SocketImpl::enableExclusiveAddressUse() noexcept {
418 return setSocketFlag(SOL_SOCKET, SO_EXCLUSIVEADDRUSE, true);
419 }
420 #endif
421
getPort() const422 std::uint16_t SocketImpl::getPort() const {
423 ::sockaddr_storage addr_s{};
424
425 ::socklen_t addr_len = sizeof(addr_s);
426
427 if (::getsockname(hnd_, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) !=
428 0) {
429 C10D_THROW_ERROR(
430 SocketError, "The port number of the socket cannot be retrieved.");
431 }
432
433 if (addr_s.ss_family == AF_INET) {
434 return ntohs(reinterpret_cast<::sockaddr_in*>(&addr_s)->sin_port);
435 } else {
436 return ntohs(reinterpret_cast<::sockaddr_in6*>(&addr_s)->sin6_port);
437 }
438 }
439
setSocketFlag(int level,int optname,bool value)440 bool SocketImpl::setSocketFlag(int level, int optname, bool value) noexcept {
441 #ifdef _WIN32
442 auto buf = value ? TRUE : FALSE;
443 #else
444 auto buf = value ? 1 : 0;
445 #endif
446 return setSocketOption(hnd_, level, optname, &buf, sizeof(buf)) == 0;
447 }
448
waitForInput(std::chrono::milliseconds timeout)449 bool SocketImpl::waitForInput(std::chrono::milliseconds timeout) {
450 using Clock = std::chrono::steady_clock;
451
452 auto deadline = Clock::now() + timeout;
453 do {
454 ::pollfd pfd{};
455 pfd.fd = hnd_;
456 pfd.events = POLLIN;
457
458 int res = pollFd(&pfd, 1, static_cast<int>(timeout.count()));
459 if (res > 0) {
460 return true;
461 } else if (res == 0) {
462 C10D_WARNING(
463 "waitForInput: poll for socket {} returned 0, likely a timeout",
464 *this);
465 continue;
466 }
467
468 std::error_code err = getSocketError();
469 if (err == std::errc::operation_in_progress) {
470 bool timedout = Clock::now() >= deadline;
471 if (timedout) {
472 return false;
473 }
474 C10D_WARNING(
475 "waitForInput: poll for socket {} returned operation_in_progress before a timeout",
476 *this);
477 } else if (err != std::errc::interrupted) {
478 C10D_WARNING(
479 "waitForInput: poll for socket {} failed with res={}, err={}.",
480 *this,
481 res,
482 err);
483 return false;
484 }
485 } while (Clock::now() < deadline);
486
487 C10D_WARNING(
488 "waitForInput: socket {} timed out after {}ms", *this, timeout.count());
489 return false;
490 }
491
492 namespace {
493
494 struct addrinfo_delete {
operator ()c10d::detail::__anonac1afc6b0211::addrinfo_delete495 void operator()(::addrinfo* addr) const noexcept {
496 ::freeaddrinfo(addr);
497 }
498 };
499
500 using addrinfo_ptr = std::unique_ptr<::addrinfo, addrinfo_delete>;
501
502 class SocketListenOp {
503 public:
504 SocketListenOp(std::uint16_t port, const SocketOptions& opts);
505
506 std::unique_ptr<SocketImpl> run();
507
508 private:
509 bool tryListen(int family);
510
511 bool tryListen(const ::addrinfo& addr);
512
513 template <typename... Args>
514 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
recordError(fmt::string_view format,Args &&...args)515 void recordError(fmt::string_view format, Args&&... args) {
516 auto msg = fmt::vformat(format, fmt::make_format_args(args...));
517
518 C10D_WARNING(msg);
519
520 errors_.emplace_back(std::move(msg));
521 }
522
523 std::string port_;
524 const SocketOptions* opts_;
525 std::vector<std::string> errors_{};
526 std::unique_ptr<SocketImpl> socket_{};
527 };
528
SocketListenOp(std::uint16_t port,const SocketOptions & opts)529 SocketListenOp::SocketListenOp(std::uint16_t port, const SocketOptions& opts)
530 : port_{fmt::to_string(port)}, opts_{&opts} {}
531
run()532 std::unique_ptr<SocketImpl> SocketListenOp::run() {
533 if (opts_->prefer_ipv6()) {
534 C10D_DEBUG("The server socket will attempt to listen on an IPv6 address.");
535 if (tryListen(AF_INET6)) {
536 return std::move(socket_);
537 }
538
539 C10D_DEBUG("The server socket will attempt to listen on an IPv4 address.");
540 if (tryListen(AF_INET)) {
541 return std::move(socket_);
542 }
543 } else {
544 C10D_DEBUG(
545 "The server socket will attempt to listen on an IPv4 or IPv6 address.");
546 if (tryListen(AF_UNSPEC)) {
547 return std::move(socket_);
548 }
549 }
550
551 constexpr auto* msg =
552 "The server socket has failed to listen on any local network address.";
553
554 C10D_ERROR(msg);
555
556 C10D_THROW_ERROR(
557 SocketError, fmt::format("{} {}", msg, fmt::join(errors_, " ")));
558 }
559
tryListen(int family)560 bool SocketListenOp::tryListen(int family) {
561 ::addrinfo hints{}, *naked_result = nullptr;
562
563 hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV;
564 hints.ai_family = family;
565 hints.ai_socktype = SOCK_STREAM;
566
567 int r = ::getaddrinfo(nullptr, port_.c_str(), &hints, &naked_result);
568 if (r != 0) {
569 const char* gai_err = ::gai_strerror(r);
570
571 recordError(
572 "The local {}network addresses cannot be retrieved (gai error: {} - {}).",
573 family == AF_INET ? "IPv4 "
574 : family == AF_INET6 ? "IPv6 "
575 : "",
576 r,
577 gai_err);
578
579 return false;
580 }
581
582 addrinfo_ptr result{naked_result};
583
584 for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) {
585 C10D_DEBUG("The server socket is attempting to listen on {}.", *addr);
586 if (tryListen(*addr)) {
587 return true;
588 }
589 }
590
591 return false;
592 }
593
tryListen(const::addrinfo & addr)594 bool SocketListenOp::tryListen(const ::addrinfo& addr) {
595 SocketImpl::Handle hnd =
596 ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
597 if (hnd == SocketImpl::invalid_socket) {
598 recordError(
599 "The server socket cannot be initialized on {} {}.",
600 addr,
601 getSocketError());
602
603 return false;
604 }
605
606 socket_ = std::make_unique<SocketImpl>(hnd);
607
608 #ifndef _WIN32
609 if (!socket_->enableAddressReuse()) {
610 C10D_WARNING(
611 "The address reuse option cannot be enabled for the server socket on {}.",
612 addr);
613 }
614 #endif
615
616 #ifdef _WIN32
617 // The SO_REUSEADDR flag has a significantly different behavior on Windows
618 // compared to Unix-like systems. It allows two or more processes to share
619 // the same port simultaneously, which is totally unsafe.
620 //
621 // Here we follow the recommendation of Microsoft and use the non-standard
622 // SO_EXCLUSIVEADDRUSE flag instead.
623 if (!socket_->enableExclusiveAddressUse()) {
624 C10D_WARNING(
625 "The exclusive address use option cannot be enabled for the server socket on {}.",
626 addr);
627 }
628 #endif
629
630 // Not all operating systems support dual-stack sockets by default. Since we
631 // wish to use our IPv6 socket for IPv4 communication as well, we explicitly
632 // ask the system to enable it.
633 if (addr.ai_family == AF_INET6 && !socket_->enableDualStack()) {
634 C10D_WARNING(
635 "The server socket does not support IPv4 communication on {}.", addr);
636 }
637
638 if (::bind(socket_->handle(), addr.ai_addr, addr.ai_addrlen) != 0) {
639 recordError(
640 "The server socket has failed to bind to {} {}.",
641 addr,
642 getSocketError());
643
644 return false;
645 }
646
647 // NOLINTNEXTLINE(bugprone-argument-comment)
648 if (::listen(socket_->handle(), -1 /* backlog */) != 0) {
649 recordError(
650 "The server socket has failed to listen on {} {}.",
651 addr,
652 getSocketError());
653
654 return false;
655 }
656
657 socket_->closeOnExec();
658
659 C10D_INFO("The server socket has started to listen on {}.", addr);
660
661 return true;
662 }
663
664 class SocketListenFromFdOp {
665 public:
666 SocketListenFromFdOp(int fd, std::uint16_t expected_port);
667
668 std::unique_ptr<SocketImpl> run() const;
669
670 private:
671 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
672 const int fd_;
673 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
674 const std::uint16_t expected_port_;
675 };
676
SocketListenFromFdOp(int fd,std::uint16_t expected_port)677 SocketListenFromFdOp::SocketListenFromFdOp(int fd, std::uint16_t expected_port)
678 : fd_(fd), expected_port_(expected_port) {}
679
run() const680 std::unique_ptr<SocketImpl> SocketListenFromFdOp::run() const {
681 C10D_DEBUG("listenFromFd: fd {}, expected port {}", fd_, expected_port_);
682
683 ::sockaddr_storage addr_storage{};
684 ::socklen_t addr_len = sizeof(addr_storage);
685 if (::getsockname(
686 fd_, reinterpret_cast<::sockaddr*>(&addr_storage), &addr_len) < 0) {
687 C10D_THROW_ERROR(
688 SocketError,
689 fmt::format("getsockname failed for fd {}: {}", fd_, getSocketError()));
690 }
691
692 auto socket = std::make_unique<SocketImpl>(fd_);
693 const auto port = socket->getPort();
694
695 if (port != expected_port_) {
696 C10D_THROW_ERROR(
697 SocketError,
698 fmt::format(
699 "listen fd {} is bound to port {}, expected to be bound to port {}",
700 fd_,
701 port,
702 expected_port_));
703 }
704
705 if (::listen(socket->handle(), -1 /* backlog */) != 0) {
706 C10D_THROW_ERROR(
707 SocketError,
708 fmt::format(
709 "Failed to listen on socket initialized from fd {}: {}.",
710 socket->handle(),
711 getSocketError()));
712 }
713
714 socket->closeOnExec();
715
716 C10D_INFO(
717 "The server has taken over the listening socket with fd {}, address {}",
718 fd_,
719 *socket);
720 return socket;
721 }
722
723 class SocketConnectOp {
724 using Clock = std::chrono::steady_clock;
725 using Duration = std::chrono::steady_clock::duration;
726 using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
727
728 enum class ConnectResult : uint8_t { Success, Error, Retry };
729
730 public:
731 SocketConnectOp(
732 const std::string& host,
733 std::uint16_t port,
734 const SocketOptions& opts);
735
736 std::unique_ptr<SocketImpl> run();
737
738 private:
739 bool tryConnect(int family);
740
741 ConnectResult tryConnect(const ::addrinfo& addr);
742
743 ConnectResult tryConnectCore(const ::addrinfo& addr);
744
745 [[noreturn]] void throwTimeoutError() const;
746
747 template <typename... Args>
748 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
recordError(fmt::string_view format,Args &&...args)749 void recordError(fmt::string_view format, Args&&... args) {
750 auto msg = fmt::vformat(format, fmt::make_format_args(args...));
751
752 C10D_WARNING(msg);
753
754 errors_.emplace_back(std::move(msg));
755 }
756
757 const char* host_;
758 std::string port_;
759 const SocketOptions* opts_;
760 TimePoint deadline_{};
761 std::vector<std::string> errors_{};
762 std::unique_ptr<SocketImpl> socket_{};
763 };
764
SocketConnectOp(const std::string & host,std::uint16_t port,const SocketOptions & opts)765 SocketConnectOp::SocketConnectOp(
766 const std::string& host,
767 std::uint16_t port,
768 const SocketOptions& opts)
769 : host_{host.c_str()}, port_{fmt::to_string(port)}, opts_{&opts} {}
770
run()771 std::unique_ptr<SocketImpl> SocketConnectOp::run() {
772 if (opts_->prefer_ipv6()) {
773 C10D_DEBUG(
774 "The client socket will attempt to connect to an IPv6 address of ({}, {}).",
775 host_,
776 port_);
777
778 if (tryConnect(AF_INET6)) {
779 return std::move(socket_);
780 }
781
782 C10D_DEBUG(
783 "The client socket will attempt to connect to an IPv4 address of ({}, {}).",
784 host_,
785 port_);
786
787 if (tryConnect(AF_INET)) {
788 return std::move(socket_);
789 }
790 } else {
791 C10D_DEBUG(
792 "The client socket will attempt to connect to an IPv4 or IPv6 address of ({}, {}).",
793 host_,
794 port_);
795
796 if (tryConnect(AF_UNSPEC)) {
797 return std::move(socket_);
798 }
799 }
800
801 auto msg = fmt::format(
802 "The client socket has failed to connect to any network address of ({}, {}).",
803 host_,
804 port_);
805
806 C10D_ERROR(msg);
807
808 C10D_THROW_ERROR(
809 SocketError, fmt::format("{} {}", msg, fmt::join(errors_, " ")));
810 }
811
tryConnect(int family)812 bool SocketConnectOp::tryConnect(int family) {
813 ::addrinfo hints{};
814 hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV;
815 hints.ai_family = family;
816 hints.ai_socktype = SOCK_STREAM;
817
818 deadline_ = Clock::now() + opts_->connect_timeout();
819
820 bool retry; // NOLINT(cppcoreguidelines-init-variables)
821 do {
822 retry = false;
823
824 errors_.clear();
825
826 ::addrinfo* naked_result = nullptr;
827 // patternlint-disable cpp-dns-deps
828 int r = ::getaddrinfo(host_, port_.c_str(), &hints, &naked_result);
829 if (r != 0) {
830 const char* gai_err = ::gai_strerror(r);
831
832 recordError(
833 "The {}network addresses of ({}, {}) cannot be retrieved (gai error: {} - {}).",
834 family == AF_INET ? "IPv4 "
835 : family == AF_INET6 ? "IPv6 "
836 : "",
837 host_,
838 port_,
839 r,
840 gai_err);
841 retry = true;
842 } else {
843 addrinfo_ptr result{naked_result};
844
845 for (::addrinfo* addr = naked_result; addr != nullptr;
846 addr = addr->ai_next) {
847 C10D_TRACE("The client socket is attempting to connect to {}.", *addr);
848
849 ConnectResult cr = tryConnect(*addr);
850 if (cr == ConnectResult::Success) {
851 return true;
852 }
853
854 if (cr == ConnectResult::Retry) {
855 retry = true;
856 }
857 }
858 }
859
860 if (retry) {
861 auto connectBackoff = opts_->connect_backoff();
862 auto delayDuration = connectBackoff->nextBackoff();
863
864 if (Clock::now() < deadline_ - delayDuration) {
865 // Prevent our log output to be too noisy, warn only every 30 seconds.
866 static auto lastLog = std::chrono::steady_clock::now();
867 auto now = std::chrono::steady_clock::now();
868 if ((now - lastLog) >= std::chrono::seconds(30)) {
869 C10D_INFO(
870 "No socket on ({}, {}) is listening yet, will retry.",
871 host_,
872 port_);
873
874 lastLog = now;
875 }
876
877 // Wait to avoid choking the server.
878 delay(delayDuration);
879 } else {
880 throwTimeoutError();
881 }
882 }
883 } while (retry);
884
885 return false;
886 }
887
tryConnect(const::addrinfo & addr)888 SocketConnectOp::ConnectResult SocketConnectOp::tryConnect(
889 const ::addrinfo& addr) {
890 if (Clock::now() >= deadline_) {
891 throwTimeoutError();
892 }
893
894 SocketImpl::Handle hnd =
895 ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
896 if (hnd == SocketImpl::invalid_socket) {
897 recordError(
898 "The client socket cannot be initialized to connect to {} {}.",
899 addr,
900 getSocketError());
901
902 return ConnectResult::Error;
903 }
904
905 socket_ = std::make_unique<SocketImpl>(hnd, addr);
906
907 socket_->enableNonBlocking();
908
909 ConnectResult cr = tryConnectCore(addr);
910 if (cr == ConnectResult::Error) {
911 std::error_code err = getSocketError();
912 if (err == std::errc::interrupted) {
913 C10_THROW_ERROR(DistNetworkError, std::strerror(err.value()));
914 }
915
916 // Retry if the server is not yet listening or if its backlog is exhausted.
917 if (err == std::errc::connection_refused ||
918 err == std::errc::connection_reset) {
919 C10D_TRACE(
920 "The server socket on {} is not yet listening {}, will retry.",
921 addr,
922 err);
923
924 return ConnectResult::Retry;
925 } else {
926 recordError(
927 "The client socket has failed to connect to {} {}.", addr, err);
928
929 return ConnectResult::Error;
930 }
931 }
932
933 socket_->closeOnExec();
934
935 // TODO: Remove once we fully migrate to non-blocking mode.
936 socket_->disableNonBlocking();
937
938 C10D_INFO("The client socket has connected to {} on {}.", addr, *socket_);
939
940 if (!socket_->enableNoDelay()) {
941 C10D_WARNING(
942 "The no-delay option cannot be enabled for the client socket on {}.",
943 *socket_);
944 }
945
946 return ConnectResult::Success;
947 }
948
tryConnectCore(const::addrinfo & addr)949 SocketConnectOp::ConnectResult SocketConnectOp::tryConnectCore(
950 const ::addrinfo& addr) {
951 int r = ::connect(socket_->handle(), addr.ai_addr, addr.ai_addrlen);
952 if (r == 0) {
953 return ConnectResult::Success;
954 }
955
956 std::error_code err = getSocketError();
957 if (err == std::errc::already_connected) {
958 return ConnectResult::Success;
959 }
960
961 if (err != std::errc::operation_in_progress &&
962 err != std::errc::operation_would_block) {
963 return ConnectResult::Error;
964 }
965
966 Duration remaining = deadline_ - Clock::now();
967 if (remaining <= Duration::zero()) {
968 throwTimeoutError();
969 }
970
971 ::pollfd pfd{};
972 pfd.fd = socket_->handle();
973 pfd.events = POLLOUT;
974
975 auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(remaining);
976
977 r = pollFd(&pfd, 1, static_cast<int>(ms.count()));
978 if (r == 0) {
979 throwTimeoutError();
980 }
981 if (r == -1) {
982 return ConnectResult::Error;
983 }
984
985 int err_code = 0;
986
987 ::socklen_t err_len = sizeof(int);
988
989 r = getSocketOption(
990 socket_->handle(), SOL_SOCKET, SO_ERROR, &err_code, &err_len);
991 if (r != 0) {
992 return ConnectResult::Error;
993 }
994
995 if (err_code != 0) {
996 setSocketError(err_code);
997
998 return ConnectResult::Error;
999 } else {
1000 return ConnectResult::Success;
1001 }
1002 }
1003
throwTimeoutError() const1004 void SocketConnectOp::throwTimeoutError() const {
1005 auto msg = fmt::format(
1006 "The client socket has timed out after {} while trying to connect to ({}, {}).",
1007 opts_->connect_timeout(),
1008 host_,
1009 port_);
1010
1011 C10D_ERROR(msg);
1012
1013 C10D_THROW_ERROR(TimeoutError, msg);
1014 }
1015
1016 } // namespace
1017
initialize()1018 void Socket::initialize() {
1019 #ifdef _WIN32
1020 static c10::once_flag init_flag{};
1021
1022 // All processes that call socket functions on Windows must first initialize
1023 // the Winsock library.
1024 c10::call_once(init_flag, []() {
1025 WSADATA data{};
1026 if (::WSAStartup(MAKEWORD(2, 2), &data) != 0) {
1027 C10D_THROW_ERROR(
1028 SocketError, "The initialization of Winsock has failed.");
1029 }
1030 });
1031 #endif
1032 }
1033
listen(std::uint16_t port,const SocketOptions & opts)1034 Socket Socket::listen(std::uint16_t port, const SocketOptions& opts) {
1035 SocketListenOp op{port, opts};
1036
1037 return Socket{op.run()};
1038 }
1039
listenFromFd(int fd,std::uint16_t expected_port)1040 Socket Socket::listenFromFd(int fd, std::uint16_t expected_port) {
1041 SocketListenFromFdOp op{fd, expected_port};
1042
1043 return Socket{op.run()};
1044 }
1045
connect(const std::string & host,std::uint16_t port,const SocketOptions & opts)1046 Socket Socket::connect(
1047 const std::string& host,
1048 std::uint16_t port,
1049 const SocketOptions& opts) {
1050 SocketConnectOp op{host, port, opts};
1051
1052 return Socket{op.run()};
1053 }
1054
1055 Socket::Socket(Socket&& other) noexcept = default;
1056
1057 Socket& Socket::operator=(Socket&& other) noexcept = default;
1058
1059 Socket::~Socket() = default;
1060
accept() const1061 Socket Socket::accept() const {
1062 if (impl_) {
1063 return Socket{impl_->accept()};
1064 }
1065
1066 C10D_THROW_ERROR(SocketError, "The socket is not initialized.");
1067 }
1068
handle() const1069 int Socket::handle() const noexcept {
1070 if (impl_) {
1071 return impl_->handle();
1072 }
1073 return SocketImpl::invalid_socket;
1074 }
1075
port() const1076 std::uint16_t Socket::port() const {
1077 if (impl_) {
1078 return impl_->getPort();
1079 }
1080 return 0;
1081 }
1082
Socket(std::unique_ptr<SocketImpl> && impl)1083 Socket::Socket(std::unique_ptr<SocketImpl>&& impl) noexcept
1084 : impl_{std::move(impl)} {}
1085
waitForInput(std::chrono::milliseconds timeout)1086 bool Socket::waitForInput(std::chrono::milliseconds timeout) {
1087 return impl_->waitForInput(timeout);
1088 }
1089
repr() const1090 std::string Socket::repr() const {
1091 if (impl_) {
1092 return fmt::format("{}", *impl_);
1093 }
1094 return "Socket(no-impl)";
1095 }
1096
1097 } // namespace c10d::detail
1098