xref: /aosp_15_r20/external/pytorch/caffe2/perfkernels/embedding_lookup_idx.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include "caffe2/perfkernels/embedding_lookup_idx.h"
2 
3 #include <c10/util/BFloat16.h>
4 #include <c10/util/Half.h>
5 #include <c10/util/Logging.h>
6 #include <c10/util/irange.h>
7 #include "caffe2/perfkernels/common.h"
8 
9 namespace caffe2 {
10 
11 /**
12  * Base implementation does runtime dispatch for each segment of reduction
13  * @return false if there is an out-of-bound error
14  */
15 template <
16     typename IndexType,
17     typename InType,
18     typename OutType,
19     bool IS_WEIGHT_POSITIONAL = false>
EmbeddingLookupGenericSlowIdx(const int64_t block_size,const int64_t output_size,const int64_t index_size,const int64_t data_size,const InType * input,const IndexType * indices,const IndexType * offsets,const float * weights,const float * scale_bias,bool normalize_by_lengths,OutType * out)20 static bool EmbeddingLookupGenericSlowIdx(
21     const int64_t block_size,
22     const int64_t output_size,
23     const int64_t index_size,
24     const int64_t data_size,
25     const InType* input,
26     const IndexType* indices,
27     const IndexType* offsets,
28     const float* weights, // optional, can be null for sum reducer
29     const float* scale_bias, // optional scale & bias params for uint8 input
30     bool normalize_by_lengths,
31     OutType* out) {
32   int64_t current = 0;
33   for (const auto m : c10::irange(output_size)) {
34     memset(out, 0, sizeof(OutType) * block_size);
35     if (current != offsets[m] - offsets[0]) {
36       return false;
37     }
38     int64_t start_offset = offsets[m];
39     int64_t end_offset = offsets[m + 1];
40     int64_t length = end_offset - start_offset;
41     for (const auto i : c10::irange(start_offset, end_offset)) {
42       int64_t idx = indices[current];
43       if (idx < 0 || idx >= data_size) {
44         return false;
45       }
46 #ifdef __GNUC__
47       if (current + 1 < index_size) {
48         __builtin_prefetch(input + block_size * indices[current + 1], 0, 1);
49       }
50 #endif // __GNUC__
51 
52       float w = 1.f, b = 0.f;
53       if (weights) {
54         w = weights[IS_WEIGHT_POSITIONAL ? i - start_offset : current];
55       }
56       if (scale_bias) {
57         b = w * scale_bias[2 * indices[current] + 1];
58         w = w * scale_bias[2 * indices[current]];
59       }
60 
61       for (const auto j : c10::irange(block_size)) {
62         out[j] += w * input[block_size * indices[current] + j] + b;
63       }
64 
65       ++current;
66     }
67     if (normalize_by_lengths && length) {
68       float scale = 1.f / length;
69       for (const auto j : c10::irange(block_size)) {
70         out[j] *= scale;
71       }
72     }
73     out += block_size;
74   }
75   return current == index_size;
76 }
77 
78 // clang-format off
79 // Proxy back to generic implementation
80 #define EMBEDDING_IDX_SPECIALIZATION(                                                                 \
81     IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL)                                     \
82   bool                                                                                                \
83       EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base(     \
84           const int64_t block_size,                                                                   \
85           const int64_t output_size,                                                                  \
86           const int64_t index_size,                                                                   \
87           const int64_t data_size,                                                                    \
88           const InType* input,                                                                        \
89           const IndexType* indices,                                                                   \
90           const IndexType* offsets,                                                                     \
91           const float* weights,                                                                       \
92           const float* scale_bias,                                                                    \
93           bool normalize_by_lengths,                                                                  \
94           OutType* out) {                                                                             \
95     return EmbeddingLookupGenericSlowIdx<                                                             \
96         IndexType,                                                                                    \
97         InType,                                                                                       \
98         OutType,                                                                                      \
99         IS_WEIGHT_POSITIONAL>(                                                                        \
100         block_size,                                                                                   \
101         output_size,                                                                                  \
102         index_size,                                                                                   \
103         data_size,                                                                                    \
104         input,                                                                                        \
105         indices,                                                                                      \
106         offsets,                                                                                      \
107         weights,                                                                                      \
108         scale_bias,                                                                                   \
109         normalize_by_lengths,                                                                         \
110         out);                                                                                         \
111   }                                                                                                   \
112   decltype(                                                                                           \
113       EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base)     \
114       EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \
115   bool                                                                                                \
116       EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL(             \
117           const int64_t block_size,                                                                   \
118           const int64_t output_size,                                                                  \
119           const int64_t index_size,                                                                   \
120           const int64_t data_size,                                                                    \
121           const InType* input,                                                                        \
122           const IndexType* indices,                                                                   \
123           const IndexType* offsets,                                                                     \
124           const float* weights,                                                                       \
125           const float* scale_bias,                                                                    \
126           bool normalize_by_lengths,                                                                  \
127           OutType* out) {                                                                             \
128     if (std::is_same<InType, uint8_t>::value) {                                                       \
129       CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr");                         \
130     } else {                                                                                          \
131       CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr");                             \
132     }                                                                                                 \
133     AVX2_FMA_DO(                                                                                      \
134         EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL,           \
135         block_size,                                                                                   \
136         output_size,                                                                                  \
137         index_size,                                                                                   \
138         data_size,                                                                                    \
139         input,                                                                                        \
140         indices,                                                                                      \
141         offsets,                                                                                      \
142         weights,                                                                                      \
143         scale_bias,                                                                                   \
144         normalize_by_lengths,                                                                         \
145         out);                                                                                         \
146     BASE_DO(                                                                                          \
147         EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL,           \
148         block_size,                                                                                   \
149         output_size,                                                                                  \
150         index_size,                                                                                   \
151         data_size,                                                                                    \
152         input,                                                                                        \
153         indices,                                                                                      \
154         offsets,                                                                                      \
155         weights,                                                                                      \
156         scale_bias,                                                                                   \
157         normalize_by_lengths,                                                                         \
158         out);                                                                                         \
159   }                                                                                                   \
160   template <>                                                                                         \
161   void EmbeddingLookupIdx<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>(                          \
162       const int64_t block_size,                                                                       \
163       const int64_t output_size,                                                                      \
164       const int64_t index_size,                                                                       \
165       const int64_t data_size,                                                                        \
166       const InType* input,                                                                            \
167       const IndexType* indices,                                                                       \
168       const IndexType* offsets,                                                                         \
169       const float* weights,                                                                           \
170       const float* scale_bias,                                                                        \
171       bool normalize_by_lengths,                                                                      \
172       OutType* out) {                                                                                 \
173     bool success =                                                                                    \
174         EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL(           \
175             block_size,                                                                               \
176             output_size,                                                                              \
177             index_size,                                                                               \
178             data_size,                                                                                \
179             input,                                                                                    \
180             indices,                                                                                  \
181             offsets,                                                                                  \
182             weights,                                                                                  \
183             scale_bias,                                                                               \
184             normalize_by_lengths,                                                                     \
185             out);                                                                                     \
186     if (success) {                                                                                    \
187       return;                                                                                         \
188     }                                                                                                 \
189     int64_t current = 0;                                                                              \
190     for (int m = 0; m < output_size; ++m) {                                                           \
191       for (int64_t i = offsets[m]; i < offsets[m + 1]; ++i) {                                         \
192         CAFFE_ENFORCE_LT(current, index_size);                                                        \
193         IndexType idx = indices[current];                                                             \
194         CAFFE_ENFORCE(                                                                                \
195             0 <= idx && idx < data_size,                                                              \
196             "Index ",                                                                                 \
197             current,                                                                                  \
198             " is out of bounds: ",                                                                    \
199             idx,                                                                                      \
200             ", range 0 to ",                                                                          \
201             data_size);                                                                               \
202         ++current;                                                                                    \
203       }                                                                                               \
204     }                                                                                                 \
205     CAFFE_ENFORCE_EQ(                                                                                 \
206         current,                                                                                      \
207         index_size,                                                                                   \
208         "Your input seems to be incorrect: the sum of lengths values should be "                      \
209         "the size of the indices tensor, but it appears not.");                                       \
210   }
211 // clang-format on
212 
213 EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false);
214 EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
215 EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false);
216 EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false);
217 EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false);
218 EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false);
219 EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
220 EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
221 
222 EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true);
223 EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true);
224 EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true);
225 EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true);
226 EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true);
227 EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true);
228 EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
229 EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
230 
231 #undef EMBEDDING_IDX_SPECIALIZATION
232 
233 } // namespace caffe2
234