1 #include "caffe2/perfkernels/embedding_lookup_idx.h"
2
3 #include <c10/util/BFloat16.h>
4 #include <c10/util/Half.h>
5 #include <c10/util/Logging.h>
6 #include <c10/util/irange.h>
7 #include "caffe2/perfkernels/common.h"
8
9 namespace caffe2 {
10
11 /**
12 * Base implementation does runtime dispatch for each segment of reduction
13 * @return false if there is an out-of-bound error
14 */
15 template <
16 typename IndexType,
17 typename InType,
18 typename OutType,
19 bool IS_WEIGHT_POSITIONAL = false>
EmbeddingLookupGenericSlowIdx(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const InType * input,const IndexType * indices,const IndexType * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,OutType * out)20 static bool EmbeddingLookupGenericSlowIdx(
21 const int64_t block_size,
22 const int64_t output_size,
23 const int64_t index_size,
24 const int64_t data_size,
25 const InType* input,
26 const IndexType* indices,
27 const IndexType* offsets,
28 const float* weights, // optional, can be null for sum reducer
29 const float* scale_bias, // optional scale & bias params for uint8 input
30 bool normalize_by_lengths,
31 OutType* out) {
32 int64_t current = 0;
33 for (const auto m : c10::irange(output_size)) {
34 memset(out, 0, sizeof(OutType) * block_size);
35 if (current != offsets[m] - offsets[0]) {
36 return false;
37 }
38 int64_t start_offset = offsets[m];
39 int64_t end_offset = offsets[m + 1];
40 int64_t length = end_offset - start_offset;
41 for (const auto i : c10::irange(start_offset, end_offset)) {
42 int64_t idx = indices[current];
43 if (idx < 0 || idx >= data_size) {
44 return false;
45 }
46 #ifdef __GNUC__
47 if (current + 1 < index_size) {
48 __builtin_prefetch(input + block_size * indices[current + 1], 0, 1);
49 }
50 #endif // __GNUC__
51
52 float w = 1.f, b = 0.f;
53 if (weights) {
54 w = weights[IS_WEIGHT_POSITIONAL ? i - start_offset : current];
55 }
56 if (scale_bias) {
57 b = w * scale_bias[2 * indices[current] + 1];
58 w = w * scale_bias[2 * indices[current]];
59 }
60
61 for (const auto j : c10::irange(block_size)) {
62 out[j] += w * input[block_size * indices[current] + j] + b;
63 }
64
65 ++current;
66 }
67 if (normalize_by_lengths && length) {
68 float scale = 1.f / length;
69 for (const auto j : c10::irange(block_size)) {
70 out[j] *= scale;
71 }
72 }
73 out += block_size;
74 }
75 return current == index_size;
76 }
77
78 // clang-format off
79 // Proxy back to generic implementation
80 #define EMBEDDING_IDX_SPECIALIZATION( \
81 IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL) \
82 bool \
83 EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base( \
84 const int64_t block_size, \
85 const int64_t output_size, \
86 const int64_t index_size, \
87 const int64_t data_size, \
88 const InType* input, \
89 const IndexType* indices, \
90 const IndexType* offsets, \
91 const float* weights, \
92 const float* scale_bias, \
93 bool normalize_by_lengths, \
94 OutType* out) { \
95 return EmbeddingLookupGenericSlowIdx< \
96 IndexType, \
97 InType, \
98 OutType, \
99 IS_WEIGHT_POSITIONAL>( \
100 block_size, \
101 output_size, \
102 index_size, \
103 data_size, \
104 input, \
105 indices, \
106 offsets, \
107 weights, \
108 scale_bias, \
109 normalize_by_lengths, \
110 out); \
111 } \
112 decltype( \
113 EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \
114 EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \
115 bool \
116 EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \
117 const int64_t block_size, \
118 const int64_t output_size, \
119 const int64_t index_size, \
120 const int64_t data_size, \
121 const InType* input, \
122 const IndexType* indices, \
123 const IndexType* offsets, \
124 const float* weights, \
125 const float* scale_bias, \
126 bool normalize_by_lengths, \
127 OutType* out) { \
128 if (std::is_same<InType, uint8_t>::value) { \
129 CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr"); \
130 } else { \
131 CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \
132 } \
133 AVX2_FMA_DO( \
134 EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
135 block_size, \
136 output_size, \
137 index_size, \
138 data_size, \
139 input, \
140 indices, \
141 offsets, \
142 weights, \
143 scale_bias, \
144 normalize_by_lengths, \
145 out); \
146 BASE_DO( \
147 EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
148 block_size, \
149 output_size, \
150 index_size, \
151 data_size, \
152 input, \
153 indices, \
154 offsets, \
155 weights, \
156 scale_bias, \
157 normalize_by_lengths, \
158 out); \
159 } \
160 template <> \
161 void EmbeddingLookupIdx<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>( \
162 const int64_t block_size, \
163 const int64_t output_size, \
164 const int64_t index_size, \
165 const int64_t data_size, \
166 const InType* input, \
167 const IndexType* indices, \
168 const IndexType* offsets, \
169 const float* weights, \
170 const float* scale_bias, \
171 bool normalize_by_lengths, \
172 OutType* out) { \
173 bool success = \
174 EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \
175 block_size, \
176 output_size, \
177 index_size, \
178 data_size, \
179 input, \
180 indices, \
181 offsets, \
182 weights, \
183 scale_bias, \
184 normalize_by_lengths, \
185 out); \
186 if (success) { \
187 return; \
188 } \
189 int64_t current = 0; \
190 for (int m = 0; m < output_size; ++m) { \
191 for (int64_t i = offsets[m]; i < offsets[m + 1]; ++i) { \
192 CAFFE_ENFORCE_LT(current, index_size); \
193 IndexType idx = indices[current]; \
194 CAFFE_ENFORCE( \
195 0 <= idx && idx < data_size, \
196 "Index ", \
197 current, \
198 " is out of bounds: ", \
199 idx, \
200 ", range 0 to ", \
201 data_size); \
202 ++current; \
203 } \
204 } \
205 CAFFE_ENFORCE_EQ( \
206 current, \
207 index_size, \
208 "Your input seems to be incorrect: the sum of lengths values should be " \
209 "the size of the indices tensor, but it appears not."); \
210 }
211 // clang-format on
212
213 EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false);
214 EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
215 EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false);
216 EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false);
217 EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false);
218 EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false);
219 EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
220 EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
221
222 EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true);
223 EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true);
224 EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true);
225 EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true);
226 EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true);
227 EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true);
228 EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
229 EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
230
231 #undef EMBEDDING_IDX_SPECIALIZATION
232
233 } // namespace caffe2
234