1 #include <c10/core/ScalarType.h> 2 #include <c10/util/Exception.h> 3 #include <torch/csrc/distributed/c10d/default_comm_hooks.hpp> 4 5 #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> 6 #include <torch/csrc/distributed/c10d/comm.hpp> 7 #include <torch/torch.h> 8 9 namespace c10d { 10 runHook(GradBucket & bucket)11c10::intrusive_ptr<c10::ivalue::Future> AllReduceCommHook::runHook( 12 GradBucket& bucket) { 13 std::vector<at::Tensor> tensors = {bucket.getBufferRef()}; 14 // Apply the division first to avoid overflow, especially for FP16. 15 tensors[0] /= state_->getSize(); 16 return state_->allreduce(tensors)->getFuture(); 17 } 18 runHook(GradBucket & bucket)19c10::intrusive_ptr<c10::ivalue::Future> FP16CompressCommHook::runHook( 20 GradBucket& bucket) { 21 auto compressed_tensor = bucket.getBufferRef().to(torch::kFloat16); 22 // Apply the division first to avoid overflow. 23 compressed_tensor /= state_->getSize(); 24 std::vector<at::Tensor> tensors = {compressed_tensor}; 25 26 auto allreduce_fut = state_->allreduce(tensors)->getFuture(); 27 auto decompressed_tensor = bucket.getBufferRef(); 28 auto decompress = [decompressed_tensor](c10::ivalue::Future& allreduce_fut) { 29 auto result = allreduce_fut.value(); 30 TORCH_INTERNAL_ASSERT( 31 result.isTensorList(), 32 "ProcessGroup::allreduce should return TensorList"); 33 34 auto reduce_tensor = result.toTensorVector()[0]; 35 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 36 reduce_tensor.scalar_type() == at::ScalarType::Half, 37 "Expected reduced tensor to be fp16 in FP16CompressHook, but got type ", 38 reduce_tensor.scalar_type()); 39 decompressed_tensor.copy_(reduce_tensor); 40 return c10::IValue(decompressed_tensor); 41 }; 42 43 return allreduce_fut->then(decompress, allreduce_fut->elementType()); 44 } 45 runHook(GradBucket & bucket)46c10::intrusive_ptr<c10::ivalue::Future> _AllReduceBySumCommHook::runHook( 47 GradBucket& bucket) { 48 std::vector<at::Tensor> tensors = {bucket.getBufferRef()}; 49 #ifdef IS_NCCLX 50 // case with sparse_metadata_ set and using indices from there 51 if (bucket.getSparseGradIndices().has_value()) { 52 AllreduceOptions opts = AllreduceOptions(); 53 opts.sparseIndices = bucket.getSparseGradIndices().value(); 54 return state_->allreduce(tensors, opts)->getFuture(); 55 } 56 #else 57 return state_->allreduce(tensors)->getFuture(); 58 #endif 59 } 60 61 } // namespace c10d 62