xref: /aosp_15_r20/external/pytorch/caffe2/perfkernels/embedding_lookup_idx.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 
5 namespace caffe2 {
6 
7 // clang-format off
8 /**
9  * Embedding lookup with reduction.
10  *
11  * `input` of size data_size * block_size
12  * `indices` of size index_size
13  * `offsets` of size output_size
14  * `weights` nullptr or array of size index_size
15  * `out` of size output_size * block_size
16  *
17  * Behavior is roughly equivalent to pseudocode:
18  *
19  * pos = 0
20  * for (i = 0..output_size-1)
21  *   for (k = 0..block_size-1)
22  *     out[i*block_size + k] = 0
23  *   start_offset = offsets[i]
24  *   end_offset = offsets[i+1]
25  *   length = end_offset - start_offset
26  *   for (j = start_offset..end_offset-1)
27  *     for (k = 0..block_size-1)
28  *       out[i*block_size + k] += input[indices[pos]*block_size + k] *
29  *           (weights ? weights[IS_WEIGHT_POSITIONAL ? j - start_offset : pos] : 1.0)
30  *     pos += 1
31  *   if (normalize_weights && length > 0)
32  *     for (k = 0..block_size-1)
33  *       out[i*block_size + k] /= length
34  *
35  * TODO: make this API also take "offsets" rather than "lengths" to match the
36  *       API for PyTorch's EmbeddingBag
37  */
38 // clang-format on
39 template <
40     typename IndexType,
41     typename InType,
42     typename OutType,
43     bool IS_WEIGHT_POSITIONAL = false>
44 void EmbeddingLookupIdx(
45     const std::int64_t block_size,
46     const std::int64_t output_size,
47     const std::int64_t index_size,
48     const std::int64_t data_size,
49     const InType* input,
50     const IndexType* indices,
51     const IndexType* offsets,
52     const float* weights, // optional, can be null for non-weighted sum
53     const float* scale_bias, // optional scale & bias params for uint8 input
54     bool normalize_by_lengths,
55     OutType* out);
56 
57 } // namespace caffe2
58