xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/comm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/cuda/ATenCUDAGeneral.h>
5 #include <ATen/cuda/CUDAContext.h>
6 #include <torch/csrc/Export.h>
7 #include <optional>
8 
9 #include <cstddef>
10 #include <vector>
11 
12 namespace torch::cuda {
13 
14 using tensor_list2d = std::vector<std::vector<at::Tensor>>;
15 
16 TORCH_CUDA_CU_API std::vector<at::Tensor>& broadcast_out(
17     const at::Tensor& tensor,
18     std::vector<at::Tensor>& out_tensors);
19 TORCH_CUDA_CU_API std::vector<at::Tensor> broadcast(
20     const at::Tensor& tensor,
21     at::IntArrayRef devices);
22 TORCH_CUDA_CU_API tensor_list2d broadcast_coalesced(
23     at::TensorList tensors,
24     at::IntArrayRef devices,
25     size_t buffer_size);
26 
27 TORCH_CUDA_CU_API std::vector<at::Tensor>& scatter_out(
28     const at::Tensor& tensor,
29     std::vector<at::Tensor>& out_tensors,
30     int64_t dim = 0,
31     const std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>&
32         streams = std::nullopt);
33 
34 TORCH_CUDA_CU_API std::vector<at::Tensor> scatter(
35     const at::Tensor& tensor,
36     at::IntArrayRef devices,
37     const std::optional<std::vector<int64_t>>& chunk_sizes = std::nullopt,
38     int64_t dim = 0,
39     const std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>&
40         streams = std::nullopt);
41 
42 TORCH_CUDA_CU_API at::Tensor& gather_out(
43     at::TensorList tensors,
44     at::Tensor& out_tensor,
45     int64_t dim);
46 
47 TORCH_CUDA_CU_API at::Tensor gather(
48     at::TensorList tensors,
49     int64_t dim,
50     std::optional<int32_t> destination_index);
51 
52 } // namespace torch::cuda
53