xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/comm.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/ivalue.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
7 #include <utility>
8 
9 namespace c10d {
10 
11 // Broadcast many tensors to all processes in the process group.
12 TORCH_API void broadcast_coalesced(
13     const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
14     at::TensorList tensors,
15     size_t buffer_size,
16     int rank = 0);
17 
18 // This class passes bucket contents tensor to DDP communication hook.
19 class TORCH_API GradBucket {
20  public:
GradBucket(size_t index,size_t bucket_count,at::Tensor tensor,std::vector<size_t> offsets,std::vector<size_t> lengths,std::vector<c10::IntArrayRef> sizes_vec,std::vector<at::Tensor> parameters,std::optional<at::Tensor> sparse_grad_indices)21   explicit GradBucket(
22       size_t index,
23       size_t bucket_count,
24       at::Tensor tensor,
25       std::vector<size_t> offsets,
26       std::vector<size_t> lengths,
27       std::vector<c10::IntArrayRef> sizes_vec,
28       std::vector<at::Tensor> parameters,
29       std::optional<at::Tensor> sparse_grad_indices)
30       : index_(index),
31         bucket_count_(bucket_count),
32         buffer_(std::move(tensor)),
33         offsets_(std::move(offsets)),
34         lengths_(std::move(lengths)),
35         sizes_vec_(std::move(sizes_vec)),
36         parameters_(std::move(parameters)),
37         sparse_grad_indices_(std::move(sparse_grad_indices)) {}
38 
39   // Returns the index of the bucket, which is unique across all the buckets.
getIndex() const40   size_t getIndex() const {
41     return index_;
42   }
43 
getBuffer() const44   const at::Tensor& getBuffer() const {
45     return buffer_;
46   }
47 
48   // Returns a mutable buffer compared with the above method.
getBufferRef()49   at::Tensor& getBufferRef() {
50     return buffer_;
51   }
52 
53   // Overwrites the buffer at a specific index.
setBuffer(at::Tensor & buffer)54   void setBuffer(at::Tensor& buffer) {
55     buffer_ = buffer;
56   }
57 
58   // Each tensor in the list that getGradients corresponds to a
59   // parameter.
60   std::vector<at::Tensor> getGradients() const;
61 
62   // Returns model parameters belonging to this bucket. They are returned in the
63   // same order as gradient tensors via getGradients(). For example,
64   // getParameters[i] will have its gradient stored in
65   // getGradients[i]
getParameters() const66   const std::vector<at::Tensor> getParameters() const {
67     return parameters_;
68   }
69 
70   // Returns whther this bucket is the last bucket to allreduce in an iteration.
isLast() const71   bool isLast() const {
72     return index_ == bucket_count_ - 1;
73   }
74 
getSparseGradIndices()75   std::optional<at::Tensor>& getSparseGradIndices() {
76     return sparse_grad_indices_;
77   }
78 
79  private:
80   size_t index_;
81   size_t bucket_count_;
82   at::Tensor buffer_;
83 
84   // Per-variable info in buffer_.
85   std::vector<size_t> offsets_;
86   std::vector<size_t> lengths_;
87   std::vector<c10::IntArrayRef> sizes_vec_;
88 
89   // Model parameters for this bucket.
90   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
91   const std::vector<at::Tensor> parameters_;
92 
93   // Predefined sparse indices for this bucket (only used for sparse tensors).
94   // The gradients will be updated to have indices with these tensor values
95   std::optional<at::Tensor> sparse_grad_indices_;
96 };
97 
98 // Base class of both `PythonCommHook` and `CppCommHook`.
99 // Requires implementing 1) `runHook` method that communicates gradients
100 // asynchronously, and 2) `parseHookResult` method that converts the hook
101 // result into a tensor.
102 class TORCH_API CommHookInterface {
103  public:
104   virtual ~CommHookInterface() = default;
105 
106   // Passes the input grad bucket to the registered communication hook.
107   // Once the tensor in the bucket are ready, kicks off the hook asynchronously
108   // and returns a future that holds the communication results.
109   virtual c10::intrusive_ptr<c10::ivalue::Future> runHook(
110       GradBucket& bucket) = 0;
111 
112   // Returns the resulting tensor once the communication hook result is
113   // ready. The resulting tensor will then be copied to the grads of
114   // individual parameters.
115   virtual at::Tensor parseHookResult(const c10::IValue& result) = 0;
116 };
117 
118 namespace detail {
119 // This helper function is called both by CppCommHookInterface below and inside
120 // reducer.
121 TORCH_API at::Tensor parseCppCommHookResult(const c10::IValue& result);
122 } // namespace detail
123 
124 // This CppCommHook interface only requires implementing runHook method that
125 // potentially uses a state.
126 template <typename T>
127 class CppCommHookInterface : public CommHookInterface {
128  public:
CppCommHookInterface(T state)129   explicit CppCommHookInterface(T state) : state_(std::move(state)) {}
130 
131   ~CppCommHookInterface() override = default;
132 
parseHookResult(const c10::IValue & result)133   at::Tensor parseHookResult(const c10::IValue& result) override {
134     return detail::parseCppCommHookResult(result);
135   }
136 
137  protected:
138   T state_;
139 };
140 
141 } // namespace c10d
142