xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/EmbeddingBag.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 #include <ATen/Config.h>
3 #include <cstdint>
4 
5 #ifdef USE_FBGEMM
6 #include <fbgemm/FbgemmEmbedding.h>
7 #endif
8 
9 namespace at::native {
10 
11 enum class EmbeddingBagMode {
12   SUM = 0,
13   MEAN = 1,
14   MAX = 2,
15 };
16 
17 [[maybe_unused]] static bool operator==(int64_t op1, EmbeddingBagMode op2) {
18   return op1 == static_cast<int64_t>(op2);
19 }
20 
21 [[maybe_unused]] static bool operator!=(int64_t op1, EmbeddingBagMode op2) {
22   return !(op1 == op2);
23 }
24 
25 void check_arguments(
26     const Tensor& weight,
27     const Tensor& indices,
28     const Tensor& offsets,
29     const int64_t mode,
30     const std::optional<Tensor>& per_sample_weights,
31     bool include_last_offset);
32 
33 void make_bag_size_out(
34     Tensor& bag_size_out,
35     const Tensor& offsets,
36     const Tensor& indices,
37     const int64_t mode,
38     const bool include_last_offset,
39     const bool requires_grad);
40 
41 void make_max_indices_out(
42     Tensor& max_indices_out,
43     const Tensor& weight,
44     const Tensor& indices,
45     const Tensor& offsets,
46     const Tensor& bag_size,
47     const int64_t mode,
48     bool include_last_offset);
49 
50 void make_offset2bag_out(
51     Tensor& offset2bag,
52     Tensor& output,
53     const Tensor& weight,
54     const Tensor& indices,
55     const Tensor& offsets,
56     const int64_t mode,
57     const std::optional<Tensor>& per_sample_weights,
58     const int64_t padding_idx = -1);
59 
60 #ifdef USE_FBGEMM
61 
62 template<bool has_weight, typename TIndex, typename TData>
63 struct _CallbackAndBlockSize {
64     using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
65 
66     int64_t blockSize = -1;
67     TCallback callback = nullptr;
68 
generateCallback_CallbackAndBlockSize69     static TCallback generateCallback(int64_t block_size) {
70         return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
71                 block_size,
72                 has_weight,
73                 /* normalize_by_lengths */false,
74                 /* prefetch */16,
75                 /* is_weight_positional */false,
76                 /* use_offsets */true);
77     }
78 
79     _CallbackAndBlockSize() = default;
80 
_CallbackAndBlockSize_CallbackAndBlockSize81     explicit _CallbackAndBlockSize(std::optional<int64_t> maybe_block_size)
82       : blockSize(maybe_block_size.value_or(-1))
83       , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
84     {}
85 };
86 
87 template<typename... StorageMixins>
88 struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
89 
90     _EmbeddingBagKernelCacheImpl() = default;
91     // use each of the mixins to store corresponding kernel and block size
92     explicit _EmbeddingBagKernelCacheImpl(std::optional<int64_t> maybe_block_size)
93       : StorageMixins(maybe_block_size)...
94     {}
95 
96     // this method is thread safe (call sites may call from different threads)
97     template<bool has_weight, typename TIndex, typename TData>
98     typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
getCallback_EmbeddingBagKernelCacheImpl99     getCallback(int64_t block_size) const {
100         // if the cache doesn't store the kernel for the incoming block size
101         // (so it is different from the one stored in corresponding mixin)
102         // regenerate the kernel (not writing it into the cache so we avoid locks)
103         if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
104             return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
105         }
106         // else retrieve the cached kernel from the corresponding mixin
107         return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
108     }
109 };
110 
111 // instantiate the cache with the list of storage mixins
112 // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
113 using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
114     _CallbackAndBlockSize<true, int32_t, float>,
115     _CallbackAndBlockSize<false, int32_t, float>,
116     _CallbackAndBlockSize<true, int64_t, float>,
117     _CallbackAndBlockSize<false, int64_t, float>,
118     _CallbackAndBlockSize<true, int32_t, unsigned short>,
119     _CallbackAndBlockSize<false, int32_t, unsigned short>,
120     _CallbackAndBlockSize<true, int64_t, unsigned short>,
121     _CallbackAndBlockSize<false, int64_t, unsigned short>>;
122 #else
123 struct _EmbeddingBagKernelCache {
_EmbeddingBagKernelCache_EmbeddingBagKernelCache124     explicit _EmbeddingBagKernelCache(std::optional<int64_t> /* maybe_block_size */) {}
125 };
126 #endif
127 
128 void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
129     Tensor& bag_size, Tensor* max_indices,
130     const Tensor &weight, const Tensor &indices,
131     const Tensor &offsets, const int64_t mode = 0,
132     const std::optional<Tensor>& per_sample_weights = std::nullopt,
133     bool include_last_offset = false,
134     int64_t padding_idx = -1,
135     _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
136 
137 void _embedding_bag_cpu_out(
138     at::Tensor& output,
139     at::Tensor& offset2bag,
140     at::Tensor& bag_size,
141     at::Tensor* p_max_indices,
142     const at::Tensor& weight,
143     const at::Tensor& indices,
144     const at::Tensor& offsets,
145     const bool scale_grad_by_freq,
146     const int64_t mode,
147     const bool sparse,
148     const std::optional<at::Tensor>& per_sample_weights,
149     const bool include_last_offset,
150     const std::optional<int64_t>& padding_idx,
151     _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
152 
153 } // namespace at::native
154