1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/autograd/function.h> 5 #include <torch/csrc/autograd/variable.h> 6 7 #include <ATen/ATen.h> 8 #include <c10/cuda/CUDAStream.h> 9 #include <optional> 10 11 #include <cstddef> 12 #include <vector> 13 14 namespace torch::autograd { 15 16 struct TORCH_CUDA_CU_API Scatter : public Node { 17 explicit Scatter( 18 std::vector<at::Device> devices, 19 std::optional<std::vector<int64_t>> chunk_sizes = std::nullopt, 20 int64_t dim = 0, 21 std::optional<std::vector<std::optional<at::cuda::CUDAStream>>> streams = 22 std::nullopt, 23 bool unsqueeze_scalars = false); 24 ~Scatter() override; 25 26 variable_list apply(variable_list&& inputs) override; 27 28 std::vector<at::Device> devices_; 29 std::optional<std::vector<int64_t>> chunk_sizes_; 30 int64_t dim_; 31 std::optional<std::vector<std::optional<at::cuda::CUDAStream>>> streams_; 32 bool unsqueeze_scalars_; 33 }; 34 35 struct TORCH_CUDA_CU_API Gather : public Node { 36 explicit Gather(const at::Device& destination_device, int64_t dim = 0); 37 ~Gather() override; 38 39 variable_list apply(variable_list&& inputs) override; 40 41 at::Device destination_device_; 42 int64_t dim_; 43 }; 44 45 } // namespace torch::autograd 46