1 #pragma once 2 3 #include <c10/macros/Macros.h> 4 #include <c10/util/FbcodeMaps.h> 5 #include <torch/csrc/distributed/c10d/Store.hpp> 6 #include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp> 7 8 namespace c10d { 9 10 class TORCH_API StoreCollectives : public ControlCollectives { 11 public: 12 explicit StoreCollectives( 13 c10::intrusive_ptr<Store> store, 14 int rank, 15 int worldSize); 16 17 void barrier( 18 const std::string& key, 19 std::chrono::milliseconds timeout = 5min, 20 bool block = true) override; 21 22 void broadcastSend( 23 const std::string& key, 24 const std::vector<uint8_t>& data, 25 std::chrono::milliseconds timeout = 5min) override; 26 std::vector<uint8_t> broadcastRecv( 27 const std::string& key, 28 std::chrono::milliseconds timeout = 5min) override; 29 30 void gatherSend( 31 const std::string& key, 32 const std::vector<uint8_t>& data, 33 std::chrono::milliseconds timeout = 5min) override; 34 std::vector<std::vector<uint8_t>> gatherRecv( 35 const std::string& key, 36 const std::vector<uint8_t>& data, 37 std::chrono::milliseconds timeout = 5min) override; 38 39 std::vector<uint8_t> scatterSend( 40 const std::string& key, 41 const std::vector<std::vector<uint8_t>>& data, 42 std::chrono::milliseconds timeout = 5min) override; 43 std::vector<uint8_t> scatterRecv( 44 const std::string& key, 45 std::chrono::milliseconds timeout = 5min) override; 46 47 std::vector<std::vector<uint8_t>> allGather( 48 const std::string& key, 49 const std::vector<uint8_t>& data, 50 std::chrono::milliseconds timeout = 5min) override; 51 52 int64_t allSum( 53 const std::string& key, 54 int64_t data, 55 std::chrono::milliseconds timeout = 5min) override; 56 57 private: 58 void enforceUnique(const std::string& key); 59 60 private: 61 c10::intrusive_ptr<Store> store_; 62 int rank_; 63 int worldSize_; 64 65 c10::FastSet<std::string> seenKeys_{}; 66 }; 67 68 } // namespace c10d 69