1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 #include <chrono> 5 #include <cstdint> 6 #include <string> 7 #include <vector> 8 9 #include <c10/macros/Macros.h> 10 #include <torch/custom_class.h> 11 12 namespace c10d { 13 14 using namespace std::chrono_literals; 15 16 class TORCH_API ControlCollectives : public torch::CustomClassHolder { 17 public: 18 virtual void barrier( 19 const std::string& key, 20 std::chrono::milliseconds timeout = 5min, 21 bool block = true) = 0; 22 23 virtual void broadcastSend( 24 const std::string& key, 25 const std::vector<uint8_t>& data, 26 std::chrono::milliseconds timeout = 5min) = 0; 27 virtual std::vector<uint8_t> broadcastRecv( 28 const std::string& key, 29 std::chrono::milliseconds timeout = 5min) = 0; 30 31 virtual void gatherSend( 32 const std::string& key, 33 const std::vector<uint8_t>& data, 34 std::chrono::milliseconds timeout = 5min) = 0; 35 virtual std::vector<std::vector<uint8_t>> gatherRecv( 36 const std::string& key, 37 const std::vector<uint8_t>& data, 38 std::chrono::milliseconds timeout = 5min) = 0; 39 40 virtual std::vector<uint8_t> scatterSend( 41 const std::string& key, 42 const std::vector<std::vector<uint8_t>>& data, 43 std::chrono::milliseconds timeout = 5min) = 0; 44 virtual std::vector<uint8_t> scatterRecv( 45 const std::string& key, 46 std::chrono::milliseconds timeout = 5min) = 0; 47 48 virtual std::vector<std::vector<uint8_t>> allGather( 49 const std::string& key, 50 const std::vector<uint8_t>& data, 51 std::chrono::milliseconds timeout = 5min) = 0; 52 53 virtual int64_t allSum( 54 const std::string& key, 55 int64_t data, 56 std::chrono::milliseconds timeout = 5min) = 0; 57 }; 58 59 } // namespace c10d 60