xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
4 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
5 #include <ATen/native/quantized/cpu/qembeddingbag.h>
6 #include <torch/library.h>
7 #ifdef USE_FBGEMM
8 #include <fbgemm/Fbgemm.h>
9 #include <fbgemm/FbgemmEmbedding.h>
10 #endif
11 
12 #include <ATen/Parallel.h>
13 #include <ATen/Utils.h>
14 #include <c10/util/irange.h>
15 
16 #include <array>
17 
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/arange.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/resize_native.h>
25 #endif
26 
27 int register_embedding_params();
28 
29 namespace {
30 
31 // Fallback implementation when FBGEMM is not available.
32 template <
33     typename IndexType,
34     typename OffsetType,
35     int BIT_RATE,
36     int NUM_ELEM_PER_BYTE>
embedding_lookup_fallback_impl(const at::Tensor & weight,const at::Tensor & indices,const at::Tensor & offsets,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,at::Tensor & output,const int64_t block_size,const int64_t output_size,bool include_last_offset,bool pruned)37 at::Tensor& embedding_lookup_fallback_impl(
38     const at::Tensor& weight,
39     const at::Tensor& indices,
40     const at::Tensor& offsets,
41     const std::optional<at::Tensor>& per_sample_weights_,
42     const std::optional<at::Tensor>& compressed_indices_mapping,
43     at::Tensor& output,
44     const int64_t block_size,
45     const int64_t output_size,
46     bool include_last_offset,
47     bool pruned) {
48   auto* output_data = output.data_ptr<float>();
49   const auto weight_data = weight.data_ptr<uint8_t>();
50   const auto indices_data = indices.data_ptr<IndexType>();
51   int32_t* compressed_indices_mapping_data = nullptr;
52   const auto weight_sizes = weight.sizes();
53   const int64_t N = weight_sizes[0];
54   const int64_t weight_size = weight_sizes[1];
55   const int index_size = indices.numel();
56 
57   auto accessor = offsets.accessor<OffsetType, 1>();
58   std::vector<OffsetType> lengths_data;
59 
60   int64_t lower = accessor[0];
61   for (const auto i : c10::irange(1, offsets.numel())) {
62     lengths_data.push_back(accessor[i] - lower);
63     lower = accessor[i];
64   }
65   if (!include_last_offset) {
66     lengths_data.push_back(indices.numel() - lower);
67   }
68 
69   int64_t current = 0;
70   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
71   float* per_sample_weights_data;
72   if (per_sample_weights_.has_value()) {
73     per_sample_weights_data = per_sample_weights_.value().data_ptr<float>();
74   }
75   for (const auto m : c10::irange(output_size)) {
76     memset(output_data, 0, block_size * sizeof(float));
77     TORCH_CHECK(
78         current + lengths_data[m] <= index_size,
79         "Expect the lengths data to be less than indices size");
80 
81     for (int i = 0; i < lengths_data[m]; ++i, ++current) {
82       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
83       int64_t idx;
84       if (!pruned) {
85         idx = indices_data[current];
86         TORCH_CHECK((idx >= 0 && idx < N), "Invalid indices data");
87       } else {
88         int64_t uncompressed_idx = indices_data[current];
89         int compressed_index_size = compressed_indices_mapping.value().numel();
90         compressed_indices_mapping_data =
91             compressed_indices_mapping.value().data_ptr<int32_t>();
92         TORCH_CHECK(
93             uncompressed_idx >= 0 && uncompressed_idx < compressed_index_size,
94             "Invalid indices data for Sparse Op.")
95         idx = compressed_indices_mapping_data[uncompressed_idx];
96         if (idx == -1) {
97           continue;
98         }
99       }
100 
101       float weight_val = 1.0f;
102       if (per_sample_weights_.has_value()) {
103         weight_val = per_sample_weights_data[current];
104       }
105       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
106       float scale, bias;
107       if (BIT_RATE == 8) {
108         const uint8_t* scale_bias =
109             weight_data + (idx + 1) * weight_size - 2 * sizeof(float);
110         uint32_t scale_val_int32 = 0;
111 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
112         scale_val_int32 = scale_val_int32 |
113           (scale_bias[0]) |
114           (scale_bias[1] << 8) |
115           (scale_bias[2] << 16) |
116           (scale_bias[3] << 24);
117 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
118         scale_val_int32 = scale_val_int32 |
119           (scale_bias[3]) |
120           (scale_bias[2] << 8) |
121           (scale_bias[1] << 16) |
122           (scale_bias[0] << 24);
123 #else
124 #error Unexpected or undefined __BYTE_ORDER__
125 #endif
126         float scale_val = (reinterpret_cast<float*>(&scale_val_int32))[0];
127         uint32_t bias_val_int32 = 0;
128 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
129         bias_val_int32 = bias_val_int32 |
130           (scale_bias[4]) |
131           (scale_bias[5] << 8) |
132           (scale_bias[6] << 16) |
133           (scale_bias[7] << 24);
134 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
135         bias_val_int32 = bias_val_int32 |
136           (scale_bias[7]) |
137           (scale_bias[6] << 8) |
138           (scale_bias[5] << 16) |
139           (scale_bias[4] << 24);
140 #else
141 #error Unexpected or undefined __BYTE_ORDER__
142 #endif
143         float bias_val = (reinterpret_cast<float*>(&bias_val_int32))[0];
144         scale = weight_val * scale_val;
145         bias = weight_val * bias_val;
146       } else {
147         const uint8_t* scale_bias =
148             weight_data + (idx + 1) * weight_size - 2 * sizeof(at::Half);
149         uint16_t scale_val_int16 = 0;
150 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
151         scale_val_int16 = scale_val_int16 |
152           (scale_bias[0]) |
153           (scale_bias[1] << 8);
154 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
155         scale_val_int16 = scale_val_int16 |
156           (scale_bias[1]) |
157           (scale_bias[0] << 8);
158 #else
159 #error Unexpected or undefined __BYTE_ORDER__
160 #endif
161         at::Half scale_val = (reinterpret_cast<at::Half*>(&scale_val_int16))[0];
162         uint16_t bias_val_int16 = 0;
163 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
164         bias_val_int16 = bias_val_int16 |
165           (scale_bias[2]) |
166           (scale_bias[3] << 8);
167 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
168         bias_val_int16 = bias_val_int16 |
169           (scale_bias[3]) |
170           (scale_bias[2] << 8);
171 #else
172 #error Unexpected or undefined __BYTE_ORDER__
173 #endif
174         at::Half bias_val = (reinterpret_cast<at::Half*>(&bias_val_int16))[0];
175         scale = weight_val * scale_val;
176         bias = weight_val * bias_val;
177       }
178 
179       for (const auto j : c10::irange(block_size)) {
180         uint8_t quantized =
181             weight_data[idx * weight_size + j / NUM_ELEM_PER_BYTE];
182         quantized >>= (j % NUM_ELEM_PER_BYTE) * BIT_RATE;
183         quantized &= (1 << BIT_RATE) - 1;
184 
185         output_data[j] = fma(scale, quantized, output_data[j] + bias);
186       }
187     } // for each i
188     output_data += block_size;
189   } // for each m
190   return output;
191 }
192 
193 namespace {
194 template <typename IndexType, typename OffsetType>
fbgemm_spmdm_report_error_(int64_t output_size,int index_size,int64_t N,const OffsetType * offsets,const IndexType * indices)195 void fbgemm_spmdm_report_error_(
196     int64_t output_size,
197     int index_size,
198     int64_t N,
199     const OffsetType* offsets,
200     const IndexType* indices) {
201   for (const auto m : c10::irange(output_size)) {
202     for (OffsetType i = offsets[m]; i < offsets[m + 1]; ++i) {
203       TORCH_CHECK(i < index_size);
204       IndexType idx = indices[i];
205       TORCH_CHECK(
206           0 <= idx && idx < N,
207           "Index ",
208           i,
209           " is out of bounds: ",
210           idx,
211           ", range 0 to ",
212           N);
213     }
214   }
215   TORCH_CHECK(
216       offsets[output_size] == index_size,
217       "Yout input seems to be incorrect: the last offset value should be "
218       "the size of the indices tensor, but it appears not.");
219 }
220 } // namespace
221 
222 template <typename IndexType, typename OffsetType>
embedding_bag_nbit_impl(at::Tensor & output,const at::Tensor & weight,const int bit_width,const at::Tensor & indices,const at::Tensor & offsets,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)223 at::Tensor& embedding_bag_nbit_impl(
224     at::Tensor& output,
225     const at::Tensor& weight,
226     const int bit_width,
227     const at::Tensor& indices,
228     const at::Tensor& offsets,
229     bool pruned_weights,
230     const std::optional<at::Tensor>& per_sample_weights_,
231     const std::optional<at::Tensor>& compressed_indices_mapping,
232     bool include_last_offset,
233     bool is_embedding_op) {
234   TORCH_CHECK(weight.dim() == 2);
235   TORCH_CHECK(offsets.dim() == 1);
236 
237   auto offsets_data = offsets.data_ptr<OffsetType>();
238 
239   // Get compressed indices for pruned_weights op.
240   int32_t* compressed_indices_mapping_data = nullptr;
241   int compressed_index_size = 0;
242   bool fallback_to_no_sparse = false;
243   if (pruned_weights) {
244     compressed_index_size = compressed_indices_mapping.value().numel();
245     compressed_indices_mapping_data =
246         compressed_indices_mapping.value().data_ptr<int32_t>();
247 
248     // if compressed_indices_mapping is [0], it is a indicator that
249     // we should fallback to non sparse embedding look up kernel.
250     if ((compressed_index_size == 1 &&
251          compressed_indices_mapping_data[0] == 0)) {
252       fallback_to_no_sparse = true;
253     }
254   }
255 
256   const auto weight_sizes = weight.sizes();
257   const int64_t weight_size = weight_sizes[1];
258   int NUM_ELEM_PER_BYTE = 8 / bit_width;
259   const int64_t D =
260       (weight_size - 2 * sizeof(at::Half)) * NUM_ELEM_PER_BYTE; // NB: 2-byte fp16 scale and 2-byte zero_offset
261   const int64_t M = offsets.sizes()[0];
262 
263   int64_t output_size = M - 1;
264   std::vector<OffsetType> offsets_include_last_val;
265   if (!include_last_offset) {
266     output_size = M;
267     offsets_include_last_val.resize(M + 1);
268     // Avoid `null pointer passed as argument 2` ASAN violation when offsets
269     // tensor is empty.
270     if (M > 0) {
271       std::memcpy(
272           offsets_include_last_val.data(),
273           offsets_data,
274           sizeof(OffsetType) * M);
275     }
276     offsets_include_last_val[M] = indices.numel();
277     offsets_data = offsets_include_last_val.data();
278   }
279   {
280     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
281     std::array<int64_t, 3> shape_arr;
282     c10::IntArrayRef shape;
283     if(indices.dim() == 2 && is_embedding_op) {
284       const auto indices_sizes = indices.sizes();
285       shape_arr[0] = indices_sizes[0];
286       shape_arr[1] = indices_sizes[1];
287       shape_arr[2] = D;
288       shape = shape_arr;
289     } else {
290       shape_arr[0] = output_size;
291       shape_arr[1] = D;
292       shape = c10::IntArrayRef(&shape_arr[0], 2);
293     }
294     at::native::resize_(output, shape, std::nullopt);
295   }
296 #ifdef USE_FBGEMM
297   const auto indices_data = indices.data_ptr<IndexType>();
298   const auto weight_data = weight.data_ptr<uint8_t>();
299   auto* output_data = output.data_ptr<float>();
300   const int64_t N = weight_sizes[0];
301 
302   const int64_t block_size = D;
303   const int index_size = indices.numel();
304   constexpr int prefetch_distance = 16;
305   if (!pruned_weights || fallback_to_no_sparse) {
306     // Generate the fbgemm kernel
307     auto kernel = fbgemm::GenerateEmbeddingSpMDMNBit<IndexType, OffsetType>(
308         /*bit rate=*/bit_width,
309         /*block size=*/block_size,
310         /*has weights=*/per_sample_weights_.has_value(),
311         /*normalize_by_lengths=*/false,
312         /*prefetch distance=*/prefetch_distance,
313         /*is_weight_positional=*/false,
314         /*use_offsets=*/true);
315 
316     bool success = kernel(
317         /*output_size=*/output_size,
318         /*index_size=*/index_size,
319         /*data_size=*/N,
320         /*input=*/weight_data,
321         /*indices=*/indices_data,
322         /*offsets=*/offsets_data,
323         /*weights=*/
324         per_sample_weights_.has_value()
325             ? per_sample_weights_.value().data_ptr<float>()
326             : nullptr,
327         /*output=*/output_data);
328 
329     if (!success) {
330       fbgemm_spmdm_report_error_(
331           output_size, index_size, N, offsets_data, indices_data);
332     }
333   } else {
334     auto kernel =
335         fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<IndexType, OffsetType>(
336             /*bit rate=*/bit_width,
337             /*block_size=*/block_size,
338             /*has weights=*/per_sample_weights_.has_value(),
339             /*normalize_by_lengths=*/false,
340             /*prefetch distance*/ prefetch_distance,
341             /*is_weight_positional*/ false,
342             /*use_offsets*/ true);
343     bool success = kernel(
344         /*output_size=*/output_size,
345         /*index_size=*/index_size,
346         /*data_size=*/compressed_index_size,
347         /*input=*/weight_data,
348         /*indices=*/indices_data,
349         /*offsets=*/offsets_data,
350         /*weights=*/
351         per_sample_weights_.has_value()
352             ? per_sample_weights_.value().data_ptr<float>()
353             : nullptr,
354         /*output=*/output_data,
355         /*compressed_indices_table=*/compressed_indices_mapping_data);
356     if (!success) {
357       fbgemm_spmdm_report_error_(
358           output_size,
359           index_size,
360           compressed_index_size,
361           offsets_data,
362           indices_data);
363     }
364   }
365   return output;
366 #else
367   if (bit_width == 4) {
368     return embedding_lookup_fallback_impl<IndexType, OffsetType, 4, 2>(
369       weight,
370       indices,
371       offsets,
372       per_sample_weights_,
373       compressed_indices_mapping,
374       output,
375       D,
376       output_size,
377       include_last_offset,
378       (pruned_weights && !fallback_to_no_sparse));
379   }
380   // bit_width == 2
381   return embedding_lookup_fallback_impl<IndexType, OffsetType, 2, 4>(
382     weight,
383     indices,
384     offsets,
385     per_sample_weights_,
386     compressed_indices_mapping,
387     output,
388     D,
389     output_size,
390     include_last_offset,
391     (pruned_weights && !fallback_to_no_sparse));
392 #endif
393 }
394 
395 template <typename IndexType, typename OffsetType>
embedding_bag_byte_impl(at::Tensor & output,const at::Tensor & weight,const at::Tensor & indices,const at::Tensor & offsets,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)396 at::Tensor& embedding_bag_byte_impl(
397     at::Tensor& output,
398     const at::Tensor& weight,
399     const at::Tensor& indices,
400     const at::Tensor& offsets,
401     bool pruned_weights,
402     const std::optional<at::Tensor>& per_sample_weights_,
403     const std::optional<at::Tensor>& compressed_indices_mapping,
404     bool include_last_offset,
405     bool is_embedding_op) {
406   TORCH_CHECK(weight.scalar_type() == at::kByte);
407   TORCH_CHECK(weight.dim() == 2);
408   TORCH_CHECK(offsets.dim() == 1);
409   auto offsets_data = offsets.data_ptr<OffsetType>();
410 
411   // Get compressed indices for pruned_weights.
412   int32_t* compressed_indices_mapping_data = nullptr;
413   int compressed_index_size = 0;
414   bool fallback_to_no_sparse = false;
415   if (pruned_weights) {
416     compressed_index_size = compressed_indices_mapping.value().numel();
417     compressed_indices_mapping_data =
418         compressed_indices_mapping.value().data_ptr<int32_t>();
419 
420     // if compressed_indices_mapping is [0], it is a indicator that
421     // we should fallback to non sparse embedding look up kernel.
422     if ((compressed_index_size == 1 &&
423          compressed_indices_mapping_data[0] == 0)) {
424       fallback_to_no_sparse = true;
425     }
426   }
427 
428   const auto weight_sizes = weight.sizes();
429   const int64_t D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias
430   const int64_t M = offsets.sizes()[0];
431 
432   int64_t output_size = M - 1;
433   std::vector<OffsetType> offsets_include_last_val;
434 
435   if (!include_last_offset) {
436     output_size = M;
437     offsets_include_last_val.resize(M + 1);
438     // Avoid `null pointer passed as argument 2` ASAN violation when offsets
439     // tensor is empty.
440     if (M > 0) {
441       std::memcpy(
442           offsets_include_last_val.data(),
443           offsets_data,
444           sizeof(OffsetType) * M);
445     }
446     offsets_include_last_val[M] = indices.numel();
447     offsets_data = offsets_include_last_val.data();
448   }
449   {
450     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
451     std::array<int64_t, 3> shape_arr;
452     c10::IntArrayRef shape;
453     if (indices.dim() == 2 && is_embedding_op) {
454       const auto indices_sizes = indices.sizes();
455       shape_arr[0] = indices_sizes[0];
456       shape_arr[1] = indices_sizes[1];
457       shape_arr[2] = D;
458       shape = shape_arr;
459     } else {
460       shape_arr[0] = output_size;
461       shape_arr[1] = D;
462       shape = c10::IntArrayRef(&shape_arr[0], 2);
463     }
464     at::native::resize_(output, shape, std::nullopt);
465   }
466 #ifdef USE_FBGEMM
467   const int64_t N = weight_sizes[0];
468   const auto weight_data = weight.data_ptr<uint8_t>();
469   const auto indices_data = indices.data_ptr<IndexType>();
470   auto* output_data = output.data_ptr<float>();
471   const int index_size = indices.numel();
472 
473   if (!pruned_weights || fallback_to_no_sparse) {
474     auto kernel_i8 =
475         fbgemm::GenerateEmbeddingSpMDM<uint8_t, IndexType, OffsetType, /*OutType=*/float, /*TRHEAD_LOCAL=*/true>(
476             /*block_size=*/D,
477             /*has_weight=*/per_sample_weights_.has_value(),
478             /*normalize_by_lengths=*/false,
479             /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
480             /*is_weight_positional=*/false,
481             /*use_offsets=*/true);
482 
483     at::parallel_for(
484         0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
485           bool success = kernel_i8(
486               /*output_size=*/end_idx - start_idx,
487               /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
488               /*data_size=*/N,
489               /*input=*/weight_data,
490               /*indices=*/indices_data + offsets_data[start_idx],
491               /*offsets_or_lengths=*/offsets_data + start_idx,
492               /*weights=*/
493               per_sample_weights_
494                   ? per_sample_weights_.value().const_data_ptr<float>() +
495                       offsets_data[start_idx]
496                   : nullptr,
497               /*out=*/output_data + start_idx * D);
498 
499           if (!success) {
500             fbgemm_spmdm_report_error_(
501                 end_idx - start_idx,
502                 offsets_data[end_idx] - offsets_data[start_idx],
503                 N,
504                 offsets_data + start_idx,
505                 indices_data + offsets_data[start_idx]);
506           }
507         });
508   } else {
509     // pruned weights
510     auto kernel_i8_sparse = fbgemm::
511         GenerateEmbeddingSpMDMRowWiseSparse<uint8_t, IndexType, OffsetType>(
512             /*block_size=*/D,
513             /*has_weight=*/per_sample_weights_.has_value(),
514             /*normalize_by_lengths=*/false,
515             /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
516             /*is_weight_positional=*/false,
517             /*use_offsets=*/true);
518 
519     auto success = kernel_i8_sparse(
520         /*output_size=*/output_size,
521         /*index_size=*/index_size,
522         /*data_size=*/compressed_index_size,
523         /*input=*/weight_data,
524         /*indices=*/indices_data,
525         /*offsets=*/offsets_data,
526         /*weights=*/
527         per_sample_weights_.has_value()
528             ? per_sample_weights_.value().data_ptr<float>()
529             : nullptr,
530         /*output=*/output_data,
531         /*compressed_indices_table=*/compressed_indices_mapping_data);
532     if (!success) {
533       fbgemm_spmdm_report_error_(
534           output_size,
535           index_size,
536           compressed_index_size,
537           offsets_data,
538           indices_data);
539     }
540   }
541   return output;
542 #else
543   return embedding_lookup_fallback_impl<IndexType, OffsetType, 8, 1>(
544       weight,
545       indices,
546       offsets,
547       per_sample_weights_,
548       compressed_indices_mapping,
549       output,
550       D,
551       output_size,
552       include_last_offset,
553       (pruned_weights && !fallback_to_no_sparse));
554 #endif
555 }
556 
embedding_bag_byte_helper(at::Tensor & output,const at::Tensor & weight,const at::Tensor & indices,const std::optional<at::Tensor> & offsets_in,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)557 at::Tensor& embedding_bag_byte_helper(
558     at::Tensor& output,
559     const at::Tensor& weight,
560     const at::Tensor& indices,
561     const std::optional<at::Tensor>& offsets_in,
562     bool pruned_weights,
563     const std::optional<at::Tensor>& per_sample_weights_,
564     const std::optional<at::Tensor>& compressed_indices_mapping,
565     bool include_last_offset,
566     bool is_embedding_op) {
567   c10::MaybeOwned<at::Tensor> offsets;
568   TORCH_CHECK(
569       indices.dim() == 1 || indices.dim() == 2,
570       "qembedding/qembedding_bag operator supports 1 or 2d indices, got ",
571       indices.dim());
572   // For embedding_bag operator with 2D indices, we set the offsets explicitly
573   // here.
574   if (indices.dim() == 2 && !is_embedding_op) {
575     TORCH_CHECK(
576         !offsets_in.has_value(),
577         "embedding_bag_byte operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
578 
579     offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
580   } else {
581     TORCH_CHECK(
582         offsets_in.has_value(),
583         "embedding_bag_byte expects offsets to be set for 1D indices.");
584     offsets = c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
585   }
586 
587   TORCH_CHECK(
588       indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong,
589       "Expect 32 or 64 bit indices, but found ",
590       indices.scalar_type(),
591       " instead.");
592   TORCH_CHECK(
593       offsets->scalar_type() == at::kInt || offsets->scalar_type() == at::kLong,
594       "Expect 32 or 64 bit offsets, but found ",
595       offsets->scalar_type(),
596       " instead.");
597   TORCH_CHECK(
598       weight.is_contiguous() && indices.is_contiguous() &&
599           offsets->is_contiguous(),
600       "Expect weight, indices, and offsets to be contiguous.");
601 
602   // Using helper function to support different type combination without the
603   // need to cast, which can be additional performance overhead
604   if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) {
605     return embedding_bag_byte_impl<int, int>(
606         output,
607         weight,
608         indices,
609         *offsets,
610         pruned_weights,
611         per_sample_weights_,
612         compressed_indices_mapping,
613         include_last_offset,
614         is_embedding_op);
615   } else if (
616       indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kLong) {
617     return embedding_bag_byte_impl<int, int64_t>(
618         output,
619         weight,
620         indices,
621         *offsets,
622         pruned_weights,
623         per_sample_weights_,
624         compressed_indices_mapping,
625         include_last_offset,
626         is_embedding_op);
627   } else if (
628       indices.scalar_type() == at::kLong && offsets->scalar_type() == at::kInt) {
629     return embedding_bag_byte_impl<int64_t, int>(
630         output,
631         weight,
632         indices,
633         *offsets,
634         pruned_weights,
635         per_sample_weights_,
636         compressed_indices_mapping,
637         include_last_offset,
638         is_embedding_op);
639   }
640 
641   // default case given the TORCH_CHECK above
642   return embedding_bag_byte_impl<int64_t, int64_t>(
643       output,
644       weight,
645       indices,
646       *offsets,
647       pruned_weights,
648       per_sample_weights_,
649       compressed_indices_mapping,
650       include_last_offset,
651       is_embedding_op);
652 }
653 
_embedding_bag_nbit_helper(at::Tensor & output,const at::Tensor & weight,const int bit_width,const at::Tensor & indices,const std::optional<at::Tensor> & offsets_in,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)654 at::Tensor& _embedding_bag_nbit_helper(
655     at::Tensor& output,
656     const at::Tensor& weight,
657     const int bit_width,
658     const at::Tensor& indices,
659     const std::optional<at::Tensor>& offsets_in,
660     bool pruned_weights,
661     const std::optional<at::Tensor>& per_sample_weights_,
662     const std::optional<at::Tensor>& compressed_indices_mapping,
663     bool include_last_offset,
664     bool is_embedding_op) {
665   c10::MaybeOwned<at::Tensor> offsets;
666   TORCH_CHECK(
667       bit_width == 4 || bit_width == 2,
668       "qembedding/qembedding_bag operator supports bit_width 2 or 4, got ",
669       bit_width);
670   TORCH_CHECK(
671       indices.dim() == 1 || indices.dim() == 2,
672       "qembedding/qembedding_bag operator supports 1 or 2d indices, got ",
673       indices.dim());
674 
675   // For embedding_bag operator with 2D indices, we need to set the offsets
676   // explicitly here.
677   if (indices.dim() == 2 && !is_embedding_op) {
678     TORCH_CHECK(
679         !offsets_in.has_value(),
680         "embedding_bag_4bit/embedding_bag_2bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
681 
682     offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(
683         0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
684   } else {
685     TORCH_CHECK(
686         offsets_in.has_value(),
687         "embedding_bag_4bit/embedding_bag_2bit operator expects offsets to be set for 1D indices.");
688     offsets = c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
689   }
690 
691   TORCH_CHECK(
692       indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong,
693       "Expect 32 or 64 bit indices, but found ",
694       indices.scalar_type(),
695       " instead.");
696   TORCH_CHECK(
697       offsets->scalar_type() == at::kInt || offsets->scalar_type() == at::kLong,
698       "Expect 32 or 64 bit offsets, but found ",
699       offsets->scalar_type(),
700       " instead.");
701   TORCH_CHECK(
702       weight.is_contiguous() && indices.is_contiguous() &&
703           offsets->is_contiguous(),
704       "Expect weight, indices, and offsets to be contiguous.");
705 
706   // Using helper function to support different type combination without the
707   // need to cast, which can be additional performance overhead
708   if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) {
709     return embedding_bag_nbit_impl<int, int>(
710         output,
711         weight,
712         bit_width,
713         indices,
714         *offsets,
715         pruned_weights,
716         per_sample_weights_,
717         compressed_indices_mapping,
718         include_last_offset,
719         is_embedding_op);
720   } else if (
721       indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kLong) {
722     return embedding_bag_nbit_impl<int, int64_t>(
723         output,
724         weight,
725         bit_width,
726         indices,
727         *offsets,
728         pruned_weights,
729         per_sample_weights_,
730         compressed_indices_mapping,
731         include_last_offset,
732         is_embedding_op);
733   } else if (
734       indices.scalar_type() == at::kLong && offsets->scalar_type() == at::kInt) {
735     return embedding_bag_nbit_impl<int64_t, int>(
736         output,
737         weight,
738         bit_width,
739         indices,
740         *offsets,
741         pruned_weights,
742         per_sample_weights_,
743         compressed_indices_mapping,
744         include_last_offset,
745         is_embedding_op);
746   }
747   return embedding_bag_nbit_impl<int64_t, int64_t>(
748       output,
749       weight,
750       bit_width,
751       indices,
752       *offsets,
753       pruned_weights,
754       per_sample_weights_,
755       compressed_indices_mapping,
756       include_last_offset,
757       is_embedding_op);
758 }
759 } // namespace
760 
embeddingbag_byte(const at::Tensor & indices,const std::optional<at::Tensor> & offsets_in,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)761 at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
762     const at::Tensor& indices,
763     const std::optional<at::Tensor>& offsets_in,
764     bool pruned_weights,
765     const std::optional<at::Tensor>& per_sample_weights_,
766     const std::optional<at::Tensor>& compressed_indices_mapping,
767     bool include_last_offset,
768     bool is_embedding_op) {
769   auto output = at::empty({0}, packed_w.options().dtype(at::kFloat));
770   return embedding_bag_byte_helper(
771       output,
772       packed_w,
773       indices,
774       offsets_in,
775       pruned_weights,
776       per_sample_weights_,
777       compressed_indices_mapping,
778       include_last_offset,
779       is_embedding_op);
780 }
781 
embeddingbag_4bit(const at::Tensor & indices,const std::optional<at::Tensor> & offsets_in,bool pruned_weights,const std::optional<at::Tensor> & per_sample_weights_,const std::optional<at::Tensor> & compressed_indices_mapping,bool include_last_offset,bool is_embedding_op)782 at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
783     const at::Tensor& indices,
784     const std::optional<at::Tensor>& offsets_in,
785     bool pruned_weights,
786     const std::optional<at::Tensor>& per_sample_weights_,
787     const std::optional<at::Tensor>& compressed_indices_mapping,
788     bool include_last_offset,
789     bool is_embedding_op) {
790   if (per_sample_weights_.has_value()) {
791     TORCH_CHECK(
792         (per_sample_weights_.value().scalar_type() == at::kFloat ||
793          per_sample_weights_.value().scalar_type() == at::kHalf),
794         "Expect fp32 or fp16 weights, but found",
795         per_sample_weights_.value().scalar_type(),
796         " instead")
797   }
798 
799   auto output = at::empty({0}, packed_w.options().dtype(at::kFloat));
800   return _embedding_bag_nbit_helper(
801     output,
802     packed_w,
803     4,
804     indices,
805     offsets_in,
806     pruned_weights,
807     per_sample_weights_.has_value()
808         ? per_sample_weights_.value().to(at::kFloat)
809         : per_sample_weights_,
810     compressed_indices_mapping,
811     include_last_offset,
812     is_embedding_op);
813 }
814 
815 namespace at {
816 namespace native {
817 
embedding_bag_byte_rowwise_offsets_out(Tensor & output,const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)818 Tensor& embedding_bag_byte_rowwise_offsets_out(
819     Tensor& output,
820     const Tensor& weight,
821     const Tensor& indices,
822     const std::optional<Tensor>& offsets_in,
823     const bool /* scale_grad_by_freq */,
824     const int64_t /* mode */,
825     bool pruned_weights,
826     const std::optional<Tensor>& per_sample_weights_,
827     const std::optional<Tensor>& compressed_indices_mapping,
828     bool include_last_offset) {
829   return embedding_bag_byte_helper(
830       output,
831       weight,
832       indices,
833       offsets_in,
834       pruned_weights,
835       per_sample_weights_,
836       compressed_indices_mapping,
837       include_last_offset,
838       false /* is_embedding_op */);
839 }
840 
embedding_bag_4bit_rowwise_offsets_out(Tensor & output,const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)841 Tensor& embedding_bag_4bit_rowwise_offsets_out(
842     Tensor& output,
843     const Tensor& weight,
844     const Tensor& indices,
845     const std::optional<Tensor>& offsets_in,
846     const bool /* scale_grad_by_freq */,
847     const int64_t /* mode */,
848     bool pruned_weights,
849     const std::optional<Tensor>& per_sample_weights_,
850     const std::optional<Tensor>& compressed_indices_mapping,
851     bool include_last_offset) {
852 
853   if (per_sample_weights_.has_value()) {
854     TORCH_CHECK(
855         (per_sample_weights_.value().scalar_type() == at::kFloat ||
856          per_sample_weights_.value().scalar_type() == at::kHalf),
857         "Expect fp32 or fp16 weights, but found",
858         per_sample_weights_.value().scalar_type(),
859         " instead")
860   }
861   return _embedding_bag_nbit_helper(
862       output,
863       weight,
864       4,
865       indices,
866       offsets_in,
867       pruned_weights,
868       per_sample_weights_.has_value()
869           ? per_sample_weights_.value().to(at::kFloat)
870           : per_sample_weights_,
871       compressed_indices_mapping,
872       include_last_offset,
873       false);
874 }
875 
embedding_bag_2bit_rowwise_offsets_out(Tensor & output,const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)876 static Tensor& embedding_bag_2bit_rowwise_offsets_out(
877     Tensor& output,
878     const Tensor& weight,
879     const Tensor& indices,
880     const std::optional<Tensor>& offsets_in,
881     const bool /* scale_grad_by_freq */,
882     const int64_t /* mode */,
883     bool pruned_weights,
884     const std::optional<Tensor>& per_sample_weights_,
885     const std::optional<Tensor>& compressed_indices_mapping,
886     bool include_last_offset) {
887 
888   if (per_sample_weights_.has_value()) {
889     TORCH_CHECK(
890         (per_sample_weights_.value().scalar_type() == at::kFloat ||
891          per_sample_weights_.value().scalar_type() == at::kHalf),
892         "Expect fp32 or fp16 weights, but found",
893         per_sample_weights_.value().scalar_type(),
894         " instead")
895   }
896   return _embedding_bag_nbit_helper(
897       output,
898       weight,
899       2,
900       indices,
901       offsets_in,
902       pruned_weights,
903       per_sample_weights_.has_value()
904           ? per_sample_weights_.value().to(at::kFloat)
905           : per_sample_weights_,
906       compressed_indices_mapping,
907       include_last_offset,
908       false);
909 }
910 
911 namespace {
912 
913 
create_empty_from(const at::Tensor & t,c10::ScalarType dtype)914 inline at::Tensor create_empty_from(
915     const at::Tensor& t,
916     c10::ScalarType dtype) {
917   return at::detail::empty_cpu(
918       {0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt);
919 }
920 
embedding_bag_byte_rowwise_offsets(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)921 Tensor embedding_bag_byte_rowwise_offsets(
922     const Tensor& weight,
923     const Tensor& indices,
924     const std::optional<Tensor>& offsets_in,
925     const bool /* scale_grad_by_freq */,
926     const int64_t /* mode */,
927     bool pruned_weights,
928     const std::optional<Tensor>& per_sample_weights_,
929     const std::optional<Tensor>& compressed_indices_mapping,
930     bool include_last_offset) {
931   auto output = create_empty_from(weight, at::kFloat);
932   embedding_bag_byte_rowwise_offsets_out(
933       output,
934       weight,
935       indices,
936       offsets_in,
937       false /*unused scale_grad_by_freq*/,
938       0 /*unused mode*/,
939       pruned_weights,
940       per_sample_weights_,
941       compressed_indices_mapping,
942       include_last_offset);
943   return output;
944 }
945 
embedding_bag_4bit_rowwise_offsets(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)946 Tensor embedding_bag_4bit_rowwise_offsets(
947     const Tensor& weight,
948     const Tensor& indices,
949     const std::optional<Tensor>& offsets_in,
950     const bool /* scale_grad_by_freq */,
951     const int64_t /* mode */,
952     bool pruned_weights,
953     const std::optional<Tensor>& per_sample_weights_,
954     const std::optional<Tensor>& compressed_indices_mapping,
955     bool include_last_offset) {
956   auto output = create_empty_from(weight, at::kFloat);
957   embedding_bag_4bit_rowwise_offsets_out(
958     output,
959     weight,
960     indices,
961     offsets_in,
962     false, // unused scale_grad_by_freq
963     0, // unused mode
964     pruned_weights,
965     per_sample_weights_,
966     compressed_indices_mapping,
967     include_last_offset);
968   return output;
969 }
970 
embedding_bag_2bit_rowwise_offsets(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)971 Tensor embedding_bag_2bit_rowwise_offsets(
972     const Tensor& weight,
973     const Tensor& indices,
974     const std::optional<Tensor>& offsets_in,
975     const bool /* scale_grad_by_freq */,
976     const int64_t /* mode */,
977     bool pruned_weights,
978     const std::optional<Tensor>& per_sample_weights_,
979     const std::optional<Tensor>& compressed_indices_mapping,
980     bool include_last_offset) {
981   auto output = create_empty_from(weight, at::kFloat);
982   embedding_bag_2bit_rowwise_offsets_out(
983     output,
984     weight,
985     indices,
986     offsets_in,
987     false, // unused scale_grad_by_freq
988     0, // unused mode
989     pruned_weights,
990     per_sample_weights_,
991     compressed_indices_mapping,
992     include_last_offset);
993   return output;
994 }
995 
embedding_bag_byte_rowwise_offsets_meta(const Tensor & weight,const Tensor & indices,const std::optional<Tensor> & offsets_in,const bool,const int64_t,bool,const std::optional<Tensor> &,const std::optional<Tensor> &,bool include_last_offset)996 Tensor embedding_bag_byte_rowwise_offsets_meta(
997     const Tensor& weight,
998     const Tensor& indices,
999     const std::optional<Tensor>& offsets_in,
1000     const bool /* scale_grad_by_freq */,
1001     const int64_t /* mode */,
1002     bool /* pruned_weights */,
1003     const std::optional<Tensor>& /* per_sample_weights_ */,
1004     const std::optional<Tensor>& /* compressed_indices_mapping */,
1005     bool include_last_offset) {
1006   TORCH_CHECK(
1007       indices.dim() == 1 || indices.dim() == 2,
1008       "quantized::embedding_bag_byte_rowwise_offsets_meta operator supports 1 or 2d indices, got ",
1009       indices.dim());
1010 
1011   TORCH_CHECK(
1012       offsets_in.has_value(),
1013       "Currently quantized::embedding_bag_byte_rowwise_offsets_meta only supports having offsets.");
1014   c10::MaybeOwned<at::Tensor> offsets =
1015       c10::MaybeOwned<at::Tensor>::borrowed(offsets_in.value());
1016 
1017   TORCH_CHECK(
1018       indices.scalar_type() == at::kInt || indices.scalar_type() == at::kLong,
1019       "Expect 32 or 64 bit indices, but found ",
1020       indices.scalar_type(),
1021       " instead.");
1022   TORCH_CHECK(
1023       offsets->scalar_type() == at::kInt || offsets->scalar_type() == at::kLong,
1024       "Expect 32 or 64 bit offsets, but found ",
1025       offsets->scalar_type(),
1026       " instead.");
1027 
1028   const auto D = weight.sym_size(1) - 8; // NB: -8 to account for scale and bias
1029   const auto M = offsets->sym_size(0);
1030   const auto output_size = include_last_offset ? M - 1 : M;
1031 
1032   return at::empty_symint({output_size, D}, weight.options().dtype(at::kFloat));
1033 }
1034 
1035 template <int bit_rate>
1036 class QEmbeddingBag final {
1037  public:
run(const c10::intrusive_ptr<EmbeddingPackedParamsBase> & packed_weight,const Tensor & indices,const std::optional<Tensor> & offsets,const bool,const int64_t,bool pruned_weights,const std::optional<Tensor> & per_sample_weights_,const std::optional<Tensor> & compressed_indices_mapping,bool include_last_offset)1038   static at::Tensor run(
1039       const c10::intrusive_ptr<EmbeddingPackedParamsBase>& packed_weight,
1040       const Tensor& indices,
1041       const std::optional<Tensor>& offsets,
1042       const bool /* scale_grad_by_freq */,
1043       const int64_t /* mode */,
1044       bool pruned_weights,
1045       const std::optional<Tensor>& per_sample_weights_,
1046       const std::optional<Tensor>& compressed_indices_mapping,
1047       bool include_last_offset) {
1048     if (bit_rate == 8) {
1049       return packed_weight->embeddingbag_byte(
1050           indices,
1051           offsets,
1052           pruned_weights,
1053           per_sample_weights_,
1054           compressed_indices_mapping,
1055           include_last_offset,
1056           false /* is_embedding_op */);
1057     } else if (bit_rate == 4) {
1058       return packed_weight->embeddingbag_4bit(
1059           indices,
1060           offsets,
1061           pruned_weights,
1062           per_sample_weights_,
1063           compressed_indices_mapping,
1064           include_last_offset,
1065           false);
1066     } else {
1067       TORCH_INTERNAL_ASSERT(
1068           false,
1069           "Currently only support 8-bit embedding_bag quantization");
1070     }
1071   }
1072 };
1073 
1074 template <int bit_rate>
1075 class QEmbedding final {
1076  public:
run(const c10::intrusive_ptr<EmbeddingPackedParamsBase> & packed_weight,const Tensor & indices,bool pruned_weights)1077   static at::Tensor run(
1078       const c10::intrusive_ptr<EmbeddingPackedParamsBase>& packed_weight,
1079       const Tensor& indices,
1080       bool pruned_weights) {
1081     // Set default offsets here since the FBGEMM lookup op expects it.
1082     const auto offsets_size = indices.numel();
1083     at::Tensor offsets = at::arange(0, offsets_size, indices.scalar_type());
1084     at::Tensor output;
1085     if (bit_rate == 8) {
1086       return packed_weight->embeddingbag_byte(
1087           indices,
1088           offsets,
1089           pruned_weights,
1090           std::nullopt,
1091           std::nullopt,
1092           false /* include_last_offset */,
1093           true /* is_embedding_op */);
1094     } else if (bit_rate == 4) {
1095       return packed_weight->embeddingbag_4bit(
1096           indices,
1097           offsets,
1098           pruned_weights,
1099           std::nullopt,
1100           std::nullopt,
1101           false,
1102           true);
1103     } else {
1104       TORCH_INTERNAL_ASSERT(
1105           false,
1106           "Currently only support 8-bit embedding quantization");
1107     }
1108     return output;
1109   }
1110 };
1111 
TORCH_LIBRARY_IMPL(quantized,CPU,m)1112 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
1113   // Function that works on TorchBind packed weights.
1114   m.impl(
1115       TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte"),
1116       TORCH_FN(QEmbeddingBag<8>::run));
1117   m.impl(
1118       TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit"),
1119       TORCH_FN(QEmbeddingBag<4>::run));
1120   m.impl(
1121       TORCH_SELECTIVE_NAME("quantized::embedding_byte"),
1122       TORCH_FN(QEmbedding<8>::run));
1123   m.impl(
1124       TORCH_SELECTIVE_NAME("quantized::embedding_4bit"),
1125       TORCH_FN(QEmbedding<4>::run));
1126 
1127   // Functions that work on at::Tensor packed weight.
1128   m.impl(
1129       TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_rowwise_offsets"),
1130       embedding_bag_byte_rowwise_offsets);
1131   m.impl(
1132       TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"),
1133       embedding_bag_4bit_rowwise_offsets);
1134   m.impl(
1135       TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_rowwise_offsets"),
1136       embedding_bag_2bit_rowwise_offsets);
1137 }
1138 
TORCH_LIBRARY_IMPL(quantized,Meta,m)1139 TORCH_LIBRARY_IMPL(quantized, Meta, m) {
1140   m.impl(
1141       "quantized::embedding_bag_byte_rowwise_offsets",
1142       embedding_bag_byte_rowwise_offsets_meta);
1143 }
1144 
1145 } // namespace
1146 } // namespace native
1147 } // namespace at
1148