1 #pragma once 2 #include <ATen/core/Tensor.h> 3 #include <ATen/cuda/Atomic.cuh> 4 #include <ATen/cuda/CUDAContext.h> 5 #include <ATen/TensorUtils.h> 6 7 namespace at { 8 namespace native { 9 10 Tensor embedding_backward_cuda_kernel( 11 const Tensor &grad, 12 const Tensor &orig_indices, 13 const Tensor &sorted_indices, 14 const Tensor &count, 15 int64_t num_weights, 16 int padding_idx = -1, 17 bool mode_mean = false, 18 const Tensor &offset2bag = Tensor(), 19 const Tensor &bag_size = Tensor(), 20 const Tensor &per_sample_weights = Tensor()); 21 22 }} 23