xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/Atomic.cuh>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/TensorUtils.h>
6 
7 namespace at {
8 namespace native {
9 
10 Tensor embedding_backward_cuda_kernel(
11     const Tensor &grad,
12     const Tensor &orig_indices,
13     const Tensor &sorted_indices,
14     const Tensor &count,
15     int64_t num_weights,
16     int padding_idx = -1,
17     bool mode_mean = false,
18     const Tensor &offset2bag = Tensor(),
19     const Tensor &bag_size = Tensor(),
20     const Tensor &per_sample_weights = Tensor());
21 
22 }}
23