1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
4 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
5 #include <ATen/native/quantized/cpu/qembeddingbag.h>
6 #include <torch/library.h>
7 #ifdef USE_FBGEMM
8 #include <fbgemm/Fbgemm.h>
9 #include <fbgemm/FbgemmEmbedding.h>
10 #endif
11
12 #include <ATen/Parallel.h>
13 #include <ATen/Utils.h>
14 #include <c10/util/irange.h>
15
16 #include <array>
17
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/arange.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/resize_native.h>
25 #endif
26
27 int register_embedding_params();
28
29 namespace {
30
31 // Fallback implementation when FBGEMM is not available.
32 template <
33 typename IndexType,
34 typename OffsetType,
35 int BIT_RATE,
36 int NUM_ELEM_PER_BYTE>
embedding_lookup_fallback_impl(const at::Tensor & weight,const at::Tensor & indices,const at::Tensor & offsets,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,at::Tensor & output,const int64_t block_size,const int64_t output_size,bool include_last_offset,bool pruned)37 at::Tensor& embedding_lookup_fallback_impl(
38 const at::Tensor& weight,
39 const at::Tensor& indices,
40 const at::Tensor& offsets,
41 const std::optional<at::Tensor>& per_sample_weights_,
42 const std::optional<at::Tensor>& compressed_indices_mapping,
43 at::Tensor& output,
44 const int64_t block_size,
45 const int64_t output_size,
46 bool include_last_offset,
47 bool pruned) {
48 auto* output_data = output.data_ptr<float>();
49 const auto weight_data = weight.data_ptr<uint8_t>();
50 const auto indices_data = indices.data_ptr<IndexType>();
51 int32_t* compressed_indices_mapping_data = nullptr;
52 const auto weight_sizes = weight.sizes();
53 const int64_t N = weight_sizes[0];
54 const int64_t weight_size = weight_sizes[1];
55 const int index_size = indices.numel();
56
57 auto accessor = offsets.accessor<OffsetType, 1>();
58 std::vector<OffsetType> lengths_data;
59
60 int64_t lower = accessor[0];
61 for (const auto i : c10::irange(1, offsets.numel())) {
62 lengths_data.push_back(accessor[i] - lower);
63 lower = accessor[i];
64 }
65 if (!include_last_offset) {
66 lengths_data.push_back(indices.numel() - lower);
67 }
68
69 int64_t current = 0;
70 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
71 float* per_sample_weights_data;
72 if (per_sample_weights_.has_value()) {
73 per_sample_weights_data = per_sample_weights_.value().data_ptr<float>();
74 }
75 for (const auto m : c10::irange(output_size)) {
76 memset(output_data, 0, block_size * sizeof(float));
77 TORCH_CHECK(
78 current + lengths_data[m] <= index_size,
79 "Expect the lengths data to be less than indices size");
80
81 for (int i = 0; i < lengths_data[m]; ++i, ++current) {
82 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
83 int64_t idx;
84 if (!pruned) {
85 idx = indices_data[current];
86 TORCH_CHECK((idx >= 0 && idx < N), "Invalid indices data");
87 } else {
88 int64_t uncompressed_idx = indices_data[current];
89 int compressed_index_size = compressed_indices_mapping.value().numel();
90 compressed_indices_mapping_data =
91 compressed_indices_mapping.value().data_ptr<int32_t>();
92 TORCH_CHECK(
93 uncompressed_idx >= 0 && uncompressed_idx < compressed_index_size,
94 "Invalid indices data for Sparse Op.")
95 idx = compressed_indices_mapping_data[uncompressed_idx];
96 if (idx == -1) {
97 continue;
98 }
99 }
100
101 float weight_val = 1.0f;
102 if (per_sample_weights_.has_value()) {
103 weight_val = per_sample_weights_data[current];
104 }
105 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
106 float scale, bias;
107 if (BIT_RATE == 8) {
108 const uint8_t* scale_bias =
109 weight_data + (idx + 1) * weight_size - 2 * sizeof(float);
110 uint32_t scale_val_int32 = 0;
111 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
112 scale_val_int32 = scale_val_int32 |
113 (scale_bias[0]) |
114 (scale_bias[1] << 8) |
115 (scale_bias[2] << 16) |
116 (scale_bias[3] << 24);
117 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
118 scale_val_int32 = scale_val_int32 |
119 (scale_bias[3]) |
120 (scale_bias[2] << 8) |
121 (scale_bias[1] << 16) |
122 (scale_bias[0] << 24);
123 #else
124 #error Unexpected or undefined __BYTE_ORDER__
125 #endif
126 float scale_val = (reinterpret_cast<float*>(&scale_val_int32))[0];
127 uint32_t bias_val_int32 = 0;
128 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
129 bias_val_int32 = bias_val_int32 |
130 (scale_bias[4]) |
131 (scale_bias[5] << 8) |
132 (scale_bias[6] << 16) |
133 (scale_bias[7] << 24);
134 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
135 bias_val_int32 = bias_val_int32 |
136 (scale_bias[7]) |
137 (scale_bias[6] << 8) |
138 (scale_bias[5] << 16) |
139 (scale_bias[4] << 24);
140 #else
141 #error Unexpected or undefined __BYTE_ORDER__
142 #endif
143 float bias_val = (reinterpret_cast<float*>(&bias_val_int32))[0];
144 scale = weight_val * scale_val;
145 bias = weight_val * bias_val;
146 } else {
147 const uint8_t* scale_bias =
148 weight_data + (idx + 1) * weight_size - 2 * sizeof(at::Half);
149 uint16_t scale_val_int16 = 0;
150 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
151 scale_val_int16 = scale_val_int16 |
152 (scale_bias[0]) |
153 (scale_bias[1] << 8);
154 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
155 scale_val_int16 = scale_val_int16 |
156 (scale_bias[1]) |
157 (scale_bias[0] << 8);
158 #else
159 #error Unexpected or undefined __BYTE_ORDER__
160 #endif
161 at::Half scale_val = (reinterpret_cast<at::Half*>(&scale_val_int16))[0];
162 uint16_t bias_val_int16 = 0;
163 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
164 bias_val_int16 = bias_val_int16 |
165 (scale_bias[2]) |
166 (scale_bias[3] << 8);
167 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
168 bias_val_int16 = bias_val_int16 |
169 (scale_bias[3]) |
170 (scale_bias[2] << 8);
171 #else
172 #error Unexpected or undefined __BYTE_ORDER__
173 #endif
174 at::Half bias_val = (reinterpret_cast<at::Half*>(&bias_val_int16))[0];
175 scale = weight_val * scale_val;
176 bias = weight_val * bias_val;
177 }
178
179 for (const auto j : c10::irange(block_size)) {
180 uint8_t quantized =
181 weight_data[idx * weight_size + j / NUM_ELEM_PER_BYTE];
182 quantized >>= (j % NUM_ELEM_PER_BYTE) * BIT_RATE;
183 quantized &= (1 << BIT_RATE) - 1;
184
185 output_data[j] = fma(scale, quantized, output_data[j] + bias);
186 }
187 } // for each i
188 output_data += block_size;
189 } // for each m
190 return output;
191 }
192
193 namespace {
194 template <typename IndexType, typename OffsetType>
fbgemm_spmdm_report_error_(int64_t output_size,int index_size,int64_t N,const OffsetType * offsets,const IndexType * indices)195 void fbgemm_spmdm_report_error_(
196 int64_t output_size,
197 int index_size,
198 int64_t N,
199 const OffsetType* offsets,
200 const IndexType* indices) {
201 for (const auto m : c10::irange(output_size)) {
202 for (OffsetType i = offsets[m]; i < offsets[m + 1]; ++i) {
203 TORCH_CHECK(i < index_size);
204 IndexType idx = indices[i];
205 TORCH_CHECK(
206 0 <= idx && idx < N,
207 "Index ",
208 i,
209 " is out of bounds: ",
210 idx,
211 ", range 0 to ",
212 N);
213 }
214 }
215 TORCH_CHECK(
216 offsets[output_size] == index_size,
217 "Yout input seems to be incorrect: the last offset value should be "
218 "the size of the indices tensor, but it appears not.");
219 }
220 } // namespace
221
222 template <typename IndexType, typename OffsetType>
embedding_bag_nbit_impl(at::Tensor & output,const at::Tensor & weight,const int bit_width,const at::Tensor & indices,const at::Tensor & offsets,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)223 at::Tensor& embedding_bag_nbit_impl(
224 at::Tensor& output,
225 const at::Tensor& weight,
226 const int bit_width,
227 const at::Tensor& indices,
228 const at::Tensor& offsets,
229 bool pruned_weights,
230 const std::optional<at::Tensor>& per_sample_weights_,
231 const std::optional<at::Tensor>& compressed_indices_mapping,
232 bool include_last_offset,
233 bool is_embedding_op) {
234 TORCH_CHECK(weight.dim() == 2);
235 TORCH_CHECK(offsets.dim() == 1);
236
237 auto offsets_data = offsets.data_ptr<OffsetType>();
238
239 // Get compressed indices for pruned_weights op.
240 int32_t* compressed_indices_mapping_data = nullptr;
241 int compressed_index_size = 0;
242 bool fallback_to_no_sparse = false;
243 if (pruned_weights) {
244 compressed_index_size = compressed_indices_mapping.value().numel();
245 compressed_indices_mapping_data =
246 compressed_indices_mapping.value().data_ptr<int32_t>();
247
248 // if compressed_indices_mapping is [0], it is a indicator that
249 // we should fallback to non sparse embedding look up kernel.
250 if ((compressed_index_size == 1 &&
251 compressed_indices_mapping_data[0] == 0)) {
252 fallback_to_no_sparse = true;
253 }
254 }
255
256 const auto weight_sizes = weight.sizes();
257 const int64_t weight_size = weight_sizes[1];
258 int NUM_ELEM_PER_BYTE = 8 / bit_width;
259 const int64_t D =
260 (weight_size - 2 * sizeof(at::Half)) * NUM_ELEM_PER_BYTE; // NB: 2-byte fp16 scale and 2-byte zero_offset
261 const int64_t M = offsets.sizes()[0];
262
263 int64_t output_size = M - 1;
264 std::vector<OffsetType> offsets_include_last_val;
265 if (!include_last_offset) {
266 output_size = M;
267 offsets_include_last_val.resize(M + 1);
268 // Avoid `null pointer passed as argument 2` ASAN violation when offsets
269 // tensor is empty.
270 if (M > 0) {
271 std::memcpy(
272 offsets_include_last_val.data(),
273 offsets_data,
274 sizeof(OffsetType) * M);
275 }
276 offsets_include_last_val[M] = indices.numel();
277 offsets_data = offsets_include_last_val.data();
278 }
279 {
280 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
281 std::array<int64_t, 3> shape_arr;
282 c10::IntArrayRef shape;
283 if(indices.dim() == 2 && is_embedding_op) {
284 const auto indices_sizes = indices.sizes();
285 shape_arr[0] = indices_sizes[0];
286 shape_arr[1] = indices_sizes[1];
287 shape_arr[2] = D;
288 shape = shape_arr;
289 } else {
290 shape_arr[0] = output_size;
291 shape_arr[1] = D;
292 shape = c10::IntArrayRef(&shape_arr[0], 2);
293 }
294 at::native::resize_(output, shape, std::nullopt);
295 }
296 #ifdef USE_FBGEMM
297 const auto indices_data = indices.data_ptr<IndexType>();
298 const auto weight_data = weight.data_ptr<uint8_t>();
299 auto* output_data = output.data_ptr<float>();
300 const int64_t N = weight_sizes[0];
301
302 const int64_t block_size = D;
303 const int index_size = indices.numel();
304 constexpr int prefetch_distance = 16;
305 if (!pruned_weights || fallback_to_no_sparse) {
306 // Generate the fbgemm kernel
307 auto kernel = fbgemm::GenerateEmbeddingSpMDMNBit<IndexType, OffsetType>(
308 /*bit rate=*/bit_width,
309 /*block size=*/block_size,
310 /*has weights=*/per_sample_weights_.has_value(),
311 /*normalize_by_lengths=*/false,
312 /*prefetch distance=*/prefetch_distance,
313 /*is_weight_positional=*/false,
314 /*use_offsets=*/true);
315
316 bool success = kernel(
317 /*output_size=*/output_size,
318 /*index_size=*/index_size,
319 /*data_size=*/N,
320 /*input=*/weight_data,
321 /*indices=*/indices_data,
322 /*offsets=*/offsets_data,
323 /*weights=*/
324 per_sample_weights_.has_value()
325 ? per_sample_weights_.value().data_ptr<float>()
326 : nullptr,
327 /*output=*/output_data);
328
329 if (!success) {
330 fbgemm_spmdm_report_error_(
331 output_size, index_size, N, offsets_data, indices_data);
332 }
333 } else {
334 auto kernel =
335 fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<IndexType, OffsetType>(
336 /*bit rate=*/bit_width,
337 /*block_size=*/block_size,
338 /*has weights=*/per_sample_weights_.has_value(),
339 /*normalize_by_lengths=*/false,
340 /*prefetch distance*/ prefetch_distance,
341 /*is_weight_positional*/ false,
342 /*use_offsets*/ true);
343 bool success = kernel(
344 /*output_size=*/output_size,
345 /*index_size=*/index_size,
346 /*data_size=*/compressed_index_size,
347 /*input=*/weight_data,
348 /*indices=*/indices_data,
349 /*offsets=*/offsets_data,
350 /*weights=*/
351 per_sample_weights_.has_value()
352 ? per_sample_weights_.value().data_ptr<float>()
353 : nullptr,
354 /*output=*/output_data,
355 /*compressed_indices_table=*/compressed_indices_mapping_data);
356 if (!success) {
357 fbgemm_spmdm_report_error_(
358 output_size,
359 index_size,
360 compressed_index_size,
361 offsets_data,
362 indices_data);
363 }
364 }
365 return output;
366 #else
367 if (bit_width == 4) {
368 return embedding_lookup_fallback_impl<IndexType, OffsetType, 4, 2>(
369 weight,
370 indices,
371 offsets,
372 per_sample_weights_,
373 compressed_indices_mapping,
374 output,
375 D,
376 output_size,
377 include_last_offset,
378 (pruned_weights && !fallback_to_no_sparse));
379 }
380 // bit_width == 2
381 return embedding_lookup_fallback_impl<IndexType, OffsetType, 2, 4>(
382 weight,
383 indices,
384 offsets,
385 per_sample_weights_,
386 compressed_indices_mapping,
387 output,
388 D,
389 output_size,
390 include_last_offset,
391 (pruned_weights && !fallback_to_no_sparse));
392 #endif
393 }
394
395 template <typename IndexType, typename OffsetType>
embedding_bag_byte_impl(at::Tensor & output,const at::Tensor & weight,const at::Tensor & indices,const at::Tensor & offsets,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)396 at::Tensor& embedding_bag_byte_impl(
397 at::Tensor& output,
398 const at::Tensor& weight,
399 const at::Tensor& indices,
400 const at::Tensor& offsets,
401 bool pruned_weights,
402 const std::optional<at::Tensor>& per_sample_weights_,
403 const std::optional<at::Tensor>& compressed_indices_mapping,
404 bool include_last_offset,
405 bool is_embedding_op) {
406 TORCH_CHECK(weight.scalar_type() == at::kByte);
407 TORCH_CHECK(weight.dim() == 2);
408 TORCH_CHECK(offsets.dim() == 1);
409 auto offsets_data = offsets.data_ptr<OffsetType>();
410
411 // Get compressed indices for pruned_weights.
412 int32_t* compressed_indices_mapping_data = nullptr;
413 int compressed_index_size = 0;
414 bool fallback_to_no_sparse = false;
415 if (pruned_weights) {
416 compressed_index_size = compressed_indices_mapping.value().numel();
417 compressed_indices_mapping_data =
418 compressed_indices_mapping.value().data_ptr<int32_t>();
419
420 // if compressed_indices_mapping is [0], it is a indicator that
421 // we should fallback to non sparse embedding look up kernel.
422 if ((compressed_index_size == 1 &&
423 compressed_indices_mapping_data[0] == 0)) {
424 fallback_to_no_sparse = true;
425 }
426 }
427
428 const auto weight_sizes = weight.sizes();
429 const int64_t D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias
430 const int64_t M = offsets.sizes()[0];
431
432 int64_t output_size = M - 1;
433 std::vector<OffsetType> offsets_include_last_val;
434
435 if (!include_last_offset) {
436 output_size = M;
437 offsets_include_last_val.resize(M + 1);
438 // Avoid `null pointer passed as argument 2` ASAN violation when offsets
439 // tensor is empty.
440 if (M > 0) {
441 std::memcpy(
442 offsets_include_last_val.data(),
443 offsets_data,
444 sizeof(OffsetType) * M);
445 }
446 offsets_include_last_val[M] = indices.numel();
447 offsets_data = offsets_include_last_val.data();
448 }
449 {
450 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
451 std::array<int64_t, 3> shape_arr;
452 c10::IntArrayRef shape;
453 if (indices.dim() == 2 && is_embedding_op) {
454 const auto indices_sizes = indices.sizes();
455 shape_arr[0] = indices_sizes[0];
456 shape_arr[1] = indices_sizes[1];
457 shape_arr[2] = D;
458 shape = shape_arr;
459 } else {
460 shape_arr[0] = output_size;
461 shape_arr[1] = D;
462 shape = c10::IntArrayRef(&shape_arr[0], 2);
463 }
464 at::native::resize_(output, shape, std::nullopt);
465 }
466 #ifdef USE_FBGEMM
467 const int64_t N = weight_sizes[0];
468 const auto weight_data = weight.data_ptr<uint8_t>();
469 const auto indices_data = indices.data_ptr<IndexType>();
470 auto* output_data = output.data_ptr<float>();
471 const int index_size = indices.numel();
472
473 if (!pruned_weights || fallback_to_no_sparse) {
474 auto kernel_i8 =
475 fbgemm::GenerateEmbeddingSpMDM<uint8_t, IndexType, OffsetType, /*OutType=*/float, /*TRHEAD_LOCAL=*/true>(
476 /*block_size=*/D,
477 /*has_weight=*/per_sample_weights_.has_value(),
478 /*normalize_by_lengths=*/false,
479 /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
480 /*is_weight_positional=*/false,
481 /*use_offsets=*/true);
482
483 at::parallel_for(
484 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
485 bool success = kernel_i8(
486 /*output_size=*/end_idx - start_idx,
487 /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
488 /*data_size=*/N,
489 /*input=*/weight_data,
490 /*indices=*/indices_data + offsets_data[start_idx],
491 /*offsets_or_lengths=*/offsets_data + start_idx,
492 /*weights=*/
493 per_sample_weights_
494 ? per_sample_weights_.value().const_data_ptr<float>() +
495 offsets_data[start_idx]
496 : nullptr,
497 /*out=*/output_data + start_idx * D);
498
499 if (!success) {
500 fbgemm_spmdm_report_error_(
501 end_idx - start_idx,
502 offsets_data[end_idx] - offsets_data[start_idx],
503 N,
504 offsets_data + start_idx,
505 indices_data + offsets_data[start_idx]);
506 }
507 });
508 } else {
509 // pruned weights
510 auto kernel_i8_sparse = fbgemm::
511 GenerateEmbeddingSpMDMRowWiseSparse<uint8_t, IndexType, OffsetType>(
512 /*block_size=*/D,
513 /*has_weight=*/per_sample_weights_.has_value(),
514 /*normalize_by_lengths=*/false,
515 /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
516 /*is_weight_positional=*/false,
517 /*use_offsets=*/true);
518
519 auto success = kernel_i8_sparse(
520 /*output_size=*/output_size,
521 /*index_size=*/index_size,
522 /*data_size=*/compressed_index_size,
523 /*input=*/weight_data,
524 /*indices=*/indices_data,
525 /*offsets=*/offsets_data,
526 /*weights=*/
527 per_sample_weights_.has_value()
528 ? per_sample_weights_.value().data_ptr<float>()
529 : nullptr,
530 /*output=*/output_data,
531 /*compressed_indices_table=*/compressed_indices_mapping_data);
532 if (!success) {
533 fbgemm_spmdm_report_error_(
534 output_size,
535 index_size,
536 compressed_index_size,
537 offsets_data,
538 indices_data);
539 }
540 }
541 return output;
542 #else
543 return embedding_lookup_fallback_impl<IndexType, OffsetType, 8, 1>(
544 weight,
545 indices,
546 offsets,
547 per_sample_weights_,
548 compressed_indices_mapping,
549 output,
550 D,
551 output_size,
552 include_last_offset,
553 (pruned_weights && !fallback_to_no_sparse));
554 #endif
555 }
556
embedding_bag_byte_helper(at::Tensor & output,const at::Tensor & weight,const at::Tensor & indices,const std::optional<at::Tensor> & offsets_in,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)557 at::Tensor& embedding_bag_byte_helper(
558 at::Tensor& output,
559 const at::Tensor& weight,
560 const at::Tensor& indices,
561 const std::optional<at::Tensor>& offsets_in,
562 bool pruned_weights,
563 const std::optional<at::Tensor>& per_sample_weights_,
564 const std::optional<at::Tensor>& compressed_indices_mapping,
565 bool include_last_offset,
566 bool is_embedding_op) {
567 c10::MaybeOwned<at::Tensor> offsets;
568 TORCH_CHECK(
569 indices.dim() == 1 || indices.dim() == 2,
570 "qembedding/qembedding_bag operator supports 1 or 2d indices, got ",
571 indices.dim());
572 // For embedding_bag operator with 2D indices, we set the offsets explicitly
573 // here.
574 if (indices.dim() == 2 && !is_embedding_op) {
575 TORCH_CHECK(
576 !offsets_in.has_value(),
577 "embedding_bag_byte operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
578
579 offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
580 } else {
581 TORCH_CHECK(
582 offsets_in.has_value(),
583 "embedding_bag_byte expects offsets to be set for 1D indices.");
584 offsets = c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
585 }
586
587 TORCH_CHECK(
588 indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong,
589 "Expect 32 or 64 bit indices, but found ",
590 indices.scalar_type(),
591 " instead.");
592 TORCH_CHECK(
593 offsets->scalar_type() == at::kInt || offsets->scalar_type() == at::kLong,
594 "Expect 32 or 64 bit offsets, but found ",
595 offsets->scalar_type(),
596 " instead.");
597 TORCH_CHECK(
598 weight.is_contiguous() && indices.is_contiguous() &&
599 offsets->is_contiguous(),
600 "Expect weight, indices, and offsets to be contiguous.");
601
602 // Using helper function to support different type combination without the
603 // need to cast, which can be additional performance overhead
604 if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) {
605 return embedding_bag_byte_impl<int, int>(
606 output,
607 weight,
608 indices,
609 *offsets,
610 pruned_weights,
611 per_sample_weights_,
612 compressed_indices_mapping,
613 include_last_offset,
614 is_embedding_op);
615 } else if (
616 indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kLong) {
617 return embedding_bag_byte_impl<int, int64_t>(
618 output,
619 weight,
620 indices,
621 *offsets,
622 pruned_weights,
623 per_sample_weights_,
624 compressed_indices_mapping,
625 include_last_offset,
626 is_embedding_op);
627 } else if (
628 indices.scalar_type() == at::kLong && offsets->scalar_type() == at::kInt) {
629 return embedding_bag_byte_impl<int64_t, int>(
630 output,
631 weight,
632 indices,
633 *offsets,
634 pruned_weights,
635 per_sample_weights_,
636 compressed_indices_mapping,
637 include_last_offset,
638 is_embedding_op);
639 }
640
641 // default case given the TORCH_CHECK above
642 return embedding_bag_byte_impl<int64_t, int64_t>(
643 output,
644 weight,
645 indices,
646 *offsets,
647 pruned_weights,
648 per_sample_weights_,
649 compressed_indices_mapping,
650 include_last_offset,
651 is_embedding_op);
652 }
653
_embedding_bag_nbit_helper(at::Tensor & output,const at::Tensor & weight,const int bit_width,const at::Tensor & indices,const std::optional<at::Tensor> & offsets_in,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)654 at::Tensor& _embedding_bag_nbit_helper(
655 at::Tensor& output,
656 const at::Tensor& weight,
657 const int bit_width,
658 const at::Tensor& indices,
659 const std::optional<at::Tensor>& offsets_in,
660 bool pruned_weights,
661 const std::optional<at::Tensor>& per_sample_weights_,
662 const std::optional<at::Tensor>& compressed_indices_mapping,
663 bool include_last_offset,
664 bool is_embedding_op) {
665 c10::MaybeOwned<at::Tensor> offsets;
666 TORCH_CHECK(
667 bit_width == 4 || bit_width == 2,
668 "qembedding/qembedding_bag operator supports bit_width 2 or 4, got ",
669 bit_width);
670 TORCH_CHECK(
671 indices.dim() == 1 || indices.dim() == 2,
672 "qembedding/qembedding_bag operator supports 1 or 2d indices, got ",
673 indices.dim());
674
675 // For embedding_bag operator with 2D indices, we need to set the offsets
676 // explicitly here.
677 if (indices.dim() == 2 && !is_embedding_op) {
678 TORCH_CHECK(
679 !offsets_in.has_value(),
680 "embedding_bag_4bit/embedding_bag_2bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
681
682 offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(
683 0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
684 } else {
685 TORCH_CHECK(
686 offsets_in.has_value(),
687 "embedding_bag_4bit/embedding_bag_2bit operator expects offsets to be set for 1D indices.");
688 offsets = c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
689 }
690
691 TORCH_CHECK(
692 indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong,
693 "Expect 32 or 64 bit indices, but found ",
694 indices.scalar_type(),
695 " instead.");
696 TORCH_CHECK(
697 offsets->scalar_type() == at::kInt || offsets->scalar_type() == at::kLong,
698 "Expect 32 or 64 bit offsets, but found ",
699 offsets->scalar_type(),
700 " instead.");
701 TORCH_CHECK(
702 weight.is_contiguous() && indices.is_contiguous() &&
703 offsets->is_contiguous(),
704 "Expect weight, indices, and offsets to be contiguous.");
705
706 // Using helper function to support different type combination without the
707 // need to cast, which can be additional performance overhead
708 if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) {
709 return embedding_bag_nbit_impl<int, int>(
710 output,
711 weight,
712 bit_width,
713 indices,
714 *offsets,
715 pruned_weights,
716 per_sample_weights_,
717 compressed_indices_mapping,
718 include_last_offset,
719 is_embedding_op);
720 } else if (
721 indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kLong) {
722 return embedding_bag_nbit_impl<int, int64_t>(
723 output,
724 weight,
725 bit_width,
726 indices,
727 *offsets,
728 pruned_weights,
729 per_sample_weights_,
730 compressed_indices_mapping,
731 include_last_offset,
732 is_embedding_op);
733 } else if (
734 indices.scalar_type() == at::kLong && offsets->scalar_type() == at::kInt) {
735 return embedding_bag_nbit_impl<int64_t, int>(
736 output,
737 weight,
738 bit_width,
739 indices,
740 *offsets,
741 pruned_weights,
742 per_sample_weights_,
743 compressed_indices_mapping,
744 include_last_offset,
745 is_embedding_op);
746 }
747 return embedding_bag_nbit_impl<int64_t, int64_t>(
748 output,
749 weight,
750 bit_width,
751 indices,
752 *offsets,
753 pruned_weights,
754 per_sample_weights_,
755 compressed_indices_mapping,
756 include_last_offset,
757 is_embedding_op);
758 }
759 } // namespace
760
embeddingbag_byte(const at::Tensor & indices,const std::optional<at::Tensor> & offsets_in,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)761 at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
762 const at::Tensor& indices,
763 const std::optional<at::Tensor>& offsets_in,
764 bool pruned_weights,
765 const std::optional<at::Tensor>& per_sample_weights_,
766 const std::optional<at::Tensor>& compressed_indices_mapping,
767 bool include_last_offset,
768 bool is_embedding_op) {
769 auto output = at::empty({0}, packed_w.options().dtype(at::kFloat));
770 return embedding_bag_byte_helper(
771 output,
772 packed_w,
773 indices,
774 offsets_in,
775 pruned_weights,
776 per_sample_weights_,
777 compressed_indices_mapping,
778 include_last_offset,
779 is_embedding_op);
780 }
781
embeddingbag_4bit(const at::Tensor & indices,const std::optional<at::Tensor> & offsets_in,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)782 at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
783 const at::Tensor& indices,
784 const std::optional<at::Tensor>& offsets_in,
785 bool pruned_weights,
786 const std::optional<at::Tensor>& per_sample_weights_,
787 const std::optional<at::Tensor>& compressed_indices_mapping,
788 bool include_last_offset,
789 bool is_embedding_op) {
790 if (per_sample_weights_.has_value()) {
791 TORCH_CHECK(
792 (per_sample_weights_.value().scalar_type() == at::kFloat ||
793 per_sample_weights_.value().scalar_type() == at::kHalf),
794 "Expect fp32 or fp16 weights, but found",
795 per_sample_weights_.value().scalar_type(),
796 " instead")
797 }
798
799 auto output = at::empty({0}, packed_w.options().dtype(at::kFloat));
800 return _embedding_bag_nbit_helper(
801 output,
802 packed_w,
803 4,
804 indices,
805 offsets_in,
806 pruned_weights,
807 per_sample_weights_.has_value()
808 ? per_sample_weights_.value().to(at::kFloat)
809 : per_sample_weights_,
810 compressed_indices_mapping,
811 include_last_offset,
812 is_embedding_op);
813 }
814
815 namespace at {
816 namespace native {
817
embedding_bag_byte_rowwise_offsets_out(Tensor & output,const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)818 Tensor& embedding_bag_byte_rowwise_offsets_out(
819 Tensor& output,
820 const Tensor& weight,
821 const Tensor& indices,
822 const std::optional<Tensor>& offsets_in,
823 const bool /* scale_grad_by_freq */,
824 const int64_t /* mode */,
825 bool pruned_weights,
826 const std::optional<Tensor>& per_sample_weights_,
827 const std::optional<Tensor>& compressed_indices_mapping,
828 bool include_last_offset) {
829 return embedding_bag_byte_helper(
830 output,
831 weight,
832 indices,
833 offsets_in,
834 pruned_weights,
835 per_sample_weights_,
836 compressed_indices_mapping,
837 include_last_offset,
838 false /* is_embedding_op */);
839 }
840
embedding_bag_4bit_rowwise_offsets_out(Tensor & output,const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)841 Tensor& embedding_bag_4bit_rowwise_offsets_out(
842 Tensor& output,
843 const Tensor& weight,
844 const Tensor& indices,
845 const std::optional<Tensor>& offsets_in,
846 const bool /* scale_grad_by_freq */,
847 const int64_t /* mode */,
848 bool pruned_weights,
849 const std::optional<Tensor>& per_sample_weights_,
850 const std::optional<Tensor>& compressed_indices_mapping,
851 bool include_last_offset) {
852
853 if (per_sample_weights_.has_value()) {
854 TORCH_CHECK(
855 (per_sample_weights_.value().scalar_type() == at::kFloat ||
856 per_sample_weights_.value().scalar_type() == at::kHalf),
857 "Expect fp32 or fp16 weights, but found",
858 per_sample_weights_.value().scalar_type(),
859 " instead")
860 }
861 return _embedding_bag_nbit_helper(
862 output,
863 weight,
864 4,
865 indices,
866 offsets_in,
867 pruned_weights,
868 per_sample_weights_.has_value()
869 ? per_sample_weights_.value().to(at::kFloat)
870 : per_sample_weights_,
871 compressed_indices_mapping,
872 include_last_offset,
873 false);
874 }
875
embedding_bag_2bit_rowwise_offsets_out(Tensor & output,const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)876 static Tensor& embedding_bag_2bit_rowwise_offsets_out(
877 Tensor& output,
878 const Tensor& weight,
879 const Tensor& indices,
880 const std::optional<Tensor>& offsets_in,
881 const bool /* scale_grad_by_freq */,
882 const int64_t /* mode */,
883 bool pruned_weights,
884 const std::optional<Tensor>& per_sample_weights_,
885 const std::optional<Tensor>& compressed_indices_mapping,
886 bool include_last_offset) {
887
888 if (per_sample_weights_.has_value()) {
889 TORCH_CHECK(
890 (per_sample_weights_.value().scalar_type() == at::kFloat ||
891 per_sample_weights_.value().scalar_type() == at::kHalf),
892 "Expect fp32 or fp16 weights, but found",
893 per_sample_weights_.value().scalar_type(),
894 " instead")
895 }
896 return _embedding_bag_nbit_helper(
897 output,
898 weight,
899 2,
900 indices,
901 offsets_in,
902 pruned_weights,
903 per_sample_weights_.has_value()
904 ? per_sample_weights_.value().to(at::kFloat)
905 : per_sample_weights_,
906 compressed_indices_mapping,
907 include_last_offset,
908 false);
909 }
910
911 namespace {
912
913
create_empty_from(const at::Tensor & t,c10::ScalarType dtype)914 inline at::Tensor create_empty_from(
915 const at::Tensor& t,
916 c10::ScalarType dtype) {
917 return at::detail::empty_cpu(
918 {0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt);
919 }
920
embedding_bag_byte_rowwise_offsets(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)921 Tensor embedding_bag_byte_rowwise_offsets(
922 const Tensor& weight,
923 const Tensor& indices,
924 const std::optional<Tensor>& offsets_in,
925 const bool /* scale_grad_by_freq */,
926 const int64_t /* mode */,
927 bool pruned_weights,
928 const std::optional<Tensor>& per_sample_weights_,
929 const std::optional<Tensor>& compressed_indices_mapping,
930 bool include_last_offset) {
931 auto output = create_empty_from(weight, at::kFloat);
932 embedding_bag_byte_rowwise_offsets_out(
933 output,
934 weight,
935 indices,
936 offsets_in,
937 false /*unused scale_grad_by_freq*/,
938 0 /*unused mode*/,
939 pruned_weights,
940 per_sample_weights_,
941 compressed_indices_mapping,
942 include_last_offset);
943 return output;
944 }
945
embedding_bag_4bit_rowwise_offsets(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)946 Tensor embedding_bag_4bit_rowwise_offsets(
947 const Tensor& weight,
948 const Tensor& indices,
949 const std::optional<Tensor>& offsets_in,
950 const bool /* scale_grad_by_freq */,
951 const int64_t /* mode */,
952 bool pruned_weights,
953 const std::optional<Tensor>& per_sample_weights_,
954 const std::optional<Tensor>& compressed_indices_mapping,
955 bool include_last_offset) {
956 auto output = create_empty_from(weight, at::kFloat);
957 embedding_bag_4bit_rowwise_offsets_out(
958 output,
959 weight,
960 indices,
961 offsets_in,
962 false, // unused scale_grad_by_freq
963 0, // unused mode
964 pruned_weights,
965 per_sample_weights_,
966 compressed_indices_mapping,
967 include_last_offset);
968 return output;
969 }
970
embedding_bag_2bit_rowwise_offsets(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)971 Tensor embedding_bag_2bit_rowwise_offsets(
972 const Tensor& weight,
973 const Tensor& indices,
974 const std::optional<Tensor>& offsets_in,
975 const bool /* scale_grad_by_freq */,
976 const int64_t /* mode */,
977 bool pruned_weights,
978 const std::optional<Tensor>& per_sample_weights_,
979 const std::optional<Tensor>& compressed_indices_mapping,
980 bool include_last_offset) {
981 auto output = create_empty_from(weight, at::kFloat);
982 embedding_bag_2bit_rowwise_offsets_out(
983 output,
984 weight,
985 indices,
986 offsets_in,
987 false, // unused scale_grad_by_freq
988 0, // unused mode
989 pruned_weights,
990 per_sample_weights_,
991 compressed_indices_mapping,
992 include_last_offset);
993 return output;
994 }
995
embedding_bag_byte_rowwise_offsets_meta(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool,const std::optional<Tensor> &,const std::optional<Tensor> &,bool include_last_offset)996 Tensor embedding_bag_byte_rowwise_offsets_meta(
997 const Tensor& weight,
998 const Tensor& indices,
999 const std::optional<Tensor>& offsets_in,
1000 const bool /* scale_grad_by_freq */,
1001 const int64_t /* mode */,
1002 bool /* pruned_weights */,
1003 const std::optional<Tensor>& /* per_sample_weights_ */,
1004 const std::optional<Tensor>& /* compressed_indices_mapping */,
1005 bool include_last_offset) {
1006 TORCH_CHECK(
1007 indices.dim() == 1 || indices.dim() == 2,
1008 "quantized::embedding_bag_byte_rowwise_offsets_meta operator supports 1 or 2d indices, got ",
1009 indices.dim());
1010
1011 TORCH_CHECK(
1012 offsets_in.has_value(),
1013 "Currently quantized::embedding_bag_byte_rowwise_offsets_meta only supports having offsets.");
1014 c10::MaybeOwned<at::Tensor> offsets =
1015 c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
1016
1017 TORCH_CHECK(
1018 indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong,
1019 "Expect 32 or 64 bit indices, but found ",
1020 indices.scalar_type(),
1021 " instead.");
1022 TORCH_CHECK(
1023 offsets->scalar_type() == at::kInt || offsets->scalar_type() == at::kLong,
1024 "Expect 32 or 64 bit offsets, but found ",
1025 offsets->scalar_type(),
1026 " instead.");
1027
1028 const auto D = weight.sym_size(1) - 8; // NB: -8 to account for scale and bias
1029 const auto M = offsets->sym_size(0);
1030 const auto output_size = include_last_offset ? M - 1 : M;
1031
1032 return at::empty_symint({output_size, D}, weight.options().dtype(at::kFloat));
1033 }
1034
1035 template <int bit_rate>
1036 class QEmbeddingBag final {
1037 public:
run(const c10::intrusive_ptr<EmbeddingPackedParamsBase> & packed_weight,const Tensor & indices,const std::optional<Tensor> & offsets,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)1038 static at::Tensor run(
1039 const c10::intrusive_ptr<EmbeddingPackedParamsBase>& packed_weight,
1040 const Tensor& indices,
1041 const std::optional<Tensor>& offsets,
1042 const bool /* scale_grad_by_freq */,
1043 const int64_t /* mode */,
1044 bool pruned_weights,
1045 const std::optional<Tensor>& per_sample_weights_,
1046 const std::optional<Tensor>& compressed_indices_mapping,
1047 bool include_last_offset) {
1048 if (bit_rate == 8) {
1049 return packed_weight->embeddingbag_byte(
1050 indices,
1051 offsets,
1052 pruned_weights,
1053 per_sample_weights_,
1054 compressed_indices_mapping,
1055 include_last_offset,
1056 false /* is_embedding_op */);
1057 } else if (bit_rate == 4) {
1058 return packed_weight->embeddingbag_4bit(
1059 indices,
1060 offsets,
1061 pruned_weights,
1062 per_sample_weights_,
1063 compressed_indices_mapping,
1064 include_last_offset,
1065 false);
1066 } else {
1067 TORCH_INTERNAL_ASSERT(
1068 false,
1069 "Currently only support 8-bit embedding_bag quantization");
1070 }
1071 }
1072 };
1073
1074 template <int bit_rate>
1075 class QEmbedding final {
1076 public:
run(const c10::intrusive_ptr<EmbeddingPackedParamsBase> & packed_weight,const Tensor & indices,bool pruned_weights)1077 static at::Tensor run(
1078 const c10::intrusive_ptr<EmbeddingPackedParamsBase>& packed_weight,
1079 const Tensor& indices,
1080 bool pruned_weights) {
1081 // Set default offsets here since the FBGEMM lookup op expects it.
1082 const auto offsets_size = indices.numel();
1083 at::Tensor offsets = at::arange(0, offsets_size, indices.scalar_type());
1084 at::Tensor output;
1085 if (bit_rate == 8) {
1086 return packed_weight->embeddingbag_byte(
1087 indices,
1088 offsets,
1089 pruned_weights,
1090 std::nullopt,
1091 std::nullopt,
1092 false /* include_last_offset */,
1093 true /* is_embedding_op */);
1094 } else if (bit_rate == 4) {
1095 return packed_weight->embeddingbag_4bit(
1096 indices,
1097 offsets,
1098 pruned_weights,
1099 std::nullopt,
1100 std::nullopt,
1101 false,
1102 true);
1103 } else {
1104 TORCH_INTERNAL_ASSERT(
1105 false,
1106 "Currently only support 8-bit embedding quantization");
1107 }
1108 return output;
1109 }
1110 };
1111
TORCH_LIBRARY_IMPL(quantized,CPU,m)1112 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
1113 // Function that works on TorchBind packed weights.
1114 m.impl(
1115 TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte"),
1116 TORCH_FN(QEmbeddingBag<8>::run));
1117 m.impl(
1118 TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit"),
1119 TORCH_FN(QEmbeddingBag<4>::run));
1120 m.impl(
1121 TORCH_SELECTIVE_NAME("quantized::embedding_byte"),
1122 TORCH_FN(QEmbedding<8>::run));
1123 m.impl(
1124 TORCH_SELECTIVE_NAME("quantized::embedding_4bit"),
1125 TORCH_FN(QEmbedding<4>::run));
1126
1127 // Functions that work on at::Tensor packed weight.
1128 m.impl(
1129 TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_rowwise_offsets"),
1130 embedding_bag_byte_rowwise_offsets);
1131 m.impl(
1132 TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"),
1133 embedding_bag_4bit_rowwise_offsets);
1134 m.impl(
1135 TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_rowwise_offsets"),
1136 embedding_bag_2bit_rowwise_offsets);
1137 }
1138
TORCH_LIBRARY_IMPL(quantized,Meta,m)1139 TORCH_LIBRARY_IMPL(quantized, Meta, m) {
1140 m.impl(
1141 "quantized::embedding_bag_byte_rowwise_offsets",
1142 embedding_bag_byte_rowwise_offsets_meta);
1143 }
1144
1145 } // namespace
1146 } // namespace native
1147 } // namespace at
1148