xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <chrono>
5 #include <cstdint>
6 #include <string>
7 #include <vector>
8 
9 #include <c10/macros/Macros.h>
10 #include <torch/custom_class.h>
11 
12 namespace c10d {
13 
14 using namespace std::chrono_literals;
15 
16 class TORCH_API ControlCollectives : public torch::CustomClassHolder {
17  public:
18   virtual void barrier(
19       const std::string& key,
20       std::chrono::milliseconds timeout = 5min,
21       bool block = true) = 0;
22 
23   virtual void broadcastSend(
24       const std::string& key,
25       const std::vector<uint8_t>& data,
26       std::chrono::milliseconds timeout = 5min) = 0;
27   virtual std::vector<uint8_t> broadcastRecv(
28       const std::string& key,
29       std::chrono::milliseconds timeout = 5min) = 0;
30 
31   virtual void gatherSend(
32       const std::string& key,
33       const std::vector<uint8_t>& data,
34       std::chrono::milliseconds timeout = 5min) = 0;
35   virtual std::vector<std::vector<uint8_t>> gatherRecv(
36       const std::string& key,
37       const std::vector<uint8_t>& data,
38       std::chrono::milliseconds timeout = 5min) = 0;
39 
40   virtual std::vector<uint8_t> scatterSend(
41       const std::string& key,
42       const std::vector<std::vector<uint8_t>>& data,
43       std::chrono::milliseconds timeout = 5min) = 0;
44   virtual std::vector<uint8_t> scatterRecv(
45       const std::string& key,
46       std::chrono::milliseconds timeout = 5min) = 0;
47 
48   virtual std::vector<std::vector<uint8_t>> allGather(
49       const std::string& key,
50       const std::vector<uint8_t>& data,
51       std::chrono::milliseconds timeout = 5min) = 0;
52 
53   virtual int64_t allSum(
54       const std::string& key,
55       int64_t data,
56       std::chrono::milliseconds timeout = 5min) = 0;
57 };
58 
59 } // namespace c10d
60