1 #pragma once 2 3 #include <cstdint> 4 5 namespace caffe2 { 6 7 // clang-format off 8 /** 9 * Embedding lookup with reduction. 10 * 11 * `input` of size data_size * block_size 12 * `indices` of size index_size 13 * `offsets` of size output_size 14 * `weights` nullptr or array of size index_size 15 * `out` of size output_size * block_size 16 * 17 * Behavior is roughly equivalent to pseudocode: 18 * 19 * pos = 0 20 * for (i = 0..output_size-1) 21 * for (k = 0..block_size-1) 22 * out[i*block_size + k] = 0 23 * start_offset = offsets[i] 24 * end_offset = offsets[i+1] 25 * length = end_offset - start_offset 26 * for (j = start_offset..end_offset-1) 27 * for (k = 0..block_size-1) 28 * out[i*block_size + k] += input[indices[pos]*block_size + k] * 29 * (weights ? weights[IS_WEIGHT_POSITIONAL ? j - start_offset : pos] : 1.0) 30 * pos += 1 31 * if (normalize_weights && length > 0) 32 * for (k = 0..block_size-1) 33 * out[i*block_size + k] /= length 34 * 35 * TODO: make this API also take "offsets" rather than "lengths" to match the 36 * API for PyTorch's EmbeddingBag 37 */ 38 // clang-format on 39 template < 40 typename IndexType, 41 typename InType, 42 typename OutType, 43 bool IS_WEIGHT_POSITIONAL = false> 44 void EmbeddingLookupIdx( 45 const std::int64_t block_size, 46 const std::int64_t output_size, 47 const std::int64_t index_size, 48 const std::int64_t data_size, 49 const InType* input, 50 const IndexType* indices, 51 const IndexType* offsets, 52 const float* weights, // optional, can be null for non-weighted sum 53 const float* scale_bias, // optional scale & bias params for uint8 input 54 bool normalize_by_lengths, 55 OutType* out); 56 57 } // namespace caffe2 58