xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Utils.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/accumulate.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/distributed/c10d/Types.hpp>
8 
9 #ifdef _WIN32
10 #include <winsock2.h>
11 #include <ws2tcpip.h>
12 typedef SSIZE_T ssize_t;
13 #pragma comment(lib, "Ws2_32.lib")
14 #else
15 #include <fcntl.h>
16 #include <netdb.h>
17 #include <sys/poll.h>
18 #include <sys/socket.h>
19 #include <unistd.h>
20 #endif
21 
22 #include <sys/types.h>
23 
24 #include <cstdint>
25 #include <cstdlib>
26 #include <functional>
27 #include <string>
28 #include <vector>
29 
30 namespace c10d {
31 
32 TORCH_API size_t getTensorsNumel(const std::vector<at::Tensor>& tensors);
33 
34 // Retrieve tensor shapes from a given tensor.
35 TORCH_API std::vector<at::Tensor> getTensorShapes(
36     const std::vector<at::Tensor>& tensors);
37 
38 // Use -2 to represent unset state of env vars
39 #define C10D_ENV_NOT_SET -2
40 
41 #define WARN_ENV_VAR_ONCE(deprecated_env, new_env)                        \
42   TORCH_WARN_ONCE(                                                        \
43       "Environment variable " + deprecated_env + " is deprecated; use " + \
44       new_env + " instead");
45 
46 // Turns at::IntArrayRef into "(1, 2, 3, 4)".
toString(at::IntArrayRef l)47 inline std::string toString(at::IntArrayRef l) {
48   std::stringstream ss;
49   ss << "(";
50   for (const auto i : c10::irange(l.size())) {
51     if (i > 0) {
52       ss << ", ";
53     }
54     ss << l[i];
55   }
56   ss << ")";
57   return ss.str();
58 }
59 
toString(const c10::Layout & layout)60 inline std::string toString(const c10::Layout& layout) {
61   std::stringstream ss;
62   ss << layout;
63   return ss.str();
64 }
65 
assertSameType(const at::DeprecatedTypeProperties & type,const std::vector<at::Tensor> & tensors)66 inline void assertSameType(
67     const at::DeprecatedTypeProperties& type,
68     const std::vector<at::Tensor>& tensors) {
69   for (const auto i : c10::irange(tensors.size())) {
70     if (!tensors[i].options().type_equal(type.options())) {
71       const std::string expected = type.toString();
72       const std::string actual = tensors[i].toString();
73       throw std::invalid_argument(
74           // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
75           "mixed types (" + expected + " and " + actual + ")");
76     }
77   }
78 }
79 
split(char separator,const std::string & string)80 inline std::vector<std::string> split(
81     char separator,
82     const std::string& string) {
83   std::vector<std::string> pieces;
84   std::stringstream ss(string);
85   std::string item;
86   while (std::getline(ss, item, separator)) {
87     pieces.push_back(std::move(item));
88   }
89   return pieces;
90 }
91 
getCvarString(const std::vector<std::string> & env,const char * def)92 inline std::string getCvarString(
93     const std::vector<std::string>& env,
94     const char* def) {
95   const char* ret = def;
96 
97   if (env.empty()) {
98     TORCH_CHECK(false, "No environment variables passed");
99     return ret;
100   }
101 
102   /* parse environment variable in reverse order, so the early
103    * versions of a variable get higher priority than the latter
104    * versions of the same variable */
105   for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
106     const char* val = std::getenv(env[i].c_str());
107     if (val == nullptr) {
108       continue;
109     } else if (i) {
110       WARN_ENV_VAR_ONCE(env[i], env[0]);
111     }
112 
113     ret = val;
114   }
115 
116   return ret;
117 }
118 
getCvarInt(const std::vector<std::string> & env,int def)119 inline int getCvarInt(const std::vector<std::string>& env, int def) {
120   int ret = def;
121 
122   if (env.empty()) {
123     TORCH_CHECK(false, "No environment variables passed");
124     return ret;
125   }
126 
127   /* parse environment variable in reverse order, so the early
128    * versions of a variable get higher priority than the latter
129    * versions of the same variable */
130   for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
131     char* val = std::getenv(env[i].c_str());
132     if (val == nullptr) {
133       continue;
134     } else if (i) {
135       WARN_ENV_VAR_ONCE(env[i], env[0]);
136     }
137 
138     try {
139       ret = std::stoi(val);
140     } catch (std::exception&) {
141       TORCH_CHECK(false, "Invalid value for environment variable: " + env[i]);
142     }
143   }
144 
145   return ret;
146 }
147 
getCvarBool(const std::vector<std::string> & env,bool def)148 inline bool getCvarBool(const std::vector<std::string>& env, bool def) {
149   bool ret = def;
150 
151   if (env.empty()) {
152     TORCH_CHECK(false, "No environment variables passed");
153     return ret;
154   }
155 
156   /* parse environment variable in reverse order, so the early
157    * versions of a variable get higher priority than the latter
158    * versions of the same variable */
159   for (ssize_t i = static_cast<ssize_t>(env.size()) - 1; i >= 0; i--) {
160     char* val_ = std::getenv(env[i].c_str());
161     if (val_ == nullptr) {
162       continue;
163     } else if (i) {
164       WARN_ENV_VAR_ONCE(env[i], env[0]);
165     }
166 
167     std::string val = std::string(val_);
168     for (auto& x : val) {
169       // NOLINTNEXTLINE(*-narrowing-conversions)
170       x = std::tolower(x);
171     }
172 
173     if (val == "y" || val == "yes" || val == "1" || val == "t" ||
174         val == "true") {
175       ret = true;
176     } else if (
177         val == "n" || val == "no" || val == "0" || val == "f" ||
178         val == "false") {
179       ret = false;
180     } else {
181       TORCH_CHECK(false, "Invalid value for environment variable: " + env[i]);
182       return ret;
183     }
184   }
185 
186   return ret;
187 }
188 
assertSameSizes(const at::IntArrayRef & sizes,const std::vector<at::Tensor> & tensors)189 inline void assertSameSizes(
190     const at::IntArrayRef& sizes,
191     const std::vector<at::Tensor>& tensors) {
192   for (const auto i : c10::irange(tensors.size())) {
193     if (!tensors[i].sizes().equals(sizes)) {
194       const auto expected = toString(sizes);
195       const auto actual = toString(tensors[i].sizes());
196       throw std::invalid_argument(
197           // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
198           "mixed sizes (" + expected + " and " + actual + ")");
199     }
200   }
201 }
202 
assertSameSizeAndType(const std::vector<at::Tensor> & tensors)203 inline void assertSameSizeAndType(const std::vector<at::Tensor>& tensors) {
204   // Ensure we have at least one tensor
205   if (tensors.empty()) {
206     throw std::invalid_argument("argument is empty");
207   }
208 
209   // Ensure all tensors have identical type and shape
210   auto options = tensors[0].options();
211   auto sizes = tensors[0].sizes();
212   for (const auto i : c10::irange(1, tensors.size())) {
213     if (!tensors[i].options().type_equal(options)) {
214       const auto expected = toString(options);
215       const auto actual = toString(tensors[i].options());
216       throw std::invalid_argument(
217           // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
218           "argument contains mixed types (" + expected + " and " + actual +
219           ")");
220     }
221     if (!tensors[i].sizes().equals(sizes)) {
222       const auto expected = toString(sizes);
223       const auto actual = toString(tensors[i].sizes());
224       throw std::invalid_argument(
225           // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
226           "argument contains mixed types (" + expected + " and " + actual +
227           ")");
228     }
229   }
230 }
231 
assertTypeMatch(const std::function<void (const std::string &)> & fn,const at::DeprecatedTypeProperties & type,const at::ArrayRef<at::Tensor> tensors,size_t index)232 inline void assertTypeMatch(
233     const std::function<void(const std::string&)>& fn,
234     const at::DeprecatedTypeProperties& type,
235     const at::ArrayRef<at::Tensor> tensors,
236     size_t index) {
237   if (!tensors[index].options().type_equal(type.options())) {
238     fn("invalid tensor type at index " + std::to_string(index) + " (expected " +
239        type.toString() + ", got " + tensors[index].toString() + ")");
240   }
241 }
242 
assertTypeMatch(const std::function<void (const std::string &)> & fn,const at::TensorOptions & options,const at::ArrayRef<at::Tensor> tensors,size_t index)243 inline void assertTypeMatch(
244     const std::function<void(const std::string&)>& fn,
245     const at::TensorOptions& options,
246     const at::ArrayRef<at::Tensor> tensors,
247     size_t index) {
248   if (!tensors[index].options().type_equal(options)) {
249     fn("invalid tensor type at index " + std::to_string(index) + " (expected " +
250        toString(options) + ", got " + toString(tensors[index].options()) + ")");
251   }
252 }
253 
assertSizesMatch(const std::function<void (const std::string &)> & fn,const at::IntArrayRef & sizes,const at::ArrayRef<at::Tensor> tensors,size_t index)254 inline void assertSizesMatch(
255     const std::function<void(const std::string&)>& fn,
256     const at::IntArrayRef& sizes,
257     const at::ArrayRef<at::Tensor> tensors,
258     size_t index) {
259   if (tensors[index].sizes() != sizes) {
260     fn("invalid tensor size at index " + std::to_string(index) + " (expected " +
261        toString(sizes) + ", got " + toString(tensors[index].sizes()) + ")");
262   }
263 }
264 
assertLayoutMatch(const std::function<void (const std::string &)> & fn,const c10::Layout & expected,const at::ArrayRef<at::Tensor> tensors,size_t index)265 inline void assertLayoutMatch(
266     const std::function<void(const std::string&)>& fn,
267     const c10::Layout& expected,
268     const at::ArrayRef<at::Tensor> tensors,
269     size_t index) {
270   const auto& actual = tensors[index].layout();
271   if (actual != expected) {
272     fn("invalid tensor layout at index " + std::to_string(index) +
273        " (expected " + toString(expected) + ", got " + toString(actual) + ")");
274   }
275 }
276 
assertLayoutMatch(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)277 inline void assertLayoutMatch(
278     const std::function<void(const std::string&)>& fn,
279     const at::ArrayRef<at::Tensor> tensors) {
280   const auto& layout = tensors[0].layout();
281   for (const auto i : c10::irange(1, tensors.size())) {
282     assertLayoutMatch(fn, layout, tensors, i);
283   }
284 }
285 
assertNonEmpty(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)286 inline void assertNonEmpty(
287     const std::function<void(const std::string&)>& fn,
288     const at::ArrayRef<at::Tensor> tensors) {
289   if (tensors.empty()) {
290     fn("requires non-empty tensor list");
291   }
292 }
293 
assertSingleElement(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)294 inline void assertSingleElement(
295     const std::function<void(const std::string&)>& fn,
296     const at::ArrayRef<at::Tensor> tensors) {
297   if (tensors.size() != 1) {
298     fn("requires a single-element tensor list");
299   }
300 }
301 
assertSingleElementInput(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)302 inline void assertSingleElementInput(
303     const std::function<void(const std::string&)>& fn,
304     const at::ArrayRef<at::Tensor> tensors) {
305   if (tensors.size() != 1) {
306     fn("requires a single-element input tensor list");
307   }
308 }
309 
assertSingleElementOutput(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)310 inline void assertSingleElementOutput(
311     const std::function<void(const std::string&)>& fn,
312     const at::ArrayRef<at::Tensor> tensors) {
313   if (tensors.size() != 1) {
314     fn("requires a single-element output tensor list");
315   }
316 }
317 
assertRootRank(const std::function<void (const std::string &)> & fn,int64_t rank,int64_t size)318 inline void assertRootRank(
319     const std::function<void(const std::string&)>& fn,
320     int64_t rank,
321     int64_t size) {
322   if (rank < 0 || rank >= size) {
323     fn("invalid root rank: " + std::to_string(rank));
324   }
325 }
326 
assertRootTensor(const std::function<void (const std::string &)> & fn,int64_t rank,int64_t size)327 inline void assertRootTensor(
328     const std::function<void(const std::string&)>& fn,
329     int64_t rank,
330     int64_t size) {
331   if (rank < 0 || rank >= size) {
332     fn("invalid root tensor: " + std::to_string(rank));
333   }
334 }
335 
assertDense(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)336 inline void assertDense(
337     const std::function<void(const std::string&)>& fn,
338     const at::ArrayRef<at::Tensor> tensors) {
339   const auto& layout = tensors[0].layout();
340   if (layout != at::kStrided) {
341     fn("only supports dense tensors");
342   }
343 }
344 
assertCPU(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)345 inline void assertCPU(
346     const std::function<void(const std::string&)>& fn,
347     const at::ArrayRef<at::Tensor> tensors) {
348   const auto& device = tensors[0].device();
349   if (device.type() != at::kCPU) {
350     fn("only supports CPU tensors");
351   }
352 }
353 
assertSameDevice(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)354 inline void assertSameDevice(
355     const std::function<void(const std::string&)>& fn,
356     const at::ArrayRef<at::Tensor> tensors) {
357   if (tensors.size() < 2) {
358     return;
359   }
360   const auto& device = tensors[0].device();
361   for (const auto i : c10::irange(1, tensors.size())) {
362     if (tensors[i].device() != device) {
363       fn("tensors should be on the same device");
364     }
365   }
366 }
367 
assertTypeAndSizesMatch(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors,const at::DeprecatedTypeProperties & type,const at::IntArrayRef & sizes)368 inline void assertTypeAndSizesMatch(
369     const std::function<void(const std::string&)>& fn,
370     const at::ArrayRef<at::Tensor> tensors,
371     const at::DeprecatedTypeProperties& type,
372     const at::IntArrayRef& sizes) {
373   for (const auto i : c10::irange(tensors.size())) {
374     assertTypeMatch(fn, type, tensors, i);
375     assertSizesMatch(fn, sizes, tensors, i);
376   }
377 }
378 
assertTypeAndSizesMatch(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors,const at::TensorOptions & options,const at::IntArrayRef & sizes)379 inline void assertTypeAndSizesMatch(
380     const std::function<void(const std::string&)>& fn,
381     const at::ArrayRef<at::Tensor> tensors,
382     const at::TensorOptions& options,
383     const at::IntArrayRef& sizes) {
384   for (const auto i : c10::irange(tensors.size())) {
385     assertTypeMatch(fn, options, tensors, i);
386     assertSizesMatch(fn, sizes, tensors, i);
387   }
388 }
389 
assertTypeAndSizesMatch(const std::function<void (const std::string &)> & fn,const at::ArrayRef<at::Tensor> tensors)390 inline void assertTypeAndSizesMatch(
391     const std::function<void(const std::string&)>& fn,
392     const at::ArrayRef<at::Tensor> tensors) {
393   const auto& options = tensors[0].options();
394   const auto sizes = tensors[0].sizes();
395   assertTypeAndSizesMatch(fn, tensors.slice(1), options, sizes);
396 }
397 
398 // Copied from ATen/core/functional.h.
399 template <typename F, typename T>
fmap(T & inputs,const F & fn)400 inline auto fmap(T& inputs, const F& fn)
401     -> std::vector<decltype(fn(*inputs.begin()))> {
402   std::vector<decltype(fn(*inputs.begin()))> r;
403   r.reserve(inputs.size());
404   for (auto& input : inputs) {
405     r.push_back(fn(input));
406   }
407   return r;
408 }
409 
410 // Copied from torch/csrc/utils/tensor_flatten.h.
flattenDenseTensors(at::TensorList tensors)411 inline at::Tensor flattenDenseTensors(at::TensorList tensors) {
412   static const auto flatten = [](const at::Tensor& t) {
413     return t.contiguous().view({-1});
414   };
415   if (tensors.size() == 1) {
416     return flatten(tensors[0]);
417   }
418   return at::cat(::c10d::fmap(tensors, flatten));
419 }
420 
newLikeFlat(std::vector<std::vector<at::Tensor>> & tensors,size_t deviceIdx)421 inline at::Tensor newLikeFlat(
422     std::vector<std::vector<at::Tensor>>& tensors,
423     size_t deviceIdx) {
424   if (tensors.empty() || tensors[0].empty()) {
425     TORCH_CHECK(false, "Received an empty list");
426   }
427   if (deviceIdx >= tensors.size()) {
428     TORCH_CHECK(false, "Invalid device index");
429   }
430   auto& t = tensors[deviceIdx][0];
431   auto device = t.device();
432   for (const auto i : c10::irange(1, tensors[deviceIdx].size())) {
433     if (tensors[deviceIdx][i].device() != device) {
434       TORCH_CHECK(false, "Expecting all tensors on the same device");
435     }
436   }
437   at::DeviceGuard gpuGuard(device);
438   std::vector<int64_t> sizes{static_cast<int64_t>(tensors[deviceIdx].size())};
439   std::vector<int64_t> strides{static_cast<int64_t>(t.numel())};
440   sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
441   strides.insert(strides.end(), t.strides().begin(), t.strides().end());
442   return at::empty_strided(
443       sizes, strides, t.options().memory_format(std::nullopt));
444 }
445 
newLikeFlat(std::vector<at::Tensor> & tensors)446 inline at::Tensor newLikeFlat(std::vector<at::Tensor>& tensors) {
447   if (tensors.empty()) {
448     TORCH_CHECK(false, "Received an empty list");
449   }
450   auto& t = tensors[0];
451   at::DeviceGuard gpuGuard(t.device());
452   std::vector<int64_t> sizes{static_cast<int64_t>(tensors.size())};
453   sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
454   return at::empty(sizes, t.options());
455 }
456 
getSizes(const std::vector<at::Tensor> & tensors)457 inline std::vector<std::vector<int64_t>> getSizes(
458     const std::vector<at::Tensor>& tensors) {
459   std::vector<std::vector<int64_t>> sizes(tensors.size());
460   for (const auto i : c10::irange(tensors.size())) {
461     sizes[i] = tensors[i].sizes().vec();
462   }
463   return sizes;
464 }
465 
getDevices(const std::vector<at::Tensor> & tensors)466 inline std::vector<int> getDevices(const std::vector<at::Tensor>& tensors) {
467   std::vector<int> devices(tensors.size(), -1);
468   if (tensors[0].device().is_cuda()) {
469     for (const auto i : c10::irange(tensors.size())) {
470       // NOLINTNEXTLINE(bugprone-signed-char-misuse)
471       devices[i] = tensors[i].storage().device().index();
472     }
473   }
474   return devices;
475 }
476 
477 template <typename T>
getDataPointer(const at::Tensor & tensor)478 inline T* getDataPointer(const at::Tensor& tensor) {
479   // This method is only used in ProcessGroupGloo for now. Call sites must make
480   // sure that the input tensor is contiguous. It is OK if the tensor does not
481   // start from the beginning of the storage. For example, it could come from
482   // chunk(..., dim=0)[1]. Hence, we need to use data_ptr() instead of
483   // tensor.storage().data()
484   // NB: not using tensor.data<T>() because tensor is not aware of gloo::TYPE
485   return static_cast<T*>(tensor.data_ptr());
486 }
487 
488 template <typename T>
getDataPointers(const std::vector<at::Tensor> & tensors)489 std::vector<T*> getDataPointers(const std::vector<at::Tensor>& tensors) {
490   std::vector<T*> ptrs(tensors.size());
491   for (const auto i : c10::irange(tensors.size())) {
492     ptrs[i] = getDataPointer<T>(tensors[i]);
493   }
494   return ptrs;
495 }
496 
497 // For alltoall split size sanity check
checkSplitSizes(const std::vector<int64_t> & split_sizes,const at::Tensor & tensor,int group_size)498 inline void checkSplitSizes(
499     const std::vector<int64_t>& split_sizes,
500     const at::Tensor& tensor,
501     int group_size) {
502   if (split_sizes.empty()) {
503     TORCH_CHECK(
504         tensor.size(0) % group_size == 0,
505         "Tensor's dim 0 does not divide equally across group size");
506   } else {
507     TORCH_CHECK(
508         split_sizes.size() == static_cast<size_t>(group_size),
509         "Number of tensor splits not equal to group size");
510     const auto sum = c10::sum_integers(split_sizes);
511     TORCH_CHECK(
512         sum == tensor.size(0), "Split sizes doesn't match total dim 0 size");
513   }
514 }
515 
516 // Compute alltoall lengths and offsets, handling multi-dimension tensors
517 template <typename T>
computeLengthsAndOffsets(const std::vector<int64_t> & split_sizes,const at::Tensor & tensor,std::vector<T> * lengths,std::vector<T> * offsets)518 size_t computeLengthsAndOffsets(
519     const std::vector<int64_t>& split_sizes,
520     const at::Tensor& tensor,
521     std::vector<T>* lengths,
522     std::vector<T>* offsets) {
523   size_t group_size = lengths->size();
524   bool equal_splits = false;
525   size_t dim0_size = tensor.size(0);
526   size_t row_size = (dim0_size ? tensor.numel() / dim0_size : 1);
527   size_t split_size = 0;
528   size_t offset = 0;
529 
530   if (split_sizes.empty()) {
531     equal_splits = true;
532     split_size = tensor.size(0) / group_size;
533   }
534   for (const auto i : c10::irange(group_size)) {
535     size_t length = row_size * (equal_splits ? split_size : split_sizes[i]);
536     (*lengths)[i] = length;
537     (*offsets)[i] = offset;
538     // TODO: see if we should add overflow protection for offset
539     offset += length;
540   }
541   return offset;
542 }
543 
544 template <typename T>
computeLengthsAndOffsets(const std::vector<at::Tensor> & tensors,std::vector<T> * lengths,std::vector<T> * offsets)545 size_t computeLengthsAndOffsets(
546     const std::vector<at::Tensor>& tensors,
547     std::vector<T>* lengths,
548     std::vector<T>* offsets) {
549   size_t group_size = lengths->size();
550   size_t offset = 0;
551   for (const auto i : c10::irange(group_size)) {
552     size_t length = tensors[i].numel();
553     (*lengths)[i] = length;
554     (*offsets)[i] = offset;
555     offset += length;
556   }
557   return offset;
558 }
559 
560 using RankType = uint32_t;
561 using SizeType = uint64_t;
562 
563 // `errno` is only meaningful when it fails. E.g., a  successful `fork()` sets
564 // `errno` to `EINVAL` in child process on some macos
565 // (https://stackoverflow.com/a/20295079), and thus `errno` should really only
566 // be inspected if an error occurred.
567 //
568 // `success_cond` is an expression used to check if an error has happend. So for
569 // `fork()`, we can use `SYSCHECK(pid = fork(), pid != -1)`. The function output
570 // is stored in variable `__output` and may be used in `success_cond`.
571 #ifdef _WIN32
572 #define SYSCHECK(expr, success_cond)                                      \
573   while (true) {                                                          \
574     auto __output = (expr);                                               \
575     auto errno_local = WSAGetLastError();                                 \
576     (void)__output;                                                       \
577     if (!(success_cond)) {                                                \
578       if (errno == EINTR) {                                               \
579         continue;                                                         \
580       } else if (                                                         \
581           errno_local == WSAETIMEDOUT || errno_local == WSAEWOULDBLOCK) { \
582         C10_THROW_ERROR(DistNetworkError, "Socket Timeout");              \
583       } else {                                                            \
584         C10_THROW_ERROR(DistNetworkError, std::strerror(errno_local));    \
585       }                                                                   \
586     } else {                                                              \
587       break;                                                              \
588     }                                                                     \
589   }
590 #else
591 #define SYSCHECK(expr, success_cond)                             \
592   while (true) {                                                 \
593     auto __output = (expr);                                      \
594     (void)__output;                                              \
595     if (!(success_cond)) {                                       \
596       if (errno == EINTR) {                                      \
597         continue;                                                \
598       } else if (errno == EAGAIN || errno == EWOULDBLOCK) {      \
599         C10_THROW_ERROR(DistNetworkError, "Socket Timeout");     \
600       } else {                                                   \
601         C10_THROW_ERROR(DistNetworkError, std::strerror(errno)); \
602       }                                                          \
603     } else {                                                     \
604       break;                                                     \
605     }                                                            \
606   }
607 #endif
608 
609 // Most functions indicate error by returning `-1`. This is a helper macro for
610 // this common case with `SYSCHECK`.
611 // Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1
612 #define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1)
613 
614 namespace tcputil {
615 
616 // Send and receive
617 template <typename T>
sendBytes(int socket,const T * buffer,size_t length,bool moreData=false)618 void sendBytes(
619     int socket,
620     const T* buffer,
621     size_t length,
622     bool moreData = false) {
623   size_t bytesToSend = sizeof(T) * length;
624   if (bytesToSend == 0) {
625     return;
626   }
627 
628   auto currentBytes = reinterpret_cast<const char*>(buffer);
629 
630   int flags = 0;
631 
632 #ifdef MSG_MORE
633   if (moreData) { // there is more data to send
634     flags |= MSG_MORE;
635   }
636 #endif
637 
638 // Ignore SIGPIPE as the send() return value is always checked for error
639 #ifdef MSG_NOSIGNAL
640   flags |= MSG_NOSIGNAL;
641 #endif
642 
643   while (bytesToSend > 0) {
644     ssize_t bytesSent = 0;
645     SYSCHECK_ERR_RETURN_NEG1(
646         bytesSent = ::send(socket, currentBytes, bytesToSend, flags))
647     if (bytesSent == 0) {
648       C10_THROW_ERROR(DistNetworkError, "failed to send, sent 0 bytes");
649     }
650 
651     bytesToSend -= bytesSent;
652     currentBytes += bytesSent;
653   }
654 }
655 
656 template <typename T>
recvBytes(int socket,T * buffer,size_t length)657 void recvBytes(int socket, T* buffer, size_t length) {
658   size_t bytesToReceive = sizeof(T) * length;
659   if (bytesToReceive == 0) {
660     return;
661   }
662 
663   auto currentBytes = reinterpret_cast<char*>(buffer);
664 
665   while (bytesToReceive > 0) {
666     ssize_t bytesReceived = 0;
667     SYSCHECK_ERR_RETURN_NEG1(
668         bytesReceived = recv(socket, currentBytes, bytesToReceive, 0))
669     if (bytesReceived == 0) {
670       C10_THROW_ERROR(DistNetworkError, "failed to recv, got 0 bytes");
671     }
672 
673     bytesToReceive -= bytesReceived;
674     currentBytes += bytesReceived;
675   }
676 }
677 
678 // send a vector's length and data
679 template <typename T>
sendVector(int socket,const std::vector<T> & vec,bool moreData=false)680 void sendVector(int socket, const std::vector<T>& vec, bool moreData = false) {
681   SizeType size = vec.size();
682   sendBytes<SizeType>(socket, &size, 1, true);
683   sendBytes<T>(socket, vec.data(), size, moreData);
684 }
685 
686 // receive a vector as sent in sendVector
687 template <typename T>
recvVector(int socket)688 std::vector<T> recvVector(int socket) {
689   SizeType valueSize = 0;
690   recvBytes<SizeType>(socket, &valueSize, 1);
691   std::vector<T> value(valueSize);
692   recvBytes<T>(socket, value.data(), value.size());
693   return value;
694 }
695 
696 // this is only for convenience when sending rvalues
697 template <typename T>
sendValue(int socket,const T & value,bool moreData=false)698 void sendValue(int socket, const T& value, bool moreData = false) {
699   sendBytes<T>(socket, &value, 1, moreData);
700 }
701 
702 template <typename T>
recvValue(int socket)703 T recvValue(int socket) {
704   T value;
705   recvBytes<T>(socket, &value, 1);
706   return value;
707 }
708 
709 // send a string's length and data
sendString(int socket,const std::string & str,bool moreData=false)710 inline void sendString(
711     int socket,
712     const std::string& str,
713     bool moreData = false) {
714   SizeType size = str.size();
715   sendBytes<SizeType>(socket, &size, 1, true);
716   sendBytes<char>(socket, str.data(), size, moreData);
717 }
718 
719 // receive a string as sent in sendString
recvString(int socket)720 inline std::string recvString(int socket) {
721   SizeType valueSize = 0;
722   recvBytes<SizeType>(socket, &valueSize, 1);
723   std::vector<char> value(valueSize);
724   recvBytes<char>(socket, value.data(), value.size());
725   return std::string(value.data(), value.size());
726 }
727 
728 } // namespace tcputil
729 } // namespace c10d
730