xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/embeddingxb.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/kernel/kernel_includes.h>
10 #include <algorithm>
11 #include <cinttypes>
12 #include <cmath>
13 
14 namespace torch {
15 namespace executor {
16 namespace native {
17 
18 using Tensor = exec_aten::Tensor;
19 using Scalar = exec_aten::Scalar;
20 using ScalarType = exec_aten::ScalarType;
21 
22 Tensor& quantized_embedding_xbit_out(
23     // TODO Evaluate whether this name is appropriate for an operator that takes
24     // non quant input and returns fp output
25     const Tensor& weight,
26     const Tensor& weight_scales,
27     const exec_aten::optional<Tensor>& opt_weight_zero_points,
28     const int64_t weight_quant_min,
29     const int64_t weight_quant_max,
30     const Tensor& indices,
31     Tensor& out,
32     int weight_nbit);
33 
34 Tensor& quantized_embedding_xbit_out(
35     KernelRuntimeContext& context,
36     const Tensor& weight,
37     const Tensor& weight_scales,
38     const exec_aten::optional<Tensor>& opt_weight_zero_points,
39     int64_t weight_quant_min,
40     int64_t weight_quant_max,
41     const Tensor& indices,
42     Tensor& out,
43     int weight_nbit);
44 
45 Tensor& quantized_embedding_xbit_dtype_out(
46     // TODO Evaluate whether this name is appropriate for an operator that takes
47     // non quant input and returns fp output
48     const Tensor& weight,
49     const Tensor& weight_scales,
50     const exec_aten::optional<Tensor>& opt_weight_zero_points,
51     const int64_t weight_quant_min,
52     const int64_t weight_quant_max,
53     const Tensor& indices,
54     exec_aten::optional<ScalarType> out_dtype,
55     Tensor& out,
56     int weight_nbit);
57 
58 Tensor& quantized_embedding_xbit_dtype_out(
59     KernelRuntimeContext& context,
60     const Tensor& weight,
61     const Tensor& weight_scales,
62     const exec_aten::optional<Tensor>& opt_weight_zero_points,
63     int64_t weight_quant_min,
64     int64_t weight_quant_max,
65     const Tensor& indices,
66     exec_aten::optional<ScalarType> out_dtype,
67     Tensor& out,
68     int weight_nbit);
69 
70 } // namespace native
71 } // namespace executor
72 } // namespace torch
73