1 #pragma once
2
3 #include <ATen/ATen.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/accumulate.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/distributed/c10d/Types.hpp>
8
9 #ifdef _WIN32
10 #include <winsock2.h>
11 #include <ws2tcpip.h>
12 typedef SSIZE_T ssize_t;
13 #pragma comment(lib, "Ws2_32.lib")
14 #else
15 #include <fcntl.h>
16 #include <netdb.h>
17 #include <sys/poll.h>
18 #include <sys/socket.h>
19 #include <unistd.h>
20 #endif
21
22 #include <sys/types.h>
23
24 #include <cstdint>
25 #include <cstdlib>
26 #include <functional>
27 #include <string>
28 #include <vector>
29
30 namespace c10d {
31
32 TORCH_API size_t getTensorsNumel(const std::vector<at::Tensor>& tensors);
33
34 // Retrieve tensor shapes from a given tensor.
35 TORCH_API std::vector<at::Tensor> getTensorShapes(
36 const std::vector<at::Tensor>& tensors);
37
38 // Use -2 to represent unset state of env vars
39 #define C10D_ENV_NOT_SET -2
40
41 #define WARN_ENV_VAR_ONCE(deprecated_env, new_env) \
42 TORCH_WARN_ONCE( \
43 "Environment variable " + deprecated_env + " is deprecated; use " + \
44 new_env + " instead");
45
46 // Turns at::IntArrayRef into "(1, 2, 3, 4)".
toString(at::IntArrayRef l)47 inline std::string toString(at::IntArrayRef l) {
48 std::stringstream ss;
49 ss << "(";
50 for (const auto i : c10::irange(l.size())) {
51 if (i > 0) {
52 ss << ", ";
53 }
54 ss << l[i];
55 }
56 ss << ")";
57 return ss.str();
58 }
59
toString(const c10::Layout & layout)60 inline std::string toString(const c10::Layout& layout) {
61 std::stringstream ss;
62 ss << layout;
63 return ss.str();
64 }
65
assertSameType(const at::DeprecatedTypeProperties & type,const std::vector<at::Tensor> & tensors)66 inline void assertSameType(
67 const at::DeprecatedTypeProperties& type,
68 const std::vector<at::Tensor>& tensors) {
69 for (const auto i : c10::irange(tensors.size())) {
70 if (!tensors[i].options().type_equal(type.options())) {
71 const std::string expected = type.toString();
72 const std::string actual = tensors[i].toString();
73 throw std::invalid_argument(
74 // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
75 "mixed types (" + expected + " and " + actual + ")");
76 }
77 }
78 }
79
split(char separator,const std::string & string)80 inline std::vector<std::string> split(
81 char separator,
82 const std::string& string) {
83 std::vector<std::string> pieces;
84 std::stringstream ss(string);
85 std::string item;
86 while (std::getline(ss, item, separator)) {
87 pieces.push_back(std::move(item));
88 }
89 return pieces;
90 }
91
getCvarString(const std::vector<std::string> & env,const char * def)92 inline std::string getCvarString(
93 const std::vector<std::string>& env,
94 const char* def) {
95 const char* ret = def;
96
97 if (env.empty()) {
98 TORCH_CHECK(false, "No environment variables passed");
99 return ret;
100 }
101
102 /* parse environment variable in reverse order, so the early
103 * versions of a variable get higher priority than the latter
104 * versions of the same variable */
105 for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
106 const char* val = std::getenv(env[i].c_str());
107 if (val == nullptr) {
108 continue;
109 } else if (i) {
110 WARN_ENV_VAR_ONCE(env[i], env[0]);
111 }
112
113 ret = val;
114 }
115
116 return ret;
117 }
118
getCvarInt(const std::vector<std::string> & env,int def)119 inline int getCvarInt(const std::vector<std::string>& env, int def) {
120 int ret = def;
121
122 if (env.empty()) {
123 TORCH_CHECK(false, "No environment variables passed");
124 return ret;
125 }
126
127 /* parse environment variable in reverse order, so the early
128 * versions of a variable get higher priority than the latter
129 * versions of the same variable */
130 for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
131 char* val = std::getenv(env[i].c_str());
132 if (val == nullptr) {
133 continue;
134 } else if (i) {
135 WARN_ENV_VAR_ONCE(env[i], env[0]);
136 }
137
138 try {
139 ret = std::stoi(val);
140 } catch (std::exception&) {
141 TORCH_CHECK(false, "Invalid value for environment variable: " + env[i]);
142 }
143 }
144
145 return ret;
146 }
147
getCvarBool(const std::vector<std::string> & env,bool def)148 inline bool getCvarBool(const std::vector<std::string>& env, bool def) {
149 bool ret = def;
150
151 if (env.empty()) {
152 TORCH_CHECK(false, "No environment variables passed");
153 return ret;
154 }
155
156 /* parse environment variable in reverse order, so the early
157 * versions of a variable get higher priority than the latter
158 * versions of the same variable */
159 for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
160 char* val_ = std::getenv(env[i].c_str());
161 if (val_ == nullptr) {
162 continue;
163 } else if (i) {
164 WARN_ENV_VAR_ONCE(env[i], env[0]);
165 }
166
167 std::string val = std::string(val_);
168 for (auto& x : val) {
169 // NOLINTNEXTLINE(*-narrowing-conversions)
170 x = std::tolower(x);
171 }
172
173 if (val == "y" || val == "yes" || val == "1" || val == "t" ||
174 val == "true") {
175 ret = true;
176 } else if (
177 val == "n" || val == "no" || val == "0" || val == "f" ||
178 val == "false") {
179 ret = false;
180 } else {
181 TORCH_CHECK(false, "Invalid value for environment variable: " + env[i]);
182 return ret;
183 }
184 }
185
186 return ret;
187 }
188
assertSameSizes(const at::IntArrayRef & sizes,const std::vector<at::Tensor> & tensors)189 inline void assertSameSizes(
190 const at::IntArrayRef& sizes,
191 const std::vector<at::Tensor>& tensors) {
192 for (const auto i : c10::irange(tensors.size())) {
193 if (!tensors[i].sizes().equals(sizes)) {
194 const auto expected = toString(sizes);
195 const auto actual = toString(tensors[i].sizes());
196 throw std::invalid_argument(
197 // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
198 "mixed sizes (" + expected + " and " + actual + ")");
199 }
200 }
201 }
202
assertSameSizeAndType(const std::vector<at::Tensor> & tensors)203 inline void assertSameSizeAndType(const std::vector<at::Tensor>& tensors) {
204 // Ensure we have at least one tensor
205 if (tensors.empty()) {
206 throw std::invalid_argument("argument is empty");
207 }
208
209 // Ensure all tensors have identical type and shape
210 auto options = tensors[0].options();
211 auto sizes = tensors[0].sizes();
212 for (const auto i : c10::irange(1, tensors.size())) {
213 if (!tensors[i].options().type_equal(options)) {
214 const auto expected = toString(options);
215 const auto actual = toString(tensors[i].options());
216 throw std::invalid_argument(
217 // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
218 "argument contains mixed types (" + expected + " and " + actual +
219 ")");
220 }
221 if (!tensors[i].sizes().equals(sizes)) {
222 const auto expected = toString(sizes);
223 const auto actual = toString(tensors[i].sizes());
224 throw std::invalid_argument(
225 // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
226 "argument contains mixed types (" + expected + " and " + actual +
227 ")");
228 }
229 }
230 }
231
assertTypeMatch(const std::function<void (const std::string &)> & fn,const at::DeprecatedTypeProperties & type,const at::ArrayRef<at::Tensor> tensors,size_t index)232 inline void assertTypeMatch(
233 const std::function<void(const std::string&)>& fn,
234 const at::DeprecatedTypeProperties& type,
235 const at::ArrayRef<at::Tensor> tensors,
236 size_t index) {
237 if (!tensors[index].options().type_equal(type.options())) {
238 fn("invalid tensor type at index " + std::to_string(index) + " (expected " +
239 type.toString() + ", got " + tensors[index].toString() + ")");
240 }
241 }
242
assertTypeMatch(const std::function<void (const std::string &)> & fn,const at::TensorOptions & options,const at::ArrayRef<at::Tensor> tensors,size_t index)243 inline void assertTypeMatch(
244 const std::function<void(const std::string&)>& fn,
245 const at::TensorOptions& options,
246 const at::ArrayRef<at::Tensor> tensors,
247 size_t index) {
248 if (!tensors[index].options().type_equal(options)) {
249 fn("invalid tensor type at index " + std::to_string(index) + " (expected " +
250 toString(options) + ", got " + toString(tensors[index].options()) + ")");
251 }
252 }
253
assertSizesMatch(const std::function<void (const std::string &)> & fn,const at::IntArrayRef & sizes,const at::ArrayRef<at::Tensor> tensors,size_t index)254 inline void assertSizesMatch(
255 const std::function<void(const std::string&)>& fn,
256 const at::IntArrayRef& sizes,
257 const at::ArrayRef<at::Tensor> tensors,
258 size_t index) {
259 if (tensors[index].sizes() != sizes) {
260 fn("invalid tensor size at index " + std::to_string(index) + " (expected " +
261 toString(sizes) + ", got " + toString(tensors[index].sizes()) + ")");
262 }
263 }
264
assertLayoutMatch(const std::function<void (const std::string &)> & fn,const c10::Layout & expected,const at::ArrayRef<at::Tensor> tensors,size_t index)265 inline void assertLayoutMatch(
266 const std::function<void(const std::string&)>& fn,
267 const c10::Layout& expected,
268 const at::ArrayRef<at::Tensor> tensors,
269 size_t index) {
270 const auto& actual = tensors[index].layout();
271 if (actual != expected) {
272 fn("invalid tensor layout at index " + std::to_string(index) +
273 " (expected " + toString(expected) + ", got " + toString(actual) + ")");
274 }
275 }
276
assertLayoutMatch(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)277 inline void assertLayoutMatch(
278 const std::function<void(const std::string&)>& fn,
279 const at::ArrayRef<at::Tensor> tensors) {
280 const auto& layout = tensors[0].layout();
281 for (const auto i : c10::irange(1, tensors.size())) {
282 assertLayoutMatch(fn, layout, tensors, i);
283 }
284 }
285
assertNonEmpty(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)286 inline void assertNonEmpty(
287 const std::function<void(const std::string&)>& fn,
288 const at::ArrayRef<at::Tensor> tensors) {
289 if (tensors.empty()) {
290 fn("requires non-empty tensor list");
291 }
292 }
293
assertSingleElement(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)294 inline void assertSingleElement(
295 const std::function<void(const std::string&)>& fn,
296 const at::ArrayRef<at::Tensor> tensors) {
297 if (tensors.size() != 1) {
298 fn("requires a single-element tensor list");
299 }
300 }
301
assertSingleElementInput(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)302 inline void assertSingleElementInput(
303 const std::function<void(const std::string&)>& fn,
304 const at::ArrayRef<at::Tensor> tensors) {
305 if (tensors.size() != 1) {
306 fn("requires a single-element input tensor list");
307 }
308 }
309
assertSingleElementOutput(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)310 inline void assertSingleElementOutput(
311 const std::function<void(const std::string&)>& fn,
312 const at::ArrayRef<at::Tensor> tensors) {
313 if (tensors.size() != 1) {
314 fn("requires a single-element output tensor list");
315 }
316 }
317
assertRootRank(const std::function<void (const std::string &)> & fn,int64_t rank,int64_t size)318 inline void assertRootRank(
319 const std::function<void(const std::string&)>& fn,
320 int64_t rank,
321 int64_t size) {
322 if (rank < 0 || rank >= size) {
323 fn("invalid root rank: " + std::to_string(rank));
324 }
325 }
326
assertRootTensor(const std::function<void (const std::string &)> & fn,int64_t rank,int64_t size)327 inline void assertRootTensor(
328 const std::function<void(const std::string&)>& fn,
329 int64_t rank,
330 int64_t size) {
331 if (rank < 0 || rank >= size) {
332 fn("invalid root tensor: " + std::to_string(rank));
333 }
334 }
335
assertDense(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)336 inline void assertDense(
337 const std::function<void(const std::string&)>& fn,
338 const at::ArrayRef<at::Tensor> tensors) {
339 const auto& layout = tensors[0].layout();
340 if (layout != at::kStrided) {
341 fn("only supports dense tensors");
342 }
343 }
344
assertCPU(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)345 inline void assertCPU(
346 const std::function<void(const std::string&)>& fn,
347 const at::ArrayRef<at::Tensor> tensors) {
348 const auto& device = tensors[0].device();
349 if (device.type() != at::kCPU) {
350 fn("only supports CPU tensors");
351 }
352 }
353
assertSameDevice(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)354 inline void assertSameDevice(
355 const std::function<void(const std::string&)>& fn,
356 const at::ArrayRef<at::Tensor> tensors) {
357 if (tensors.size() < 2) {
358 return;
359 }
360 const auto& device = tensors[0].device();
361 for (const auto i : c10::irange(1, tensors.size())) {
362 if (tensors[i].device() != device) {
363 fn("tensors should be on the same device");
364 }
365 }
366 }
367
assertTypeAndSizesMatch(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors,const at::DeprecatedTypeProperties & type,const at::IntArrayRef & sizes)368 inline void assertTypeAndSizesMatch(
369 const std::function<void(const std::string&)>& fn,
370 const at::ArrayRef<at::Tensor> tensors,
371 const at::DeprecatedTypeProperties& type,
372 const at::IntArrayRef& sizes) {
373 for (const auto i : c10::irange(tensors.size())) {
374 assertTypeMatch(fn, type, tensors, i);
375 assertSizesMatch(fn, sizes, tensors, i);
376 }
377 }
378
assertTypeAndSizesMatch(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors,const at::TensorOptions & options,const at::IntArrayRef & sizes)379 inline void assertTypeAndSizesMatch(
380 const std::function<void(const std::string&)>& fn,
381 const at::ArrayRef<at::Tensor> tensors,
382 const at::TensorOptions& options,
383 const at::IntArrayRef& sizes) {
384 for (const auto i : c10::irange(tensors.size())) {
385 assertTypeMatch(fn, options, tensors, i);
386 assertSizesMatch(fn, sizes, tensors, i);
387 }
388 }
389
assertTypeAndSizesMatch(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)390 inline void assertTypeAndSizesMatch(
391 const std::function<void(const std::string&)>& fn,
392 const at::ArrayRef<at::Tensor> tensors) {
393 const auto& options = tensors[0].options();
394 const auto sizes = tensors[0].sizes();
395 assertTypeAndSizesMatch(fn, tensors.slice(1), options, sizes);
396 }
397
398 // Copied from ATen/core/functional.h.
399 template <typename F, typename T>
fmap(T & inputs,const F & fn)400 inline auto fmap(T& inputs, const F& fn)
401 -> std::vector<decltype(fn(*inputs.begin()))> {
402 std::vector<decltype(fn(*inputs.begin()))> r;
403 r.reserve(inputs.size());
404 for (auto& input : inputs) {
405 r.push_back(fn(input));
406 }
407 return r;
408 }
409
410 // Copied from torch/csrc/utils/tensor_flatten.h.
flattenDenseTensors(at::TensorList tensors)411 inline at::Tensor flattenDenseTensors(at::TensorList tensors) {
412 static const auto flatten = [](const at::Tensor& t) {
413 return t.contiguous().view({-1});
414 };
415 if (tensors.size() == 1) {
416 return flatten(tensors[0]);
417 }
418 return at::cat(::c10d::fmap(tensors, flatten));
419 }
420
newLikeFlat(std::vector<std::vector<at::Tensor>> & tensors,size_t deviceIdx)421 inline at::Tensor newLikeFlat(
422 std::vector<std::vector<at::Tensor>>& tensors,
423 size_t deviceIdx) {
424 if (tensors.empty() || tensors[0].empty()) {
425 TORCH_CHECK(false, "Received an empty list");
426 }
427 if (deviceIdx >= tensors.size()) {
428 TORCH_CHECK(false, "Invalid device index");
429 }
430 auto& t = tensors[deviceIdx][0];
431 auto device = t.device();
432 for (const auto i : c10::irange(1, tensors[deviceIdx].size())) {
433 if (tensors[deviceIdx][i].device() != device) {
434 TORCH_CHECK(false, "Expecting all tensors on the same device");
435 }
436 }
437 at::DeviceGuard gpuGuard(device);
438 std::vector<int64_t> sizes{static_cast<int64_t>(tensors[deviceIdx].size())};
439 std::vector<int64_t> strides{static_cast<int64_t>(t.numel())};
440 sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
441 strides.insert(strides.end(), t.strides().begin(), t.strides().end());
442 return at::empty_strided(
443 sizes, strides, t.options().memory_format(std::nullopt));
444 }
445
newLikeFlat(std::vector<at::Tensor> & tensors)446 inline at::Tensor newLikeFlat(std::vector<at::Tensor>& tensors) {
447 if (tensors.empty()) {
448 TORCH_CHECK(false, "Received an empty list");
449 }
450 auto& t = tensors[0];
451 at::DeviceGuard gpuGuard(t.device());
452 std::vector<int64_t> sizes{static_cast<int64_t>(tensors.size())};
453 sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
454 return at::empty(sizes, t.options());
455 }
456
getSizes(const std::vector<at::Tensor> & tensors)457 inline std::vector<std::vector<int64_t>> getSizes(
458 const std::vector<at::Tensor>& tensors) {
459 std::vector<std::vector<int64_t>> sizes(tensors.size());
460 for (const auto i : c10::irange(tensors.size())) {
461 sizes[i] = tensors[i].sizes().vec();
462 }
463 return sizes;
464 }
465
getDevices(const std::vector<at::Tensor> & tensors)466 inline std::vector<int> getDevices(const std::vector<at::Tensor>& tensors) {
467 std::vector<int> devices(tensors.size(), -1);
468 if (tensors[0].device().is_cuda()) {
469 for (const auto i : c10::irange(tensors.size())) {
470 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
471 devices[i] = tensors[i].storage().device().index();
472 }
473 }
474 return devices;
475 }
476
477 template <typename T>
getDataPointer(const at::Tensor & tensor)478 inline T* getDataPointer(const at::Tensor& tensor) {
479 // This method is only used in ProcessGroupGloo for now. Call sites must make
480 // sure that the input tensor is contiguous. It is OK if the tensor does not
481 // start from the beginning of the storage. For example, it could come from
482 // chunk(..., dim=0)[1]. Hence, we need to use data_ptr() instead of
483 // tensor.storage().data()
484 // NB: not using tensor.data<T>() because tensor is not aware of gloo::TYPE
485 return static_cast<T*>(tensor.data_ptr());
486 }
487
488 template <typename T>
getDataPointers(const std::vector<at::Tensor> & tensors)489 std::vector<T*> getDataPointers(const std::vector<at::Tensor>& tensors) {
490 std::vector<T*> ptrs(tensors.size());
491 for (const auto i : c10::irange(tensors.size())) {
492 ptrs[i] = getDataPointer<T>(tensors[i]);
493 }
494 return ptrs;
495 }
496
497 // For alltoall split size sanity check
checkSplitSizes(const std::vector<int64_t> & split_sizes,const at::Tensor & tensor,int group_size)498 inline void checkSplitSizes(
499 const std::vector<int64_t>& split_sizes,
500 const at::Tensor& tensor,
501 int group_size) {
502 if (split_sizes.empty()) {
503 TORCH_CHECK(
504 tensor.size(0) % group_size == 0,
505 "Tensor's dim 0 does not divide equally across group size");
506 } else {
507 TORCH_CHECK(
508 split_sizes.size() == static_cast<size_t>(group_size),
509 "Number of tensor splits not equal to group size");
510 const auto sum = c10::sum_integers(split_sizes);
511 TORCH_CHECK(
512 sum == tensor.size(0), "Split sizes doesn't match total dim 0 size");
513 }
514 }
515
516 // Compute alltoall lengths and offsets, handling multi-dimension tensors
517 template <typename T>
computeLengthsAndOffsets(const std::vector<int64_t> & split_sizes,const at::Tensor & tensor,std::vector<T> * lengths,std::vector<T> * offsets)518 size_t computeLengthsAndOffsets(
519 const std::vector<int64_t>& split_sizes,
520 const at::Tensor& tensor,
521 std::vector<T>* lengths,
522 std::vector<T>* offsets) {
523 size_t group_size = lengths->size();
524 bool equal_splits = false;
525 size_t dim0_size = tensor.size(0);
526 size_t row_size = (dim0_size ? tensor.numel() / dim0_size : 1);
527 size_t split_size = 0;
528 size_t offset = 0;
529
530 if (split_sizes.empty()) {
531 equal_splits = true;
532 split_size = tensor.size(0) / group_size;
533 }
534 for (const auto i : c10::irange(group_size)) {
535 size_t length = row_size * (equal_splits ? split_size : split_sizes[i]);
536 (*lengths)[i] = length;
537 (*offsets)[i] = offset;
538 // TODO: see if we should add overflow protection for offset
539 offset += length;
540 }
541 return offset;
542 }
543
544 template <typename T>
computeLengthsAndOffsets(const std::vector<at::Tensor> & tensors,std::vector<T> * lengths,std::vector<T> * offsets)545 size_t computeLengthsAndOffsets(
546 const std::vector<at::Tensor>& tensors,
547 std::vector<T>* lengths,
548 std::vector<T>* offsets) {
549 size_t group_size = lengths->size();
550 size_t offset = 0;
551 for (const auto i : c10::irange(group_size)) {
552 size_t length = tensors[i].numel();
553 (*lengths)[i] = length;
554 (*offsets)[i] = offset;
555 offset += length;
556 }
557 return offset;
558 }
559
560 using RankType = uint32_t;
561 using SizeType = uint64_t;
562
563 // `errno` is only meaningful when it fails. E.g., a successful `fork()` sets
564 // `errno` to `EINVAL` in child process on some macos
565 // (https://stackoverflow.com/a/20295079), and thus `errno` should really only
566 // be inspected if an error occurred.
567 //
568 // `success_cond` is an expression used to check if an error has happend. So for
569 // `fork()`, we can use `SYSCHECK(pid = fork(), pid != -1)`. The function output
570 // is stored in variable `__output` and may be used in `success_cond`.
571 #ifdef _WIN32
572 #define SYSCHECK(expr, success_cond) \
573 while (true) { \
574 auto __output = (expr); \
575 auto errno_local = WSAGetLastError(); \
576 (void)__output; \
577 if (!(success_cond)) { \
578 if (errno == EINTR) { \
579 continue; \
580 } else if ( \
581 errno_local == WSAETIMEDOUT || errno_local == WSAEWOULDBLOCK) { \
582 C10_THROW_ERROR(DistNetworkError, "Socket Timeout"); \
583 } else { \
584 C10_THROW_ERROR(DistNetworkError, std::strerror(errno_local)); \
585 } \
586 } else { \
587 break; \
588 } \
589 }
590 #else
591 #define SYSCHECK(expr, success_cond) \
592 while (true) { \
593 auto __output = (expr); \
594 (void)__output; \
595 if (!(success_cond)) { \
596 if (errno == EINTR) { \
597 continue; \
598 } else if (errno == EAGAIN || errno == EWOULDBLOCK) { \
599 C10_THROW_ERROR(DistNetworkError, "Socket Timeout"); \
600 } else { \
601 C10_THROW_ERROR(DistNetworkError, std::strerror(errno)); \
602 } \
603 } else { \
604 break; \
605 } \
606 }
607 #endif
608
609 // Most functions indicate error by returning `-1`. This is a helper macro for
610 // this common case with `SYSCHECK`.
611 // Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1
612 #define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1)
613
614 namespace tcputil {
615
616 // Send and receive
617 template <typename T>
sendBytes(int socket,const T * buffer,size_t length,bool moreData=false)618 void sendBytes(
619 int socket,
620 const T* buffer,
621 size_t length,
622 bool moreData = false) {
623 size_t bytesToSend = sizeof(T) * length;
624 if (bytesToSend == 0) {
625 return;
626 }
627
628 auto currentBytes = reinterpret_cast<const char*>(buffer);
629
630 int flags = 0;
631
632 #ifdef MSG_MORE
633 if (moreData) { // there is more data to send
634 flags |= MSG_MORE;
635 }
636 #endif
637
638 // Ignore SIGPIPE as the send() return value is always checked for error
639 #ifdef MSG_NOSIGNAL
640 flags |= MSG_NOSIGNAL;
641 #endif
642
643 while (bytesToSend > 0) {
644 ssize_t bytesSent = 0;
645 SYSCHECK_ERR_RETURN_NEG1(
646 bytesSent = ::send(socket, currentBytes, bytesToSend, flags))
647 if (bytesSent == 0) {
648 C10_THROW_ERROR(DistNetworkError, "failed to send, sent 0 bytes");
649 }
650
651 bytesToSend -= bytesSent;
652 currentBytes += bytesSent;
653 }
654 }
655
656 template <typename T>
recvBytes(int socket,T * buffer,size_t length)657 void recvBytes(int socket, T* buffer, size_t length) {
658 size_t bytesToReceive = sizeof(T) * length;
659 if (bytesToReceive == 0) {
660 return;
661 }
662
663 auto currentBytes = reinterpret_cast<char*>(buffer);
664
665 while (bytesToReceive > 0) {
666 ssize_t bytesReceived = 0;
667 SYSCHECK_ERR_RETURN_NEG1(
668 bytesReceived = recv(socket, currentBytes, bytesToReceive, 0))
669 if (bytesReceived == 0) {
670 C10_THROW_ERROR(DistNetworkError, "failed to recv, got 0 bytes");
671 }
672
673 bytesToReceive -= bytesReceived;
674 currentBytes += bytesReceived;
675 }
676 }
677
678 // send a vector's length and data
679 template <typename T>
sendVector(int socket,const std::vector<T> & vec,bool moreData=false)680 void sendVector(int socket, const std::vector<T>& vec, bool moreData = false) {
681 SizeType size = vec.size();
682 sendBytes<SizeType>(socket, &size, 1, true);
683 sendBytes<T>(socket, vec.data(), size, moreData);
684 }
685
686 // receive a vector as sent in sendVector
687 template <typename T>
recvVector(int socket)688 std::vector<T> recvVector(int socket) {
689 SizeType valueSize = 0;
690 recvBytes<SizeType>(socket, &valueSize, 1);
691 std::vector<T> value(valueSize);
692 recvBytes<T>(socket, value.data(), value.size());
693 return value;
694 }
695
696 // this is only for convenience when sending rvalues
697 template <typename T>
sendValue(int socket,const T & value,bool moreData=false)698 void sendValue(int socket, const T& value, bool moreData = false) {
699 sendBytes<T>(socket, &value, 1, moreData);
700 }
701
702 template <typename T>
recvValue(int socket)703 T recvValue(int socket) {
704 T value;
705 recvBytes<T>(socket, &value, 1);
706 return value;
707 }
708
709 // send a string's length and data
sendString(int socket,const std::string & str,bool moreData=false)710 inline void sendString(
711 int socket,
712 const std::string& str,
713 bool moreData = false) {
714 SizeType size = str.size();
715 sendBytes<SizeType>(socket, &size, 1, true);
716 sendBytes<char>(socket, str.data(), size, moreData);
717 }
718
719 // receive a string as sent in sendString
recvString(int socket)720 inline std::string recvString(int socket) {
721 SizeType valueSize = 0;
722 recvBytes<SizeType>(socket, &valueSize, 1);
723 std::vector<char> value(valueSize);
724 recvBytes<char>(socket, value.data(), value.size());
725 return std::string(value.data(), value.size());
726 }
727
728 } // namespace tcputil
729 } // namespace c10d
730