xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <c10/util/FbcodeMaps.h>
5 #include <torch/csrc/distributed/c10d/Store.hpp>
6 #include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
7 
8 namespace c10d {
9 
10 class TORCH_API StoreCollectives : public ControlCollectives {
11  public:
12   explicit StoreCollectives(
13       c10::intrusive_ptr<Store> store,
14       int rank,
15       int worldSize);
16 
17   void barrier(
18       const std::string& key,
19       std::chrono::milliseconds timeout = 5min,
20       bool block = true) override;
21 
22   void broadcastSend(
23       const std::string& key,
24       const std::vector<uint8_t>& data,
25       std::chrono::milliseconds timeout = 5min) override;
26   std::vector<uint8_t> broadcastRecv(
27       const std::string& key,
28       std::chrono::milliseconds timeout = 5min) override;
29 
30   void gatherSend(
31       const std::string& key,
32       const std::vector<uint8_t>& data,
33       std::chrono::milliseconds timeout = 5min) override;
34   std::vector<std::vector<uint8_t>> gatherRecv(
35       const std::string& key,
36       const std::vector<uint8_t>& data,
37       std::chrono::milliseconds timeout = 5min) override;
38 
39   std::vector<uint8_t> scatterSend(
40       const std::string& key,
41       const std::vector<std::vector<uint8_t>>& data,
42       std::chrono::milliseconds timeout = 5min) override;
43   std::vector<uint8_t> scatterRecv(
44       const std::string& key,
45       std::chrono::milliseconds timeout = 5min) override;
46 
47   std::vector<std::vector<uint8_t>> allGather(
48       const std::string& key,
49       const std::vector<uint8_t>& data,
50       std::chrono::milliseconds timeout = 5min) override;
51 
52   int64_t allSum(
53       const std::string& key,
54       int64_t data,
55       std::chrono::milliseconds timeout = 5min) override;
56 
57  private:
58   void enforceUnique(const std::string& key);
59 
60  private:
61   c10::intrusive_ptr<Store> store_;
62   int rank_;
63   int worldSize_;
64 
65   c10::FastSet<std::string> seenKeys_{};
66 };
67 
68 } // namespace c10d
69