xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/embeddingxb.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/quantized/cpu/embeddingxb.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cassert>
13 #include <cinttypes>
14 #include <cmath>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using Tensor = exec_aten::Tensor;
21 using Scalar = exec_aten::Scalar;
22 using ScalarType = exec_aten::ScalarType;
23 
24 namespace {
25 
26 static inline int32_t
weight_value(const unsigned char * w_data,int32_t index,int32_t weight_nbit)27 weight_value(const unsigned char* w_data, int32_t index, int32_t weight_nbit) {
28   if (weight_nbit == 2) {
29     int32_t subbyte = index % 4;
30     index >>= 2;
31     switch (subbyte) {
32       case 0:
33         return (int32_t)(w_data[index] & 3) - 2;
34       case 1:
35         return (int32_t)((w_data[index] & 12) >> 2) - 2;
36       case 2:
37         return (int32_t)((w_data[index] & 48) >> 4) - 2;
38       case 3:
39         return (int32_t)((w_data[index] & 192) >> 6) - 2;
40     }
41   } else if (weight_nbit == 4) {
42     int32_t odd = index & 1;
43     index >>= 1;
44     if (odd) {
45       return (int32_t)(w_data[index] & 0x0F) - 8;
46     } else {
47       return (int32_t)((w_data[index] >> 4) & 0x0F) - 8;
48     }
49   }
50 
51   ET_CHECK_MSG(false, "invalid weight_nbit");
52 }
53 
get_embedding_dim(int32_t packed_dim,int32_t weight_nbit)54 static inline int32_t get_embedding_dim(
55     int32_t packed_dim,
56     int32_t weight_nbit) {
57   ET_CHECK_MSG(8 % weight_nbit == 0, "invalid embedding dim");
58   int packed_values_per_byte = 8 / weight_nbit;
59   return packed_dim * packed_values_per_byte;
60 }
61 
62 /**
63  * Asserts that the parameters are valid.
64  */
check_embedding_xbit_args(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)65 void check_embedding_xbit_args(
66     const Tensor& weight,
67     const Tensor& weight_scales,
68     const exec_aten::optional<Tensor>& opt_weight_zero_points,
69     const int64_t weight_quant_min,
70     const int64_t weight_quant_max,
71     const Tensor& indices,
72     exec_aten::optional<ScalarType> out_dtype,
73     Tensor& out,
74     int weight_nbit) {
75   ET_CHECK_MSG(8 % weight_nbit == 0, "nbit must divide 8");
76 
77   ET_CHECK_MSG(
78       weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
79 
80   ET_CHECK_MSG(
81       weight_scales.dim() == 1 || weight_scales.dim() == 2,
82       "weight_scales must be 1D or 2D but got() %zd dims",
83       weight_scales.dim());
84 
85   ET_CHECK_MSG(
86       weight_scales.size(0) == weight.size(0),
87       "Number of scales must be == weight.size(0)=%zd"
88       ", but got %zd",
89       weight_scales.size(0),
90       weight.size(0));
91 
92   if (weight_scales.dim() == 2) {
93     auto num_groups = weight_scales.size(1);
94     ET_CHECK_MSG(
95         // each 8b uint8 column is packed_values_per_byte columns
96         get_embedding_dim(weight.size(1), weight_nbit) % num_groups == 0,
97         "Number of groups must divide weight.size(1)=%zd"
98         ", but got # of groups = %zd",
99         weight.size(1),
100         num_groups);
101   }
102 
103   ET_CHECK_MSG(
104       weight.scalar_type() == ScalarType::Byte,
105       "weight.scalar_type() %" PRId8 " is not supported:",
106       static_cast<int8_t>(weight.scalar_type()));
107 
108   ET_CHECK_MSG(
109       out.scalar_type() == ScalarType::Float ||
110           out.scalar_type() == ScalarType::Half,
111       "out.scalar_type() %" PRId8 " is not supported:",
112       static_cast<int8_t>(out.scalar_type()));
113 
114   ET_CHECK_MSG(
115       weight_scales.scalar_type() == ScalarType::Float ||
116           weight_scales.scalar_type() == ScalarType::Half,
117       "weight_scales.scalar_type() %" PRId8 " is not supported:",
118       static_cast<int8_t>(weight_scales.scalar_type()));
119 
120   if (opt_weight_zero_points.has_value()) {
121     ET_CHECK_MSG(
122         opt_weight_zero_points.value().dim() == weight_scales.dim(),
123         "weight_zero_points's rank match that of weight_scales. "
124         "weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8,
125         static_cast<int8_t>(opt_weight_zero_points.value().dim()),
126         static_cast<int8_t>(weight_scales.dim()));
127 
128     ET_CHECK_MSG(
129         opt_weight_zero_points.value().scalar_type() == out.scalar_type(),
130         "weight zero points scalar type %" PRId8
131         " does not match out.scalar_type()",
132         static_cast<int8_t>(opt_weight_zero_points.value().scalar_type()));
133 
134     for (int32_t i = 0; i < weight_scales.dim(); ++i) {
135       ET_CHECK_MSG(
136           opt_weight_zero_points.value().size(i) == weight_scales.size(i),
137           "Dimension size misatch at dim %" PRIi32
138           "Weight_zero_point size = %zd"
139           ", weight_scales size = %zd.",
140           i,
141           opt_weight_zero_points.value().size(i),
142           weight_scales.size(i));
143     }
144   }
145 
146   ET_CHECK_MSG(
147       indices.scalar_type() == ScalarType::Long,
148       "indices.scalar_type() %" PRId8 " is not Long only Long is supported:",
149       static_cast<int8_t>(indices.scalar_type()));
150 
151   ET_CHECK_MSG(
152       weight_quant_min <= weight_quant_max,
153       "weight quant min: %" PRId64
154       " is greater than weight quant max: %" PRId64,
155       weight_quant_min,
156       weight_quant_max);
157 
158   if (out_dtype.has_value()) {
159     ET_CHECK_MSG(
160         out.scalar_type() == out_dtype.value(),
161         "output_dtype must match the dtype of the out tensor");
162   }
163 }
164 
165 /**
166  * Retrieves the embeddings specified by indices, dequantizes them, and stores
167  * them in out. Weight will always be uint8
168  */
169 template <typename CTYPE_PARAMS, typename CTYPE_OUT>
embedding_xbit_per_channel(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const Tensor & indices,Tensor & out,int weight_nbit)170 void embedding_xbit_per_channel(
171     const Tensor& weight,
172     const Tensor& weight_scales,
173     const exec_aten::optional<Tensor>& opt_weight_zero_points,
174     const Tensor& indices,
175     Tensor& out,
176     int weight_nbit) {
177   auto embedding_dim = get_embedding_dim(weight.size(1), weight_nbit);
178 
179   int32_t num_groups_per_channel = 1;
180   if (weight_scales.dim() == 2) {
181     num_groups_per_channel = weight_scales.size(1);
182   }
183   int32_t group_size = embedding_dim / num_groups_per_channel;
184 
185   CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
186   const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
187 
188   const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
189   const CTYPE_PARAMS* zero_points = nullptr;
190   if (opt_weight_zero_points.has_value()) {
191     zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_PARAMS>();
192   }
193 
194   for (int i = 0; i < indices.numel(); i++) {
195     int64_t index = indices_ptr[i];
196     // If using groupwise embedding
197     int32_t qparams_index = index * num_groups_per_channel;
198     CTYPE_PARAMS zp = 0.0;
199     const CTYPE_PARAMS* scale_ptr = scales + qparams_index;
200     const CTYPE_PARAMS* zero_points_ptr = nullptr;
201     if (opt_weight_zero_points.has_value()) {
202       zero_points_ptr = zero_points + qparams_index;
203     }
204 
205     const uint8_t* w_data =
206         weight.const_data_ptr<uint8_t>() + weight.size(1) * index;
207 
208     for (int j = 0; j < embedding_dim; ++j) {
209       int32_t group_id = j / group_size;
210       const CTYPE_PARAMS scale = scale_ptr[group_id];
211       if (opt_weight_zero_points.has_value()) {
212         zp = zero_points_ptr[group_id];
213       }
214       out_data[j] = static_cast<CTYPE_OUT>(
215           (static_cast<float>(weight_value(w_data, j, weight_nbit)) -
216            static_cast<float>(zp)) *
217           static_cast<float>(scale));
218     }
219     out_data += embedding_dim;
220   }
221 }
222 
resize_out_tensor(const Tensor & weight,const Tensor & indices,Tensor & out,int weight_nbit)223 void resize_out_tensor(
224     const Tensor& weight,
225     const Tensor& indices,
226     Tensor& out,
227     int weight_nbit) {
228   exec_aten::SizesType expected_output_size[kTensorDimensionLimit];
229   for (size_t i = 0; i < indices.dim(); i++) {
230     expected_output_size[i] = indices.size(i);
231   }
232   const size_t embedding_dim = get_embedding_dim(weight.size(1), weight_nbit);
233   expected_output_size[out.dim() - 1] = embedding_dim;
234 
235   exec_aten::ArrayRef<exec_aten::SizesType> output_size{
236       expected_output_size, static_cast<size_t>(out.dim())};
237 
238   torch::executor::Error err = resize_tensor(out, output_size);
239   ET_CHECK_MSG(
240       err == torch::executor::Error::Ok,
241       "Failed to resize out Tensor in quantized_embedding_xbit_out");
242 }
243 
244 } // namespace
245 
246 /**
247  * Retrieves the embeddings specified by indices, dequantizes them, and stores
248  * them in out. The weight is quantized per channel, with a scale and zero_point
249  * for each embedding.
250  *
251  * Corresponds as the out variant to torch.ops.quantized.embedding_xbit
252  *
253  * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
254  * metadata that is passed around which can be useful for pattern matching. See
255  * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
256  * info.
257  */
quantized_embedding_xbit_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,Tensor & out,int weight_nbit)258 Tensor& quantized_embedding_xbit_out(
259     // TODO Evaluate whether this name is appropriate for an operator that takes
260     // non quant input and returns fp output
261     const Tensor& weight,
262     const Tensor& weight_scales,
263     const exec_aten::optional<Tensor>& opt_weight_zero_points,
264     const int64_t weight_quant_min,
265     const int64_t weight_quant_max,
266     const Tensor& indices,
267     Tensor& out,
268     int weight_nbit) {
269   ScalarType out_type = out.scalar_type();
270 
271   // TODO (jakeszwe): improve these to account for the size of out in relation
272   // to weight and indices accounting for a possible batch dimension
273   check_embedding_xbit_args(
274       weight,
275       weight_scales,
276       opt_weight_zero_points,
277       weight_quant_min,
278       weight_quant_max,
279       indices,
280       out_type,
281       out,
282       weight_nbit);
283 
284   constexpr auto name = "quantized_decomposed::embedding_xbit.out";
285   ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
286     embedding_xbit_per_channel<CTYPE_OUT, CTYPE_OUT>(
287         weight,
288         weight_scales,
289         opt_weight_zero_points,
290         indices,
291         out,
292         weight_nbit);
293   });
294 
295   return out;
296 }
297 
quantized_embedding_xbit_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,Tensor & out,int weight_nbit)298 Tensor& quantized_embedding_xbit_out(
299     KernelRuntimeContext& context,
300     const Tensor& weight,
301     const Tensor& weight_scales,
302     const exec_aten::optional<Tensor>& opt_weight_zero_points,
303     int64_t weight_quant_min,
304     int64_t weight_quant_max,
305     const Tensor& indices,
306     Tensor& out,
307     int weight_nbit) {
308   // TODO(larryliu): Add a context arg to the real op function and remove this
309   // wrapper
310   (void)context;
311   resize_out_tensor(weight, indices, out, weight_nbit);
312   return quantized_embedding_xbit_out(
313       weight,
314       weight_scales,
315       opt_weight_zero_points,
316       weight_quant_min,
317       weight_quant_max,
318       indices,
319       out,
320       weight_nbit);
321 }
322 
quantized_embedding_xbit_dtype_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)323 Tensor& quantized_embedding_xbit_dtype_out(
324     // TODO Evaluate whether this name is appropriate for an operator that takes
325     // non quant input and returns fp output
326     const Tensor& weight,
327     const Tensor& weight_scales,
328     const exec_aten::optional<Tensor>& opt_weight_zero_points,
329     const int64_t weight_quant_min,
330     const int64_t weight_quant_max,
331     const Tensor& indices,
332     exec_aten::optional<ScalarType> out_dtype,
333     Tensor& out,
334     int weight_nbit) {
335   // TODO (jakeszwe): improve these to account for the size of out in relation
336   // to weight and indices accounting for a possible batch dimension
337   check_embedding_xbit_args(
338       weight,
339       weight_scales,
340       opt_weight_zero_points,
341       weight_quant_min,
342       weight_quant_max,
343       indices,
344       out_dtype,
345       out,
346       weight_nbit);
347 
348   ScalarType params_type = weight_scales.scalar_type();
349   ScalarType out_type = out.scalar_type();
350 
351   constexpr auto name = "quantized_decomposed::embedding_xbit.dtype_out";
352   ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
353     ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
354       embedding_xbit_per_channel<CTYPE_P, CTYPE_OUT>(
355           weight,
356           weight_scales,
357           opt_weight_zero_points,
358           indices,
359           out,
360           weight_nbit);
361     });
362   });
363 
364   return out;
365 }
366 
quantized_embedding_xbit_dtype_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)367 Tensor& quantized_embedding_xbit_dtype_out(
368     KernelRuntimeContext& context,
369     const Tensor& weight,
370     const Tensor& weight_scales,
371     const exec_aten::optional<Tensor>& opt_weight_zero_points,
372     int64_t weight_quant_min,
373     int64_t weight_quant_max,
374     const Tensor& indices,
375     exec_aten::optional<ScalarType> out_dtype,
376     Tensor& out,
377     int weight_nbit) {
378   // TODO(larryliu): Add a context arg to the real op function and remove this
379   // wrapper
380   (void)context;
381   resize_out_tensor(weight, indices, out, weight_nbit);
382   return quantized_embedding_xbit_dtype_out(
383       weight,
384       weight_scales,
385       opt_weight_zero_points,
386       weight_quant_min,
387       weight_quant_max,
388       indices,
389       out_dtype,
390       out,
391       weight_nbit);
392 }
393 
394 } // namespace native
395 } // namespace executor
396 } // namespace torch
397