xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/reducer.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/ScalarType.h>
4 #include <atomic>
5 #include <memory>
6 #include <mutex>
7 #include <tuple>
8 #include <unordered_map>
9 #include <vector>
10 
11 #include <ATen/core/ivalue_inl.h>
12 #include <c10/macros/Macros.h>
13 #include <c10/util/intrusive_ptr.h>
14 #include <torch/csrc/autograd/function.h>
15 #include <torch/csrc/autograd/profiler.h>
16 #include <torch/csrc/autograd/variable.h>
17 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
18 #include <torch/csrc/distributed/c10d/Utils.hpp>
19 #include <torch/csrc/distributed/c10d/comm.hpp>
20 #include <torch/csrc/distributed/c10d/debug.h>
21 #include <torch/csrc/distributed/c10d/default_comm_hooks.hpp>
22 #include <torch/csrc/distributed/c10d/reducer_timer.hpp>
23 #ifndef _WIN32
24 #include <torch/csrc/distributed/autograd/context/context.h>
25 #endif
26 
27 namespace c10d {
28 
29 constexpr int kDefaultFirstBucketBytes = int(1024 * 1024);
30 constexpr int kDefaultBucketBytesCap = int(25 * 1024 * 1024);
31 // Collect runtime stats once for every kDDPRuntimeLoggingSampleRate iterations.
32 constexpr int kDDPRuntimeLoggingSampleRate = 100;
33 
34 // Forward declaration
35 class Logger;
36 
37 // Local accumulator type for a single bucket.
38 struct BucketAccumulator {
39   std::vector<size_t> indices;
40   size_t size = 0;
41   size_t size_limit = 0;
42 };
43 
44 class TORCH_API Reducer {
45  public:
46   // The constructor takes a list of variables (i.e. parameters) for this
47   // process's single model replica (as DDP assumes single-process
48   // single-device). The bucket assignment for this reducer, `bucket_indices`,
49   // is specified as a list of buckets, each of which is specified as a list of
50   // indices into the bucket's `variables` list.
51   explicit Reducer(
52       std::vector<at::Tensor> params,
53       std::vector<std::vector<size_t>> bucket_indices,
54       const std::vector<size_t>& per_bucket_size_limits,
55       c10::intrusive_ptr<c10d::ProcessGroup> process_group,
56       std::vector<bool> expect_sparse_gradients,
57       int64_t bucket_bytes_cap,
58       bool find_unused_parameters,
59       bool gradient_as_bucket_view,
60       std::unordered_map<size_t, std::string> param_names,
61       int64_t first_bucket_bytes_cap);
62 
63   ~Reducer() noexcept(false);
64 
65   // To (re-)initialize bucket assignment, pass a list of buckets, each of
66   // which is specified by a list of indices in the bucket's `variables` list.
67   // This function performs validation that the variables within a bucket
68   // all live on the same device and have the same dimensionality.
69   void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
70 
71   void autograd_hook(size_t index);
72 
73   // This function is called when the forward function has produced an output,
74   // and the user wishes to reduce gradients in the backwards pass.
75   // If they don't, and wish to accumulate gradients before reducing them,
76   // a call to this function can simply be omitted.
77   void prepare_for_backward(const std::vector<at::Tensor>& outputs);
78 
79   // Called at the beginning of forward() inside DistributedDataParallel,
80   // right now it captures the starting time of forward in each iteration.
81   void prepare_for_forward();
82 
83   // Returns the relative time in nanoseconds when gradients were ready,
84   // with respect to the time `prepare_for_backward` was called. The
85   // vector is for parameters for a single model replica.
get_backward_stats() const86   std::vector<int64_t> get_backward_stats() const {
87     return backward_stats_;
88   }
89 
90   // Registers a hook to the reducer. The hook is `CommHookInterface`
91   // type to allow both Python and CPP hooks. This function can only
92   // be called once before calling backward.
93   // Cannot combine with the call of `register_builtin_comm_hook`.
94   void register_comm_hook(std::unique_ptr<CommHookInterface> iface);
95 
96   // Registers a built-in C++ comm hook to the reducer. This function can only
97   // be called once before calling backward.
98   // Cannot combine with the call of `register_comm_hook`.
99   void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type);
100 
101   // Informs reducer that optimizer is running in backward, so gradients
102   // don't need to be copied from buckets as the optimizer would've already
103   // been applied.
set_optimizer_in_backward()104   void set_optimizer_in_backward() {
105     optim_in_backward_ = true;
106   };
107 
108   // Runs allreduce or installed communication hook given GradBucket instance.
109   c10::intrusive_ptr<c10::ivalue::Future> run_comm_hook(
110       GradBucket& grad_bucket);
111 
112   // Runs default allreduce hook.
113   c10::intrusive_ptr<c10::ivalue::Future> run_allreduce_hook(
114       GradBucket& grad_bucket);
115 
116   // Returns gradient buckets in sequential order of buckets_. This is the order
117   // in which buckets are reduced across processes. If return_zero_tensors=true,
118   // will return zero tensors of the same shape instead of the true tensors.
119   std::vector<c10d::GradBucket> get_grad_buckets(
120       bool return_zero_tensors = true) const;
121 
122   // Rebuild buckets based on rebuilt_params_ and rebuilt_param_indices_
123   // according to when tensors received grads in the backward pass.
124   // TODO this function makes broadcast communication call and
125   // could be overlapped with next forward() call, thus
126   // it could be async. Will make it async when rebuilding buckets for
127   // find_unused_parameters = true case, as we could rebuild buckets more than
128   // once for find_unused_parameters = true case, where subgraphs are trained
129   // and parameter indices order may change more frequently.
130   // For find_unused_parameters = false case, buckets are only rebuilt once,
131   // the performance cost is negligible. Returns true if the buckets were
132   // rebuilt.
133   bool rebuild_buckets();
134 
135   void setSparseMetadata(std::map<std::string, at::Tensor>& metadata);
136 
137   // Install futures that should be awaited at end of backwards. Currently these
138   // are only used by user-defined custom buffer reduction hooks, but can be
139   // generalized to any user-originating futures that need to be awaited.
140   void install_futures(c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs);
141 
142   // Returns true if we should rebuild buckets, else false. We only rebuild
143   // buckets once after the first iteration and never rebuild them if
144   // find_unused_parameters_.
should_rebuild_buckets() const145   inline bool should_rebuild_buckets() const {
146     return (static_graph_ || !find_unused_parameters_) && !has_rebuilt_bucket_;
147   }
148 
149   // Pushes all parameters to be rebuilt.
150   void push_rebuilt_params_for_all_indices();
151 
152   // Creates and sets ForwardPassWorkHandle given a Work and the
153   // corresponding tensor being reduced.
154   void set_forward_pass_work_handle(
155       c10::intrusive_ptr<c10d::Work> forwardPassWorkHandle,
156       bool useStaticWorldSize);
157 
158   // Retrieve on-device tensors used to track locally unused parameters. It is
159   // a tensor where index i = 1 if the Variable with that index has been used.
160   at::Tensor get_local_used_map_on_device() const;
161 
162   // An function for users to set sample_rate of collecting
163   // runtime stats. The time stats will be recorded for the
164   // first 10 iterations, after 10 iterations time stats will be
165   // recorded once every "sample_rate" training iterations.
166   void set_ddp_runtime_logging_sample_rate(int sample_rate);
167 
168   // Specify the training graph is static.
169   void set_static_graph();
170 
171   // Delay all reduce to be after all gradients' calculation is complete.
172   void delay_all_reduce();
173 
174   void set_mixed_precision_param_dtype(c10::ScalarType dtype);
175 
176   // Weak reference to associated DDP logger. The reference is weak to avoid
177   // refcycle between reducer and logger.
178   void set_logger(std::weak_ptr<c10d::Logger> logger);
179 
180   // When graph is not explicitly set by user as static and has unused
181   // parameters, this will return whether the graph has been static until the
182   // current iteration, which means unused params set has not changed.
183   bool ddp_graph_static();
184 
185   // Removes autograd hooks registered by the Reducer on the model parameters.
186   void remove_autograd_hooks();
187 
188   // Checks whether or not the reducer has finalized the current backward
189   // iteration.
190   void check_finalized();
191 
192   // Updates the underlying process group used by DDP with the new process
193   // group.
194   void update_process_group(
195       c10::intrusive_ptr<c10d::ProcessGroup> new_process_group);
196 
197   // Resets reducer state.
198   void reset_state();
199 
200  protected:
201   // Forward declaration.
202   struct Bucket;
203 
204   void push_rebuilt_params(const size_t& index);
205 
206   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
207   mutable std::mutex mutex_;
208   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
209   const std::vector<at::Tensor> params_;
210   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
211   c10::intrusive_ptr<::c10d::ProcessGroup> process_group_;
212   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
213   std::vector<bool> expect_sparse_gradients_;
214 
215   std::vector<std::shared_ptr<torch::autograd::Node>>
216       grad_accumulators_; // NOLINT(cppcoreguidelines-non-private-member-variables-in-classes)
217   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
218   std::unordered_map<torch::autograd::Node*, size_t> gradAccToVariableMap_;
219   std::vector<std::pair<uintptr_t, std::shared_ptr<torch::autograd::Node>>>
220       hooks_; // NOLINT(cppcoreguidelines-non-private-member-variables-in-classes)
221 
222   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
223   bool expect_autograd_hooks_;
224   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
225   bool require_finalize_;
226   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
227   size_t next_bucket_;
228 
229   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
230   bool has_marked_unused_parameters_;
231   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
232   const bool find_unused_parameters_;
233   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
234   const bool gradient_as_bucket_view_;
235   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
236   std::vector<size_t> unused_parameters_;
237   // Previous iteration's unused params, used for checking if unused parameters
238   // change between iterations. Only filled during the first backwards call.
239   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
240   std::vector<size_t> prev_iteration_unused_parameters_;
241   // Whether graph is static or not. When user does not explicitly set static
242   // graph, the only possible dynamism is set of unused parameters changing
243   // between iterations which is tracked by this flag.
244   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
245   bool ddp_graph_static_{true};
246   // Locally used parameter maps indicating if parameters are used locally
247   // during the current iteration or no_sync session if no_sync is on.
248   // Each map is a one-dim int32 tensor of number of parameters. These tensors
249   // are marked in autograd_hook to indicate the corresponding param has been
250   // used, and get allreduced in the end of backward step of current iteration
251   // or no_sync session for figuring out the globally unused parameters.
252   //
253   // local_used_map_:     CPU tensor for bookkeeping locally used params
254   // local_used_map_dev_: dev tensor for reducing globally unused params
255   at::Tensor local_used_map_;
256   at::Tensor local_used_map_dev_;
257   // Indicate that reduction is done and D2H copy is done as well.
258   bool local_used_map_reduced_;
259 
260   // Weak pointer to associated DDP logger.
261   std::weak_ptr<c10d::Logger> logger_;
262   // List of futures installed by Reducer::install_futures that should be
263   // awaited at the end of backwards pass.
264   std::optional<c10::List<c10::intrusive_ptr<c10::ivalue::Future>>>
265       installed_futures_{std::nullopt};
266   // Mixed precision parameter dtype for bucket type checking.
267   std::optional<c10::ScalarType> mixed_precision_param_dtype_{std::nullopt};
268 
269   // Work handle for allreduce on local_used_map_
270   c10::intrusive_ptr<c10d::Work> local_used_work_;
271 
272   void mark_variable_ready_dense(size_t variable_index);
273 
274   void mark_variable_ready_sparse(size_t variable_index);
275 
276   void mark_variable_ready(size_t variable_index);
277 
278   void mark_bucket_ready(size_t bucket_index);
279 
280   void finalize_bucket_dense(Bucket& bucket);
281 
282   void finalize_backward();
283 
284   // Returns list of model parameters corresponding to the given bucket.
285   // bucket_index is a key to cache after buckets are rebuilt, after which this
286   // mapping never changes.
287   std::vector<at::Tensor> get_variables_for_bucket(
288       size_t bucket_index,
289       const Bucket& bucket) const;
290 
291   // Asserts that the reduction for the previous iteration has finished before
292   // rebuilding buckets or kicking off the next one.
293   void ensure_prior_reduction_finished();
294 
295   // Broadcast rebuilt buckets from rank 0 to other ranks before initializing
296   // the buckets
297   void sync_bucket_indices(std::vector<std::vector<size_t>>& bucket_indices);
298 
299   // We'd like to use DistAutogradContext::GradCallback here but dist autograd
300   // doesn't exist under Windows. So we just directly use the concrete type but
301   // to preserve and enforce our original intent we do a static assert when dist
302   // autograd is available.
303   using GradCallback = std::function<bool(at::Tensor&)>;
304 #ifndef _WIN32
305   static_assert(
306       std::is_same_v<
307           GradCallback,
308           torch::distributed::autograd::DistAutogradContext::GradCallback>);
309 #endif
310   void runGradCallbackForVariable(at::Tensor& variable, GradCallback&& cb);
311 
312   // This function is called inside `initialize_buckets()`. It initializes both
313   // `bucket_views_in` and `bucket_views_out` with views for each variable's
314   // gradient into the bucket's flattened `gradients` tensor. Views serve as
315   // entry points to `copy_()` each grad's data in/out of the flattened
316   // `gradients` tensor.
317   void initialize_bucket_views(Bucket& bucket);
318 
319   // This function is called inside `finalize_backward`, it happens only if
320   // DDP communication hook was registered to recreate just bucket_views_out
321   // with the result of `future_work`.
322   void populate_bucket_views_out(Bucket& bucket, at::Tensor& tensor);
323 
324   // If gradient_as_bucket_view_ is false, after allreduce buckets,
325   // copy bucket results back to grads.
326   void copy_bucket_to_grad(
327       at::Tensor& variable,
328       Reducer::Bucket& bucket,
329       size_t intra_bucket_index,
330       bool global_unused);
331   // Check layout of grad and bucket_view before copying the grad to bucket.
332   void check_grad_layout(const at::Tensor& grad, const at::Tensor& bucket_view);
333 
334   // A bucket contains [1..N] gradients to be reduced, where the gradients
335   // have the same dtype and device.
336   // Coalescing gradients together before reducing can result in lower overhead
337   // and/or faster time to completion. Coalescing requires the constituent
338   // gradients to have the same dtype and device, and the resulting flattened
339   // tensor uses that common dtype and device. The flattened tensor is filled
340   // as the corresponding gradients are computed (triggered by autograd hooks),
341   // and the buckets are reduced in a predetermined order consistent across
342   // processes.
343   struct Bucket {
344     // Gradients of the bucket flattened into a 1-dimensional tensor
345     at::Tensor gradients;
346 
347     // Views into the `gradients` tensor for each individual gradient
348     // Each view is created with layout (size and stride) matching the
349     // gradient's expected layout (see the "Gradient Layout Contract" in
350     // torch/csrc/autograd/functions/accumulate_grad.h).
351     // `bucket_views_in[i].copy_(grad)` and `grad.copy_(bucket_views_out[i])`
352     // provide convenient ways to copy gradient data in/out of `gradients`,
353     // respectively.
354     // We keep both `bucket_views_in` and `bucket_views_out` because
355     // registering a DDP communication hook may re-initialize
356     // `bucket_views_out` with the value of the hook's `future_work` but we
357     // still need separate views into the bucket's original flattened gradient
358     // to copy in gradient data.
359     std::vector<at::Tensor> bucket_views_in;
360     std::vector<at::Tensor> bucket_views_out;
361 
362     // Variables whose gradients are held in this bucket
363     // We use refcounted tensors here so that we can easily unflatten the
364     // bucket's flattened `gradients` tensor into the participating variables
365     // after reduction has completed.
366     std::vector<at::Tensor> variables;
367 
368     // Per-variable offset/length into the flattened `gradients` tensor and
369     // the corresponding `GradBucket` instance for communication hooks
370     std::vector<size_t> offsets;
371     std::vector<size_t> lengths;
372 
373     // Per-variable sizes slicing into the bucket's `gradients` tensor
374     std::vector<c10::IntArrayRef> sizes_vec;
375 
376     // Number of gradients left to be computed before the bucket is ready to
377     // be reduced
378     size_t pending;
379 
380     // Global indices of participating variables in the bucket
381     std::vector<size_t> variable_indices;
382 
383     // Future work handle for DDP communication hook
384     // If no hook is registered, a temporary vanilla allreduce hook is used.
385     c10::intrusive_ptr<at::ivalue::Future> future_work;
386 
387     // If this bucket should expect a single sparse gradient
388     // If `true`, then this implies that `bucket.variables.size() == 1`.
389     bool expect_sparse_gradient = false;
390 
391     // Sparse indices tensor
392     std::optional<at::Tensor> sparse_tensor_indices = std::nullopt;
393 
394     // TODO(@pietern)
395     // Memory copies from gradient tensors into the bucket are potentially
396     // done on different CUDA streams. We record an event for every copy
397     // so that we can synchronize with them prior to kicking off the reduction.
398     // std::vector<at::cuda::CUDAEvent> events;
399   };
400 
401   std::vector<Bucket> buckets_;
402 
403   // A variable locator locates a particular variable in the reducer's buckets
404   struct VariableLocator {
405     // Index of the bucket containing the variable in the `buckets_` vector
406     size_t bucket_index;
407     // Index of the variable in the bucket, which may be used consistently
408     // across `bucket_views_in`, `bucket_views_out`, `variables`, `offsets`,
409     // `lengths`, `sizes_vec`, and `variable_indices` in `Bucket`
410     size_t intra_bucket_index;
411 
412     VariableLocator() = default;
413 
VariableLocatorc10d::Reducer::VariableLocator414     VariableLocator(size_t bucket_index_, size_t intra_bucket_index_)
415         : bucket_index(bucket_index_),
416           intra_bucket_index(intra_bucket_index_) {}
417   };
418 
419   // Map the index of a variable to its location in the bucket structure.
420   std::vector<VariableLocator> variable_locators_;
421 
422   // track the number of iterations to synchronize grads in training so far.
423   long num_iterations_;
424   // track distinct iteration of backward call. This is distinct from
425   // num_iterations_, for example in the case of multiple forward before
426   // backward.
427   long num_bwd_calls_;
428   // whether the first autograd hook for a distinct backward pass has been
429   // called.
430   bool first_autograd_hook_called_;
431   // track the number of buckets that have been ready for
432   // communication calls like allReduce or communication hooks.
433   int num_buckets_ready_;
434 
435   // Timing information.
436   int64_t backward_compute_start_time_ = -1;
437   std::unique_ptr<Timer> timer_;
438 
439   // We collect the relative timestamp of every gradient being ready
440   // when executing autograd. This can be used to derive a timeline of
441   // the point in time buckets were ready, or ideal bucket assignment/ordering.
442   std::vector<int64_t> backward_stats_;
443 
444   bool should_collect_runtime_stats();
445   void record_forward_compute_start_time();
446   void record_backward_compute_start_time();
447   void record_backward_compute_end_time();
448   void record_backward_comm_start_time();
449   void record_backward_comm_end_time();
450 
451   int get_ddp_runtime_logging_sample_rate();
452   int ddp_runtime_logging_sample_rate_ = kDDPRuntimeLoggingSampleRate;
453 
454   bool is_multi_device_module_ = false;
455 
456   // Following variables are to help build dynamic bucket order
457   bool has_rebuilt_bucket_;
458   std::vector<at::Tensor> rebuilt_params_;
459   std::vector<int64_t> rebuilt_param_indices_;
460   const int64_t bucket_bytes_cap_;
461 
462 #ifndef _WIN32
463   struct RpcContext {
464     using ContextPtr = torch::distributed::autograd::ContextPtr;
465     // The shared_ptr is to hold the context instance.
466     ContextPtr context_ptr_holder;
467     std::atomic<ContextPtr::element_type*> context_ptr{nullptr};
468 
469     void set(ContextPtr&& new_context_ptr);
470   };
471   RpcContext rpc_context_;
472 #endif
473 
474   // A struct containing work handle and tensor for allreduce scheduled in
475   // forward pass, if applicable.
476   struct ForwardPassAllreduceWork {
477     c10::intrusive_ptr<c10d::Work> workHandle;
478     at::Tensor resultTensor;
479     // whether we should divide by the initial world_size or the no. of
480     // remaining DDP ranks.
481     bool useStaticWorldSize;
482   };
483 
484   // Handle for the currently scheduled allreduce in the forward pass, if
485   // applicable.
486   ForwardPassAllreduceWork forwardPassWorkHandle_;
487 
488   // Division factor for reduction of gradients.
489   // Equal to the process group size, with an exception of handling uneven
490   // input.
491   int div_factor_;
492 
493   bool static_graph_;
494 
495   // Key: size_t (index), Value: the number of times that a variable's
496   // autograd_hook() should be triggered before marking this variable's grad as
497   // ready for communication. Map will not change after 1st iteration.
498   std::unordered_map<size_t, int> numGradHooksTriggeredMap_;
499   // Key: size_t (index), Value: the number of times that a variable's
500   // autograd_hook() are left to be triggered before marking this variable's
501   // grad as ready for communication. Map will change after 1st iteration to
502   // track a grad is ready for communication or not.
503   std::unordered_map<size_t, int> numGradHooksTriggeredMapPerIteration_;
504 
505  private:
506   // reset counting for buckets before backward starts
507   void reset_bucket_counting();
508   // search unused parameters beore backward starts
509   void search_unused_parameters(
510       const std::vector<torch::autograd::Variable>& outputs);
511   void set_divide_factor();
512   // kick off all reduce for the ready bucket
513   void all_reduce_bucket(Bucket& bucket);
514   // kick off all reduce to local used map, it can help find global unused
515   // parameters
516   void all_reduce_local_used_map();
517   // initialize locally used parameter maps
518   void initialize_local_used_map();
519   // get current cuda stream
520   const c10::Stream get_current_stream();
521   bool dynamic_graph_find_unused();
522   bool static_graph_first_iteration();
523   bool static_graph_after_first_iteration();
524 
525   // comm_hook_ is used to access the DDP communication hook if registered.
526   std::unique_ptr<CommHookInterface> comm_hook_;
527 
528   // Sparse metadata contains the indices that will be used
529   // when calling into sparse allreduce.
530   // This is only used in the sparse allreduce collective calls
531   std::unique_ptr<std::map<std::string, at::Tensor>> sparse_metadata_;
532 
533   // Debug level setting. It is parsed once when Reducer is constructed, and
534   // remains the same across a single invocation of DDP training.
535   DebugLevel ddp_debug_level_;
536   // Mapping of variable index to fully qualified name of model to notify users
537   // about errors when certain parameters do not get gradient.
538   std::unordered_map<size_t, std::string> param_names_;
539   // Variable indices stored sequentially in order of when the gradient is ready
540   // for the current backwards pass.
541   std::vector<int64_t> grad_ready_order_indices_;
542   // Bytes capacity of first bucket, can be configured by user
543   int64_t first_bucket_bytes_cap_;
544   // Per iteration set of parameter indices that have been marked ready.
545   std::unordered_set<size_t> perIterationReadyParams_;
546   // Retrieves parameter names that have not been marked as ready as part of
547   // previous iteration.
548   std::vector<std::string> getUnmarkedParamsForIteration();
549   // Retrieves parameter indices that have not been marked as ready as part of
550   // previous iteration.
551   std::vector<size_t> getUnmarkedParamIndicesForIteration();
552   // Raises appropriate error if mark_variable_ready is called on the same
553   // variable twice, which is unexpected.
554   void checkAndRaiseMarkedTwiceError(size_t curVariableIndex);
555   // Retrieves parameter corresponding to the given VariableIndex.
556   at::Tensor& get_param_from_index(size_t index);
557 
558   // Cached bucket index to model parameter mapping. Populated after buckets
559   // are rebuilt after which this mapping is static.
560   mutable std::unordered_map<size_t, std::vector<at::Tensor>>
561       cached_variables_for_bucket_;
562 
563   bool optim_in_backward_{false};
564   friend class Logger;
565 };
566 
567 // This is equivalent to take_tensors but returns indices into the
568 // tensor list argument for bucket assignment. Also, it is aware
569 // of device placement and will not allow buckets to span devices.
570 // The index of tensors[i] assigned to bucket is tensor_indices[i],
571 // when tensor_indices is empty, the index of tensors[i] assigned to
572 // bucket is i.
573 TORCH_API std::tuple<std::vector<std::vector<size_t>>, std::vector<size_t>>
574 compute_bucket_assignment_by_size(
575     const std::vector<at::Tensor>& tensors,
576     const std::vector<size_t>& bucket_size,
577     const std::vector<bool>& expect_sparse_gradient = {},
578     const std::vector<int64_t>& tensor_indices = {},
579     const std::optional<std::weak_ptr<c10d::Logger>>& logger = {});
580 
581 // Verify models across all processes are the same as model on rank 0 with
582 // respect to no. of params and matching dtype/size/layout.
583 TORCH_API void verify_params_across_processes(
584     const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
585     const std::vector<at::Tensor>& params,
586     const std::optional<std::weak_ptr<c10d::Logger>>& logger);
587 } // namespace c10d
588