xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/EmbeddingPackedParams.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/core/ivalue.h>
5 
6 struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder {
7   virtual at::Tensor embeddingbag_byte(
8     const at::Tensor& indices,
9     const std::optional<at::Tensor>& offsets,
10     bool pruned_weights,
11     const std::optional<at::Tensor>& per_sample_weights_,
12     const std::optional<at::Tensor>& compressed_indices_mapping,
13     bool include_last_offset,
14     bool is_embedding_op) = 0;
15 
16   virtual at::Tensor embeddingbag_4bit(
17     const at::Tensor& indices,
18     const std::optional<at::Tensor>& offsets,
19     bool pruned_weights,
20     const std::optional<at::Tensor>& per_sample_weights_,
21     const std::optional<at::Tensor>& compressed_indices_mapping,
22     bool include_last_offset,
23     bool is_embedding_op) = 0;
24 
25   virtual at::Tensor unpack() = 0;
26 
27   virtual int64_t bit_rate() const = 0;
28   virtual int64_t version() const = 0;
29 };
30