1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <ATen/cuda/ATenCUDAGeneral.h> 5 #include <ATen/cuda/CUDAContext.h> 6 #include <torch/csrc/Export.h> 7 #include <optional> 8 9 #include <cstddef> 10 #include <vector> 11 12 namespace torch::cuda { 13 14 using tensor_list2d = std::vector<std::vector<at::Tensor>>; 15 16 TORCH_CUDA_CU_API std::vector<at::Tensor>& broadcast_out( 17 const at::Tensor& tensor, 18 std::vector<at::Tensor>& out_tensors); 19 TORCH_CUDA_CU_API std::vector<at::Tensor> broadcast( 20 const at::Tensor& tensor, 21 at::IntArrayRef devices); 22 TORCH_CUDA_CU_API tensor_list2d broadcast_coalesced( 23 at::TensorList tensors, 24 at::IntArrayRef devices, 25 size_t buffer_size); 26 27 TORCH_CUDA_CU_API std::vector<at::Tensor>& scatter_out( 28 const at::Tensor& tensor, 29 std::vector<at::Tensor>& out_tensors, 30 int64_t dim = 0, 31 const std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>& 32 streams = std::nullopt); 33 34 TORCH_CUDA_CU_API std::vector<at::Tensor> scatter( 35 const at::Tensor& tensor, 36 at::IntArrayRef devices, 37 const std::optional<std::vector<int64_t>>& chunk_sizes = std::nullopt, 38 int64_t dim = 0, 39 const std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>& 40 streams = std::nullopt); 41 42 TORCH_CUDA_CU_API at::Tensor& gather_out( 43 at::TensorList tensors, 44 at::Tensor& out_tensor, 45 int64_t dim); 46 47 TORCH_CUDA_CU_API at::Tensor gather( 48 at::TensorList tensors, 49 int64_t dim, 50 std::optional<int32_t> destination_index); 51 52 } // namespace torch::cuda 53