1 #pragma once 2 3 #include <c10/core/ScalarType.h> 4 #include <atomic> 5 #include <memory> 6 #include <mutex> 7 #include <tuple> 8 #include <unordered_map> 9 #include <vector> 10 11 #include <ATen/core/ivalue_inl.h> 12 #include <c10/macros/Macros.h> 13 #include <c10/util/intrusive_ptr.h> 14 #include <torch/csrc/autograd/function.h> 15 #include <torch/csrc/autograd/profiler.h> 16 #include <torch/csrc/autograd/variable.h> 17 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> 18 #include <torch/csrc/distributed/c10d/Utils.hpp> 19 #include <torch/csrc/distributed/c10d/comm.hpp> 20 #include <torch/csrc/distributed/c10d/debug.h> 21 #include <torch/csrc/distributed/c10d/default_comm_hooks.hpp> 22 #include <torch/csrc/distributed/c10d/reducer_timer.hpp> 23 #ifndef _WIN32 24 #include <torch/csrc/distributed/autograd/context/context.h> 25 #endif 26 27 namespace c10d { 28 29 constexpr int kDefaultFirstBucketBytes = int(1024 * 1024); 30 constexpr int kDefaultBucketBytesCap = int(25 * 1024 * 1024); 31 // Collect runtime stats once for every kDDPRuntimeLoggingSampleRate iterations. 32 constexpr int kDDPRuntimeLoggingSampleRate = 100; 33 34 // Forward declaration 35 class Logger; 36 37 // Local accumulator type for a single bucket. 38 struct BucketAccumulator { 39 std::vector<size_t> indices; 40 size_t size = 0; 41 size_t size_limit = 0; 42 }; 43 44 class TORCH_API Reducer { 45 public: 46 // The constructor takes a list of variables (i.e. parameters) for this 47 // process's single model replica (as DDP assumes single-process 48 // single-device). The bucket assignment for this reducer, `bucket_indices`, 49 // is specified as a list of buckets, each of which is specified as a list of 50 // indices into the bucket's `variables` list. 51 explicit Reducer( 52 std::vector<at::Tensor> params, 53 std::vector<std::vector<size_t>> bucket_indices, 54 const std::vector<size_t>& per_bucket_size_limits, 55 c10::intrusive_ptr<c10d::ProcessGroup> process_group, 56 std::vector<bool> expect_sparse_gradients, 57 int64_t bucket_bytes_cap, 58 bool find_unused_parameters, 59 bool gradient_as_bucket_view, 60 std::unordered_map<size_t, std::string> param_names, 61 int64_t first_bucket_bytes_cap); 62 63 ~Reducer() noexcept(false); 64 65 // To (re-)initialize bucket assignment, pass a list of buckets, each of 66 // which is specified by a list of indices in the bucket's `variables` list. 67 // This function performs validation that the variables within a bucket 68 // all live on the same device and have the same dimensionality. 69 void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices); 70 71 void autograd_hook(size_t index); 72 73 // This function is called when the forward function has produced an output, 74 // and the user wishes to reduce gradients in the backwards pass. 75 // If they don't, and wish to accumulate gradients before reducing them, 76 // a call to this function can simply be omitted. 77 void prepare_for_backward(const std::vector<at::Tensor>& outputs); 78 79 // Called at the beginning of forward() inside DistributedDataParallel, 80 // right now it captures the starting time of forward in each iteration. 81 void prepare_for_forward(); 82 83 // Returns the relative time in nanoseconds when gradients were ready, 84 // with respect to the time `prepare_for_backward` was called. The 85 // vector is for parameters for a single model replica. get_backward_stats() const86 std::vector<int64_t> get_backward_stats() const { 87 return backward_stats_; 88 } 89 90 // Registers a hook to the reducer. The hook is `CommHookInterface` 91 // type to allow both Python and CPP hooks. This function can only 92 // be called once before calling backward. 93 // Cannot combine with the call of `register_builtin_comm_hook`. 94 void register_comm_hook(std::unique_ptr<CommHookInterface> iface); 95 96 // Registers a built-in C++ comm hook to the reducer. This function can only 97 // be called once before calling backward. 98 // Cannot combine with the call of `register_comm_hook`. 99 void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type); 100 101 // Informs reducer that optimizer is running in backward, so gradients 102 // don't need to be copied from buckets as the optimizer would've already 103 // been applied. set_optimizer_in_backward()104 void set_optimizer_in_backward() { 105 optim_in_backward_ = true; 106 }; 107 108 // Runs allreduce or installed communication hook given GradBucket instance. 109 c10::intrusive_ptr<c10::ivalue::Future> run_comm_hook( 110 GradBucket& grad_bucket); 111 112 // Runs default allreduce hook. 113 c10::intrusive_ptr<c10::ivalue::Future> run_allreduce_hook( 114 GradBucket& grad_bucket); 115 116 // Returns gradient buckets in sequential order of buckets_. This is the order 117 // in which buckets are reduced across processes. If return_zero_tensors=true, 118 // will return zero tensors of the same shape instead of the true tensors. 119 std::vector<c10d::GradBucket> get_grad_buckets( 120 bool return_zero_tensors = true) const; 121 122 // Rebuild buckets based on rebuilt_params_ and rebuilt_param_indices_ 123 // according to when tensors received grads in the backward pass. 124 // TODO this function makes broadcast communication call and 125 // could be overlapped with next forward() call, thus 126 // it could be async. Will make it async when rebuilding buckets for 127 // find_unused_parameters = true case, as we could rebuild buckets more than 128 // once for find_unused_parameters = true case, where subgraphs are trained 129 // and parameter indices order may change more frequently. 130 // For find_unused_parameters = false case, buckets are only rebuilt once, 131 // the performance cost is negligible. Returns true if the buckets were 132 // rebuilt. 133 bool rebuild_buckets(); 134 135 void setSparseMetadata(std::map<std::string, at::Tensor>& metadata); 136 137 // Install futures that should be awaited at end of backwards. Currently these 138 // are only used by user-defined custom buffer reduction hooks, but can be 139 // generalized to any user-originating futures that need to be awaited. 140 void install_futures(c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs); 141 142 // Returns true if we should rebuild buckets, else false. We only rebuild 143 // buckets once after the first iteration and never rebuild them if 144 // find_unused_parameters_. should_rebuild_buckets() const145 inline bool should_rebuild_buckets() const { 146 return (static_graph_ || !find_unused_parameters_) && !has_rebuilt_bucket_; 147 } 148 149 // Pushes all parameters to be rebuilt. 150 void push_rebuilt_params_for_all_indices(); 151 152 // Creates and sets ForwardPassWorkHandle given a Work and the 153 // corresponding tensor being reduced. 154 void set_forward_pass_work_handle( 155 c10::intrusive_ptr<c10d::Work> forwardPassWorkHandle, 156 bool useStaticWorldSize); 157 158 // Retrieve on-device tensors used to track locally unused parameters. It is 159 // a tensor where index i = 1 if the Variable with that index has been used. 160 at::Tensor get_local_used_map_on_device() const; 161 162 // An function for users to set sample_rate of collecting 163 // runtime stats. The time stats will be recorded for the 164 // first 10 iterations, after 10 iterations time stats will be 165 // recorded once every "sample_rate" training iterations. 166 void set_ddp_runtime_logging_sample_rate(int sample_rate); 167 168 // Specify the training graph is static. 169 void set_static_graph(); 170 171 // Delay all reduce to be after all gradients' calculation is complete. 172 void delay_all_reduce(); 173 174 void set_mixed_precision_param_dtype(c10::ScalarType dtype); 175 176 // Weak reference to associated DDP logger. The reference is weak to avoid 177 // refcycle between reducer and logger. 178 void set_logger(std::weak_ptr<c10d::Logger> logger); 179 180 // When graph is not explicitly set by user as static and has unused 181 // parameters, this will return whether the graph has been static until the 182 // current iteration, which means unused params set has not changed. 183 bool ddp_graph_static(); 184 185 // Removes autograd hooks registered by the Reducer on the model parameters. 186 void remove_autograd_hooks(); 187 188 // Checks whether or not the reducer has finalized the current backward 189 // iteration. 190 void check_finalized(); 191 192 // Updates the underlying process group used by DDP with the new process 193 // group. 194 void update_process_group( 195 c10::intrusive_ptr<c10d::ProcessGroup> new_process_group); 196 197 // Resets reducer state. 198 void reset_state(); 199 200 protected: 201 // Forward declaration. 202 struct Bucket; 203 204 void push_rebuilt_params(const size_t& index); 205 206 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 207 mutable std::mutex mutex_; 208 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 209 const std::vector<at::Tensor> params_; 210 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 211 c10::intrusive_ptr<::c10d::ProcessGroup> process_group_; 212 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 213 std::vector<bool> expect_sparse_gradients_; 214 215 std::vector<std::shared_ptr<torch::autograd::Node>> 216 grad_accumulators_; // NOLINT(cppcoreguidelines-non-private-member-variables-in-classes) 217 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 218 std::unordered_map<torch::autograd::Node*, size_t> gradAccToVariableMap_; 219 std::vector<std::pair<uintptr_t, std::shared_ptr<torch::autograd::Node>>> 220 hooks_; // NOLINT(cppcoreguidelines-non-private-member-variables-in-classes) 221 222 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 223 bool expect_autograd_hooks_; 224 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 225 bool require_finalize_; 226 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 227 size_t next_bucket_; 228 229 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 230 bool has_marked_unused_parameters_; 231 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 232 const bool find_unused_parameters_; 233 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 234 const bool gradient_as_bucket_view_; 235 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 236 std::vector<size_t> unused_parameters_; 237 // Previous iteration's unused params, used for checking if unused parameters 238 // change between iterations. Only filled during the first backwards call. 239 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 240 std::vector<size_t> prev_iteration_unused_parameters_; 241 // Whether graph is static or not. When user does not explicitly set static 242 // graph, the only possible dynamism is set of unused parameters changing 243 // between iterations which is tracked by this flag. 244 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 245 bool ddp_graph_static_{true}; 246 // Locally used parameter maps indicating if parameters are used locally 247 // during the current iteration or no_sync session if no_sync is on. 248 // Each map is a one-dim int32 tensor of number of parameters. These tensors 249 // are marked in autograd_hook to indicate the corresponding param has been 250 // used, and get allreduced in the end of backward step of current iteration 251 // or no_sync session for figuring out the globally unused parameters. 252 // 253 // local_used_map_: CPU tensor for bookkeeping locally used params 254 // local_used_map_dev_: dev tensor for reducing globally unused params 255 at::Tensor local_used_map_; 256 at::Tensor local_used_map_dev_; 257 // Indicate that reduction is done and D2H copy is done as well. 258 bool local_used_map_reduced_; 259 260 // Weak pointer to associated DDP logger. 261 std::weak_ptr<c10d::Logger> logger_; 262 // List of futures installed by Reducer::install_futures that should be 263 // awaited at the end of backwards pass. 264 std::optional<c10::List<c10::intrusive_ptr<c10::ivalue::Future>>> 265 installed_futures_{std::nullopt}; 266 // Mixed precision parameter dtype for bucket type checking. 267 std::optional<c10::ScalarType> mixed_precision_param_dtype_{std::nullopt}; 268 269 // Work handle for allreduce on local_used_map_ 270 c10::intrusive_ptr<c10d::Work> local_used_work_; 271 272 void mark_variable_ready_dense(size_t variable_index); 273 274 void mark_variable_ready_sparse(size_t variable_index); 275 276 void mark_variable_ready(size_t variable_index); 277 278 void mark_bucket_ready(size_t bucket_index); 279 280 void finalize_bucket_dense(Bucket& bucket); 281 282 void finalize_backward(); 283 284 // Returns list of model parameters corresponding to the given bucket. 285 // bucket_index is a key to cache after buckets are rebuilt, after which this 286 // mapping never changes. 287 std::vector<at::Tensor> get_variables_for_bucket( 288 size_t bucket_index, 289 const Bucket& bucket) const; 290 291 // Asserts that the reduction for the previous iteration has finished before 292 // rebuilding buckets or kicking off the next one. 293 void ensure_prior_reduction_finished(); 294 295 // Broadcast rebuilt buckets from rank 0 to other ranks before initializing 296 // the buckets 297 void sync_bucket_indices(std::vector<std::vector<size_t>>& bucket_indices); 298 299 // We'd like to use DistAutogradContext::GradCallback here but dist autograd 300 // doesn't exist under Windows. So we just directly use the concrete type but 301 // to preserve and enforce our original intent we do a static assert when dist 302 // autograd is available. 303 using GradCallback = std::function<bool(at::Tensor&)>; 304 #ifndef _WIN32 305 static_assert( 306 std::is_same_v< 307 GradCallback, 308 torch::distributed::autograd::DistAutogradContext::GradCallback>); 309 #endif 310 void runGradCallbackForVariable(at::Tensor& variable, GradCallback&& cb); 311 312 // This function is called inside `initialize_buckets()`. It initializes both 313 // `bucket_views_in` and `bucket_views_out` with views for each variable's 314 // gradient into the bucket's flattened `gradients` tensor. Views serve as 315 // entry points to `copy_()` each grad's data in/out of the flattened 316 // `gradients` tensor. 317 void initialize_bucket_views(Bucket& bucket); 318 319 // This function is called inside `finalize_backward`, it happens only if 320 // DDP communication hook was registered to recreate just bucket_views_out 321 // with the result of `future_work`. 322 void populate_bucket_views_out(Bucket& bucket, at::Tensor& tensor); 323 324 // If gradient_as_bucket_view_ is false, after allreduce buckets, 325 // copy bucket results back to grads. 326 void copy_bucket_to_grad( 327 at::Tensor& variable, 328 Reducer::Bucket& bucket, 329 size_t intra_bucket_index, 330 bool global_unused); 331 // Check layout of grad and bucket_view before copying the grad to bucket. 332 void check_grad_layout(const at::Tensor& grad, const at::Tensor& bucket_view); 333 334 // A bucket contains [1..N] gradients to be reduced, where the gradients 335 // have the same dtype and device. 336 // Coalescing gradients together before reducing can result in lower overhead 337 // and/or faster time to completion. Coalescing requires the constituent 338 // gradients to have the same dtype and device, and the resulting flattened 339 // tensor uses that common dtype and device. The flattened tensor is filled 340 // as the corresponding gradients are computed (triggered by autograd hooks), 341 // and the buckets are reduced in a predetermined order consistent across 342 // processes. 343 struct Bucket { 344 // Gradients of the bucket flattened into a 1-dimensional tensor 345 at::Tensor gradients; 346 347 // Views into the `gradients` tensor for each individual gradient 348 // Each view is created with layout (size and stride) matching the 349 // gradient's expected layout (see the "Gradient Layout Contract" in 350 // torch/csrc/autograd/functions/accumulate_grad.h). 351 // `bucket_views_in[i].copy_(grad)` and `grad.copy_(bucket_views_out[i])` 352 // provide convenient ways to copy gradient data in/out of `gradients`, 353 // respectively. 354 // We keep both `bucket_views_in` and `bucket_views_out` because 355 // registering a DDP communication hook may re-initialize 356 // `bucket_views_out` with the value of the hook's `future_work` but we 357 // still need separate views into the bucket's original flattened gradient 358 // to copy in gradient data. 359 std::vector<at::Tensor> bucket_views_in; 360 std::vector<at::Tensor> bucket_views_out; 361 362 // Variables whose gradients are held in this bucket 363 // We use refcounted tensors here so that we can easily unflatten the 364 // bucket's flattened `gradients` tensor into the participating variables 365 // after reduction has completed. 366 std::vector<at::Tensor> variables; 367 368 // Per-variable offset/length into the flattened `gradients` tensor and 369 // the corresponding `GradBucket` instance for communication hooks 370 std::vector<size_t> offsets; 371 std::vector<size_t> lengths; 372 373 // Per-variable sizes slicing into the bucket's `gradients` tensor 374 std::vector<c10::IntArrayRef> sizes_vec; 375 376 // Number of gradients left to be computed before the bucket is ready to 377 // be reduced 378 size_t pending; 379 380 // Global indices of participating variables in the bucket 381 std::vector<size_t> variable_indices; 382 383 // Future work handle for DDP communication hook 384 // If no hook is registered, a temporary vanilla allreduce hook is used. 385 c10::intrusive_ptr<at::ivalue::Future> future_work; 386 387 // If this bucket should expect a single sparse gradient 388 // If `true`, then this implies that `bucket.variables.size() == 1`. 389 bool expect_sparse_gradient = false; 390 391 // Sparse indices tensor 392 std::optional<at::Tensor> sparse_tensor_indices = std::nullopt; 393 394 // TODO(@pietern) 395 // Memory copies from gradient tensors into the bucket are potentially 396 // done on different CUDA streams. We record an event for every copy 397 // so that we can synchronize with them prior to kicking off the reduction. 398 // std::vector<at::cuda::CUDAEvent> events; 399 }; 400 401 std::vector<Bucket> buckets_; 402 403 // A variable locator locates a particular variable in the reducer's buckets 404 struct VariableLocator { 405 // Index of the bucket containing the variable in the `buckets_` vector 406 size_t bucket_index; 407 // Index of the variable in the bucket, which may be used consistently 408 // across `bucket_views_in`, `bucket_views_out`, `variables`, `offsets`, 409 // `lengths`, `sizes_vec`, and `variable_indices` in `Bucket` 410 size_t intra_bucket_index; 411 412 VariableLocator() = default; 413 VariableLocatorc10d::Reducer::VariableLocator414 VariableLocator(size_t bucket_index_, size_t intra_bucket_index_) 415 : bucket_index(bucket_index_), 416 intra_bucket_index(intra_bucket_index_) {} 417 }; 418 419 // Map the index of a variable to its location in the bucket structure. 420 std::vector<VariableLocator> variable_locators_; 421 422 // track the number of iterations to synchronize grads in training so far. 423 long num_iterations_; 424 // track distinct iteration of backward call. This is distinct from 425 // num_iterations_, for example in the case of multiple forward before 426 // backward. 427 long num_bwd_calls_; 428 // whether the first autograd hook for a distinct backward pass has been 429 // called. 430 bool first_autograd_hook_called_; 431 // track the number of buckets that have been ready for 432 // communication calls like allReduce or communication hooks. 433 int num_buckets_ready_; 434 435 // Timing information. 436 int64_t backward_compute_start_time_ = -1; 437 std::unique_ptr<Timer> timer_; 438 439 // We collect the relative timestamp of every gradient being ready 440 // when executing autograd. This can be used to derive a timeline of 441 // the point in time buckets were ready, or ideal bucket assignment/ordering. 442 std::vector<int64_t> backward_stats_; 443 444 bool should_collect_runtime_stats(); 445 void record_forward_compute_start_time(); 446 void record_backward_compute_start_time(); 447 void record_backward_compute_end_time(); 448 void record_backward_comm_start_time(); 449 void record_backward_comm_end_time(); 450 451 int get_ddp_runtime_logging_sample_rate(); 452 int ddp_runtime_logging_sample_rate_ = kDDPRuntimeLoggingSampleRate; 453 454 bool is_multi_device_module_ = false; 455 456 // Following variables are to help build dynamic bucket order 457 bool has_rebuilt_bucket_; 458 std::vector<at::Tensor> rebuilt_params_; 459 std::vector<int64_t> rebuilt_param_indices_; 460 const int64_t bucket_bytes_cap_; 461 462 #ifndef _WIN32 463 struct RpcContext { 464 using ContextPtr = torch::distributed::autograd::ContextPtr; 465 // The shared_ptr is to hold the context instance. 466 ContextPtr context_ptr_holder; 467 std::atomic<ContextPtr::element_type*> context_ptr{nullptr}; 468 469 void set(ContextPtr&& new_context_ptr); 470 }; 471 RpcContext rpc_context_; 472 #endif 473 474 // A struct containing work handle and tensor for allreduce scheduled in 475 // forward pass, if applicable. 476 struct ForwardPassAllreduceWork { 477 c10::intrusive_ptr<c10d::Work> workHandle; 478 at::Tensor resultTensor; 479 // whether we should divide by the initial world_size or the no. of 480 // remaining DDP ranks. 481 bool useStaticWorldSize; 482 }; 483 484 // Handle for the currently scheduled allreduce in the forward pass, if 485 // applicable. 486 ForwardPassAllreduceWork forwardPassWorkHandle_; 487 488 // Division factor for reduction of gradients. 489 // Equal to the process group size, with an exception of handling uneven 490 // input. 491 int div_factor_; 492 493 bool static_graph_; 494 495 // Key: size_t (index), Value: the number of times that a variable's 496 // autograd_hook() should be triggered before marking this variable's grad as 497 // ready for communication. Map will not change after 1st iteration. 498 std::unordered_map<size_t, int> numGradHooksTriggeredMap_; 499 // Key: size_t (index), Value: the number of times that a variable's 500 // autograd_hook() are left to be triggered before marking this variable's 501 // grad as ready for communication. Map will change after 1st iteration to 502 // track a grad is ready for communication or not. 503 std::unordered_map<size_t, int> numGradHooksTriggeredMapPerIteration_; 504 505 private: 506 // reset counting for buckets before backward starts 507 void reset_bucket_counting(); 508 // search unused parameters beore backward starts 509 void search_unused_parameters( 510 const std::vector<torch::autograd::Variable>& outputs); 511 void set_divide_factor(); 512 // kick off all reduce for the ready bucket 513 void all_reduce_bucket(Bucket& bucket); 514 // kick off all reduce to local used map, it can help find global unused 515 // parameters 516 void all_reduce_local_used_map(); 517 // initialize locally used parameter maps 518 void initialize_local_used_map(); 519 // get current cuda stream 520 const c10::Stream get_current_stream(); 521 bool dynamic_graph_find_unused(); 522 bool static_graph_first_iteration(); 523 bool static_graph_after_first_iteration(); 524 525 // comm_hook_ is used to access the DDP communication hook if registered. 526 std::unique_ptr<CommHookInterface> comm_hook_; 527 528 // Sparse metadata contains the indices that will be used 529 // when calling into sparse allreduce. 530 // This is only used in the sparse allreduce collective calls 531 std::unique_ptr<std::map<std::string, at::Tensor>> sparse_metadata_; 532 533 // Debug level setting. It is parsed once when Reducer is constructed, and 534 // remains the same across a single invocation of DDP training. 535 DebugLevel ddp_debug_level_; 536 // Mapping of variable index to fully qualified name of model to notify users 537 // about errors when certain parameters do not get gradient. 538 std::unordered_map<size_t, std::string> param_names_; 539 // Variable indices stored sequentially in order of when the gradient is ready 540 // for the current backwards pass. 541 std::vector<int64_t> grad_ready_order_indices_; 542 // Bytes capacity of first bucket, can be configured by user 543 int64_t first_bucket_bytes_cap_; 544 // Per iteration set of parameter indices that have been marked ready. 545 std::unordered_set<size_t> perIterationReadyParams_; 546 // Retrieves parameter names that have not been marked as ready as part of 547 // previous iteration. 548 std::vector<std::string> getUnmarkedParamsForIteration(); 549 // Retrieves parameter indices that have not been marked as ready as part of 550 // previous iteration. 551 std::vector<size_t> getUnmarkedParamIndicesForIteration(); 552 // Raises appropriate error if mark_variable_ready is called on the same 553 // variable twice, which is unexpected. 554 void checkAndRaiseMarkedTwiceError(size_t curVariableIndex); 555 // Retrieves parameter corresponding to the given VariableIndex. 556 at::Tensor& get_param_from_index(size_t index); 557 558 // Cached bucket index to model parameter mapping. Populated after buckets 559 // are rebuilt after which this mapping is static. 560 mutable std::unordered_map<size_t, std::vector<at::Tensor>> 561 cached_variables_for_bucket_; 562 563 bool optim_in_backward_{false}; 564 friend class Logger; 565 }; 566 567 // This is equivalent to take_tensors but returns indices into the 568 // tensor list argument for bucket assignment. Also, it is aware 569 // of device placement and will not allow buckets to span devices. 570 // The index of tensors[i] assigned to bucket is tensor_indices[i], 571 // when tensor_indices is empty, the index of tensors[i] assigned to 572 // bucket is i. 573 TORCH_API std::tuple<std::vector<std::vector<size_t>>, std::vector<size_t>> 574 compute_bucket_assignment_by_size( 575 const std::vector<at::Tensor>& tensors, 576 const std::vector<size_t>& bucket_size, 577 const std::vector<bool>& expect_sparse_gradient = {}, 578 const std::vector<int64_t>& tensor_indices = {}, 579 const std::optional<std::weak_ptr<c10d::Logger>>& logger = {}); 580 581 // Verify models across all processes are the same as model on rank 0 with 582 // respect to no. of params and matching dtype/size/layout. 583 TORCH_API void verify_params_across_processes( 584 const c10::intrusive_ptr<c10d::ProcessGroup>& process_group, 585 const std::vector<at::Tensor>& params, 586 const std::optional<std::weak_ptr<c10d::Logger>>& logger); 587 } // namespace c10d 588