1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #include <executorch/runtime/kernel/kernel_includes.h> 10 #include <algorithm> 11 #include <cinttypes> 12 #include <cmath> 13 14 namespace torch { 15 namespace executor { 16 namespace native { 17 18 using Tensor = exec_aten::Tensor; 19 using Scalar = exec_aten::Scalar; 20 using ScalarType = exec_aten::ScalarType; 21 22 Tensor& quantized_embedding_xbit_out( 23 // TODO Evaluate whether this name is appropriate for an operator that takes 24 // non quant input and returns fp output 25 const Tensor& weight, 26 const Tensor& weight_scales, 27 const exec_aten::optional<Tensor>& opt_weight_zero_points, 28 const int64_t weight_quant_min, 29 const int64_t weight_quant_max, 30 const Tensor& indices, 31 Tensor& out, 32 int weight_nbit); 33 34 Tensor& quantized_embedding_xbit_out( 35 KernelRuntimeContext& context, 36 const Tensor& weight, 37 const Tensor& weight_scales, 38 const exec_aten::optional<Tensor>& opt_weight_zero_points, 39 int64_t weight_quant_min, 40 int64_t weight_quant_max, 41 const Tensor& indices, 42 Tensor& out, 43 int weight_nbit); 44 45 Tensor& quantized_embedding_xbit_dtype_out( 46 // TODO Evaluate whether this name is appropriate for an operator that takes 47 // non quant input and returns fp output 48 const Tensor& weight, 49 const Tensor& weight_scales, 50 const exec_aten::optional<Tensor>& opt_weight_zero_points, 51 const int64_t weight_quant_min, 52 const int64_t weight_quant_max, 53 const Tensor& indices, 54 exec_aten::optional<ScalarType> out_dtype, 55 Tensor& out, 56 int weight_nbit); 57 58 Tensor& quantized_embedding_xbit_dtype_out( 59 KernelRuntimeContext& context, 60 const Tensor& weight, 61 const Tensor& weight_scales, 62 const exec_aten::optional<Tensor>& opt_weight_zero_points, 63 int64_t weight_quant_min, 64 int64_t weight_quant_max, 65 const Tensor& indices, 66 exec_aten::optional<ScalarType> out_dtype, 67 Tensor& out, 68 int weight_nbit); 69 70 } // namespace native 71 } // namespace executor 72 } // namespace torch 73