1 #include <ATen/core/Tensor.h> 2 #include <ATen/Config.h> 3 #include <cstdint> 4 5 #ifdef USE_FBGEMM 6 #include <fbgemm/FbgemmEmbedding.h> 7 #endif 8 9 namespace at::native { 10 11 enum class EmbeddingBagMode { 12 SUM = 0, 13 MEAN = 1, 14 MAX = 2, 15 }; 16 17 [[maybe_unused]] static bool operator==(int64_t op1, EmbeddingBagMode op2) { 18 return op1 == static_cast<int64_t>(op2); 19 } 20 21 [[maybe_unused]] static bool operator!=(int64_t op1, EmbeddingBagMode op2) { 22 return !(op1 == op2); 23 } 24 25 void check_arguments( 26 const Tensor& weight, 27 const Tensor& indices, 28 const Tensor& offsets, 29 const int64_t mode, 30 const std::optional<Tensor>& per_sample_weights, 31 bool include_last_offset); 32 33 void make_bag_size_out( 34 Tensor& bag_size_out, 35 const Tensor& offsets, 36 const Tensor& indices, 37 const int64_t mode, 38 const bool include_last_offset, 39 const bool requires_grad); 40 41 void make_max_indices_out( 42 Tensor& max_indices_out, 43 const Tensor& weight, 44 const Tensor& indices, 45 const Tensor& offsets, 46 const Tensor& bag_size, 47 const int64_t mode, 48 bool include_last_offset); 49 50 void make_offset2bag_out( 51 Tensor& offset2bag, 52 Tensor& output, 53 const Tensor& weight, 54 const Tensor& indices, 55 const Tensor& offsets, 56 const int64_t mode, 57 const std::optional<Tensor>& per_sample_weights, 58 const int64_t padding_idx = -1); 59 60 #ifdef USE_FBGEMM 61 62 template<bool has_weight, typename TIndex, typename TData> 63 struct _CallbackAndBlockSize { 64 using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type; 65 66 int64_t blockSize = -1; 67 TCallback callback = nullptr; 68 generateCallback_CallbackAndBlockSize69 static TCallback generateCallback(int64_t block_size) { 70 return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>( 71 block_size, 72 has_weight, 73 /* normalize_by_lengths */false, 74 /* prefetch */16, 75 /* is_weight_positional */false, 76 /* use_offsets */true); 77 } 78 79 _CallbackAndBlockSize() = default; 80 _CallbackAndBlockSize_CallbackAndBlockSize81 explicit _CallbackAndBlockSize(std::optional<int64_t> maybe_block_size) 82 : blockSize(maybe_block_size.value_or(-1)) 83 , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr) 84 {} 85 }; 86 87 template<typename... StorageMixins> 88 struct _EmbeddingBagKernelCacheImpl : private StorageMixins... { 89 90 _EmbeddingBagKernelCacheImpl() = default; 91 // use each of the mixins to store corresponding kernel and block size 92 explicit _EmbeddingBagKernelCacheImpl(std::optional<int64_t> maybe_block_size) 93 : StorageMixins(maybe_block_size)... 94 {} 95 96 // this method is thread safe (call sites may call from different threads) 97 template<bool has_weight, typename TIndex, typename TData> 98 typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback getCallback_EmbeddingBagKernelCacheImpl99 getCallback(int64_t block_size) const { 100 // if the cache doesn't store the kernel for the incoming block size 101 // (so it is different from the one stored in corresponding mixin) 102 // regenerate the kernel (not writing it into the cache so we avoid locks) 103 if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) { 104 return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size); 105 } 106 // else retrieve the cached kernel from the corresponding mixin 107 return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback; 108 } 109 }; 110 111 // instantiate the cache with the list of storage mixins 112 // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file 113 using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl< 114 _CallbackAndBlockSize<true, int32_t, float>, 115 _CallbackAndBlockSize<false, int32_t, float>, 116 _CallbackAndBlockSize<true, int64_t, float>, 117 _CallbackAndBlockSize<false, int64_t, float>, 118 _CallbackAndBlockSize<true, int32_t, unsigned short>, 119 _CallbackAndBlockSize<false, int32_t, unsigned short>, 120 _CallbackAndBlockSize<true, int64_t, unsigned short>, 121 _CallbackAndBlockSize<false, int64_t, unsigned short>>; 122 #else 123 struct _EmbeddingBagKernelCache { _EmbeddingBagKernelCache_EmbeddingBagKernelCache124 explicit _EmbeddingBagKernelCache(std::optional<int64_t> /* maybe_block_size */) {} 125 }; 126 #endif 127 128 void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag, 129 Tensor& bag_size, Tensor* max_indices, 130 const Tensor &weight, const Tensor &indices, 131 const Tensor &offsets, const int64_t mode = 0, 132 const std::optional<Tensor>& per_sample_weights = std::nullopt, 133 bool include_last_offset = false, 134 int64_t padding_idx = -1, 135 _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr); 136 137 void _embedding_bag_cpu_out( 138 at::Tensor& output, 139 at::Tensor& offset2bag, 140 at::Tensor& bag_size, 141 at::Tensor* p_max_indices, 142 const at::Tensor& weight, 143 const at::Tensor& indices, 144 const at::Tensor& offsets, 145 const bool scale_grad_by_freq, 146 const int64_t mode, 147 const bool sparse, 148 const std::optional<at::Tensor>& per_sample_weights, 149 const bool include_last_offset, 150 const std::optional<int64_t>& padding_idx, 151 _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr); 152 153 } // namespace at::native 154