1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/core/ivalue.h> 5 6 struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder { 7 virtual at::Tensor embeddingbag_byte( 8 const at::Tensor& indices, 9 const std::optional<at::Tensor>& offsets, 10 bool pruned_weights, 11 const std::optional<at::Tensor>& per_sample_weights_, 12 const std::optional<at::Tensor>& compressed_indices_mapping, 13 bool include_last_offset, 14 bool is_embedding_op) = 0; 15 16 virtual at::Tensor embeddingbag_4bit( 17 const at::Tensor& indices, 18 const std::optional<at::Tensor>& offsets, 19 bool pruned_weights, 20 const std::optional<at::Tensor>& per_sample_weights_, 21 const std::optional<at::Tensor>& compressed_indices_mapping, 22 bool include_last_offset, 23 bool is_embedding_op) = 0; 24 25 virtual at::Tensor unpack() = 0; 26 27 virtual int64_t bit_rate() const = 0; 28 virtual int64_t version() const = 0; 29 }; 30